diff options
Diffstat (limited to 'train_lora.py')
-rw-r--r-- | train_lora.py | 489 |
1 files changed, 269 insertions, 220 deletions
diff --git a/train_lora.py b/train_lora.py index c74dd8f..fccf48d 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -16,6 +16,7 @@ from accelerate import Accelerator | |||
16 | from accelerate.logging import get_logger | 16 | from accelerate.logging import get_logger |
17 | from accelerate.utils import LoggerType, set_seed | 17 | from accelerate.utils import LoggerType, set_seed |
18 | from peft import LoraConfig, get_peft_model | 18 | from peft import LoraConfig, get_peft_model |
19 | |||
19 | # from diffusers.models.attention_processor import AttnProcessor | 20 | # from diffusers.models.attention_processor import AttnProcessor |
20 | from diffusers.utils.import_utils import is_xformers_available | 21 | from diffusers.utils.import_utils import is_xformers_available |
21 | import transformers | 22 | import transformers |
@@ -34,15 +35,20 @@ from util.files import load_config, load_embeddings_from_dir | |||
34 | 35 | ||
35 | # https://github.com/huggingface/peft/blob/main/examples/lora_dreambooth/train_dreambooth.py | 36 | # https://github.com/huggingface/peft/blob/main/examples/lora_dreambooth/train_dreambooth.py |
36 | UNET_TARGET_MODULES_ORIG = ["to_q", "to_v", "query", "value"] | 37 | UNET_TARGET_MODULES_ORIG = ["to_q", "to_v", "query", "value"] |
37 | UNET_TARGET_MODULES = UNET_TARGET_MODULES_ORIG + ["to_out.0"] # ["to_k", "key"] | 38 | UNET_TARGET_MODULES = UNET_TARGET_MODULES_ORIG + ["to_out.0", "to_k", "key"] # [] |
38 | TEXT_ENCODER_TARGET_MODULES_ORIG = ["q_proj", "v_proj"] | 39 | TEXT_ENCODER_TARGET_MODULES_ORIG = ["q_proj", "v_proj"] |
39 | TEXT_ENCODER_TARGET_MODULES = TEXT_ENCODER_TARGET_MODULES_ORIG + ["out_proj"] # ["k_proj"] | 40 | TEXT_ENCODER_TARGET_MODULES = TEXT_ENCODER_TARGET_MODULES_ORIG + [ |
40 | TEXT_ENCODER_TARGET_MODULES_WITH_EMBEDDING = TEXT_ENCODER_TARGET_MODULES + ["token_embedding"] | 41 | "out_proj", |
42 | "k_proj", | ||
43 | ] # [] | ||
44 | TEXT_ENCODER_TARGET_MODULES_WITH_EMBEDDING = TEXT_ENCODER_TARGET_MODULES + [ | ||
45 | "token_embedding" | ||
46 | ] | ||
41 | 47 | ||
42 | 48 | ||
43 | logger = get_logger(__name__) | 49 | logger = get_logger(__name__) |
44 | 50 | ||
45 | warnings.filterwarnings('ignore') | 51 | warnings.filterwarnings("ignore") |
46 | 52 | ||
47 | 53 | ||
48 | torch.backends.cuda.matmul.allow_tf32 = True | 54 | torch.backends.cuda.matmul.allow_tf32 = True |
@@ -55,20 +61,27 @@ hidet.torch.dynamo_config.use_tensor_core(True) | |||
55 | hidet.torch.dynamo_config.search_space(0) | 61 | hidet.torch.dynamo_config.search_space(0) |
56 | 62 | ||
57 | 63 | ||
58 | if is_xformers_available(): | 64 | def patch_xformers(dtype): |
59 | import xformers | 65 | if is_xformers_available(): |
60 | import xformers.ops | 66 | import xformers |
61 | 67 | import xformers.ops | |
62 | orig_xformers_memory_efficient_attention = xformers.ops.memory_efficient_attention | 68 | |
63 | def xformers_memory_efficient_attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs): | 69 | orig_xformers_memory_efficient_attention = ( |
64 | return orig_xformers_memory_efficient_attention(query.to(key.dtype), key, value.to(key.dtype), **kwargs) | 70 | xformers.ops.memory_efficient_attention |
65 | xformers.ops.memory_efficient_attention = xformers_memory_efficient_attention | 71 | ) |
72 | |||
73 | def xformers_memory_efficient_attention( | ||
74 | query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs | ||
75 | ): | ||
76 | return orig_xformers_memory_efficient_attention( | ||
77 | query.to(dtype), key.to(dtype), value.to(dtype), **kwargs | ||
78 | ) | ||
79 | |||
80 | xformers.ops.memory_efficient_attention = xformers_memory_efficient_attention | ||
66 | 81 | ||
67 | 82 | ||
68 | def parse_args(): | 83 | def parse_args(): |
69 | parser = argparse.ArgumentParser( | 84 | parser = argparse.ArgumentParser(description="Simple example of a training script.") |
70 | description="Simple example of a training script." | ||
71 | ) | ||
72 | parser.add_argument( | 85 | parser.add_argument( |
73 | "--pretrained_model_name_or_path", | 86 | "--pretrained_model_name_or_path", |
74 | type=str, | 87 | type=str, |
@@ -85,7 +98,7 @@ def parse_args(): | |||
85 | "--train_data_file", | 98 | "--train_data_file", |
86 | type=str, | 99 | type=str, |
87 | default=None, | 100 | default=None, |
88 | help="A folder containing the training data." | 101 | help="A folder containing the training data.", |
89 | ) | 102 | ) |
90 | parser.add_argument( | 103 | parser.add_argument( |
91 | "--train_data_template", | 104 | "--train_data_template", |
@@ -96,13 +109,13 @@ def parse_args(): | |||
96 | "--train_set_pad", | 109 | "--train_set_pad", |
97 | type=int, | 110 | type=int, |
98 | default=None, | 111 | default=None, |
99 | help="The number to fill train dataset items up to." | 112 | help="The number to fill train dataset items up to.", |
100 | ) | 113 | ) |
101 | parser.add_argument( | 114 | parser.add_argument( |
102 | "--valid_set_pad", | 115 | "--valid_set_pad", |
103 | type=int, | 116 | type=int, |
104 | default=None, | 117 | default=None, |
105 | help="The number to fill validation dataset items up to." | 118 | help="The number to fill validation dataset items up to.", |
106 | ) | 119 | ) |
107 | parser.add_argument( | 120 | parser.add_argument( |
108 | "--project", | 121 | "--project", |
@@ -111,64 +124,52 @@ def parse_args(): | |||
111 | help="The name of the current project.", | 124 | help="The name of the current project.", |
112 | ) | 125 | ) |
113 | parser.add_argument( | 126 | parser.add_argument( |
114 | "--auto_cycles", | 127 | "--auto_cycles", type=str, default="o", help="Cycles to run automatically." |
115 | type=str, | ||
116 | default="o", | ||
117 | help="Cycles to run automatically." | ||
118 | ) | 128 | ) |
119 | parser.add_argument( | 129 | parser.add_argument( |
120 | "--cycle_decay", | 130 | "--cycle_decay", type=float, default=1.0, help="Learning rate decay per cycle." |
121 | type=float, | ||
122 | default=1.0, | ||
123 | help="Learning rate decay per cycle." | ||
124 | ) | 131 | ) |
125 | parser.add_argument( | 132 | parser.add_argument( |
126 | "--placeholder_tokens", | 133 | "--placeholder_tokens", |
127 | type=str, | 134 | type=str, |
128 | nargs='*', | 135 | nargs="*", |
129 | help="A token to use as a placeholder for the concept.", | 136 | help="A token to use as a placeholder for the concept.", |
130 | ) | 137 | ) |
131 | parser.add_argument( | 138 | parser.add_argument( |
132 | "--initializer_tokens", | 139 | "--initializer_tokens", |
133 | type=str, | 140 | type=str, |
134 | nargs='*', | 141 | nargs="*", |
135 | help="A token to use as initializer word." | 142 | help="A token to use as initializer word.", |
136 | ) | 143 | ) |
137 | parser.add_argument( | 144 | parser.add_argument( |
138 | "--filter_tokens", | 145 | "--filter_tokens", type=str, nargs="*", help="Tokens to filter the dataset by." |
139 | type=str, | ||
140 | nargs='*', | ||
141 | help="Tokens to filter the dataset by." | ||
142 | ) | 146 | ) |
143 | parser.add_argument( | 147 | parser.add_argument( |
144 | "--initializer_noise", | 148 | "--initializer_noise", |
145 | type=float, | 149 | type=float, |
146 | default=0, | 150 | default=0, |
147 | help="Noise to apply to the initializer word" | 151 | help="Noise to apply to the initializer word", |
148 | ) | 152 | ) |
149 | parser.add_argument( | 153 | parser.add_argument( |
150 | "--alias_tokens", | 154 | "--alias_tokens", |
151 | type=str, | 155 | type=str, |
152 | nargs='*', | 156 | nargs="*", |
153 | default=[], | 157 | default=[], |
154 | help="Tokens to create an alias for." | 158 | help="Tokens to create an alias for.", |
155 | ) | 159 | ) |
156 | parser.add_argument( | 160 | parser.add_argument( |
157 | "--inverted_initializer_tokens", | 161 | "--inverted_initializer_tokens", |
158 | type=str, | 162 | type=str, |
159 | nargs='*', | 163 | nargs="*", |
160 | help="A token to use as initializer word." | 164 | help="A token to use as initializer word.", |
161 | ) | 165 | ) |
162 | parser.add_argument( | 166 | parser.add_argument( |
163 | "--num_vectors", | 167 | "--num_vectors", type=int, nargs="*", help="Number of vectors per embedding." |
164 | type=int, | ||
165 | nargs='*', | ||
166 | help="Number of vectors per embedding." | ||
167 | ) | 168 | ) |
168 | parser.add_argument( | 169 | parser.add_argument( |
169 | "--exclude_collections", | 170 | "--exclude_collections", |
170 | type=str, | 171 | type=str, |
171 | nargs='*', | 172 | nargs="*", |
172 | help="Exclude all items with a listed collection.", | 173 | help="Exclude all items with a listed collection.", |
173 | ) | 174 | ) |
174 | parser.add_argument( | 175 | parser.add_argument( |
@@ -214,7 +215,7 @@ def parse_args(): | |||
214 | "--num_class_images", | 215 | "--num_class_images", |
215 | type=int, | 216 | type=int, |
216 | default=0, | 217 | default=0, |
217 | help="How many class images to generate." | 218 | help="How many class images to generate.", |
218 | ) | 219 | ) |
219 | parser.add_argument( | 220 | parser.add_argument( |
220 | "--class_image_dir", | 221 | "--class_image_dir", |
@@ -242,14 +243,11 @@ def parse_args(): | |||
242 | parser.add_argument( | 243 | parser.add_argument( |
243 | "--collection", | 244 | "--collection", |
244 | type=str, | 245 | type=str, |
245 | nargs='*', | 246 | nargs="*", |
246 | help="A collection to filter the dataset.", | 247 | help="A collection to filter the dataset.", |
247 | ) | 248 | ) |
248 | parser.add_argument( | 249 | parser.add_argument( |
249 | "--seed", | 250 | "--seed", type=int, default=None, help="A seed for reproducible training." |
250 | type=int, | ||
251 | default=None, | ||
252 | help="A seed for reproducible training." | ||
253 | ) | 251 | ) |
254 | parser.add_argument( | 252 | parser.add_argument( |
255 | "--resolution", | 253 | "--resolution", |
@@ -270,18 +268,10 @@ def parse_args(): | |||
270 | "--input_pertubation", | 268 | "--input_pertubation", |
271 | type=float, | 269 | type=float, |
272 | default=0, | 270 | default=0, |
273 | help="The scale of input pretubation. Recommended 0.1." | 271 | help="The scale of input pretubation. Recommended 0.1.", |
274 | ) | ||
275 | parser.add_argument( | ||
276 | "--num_train_epochs", | ||
277 | type=int, | ||
278 | default=None | ||
279 | ) | ||
280 | parser.add_argument( | ||
281 | "--num_train_steps", | ||
282 | type=int, | ||
283 | default=2000 | ||
284 | ) | 272 | ) |
273 | parser.add_argument("--num_train_epochs", type=int, default=None) | ||
274 | parser.add_argument("--num_train_steps", type=int, default=2000) | ||
285 | parser.add_argument( | 275 | parser.add_argument( |
286 | "--gradient_accumulation_steps", | 276 | "--gradient_accumulation_steps", |
287 | type=int, | 277 | type=int, |
@@ -289,22 +279,19 @@ def parse_args(): | |||
289 | help="Number of updates steps to accumulate before performing a backward/update pass.", | 279 | help="Number of updates steps to accumulate before performing a backward/update pass.", |
290 | ) | 280 | ) |
291 | parser.add_argument( | 281 | parser.add_argument( |
292 | "--lora_r", | 282 | "--lora_r", type=int, default=8, help="Lora rank, only used if use_lora is True" |
293 | type=int, | ||
294 | default=8, | ||
295 | help="Lora rank, only used if use_lora is True" | ||
296 | ) | 283 | ) |
297 | parser.add_argument( | 284 | parser.add_argument( |
298 | "--lora_alpha", | 285 | "--lora_alpha", |
299 | type=int, | 286 | type=int, |
300 | default=32, | 287 | default=32, |
301 | help="Lora alpha, only used if use_lora is True" | 288 | help="Lora alpha, only used if use_lora is True", |
302 | ) | 289 | ) |
303 | parser.add_argument( | 290 | parser.add_argument( |
304 | "--lora_dropout", | 291 | "--lora_dropout", |
305 | type=float, | 292 | type=float, |
306 | default=0.0, | 293 | default=0.0, |
307 | help="Lora dropout, only used if use_lora is True" | 294 | help="Lora dropout, only used if use_lora is True", |
308 | ) | 295 | ) |
309 | parser.add_argument( | 296 | parser.add_argument( |
310 | "--lora_bias", | 297 | "--lora_bias", |
@@ -344,7 +331,7 @@ def parse_args(): | |||
344 | parser.add_argument( | 331 | parser.add_argument( |
345 | "--train_text_encoder_cycles", | 332 | "--train_text_encoder_cycles", |
346 | default=999999, | 333 | default=999999, |
347 | help="Number of epochs the text encoder will be trained." | 334 | help="Number of epochs the text encoder will be trained.", |
348 | ) | 335 | ) |
349 | parser.add_argument( | 336 | parser.add_argument( |
350 | "--find_lr", | 337 | "--find_lr", |
@@ -378,27 +365,31 @@ def parse_args(): | |||
378 | "--lr_scheduler", | 365 | "--lr_scheduler", |
379 | type=str, | 366 | type=str, |
380 | default="one_cycle", | 367 | default="one_cycle", |
381 | choices=["linear", "cosine", "cosine_with_restarts", "polynomial", | 368 | choices=[ |
382 | "constant", "constant_with_warmup", "one_cycle"], | 369 | "linear", |
383 | help='The scheduler type to use.', | 370 | "cosine", |
371 | "cosine_with_restarts", | ||
372 | "polynomial", | ||
373 | "constant", | ||
374 | "constant_with_warmup", | ||
375 | "one_cycle", | ||
376 | ], | ||
377 | help="The scheduler type to use.", | ||
384 | ) | 378 | ) |
385 | parser.add_argument( | 379 | parser.add_argument( |
386 | "--lr_warmup_epochs", | 380 | "--lr_warmup_epochs", |
387 | type=int, | 381 | type=int, |
388 | default=10, | 382 | default=10, |
389 | help="Number of steps for the warmup in the lr scheduler." | 383 | help="Number of steps for the warmup in the lr scheduler.", |
390 | ) | 384 | ) |
391 | parser.add_argument( | 385 | parser.add_argument( |
392 | "--lr_mid_point", | 386 | "--lr_mid_point", type=float, default=0.3, help="OneCycle schedule mid point." |
393 | type=float, | ||
394 | default=0.3, | ||
395 | help="OneCycle schedule mid point." | ||
396 | ) | 387 | ) |
397 | parser.add_argument( | 388 | parser.add_argument( |
398 | "--lr_cycles", | 389 | "--lr_cycles", |
399 | type=int, | 390 | type=int, |
400 | default=None, | 391 | default=None, |
401 | help="Number of restart cycles in the lr scheduler (if supported)." | 392 | help="Number of restart cycles in the lr scheduler (if supported).", |
402 | ) | 393 | ) |
403 | parser.add_argument( | 394 | parser.add_argument( |
404 | "--lr_warmup_func", | 395 | "--lr_warmup_func", |
@@ -410,7 +401,7 @@ def parse_args(): | |||
410 | "--lr_warmup_exp", | 401 | "--lr_warmup_exp", |
411 | type=int, | 402 | type=int, |
412 | default=1, | 403 | default=1, |
413 | help='If lr_warmup_func is "cos", exponent to modify the function' | 404 | help='If lr_warmup_func is "cos", exponent to modify the function', |
414 | ) | 405 | ) |
415 | parser.add_argument( | 406 | parser.add_argument( |
416 | "--lr_annealing_func", | 407 | "--lr_annealing_func", |
@@ -422,69 +413,76 @@ def parse_args(): | |||
422 | "--lr_annealing_exp", | 413 | "--lr_annealing_exp", |
423 | type=int, | 414 | type=int, |
424 | default=3, | 415 | default=3, |
425 | help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' | 416 | help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function', |
426 | ) | 417 | ) |
427 | parser.add_argument( | 418 | parser.add_argument( |
428 | "--lr_min_lr", | 419 | "--lr_min_lr", |
429 | type=float, | 420 | type=float, |
430 | default=0.04, | 421 | default=0.04, |
431 | help="Minimum learning rate in the lr scheduler." | 422 | help="Minimum learning rate in the lr scheduler.", |
432 | ) | ||
433 | parser.add_argument( | ||
434 | "--min_snr_gamma", | ||
435 | type=int, | ||
436 | default=5, | ||
437 | help="MinSNR gamma." | ||
438 | ) | 423 | ) |
424 | parser.add_argument("--min_snr_gamma", type=int, default=5, help="MinSNR gamma.") | ||
439 | parser.add_argument( | 425 | parser.add_argument( |
440 | "--schedule_sampler", | 426 | "--schedule_sampler", |
441 | type=str, | 427 | type=str, |
442 | default="uniform", | 428 | default="uniform", |
443 | choices=["uniform", "loss-second-moment"], | 429 | choices=["uniform", "loss-second-moment"], |
444 | help="Noise schedule sampler." | 430 | help="Noise schedule sampler.", |
445 | ) | 431 | ) |
446 | parser.add_argument( | 432 | parser.add_argument( |
447 | "--optimizer", | 433 | "--optimizer", |
448 | type=str, | 434 | type=str, |
449 | default="adan", | 435 | default="adan", |
450 | choices=["adam", "adam8bit", "adan", "lion", "dadam", "dadan", "adafactor"], | 436 | choices=[ |
451 | help='Optimizer to use' | 437 | "adam", |
438 | "adam8bit", | ||
439 | "adan", | ||
440 | "lion", | ||
441 | "dadam", | ||
442 | "dadan", | ||
443 | "dlion", | ||
444 | "adafactor", | ||
445 | ], | ||
446 | help="Optimizer to use", | ||
452 | ) | 447 | ) |
453 | parser.add_argument( | 448 | parser.add_argument( |
454 | "--dadaptation_d0", | 449 | "--dadaptation_d0", |
455 | type=float, | 450 | type=float, |
456 | default=1e-6, | 451 | default=1e-6, |
457 | help="The d0 parameter for Dadaptation optimizers." | 452 | help="The d0 parameter for Dadaptation optimizers.", |
453 | ) | ||
454 | parser.add_argument( | ||
455 | "--dadaptation_growth_rate", | ||
456 | type=float, | ||
457 | default=math.inf, | ||
458 | help="The growth_rate parameter for Dadaptation optimizers.", | ||
458 | ) | 459 | ) |
459 | parser.add_argument( | 460 | parser.add_argument( |
460 | "--adam_beta1", | 461 | "--adam_beta1", |
461 | type=float, | 462 | type=float, |
462 | default=None, | 463 | default=None, |
463 | help="The beta1 parameter for the Adam optimizer." | 464 | help="The beta1 parameter for the Adam optimizer.", |
464 | ) | 465 | ) |
465 | parser.add_argument( | 466 | parser.add_argument( |
466 | "--adam_beta2", | 467 | "--adam_beta2", |
467 | type=float, | 468 | type=float, |
468 | default=None, | 469 | default=None, |
469 | help="The beta2 parameter for the Adam optimizer." | 470 | help="The beta2 parameter for the Adam optimizer.", |
470 | ) | 471 | ) |
471 | parser.add_argument( | 472 | parser.add_argument( |
472 | "--adam_weight_decay", | 473 | "--adam_weight_decay", type=float, default=2e-2, help="Weight decay to use." |
473 | type=float, | ||
474 | default=2e-2, | ||
475 | help="Weight decay to use." | ||
476 | ) | 474 | ) |
477 | parser.add_argument( | 475 | parser.add_argument( |
478 | "--adam_epsilon", | 476 | "--adam_epsilon", |
479 | type=float, | 477 | type=float, |
480 | default=1e-08, | 478 | default=1e-08, |
481 | help="Epsilon value for the Adam optimizer" | 479 | help="Epsilon value for the Adam optimizer", |
482 | ) | 480 | ) |
483 | parser.add_argument( | 481 | parser.add_argument( |
484 | "--adam_amsgrad", | 482 | "--adam_amsgrad", |
485 | type=bool, | 483 | type=bool, |
486 | default=False, | 484 | default=False, |
487 | help="Amsgrad value for the Adam optimizer" | 485 | help="Amsgrad value for the Adam optimizer", |
488 | ) | 486 | ) |
489 | parser.add_argument( | 487 | parser.add_argument( |
490 | "--mixed_precision", | 488 | "--mixed_precision", |
@@ -547,19 +545,19 @@ def parse_args(): | |||
547 | "--valid_set_size", | 545 | "--valid_set_size", |
548 | type=int, | 546 | type=int, |
549 | default=None, | 547 | default=None, |
550 | help="Number of images in the validation dataset." | 548 | help="Number of images in the validation dataset.", |
551 | ) | 549 | ) |
552 | parser.add_argument( | 550 | parser.add_argument( |
553 | "--valid_set_repeat", | 551 | "--valid_set_repeat", |
554 | type=int, | 552 | type=int, |
555 | default=1, | 553 | default=1, |
556 | help="Times the images in the validation dataset are repeated." | 554 | help="Times the images in the validation dataset are repeated.", |
557 | ) | 555 | ) |
558 | parser.add_argument( | 556 | parser.add_argument( |
559 | "--train_batch_size", | 557 | "--train_batch_size", |
560 | type=int, | 558 | type=int, |
561 | default=1, | 559 | default=1, |
562 | help="Batch size (per device) for the training dataloader." | 560 | help="Batch size (per device) for the training dataloader.", |
563 | ) | 561 | ) |
564 | parser.add_argument( | 562 | parser.add_argument( |
565 | "--sample_steps", | 563 | "--sample_steps", |
@@ -571,19 +569,10 @@ def parse_args(): | |||
571 | "--prior_loss_weight", | 569 | "--prior_loss_weight", |
572 | type=float, | 570 | type=float, |
573 | default=1.0, | 571 | default=1.0, |
574 | help="The weight of prior preservation loss." | 572 | help="The weight of prior preservation loss.", |
575 | ) | ||
576 | parser.add_argument( | ||
577 | "--run_pti", | ||
578 | action="store_true", | ||
579 | help="Whether to run PTI." | ||
580 | ) | ||
581 | parser.add_argument( | ||
582 | "--emb_alpha", | ||
583 | type=float, | ||
584 | default=1.0, | ||
585 | help="Embedding alpha" | ||
586 | ) | 573 | ) |
574 | parser.add_argument("--run_pti", action="store_true", help="Whether to run PTI.") | ||
575 | parser.add_argument("--emb_alpha", type=float, default=1.0, help="Embedding alpha") | ||
587 | parser.add_argument( | 576 | parser.add_argument( |
588 | "--emb_dropout", | 577 | "--emb_dropout", |
589 | type=float, | 578 | type=float, |
@@ -591,27 +580,16 @@ def parse_args(): | |||
591 | help="Embedding dropout probability.", | 580 | help="Embedding dropout probability.", |
592 | ) | 581 | ) |
593 | parser.add_argument( | 582 | parser.add_argument( |
594 | "--use_emb_decay", | 583 | "--use_emb_decay", action="store_true", help="Whether to use embedding decay." |
595 | action="store_true", | ||
596 | help="Whether to use embedding decay." | ||
597 | ) | 584 | ) |
598 | parser.add_argument( | 585 | parser.add_argument( |
599 | "--emb_decay_target", | 586 | "--emb_decay_target", default=0.4, type=float, help="Embedding decay target." |
600 | default=0.4, | ||
601 | type=float, | ||
602 | help="Embedding decay target." | ||
603 | ) | 587 | ) |
604 | parser.add_argument( | 588 | parser.add_argument( |
605 | "--emb_decay", | 589 | "--emb_decay", default=1e2, type=float, help="Embedding decay factor." |
606 | default=1e+2, | ||
607 | type=float, | ||
608 | help="Embedding decay factor." | ||
609 | ) | 590 | ) |
610 | parser.add_argument( | 591 | parser.add_argument( |
611 | "--max_grad_norm", | 592 | "--max_grad_norm", default=1.0, type=float, help="Max gradient norm." |
612 | default=1.0, | ||
613 | type=float, | ||
614 | help="Max gradient norm." | ||
615 | ) | 593 | ) |
616 | parser.add_argument( | 594 | parser.add_argument( |
617 | "--noise_timesteps", | 595 | "--noise_timesteps", |
@@ -622,7 +600,7 @@ def parse_args(): | |||
622 | "--config", | 600 | "--config", |
623 | type=str, | 601 | type=str, |
624 | default=None, | 602 | default=None, |
625 | help="Path to a JSON configuration file containing arguments for invoking this script." | 603 | help="Path to a JSON configuration file containing arguments for invoking this script.", |
626 | ) | 604 | ) |
627 | 605 | ||
628 | args = parser.parse_args() | 606 | args = parser.parse_args() |
@@ -649,29 +627,44 @@ def parse_args(): | |||
649 | args.placeholder_tokens = [args.placeholder_tokens] | 627 | args.placeholder_tokens = [args.placeholder_tokens] |
650 | 628 | ||
651 | if isinstance(args.initializer_tokens, str): | 629 | if isinstance(args.initializer_tokens, str): |
652 | args.initializer_tokens = [args.initializer_tokens] * len(args.placeholder_tokens) | 630 | args.initializer_tokens = [args.initializer_tokens] * len( |
631 | args.placeholder_tokens | ||
632 | ) | ||
653 | 633 | ||
654 | if len(args.placeholder_tokens) == 0: | 634 | if len(args.placeholder_tokens) == 0: |
655 | args.placeholder_tokens = [f"<*{i}>" for i in range(len(args.initializer_tokens))] | 635 | args.placeholder_tokens = [ |
636 | f"<*{i}>" for i in range(len(args.initializer_tokens)) | ||
637 | ] | ||
656 | 638 | ||
657 | if len(args.initializer_tokens) == 0: | 639 | if len(args.initializer_tokens) == 0: |
658 | args.initializer_tokens = args.placeholder_tokens.copy() | 640 | args.initializer_tokens = args.placeholder_tokens.copy() |
659 | 641 | ||
660 | if len(args.placeholder_tokens) != len(args.initializer_tokens): | 642 | if len(args.placeholder_tokens) != len(args.initializer_tokens): |
661 | raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") | 643 | raise ValueError( |
644 | "--placeholder_tokens and --initializer_tokens must have the same number of items" | ||
645 | ) | ||
662 | 646 | ||
663 | if isinstance(args.inverted_initializer_tokens, str): | 647 | if isinstance(args.inverted_initializer_tokens, str): |
664 | args.inverted_initializer_tokens = [args.inverted_initializer_tokens] * len(args.placeholder_tokens) | 648 | args.inverted_initializer_tokens = [args.inverted_initializer_tokens] * len( |
649 | args.placeholder_tokens | ||
650 | ) | ||
665 | 651 | ||
666 | if isinstance(args.inverted_initializer_tokens, list) and len(args.inverted_initializer_tokens) != 0: | 652 | if ( |
653 | isinstance(args.inverted_initializer_tokens, list) | ||
654 | and len(args.inverted_initializer_tokens) != 0 | ||
655 | ): | ||
667 | args.placeholder_tokens += [f"inv_{t}" for t in args.placeholder_tokens] | 656 | args.placeholder_tokens += [f"inv_{t}" for t in args.placeholder_tokens] |
668 | args.initializer_tokens += args.inverted_initializer_tokens | 657 | args.initializer_tokens += args.inverted_initializer_tokens |
669 | 658 | ||
670 | if isinstance(args.num_vectors, int): | 659 | if isinstance(args.num_vectors, int): |
671 | args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) | 660 | args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) |
672 | 661 | ||
673 | if isinstance(args.num_vectors, list) and len(args.placeholder_tokens) != len(args.num_vectors): | 662 | if isinstance(args.num_vectors, list) and len(args.placeholder_tokens) != len( |
674 | raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") | 663 | args.num_vectors |
664 | ): | ||
665 | raise ValueError( | ||
666 | "--placeholder_tokens and --num_vectors must have the same number of items" | ||
667 | ) | ||
675 | 668 | ||
676 | if args.alias_tokens is None: | 669 | if args.alias_tokens is None: |
677 | args.alias_tokens = [] | 670 | args.alias_tokens = [] |
@@ -695,15 +688,15 @@ def parse_args(): | |||
695 | raise ValueError("You must specify --output_dir") | 688 | raise ValueError("You must specify --output_dir") |
696 | 689 | ||
697 | if args.adam_beta1 is None: | 690 | if args.adam_beta1 is None: |
698 | if args.optimizer in ('adam', 'adam8bit'): | 691 | if args.optimizer in ("adam", "adam8bit", "dadam"): |
699 | args.adam_beta1 = 0.9 | 692 | args.adam_beta1 = 0.9 |
700 | elif args.optimizer == 'lion': | 693 | elif args.optimizer in ("lion", "dlion"): |
701 | args.adam_beta1 = 0.95 | 694 | args.adam_beta1 = 0.95 |
702 | 695 | ||
703 | if args.adam_beta2 is None: | 696 | if args.adam_beta2 is None: |
704 | if args.optimizer in ('adam', 'adam8bit'): | 697 | if args.optimizer in ("adam", "adam8bit", "dadam"): |
705 | args.adam_beta2 = 0.999 | 698 | args.adam_beta2 = 0.999 |
706 | elif args.optimizer == 'lion': | 699 | elif args.optimizer in ("lion", "dlion"): |
707 | args.adam_beta2 = 0.98 | 700 | args.adam_beta2 = 0.98 |
708 | 701 | ||
709 | return args | 702 | return args |
@@ -719,7 +712,7 @@ def main(): | |||
719 | accelerator = Accelerator( | 712 | accelerator = Accelerator( |
720 | log_with=LoggerType.TENSORBOARD, | 713 | log_with=LoggerType.TENSORBOARD, |
721 | project_dir=f"{output_dir}", | 714 | project_dir=f"{output_dir}", |
722 | mixed_precision=args.mixed_precision | 715 | mixed_precision=args.mixed_precision, |
723 | ) | 716 | ) |
724 | 717 | ||
725 | weight_dtype = torch.float32 | 718 | weight_dtype = torch.float32 |
@@ -728,6 +721,8 @@ def main(): | |||
728 | elif args.mixed_precision == "bf16": | 721 | elif args.mixed_precision == "bf16": |
729 | weight_dtype = torch.bfloat16 | 722 | weight_dtype = torch.bfloat16 |
730 | 723 | ||
724 | patch_xformers(weight_dtype) | ||
725 | |||
731 | logging.basicConfig(filename=output_dir / "log.txt", level=logging.DEBUG) | 726 | logging.basicConfig(filename=output_dir / "log.txt", level=logging.DEBUG) |
732 | 727 | ||
733 | if args.seed is None: | 728 | if args.seed is None: |
@@ -737,12 +732,18 @@ def main(): | |||
737 | 732 | ||
738 | save_args(output_dir, args) | 733 | save_args(output_dir, args) |
739 | 734 | ||
740 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler = get_models(args.pretrained_model_name_or_path) | 735 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler = get_models( |
741 | schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, noise_scheduler.config.num_train_timesteps) | 736 | args.pretrained_model_name_or_path |
742 | 737 | ) | |
738 | schedule_sampler = create_named_schedule_sampler( | ||
739 | args.schedule_sampler, noise_scheduler.config.num_train_timesteps | ||
740 | ) | ||
741 | |||
743 | def ensure_embeddings(): | 742 | def ensure_embeddings(): |
744 | if args.lora_text_encoder_emb: | 743 | if args.lora_text_encoder_emb: |
745 | raise ValueError("Can't use TI options when training token embeddings with LoRA") | 744 | raise ValueError( |
745 | "Can't use TI options when training token embeddings with LoRA" | ||
746 | ) | ||
746 | return patch_managed_embeddings(text_encoder, args.emb_alpha, args.emb_dropout) | 747 | return patch_managed_embeddings(text_encoder, args.emb_alpha, args.emb_dropout) |
747 | 748 | ||
748 | unet_config = LoraConfig( | 749 | unet_config = LoraConfig( |
@@ -757,7 +758,9 @@ def main(): | |||
757 | text_encoder_config = LoraConfig( | 758 | text_encoder_config = LoraConfig( |
758 | r=args.lora_text_encoder_r, | 759 | r=args.lora_text_encoder_r, |
759 | lora_alpha=args.lora_text_encoder_alpha, | 760 | lora_alpha=args.lora_text_encoder_alpha, |
760 | target_modules=TEXT_ENCODER_TARGET_MODULES_WITH_EMBEDDING if args.lora_text_encoder_emb else TEXT_ENCODER_TARGET_MODULES, | 761 | target_modules=TEXT_ENCODER_TARGET_MODULES_WITH_EMBEDDING |
762 | if args.lora_text_encoder_emb | ||
763 | else TEXT_ENCODER_TARGET_MODULES, | ||
761 | lora_dropout=args.lora_text_encoder_dropout, | 764 | lora_dropout=args.lora_text_encoder_dropout, |
762 | bias=args.lora_text_encoder_bias, | 765 | bias=args.lora_text_encoder_bias, |
763 | ) | 766 | ) |
@@ -787,7 +790,7 @@ def main(): | |||
787 | 790 | ||
788 | if len(args.alias_tokens) != 0: | 791 | if len(args.alias_tokens) != 0: |
789 | embeddings = ensure_embeddings() | 792 | embeddings = ensure_embeddings() |
790 | 793 | ||
791 | alias_placeholder_tokens = args.alias_tokens[::2] | 794 | alias_placeholder_tokens = args.alias_tokens[::2] |
792 | alias_initializer_tokens = args.alias_tokens[1::2] | 795 | alias_initializer_tokens = args.alias_tokens[1::2] |
793 | 796 | ||
@@ -795,27 +798,33 @@ def main(): | |||
795 | tokenizer=tokenizer, | 798 | tokenizer=tokenizer, |
796 | embeddings=embeddings, | 799 | embeddings=embeddings, |
797 | placeholder_tokens=alias_placeholder_tokens, | 800 | placeholder_tokens=alias_placeholder_tokens, |
798 | initializer_tokens=alias_initializer_tokens | 801 | initializer_tokens=alias_initializer_tokens, |
799 | ) | 802 | ) |
800 | embeddings.persist() | 803 | embeddings.persist() |
801 | print(f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}") | 804 | print( |
805 | f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}" | ||
806 | ) | ||
802 | 807 | ||
803 | placeholder_tokens = [] | 808 | placeholder_tokens = [] |
804 | placeholder_token_ids = [] | 809 | placeholder_token_ids = [] |
805 | 810 | ||
806 | if args.embeddings_dir is not None: | 811 | if args.embeddings_dir is not None: |
807 | embeddings = ensure_embeddings() | 812 | embeddings = ensure_embeddings() |
808 | 813 | ||
809 | embeddings_dir = Path(args.embeddings_dir) | 814 | embeddings_dir = Path(args.embeddings_dir) |
810 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): | 815 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): |
811 | raise ValueError("--embeddings_dir must point to an existing directory") | 816 | raise ValueError("--embeddings_dir must point to an existing directory") |
812 | 817 | ||
813 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) | 818 | added_tokens, added_ids = load_embeddings_from_dir( |
819 | tokenizer, embeddings, embeddings_dir | ||
820 | ) | ||
814 | 821 | ||
815 | placeholder_tokens = added_tokens | 822 | placeholder_tokens = added_tokens |
816 | placeholder_token_ids = added_ids | 823 | placeholder_token_ids = added_ids |
817 | 824 | ||
818 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") | 825 | print( |
826 | f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}" | ||
827 | ) | ||
819 | 828 | ||
820 | if args.train_dir_embeddings: | 829 | if args.train_dir_embeddings: |
821 | print("Training embeddings from embeddings dir") | 830 | print("Training embeddings from embeddings dir") |
@@ -824,7 +833,7 @@ def main(): | |||
824 | 833 | ||
825 | if len(args.placeholder_tokens) != 0 and not args.train_dir_embeddings: | 834 | if len(args.placeholder_tokens) != 0 and not args.train_dir_embeddings: |
826 | embeddings = ensure_embeddings() | 835 | embeddings = ensure_embeddings() |
827 | 836 | ||
828 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( | 837 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( |
829 | tokenizer=tokenizer, | 838 | tokenizer=tokenizer, |
830 | embeddings=embeddings, | 839 | embeddings=embeddings, |
@@ -836,23 +845,34 @@ def main(): | |||
836 | 845 | ||
837 | placeholder_tokens = args.placeholder_tokens | 846 | placeholder_tokens = args.placeholder_tokens |
838 | 847 | ||
839 | stats = list(zip( | 848 | stats = list( |
840 | placeholder_tokens, placeholder_token_ids, args.initializer_tokens, initializer_token_ids | 849 | zip( |
841 | )) | 850 | placeholder_tokens, |
851 | placeholder_token_ids, | ||
852 | args.initializer_tokens, | ||
853 | initializer_token_ids, | ||
854 | ) | ||
855 | ) | ||
842 | print(f"Training embeddings: {stats}") | 856 | print(f"Training embeddings: {stats}") |
843 | 857 | ||
844 | if args.scale_lr: | 858 | if args.scale_lr: |
845 | args.learning_rate_unet = ( | 859 | args.learning_rate_unet = ( |
846 | args.learning_rate_unet * args.gradient_accumulation_steps * | 860 | args.learning_rate_unet |
847 | args.train_batch_size * accelerator.num_processes | 861 | * args.gradient_accumulation_steps |
862 | * args.train_batch_size | ||
863 | * accelerator.num_processes | ||
848 | ) | 864 | ) |
849 | args.learning_rate_text = ( | 865 | args.learning_rate_text = ( |
850 | args.learning_rate_text * args.gradient_accumulation_steps * | 866 | args.learning_rate_text |
851 | args.train_batch_size * accelerator.num_processes | 867 | * args.gradient_accumulation_steps |
868 | * args.train_batch_size | ||
869 | * accelerator.num_processes | ||
852 | ) | 870 | ) |
853 | args.learning_rate_emb = ( | 871 | args.learning_rate_emb = ( |
854 | args.learning_rate_emb * args.gradient_accumulation_steps * | 872 | args.learning_rate_emb |
855 | args.train_batch_size * accelerator.num_processes | 873 | * args.gradient_accumulation_steps |
874 | * args.train_batch_size | ||
875 | * accelerator.num_processes | ||
856 | ) | 876 | ) |
857 | 877 | ||
858 | if args.find_lr: | 878 | if args.find_lr: |
@@ -861,11 +881,13 @@ def main(): | |||
861 | args.learning_rate_emb = 1e-6 | 881 | args.learning_rate_emb = 1e-6 |
862 | args.lr_scheduler = "exponential_growth" | 882 | args.lr_scheduler = "exponential_growth" |
863 | 883 | ||
864 | if args.optimizer == 'adam8bit': | 884 | if args.optimizer == "adam8bit": |
865 | try: | 885 | try: |
866 | import bitsandbytes as bnb | 886 | import bitsandbytes as bnb |
867 | except ImportError: | 887 | except ImportError: |
868 | raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.") | 888 | raise ImportError( |
889 | "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." | ||
890 | ) | ||
869 | 891 | ||
870 | create_optimizer = partial( | 892 | create_optimizer = partial( |
871 | bnb.optim.AdamW8bit, | 893 | bnb.optim.AdamW8bit, |
@@ -874,7 +896,7 @@ def main(): | |||
874 | eps=args.adam_epsilon, | 896 | eps=args.adam_epsilon, |
875 | amsgrad=args.adam_amsgrad, | 897 | amsgrad=args.adam_amsgrad, |
876 | ) | 898 | ) |
877 | elif args.optimizer == 'adam': | 899 | elif args.optimizer == "adam": |
878 | create_optimizer = partial( | 900 | create_optimizer = partial( |
879 | torch.optim.AdamW, | 901 | torch.optim.AdamW, |
880 | betas=(args.adam_beta1, args.adam_beta2), | 902 | betas=(args.adam_beta1, args.adam_beta2), |
@@ -882,11 +904,13 @@ def main(): | |||
882 | eps=args.adam_epsilon, | 904 | eps=args.adam_epsilon, |
883 | amsgrad=args.adam_amsgrad, | 905 | amsgrad=args.adam_amsgrad, |
884 | ) | 906 | ) |
885 | elif args.optimizer == 'adan': | 907 | elif args.optimizer == "adan": |
886 | try: | 908 | try: |
887 | import timm.optim | 909 | import timm.optim |
888 | except ImportError: | 910 | except ImportError: |
889 | raise ImportError("To use Adan, please install the PyTorch Image Models library: `pip install timm`.") | 911 | raise ImportError( |
912 | "To use Adan, please install the PyTorch Image Models library: `pip install timm`." | ||
913 | ) | ||
890 | 914 | ||
891 | create_optimizer = partial( | 915 | create_optimizer = partial( |
892 | timm.optim.Adan, | 916 | timm.optim.Adan, |
@@ -894,11 +918,13 @@ def main(): | |||
894 | eps=args.adam_epsilon, | 918 | eps=args.adam_epsilon, |
895 | no_prox=True, | 919 | no_prox=True, |
896 | ) | 920 | ) |
897 | elif args.optimizer == 'lion': | 921 | elif args.optimizer == "lion": |
898 | try: | 922 | try: |
899 | import lion_pytorch | 923 | import lion_pytorch |
900 | except ImportError: | 924 | except ImportError: |
901 | raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion-pytorch`.") | 925 | raise ImportError( |
926 | "To use Lion, please install the lion_pytorch library: `pip install lion-pytorch`." | ||
927 | ) | ||
902 | 928 | ||
903 | create_optimizer = partial( | 929 | create_optimizer = partial( |
904 | lion_pytorch.Lion, | 930 | lion_pytorch.Lion, |
@@ -906,7 +932,7 @@ def main(): | |||
906 | weight_decay=args.adam_weight_decay, | 932 | weight_decay=args.adam_weight_decay, |
907 | use_triton=True, | 933 | use_triton=True, |
908 | ) | 934 | ) |
909 | elif args.optimizer == 'adafactor': | 935 | elif args.optimizer == "adafactor": |
910 | create_optimizer = partial( | 936 | create_optimizer = partial( |
911 | transformers.optimization.Adafactor, | 937 | transformers.optimization.Adafactor, |
912 | weight_decay=args.adam_weight_decay, | 938 | weight_decay=args.adam_weight_decay, |
@@ -920,11 +946,13 @@ def main(): | |||
920 | args.learning_rate_unet = None | 946 | args.learning_rate_unet = None |
921 | args.learning_rate_text = None | 947 | args.learning_rate_text = None |
922 | args.learning_rate_emb = None | 948 | args.learning_rate_emb = None |
923 | elif args.optimizer == 'dadam': | 949 | elif args.optimizer == "dadam": |
924 | try: | 950 | try: |
925 | import dadaptation | 951 | import dadaptation |
926 | except ImportError: | 952 | except ImportError: |
927 | raise ImportError("To use DAdaptAdam, please install the dadaptation library: `pip install dadaptation`.") | 953 | raise ImportError( |
954 | "To use DAdaptAdam, please install the dadaptation library: `pip install dadaptation`." | ||
955 | ) | ||
928 | 956 | ||
929 | create_optimizer = partial( | 957 | create_optimizer = partial( |
930 | dadaptation.DAdaptAdam, | 958 | dadaptation.DAdaptAdam, |
@@ -933,29 +961,35 @@ def main(): | |||
933 | eps=args.adam_epsilon, | 961 | eps=args.adam_epsilon, |
934 | decouple=True, | 962 | decouple=True, |
935 | d0=args.dadaptation_d0, | 963 | d0=args.dadaptation_d0, |
964 | growth_rate=args.dadaptation_growth_rate, | ||
936 | ) | 965 | ) |
937 | 966 | ||
938 | args.learning_rate_unet = 1.0 | 967 | args.learning_rate_unet = 1.0 |
939 | args.learning_rate_text = 1.0 | 968 | args.learning_rate_text = 1.0 |
940 | args.learning_rate_emb = 1.0 | 969 | args.learning_rate_emb = 1.0 |
941 | elif args.optimizer == 'dadan': | 970 | elif args.optimizer == "dadan": |
942 | try: | 971 | try: |
943 | import dadaptation | 972 | import dadaptation |
944 | except ImportError: | 973 | except ImportError: |
945 | raise ImportError("To use DAdaptAdan, please install the dadaptation library: `pip install dadaptation`.") | 974 | raise ImportError( |
975 | "To use DAdaptAdan, please install the dadaptation library: `pip install dadaptation`." | ||
976 | ) | ||
946 | 977 | ||
947 | create_optimizer = partial( | 978 | create_optimizer = partial( |
948 | dadaptation.DAdaptAdan, | 979 | dadaptation.DAdaptAdan, |
949 | weight_decay=args.adam_weight_decay, | 980 | weight_decay=args.adam_weight_decay, |
950 | eps=args.adam_epsilon, | 981 | eps=args.adam_epsilon, |
951 | d0=args.dadaptation_d0, | 982 | d0=args.dadaptation_d0, |
983 | growth_rate=args.dadaptation_growth_rate, | ||
952 | ) | 984 | ) |
953 | 985 | ||
954 | args.learning_rate_unet = 1.0 | 986 | args.learning_rate_unet = 1.0 |
955 | args.learning_rate_text = 1.0 | 987 | args.learning_rate_text = 1.0 |
956 | args.learning_rate_emb = 1.0 | 988 | args.learning_rate_emb = 1.0 |
989 | elif args.optimizer == "dlion": | ||
990 | raise ImportError("DLion has not been merged into dadaptation yet") | ||
957 | else: | 991 | else: |
958 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") | 992 | raise ValueError(f'Unknown --optimizer "{args.optimizer}"') |
959 | 993 | ||
960 | trainer = partial( | 994 | trainer = partial( |
961 | train, | 995 | train, |
@@ -1026,25 +1060,33 @@ def main(): | |||
1026 | 1060 | ||
1027 | if args.run_pti and len(placeholder_tokens) != 0: | 1061 | if args.run_pti and len(placeholder_tokens) != 0: |
1028 | embeddings = ensure_embeddings() | 1062 | embeddings = ensure_embeddings() |
1029 | 1063 | ||
1030 | filter_tokens = [token for token in args.filter_tokens if token in placeholder_tokens] | 1064 | filter_tokens = [ |
1065 | token for token in args.filter_tokens if token in placeholder_tokens | ||
1066 | ] | ||
1031 | 1067 | ||
1032 | pti_datamodule = create_datamodule( | 1068 | pti_datamodule = create_datamodule( |
1033 | valid_set_size=0, | 1069 | valid_set_size=0, |
1034 | batch_size=args.train_batch_size, | 1070 | batch_size=args.train_batch_size, |
1035 | filter=partial(keyword_filter, filter_tokens, args.collection, args.exclude_collections), | 1071 | filter=partial( |
1072 | keyword_filter, filter_tokens, args.collection, args.exclude_collections | ||
1073 | ), | ||
1036 | ) | 1074 | ) |
1037 | pti_datamodule.setup() | 1075 | pti_datamodule.setup() |
1038 | 1076 | ||
1039 | num_train_epochs = args.num_train_epochs | 1077 | num_train_epochs = args.num_train_epochs |
1040 | pti_sample_frequency = args.sample_frequency | 1078 | pti_sample_frequency = args.sample_frequency |
1041 | if num_train_epochs is None: | 1079 | if num_train_epochs is None: |
1042 | num_train_epochs = math.ceil( | 1080 | num_train_epochs = ( |
1043 | args.num_train_steps / len(pti_datamodule.train_dataset) | 1081 | math.ceil(args.num_train_steps / len(pti_datamodule.train_dataset)) |
1044 | ) * args.gradient_accumulation_steps | 1082 | * args.gradient_accumulation_steps |
1045 | pti_sample_frequency = math.ceil(num_train_epochs * (pti_sample_frequency / args.num_train_steps)) | 1083 | ) |
1084 | pti_sample_frequency = math.ceil( | ||
1085 | num_train_epochs * (pti_sample_frequency / args.num_train_steps) | ||
1086 | ) | ||
1046 | num_training_steps_per_epoch = math.ceil( | 1087 | num_training_steps_per_epoch = math.ceil( |
1047 | len(pti_datamodule.train_dataset) / args.gradient_accumulation_steps) | 1088 | len(pti_datamodule.train_dataset) / args.gradient_accumulation_steps |
1089 | ) | ||
1048 | num_train_steps = num_training_steps_per_epoch * num_train_epochs | 1090 | num_train_steps = num_training_steps_per_epoch * num_train_epochs |
1049 | if args.sample_num is not None: | 1091 | if args.sample_num is not None: |
1050 | pti_sample_frequency = math.ceil(num_train_epochs / args.sample_num) | 1092 | pti_sample_frequency = math.ceil(num_train_epochs / args.sample_num) |
@@ -1060,11 +1102,15 @@ def main(): | |||
1060 | print(f"============ PTI ============") | 1102 | print(f"============ PTI ============") |
1061 | print("") | 1103 | print("") |
1062 | 1104 | ||
1063 | pti_optimizer = create_optimizer([{ | 1105 | pti_optimizer = create_optimizer( |
1064 | "params": text_encoder.text_model.embeddings.token_embedding.parameters(), | 1106 | [ |
1065 | "lr": args.learning_rate_emb, | 1107 | { |
1066 | "weight_decay": 0, | 1108 | "params": text_encoder.text_model.embeddings.token_embedding.parameters(), |
1067 | }]) | 1109 | "lr": args.learning_rate_emb, |
1110 | "weight_decay": 0, | ||
1111 | } | ||
1112 | ] | ||
1113 | ) | ||
1068 | 1114 | ||
1069 | pti_lr_scheduler = create_lr_scheduler( | 1115 | pti_lr_scheduler = create_lr_scheduler( |
1070 | "constant_with_warmup", | 1116 | "constant_with_warmup", |
@@ -1113,11 +1159,16 @@ def main(): | |||
1113 | num_train_epochs = args.num_train_epochs | 1159 | num_train_epochs = args.num_train_epochs |
1114 | lora_sample_frequency = args.sample_frequency | 1160 | lora_sample_frequency = args.sample_frequency |
1115 | if num_train_epochs is None: | 1161 | if num_train_epochs is None: |
1116 | num_train_epochs = math.ceil( | 1162 | num_train_epochs = ( |
1117 | args.num_train_steps / len(lora_datamodule.train_dataset) | 1163 | math.ceil(args.num_train_steps / len(lora_datamodule.train_dataset)) |
1118 | ) * args.gradient_accumulation_steps | 1164 | * args.gradient_accumulation_steps |
1119 | lora_sample_frequency = math.ceil(num_train_epochs * (lora_sample_frequency / args.num_train_steps)) | 1165 | ) |
1120 | num_training_steps_per_epoch = math.ceil(len(lora_datamodule.train_dataset) / args.gradient_accumulation_steps) | 1166 | lora_sample_frequency = math.ceil( |
1167 | num_train_epochs * (lora_sample_frequency / args.num_train_steps) | ||
1168 | ) | ||
1169 | num_training_steps_per_epoch = math.ceil( | ||
1170 | len(lora_datamodule.train_dataset) / args.gradient_accumulation_steps | ||
1171 | ) | ||
1121 | num_train_steps = num_training_steps_per_epoch * num_train_epochs | 1172 | num_train_steps = num_training_steps_per_epoch * num_train_epochs |
1122 | if args.sample_num is not None: | 1173 | if args.sample_num is not None: |
1123 | lora_sample_frequency = math.ceil(num_train_epochs / args.sample_num) | 1174 | lora_sample_frequency = math.ceil(num_train_epochs / args.sample_num) |
@@ -1131,7 +1182,6 @@ def main(): | |||
1131 | 1182 | ||
1132 | training_iter = 0 | 1183 | training_iter = 0 |
1133 | auto_cycles = list(args.auto_cycles) | 1184 | auto_cycles = list(args.auto_cycles) |
1134 | learning_rate_emb = args.learning_rate_emb | ||
1135 | learning_rate_unet = args.learning_rate_unet | 1185 | learning_rate_unet = args.learning_rate_unet |
1136 | learning_rate_text = args.learning_rate_text | 1186 | learning_rate_text = args.learning_rate_text |
1137 | lr_scheduler = args.lr_scheduler | 1187 | lr_scheduler = args.lr_scheduler |
@@ -1145,21 +1195,15 @@ def main(): | |||
1145 | 1195 | ||
1146 | params_to_optimize = [ | 1196 | params_to_optimize = [ |
1147 | { | 1197 | { |
1148 | "params": ( | 1198 | "params": (param for param in unet.parameters() if param.requires_grad), |
1149 | param | ||
1150 | for param in unet.parameters() | ||
1151 | if param.requires_grad | ||
1152 | ), | ||
1153 | "lr": learning_rate_unet, | 1199 | "lr": learning_rate_unet, |
1154 | }, | 1200 | }, |
1155 | { | 1201 | { |
1156 | "params": ( | 1202 | "params": ( |
1157 | param | 1203 | param for param in text_encoder.parameters() if param.requires_grad |
1158 | for param in text_encoder.parameters() | ||
1159 | if param.requires_grad | ||
1160 | ), | 1204 | ), |
1161 | "lr": learning_rate_text, | 1205 | "lr": learning_rate_text, |
1162 | } | 1206 | }, |
1163 | ] | 1207 | ] |
1164 | group_labels = ["unet", "text"] | 1208 | group_labels = ["unet", "text"] |
1165 | 1209 | ||
@@ -1169,19 +1213,26 @@ def main(): | |||
1169 | if len(auto_cycles) != 0: | 1213 | if len(auto_cycles) != 0: |
1170 | response = auto_cycles.pop(0) | 1214 | response = auto_cycles.pop(0) |
1171 | else: | 1215 | else: |
1172 | response = input("\n### Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> ") | 1216 | response = input( |
1217 | "\n### Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> " | ||
1218 | ) | ||
1173 | 1219 | ||
1174 | if response.lower().strip() == "o": | 1220 | if response.lower().strip() == "o": |
1175 | if args.learning_rate_emb is not None: | ||
1176 | learning_rate_emb = args.learning_rate_emb * 2 | ||
1177 | if args.learning_rate_unet is not None: | 1221 | if args.learning_rate_unet is not None: |
1178 | learning_rate_unet = args.learning_rate_unet * 2 | 1222 | learning_rate_unet = ( |
1223 | args.learning_rate_unet * 2 * (args.cycle_decay**training_iter) | ||
1224 | ) | ||
1179 | if args.learning_rate_text is not None: | 1225 | if args.learning_rate_text is not None: |
1180 | learning_rate_text = args.learning_rate_text * 2 | 1226 | learning_rate_text = ( |
1227 | args.learning_rate_text * 2 * (args.cycle_decay**training_iter) | ||
1228 | ) | ||
1181 | else: | 1229 | else: |
1182 | learning_rate_emb = args.learning_rate_emb | 1230 | learning_rate_unet = args.learning_rate_unet * ( |
1183 | learning_rate_unet = args.learning_rate_unet | 1231 | args.cycle_decay**training_iter |
1184 | learning_rate_text = args.learning_rate_text | 1232 | ) |
1233 | learning_rate_text = args.learning_rate_text * ( | ||
1234 | args.cycle_decay**training_iter | ||
1235 | ) | ||
1185 | 1236 | ||
1186 | if response.lower().strip() == "o": | 1237 | if response.lower().strip() == "o": |
1187 | lr_scheduler = "one_cycle" | 1238 | lr_scheduler = "one_cycle" |
@@ -1204,9 +1255,11 @@ def main(): | |||
1204 | print("") | 1255 | print("") |
1205 | print(f"============ LoRA cycle {training_iter + 1}: {response} ============") | 1256 | print(f"============ LoRA cycle {training_iter + 1}: {response} ============") |
1206 | print("") | 1257 | print("") |
1207 | 1258 | ||
1208 | for group, lr in zip(lora_optimizer.param_groups, [learning_rate_unet, learning_rate_text]): | 1259 | for group, lr in zip( |
1209 | group['lr'] = lr | 1260 | lora_optimizer.param_groups, [learning_rate_unet, learning_rate_text] |
1261 | ): | ||
1262 | group["lr"] = lr | ||
1210 | 1263 | ||
1211 | lora_lr_scheduler = create_lr_scheduler( | 1264 | lora_lr_scheduler = create_lr_scheduler( |
1212 | lr_scheduler, | 1265 | lr_scheduler, |
@@ -1218,7 +1271,9 @@ def main(): | |||
1218 | warmup_epochs=lr_warmup_epochs, | 1271 | warmup_epochs=lr_warmup_epochs, |
1219 | ) | 1272 | ) |
1220 | 1273 | ||
1221 | lora_checkpoint_output_dir = output_dir / lora_project / f"model_{training_iter}" | 1274 | lora_checkpoint_output_dir = ( |
1275 | output_dir / lora_project / f"model_{training_iter}" | ||
1276 | ) | ||
1222 | 1277 | ||
1223 | trainer( | 1278 | trainer( |
1224 | strategy=lora_strategy, | 1279 | strategy=lora_strategy, |
@@ -1246,12 +1301,6 @@ def main(): | |||
1246 | ) | 1301 | ) |
1247 | 1302 | ||
1248 | training_iter += 1 | 1303 | training_iter += 1 |
1249 | if learning_rate_emb is not None: | ||
1250 | learning_rate_emb *= args.cycle_decay | ||
1251 | if learning_rate_unet is not None: | ||
1252 | learning_rate_unet *= args.cycle_decay | ||
1253 | if learning_rate_text is not None: | ||
1254 | learning_rate_text *= args.cycle_decay | ||
1255 | 1304 | ||
1256 | accelerator.end_training() | 1305 | accelerator.end_training() |
1257 | 1306 | ||