summaryrefslogtreecommitdiffstats
path: root/train.py
diff options
context:
space:
mode:
Diffstat (limited to 'train.py')
-rw-r--r--train.py672
1 files changed, 672 insertions, 0 deletions
diff --git a/train.py b/train.py
new file mode 100644
index 0000000..d8644c4
--- /dev/null
+++ b/train.py
@@ -0,0 +1,672 @@
1import argparse
2import datetime
3import logging
4from pathlib import Path
5
6import torch
7import torch.utils.checkpoint
8
9from accelerate import Accelerator
10from accelerate.logging import get_logger
11from accelerate.utils import LoggerType, set_seed
12from slugify import slugify
13
14from data.csv import VlpnDataModule, VlpnDataItem
15from util import load_config, load_embeddings_from_dir
16
17from trainer.ti import TextualInversionTrainingStrategy
18from trainer.base import Trainer
19from training.optimization import get_scheduler
20from training.util import save_args, generate_class_images, add_placeholder_tokens, get_models
21
22logger = get_logger(__name__)
23
24
25torch.backends.cuda.matmul.allow_tf32 = True
26torch.backends.cudnn.benchmark = True
27
28
29def parse_args():
30 parser = argparse.ArgumentParser(
31 description="Simple example of a training script."
32 )
33 parser.add_argument(
34 "--pretrained_model_name_or_path",
35 type=str,
36 default=None,
37 help="Path to pretrained model or model identifier from huggingface.co/models.",
38 )
39 parser.add_argument(
40 "--tokenizer_name",
41 type=str,
42 default=None,
43 help="Pretrained tokenizer name or path if not the same as model_name",
44 )
45 parser.add_argument(
46 "--train_data_file",
47 type=str,
48 default=None,
49 help="A CSV file containing the training data."
50 )
51 parser.add_argument(
52 "--train_data_template",
53 type=str,
54 default="template",
55 )
56 parser.add_argument(
57 "--project",
58 type=str,
59 default=None,
60 help="The name of the current project.",
61 )
62 parser.add_argument(
63 "--placeholder_tokens",
64 type=str,
65 nargs='*',
66 help="A token to use as a placeholder for the concept.",
67 )
68 parser.add_argument(
69 "--initializer_tokens",
70 type=str,
71 nargs='*',
72 help="A token to use as initializer word."
73 )
74 parser.add_argument(
75 "--num_vectors",
76 type=int,
77 nargs='*',
78 help="Number of vectors per embedding."
79 )
80 parser.add_argument(
81 "--num_class_images",
82 type=int,
83 default=1,
84 help="How many class images to generate."
85 )
86 parser.add_argument(
87 "--class_image_dir",
88 type=str,
89 default="cls",
90 help="The directory where class images will be saved.",
91 )
92 parser.add_argument(
93 "--exclude_collections",
94 type=str,
95 nargs='*',
96 help="Exclude all items with a listed collection.",
97 )
98 parser.add_argument(
99 "--output_dir",
100 type=str,
101 default="output/text-inversion",
102 help="The output directory where the model predictions and checkpoints will be written.",
103 )
104 parser.add_argument(
105 "--embeddings_dir",
106 type=str,
107 default=None,
108 help="The embeddings directory where Textual Inversion embeddings are stored.",
109 )
110 parser.add_argument(
111 "--collection",
112 type=str,
113 nargs='*',
114 help="A collection to filter the dataset.",
115 )
116 parser.add_argument(
117 "--seed",
118 type=int,
119 default=None,
120 help="A seed for reproducible training."
121 )
122 parser.add_argument(
123 "--resolution",
124 type=int,
125 default=768,
126 help=(
127 "The resolution for input images, all the images in the train/validation dataset will be resized to this"
128 " resolution"
129 ),
130 )
131 parser.add_argument(
132 "--num_buckets",
133 type=int,
134 default=0,
135 help="Number of aspect ratio buckets in either direction.",
136 )
137 parser.add_argument(
138 "--progressive_buckets",
139 action="store_true",
140 help="Include images in smaller buckets as well.",
141 )
142 parser.add_argument(
143 "--bucket_step_size",
144 type=int,
145 default=64,
146 help="Step size between buckets.",
147 )
148 parser.add_argument(
149 "--bucket_max_pixels",
150 type=int,
151 default=None,
152 help="Maximum pixels per bucket.",
153 )
154 parser.add_argument(
155 "--tag_dropout",
156 type=float,
157 default=0,
158 help="Tag dropout probability.",
159 )
160 parser.add_argument(
161 "--no_tag_shuffle",
162 action="store_true",
163 help="Shuffle tags.",
164 )
165 parser.add_argument(
166 "--vector_dropout",
167 type=int,
168 default=0,
169 help="Vector dropout probability.",
170 )
171 parser.add_argument(
172 "--vector_shuffle",
173 type=str,
174 default="auto",
175 help='Vector shuffling algorithm. Choose between ["all", "trailing", "leading", "between", "auto", "off"]',
176 )
177 parser.add_argument(
178 "--num_train_epochs",
179 type=int,
180 default=100
181 )
182 parser.add_argument(
183 "--gradient_accumulation_steps",
184 type=int,
185 default=1,
186 help="Number of updates steps to accumulate before performing a backward/update pass.",
187 )
188 parser.add_argument(
189 "--gradient_checkpointing",
190 action="store_true",
191 help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
192 )
193 parser.add_argument(
194 "--find_lr",
195 action="store_true",
196 help="Automatically find a learning rate (no training).",
197 )
198 parser.add_argument(
199 "--learning_rate",
200 type=float,
201 default=1e-4,
202 help="Initial learning rate (after the potential warmup period) to use.",
203 )
204 parser.add_argument(
205 "--scale_lr",
206 action="store_true",
207 help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
208 )
209 parser.add_argument(
210 "--lr_scheduler",
211 type=str,
212 default="one_cycle",
213 help=(
214 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
215 ' "constant", "constant_with_warmup", "one_cycle"]'
216 ),
217 )
218 parser.add_argument(
219 "--lr_warmup_epochs",
220 type=int,
221 default=10,
222 help="Number of steps for the warmup in the lr scheduler."
223 )
224 parser.add_argument(
225 "--lr_cycles",
226 type=int,
227 default=None,
228 help="Number of restart cycles in the lr scheduler."
229 )
230 parser.add_argument(
231 "--lr_warmup_func",
232 type=str,
233 default="cos",
234 help='Choose between ["linear", "cos"]'
235 )
236 parser.add_argument(
237 "--lr_warmup_exp",
238 type=int,
239 default=1,
240 help='If lr_warmup_func is "cos", exponent to modify the function'
241 )
242 parser.add_argument(
243 "--lr_annealing_func",
244 type=str,
245 default="cos",
246 help='Choose between ["linear", "half_cos", "cos"]'
247 )
248 parser.add_argument(
249 "--lr_annealing_exp",
250 type=int,
251 default=1,
252 help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function'
253 )
254 parser.add_argument(
255 "--lr_min_lr",
256 type=float,
257 default=0.04,
258 help="Minimum learning rate in the lr scheduler."
259 )
260 parser.add_argument(
261 "--use_ema",
262 action="store_true",
263 help="Whether to use EMA model."
264 )
265 parser.add_argument(
266 "--ema_inv_gamma",
267 type=float,
268 default=1.0
269 )
270 parser.add_argument(
271 "--ema_power",
272 type=float,
273 default=1
274 )
275 parser.add_argument(
276 "--ema_max_decay",
277 type=float,
278 default=0.9999
279 )
280 parser.add_argument(
281 "--use_8bit_adam",
282 action="store_true",
283 help="Whether or not to use 8-bit Adam from bitsandbytes."
284 )
285 parser.add_argument(
286 "--adam_beta1",
287 type=float,
288 default=0.9,
289 help="The beta1 parameter for the Adam optimizer."
290 )
291 parser.add_argument(
292 "--adam_beta2",
293 type=float,
294 default=0.999,
295 help="The beta2 parameter for the Adam optimizer."
296 )
297 parser.add_argument(
298 "--adam_weight_decay",
299 type=float,
300 default=0,
301 help="Weight decay to use."
302 )
303 parser.add_argument(
304 "--adam_epsilon",
305 type=float,
306 default=1e-08,
307 help="Epsilon value for the Adam optimizer"
308 )
309 parser.add_argument(
310 "--adam_amsgrad",
311 type=bool,
312 default=False,
313 help="Amsgrad value for the Adam optimizer"
314 )
315 parser.add_argument(
316 "--mixed_precision",
317 type=str,
318 default="no",
319 choices=["no", "fp16", "bf16"],
320 help=(
321 "Whether to use mixed precision. Choose"
322 "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
323 "and an Nvidia Ampere GPU."
324 ),
325 )
326 parser.add_argument(
327 "--checkpoint_frequency",
328 type=int,
329 default=5,
330 help="How often to save a checkpoint and sample image (in epochs)",
331 )
332 parser.add_argument(
333 "--sample_frequency",
334 type=int,
335 default=1,
336 help="How often to save a checkpoint and sample image (in epochs)",
337 )
338 parser.add_argument(
339 "--sample_image_size",
340 type=int,
341 default=768,
342 help="Size of sample images",
343 )
344 parser.add_argument(
345 "--sample_batches",
346 type=int,
347 default=1,
348 help="Number of sample batches to generate per checkpoint",
349 )
350 parser.add_argument(
351 "--sample_batch_size",
352 type=int,
353 default=1,
354 help="Number of samples to generate per batch",
355 )
356 parser.add_argument(
357 "--valid_set_size",
358 type=int,
359 default=None,
360 help="Number of images in the validation dataset."
361 )
362 parser.add_argument(
363 "--valid_set_repeat",
364 type=int,
365 default=1,
366 help="Times the images in the validation dataset are repeated."
367 )
368 parser.add_argument(
369 "--train_batch_size",
370 type=int,
371 default=1,
372 help="Batch size (per device) for the training dataloader."
373 )
374 parser.add_argument(
375 "--sample_steps",
376 type=int,
377 default=20,
378 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.",
379 )
380 parser.add_argument(
381 "--prior_loss_weight",
382 type=float,
383 default=1.0,
384 help="The weight of prior preservation loss."
385 )
386 parser.add_argument(
387 "--emb_decay_target",
388 default=0.4,
389 type=float,
390 help="Embedding decay target."
391 )
392 parser.add_argument(
393 "--emb_decay_factor",
394 default=0,
395 type=float,
396 help="Embedding decay factor."
397 )
398 parser.add_argument(
399 "--emb_decay_start",
400 default=1e-4,
401 type=float,
402 help="Embedding decay start offset."
403 )
404 parser.add_argument(
405 "--noise_timesteps",
406 type=int,
407 default=1000,
408 )
409 parser.add_argument(
410 "--resume_from",
411 type=str,
412 default=None,
413 help="Path to a directory to resume training from (ie, logs/token_name/2022-09-22T23-36-27)"
414 )
415 parser.add_argument(
416 "--global_step",
417 type=int,
418 default=0,
419 )
420 parser.add_argument(
421 "--config",
422 type=str,
423 default=None,
424 help="Path to a JSON configuration file containing arguments for invoking this script."
425 )
426
427 args = parser.parse_args()
428 if args.config is not None:
429 args = load_config(args.config)
430 args = parser.parse_args(namespace=argparse.Namespace(**args))
431
432 if args.train_data_file is None:
433 raise ValueError("You must specify --train_data_file")
434
435 if args.pretrained_model_name_or_path is None:
436 raise ValueError("You must specify --pretrained_model_name_or_path")
437
438 if args.project is None:
439 raise ValueError("You must specify --project")
440
441 if isinstance(args.placeholder_tokens, str):
442 args.placeholder_tokens = [args.placeholder_tokens]
443
444 if len(args.placeholder_tokens) == 0:
445 args.placeholder_tokens = [f"<*{i}>" for i in range(args.initializer_tokens)]
446
447 if isinstance(args.initializer_tokens, str):
448 args.initializer_tokens = [args.initializer_tokens] * len(args.placeholder_tokens)
449
450 if len(args.initializer_tokens) == 0:
451 raise ValueError("You must specify --initializer_tokens")
452
453 if len(args.placeholder_tokens) != len(args.initializer_tokens):
454 raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items")
455
456 if args.num_vectors is None:
457 args.num_vectors = 1
458
459 if isinstance(args.num_vectors, int):
460 args.num_vectors = [args.num_vectors] * len(args.initializer_tokens)
461
462 if len(args.placeholder_tokens) != len(args.num_vectors):
463 raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items")
464
465 if isinstance(args.collection, str):
466 args.collection = [args.collection]
467
468 if isinstance(args.exclude_collections, str):
469 args.exclude_collections = [args.exclude_collections]
470
471 if args.output_dir is None:
472 raise ValueError("You must specify --output_dir")
473
474 return args
475
476
477def main():
478 args = parse_args()
479
480 global_step_offset = args.global_step
481 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
482 output_dir = Path(args.output_dir).joinpath(slugify(args.project), now)
483 output_dir.mkdir(parents=True, exist_ok=True)
484
485 accelerator = Accelerator(
486 log_with=LoggerType.TENSORBOARD,
487 logging_dir=f"{output_dir}",
488 gradient_accumulation_steps=args.gradient_accumulation_steps,
489 mixed_precision=args.mixed_precision
490 )
491
492 logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG)
493
494 if args.seed is None:
495 args.seed = torch.random.seed() >> 32
496
497 set_seed(args.seed)
498
499 save_args(output_dir, args)
500
501 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models(
502 args.pretrained_model_name_or_path)
503
504 tokenizer.set_use_vector_shuffle(args.vector_shuffle)
505 tokenizer.set_dropout(args.vector_dropout)
506
507 vae.enable_slicing()
508 vae.set_use_memory_efficient_attention_xformers(True)
509 unet.set_use_memory_efficient_attention_xformers(True)
510
511 if args.gradient_checkpointing:
512 unet.enable_gradient_checkpointing()
513 text_encoder.gradient_checkpointing_enable()
514
515 if args.embeddings_dir is not None:
516 embeddings_dir = Path(args.embeddings_dir)
517 if not embeddings_dir.exists() or not embeddings_dir.is_dir():
518 raise ValueError("--embeddings_dir must point to an existing directory")
519
520 added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir)
521 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}")
522
523 placeholder_token_ids, initializer_token_ids = add_placeholder_tokens(
524 tokenizer=tokenizer,
525 embeddings=embeddings,
526 placeholder_tokens=args.placeholder_tokens,
527 initializer_tokens=args.initializer_tokens,
528 num_vectors=args.num_vectors
529 )
530
531 if len(placeholder_token_ids) != 0:
532 initializer_token_id_lens = [len(id) for id in initializer_token_ids]
533 placeholder_token_stats = list(zip(args.placeholder_tokens, placeholder_token_ids, initializer_token_id_lens))
534 print(f"Added {len(placeholder_token_ids)} new tokens: {placeholder_token_stats}")
535
536 if args.scale_lr:
537 args.learning_rate = (
538 args.learning_rate * args.gradient_accumulation_steps *
539 args.train_batch_size * accelerator.num_processes
540 )
541
542 if args.find_lr:
543 args.learning_rate = 1e-5
544
545 if args.use_8bit_adam:
546 try:
547 import bitsandbytes as bnb
548 except ImportError:
549 raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.")
550
551 optimizer_class = bnb.optim.AdamW8bit
552 else:
553 optimizer_class = torch.optim.AdamW
554
555 optimizer = optimizer_class(
556 text_encoder.text_model.embeddings.temp_token_embedding.parameters(),
557 lr=args.learning_rate,
558 betas=(args.adam_beta1, args.adam_beta2),
559 weight_decay=args.adam_weight_decay,
560 eps=args.adam_epsilon,
561 amsgrad=args.adam_amsgrad,
562 )
563
564 weight_dtype = torch.float32
565 if args.mixed_precision == "fp16":
566 weight_dtype = torch.float16
567 elif args.mixed_precision == "bf16":
568 weight_dtype = torch.bfloat16
569
570 def keyword_filter(item: VlpnDataItem):
571 cond1 = any(
572 keyword in part
573 for keyword in args.placeholder_tokens
574 for part in item.prompt
575 )
576 cond3 = args.collection is None or args.collection in item.collection
577 cond4 = args.exclude_collections is None or not any(
578 collection in item.collection
579 for collection in args.exclude_collections
580 )
581 return cond1 and cond3 and cond4
582
583 datamodule = VlpnDataModule(
584 data_file=args.train_data_file,
585 batch_size=args.train_batch_size,
586 tokenizer=tokenizer,
587 class_subdir=args.class_image_dir,
588 num_class_images=args.num_class_images,
589 size=args.resolution,
590 num_buckets=args.num_buckets,
591 progressive_buckets=args.progressive_buckets,
592 bucket_step_size=args.bucket_step_size,
593 bucket_max_pixels=args.bucket_max_pixels,
594 dropout=args.tag_dropout,
595 shuffle=not args.no_tag_shuffle,
596 template_key=args.train_data_template,
597 valid_set_size=args.valid_set_size,
598 valid_set_repeat=args.valid_set_repeat,
599 seed=args.seed,
600 filter=keyword_filter,
601 dtype=weight_dtype
602 )
603 datamodule.setup()
604
605 train_dataloader = datamodule.train_dataloader
606 val_dataloader = datamodule.val_dataloader
607
608 if args.num_class_images != 0:
609 generate_class_images(
610 accelerator,
611 text_encoder,
612 vae,
613 unet,
614 tokenizer,
615 sample_scheduler,
616 datamodule.data_train,
617 args.sample_batch_size,
618 args.sample_image_size,
619 args.sample_steps
620 )
621
622 lr_scheduler = get_scheduler(
623 args.lr_scheduler,
624 optimizer=optimizer,
625 num_training_steps_per_epoch=len(train_dataloader),
626 gradient_accumulation_steps=args.gradient_accumulation_steps,
627 min_lr=args.lr_min_lr,
628 warmup_func=args.lr_warmup_func,
629 annealing_func=args.lr_annealing_func,
630 warmup_exp=args.lr_warmup_exp,
631 annealing_exp=args.lr_annealing_exp,
632 cycles=args.lr_cycles,
633 train_epochs=args.num_train_epochs,
634 warmup_epochs=args.lr_warmup_epochs,
635 )
636
637 trainer = Trainer(
638 accelerator=accelerator,
639 unet=unet,
640 text_encoder=text_encoder,
641 tokenizer=tokenizer,
642 vae=vae,
643 noise_scheduler=noise_scheduler,
644 sample_scheduler=sample_scheduler,
645 train_dataloader=train_dataloader,
646 val_dataloader=val_dataloader,
647 dtype=weight_dtype,
648 )
649
650 trainer(
651 strategy_class=TextualInversionTrainingStrategy,
652 optimizer=optimizer,
653 lr_scheduler=lr_scheduler,
654 num_train_epochs=args.num_train_epochs,
655 sample_frequency=args.sample_frequency,
656 checkpoint_frequency=args.checkpoint_frequency,
657 global_step_offset=global_step_offset,
658 prior_loss_weight=args.prior_loss_weight,
659 output_dir=output_dir,
660 placeholder_tokens=args.placeholder_tokens,
661 placeholder_token_ids=placeholder_token_ids,
662 learning_rate=args.learning_rate,
663 sample_steps=args.sample_steps,
664 sample_image_size=args.sample_image_size,
665 sample_batch_size=args.sample_batch_size,
666 sample_batches=args.sample_batches,
667 seed=args.seed,
668 )
669
670
671if __name__ == "__main__":
672 main()