summaryrefslogtreecommitdiffstats
path: root/train_lora.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_lora.py')
-rw-r--r--train_lora.py489
1 files changed, 269 insertions, 220 deletions
diff --git a/train_lora.py b/train_lora.py
index c74dd8f..fccf48d 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -16,6 +16,7 @@ from accelerate import Accelerator
16from accelerate.logging import get_logger 16from accelerate.logging import get_logger
17from accelerate.utils import LoggerType, set_seed 17from accelerate.utils import LoggerType, set_seed
18from peft import LoraConfig, get_peft_model 18from peft import LoraConfig, get_peft_model
19
19# from diffusers.models.attention_processor import AttnProcessor 20# from diffusers.models.attention_processor import AttnProcessor
20from diffusers.utils.import_utils import is_xformers_available 21from diffusers.utils.import_utils import is_xformers_available
21import transformers 22import transformers
@@ -34,15 +35,20 @@ from util.files import load_config, load_embeddings_from_dir
34 35
35# https://github.com/huggingface/peft/blob/main/examples/lora_dreambooth/train_dreambooth.py 36# https://github.com/huggingface/peft/blob/main/examples/lora_dreambooth/train_dreambooth.py
36UNET_TARGET_MODULES_ORIG = ["to_q", "to_v", "query", "value"] 37UNET_TARGET_MODULES_ORIG = ["to_q", "to_v", "query", "value"]
37UNET_TARGET_MODULES = UNET_TARGET_MODULES_ORIG + ["to_out.0"] # ["to_k", "key"] 38UNET_TARGET_MODULES = UNET_TARGET_MODULES_ORIG + ["to_out.0", "to_k", "key"] # []
38TEXT_ENCODER_TARGET_MODULES_ORIG = ["q_proj", "v_proj"] 39TEXT_ENCODER_TARGET_MODULES_ORIG = ["q_proj", "v_proj"]
39TEXT_ENCODER_TARGET_MODULES = TEXT_ENCODER_TARGET_MODULES_ORIG + ["out_proj"] # ["k_proj"] 40TEXT_ENCODER_TARGET_MODULES = TEXT_ENCODER_TARGET_MODULES_ORIG + [
40TEXT_ENCODER_TARGET_MODULES_WITH_EMBEDDING = TEXT_ENCODER_TARGET_MODULES + ["token_embedding"] 41 "out_proj",
42 "k_proj",
43] # []
44TEXT_ENCODER_TARGET_MODULES_WITH_EMBEDDING = TEXT_ENCODER_TARGET_MODULES + [
45 "token_embedding"
46]
41 47
42 48
43logger = get_logger(__name__) 49logger = get_logger(__name__)
44 50
45warnings.filterwarnings('ignore') 51warnings.filterwarnings("ignore")
46 52
47 53
48torch.backends.cuda.matmul.allow_tf32 = True 54torch.backends.cuda.matmul.allow_tf32 = True
@@ -55,20 +61,27 @@ hidet.torch.dynamo_config.use_tensor_core(True)
55hidet.torch.dynamo_config.search_space(0) 61hidet.torch.dynamo_config.search_space(0)
56 62
57 63
58if is_xformers_available(): 64def patch_xformers(dtype):
59 import xformers 65 if is_xformers_available():
60 import xformers.ops 66 import xformers
61 67 import xformers.ops
62 orig_xformers_memory_efficient_attention = xformers.ops.memory_efficient_attention 68
63 def xformers_memory_efficient_attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs): 69 orig_xformers_memory_efficient_attention = (
64 return orig_xformers_memory_efficient_attention(query.to(key.dtype), key, value.to(key.dtype), **kwargs) 70 xformers.ops.memory_efficient_attention
65 xformers.ops.memory_efficient_attention = xformers_memory_efficient_attention 71 )
72
73 def xformers_memory_efficient_attention(
74 query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs
75 ):
76 return orig_xformers_memory_efficient_attention(
77 query.to(dtype), key.to(dtype), value.to(dtype), **kwargs
78 )
79
80 xformers.ops.memory_efficient_attention = xformers_memory_efficient_attention
66 81
67 82
68def parse_args(): 83def parse_args():
69 parser = argparse.ArgumentParser( 84 parser = argparse.ArgumentParser(description="Simple example of a training script.")
70 description="Simple example of a training script."
71 )
72 parser.add_argument( 85 parser.add_argument(
73 "--pretrained_model_name_or_path", 86 "--pretrained_model_name_or_path",
74 type=str, 87 type=str,
@@ -85,7 +98,7 @@ def parse_args():
85 "--train_data_file", 98 "--train_data_file",
86 type=str, 99 type=str,
87 default=None, 100 default=None,
88 help="A folder containing the training data." 101 help="A folder containing the training data.",
89 ) 102 )
90 parser.add_argument( 103 parser.add_argument(
91 "--train_data_template", 104 "--train_data_template",
@@ -96,13 +109,13 @@ def parse_args():
96 "--train_set_pad", 109 "--train_set_pad",
97 type=int, 110 type=int,
98 default=None, 111 default=None,
99 help="The number to fill train dataset items up to." 112 help="The number to fill train dataset items up to.",
100 ) 113 )
101 parser.add_argument( 114 parser.add_argument(
102 "--valid_set_pad", 115 "--valid_set_pad",
103 type=int, 116 type=int,
104 default=None, 117 default=None,
105 help="The number to fill validation dataset items up to." 118 help="The number to fill validation dataset items up to.",
106 ) 119 )
107 parser.add_argument( 120 parser.add_argument(
108 "--project", 121 "--project",
@@ -111,64 +124,52 @@ def parse_args():
111 help="The name of the current project.", 124 help="The name of the current project.",
112 ) 125 )
113 parser.add_argument( 126 parser.add_argument(
114 "--auto_cycles", 127 "--auto_cycles", type=str, default="o", help="Cycles to run automatically."
115 type=str,
116 default="o",
117 help="Cycles to run automatically."
118 ) 128 )
119 parser.add_argument( 129 parser.add_argument(
120 "--cycle_decay", 130 "--cycle_decay", type=float, default=1.0, help="Learning rate decay per cycle."
121 type=float,
122 default=1.0,
123 help="Learning rate decay per cycle."
124 ) 131 )
125 parser.add_argument( 132 parser.add_argument(
126 "--placeholder_tokens", 133 "--placeholder_tokens",
127 type=str, 134 type=str,
128 nargs='*', 135 nargs="*",
129 help="A token to use as a placeholder for the concept.", 136 help="A token to use as a placeholder for the concept.",
130 ) 137 )
131 parser.add_argument( 138 parser.add_argument(
132 "--initializer_tokens", 139 "--initializer_tokens",
133 type=str, 140 type=str,
134 nargs='*', 141 nargs="*",
135 help="A token to use as initializer word." 142 help="A token to use as initializer word.",
136 ) 143 )
137 parser.add_argument( 144 parser.add_argument(
138 "--filter_tokens", 145 "--filter_tokens", type=str, nargs="*", help="Tokens to filter the dataset by."
139 type=str,
140 nargs='*',
141 help="Tokens to filter the dataset by."
142 ) 146 )
143 parser.add_argument( 147 parser.add_argument(
144 "--initializer_noise", 148 "--initializer_noise",
145 type=float, 149 type=float,
146 default=0, 150 default=0,
147 help="Noise to apply to the initializer word" 151 help="Noise to apply to the initializer word",
148 ) 152 )
149 parser.add_argument( 153 parser.add_argument(
150 "--alias_tokens", 154 "--alias_tokens",
151 type=str, 155 type=str,
152 nargs='*', 156 nargs="*",
153 default=[], 157 default=[],
154 help="Tokens to create an alias for." 158 help="Tokens to create an alias for.",
155 ) 159 )
156 parser.add_argument( 160 parser.add_argument(
157 "--inverted_initializer_tokens", 161 "--inverted_initializer_tokens",
158 type=str, 162 type=str,
159 nargs='*', 163 nargs="*",
160 help="A token to use as initializer word." 164 help="A token to use as initializer word.",
161 ) 165 )
162 parser.add_argument( 166 parser.add_argument(
163 "--num_vectors", 167 "--num_vectors", type=int, nargs="*", help="Number of vectors per embedding."
164 type=int,
165 nargs='*',
166 help="Number of vectors per embedding."
167 ) 168 )
168 parser.add_argument( 169 parser.add_argument(
169 "--exclude_collections", 170 "--exclude_collections",
170 type=str, 171 type=str,
171 nargs='*', 172 nargs="*",
172 help="Exclude all items with a listed collection.", 173 help="Exclude all items with a listed collection.",
173 ) 174 )
174 parser.add_argument( 175 parser.add_argument(
@@ -214,7 +215,7 @@ def parse_args():
214 "--num_class_images", 215 "--num_class_images",
215 type=int, 216 type=int,
216 default=0, 217 default=0,
217 help="How many class images to generate." 218 help="How many class images to generate.",
218 ) 219 )
219 parser.add_argument( 220 parser.add_argument(
220 "--class_image_dir", 221 "--class_image_dir",
@@ -242,14 +243,11 @@ def parse_args():
242 parser.add_argument( 243 parser.add_argument(
243 "--collection", 244 "--collection",
244 type=str, 245 type=str,
245 nargs='*', 246 nargs="*",
246 help="A collection to filter the dataset.", 247 help="A collection to filter the dataset.",
247 ) 248 )
248 parser.add_argument( 249 parser.add_argument(
249 "--seed", 250 "--seed", type=int, default=None, help="A seed for reproducible training."
250 type=int,
251 default=None,
252 help="A seed for reproducible training."
253 ) 251 )
254 parser.add_argument( 252 parser.add_argument(
255 "--resolution", 253 "--resolution",
@@ -270,18 +268,10 @@ def parse_args():
270 "--input_pertubation", 268 "--input_pertubation",
271 type=float, 269 type=float,
272 default=0, 270 default=0,
273 help="The scale of input pretubation. Recommended 0.1." 271 help="The scale of input pretubation. Recommended 0.1.",
274 )
275 parser.add_argument(
276 "--num_train_epochs",
277 type=int,
278 default=None
279 )
280 parser.add_argument(
281 "--num_train_steps",
282 type=int,
283 default=2000
284 ) 272 )
273 parser.add_argument("--num_train_epochs", type=int, default=None)
274 parser.add_argument("--num_train_steps", type=int, default=2000)
285 parser.add_argument( 275 parser.add_argument(
286 "--gradient_accumulation_steps", 276 "--gradient_accumulation_steps",
287 type=int, 277 type=int,
@@ -289,22 +279,19 @@ def parse_args():
289 help="Number of updates steps to accumulate before performing a backward/update pass.", 279 help="Number of updates steps to accumulate before performing a backward/update pass.",
290 ) 280 )
291 parser.add_argument( 281 parser.add_argument(
292 "--lora_r", 282 "--lora_r", type=int, default=8, help="Lora rank, only used if use_lora is True"
293 type=int,
294 default=8,
295 help="Lora rank, only used if use_lora is True"
296 ) 283 )
297 parser.add_argument( 284 parser.add_argument(
298 "--lora_alpha", 285 "--lora_alpha",
299 type=int, 286 type=int,
300 default=32, 287 default=32,
301 help="Lora alpha, only used if use_lora is True" 288 help="Lora alpha, only used if use_lora is True",
302 ) 289 )
303 parser.add_argument( 290 parser.add_argument(
304 "--lora_dropout", 291 "--lora_dropout",
305 type=float, 292 type=float,
306 default=0.0, 293 default=0.0,
307 help="Lora dropout, only used if use_lora is True" 294 help="Lora dropout, only used if use_lora is True",
308 ) 295 )
309 parser.add_argument( 296 parser.add_argument(
310 "--lora_bias", 297 "--lora_bias",
@@ -344,7 +331,7 @@ def parse_args():
344 parser.add_argument( 331 parser.add_argument(
345 "--train_text_encoder_cycles", 332 "--train_text_encoder_cycles",
346 default=999999, 333 default=999999,
347 help="Number of epochs the text encoder will be trained." 334 help="Number of epochs the text encoder will be trained.",
348 ) 335 )
349 parser.add_argument( 336 parser.add_argument(
350 "--find_lr", 337 "--find_lr",
@@ -378,27 +365,31 @@ def parse_args():
378 "--lr_scheduler", 365 "--lr_scheduler",
379 type=str, 366 type=str,
380 default="one_cycle", 367 default="one_cycle",
381 choices=["linear", "cosine", "cosine_with_restarts", "polynomial", 368 choices=[
382 "constant", "constant_with_warmup", "one_cycle"], 369 "linear",
383 help='The scheduler type to use.', 370 "cosine",
371 "cosine_with_restarts",
372 "polynomial",
373 "constant",
374 "constant_with_warmup",
375 "one_cycle",
376 ],
377 help="The scheduler type to use.",
384 ) 378 )
385 parser.add_argument( 379 parser.add_argument(
386 "--lr_warmup_epochs", 380 "--lr_warmup_epochs",
387 type=int, 381 type=int,
388 default=10, 382 default=10,
389 help="Number of steps for the warmup in the lr scheduler." 383 help="Number of steps for the warmup in the lr scheduler.",
390 ) 384 )
391 parser.add_argument( 385 parser.add_argument(
392 "--lr_mid_point", 386 "--lr_mid_point", type=float, default=0.3, help="OneCycle schedule mid point."
393 type=float,
394 default=0.3,
395 help="OneCycle schedule mid point."
396 ) 387 )
397 parser.add_argument( 388 parser.add_argument(
398 "--lr_cycles", 389 "--lr_cycles",
399 type=int, 390 type=int,
400 default=None, 391 default=None,
401 help="Number of restart cycles in the lr scheduler (if supported)." 392 help="Number of restart cycles in the lr scheduler (if supported).",
402 ) 393 )
403 parser.add_argument( 394 parser.add_argument(
404 "--lr_warmup_func", 395 "--lr_warmup_func",
@@ -410,7 +401,7 @@ def parse_args():
410 "--lr_warmup_exp", 401 "--lr_warmup_exp",
411 type=int, 402 type=int,
412 default=1, 403 default=1,
413 help='If lr_warmup_func is "cos", exponent to modify the function' 404 help='If lr_warmup_func is "cos", exponent to modify the function',
414 ) 405 )
415 parser.add_argument( 406 parser.add_argument(
416 "--lr_annealing_func", 407 "--lr_annealing_func",
@@ -422,69 +413,76 @@ def parse_args():
422 "--lr_annealing_exp", 413 "--lr_annealing_exp",
423 type=int, 414 type=int,
424 default=3, 415 default=3,
425 help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' 416 help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function',
426 ) 417 )
427 parser.add_argument( 418 parser.add_argument(
428 "--lr_min_lr", 419 "--lr_min_lr",
429 type=float, 420 type=float,
430 default=0.04, 421 default=0.04,
431 help="Minimum learning rate in the lr scheduler." 422 help="Minimum learning rate in the lr scheduler.",
432 )
433 parser.add_argument(
434 "--min_snr_gamma",
435 type=int,
436 default=5,
437 help="MinSNR gamma."
438 ) 423 )
424 parser.add_argument("--min_snr_gamma", type=int, default=5, help="MinSNR gamma.")
439 parser.add_argument( 425 parser.add_argument(
440 "--schedule_sampler", 426 "--schedule_sampler",
441 type=str, 427 type=str,
442 default="uniform", 428 default="uniform",
443 choices=["uniform", "loss-second-moment"], 429 choices=["uniform", "loss-second-moment"],
444 help="Noise schedule sampler." 430 help="Noise schedule sampler.",
445 ) 431 )
446 parser.add_argument( 432 parser.add_argument(
447 "--optimizer", 433 "--optimizer",
448 type=str, 434 type=str,
449 default="adan", 435 default="adan",
450 choices=["adam", "adam8bit", "adan", "lion", "dadam", "dadan", "adafactor"], 436 choices=[
451 help='Optimizer to use' 437 "adam",
438 "adam8bit",
439 "adan",
440 "lion",
441 "dadam",
442 "dadan",
443 "dlion",
444 "adafactor",
445 ],
446 help="Optimizer to use",
452 ) 447 )
453 parser.add_argument( 448 parser.add_argument(
454 "--dadaptation_d0", 449 "--dadaptation_d0",
455 type=float, 450 type=float,
456 default=1e-6, 451 default=1e-6,
457 help="The d0 parameter for Dadaptation optimizers." 452 help="The d0 parameter for Dadaptation optimizers.",
453 )
454 parser.add_argument(
455 "--dadaptation_growth_rate",
456 type=float,
457 default=math.inf,
458 help="The growth_rate parameter for Dadaptation optimizers.",
458 ) 459 )
459 parser.add_argument( 460 parser.add_argument(
460 "--adam_beta1", 461 "--adam_beta1",
461 type=float, 462 type=float,
462 default=None, 463 default=None,
463 help="The beta1 parameter for the Adam optimizer." 464 help="The beta1 parameter for the Adam optimizer.",
464 ) 465 )
465 parser.add_argument( 466 parser.add_argument(
466 "--adam_beta2", 467 "--adam_beta2",
467 type=float, 468 type=float,
468 default=None, 469 default=None,
469 help="The beta2 parameter for the Adam optimizer." 470 help="The beta2 parameter for the Adam optimizer.",
470 ) 471 )
471 parser.add_argument( 472 parser.add_argument(
472 "--adam_weight_decay", 473 "--adam_weight_decay", type=float, default=2e-2, help="Weight decay to use."
473 type=float,
474 default=2e-2,
475 help="Weight decay to use."
476 ) 474 )
477 parser.add_argument( 475 parser.add_argument(
478 "--adam_epsilon", 476 "--adam_epsilon",
479 type=float, 477 type=float,
480 default=1e-08, 478 default=1e-08,
481 help="Epsilon value for the Adam optimizer" 479 help="Epsilon value for the Adam optimizer",
482 ) 480 )
483 parser.add_argument( 481 parser.add_argument(
484 "--adam_amsgrad", 482 "--adam_amsgrad",
485 type=bool, 483 type=bool,
486 default=False, 484 default=False,
487 help="Amsgrad value for the Adam optimizer" 485 help="Amsgrad value for the Adam optimizer",
488 ) 486 )
489 parser.add_argument( 487 parser.add_argument(
490 "--mixed_precision", 488 "--mixed_precision",
@@ -547,19 +545,19 @@ def parse_args():
547 "--valid_set_size", 545 "--valid_set_size",
548 type=int, 546 type=int,
549 default=None, 547 default=None,
550 help="Number of images in the validation dataset." 548 help="Number of images in the validation dataset.",
551 ) 549 )
552 parser.add_argument( 550 parser.add_argument(
553 "--valid_set_repeat", 551 "--valid_set_repeat",
554 type=int, 552 type=int,
555 default=1, 553 default=1,
556 help="Times the images in the validation dataset are repeated." 554 help="Times the images in the validation dataset are repeated.",
557 ) 555 )
558 parser.add_argument( 556 parser.add_argument(
559 "--train_batch_size", 557 "--train_batch_size",
560 type=int, 558 type=int,
561 default=1, 559 default=1,
562 help="Batch size (per device) for the training dataloader." 560 help="Batch size (per device) for the training dataloader.",
563 ) 561 )
564 parser.add_argument( 562 parser.add_argument(
565 "--sample_steps", 563 "--sample_steps",
@@ -571,19 +569,10 @@ def parse_args():
571 "--prior_loss_weight", 569 "--prior_loss_weight",
572 type=float, 570 type=float,
573 default=1.0, 571 default=1.0,
574 help="The weight of prior preservation loss." 572 help="The weight of prior preservation loss.",
575 )
576 parser.add_argument(
577 "--run_pti",
578 action="store_true",
579 help="Whether to run PTI."
580 )
581 parser.add_argument(
582 "--emb_alpha",
583 type=float,
584 default=1.0,
585 help="Embedding alpha"
586 ) 573 )
574 parser.add_argument("--run_pti", action="store_true", help="Whether to run PTI.")
575 parser.add_argument("--emb_alpha", type=float, default=1.0, help="Embedding alpha")
587 parser.add_argument( 576 parser.add_argument(
588 "--emb_dropout", 577 "--emb_dropout",
589 type=float, 578 type=float,
@@ -591,27 +580,16 @@ def parse_args():
591 help="Embedding dropout probability.", 580 help="Embedding dropout probability.",
592 ) 581 )
593 parser.add_argument( 582 parser.add_argument(
594 "--use_emb_decay", 583 "--use_emb_decay", action="store_true", help="Whether to use embedding decay."
595 action="store_true",
596 help="Whether to use embedding decay."
597 ) 584 )
598 parser.add_argument( 585 parser.add_argument(
599 "--emb_decay_target", 586 "--emb_decay_target", default=0.4, type=float, help="Embedding decay target."
600 default=0.4,
601 type=float,
602 help="Embedding decay target."
603 ) 587 )
604 parser.add_argument( 588 parser.add_argument(
605 "--emb_decay", 589 "--emb_decay", default=1e2, type=float, help="Embedding decay factor."
606 default=1e+2,
607 type=float,
608 help="Embedding decay factor."
609 ) 590 )
610 parser.add_argument( 591 parser.add_argument(
611 "--max_grad_norm", 592 "--max_grad_norm", default=1.0, type=float, help="Max gradient norm."
612 default=1.0,
613 type=float,
614 help="Max gradient norm."
615 ) 593 )
616 parser.add_argument( 594 parser.add_argument(
617 "--noise_timesteps", 595 "--noise_timesteps",
@@ -622,7 +600,7 @@ def parse_args():
622 "--config", 600 "--config",
623 type=str, 601 type=str,
624 default=None, 602 default=None,
625 help="Path to a JSON configuration file containing arguments for invoking this script." 603 help="Path to a JSON configuration file containing arguments for invoking this script.",
626 ) 604 )
627 605
628 args = parser.parse_args() 606 args = parser.parse_args()
@@ -649,29 +627,44 @@ def parse_args():
649 args.placeholder_tokens = [args.placeholder_tokens] 627 args.placeholder_tokens = [args.placeholder_tokens]
650 628
651 if isinstance(args.initializer_tokens, str): 629 if isinstance(args.initializer_tokens, str):
652 args.initializer_tokens = [args.initializer_tokens] * len(args.placeholder_tokens) 630 args.initializer_tokens = [args.initializer_tokens] * len(
631 args.placeholder_tokens
632 )
653 633
654 if len(args.placeholder_tokens) == 0: 634 if len(args.placeholder_tokens) == 0:
655 args.placeholder_tokens = [f"<*{i}>" for i in range(len(args.initializer_tokens))] 635 args.placeholder_tokens = [
636 f"<*{i}>" for i in range(len(args.initializer_tokens))
637 ]
656 638
657 if len(args.initializer_tokens) == 0: 639 if len(args.initializer_tokens) == 0:
658 args.initializer_tokens = args.placeholder_tokens.copy() 640 args.initializer_tokens = args.placeholder_tokens.copy()
659 641
660 if len(args.placeholder_tokens) != len(args.initializer_tokens): 642 if len(args.placeholder_tokens) != len(args.initializer_tokens):
661 raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") 643 raise ValueError(
644 "--placeholder_tokens and --initializer_tokens must have the same number of items"
645 )
662 646
663 if isinstance(args.inverted_initializer_tokens, str): 647 if isinstance(args.inverted_initializer_tokens, str):
664 args.inverted_initializer_tokens = [args.inverted_initializer_tokens] * len(args.placeholder_tokens) 648 args.inverted_initializer_tokens = [args.inverted_initializer_tokens] * len(
649 args.placeholder_tokens
650 )
665 651
666 if isinstance(args.inverted_initializer_tokens, list) and len(args.inverted_initializer_tokens) != 0: 652 if (
653 isinstance(args.inverted_initializer_tokens, list)
654 and len(args.inverted_initializer_tokens) != 0
655 ):
667 args.placeholder_tokens += [f"inv_{t}" for t in args.placeholder_tokens] 656 args.placeholder_tokens += [f"inv_{t}" for t in args.placeholder_tokens]
668 args.initializer_tokens += args.inverted_initializer_tokens 657 args.initializer_tokens += args.inverted_initializer_tokens
669 658
670 if isinstance(args.num_vectors, int): 659 if isinstance(args.num_vectors, int):
671 args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) 660 args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens)
672 661
673 if isinstance(args.num_vectors, list) and len(args.placeholder_tokens) != len(args.num_vectors): 662 if isinstance(args.num_vectors, list) and len(args.placeholder_tokens) != len(
674 raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") 663 args.num_vectors
664 ):
665 raise ValueError(
666 "--placeholder_tokens and --num_vectors must have the same number of items"
667 )
675 668
676 if args.alias_tokens is None: 669 if args.alias_tokens is None:
677 args.alias_tokens = [] 670 args.alias_tokens = []
@@ -695,15 +688,15 @@ def parse_args():
695 raise ValueError("You must specify --output_dir") 688 raise ValueError("You must specify --output_dir")
696 689
697 if args.adam_beta1 is None: 690 if args.adam_beta1 is None:
698 if args.optimizer in ('adam', 'adam8bit'): 691 if args.optimizer in ("adam", "adam8bit", "dadam"):
699 args.adam_beta1 = 0.9 692 args.adam_beta1 = 0.9
700 elif args.optimizer == 'lion': 693 elif args.optimizer in ("lion", "dlion"):
701 args.adam_beta1 = 0.95 694 args.adam_beta1 = 0.95
702 695
703 if args.adam_beta2 is None: 696 if args.adam_beta2 is None:
704 if args.optimizer in ('adam', 'adam8bit'): 697 if args.optimizer in ("adam", "adam8bit", "dadam"):
705 args.adam_beta2 = 0.999 698 args.adam_beta2 = 0.999
706 elif args.optimizer == 'lion': 699 elif args.optimizer in ("lion", "dlion"):
707 args.adam_beta2 = 0.98 700 args.adam_beta2 = 0.98
708 701
709 return args 702 return args
@@ -719,7 +712,7 @@ def main():
719 accelerator = Accelerator( 712 accelerator = Accelerator(
720 log_with=LoggerType.TENSORBOARD, 713 log_with=LoggerType.TENSORBOARD,
721 project_dir=f"{output_dir}", 714 project_dir=f"{output_dir}",
722 mixed_precision=args.mixed_precision 715 mixed_precision=args.mixed_precision,
723 ) 716 )
724 717
725 weight_dtype = torch.float32 718 weight_dtype = torch.float32
@@ -728,6 +721,8 @@ def main():
728 elif args.mixed_precision == "bf16": 721 elif args.mixed_precision == "bf16":
729 weight_dtype = torch.bfloat16 722 weight_dtype = torch.bfloat16
730 723
724 patch_xformers(weight_dtype)
725
731 logging.basicConfig(filename=output_dir / "log.txt", level=logging.DEBUG) 726 logging.basicConfig(filename=output_dir / "log.txt", level=logging.DEBUG)
732 727
733 if args.seed is None: 728 if args.seed is None:
@@ -737,12 +732,18 @@ def main():
737 732
738 save_args(output_dir, args) 733 save_args(output_dir, args)
739 734
740 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler = get_models(args.pretrained_model_name_or_path) 735 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler = get_models(
741 schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, noise_scheduler.config.num_train_timesteps) 736 args.pretrained_model_name_or_path
742 737 )
738 schedule_sampler = create_named_schedule_sampler(
739 args.schedule_sampler, noise_scheduler.config.num_train_timesteps
740 )
741
743 def ensure_embeddings(): 742 def ensure_embeddings():
744 if args.lora_text_encoder_emb: 743 if args.lora_text_encoder_emb:
745 raise ValueError("Can't use TI options when training token embeddings with LoRA") 744 raise ValueError(
745 "Can't use TI options when training token embeddings with LoRA"
746 )
746 return patch_managed_embeddings(text_encoder, args.emb_alpha, args.emb_dropout) 747 return patch_managed_embeddings(text_encoder, args.emb_alpha, args.emb_dropout)
747 748
748 unet_config = LoraConfig( 749 unet_config = LoraConfig(
@@ -757,7 +758,9 @@ def main():
757 text_encoder_config = LoraConfig( 758 text_encoder_config = LoraConfig(
758 r=args.lora_text_encoder_r, 759 r=args.lora_text_encoder_r,
759 lora_alpha=args.lora_text_encoder_alpha, 760 lora_alpha=args.lora_text_encoder_alpha,
760 target_modules=TEXT_ENCODER_TARGET_MODULES_WITH_EMBEDDING if args.lora_text_encoder_emb else TEXT_ENCODER_TARGET_MODULES, 761 target_modules=TEXT_ENCODER_TARGET_MODULES_WITH_EMBEDDING
762 if args.lora_text_encoder_emb
763 else TEXT_ENCODER_TARGET_MODULES,
761 lora_dropout=args.lora_text_encoder_dropout, 764 lora_dropout=args.lora_text_encoder_dropout,
762 bias=args.lora_text_encoder_bias, 765 bias=args.lora_text_encoder_bias,
763 ) 766 )
@@ -787,7 +790,7 @@ def main():
787 790
788 if len(args.alias_tokens) != 0: 791 if len(args.alias_tokens) != 0:
789 embeddings = ensure_embeddings() 792 embeddings = ensure_embeddings()
790 793
791 alias_placeholder_tokens = args.alias_tokens[::2] 794 alias_placeholder_tokens = args.alias_tokens[::2]
792 alias_initializer_tokens = args.alias_tokens[1::2] 795 alias_initializer_tokens = args.alias_tokens[1::2]
793 796
@@ -795,27 +798,33 @@ def main():
795 tokenizer=tokenizer, 798 tokenizer=tokenizer,
796 embeddings=embeddings, 799 embeddings=embeddings,
797 placeholder_tokens=alias_placeholder_tokens, 800 placeholder_tokens=alias_placeholder_tokens,
798 initializer_tokens=alias_initializer_tokens 801 initializer_tokens=alias_initializer_tokens,
799 ) 802 )
800 embeddings.persist() 803 embeddings.persist()
801 print(f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}") 804 print(
805 f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}"
806 )
802 807
803 placeholder_tokens = [] 808 placeholder_tokens = []
804 placeholder_token_ids = [] 809 placeholder_token_ids = []
805 810
806 if args.embeddings_dir is not None: 811 if args.embeddings_dir is not None:
807 embeddings = ensure_embeddings() 812 embeddings = ensure_embeddings()
808 813
809 embeddings_dir = Path(args.embeddings_dir) 814 embeddings_dir = Path(args.embeddings_dir)
810 if not embeddings_dir.exists() or not embeddings_dir.is_dir(): 815 if not embeddings_dir.exists() or not embeddings_dir.is_dir():
811 raise ValueError("--embeddings_dir must point to an existing directory") 816 raise ValueError("--embeddings_dir must point to an existing directory")
812 817
813 added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) 818 added_tokens, added_ids = load_embeddings_from_dir(
819 tokenizer, embeddings, embeddings_dir
820 )
814 821
815 placeholder_tokens = added_tokens 822 placeholder_tokens = added_tokens
816 placeholder_token_ids = added_ids 823 placeholder_token_ids = added_ids
817 824
818 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") 825 print(
826 f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}"
827 )
819 828
820 if args.train_dir_embeddings: 829 if args.train_dir_embeddings:
821 print("Training embeddings from embeddings dir") 830 print("Training embeddings from embeddings dir")
@@ -824,7 +833,7 @@ def main():
824 833
825 if len(args.placeholder_tokens) != 0 and not args.train_dir_embeddings: 834 if len(args.placeholder_tokens) != 0 and not args.train_dir_embeddings:
826 embeddings = ensure_embeddings() 835 embeddings = ensure_embeddings()
827 836
828 placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( 837 placeholder_token_ids, initializer_token_ids = add_placeholder_tokens(
829 tokenizer=tokenizer, 838 tokenizer=tokenizer,
830 embeddings=embeddings, 839 embeddings=embeddings,
@@ -836,23 +845,34 @@ def main():
836 845
837 placeholder_tokens = args.placeholder_tokens 846 placeholder_tokens = args.placeholder_tokens
838 847
839 stats = list(zip( 848 stats = list(
840 placeholder_tokens, placeholder_token_ids, args.initializer_tokens, initializer_token_ids 849 zip(
841 )) 850 placeholder_tokens,
851 placeholder_token_ids,
852 args.initializer_tokens,
853 initializer_token_ids,
854 )
855 )
842 print(f"Training embeddings: {stats}") 856 print(f"Training embeddings: {stats}")
843 857
844 if args.scale_lr: 858 if args.scale_lr:
845 args.learning_rate_unet = ( 859 args.learning_rate_unet = (
846 args.learning_rate_unet * args.gradient_accumulation_steps * 860 args.learning_rate_unet
847 args.train_batch_size * accelerator.num_processes 861 * args.gradient_accumulation_steps
862 * args.train_batch_size
863 * accelerator.num_processes
848 ) 864 )
849 args.learning_rate_text = ( 865 args.learning_rate_text = (
850 args.learning_rate_text * args.gradient_accumulation_steps * 866 args.learning_rate_text
851 args.train_batch_size * accelerator.num_processes 867 * args.gradient_accumulation_steps
868 * args.train_batch_size
869 * accelerator.num_processes
852 ) 870 )
853 args.learning_rate_emb = ( 871 args.learning_rate_emb = (
854 args.learning_rate_emb * args.gradient_accumulation_steps * 872 args.learning_rate_emb
855 args.train_batch_size * accelerator.num_processes 873 * args.gradient_accumulation_steps
874 * args.train_batch_size
875 * accelerator.num_processes
856 ) 876 )
857 877
858 if args.find_lr: 878 if args.find_lr:
@@ -861,11 +881,13 @@ def main():
861 args.learning_rate_emb = 1e-6 881 args.learning_rate_emb = 1e-6
862 args.lr_scheduler = "exponential_growth" 882 args.lr_scheduler = "exponential_growth"
863 883
864 if args.optimizer == 'adam8bit': 884 if args.optimizer == "adam8bit":
865 try: 885 try:
866 import bitsandbytes as bnb 886 import bitsandbytes as bnb
867 except ImportError: 887 except ImportError:
868 raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.") 888 raise ImportError(
889 "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
890 )
869 891
870 create_optimizer = partial( 892 create_optimizer = partial(
871 bnb.optim.AdamW8bit, 893 bnb.optim.AdamW8bit,
@@ -874,7 +896,7 @@ def main():
874 eps=args.adam_epsilon, 896 eps=args.adam_epsilon,
875 amsgrad=args.adam_amsgrad, 897 amsgrad=args.adam_amsgrad,
876 ) 898 )
877 elif args.optimizer == 'adam': 899 elif args.optimizer == "adam":
878 create_optimizer = partial( 900 create_optimizer = partial(
879 torch.optim.AdamW, 901 torch.optim.AdamW,
880 betas=(args.adam_beta1, args.adam_beta2), 902 betas=(args.adam_beta1, args.adam_beta2),
@@ -882,11 +904,13 @@ def main():
882 eps=args.adam_epsilon, 904 eps=args.adam_epsilon,
883 amsgrad=args.adam_amsgrad, 905 amsgrad=args.adam_amsgrad,
884 ) 906 )
885 elif args.optimizer == 'adan': 907 elif args.optimizer == "adan":
886 try: 908 try:
887 import timm.optim 909 import timm.optim
888 except ImportError: 910 except ImportError:
889 raise ImportError("To use Adan, please install the PyTorch Image Models library: `pip install timm`.") 911 raise ImportError(
912 "To use Adan, please install the PyTorch Image Models library: `pip install timm`."
913 )
890 914
891 create_optimizer = partial( 915 create_optimizer = partial(
892 timm.optim.Adan, 916 timm.optim.Adan,
@@ -894,11 +918,13 @@ def main():
894 eps=args.adam_epsilon, 918 eps=args.adam_epsilon,
895 no_prox=True, 919 no_prox=True,
896 ) 920 )
897 elif args.optimizer == 'lion': 921 elif args.optimizer == "lion":
898 try: 922 try:
899 import lion_pytorch 923 import lion_pytorch
900 except ImportError: 924 except ImportError:
901 raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion-pytorch`.") 925 raise ImportError(
926 "To use Lion, please install the lion_pytorch library: `pip install lion-pytorch`."
927 )
902 928
903 create_optimizer = partial( 929 create_optimizer = partial(
904 lion_pytorch.Lion, 930 lion_pytorch.Lion,
@@ -906,7 +932,7 @@ def main():
906 weight_decay=args.adam_weight_decay, 932 weight_decay=args.adam_weight_decay,
907 use_triton=True, 933 use_triton=True,
908 ) 934 )
909 elif args.optimizer == 'adafactor': 935 elif args.optimizer == "adafactor":
910 create_optimizer = partial( 936 create_optimizer = partial(
911 transformers.optimization.Adafactor, 937 transformers.optimization.Adafactor,
912 weight_decay=args.adam_weight_decay, 938 weight_decay=args.adam_weight_decay,
@@ -920,11 +946,13 @@ def main():
920 args.learning_rate_unet = None 946 args.learning_rate_unet = None
921 args.learning_rate_text = None 947 args.learning_rate_text = None
922 args.learning_rate_emb = None 948 args.learning_rate_emb = None
923 elif args.optimizer == 'dadam': 949 elif args.optimizer == "dadam":
924 try: 950 try:
925 import dadaptation 951 import dadaptation
926 except ImportError: 952 except ImportError:
927 raise ImportError("To use DAdaptAdam, please install the dadaptation library: `pip install dadaptation`.") 953 raise ImportError(
954 "To use DAdaptAdam, please install the dadaptation library: `pip install dadaptation`."
955 )
928 956
929 create_optimizer = partial( 957 create_optimizer = partial(
930 dadaptation.DAdaptAdam, 958 dadaptation.DAdaptAdam,
@@ -933,29 +961,35 @@ def main():
933 eps=args.adam_epsilon, 961 eps=args.adam_epsilon,
934 decouple=True, 962 decouple=True,
935 d0=args.dadaptation_d0, 963 d0=args.dadaptation_d0,
964 growth_rate=args.dadaptation_growth_rate,
936 ) 965 )
937 966
938 args.learning_rate_unet = 1.0 967 args.learning_rate_unet = 1.0
939 args.learning_rate_text = 1.0 968 args.learning_rate_text = 1.0
940 args.learning_rate_emb = 1.0 969 args.learning_rate_emb = 1.0
941 elif args.optimizer == 'dadan': 970 elif args.optimizer == "dadan":
942 try: 971 try:
943 import dadaptation 972 import dadaptation
944 except ImportError: 973 except ImportError:
945 raise ImportError("To use DAdaptAdan, please install the dadaptation library: `pip install dadaptation`.") 974 raise ImportError(
975 "To use DAdaptAdan, please install the dadaptation library: `pip install dadaptation`."
976 )
946 977
947 create_optimizer = partial( 978 create_optimizer = partial(
948 dadaptation.DAdaptAdan, 979 dadaptation.DAdaptAdan,
949 weight_decay=args.adam_weight_decay, 980 weight_decay=args.adam_weight_decay,
950 eps=args.adam_epsilon, 981 eps=args.adam_epsilon,
951 d0=args.dadaptation_d0, 982 d0=args.dadaptation_d0,
983 growth_rate=args.dadaptation_growth_rate,
952 ) 984 )
953 985
954 args.learning_rate_unet = 1.0 986 args.learning_rate_unet = 1.0
955 args.learning_rate_text = 1.0 987 args.learning_rate_text = 1.0
956 args.learning_rate_emb = 1.0 988 args.learning_rate_emb = 1.0
989 elif args.optimizer == "dlion":
990 raise ImportError("DLion has not been merged into dadaptation yet")
957 else: 991 else:
958 raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") 992 raise ValueError(f'Unknown --optimizer "{args.optimizer}"')
959 993
960 trainer = partial( 994 trainer = partial(
961 train, 995 train,
@@ -1026,25 +1060,33 @@ def main():
1026 1060
1027 if args.run_pti and len(placeholder_tokens) != 0: 1061 if args.run_pti and len(placeholder_tokens) != 0:
1028 embeddings = ensure_embeddings() 1062 embeddings = ensure_embeddings()
1029 1063
1030 filter_tokens = [token for token in args.filter_tokens if token in placeholder_tokens] 1064 filter_tokens = [
1065 token for token in args.filter_tokens if token in placeholder_tokens
1066 ]
1031 1067
1032 pti_datamodule = create_datamodule( 1068 pti_datamodule = create_datamodule(
1033 valid_set_size=0, 1069 valid_set_size=0,
1034 batch_size=args.train_batch_size, 1070 batch_size=args.train_batch_size,
1035 filter=partial(keyword_filter, filter_tokens, args.collection, args.exclude_collections), 1071 filter=partial(
1072 keyword_filter, filter_tokens, args.collection, args.exclude_collections
1073 ),
1036 ) 1074 )
1037 pti_datamodule.setup() 1075 pti_datamodule.setup()
1038 1076
1039 num_train_epochs = args.num_train_epochs 1077 num_train_epochs = args.num_train_epochs
1040 pti_sample_frequency = args.sample_frequency 1078 pti_sample_frequency = args.sample_frequency
1041 if num_train_epochs is None: 1079 if num_train_epochs is None:
1042 num_train_epochs = math.ceil( 1080 num_train_epochs = (
1043 args.num_train_steps / len(pti_datamodule.train_dataset) 1081 math.ceil(args.num_train_steps / len(pti_datamodule.train_dataset))
1044 ) * args.gradient_accumulation_steps 1082 * args.gradient_accumulation_steps
1045 pti_sample_frequency = math.ceil(num_train_epochs * (pti_sample_frequency / args.num_train_steps)) 1083 )
1084 pti_sample_frequency = math.ceil(
1085 num_train_epochs * (pti_sample_frequency / args.num_train_steps)
1086 )
1046 num_training_steps_per_epoch = math.ceil( 1087 num_training_steps_per_epoch = math.ceil(
1047 len(pti_datamodule.train_dataset) / args.gradient_accumulation_steps) 1088 len(pti_datamodule.train_dataset) / args.gradient_accumulation_steps
1089 )
1048 num_train_steps = num_training_steps_per_epoch * num_train_epochs 1090 num_train_steps = num_training_steps_per_epoch * num_train_epochs
1049 if args.sample_num is not None: 1091 if args.sample_num is not None:
1050 pti_sample_frequency = math.ceil(num_train_epochs / args.sample_num) 1092 pti_sample_frequency = math.ceil(num_train_epochs / args.sample_num)
@@ -1060,11 +1102,15 @@ def main():
1060 print(f"============ PTI ============") 1102 print(f"============ PTI ============")
1061 print("") 1103 print("")
1062 1104
1063 pti_optimizer = create_optimizer([{ 1105 pti_optimizer = create_optimizer(
1064 "params": text_encoder.text_model.embeddings.token_embedding.parameters(), 1106 [
1065 "lr": args.learning_rate_emb, 1107 {
1066 "weight_decay": 0, 1108 "params": text_encoder.text_model.embeddings.token_embedding.parameters(),
1067 }]) 1109 "lr": args.learning_rate_emb,
1110 "weight_decay": 0,
1111 }
1112 ]
1113 )
1068 1114
1069 pti_lr_scheduler = create_lr_scheduler( 1115 pti_lr_scheduler = create_lr_scheduler(
1070 "constant_with_warmup", 1116 "constant_with_warmup",
@@ -1113,11 +1159,16 @@ def main():
1113 num_train_epochs = args.num_train_epochs 1159 num_train_epochs = args.num_train_epochs
1114 lora_sample_frequency = args.sample_frequency 1160 lora_sample_frequency = args.sample_frequency
1115 if num_train_epochs is None: 1161 if num_train_epochs is None:
1116 num_train_epochs = math.ceil( 1162 num_train_epochs = (
1117 args.num_train_steps / len(lora_datamodule.train_dataset) 1163 math.ceil(args.num_train_steps / len(lora_datamodule.train_dataset))
1118 ) * args.gradient_accumulation_steps 1164 * args.gradient_accumulation_steps
1119 lora_sample_frequency = math.ceil(num_train_epochs * (lora_sample_frequency / args.num_train_steps)) 1165 )
1120 num_training_steps_per_epoch = math.ceil(len(lora_datamodule.train_dataset) / args.gradient_accumulation_steps) 1166 lora_sample_frequency = math.ceil(
1167 num_train_epochs * (lora_sample_frequency / args.num_train_steps)
1168 )
1169 num_training_steps_per_epoch = math.ceil(
1170 len(lora_datamodule.train_dataset) / args.gradient_accumulation_steps
1171 )
1121 num_train_steps = num_training_steps_per_epoch * num_train_epochs 1172 num_train_steps = num_training_steps_per_epoch * num_train_epochs
1122 if args.sample_num is not None: 1173 if args.sample_num is not None:
1123 lora_sample_frequency = math.ceil(num_train_epochs / args.sample_num) 1174 lora_sample_frequency = math.ceil(num_train_epochs / args.sample_num)
@@ -1131,7 +1182,6 @@ def main():
1131 1182
1132 training_iter = 0 1183 training_iter = 0
1133 auto_cycles = list(args.auto_cycles) 1184 auto_cycles = list(args.auto_cycles)
1134 learning_rate_emb = args.learning_rate_emb
1135 learning_rate_unet = args.learning_rate_unet 1185 learning_rate_unet = args.learning_rate_unet
1136 learning_rate_text = args.learning_rate_text 1186 learning_rate_text = args.learning_rate_text
1137 lr_scheduler = args.lr_scheduler 1187 lr_scheduler = args.lr_scheduler
@@ -1145,21 +1195,15 @@ def main():
1145 1195
1146 params_to_optimize = [ 1196 params_to_optimize = [
1147 { 1197 {
1148 "params": ( 1198 "params": (param for param in unet.parameters() if param.requires_grad),
1149 param
1150 for param in unet.parameters()
1151 if param.requires_grad
1152 ),
1153 "lr": learning_rate_unet, 1199 "lr": learning_rate_unet,
1154 }, 1200 },
1155 { 1201 {
1156 "params": ( 1202 "params": (
1157 param 1203 param for param in text_encoder.parameters() if param.requires_grad
1158 for param in text_encoder.parameters()
1159 if param.requires_grad
1160 ), 1204 ),
1161 "lr": learning_rate_text, 1205 "lr": learning_rate_text,
1162 } 1206 },
1163 ] 1207 ]
1164 group_labels = ["unet", "text"] 1208 group_labels = ["unet", "text"]
1165 1209
@@ -1169,19 +1213,26 @@ def main():
1169 if len(auto_cycles) != 0: 1213 if len(auto_cycles) != 0:
1170 response = auto_cycles.pop(0) 1214 response = auto_cycles.pop(0)
1171 else: 1215 else:
1172 response = input("\n### Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> ") 1216 response = input(
1217 "\n### Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> "
1218 )
1173 1219
1174 if response.lower().strip() == "o": 1220 if response.lower().strip() == "o":
1175 if args.learning_rate_emb is not None:
1176 learning_rate_emb = args.learning_rate_emb * 2
1177 if args.learning_rate_unet is not None: 1221 if args.learning_rate_unet is not None:
1178 learning_rate_unet = args.learning_rate_unet * 2 1222 learning_rate_unet = (
1223 args.learning_rate_unet * 2 * (args.cycle_decay**training_iter)
1224 )
1179 if args.learning_rate_text is not None: 1225 if args.learning_rate_text is not None:
1180 learning_rate_text = args.learning_rate_text * 2 1226 learning_rate_text = (
1227 args.learning_rate_text * 2 * (args.cycle_decay**training_iter)
1228 )
1181 else: 1229 else:
1182 learning_rate_emb = args.learning_rate_emb 1230 learning_rate_unet = args.learning_rate_unet * (
1183 learning_rate_unet = args.learning_rate_unet 1231 args.cycle_decay**training_iter
1184 learning_rate_text = args.learning_rate_text 1232 )
1233 learning_rate_text = args.learning_rate_text * (
1234 args.cycle_decay**training_iter
1235 )
1185 1236
1186 if response.lower().strip() == "o": 1237 if response.lower().strip() == "o":
1187 lr_scheduler = "one_cycle" 1238 lr_scheduler = "one_cycle"
@@ -1204,9 +1255,11 @@ def main():
1204 print("") 1255 print("")
1205 print(f"============ LoRA cycle {training_iter + 1}: {response} ============") 1256 print(f"============ LoRA cycle {training_iter + 1}: {response} ============")
1206 print("") 1257 print("")
1207 1258
1208 for group, lr in zip(lora_optimizer.param_groups, [learning_rate_unet, learning_rate_text]): 1259 for group, lr in zip(
1209 group['lr'] = lr 1260 lora_optimizer.param_groups, [learning_rate_unet, learning_rate_text]
1261 ):
1262 group["lr"] = lr
1210 1263
1211 lora_lr_scheduler = create_lr_scheduler( 1264 lora_lr_scheduler = create_lr_scheduler(
1212 lr_scheduler, 1265 lr_scheduler,
@@ -1218,7 +1271,9 @@ def main():
1218 warmup_epochs=lr_warmup_epochs, 1271 warmup_epochs=lr_warmup_epochs,
1219 ) 1272 )
1220 1273
1221 lora_checkpoint_output_dir = output_dir / lora_project / f"model_{training_iter}" 1274 lora_checkpoint_output_dir = (
1275 output_dir / lora_project / f"model_{training_iter}"
1276 )
1222 1277
1223 trainer( 1278 trainer(
1224 strategy=lora_strategy, 1279 strategy=lora_strategy,
@@ -1246,12 +1301,6 @@ def main():
1246 ) 1301 )
1247 1302
1248 training_iter += 1 1303 training_iter += 1
1249 if learning_rate_emb is not None:
1250 learning_rate_emb *= args.cycle_decay
1251 if learning_rate_unet is not None:
1252 learning_rate_unet *= args.cycle_decay
1253 if learning_rate_text is not None:
1254 learning_rate_text *= args.cycle_decay
1255 1304
1256 accelerator.end_training() 1305 accelerator.end_training()
1257 1306