summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py379
1 files changed, 201 insertions, 178 deletions
diff --git a/train_ti.py b/train_ti.py
index f60e3e5..c6f0b3a 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -32,7 +32,7 @@ from util.files import load_config, load_embeddings_from_dir
32 32
33logger = get_logger(__name__) 33logger = get_logger(__name__)
34 34
35warnings.filterwarnings('ignore') 35warnings.filterwarnings("ignore")
36 36
37 37
38torch.backends.cuda.matmul.allow_tf32 = True 38torch.backends.cuda.matmul.allow_tf32 = True
@@ -46,9 +46,7 @@ hidet.torch.dynamo_config.search_space(0)
46 46
47 47
48def parse_args(): 48def parse_args():
49 parser = argparse.ArgumentParser( 49 parser = argparse.ArgumentParser(description="Simple example of a training script.")
50 description="Simple example of a training script."
51 )
52 parser.add_argument( 50 parser.add_argument(
53 "--pretrained_model_name_or_path", 51 "--pretrained_model_name_or_path",
54 type=str, 52 type=str,
@@ -65,12 +63,12 @@ def parse_args():
65 "--train_data_file", 63 "--train_data_file",
66 type=str, 64 type=str,
67 default=None, 65 default=None,
68 help="A CSV file containing the training data." 66 help="A CSV file containing the training data.",
69 ) 67 )
70 parser.add_argument( 68 parser.add_argument(
71 "--train_data_template", 69 "--train_data_template",
72 type=str, 70 type=str,
73 nargs='*', 71 nargs="*",
74 default="template", 72 default="template",
75 ) 73 )
76 parser.add_argument( 74 parser.add_argument(
@@ -80,59 +78,47 @@ def parse_args():
80 help="The name of the current project.", 78 help="The name of the current project.",
81 ) 79 )
82 parser.add_argument( 80 parser.add_argument(
83 "--auto_cycles", 81 "--auto_cycles", type=str, default="o", help="Cycles to run automatically."
84 type=str,
85 default="o",
86 help="Cycles to run automatically."
87 ) 82 )
88 parser.add_argument( 83 parser.add_argument(
89 "--cycle_decay", 84 "--cycle_decay", type=float, default=1.0, help="Learning rate decay per cycle."
90 type=float,
91 default=1.0,
92 help="Learning rate decay per cycle."
93 ) 85 )
94 parser.add_argument( 86 parser.add_argument(
95 "--placeholder_tokens", 87 "--placeholder_tokens",
96 type=str, 88 type=str,
97 nargs='*', 89 nargs="*",
98 help="A token to use as a placeholder for the concept.", 90 help="A token to use as a placeholder for the concept.",
99 ) 91 )
100 parser.add_argument( 92 parser.add_argument(
101 "--initializer_tokens", 93 "--initializer_tokens",
102 type=str, 94 type=str,
103 nargs='*', 95 nargs="*",
104 help="A token to use as initializer word." 96 help="A token to use as initializer word.",
105 ) 97 )
106 parser.add_argument( 98 parser.add_argument(
107 "--filter_tokens", 99 "--filter_tokens", type=str, nargs="*", help="Tokens to filter the dataset by."
108 type=str,
109 nargs='*',
110 help="Tokens to filter the dataset by."
111 ) 100 )
112 parser.add_argument( 101 parser.add_argument(
113 "--initializer_noise", 102 "--initializer_noise",
114 type=float, 103 type=float,
115 default=0, 104 default=0,
116 help="Noise to apply to the initializer word" 105 help="Noise to apply to the initializer word",
117 ) 106 )
118 parser.add_argument( 107 parser.add_argument(
119 "--alias_tokens", 108 "--alias_tokens",
120 type=str, 109 type=str,
121 nargs='*', 110 nargs="*",
122 default=[], 111 default=[],
123 help="Tokens to create an alias for." 112 help="Tokens to create an alias for.",
124 ) 113 )
125 parser.add_argument( 114 parser.add_argument(
126 "--inverted_initializer_tokens", 115 "--inverted_initializer_tokens",
127 type=str, 116 type=str,
128 nargs='*', 117 nargs="*",
129 help="A token to use as initializer word." 118 help="A token to use as initializer word.",
130 ) 119 )
131 parser.add_argument( 120 parser.add_argument(
132 "--num_vectors", 121 "--num_vectors", type=int, nargs="*", help="Number of vectors per embedding."
133 type=int,
134 nargs='*',
135 help="Number of vectors per embedding."
136 ) 122 )
137 parser.add_argument( 123 parser.add_argument(
138 "--sequential", 124 "--sequential",
@@ -147,7 +133,7 @@ def parse_args():
147 "--num_class_images", 133 "--num_class_images",
148 type=int, 134 type=int,
149 default=0, 135 default=0,
150 help="How many class images to generate." 136 help="How many class images to generate.",
151 ) 137 )
152 parser.add_argument( 138 parser.add_argument(
153 "--class_image_dir", 139 "--class_image_dir",
@@ -158,7 +144,7 @@ def parse_args():
158 parser.add_argument( 144 parser.add_argument(
159 "--exclude_collections", 145 "--exclude_collections",
160 type=str, 146 type=str,
161 nargs='*', 147 nargs="*",
162 help="Exclude all items with a listed collection.", 148 help="Exclude all items with a listed collection.",
163 ) 149 )
164 parser.add_argument( 150 parser.add_argument(
@@ -181,14 +167,11 @@ def parse_args():
181 parser.add_argument( 167 parser.add_argument(
182 "--collection", 168 "--collection",
183 type=str, 169 type=str,
184 nargs='*', 170 nargs="*",
185 help="A collection to filter the dataset.", 171 help="A collection to filter the dataset.",
186 ) 172 )
187 parser.add_argument( 173 parser.add_argument(
188 "--seed", 174 "--seed", type=int, default=None, help="A seed for reproducible training."
189 type=int,
190 default=None,
191 help="A seed for reproducible training."
192 ) 175 )
193 parser.add_argument( 176 parser.add_argument(
194 "--resolution", 177 "--resolution",
@@ -244,7 +227,7 @@ def parse_args():
244 type=str, 227 type=str,
245 default="auto", 228 default="auto",
246 choices=["all", "trailing", "leading", "between", "auto", "off"], 229 choices=["all", "trailing", "leading", "between", "auto", "off"],
247 help='Vector shuffling algorithm.', 230 help="Vector shuffling algorithm.",
248 ) 231 )
249 parser.add_argument( 232 parser.add_argument(
250 "--offset_noise_strength", 233 "--offset_noise_strength",
@@ -256,18 +239,10 @@ def parse_args():
256 "--input_pertubation", 239 "--input_pertubation",
257 type=float, 240 type=float,
258 default=0, 241 default=0,
259 help="The scale of input pretubation. Recommended 0.1." 242 help="The scale of input pretubation. Recommended 0.1.",
260 )
261 parser.add_argument(
262 "--num_train_epochs",
263 type=int,
264 default=None
265 )
266 parser.add_argument(
267 "--num_train_steps",
268 type=int,
269 default=2000
270 ) 243 )
244 parser.add_argument("--num_train_epochs", type=int, default=None)
245 parser.add_argument("--num_train_steps", type=int, default=2000)
271 parser.add_argument( 246 parser.add_argument(
272 "--gradient_accumulation_steps", 247 "--gradient_accumulation_steps",
273 type=int, 248 type=int,
@@ -299,27 +274,31 @@ def parse_args():
299 "--lr_scheduler", 274 "--lr_scheduler",
300 type=str, 275 type=str,
301 default="one_cycle", 276 default="one_cycle",
302 choices=["linear", "cosine", "cosine_with_restarts", "polynomial", 277 choices=[
303 "constant", "constant_with_warmup", "one_cycle"], 278 "linear",
304 help='The scheduler type to use.', 279 "cosine",
280 "cosine_with_restarts",
281 "polynomial",
282 "constant",
283 "constant_with_warmup",
284 "one_cycle",
285 ],
286 help="The scheduler type to use.",
305 ) 287 )
306 parser.add_argument( 288 parser.add_argument(
307 "--lr_warmup_epochs", 289 "--lr_warmup_epochs",
308 type=int, 290 type=int,
309 default=10, 291 default=10,
310 help="Number of steps for the warmup in the lr scheduler." 292 help="Number of steps for the warmup in the lr scheduler.",
311 ) 293 )
312 parser.add_argument( 294 parser.add_argument(
313 "--lr_mid_point", 295 "--lr_mid_point", type=float, default=0.3, help="OneCycle schedule mid point."
314 type=float,
315 default=0.3,
316 help="OneCycle schedule mid point."
317 ) 296 )
318 parser.add_argument( 297 parser.add_argument(
319 "--lr_cycles", 298 "--lr_cycles",
320 type=int, 299 type=int,
321 default=None, 300 default=None,
322 help="Number of restart cycles in the lr scheduler." 301 help="Number of restart cycles in the lr scheduler.",
323 ) 302 )
324 parser.add_argument( 303 parser.add_argument(
325 "--lr_warmup_func", 304 "--lr_warmup_func",
@@ -331,7 +310,7 @@ def parse_args():
331 "--lr_warmup_exp", 310 "--lr_warmup_exp",
332 type=int, 311 type=int,
333 default=1, 312 default=1,
334 help='If lr_warmup_func is "cos", exponent to modify the function' 313 help='If lr_warmup_func is "cos", exponent to modify the function',
335 ) 314 )
336 parser.add_argument( 315 parser.add_argument(
337 "--lr_annealing_func", 316 "--lr_annealing_func",
@@ -343,89 +322,67 @@ def parse_args():
343 "--lr_annealing_exp", 322 "--lr_annealing_exp",
344 type=int, 323 type=int,
345 default=1, 324 default=1,
346 help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' 325 help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function',
347 ) 326 )
348 parser.add_argument( 327 parser.add_argument(
349 "--lr_min_lr", 328 "--lr_min_lr",
350 type=float, 329 type=float,
351 default=0.04, 330 default=0.04,
352 help="Minimum learning rate in the lr scheduler." 331 help="Minimum learning rate in the lr scheduler.",
353 ) 332 )
354 parser.add_argument( 333 parser.add_argument(
355 "--use_ema", 334 "--use_ema", action="store_true", help="Whether to use EMA model."
356 action="store_true",
357 help="Whether to use EMA model."
358 )
359 parser.add_argument(
360 "--ema_inv_gamma",
361 type=float,
362 default=1.0
363 )
364 parser.add_argument(
365 "--ema_power",
366 type=float,
367 default=4/5
368 )
369 parser.add_argument(
370 "--ema_max_decay",
371 type=float,
372 default=0.9999
373 )
374 parser.add_argument(
375 "--min_snr_gamma",
376 type=int,
377 default=5,
378 help="MinSNR gamma."
379 ) 335 )
336 parser.add_argument("--ema_inv_gamma", type=float, default=1.0)
337 parser.add_argument("--ema_power", type=float, default=4 / 5)
338 parser.add_argument("--ema_max_decay", type=float, default=0.9999)
339 parser.add_argument("--min_snr_gamma", type=int, default=5, help="MinSNR gamma.")
380 parser.add_argument( 340 parser.add_argument(
381 "--schedule_sampler", 341 "--schedule_sampler",
382 type=str, 342 type=str,
383 default="uniform", 343 default="uniform",
384 choices=["uniform", "loss-second-moment"], 344 choices=["uniform", "loss-second-moment"],
385 help="Noise schedule sampler." 345 help="Noise schedule sampler.",
386 ) 346 )
387 parser.add_argument( 347 parser.add_argument(
388 "--optimizer", 348 "--optimizer",
389 type=str, 349 type=str,
390 default="adan", 350 default="adan",
391 choices=["adam", "adam8bit", "adan", "lion", "dadam", "dadan", "adafactor"], 351 choices=["adam", "adam8bit", "adan", "lion", "dadam", "dadan", "adafactor"],
392 help='Optimizer to use' 352 help="Optimizer to use",
393 ) 353 )
394 parser.add_argument( 354 parser.add_argument(
395 "--dadaptation_d0", 355 "--dadaptation_d0",
396 type=float, 356 type=float,
397 default=1e-6, 357 default=1e-6,
398 help="The d0 parameter for Dadaptation optimizers." 358 help="The d0 parameter for Dadaptation optimizers.",
399 ) 359 )
400 parser.add_argument( 360 parser.add_argument(
401 "--adam_beta1", 361 "--adam_beta1",
402 type=float, 362 type=float,
403 default=None, 363 default=None,
404 help="The beta1 parameter for the Adam optimizer." 364 help="The beta1 parameter for the Adam optimizer.",
405 ) 365 )
406 parser.add_argument( 366 parser.add_argument(
407 "--adam_beta2", 367 "--adam_beta2",
408 type=float, 368 type=float,
409 default=None, 369 default=None,
410 help="The beta2 parameter for the Adam optimizer." 370 help="The beta2 parameter for the Adam optimizer.",
411 ) 371 )
412 parser.add_argument( 372 parser.add_argument(
413 "--adam_weight_decay", 373 "--adam_weight_decay", type=float, default=2e-2, help="Weight decay to use."
414 type=float,
415 default=2e-2,
416 help="Weight decay to use."
417 ) 374 )
418 parser.add_argument( 375 parser.add_argument(
419 "--adam_epsilon", 376 "--adam_epsilon",
420 type=float, 377 type=float,
421 default=1e-08, 378 default=1e-08,
422 help="Epsilon value for the Adam optimizer" 379 help="Epsilon value for the Adam optimizer",
423 ) 380 )
424 parser.add_argument( 381 parser.add_argument(
425 "--adam_amsgrad", 382 "--adam_amsgrad",
426 type=bool, 383 type=bool,
427 default=False, 384 default=False,
428 help="Amsgrad value for the Adam optimizer" 385 help="Amsgrad value for the Adam optimizer",
429 ) 386 )
430 parser.add_argument( 387 parser.add_argument(
431 "--mixed_precision", 388 "--mixed_precision",
@@ -456,7 +413,7 @@ def parse_args():
456 ) 413 )
457 parser.add_argument( 414 parser.add_argument(
458 "--no_milestone_checkpoints", 415 "--no_milestone_checkpoints",
459 action='store_true', 416 action="store_true",
460 help="If checkpoints are saved on maximum accuracy", 417 help="If checkpoints are saved on maximum accuracy",
461 ) 418 )
462 parser.add_argument( 419 parser.add_argument(
@@ -493,25 +450,25 @@ def parse_args():
493 "--valid_set_size", 450 "--valid_set_size",
494 type=int, 451 type=int,
495 default=None, 452 default=None,
496 help="Number of images in the validation dataset." 453 help="Number of images in the validation dataset.",
497 ) 454 )
498 parser.add_argument( 455 parser.add_argument(
499 "--train_set_pad", 456 "--train_set_pad",
500 type=int, 457 type=int,
501 default=None, 458 default=None,
502 help="The number to fill train dataset items up to." 459 help="The number to fill train dataset items up to.",
503 ) 460 )
504 parser.add_argument( 461 parser.add_argument(
505 "--valid_set_pad", 462 "--valid_set_pad",
506 type=int, 463 type=int,
507 default=None, 464 default=None,
508 help="The number to fill validation dataset items up to." 465 help="The number to fill validation dataset items up to.",
509 ) 466 )
510 parser.add_argument( 467 parser.add_argument(
511 "--train_batch_size", 468 "--train_batch_size",
512 type=int, 469 type=int,
513 default=1, 470 default=1,
514 help="Batch size (per device) for the training dataloader." 471 help="Batch size (per device) for the training dataloader.",
515 ) 472 )
516 parser.add_argument( 473 parser.add_argument(
517 "--sample_steps", 474 "--sample_steps",
@@ -523,14 +480,9 @@ def parse_args():
523 "--prior_loss_weight", 480 "--prior_loss_weight",
524 type=float, 481 type=float,
525 default=1.0, 482 default=1.0,
526 help="The weight of prior preservation loss." 483 help="The weight of prior preservation loss.",
527 )
528 parser.add_argument(
529 "--emb_alpha",
530 type=float,
531 default=1.0,
532 help="Embedding alpha"
533 ) 484 )
485 parser.add_argument("--emb_alpha", type=float, default=1.0, help="Embedding alpha")
534 parser.add_argument( 486 parser.add_argument(
535 "--emb_dropout", 487 "--emb_dropout",
536 type=float, 488 type=float,
@@ -538,21 +490,13 @@ def parse_args():
538 help="Embedding dropout probability.", 490 help="Embedding dropout probability.",
539 ) 491 )
540 parser.add_argument( 492 parser.add_argument(
541 "--use_emb_decay", 493 "--use_emb_decay", action="store_true", help="Whether to use embedding decay."
542 action="store_true",
543 help="Whether to use embedding decay."
544 ) 494 )
545 parser.add_argument( 495 parser.add_argument(
546 "--emb_decay_target", 496 "--emb_decay_target", default=0.4, type=float, help="Embedding decay target."
547 default=0.4,
548 type=float,
549 help="Embedding decay target."
550 ) 497 )
551 parser.add_argument( 498 parser.add_argument(
552 "--emb_decay", 499 "--emb_decay", default=1e2, type=float, help="Embedding decay factor."
553 default=1e+2,
554 type=float,
555 help="Embedding decay factor."
556 ) 500 )
557 parser.add_argument( 501 parser.add_argument(
558 "--noise_timesteps", 502 "--noise_timesteps",
@@ -563,7 +507,7 @@ def parse_args():
563 "--resume_from", 507 "--resume_from",
564 type=str, 508 type=str,
565 default=None, 509 default=None,
566 help="Path to a directory to resume training from (ie, logs/token_name/2022-09-22T23-36-27)" 510 help="Path to a directory to resume training from (ie, logs/token_name/2022-09-22T23-36-27)",
567 ) 511 )
568 parser.add_argument( 512 parser.add_argument(
569 "--global_step", 513 "--global_step",
@@ -574,7 +518,7 @@ def parse_args():
574 "--config", 518 "--config",
575 type=str, 519 type=str,
576 default=None, 520 default=None,
577 help="Path to a JSON configuration file containing arguments for invoking this script." 521 help="Path to a JSON configuration file containing arguments for invoking this script.",
578 ) 522 )
579 523
580 args = parser.parse_args() 524 args = parser.parse_args()
@@ -595,29 +539,44 @@ def parse_args():
595 args.placeholder_tokens = [args.placeholder_tokens] 539 args.placeholder_tokens = [args.placeholder_tokens]
596 540
597 if isinstance(args.initializer_tokens, str): 541 if isinstance(args.initializer_tokens, str):
598 args.initializer_tokens = [args.initializer_tokens] * len(args.placeholder_tokens) 542 args.initializer_tokens = [args.initializer_tokens] * len(
543 args.placeholder_tokens
544 )
599 545
600 if len(args.placeholder_tokens) == 0: 546 if len(args.placeholder_tokens) == 0:
601 args.placeholder_tokens = [f"<*{i}>" for i in range(len(args.initializer_tokens))] 547 args.placeholder_tokens = [
548 f"<*{i}>" for i in range(len(args.initializer_tokens))
549 ]
602 550
603 if len(args.initializer_tokens) == 0: 551 if len(args.initializer_tokens) == 0:
604 args.initializer_tokens = args.placeholder_tokens.copy() 552 args.initializer_tokens = args.placeholder_tokens.copy()
605 553
606 if len(args.placeholder_tokens) != len(args.initializer_tokens): 554 if len(args.placeholder_tokens) != len(args.initializer_tokens):
607 raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") 555 raise ValueError(
556 "--placeholder_tokens and --initializer_tokens must have the same number of items"
557 )
608 558
609 if isinstance(args.inverted_initializer_tokens, str): 559 if isinstance(args.inverted_initializer_tokens, str):
610 args.inverted_initializer_tokens = [args.inverted_initializer_tokens] * len(args.placeholder_tokens) 560 args.inverted_initializer_tokens = [args.inverted_initializer_tokens] * len(
561 args.placeholder_tokens
562 )
611 563
612 if isinstance(args.inverted_initializer_tokens, list) and len(args.inverted_initializer_tokens) != 0: 564 if (
565 isinstance(args.inverted_initializer_tokens, list)
566 and len(args.inverted_initializer_tokens) != 0
567 ):
613 args.placeholder_tokens += [f"inv_{t}" for t in args.placeholder_tokens] 568 args.placeholder_tokens += [f"inv_{t}" for t in args.placeholder_tokens]
614 args.initializer_tokens += args.inverted_initializer_tokens 569 args.initializer_tokens += args.inverted_initializer_tokens
615 570
616 if isinstance(args.num_vectors, int): 571 if isinstance(args.num_vectors, int):
617 args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) 572 args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens)
618 573
619 if isinstance(args.num_vectors, list) and len(args.placeholder_tokens) != len(args.num_vectors): 574 if isinstance(args.num_vectors, list) and len(args.placeholder_tokens) != len(
620 raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") 575 args.num_vectors
576 ):
577 raise ValueError(
578 "--placeholder_tokens and --num_vectors must have the same number of items"
579 )
621 580
622 if args.alias_tokens is None: 581 if args.alias_tokens is None:
623 args.alias_tokens = [] 582 args.alias_tokens = []
@@ -639,16 +598,22 @@ def parse_args():
639 ] 598 ]
640 599
641 if isinstance(args.train_data_template, str): 600 if isinstance(args.train_data_template, str):
642 args.train_data_template = [args.train_data_template] * len(args.placeholder_tokens) 601 args.train_data_template = [args.train_data_template] * len(
602 args.placeholder_tokens
603 )
643 604
644 if len(args.placeholder_tokens) != len(args.train_data_template): 605 if len(args.placeholder_tokens) != len(args.train_data_template):
645 raise ValueError("--placeholder_tokens and --train_data_template must have the same number of items") 606 raise ValueError(
607 "--placeholder_tokens and --train_data_template must have the same number of items"
608 )
646 609
647 if args.num_vectors is None: 610 if args.num_vectors is None:
648 args.num_vectors = [None] * len(args.placeholder_tokens) 611 args.num_vectors = [None] * len(args.placeholder_tokens)
649 else: 612 else:
650 if isinstance(args.train_data_template, list): 613 if isinstance(args.train_data_template, list):
651 raise ValueError("--train_data_template can't be a list in simultaneous mode") 614 raise ValueError(
615 "--train_data_template can't be a list in simultaneous mode"
616 )
652 617
653 if isinstance(args.collection, str): 618 if isinstance(args.collection, str):
654 args.collection = [args.collection] 619 args.collection = [args.collection]
@@ -660,13 +625,13 @@ def parse_args():
660 raise ValueError("You must specify --output_dir") 625 raise ValueError("You must specify --output_dir")
661 626
662 if args.adam_beta1 is None: 627 if args.adam_beta1 is None:
663 if args.optimizer == 'lion': 628 if args.optimizer == "lion":
664 args.adam_beta1 = 0.95 629 args.adam_beta1 = 0.95
665 else: 630 else:
666 args.adam_beta1 = 0.9 631 args.adam_beta1 = 0.9
667 632
668 if args.adam_beta2 is None: 633 if args.adam_beta2 is None:
669 if args.optimizer == 'lion': 634 if args.optimizer == "lion":
670 args.adam_beta2 = 0.98 635 args.adam_beta2 = 0.98
671 else: 636 else:
672 args.adam_beta2 = 0.999 637 args.adam_beta2 = 0.999
@@ -679,13 +644,13 @@ def main():
679 644
680 global_step_offset = args.global_step 645 global_step_offset = args.global_step
681 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") 646 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
682 output_dir = Path(args.output_dir)/slugify(args.project)/now 647 output_dir = Path(args.output_dir) / slugify(args.project) / now
683 output_dir.mkdir(parents=True, exist_ok=True) 648 output_dir.mkdir(parents=True, exist_ok=True)
684 649
685 accelerator = Accelerator( 650 accelerator = Accelerator(
686 log_with=LoggerType.TENSORBOARD, 651 log_with=LoggerType.TENSORBOARD,
687 project_dir=f"{output_dir}", 652 project_dir=f"{output_dir}",
688 mixed_precision=args.mixed_precision 653 mixed_precision=args.mixed_precision,
689 ) 654 )
690 655
691 weight_dtype = torch.float32 656 weight_dtype = torch.float32
@@ -703,9 +668,15 @@ def main():
703 668
704 save_args(output_dir, args) 669 save_args(output_dir, args)
705 670
706 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler = get_models(args.pretrained_model_name_or_path) 671 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler = get_models(
707 embeddings = patch_managed_embeddings(text_encoder, args.emb_alpha, args.emb_dropout) 672 args.pretrained_model_name_or_path
708 schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, noise_scheduler.config.num_train_timesteps) 673 )
674 embeddings = patch_managed_embeddings(
675 text_encoder, args.emb_alpha, args.emb_dropout
676 )
677 schedule_sampler = create_named_schedule_sampler(
678 args.schedule_sampler, noise_scheduler.config.num_train_timesteps
679 )
709 680
710 tokenizer.set_use_vector_shuffle(args.vector_shuffle) 681 tokenizer.set_use_vector_shuffle(args.vector_shuffle)
711 tokenizer.set_dropout(args.vector_dropout) 682 tokenizer.set_dropout(args.vector_dropout)
@@ -717,16 +688,16 @@ def main():
717 unet.enable_xformers_memory_efficient_attention() 688 unet.enable_xformers_memory_efficient_attention()
718 elif args.compile_unet: 689 elif args.compile_unet:
719 unet.mid_block.attentions[0].transformer_blocks[0].attn1._use_2_0_attn = False 690 unet.mid_block.attentions[0].transformer_blocks[0].attn1._use_2_0_attn = False
720 691
721 proc = AttnProcessor() 692 proc = AttnProcessor()
722 693
723 def fn_recursive_set_proc(module: torch.nn.Module): 694 def fn_recursive_set_proc(module: torch.nn.Module):
724 if hasattr(module, "processor"): 695 if hasattr(module, "processor"):
725 module.processor = proc 696 module.processor = proc
726 697
727 for child in module.children(): 698 for child in module.children():
728 fn_recursive_set_proc(child) 699 fn_recursive_set_proc(child)
729 700
730 fn_recursive_set_proc(unet) 701 fn_recursive_set_proc(unet)
731 702
732 if args.gradient_checkpointing: 703 if args.gradient_checkpointing:
@@ -751,18 +722,24 @@ def main():
751 tokenizer=tokenizer, 722 tokenizer=tokenizer,
752 embeddings=embeddings, 723 embeddings=embeddings,
753 placeholder_tokens=alias_placeholder_tokens, 724 placeholder_tokens=alias_placeholder_tokens,
754 initializer_tokens=alias_initializer_tokens 725 initializer_tokens=alias_initializer_tokens,
755 ) 726 )
756 embeddings.persist() 727 embeddings.persist()
757 print(f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}") 728 print(
729 f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}"
730 )
758 731
759 if args.embeddings_dir is not None: 732 if args.embeddings_dir is not None:
760 embeddings_dir = Path(args.embeddings_dir) 733 embeddings_dir = Path(args.embeddings_dir)
761 if not embeddings_dir.exists() or not embeddings_dir.is_dir(): 734 if not embeddings_dir.exists() or not embeddings_dir.is_dir():
762 raise ValueError("--embeddings_dir must point to an existing directory") 735 raise ValueError("--embeddings_dir must point to an existing directory")
763 736
764 added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) 737 added_tokens, added_ids = load_embeddings_from_dir(
765 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") 738 tokenizer, embeddings, embeddings_dir
739 )
740 print(
741 f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}"
742 )
766 743
767 if args.train_dir_embeddings: 744 if args.train_dir_embeddings:
768 args.placeholder_tokens = added_tokens 745 args.placeholder_tokens = added_tokens
@@ -772,19 +749,23 @@ def main():
772 749
773 if args.scale_lr: 750 if args.scale_lr:
774 args.learning_rate = ( 751 args.learning_rate = (
775 args.learning_rate * args.gradient_accumulation_steps * 752 args.learning_rate
776 args.train_batch_size * accelerator.num_processes 753 * args.gradient_accumulation_steps
754 * args.train_batch_size
755 * accelerator.num_processes
777 ) 756 )
778 757
779 if args.find_lr: 758 if args.find_lr:
780 args.learning_rate = 1e-5 759 args.learning_rate = 1e-5
781 args.lr_scheduler = "exponential_growth" 760 args.lr_scheduler = "exponential_growth"
782 761
783 if args.optimizer == 'adam8bit': 762 if args.optimizer == "adam8bit":
784 try: 763 try:
785 import bitsandbytes as bnb 764 import bitsandbytes as bnb
786 except ImportError: 765 except ImportError:
787 raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.") 766 raise ImportError(
767 "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
768 )
788 769
789 create_optimizer = partial( 770 create_optimizer = partial(
790 bnb.optim.AdamW8bit, 771 bnb.optim.AdamW8bit,
@@ -793,7 +774,7 @@ def main():
793 eps=args.adam_epsilon, 774 eps=args.adam_epsilon,
794 amsgrad=args.adam_amsgrad, 775 amsgrad=args.adam_amsgrad,
795 ) 776 )
796 elif args.optimizer == 'adam': 777 elif args.optimizer == "adam":
797 create_optimizer = partial( 778 create_optimizer = partial(
798 torch.optim.AdamW, 779 torch.optim.AdamW,
799 betas=(args.adam_beta1, args.adam_beta2), 780 betas=(args.adam_beta1, args.adam_beta2),
@@ -801,11 +782,13 @@ def main():
801 eps=args.adam_epsilon, 782 eps=args.adam_epsilon,
802 amsgrad=args.adam_amsgrad, 783 amsgrad=args.adam_amsgrad,
803 ) 784 )
804 elif args.optimizer == 'adan': 785 elif args.optimizer == "adan":
805 try: 786 try:
806 import timm.optim 787 import timm.optim
807 except ImportError: 788 except ImportError:
808 raise ImportError("To use Adan, please install the PyTorch Image Models library: `pip install timm`.") 789 raise ImportError(
790 "To use Adan, please install the PyTorch Image Models library: `pip install timm`."
791 )
809 792
810 create_optimizer = partial( 793 create_optimizer = partial(
811 timm.optim.Adan, 794 timm.optim.Adan,
@@ -813,11 +796,13 @@ def main():
813 eps=args.adam_epsilon, 796 eps=args.adam_epsilon,
814 no_prox=True, 797 no_prox=True,
815 ) 798 )
816 elif args.optimizer == 'lion': 799 elif args.optimizer == "lion":
817 try: 800 try:
818 import lion_pytorch 801 import lion_pytorch
819 except ImportError: 802 except ImportError:
820 raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion-pytorch`.") 803 raise ImportError(
804 "To use Lion, please install the lion_pytorch library: `pip install lion-pytorch`."
805 )
821 806
822 create_optimizer = partial( 807 create_optimizer = partial(
823 lion_pytorch.Lion, 808 lion_pytorch.Lion,
@@ -825,7 +810,7 @@ def main():
825 weight_decay=args.adam_weight_decay, 810 weight_decay=args.adam_weight_decay,
826 use_triton=True, 811 use_triton=True,
827 ) 812 )
828 elif args.optimizer == 'adafactor': 813 elif args.optimizer == "adafactor":
829 create_optimizer = partial( 814 create_optimizer = partial(
830 transformers.optimization.Adafactor, 815 transformers.optimization.Adafactor,
831 weight_decay=args.adam_weight_decay, 816 weight_decay=args.adam_weight_decay,
@@ -837,11 +822,13 @@ def main():
837 args.lr_scheduler = "adafactor" 822 args.lr_scheduler = "adafactor"
838 args.lr_min_lr = args.learning_rate 823 args.lr_min_lr = args.learning_rate
839 args.learning_rate = None 824 args.learning_rate = None
840 elif args.optimizer == 'dadam': 825 elif args.optimizer == "dadam":
841 try: 826 try:
842 import dadaptation 827 import dadaptation
843 except ImportError: 828 except ImportError:
844 raise ImportError("To use DAdaptAdam, please install the dadaptation library: `pip install dadaptation`.") 829 raise ImportError(
830 "To use DAdaptAdam, please install the dadaptation library: `pip install dadaptation`."
831 )
845 832
846 create_optimizer = partial( 833 create_optimizer = partial(
847 dadaptation.DAdaptAdam, 834 dadaptation.DAdaptAdam,
@@ -851,11 +838,13 @@ def main():
851 decouple=True, 838 decouple=True,
852 d0=args.dadaptation_d0, 839 d0=args.dadaptation_d0,
853 ) 840 )
854 elif args.optimizer == 'dadan': 841 elif args.optimizer == "dadan":
855 try: 842 try:
856 import dadaptation 843 import dadaptation
857 except ImportError: 844 except ImportError:
858 raise ImportError("To use DAdaptAdan, please install the dadaptation library: `pip install dadaptation`.") 845 raise ImportError(
846 "To use DAdaptAdan, please install the dadaptation library: `pip install dadaptation`."
847 )
859 848
860 create_optimizer = partial( 849 create_optimizer = partial(
861 dadaptation.DAdaptAdan, 850 dadaptation.DAdaptAdan,
@@ -864,7 +853,7 @@ def main():
864 d0=args.dadaptation_d0, 853 d0=args.dadaptation_d0,
865 ) 854 )
866 else: 855 else:
867 raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") 856 raise ValueError(f'Unknown --optimizer "{args.optimizer}"')
868 857
869 trainer = partial( 858 trainer = partial(
870 train, 859 train,
@@ -904,10 +893,21 @@ def main():
904 sample_image_size=args.sample_image_size, 893 sample_image_size=args.sample_image_size,
905 ) 894 )
906 895
896 optimizer = create_optimizer(
897 text_encoder.text_model.embeddings.token_embedding.parameters(),
898 lr=learning_rate,
899 )
900
907 data_generator = torch.Generator(device="cpu").manual_seed(args.seed) 901 data_generator = torch.Generator(device="cpu").manual_seed(args.seed)
908 data_npgenerator = np.random.default_rng(args.seed) 902 data_npgenerator = np.random.default_rng(args.seed)
909 903
910 def run(i: int, placeholder_tokens: list[str], initializer_tokens: list[str], num_vectors: Union[int, list[int]], data_template: str): 904 def run(
905 i: int,
906 placeholder_tokens: list[str],
907 initializer_tokens: list[str],
908 num_vectors: Union[int, list[int]],
909 data_template: str,
910 ):
911 placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( 911 placeholder_token_ids, initializer_token_ids = add_placeholder_tokens(
912 tokenizer=tokenizer, 912 tokenizer=tokenizer,
913 embeddings=embeddings, 913 embeddings=embeddings,
@@ -917,14 +917,23 @@ def main():
917 initializer_noise=args.initializer_noise, 917 initializer_noise=args.initializer_noise,
918 ) 918 )
919 919
920 stats = list(zip(placeholder_tokens, placeholder_token_ids, initializer_tokens, initializer_token_ids)) 920 stats = list(
921 zip(
922 placeholder_tokens,
923 placeholder_token_ids,
924 initializer_tokens,
925 initializer_token_ids,
926 )
927 )
921 928
922 print("") 929 print("")
923 print(f"============ TI batch {i + 1} ============") 930 print(f"============ TI batch {i + 1} ============")
924 print("") 931 print("")
925 print(stats) 932 print(stats)
926 933
927 filter_tokens = [token for token in args.filter_tokens if token in placeholder_tokens] 934 filter_tokens = [
935 token for token in args.filter_tokens if token in placeholder_tokens
936 ]
928 937
929 datamodule = VlpnDataModule( 938 datamodule = VlpnDataModule(
930 data_file=args.train_data_file, 939 data_file=args.train_data_file,
@@ -945,7 +954,9 @@ def main():
945 valid_set_size=args.valid_set_size, 954 valid_set_size=args.valid_set_size,
946 train_set_pad=args.train_set_pad, 955 train_set_pad=args.train_set_pad,
947 valid_set_pad=args.valid_set_pad, 956 valid_set_pad=args.valid_set_pad,
948 filter=partial(keyword_filter, filter_tokens, args.collection, args.exclude_collections), 957 filter=partial(
958 keyword_filter, filter_tokens, args.collection, args.exclude_collections
959 ),
949 dtype=weight_dtype, 960 dtype=weight_dtype,
950 generator=data_generator, 961 generator=data_generator,
951 npgenerator=data_npgenerator, 962 npgenerator=data_npgenerator,
@@ -955,11 +966,16 @@ def main():
955 num_train_epochs = args.num_train_epochs 966 num_train_epochs = args.num_train_epochs
956 sample_frequency = args.sample_frequency 967 sample_frequency = args.sample_frequency
957 if num_train_epochs is None: 968 if num_train_epochs is None:
958 num_train_epochs = math.ceil( 969 num_train_epochs = (
959 args.num_train_steps / len(datamodule.train_dataset) 970 math.ceil(args.num_train_steps / len(datamodule.train_dataset))
960 ) * args.gradient_accumulation_steps 971 * args.gradient_accumulation_steps
961 sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) 972 )
962 num_training_steps_per_epoch = math.ceil(len(datamodule.train_dataset) / args.gradient_accumulation_steps) 973 sample_frequency = math.ceil(
974 num_train_epochs * (sample_frequency / args.num_train_steps)
975 )
976 num_training_steps_per_epoch = math.ceil(
977 len(datamodule.train_dataset) / args.gradient_accumulation_steps
978 )
963 num_train_steps = num_training_steps_per_epoch * num_train_epochs 979 num_train_steps = num_training_steps_per_epoch * num_train_epochs
964 if args.sample_num is not None: 980 if args.sample_num is not None:
965 sample_frequency = math.ceil(num_train_epochs / args.sample_num) 981 sample_frequency = math.ceil(num_train_epochs / args.sample_num)
@@ -988,7 +1004,8 @@ def main():
988 response = auto_cycles.pop(0) 1004 response = auto_cycles.pop(0)
989 else: 1005 else:
990 response = input( 1006 response = input(
991 "\n### Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> ") 1007 "\n### Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> "
1008 )
992 1009
993 if response.lower().strip() == "o": 1010 if response.lower().strip() == "o":
994 if args.learning_rate is not None: 1011 if args.learning_rate is not None:
@@ -1018,10 +1035,8 @@ def main():
1018 print(f"------------ TI cycle {training_iter + 1}: {response} ------------") 1035 print(f"------------ TI cycle {training_iter + 1}: {response} ------------")
1019 print("") 1036 print("")
1020 1037
1021 optimizer = create_optimizer( 1038 for group, lr in zip(optimizer.param_groups, [learning_rate]):
1022 text_encoder.text_model.embeddings.token_embedding.parameters(), 1039 group["lr"] = lr
1023 lr=learning_rate,
1024 )
1025 1040
1026 lr_scheduler = get_scheduler( 1041 lr_scheduler = get_scheduler(
1027 lr_scheduler, 1042 lr_scheduler,
@@ -1040,7 +1055,9 @@ def main():
1040 mid_point=args.lr_mid_point, 1055 mid_point=args.lr_mid_point,
1041 ) 1056 )
1042 1057
1043 checkpoint_output_dir = output_dir / project / f"checkpoints_{training_iter}" 1058 checkpoint_output_dir = (
1059 output_dir / project / f"checkpoints_{training_iter}"
1060 )
1044 1061
1045 trainer( 1062 trainer(
1046 train_dataloader=datamodule.train_dataloader, 1063 train_dataloader=datamodule.train_dataloader,
@@ -1070,14 +1087,20 @@ def main():
1070 accelerator.end_training() 1087 accelerator.end_training()
1071 1088
1072 if not args.sequential: 1089 if not args.sequential:
1073 run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template) 1090 run(
1091 0,
1092 args.placeholder_tokens,
1093 args.initializer_tokens,
1094 args.num_vectors,
1095 args.train_data_template,
1096 )
1074 else: 1097 else:
1075 for i, placeholder_token, initializer_token, num_vectors, data_template in zip( 1098 for i, placeholder_token, initializer_token, num_vectors, data_template in zip(
1076 range(len(args.placeholder_tokens)), 1099 range(len(args.placeholder_tokens)),
1077 args.placeholder_tokens, 1100 args.placeholder_tokens,
1078 args.initializer_tokens, 1101 args.initializer_tokens,
1079 args.num_vectors, 1102 args.num_vectors,
1080 args.train_data_template 1103 args.train_data_template,
1081 ): 1104 ):
1082 run(i, [placeholder_token], [initializer_token], num_vectors, data_template) 1105 run(i, [placeholder_token], [initializer_token], num_vectors, data_template)
1083 embeddings.persist() 1106 embeddings.persist()