diff options
Diffstat (limited to 'train_lora.py')
| -rw-r--r-- | train_lora.py | 489 |
1 files changed, 269 insertions, 220 deletions
diff --git a/train_lora.py b/train_lora.py index c74dd8f..fccf48d 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -16,6 +16,7 @@ from accelerate import Accelerator | |||
| 16 | from accelerate.logging import get_logger | 16 | from accelerate.logging import get_logger |
| 17 | from accelerate.utils import LoggerType, set_seed | 17 | from accelerate.utils import LoggerType, set_seed |
| 18 | from peft import LoraConfig, get_peft_model | 18 | from peft import LoraConfig, get_peft_model |
| 19 | |||
| 19 | # from diffusers.models.attention_processor import AttnProcessor | 20 | # from diffusers.models.attention_processor import AttnProcessor |
| 20 | from diffusers.utils.import_utils import is_xformers_available | 21 | from diffusers.utils.import_utils import is_xformers_available |
| 21 | import transformers | 22 | import transformers |
| @@ -34,15 +35,20 @@ from util.files import load_config, load_embeddings_from_dir | |||
| 34 | 35 | ||
| 35 | # https://github.com/huggingface/peft/blob/main/examples/lora_dreambooth/train_dreambooth.py | 36 | # https://github.com/huggingface/peft/blob/main/examples/lora_dreambooth/train_dreambooth.py |
| 36 | UNET_TARGET_MODULES_ORIG = ["to_q", "to_v", "query", "value"] | 37 | UNET_TARGET_MODULES_ORIG = ["to_q", "to_v", "query", "value"] |
| 37 | UNET_TARGET_MODULES = UNET_TARGET_MODULES_ORIG + ["to_out.0"] # ["to_k", "key"] | 38 | UNET_TARGET_MODULES = UNET_TARGET_MODULES_ORIG + ["to_out.0", "to_k", "key"] # [] |
| 38 | TEXT_ENCODER_TARGET_MODULES_ORIG = ["q_proj", "v_proj"] | 39 | TEXT_ENCODER_TARGET_MODULES_ORIG = ["q_proj", "v_proj"] |
| 39 | TEXT_ENCODER_TARGET_MODULES = TEXT_ENCODER_TARGET_MODULES_ORIG + ["out_proj"] # ["k_proj"] | 40 | TEXT_ENCODER_TARGET_MODULES = TEXT_ENCODER_TARGET_MODULES_ORIG + [ |
| 40 | TEXT_ENCODER_TARGET_MODULES_WITH_EMBEDDING = TEXT_ENCODER_TARGET_MODULES + ["token_embedding"] | 41 | "out_proj", |
| 42 | "k_proj", | ||
| 43 | ] # [] | ||
| 44 | TEXT_ENCODER_TARGET_MODULES_WITH_EMBEDDING = TEXT_ENCODER_TARGET_MODULES + [ | ||
| 45 | "token_embedding" | ||
| 46 | ] | ||
| 41 | 47 | ||
| 42 | 48 | ||
| 43 | logger = get_logger(__name__) | 49 | logger = get_logger(__name__) |
| 44 | 50 | ||
| 45 | warnings.filterwarnings('ignore') | 51 | warnings.filterwarnings("ignore") |
| 46 | 52 | ||
| 47 | 53 | ||
| 48 | torch.backends.cuda.matmul.allow_tf32 = True | 54 | torch.backends.cuda.matmul.allow_tf32 = True |
| @@ -55,20 +61,27 @@ hidet.torch.dynamo_config.use_tensor_core(True) | |||
| 55 | hidet.torch.dynamo_config.search_space(0) | 61 | hidet.torch.dynamo_config.search_space(0) |
| 56 | 62 | ||
| 57 | 63 | ||
| 58 | if is_xformers_available(): | 64 | def patch_xformers(dtype): |
| 59 | import xformers | 65 | if is_xformers_available(): |
| 60 | import xformers.ops | 66 | import xformers |
| 61 | 67 | import xformers.ops | |
| 62 | orig_xformers_memory_efficient_attention = xformers.ops.memory_efficient_attention | 68 | |
| 63 | def xformers_memory_efficient_attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs): | 69 | orig_xformers_memory_efficient_attention = ( |
| 64 | return orig_xformers_memory_efficient_attention(query.to(key.dtype), key, value.to(key.dtype), **kwargs) | 70 | xformers.ops.memory_efficient_attention |
| 65 | xformers.ops.memory_efficient_attention = xformers_memory_efficient_attention | 71 | ) |
| 72 | |||
| 73 | def xformers_memory_efficient_attention( | ||
| 74 | query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs | ||
| 75 | ): | ||
| 76 | return orig_xformers_memory_efficient_attention( | ||
| 77 | query.to(dtype), key.to(dtype), value.to(dtype), **kwargs | ||
| 78 | ) | ||
| 79 | |||
| 80 | xformers.ops.memory_efficient_attention = xformers_memory_efficient_attention | ||
| 66 | 81 | ||
| 67 | 82 | ||
| 68 | def parse_args(): | 83 | def parse_args(): |
| 69 | parser = argparse.ArgumentParser( | 84 | parser = argparse.ArgumentParser(description="Simple example of a training script.") |
| 70 | description="Simple example of a training script." | ||
| 71 | ) | ||
| 72 | parser.add_argument( | 85 | parser.add_argument( |
| 73 | "--pretrained_model_name_or_path", | 86 | "--pretrained_model_name_or_path", |
| 74 | type=str, | 87 | type=str, |
| @@ -85,7 +98,7 @@ def parse_args(): | |||
| 85 | "--train_data_file", | 98 | "--train_data_file", |
| 86 | type=str, | 99 | type=str, |
| 87 | default=None, | 100 | default=None, |
| 88 | help="A folder containing the training data." | 101 | help="A folder containing the training data.", |
| 89 | ) | 102 | ) |
| 90 | parser.add_argument( | 103 | parser.add_argument( |
| 91 | "--train_data_template", | 104 | "--train_data_template", |
| @@ -96,13 +109,13 @@ def parse_args(): | |||
| 96 | "--train_set_pad", | 109 | "--train_set_pad", |
| 97 | type=int, | 110 | type=int, |
| 98 | default=None, | 111 | default=None, |
| 99 | help="The number to fill train dataset items up to." | 112 | help="The number to fill train dataset items up to.", |
| 100 | ) | 113 | ) |
| 101 | parser.add_argument( | 114 | parser.add_argument( |
| 102 | "--valid_set_pad", | 115 | "--valid_set_pad", |
| 103 | type=int, | 116 | type=int, |
| 104 | default=None, | 117 | default=None, |
| 105 | help="The number to fill validation dataset items up to." | 118 | help="The number to fill validation dataset items up to.", |
| 106 | ) | 119 | ) |
| 107 | parser.add_argument( | 120 | parser.add_argument( |
| 108 | "--project", | 121 | "--project", |
| @@ -111,64 +124,52 @@ def parse_args(): | |||
| 111 | help="The name of the current project.", | 124 | help="The name of the current project.", |
| 112 | ) | 125 | ) |
| 113 | parser.add_argument( | 126 | parser.add_argument( |
| 114 | "--auto_cycles", | 127 | "--auto_cycles", type=str, default="o", help="Cycles to run automatically." |
| 115 | type=str, | ||
| 116 | default="o", | ||
| 117 | help="Cycles to run automatically." | ||
| 118 | ) | 128 | ) |
| 119 | parser.add_argument( | 129 | parser.add_argument( |
| 120 | "--cycle_decay", | 130 | "--cycle_decay", type=float, default=1.0, help="Learning rate decay per cycle." |
| 121 | type=float, | ||
| 122 | default=1.0, | ||
| 123 | help="Learning rate decay per cycle." | ||
| 124 | ) | 131 | ) |
| 125 | parser.add_argument( | 132 | parser.add_argument( |
| 126 | "--placeholder_tokens", | 133 | "--placeholder_tokens", |
| 127 | type=str, | 134 | type=str, |
| 128 | nargs='*', | 135 | nargs="*", |
| 129 | help="A token to use as a placeholder for the concept.", | 136 | help="A token to use as a placeholder for the concept.", |
| 130 | ) | 137 | ) |
| 131 | parser.add_argument( | 138 | parser.add_argument( |
| 132 | "--initializer_tokens", | 139 | "--initializer_tokens", |
| 133 | type=str, | 140 | type=str, |
| 134 | nargs='*', | 141 | nargs="*", |
| 135 | help="A token to use as initializer word." | 142 | help="A token to use as initializer word.", |
| 136 | ) | 143 | ) |
| 137 | parser.add_argument( | 144 | parser.add_argument( |
| 138 | "--filter_tokens", | 145 | "--filter_tokens", type=str, nargs="*", help="Tokens to filter the dataset by." |
| 139 | type=str, | ||
| 140 | nargs='*', | ||
| 141 | help="Tokens to filter the dataset by." | ||
| 142 | ) | 146 | ) |
| 143 | parser.add_argument( | 147 | parser.add_argument( |
| 144 | "--initializer_noise", | 148 | "--initializer_noise", |
| 145 | type=float, | 149 | type=float, |
| 146 | default=0, | 150 | default=0, |
| 147 | help="Noise to apply to the initializer word" | 151 | help="Noise to apply to the initializer word", |
| 148 | ) | 152 | ) |
| 149 | parser.add_argument( | 153 | parser.add_argument( |
| 150 | "--alias_tokens", | 154 | "--alias_tokens", |
| 151 | type=str, | 155 | type=str, |
| 152 | nargs='*', | 156 | nargs="*", |
| 153 | default=[], | 157 | default=[], |
| 154 | help="Tokens to create an alias for." | 158 | help="Tokens to create an alias for.", |
| 155 | ) | 159 | ) |
| 156 | parser.add_argument( | 160 | parser.add_argument( |
| 157 | "--inverted_initializer_tokens", | 161 | "--inverted_initializer_tokens", |
| 158 | type=str, | 162 | type=str, |
| 159 | nargs='*', | 163 | nargs="*", |
| 160 | help="A token to use as initializer word." | 164 | help="A token to use as initializer word.", |
| 161 | ) | 165 | ) |
| 162 | parser.add_argument( | 166 | parser.add_argument( |
| 163 | "--num_vectors", | 167 | "--num_vectors", type=int, nargs="*", help="Number of vectors per embedding." |
| 164 | type=int, | ||
| 165 | nargs='*', | ||
| 166 | help="Number of vectors per embedding." | ||
| 167 | ) | 168 | ) |
| 168 | parser.add_argument( | 169 | parser.add_argument( |
| 169 | "--exclude_collections", | 170 | "--exclude_collections", |
| 170 | type=str, | 171 | type=str, |
| 171 | nargs='*', | 172 | nargs="*", |
| 172 | help="Exclude all items with a listed collection.", | 173 | help="Exclude all items with a listed collection.", |
| 173 | ) | 174 | ) |
| 174 | parser.add_argument( | 175 | parser.add_argument( |
| @@ -214,7 +215,7 @@ def parse_args(): | |||
| 214 | "--num_class_images", | 215 | "--num_class_images", |
| 215 | type=int, | 216 | type=int, |
| 216 | default=0, | 217 | default=0, |
| 217 | help="How many class images to generate." | 218 | help="How many class images to generate.", |
| 218 | ) | 219 | ) |
| 219 | parser.add_argument( | 220 | parser.add_argument( |
| 220 | "--class_image_dir", | 221 | "--class_image_dir", |
| @@ -242,14 +243,11 @@ def parse_args(): | |||
| 242 | parser.add_argument( | 243 | parser.add_argument( |
| 243 | "--collection", | 244 | "--collection", |
| 244 | type=str, | 245 | type=str, |
| 245 | nargs='*', | 246 | nargs="*", |
| 246 | help="A collection to filter the dataset.", | 247 | help="A collection to filter the dataset.", |
| 247 | ) | 248 | ) |
| 248 | parser.add_argument( | 249 | parser.add_argument( |
| 249 | "--seed", | 250 | "--seed", type=int, default=None, help="A seed for reproducible training." |
| 250 | type=int, | ||
| 251 | default=None, | ||
| 252 | help="A seed for reproducible training." | ||
| 253 | ) | 251 | ) |
| 254 | parser.add_argument( | 252 | parser.add_argument( |
| 255 | "--resolution", | 253 | "--resolution", |
| @@ -270,18 +268,10 @@ def parse_args(): | |||
| 270 | "--input_pertubation", | 268 | "--input_pertubation", |
| 271 | type=float, | 269 | type=float, |
| 272 | default=0, | 270 | default=0, |
| 273 | help="The scale of input pretubation. Recommended 0.1." | 271 | help="The scale of input pretubation. Recommended 0.1.", |
| 274 | ) | ||
| 275 | parser.add_argument( | ||
| 276 | "--num_train_epochs", | ||
| 277 | type=int, | ||
| 278 | default=None | ||
| 279 | ) | ||
| 280 | parser.add_argument( | ||
| 281 | "--num_train_steps", | ||
| 282 | type=int, | ||
| 283 | default=2000 | ||
| 284 | ) | 272 | ) |
| 273 | parser.add_argument("--num_train_epochs", type=int, default=None) | ||
| 274 | parser.add_argument("--num_train_steps", type=int, default=2000) | ||
| 285 | parser.add_argument( | 275 | parser.add_argument( |
| 286 | "--gradient_accumulation_steps", | 276 | "--gradient_accumulation_steps", |
| 287 | type=int, | 277 | type=int, |
| @@ -289,22 +279,19 @@ def parse_args(): | |||
| 289 | help="Number of updates steps to accumulate before performing a backward/update pass.", | 279 | help="Number of updates steps to accumulate before performing a backward/update pass.", |
| 290 | ) | 280 | ) |
| 291 | parser.add_argument( | 281 | parser.add_argument( |
| 292 | "--lora_r", | 282 | "--lora_r", type=int, default=8, help="Lora rank, only used if use_lora is True" |
| 293 | type=int, | ||
| 294 | default=8, | ||
| 295 | help="Lora rank, only used if use_lora is True" | ||
| 296 | ) | 283 | ) |
| 297 | parser.add_argument( | 284 | parser.add_argument( |
| 298 | "--lora_alpha", | 285 | "--lora_alpha", |
| 299 | type=int, | 286 | type=int, |
| 300 | default=32, | 287 | default=32, |
| 301 | help="Lora alpha, only used if use_lora is True" | 288 | help="Lora alpha, only used if use_lora is True", |
| 302 | ) | 289 | ) |
| 303 | parser.add_argument( | 290 | parser.add_argument( |
| 304 | "--lora_dropout", | 291 | "--lora_dropout", |
| 305 | type=float, | 292 | type=float, |
| 306 | default=0.0, | 293 | default=0.0, |
| 307 | help="Lora dropout, only used if use_lora is True" | 294 | help="Lora dropout, only used if use_lora is True", |
| 308 | ) | 295 | ) |
| 309 | parser.add_argument( | 296 | parser.add_argument( |
| 310 | "--lora_bias", | 297 | "--lora_bias", |
| @@ -344,7 +331,7 @@ def parse_args(): | |||
| 344 | parser.add_argument( | 331 | parser.add_argument( |
| 345 | "--train_text_encoder_cycles", | 332 | "--train_text_encoder_cycles", |
| 346 | default=999999, | 333 | default=999999, |
| 347 | help="Number of epochs the text encoder will be trained." | 334 | help="Number of epochs the text encoder will be trained.", |
| 348 | ) | 335 | ) |
| 349 | parser.add_argument( | 336 | parser.add_argument( |
| 350 | "--find_lr", | 337 | "--find_lr", |
| @@ -378,27 +365,31 @@ def parse_args(): | |||
| 378 | "--lr_scheduler", | 365 | "--lr_scheduler", |
| 379 | type=str, | 366 | type=str, |
| 380 | default="one_cycle", | 367 | default="one_cycle", |
| 381 | choices=["linear", "cosine", "cosine_with_restarts", "polynomial", | 368 | choices=[ |
| 382 | "constant", "constant_with_warmup", "one_cycle"], | 369 | "linear", |
| 383 | help='The scheduler type to use.', | 370 | "cosine", |
| 371 | "cosine_with_restarts", | ||
| 372 | "polynomial", | ||
| 373 | "constant", | ||
| 374 | "constant_with_warmup", | ||
| 375 | "one_cycle", | ||
| 376 | ], | ||
| 377 | help="The scheduler type to use.", | ||
| 384 | ) | 378 | ) |
| 385 | parser.add_argument( | 379 | parser.add_argument( |
| 386 | "--lr_warmup_epochs", | 380 | "--lr_warmup_epochs", |
| 387 | type=int, | 381 | type=int, |
| 388 | default=10, | 382 | default=10, |
| 389 | help="Number of steps for the warmup in the lr scheduler." | 383 | help="Number of steps for the warmup in the lr scheduler.", |
| 390 | ) | 384 | ) |
| 391 | parser.add_argument( | 385 | parser.add_argument( |
| 392 | "--lr_mid_point", | 386 | "--lr_mid_point", type=float, default=0.3, help="OneCycle schedule mid point." |
| 393 | type=float, | ||
| 394 | default=0.3, | ||
| 395 | help="OneCycle schedule mid point." | ||
| 396 | ) | 387 | ) |
| 397 | parser.add_argument( | 388 | parser.add_argument( |
| 398 | "--lr_cycles", | 389 | "--lr_cycles", |
| 399 | type=int, | 390 | type=int, |
| 400 | default=None, | 391 | default=None, |
| 401 | help="Number of restart cycles in the lr scheduler (if supported)." | 392 | help="Number of restart cycles in the lr scheduler (if supported).", |
| 402 | ) | 393 | ) |
| 403 | parser.add_argument( | 394 | parser.add_argument( |
| 404 | "--lr_warmup_func", | 395 | "--lr_warmup_func", |
| @@ -410,7 +401,7 @@ def parse_args(): | |||
| 410 | "--lr_warmup_exp", | 401 | "--lr_warmup_exp", |
| 411 | type=int, | 402 | type=int, |
| 412 | default=1, | 403 | default=1, |
| 413 | help='If lr_warmup_func is "cos", exponent to modify the function' | 404 | help='If lr_warmup_func is "cos", exponent to modify the function', |
| 414 | ) | 405 | ) |
| 415 | parser.add_argument( | 406 | parser.add_argument( |
| 416 | "--lr_annealing_func", | 407 | "--lr_annealing_func", |
| @@ -422,69 +413,76 @@ def parse_args(): | |||
| 422 | "--lr_annealing_exp", | 413 | "--lr_annealing_exp", |
| 423 | type=int, | 414 | type=int, |
| 424 | default=3, | 415 | default=3, |
| 425 | help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' | 416 | help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function', |
| 426 | ) | 417 | ) |
| 427 | parser.add_argument( | 418 | parser.add_argument( |
| 428 | "--lr_min_lr", | 419 | "--lr_min_lr", |
| 429 | type=float, | 420 | type=float, |
| 430 | default=0.04, | 421 | default=0.04, |
| 431 | help="Minimum learning rate in the lr scheduler." | 422 | help="Minimum learning rate in the lr scheduler.", |
| 432 | ) | ||
| 433 | parser.add_argument( | ||
| 434 | "--min_snr_gamma", | ||
| 435 | type=int, | ||
| 436 | default=5, | ||
| 437 | help="MinSNR gamma." | ||
| 438 | ) | 423 | ) |
| 424 | parser.add_argument("--min_snr_gamma", type=int, default=5, help="MinSNR gamma.") | ||
| 439 | parser.add_argument( | 425 | parser.add_argument( |
| 440 | "--schedule_sampler", | 426 | "--schedule_sampler", |
| 441 | type=str, | 427 | type=str, |
| 442 | default="uniform", | 428 | default="uniform", |
| 443 | choices=["uniform", "loss-second-moment"], | 429 | choices=["uniform", "loss-second-moment"], |
| 444 | help="Noise schedule sampler." | 430 | help="Noise schedule sampler.", |
| 445 | ) | 431 | ) |
| 446 | parser.add_argument( | 432 | parser.add_argument( |
| 447 | "--optimizer", | 433 | "--optimizer", |
| 448 | type=str, | 434 | type=str, |
| 449 | default="adan", | 435 | default="adan", |
| 450 | choices=["adam", "adam8bit", "adan", "lion", "dadam", "dadan", "adafactor"], | 436 | choices=[ |
| 451 | help='Optimizer to use' | 437 | "adam", |
| 438 | "adam8bit", | ||
| 439 | "adan", | ||
| 440 | "lion", | ||
| 441 | "dadam", | ||
| 442 | "dadan", | ||
| 443 | "dlion", | ||
| 444 | "adafactor", | ||
| 445 | ], | ||
| 446 | help="Optimizer to use", | ||
| 452 | ) | 447 | ) |
| 453 | parser.add_argument( | 448 | parser.add_argument( |
| 454 | "--dadaptation_d0", | 449 | "--dadaptation_d0", |
| 455 | type=float, | 450 | type=float, |
| 456 | default=1e-6, | 451 | default=1e-6, |
| 457 | help="The d0 parameter for Dadaptation optimizers." | 452 | help="The d0 parameter for Dadaptation optimizers.", |
| 453 | ) | ||
| 454 | parser.add_argument( | ||
| 455 | "--dadaptation_growth_rate", | ||
| 456 | type=float, | ||
| 457 | default=math.inf, | ||
| 458 | help="The growth_rate parameter for Dadaptation optimizers.", | ||
| 458 | ) | 459 | ) |
| 459 | parser.add_argument( | 460 | parser.add_argument( |
| 460 | "--adam_beta1", | 461 | "--adam_beta1", |
| 461 | type=float, | 462 | type=float, |
| 462 | default=None, | 463 | default=None, |
| 463 | help="The beta1 parameter for the Adam optimizer." | 464 | help="The beta1 parameter for the Adam optimizer.", |
| 464 | ) | 465 | ) |
| 465 | parser.add_argument( | 466 | parser.add_argument( |
| 466 | "--adam_beta2", | 467 | "--adam_beta2", |
| 467 | type=float, | 468 | type=float, |
| 468 | default=None, | 469 | default=None, |
| 469 | help="The beta2 parameter for the Adam optimizer." | 470 | help="The beta2 parameter for the Adam optimizer.", |
| 470 | ) | 471 | ) |
| 471 | parser.add_argument( | 472 | parser.add_argument( |
| 472 | "--adam_weight_decay", | 473 | "--adam_weight_decay", type=float, default=2e-2, help="Weight decay to use." |
| 473 | type=float, | ||
| 474 | default=2e-2, | ||
| 475 | help="Weight decay to use." | ||
| 476 | ) | 474 | ) |
| 477 | parser.add_argument( | 475 | parser.add_argument( |
| 478 | "--adam_epsilon", | 476 | "--adam_epsilon", |
| 479 | type=float, | 477 | type=float, |
| 480 | default=1e-08, | 478 | default=1e-08, |
| 481 | help="Epsilon value for the Adam optimizer" | 479 | help="Epsilon value for the Adam optimizer", |
| 482 | ) | 480 | ) |
| 483 | parser.add_argument( | 481 | parser.add_argument( |
| 484 | "--adam_amsgrad", | 482 | "--adam_amsgrad", |
| 485 | type=bool, | 483 | type=bool, |
| 486 | default=False, | 484 | default=False, |
| 487 | help="Amsgrad value for the Adam optimizer" | 485 | help="Amsgrad value for the Adam optimizer", |
| 488 | ) | 486 | ) |
| 489 | parser.add_argument( | 487 | parser.add_argument( |
| 490 | "--mixed_precision", | 488 | "--mixed_precision", |
| @@ -547,19 +545,19 @@ def parse_args(): | |||
| 547 | "--valid_set_size", | 545 | "--valid_set_size", |
| 548 | type=int, | 546 | type=int, |
| 549 | default=None, | 547 | default=None, |
| 550 | help="Number of images in the validation dataset." | 548 | help="Number of images in the validation dataset.", |
| 551 | ) | 549 | ) |
| 552 | parser.add_argument( | 550 | parser.add_argument( |
| 553 | "--valid_set_repeat", | 551 | "--valid_set_repeat", |
| 554 | type=int, | 552 | type=int, |
| 555 | default=1, | 553 | default=1, |
| 556 | help="Times the images in the validation dataset are repeated." | 554 | help="Times the images in the validation dataset are repeated.", |
| 557 | ) | 555 | ) |
| 558 | parser.add_argument( | 556 | parser.add_argument( |
| 559 | "--train_batch_size", | 557 | "--train_batch_size", |
| 560 | type=int, | 558 | type=int, |
| 561 | default=1, | 559 | default=1, |
| 562 | help="Batch size (per device) for the training dataloader." | 560 | help="Batch size (per device) for the training dataloader.", |
| 563 | ) | 561 | ) |
| 564 | parser.add_argument( | 562 | parser.add_argument( |
| 565 | "--sample_steps", | 563 | "--sample_steps", |
| @@ -571,19 +569,10 @@ def parse_args(): | |||
| 571 | "--prior_loss_weight", | 569 | "--prior_loss_weight", |
| 572 | type=float, | 570 | type=float, |
| 573 | default=1.0, | 571 | default=1.0, |
| 574 | help="The weight of prior preservation loss." | 572 | help="The weight of prior preservation loss.", |
| 575 | ) | ||
| 576 | parser.add_argument( | ||
| 577 | "--run_pti", | ||
| 578 | action="store_true", | ||
| 579 | help="Whether to run PTI." | ||
| 580 | ) | ||
| 581 | parser.add_argument( | ||
| 582 | "--emb_alpha", | ||
| 583 | type=float, | ||
| 584 | default=1.0, | ||
| 585 | help="Embedding alpha" | ||
| 586 | ) | 573 | ) |
| 574 | parser.add_argument("--run_pti", action="store_true", help="Whether to run PTI.") | ||
| 575 | parser.add_argument("--emb_alpha", type=float, default=1.0, help="Embedding alpha") | ||
| 587 | parser.add_argument( | 576 | parser.add_argument( |
| 588 | "--emb_dropout", | 577 | "--emb_dropout", |
| 589 | type=float, | 578 | type=float, |
| @@ -591,27 +580,16 @@ def parse_args(): | |||
| 591 | help="Embedding dropout probability.", | 580 | help="Embedding dropout probability.", |
| 592 | ) | 581 | ) |
| 593 | parser.add_argument( | 582 | parser.add_argument( |
| 594 | "--use_emb_decay", | 583 | "--use_emb_decay", action="store_true", help="Whether to use embedding decay." |
| 595 | action="store_true", | ||
| 596 | help="Whether to use embedding decay." | ||
| 597 | ) | 584 | ) |
| 598 | parser.add_argument( | 585 | parser.add_argument( |
| 599 | "--emb_decay_target", | 586 | "--emb_decay_target", default=0.4, type=float, help="Embedding decay target." |
| 600 | default=0.4, | ||
| 601 | type=float, | ||
| 602 | help="Embedding decay target." | ||
| 603 | ) | 587 | ) |
| 604 | parser.add_argument( | 588 | parser.add_argument( |
| 605 | "--emb_decay", | 589 | "--emb_decay", default=1e2, type=float, help="Embedding decay factor." |
| 606 | default=1e+2, | ||
| 607 | type=float, | ||
| 608 | help="Embedding decay factor." | ||
| 609 | ) | 590 | ) |
| 610 | parser.add_argument( | 591 | parser.add_argument( |
| 611 | "--max_grad_norm", | 592 | "--max_grad_norm", default=1.0, type=float, help="Max gradient norm." |
| 612 | default=1.0, | ||
| 613 | type=float, | ||
| 614 | help="Max gradient norm." | ||
| 615 | ) | 593 | ) |
| 616 | parser.add_argument( | 594 | parser.add_argument( |
| 617 | "--noise_timesteps", | 595 | "--noise_timesteps", |
| @@ -622,7 +600,7 @@ def parse_args(): | |||
| 622 | "--config", | 600 | "--config", |
| 623 | type=str, | 601 | type=str, |
| 624 | default=None, | 602 | default=None, |
| 625 | help="Path to a JSON configuration file containing arguments for invoking this script." | 603 | help="Path to a JSON configuration file containing arguments for invoking this script.", |
| 626 | ) | 604 | ) |
| 627 | 605 | ||
| 628 | args = parser.parse_args() | 606 | args = parser.parse_args() |
| @@ -649,29 +627,44 @@ def parse_args(): | |||
| 649 | args.placeholder_tokens = [args.placeholder_tokens] | 627 | args.placeholder_tokens = [args.placeholder_tokens] |
| 650 | 628 | ||
| 651 | if isinstance(args.initializer_tokens, str): | 629 | if isinstance(args.initializer_tokens, str): |
| 652 | args.initializer_tokens = [args.initializer_tokens] * len(args.placeholder_tokens) | 630 | args.initializer_tokens = [args.initializer_tokens] * len( |
| 631 | args.placeholder_tokens | ||
| 632 | ) | ||
| 653 | 633 | ||
| 654 | if len(args.placeholder_tokens) == 0: | 634 | if len(args.placeholder_tokens) == 0: |
| 655 | args.placeholder_tokens = [f"<*{i}>" for i in range(len(args.initializer_tokens))] | 635 | args.placeholder_tokens = [ |
| 636 | f"<*{i}>" for i in range(len(args.initializer_tokens)) | ||
| 637 | ] | ||
| 656 | 638 | ||
| 657 | if len(args.initializer_tokens) == 0: | 639 | if len(args.initializer_tokens) == 0: |
| 658 | args.initializer_tokens = args.placeholder_tokens.copy() | 640 | args.initializer_tokens = args.placeholder_tokens.copy() |
| 659 | 641 | ||
| 660 | if len(args.placeholder_tokens) != len(args.initializer_tokens): | 642 | if len(args.placeholder_tokens) != len(args.initializer_tokens): |
| 661 | raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") | 643 | raise ValueError( |
| 644 | "--placeholder_tokens and --initializer_tokens must have the same number of items" | ||
| 645 | ) | ||
| 662 | 646 | ||
| 663 | if isinstance(args.inverted_initializer_tokens, str): | 647 | if isinstance(args.inverted_initializer_tokens, str): |
| 664 | args.inverted_initializer_tokens = [args.inverted_initializer_tokens] * len(args.placeholder_tokens) | 648 | args.inverted_initializer_tokens = [args.inverted_initializer_tokens] * len( |
| 649 | args.placeholder_tokens | ||
| 650 | ) | ||
| 665 | 651 | ||
| 666 | if isinstance(args.inverted_initializer_tokens, list) and len(args.inverted_initializer_tokens) != 0: | 652 | if ( |
| 653 | isinstance(args.inverted_initializer_tokens, list) | ||
| 654 | and len(args.inverted_initializer_tokens) != 0 | ||
| 655 | ): | ||
| 667 | args.placeholder_tokens += [f"inv_{t}" for t in args.placeholder_tokens] | 656 | args.placeholder_tokens += [f"inv_{t}" for t in args.placeholder_tokens] |
| 668 | args.initializer_tokens += args.inverted_initializer_tokens | 657 | args.initializer_tokens += args.inverted_initializer_tokens |
| 669 | 658 | ||
| 670 | if isinstance(args.num_vectors, int): | 659 | if isinstance(args.num_vectors, int): |
| 671 | args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) | 660 | args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) |
| 672 | 661 | ||
| 673 | if isinstance(args.num_vectors, list) and len(args.placeholder_tokens) != len(args.num_vectors): | 662 | if isinstance(args.num_vectors, list) and len(args.placeholder_tokens) != len( |
| 674 | raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") | 663 | args.num_vectors |
| 664 | ): | ||
| 665 | raise ValueError( | ||
| 666 | "--placeholder_tokens and --num_vectors must have the same number of items" | ||
| 667 | ) | ||
| 675 | 668 | ||
| 676 | if args.alias_tokens is None: | 669 | if args.alias_tokens is None: |
| 677 | args.alias_tokens = [] | 670 | args.alias_tokens = [] |
| @@ -695,15 +688,15 @@ def parse_args(): | |||
| 695 | raise ValueError("You must specify --output_dir") | 688 | raise ValueError("You must specify --output_dir") |
| 696 | 689 | ||
| 697 | if args.adam_beta1 is None: | 690 | if args.adam_beta1 is None: |
| 698 | if args.optimizer in ('adam', 'adam8bit'): | 691 | if args.optimizer in ("adam", "adam8bit", "dadam"): |
| 699 | args.adam_beta1 = 0.9 | 692 | args.adam_beta1 = 0.9 |
| 700 | elif args.optimizer == 'lion': | 693 | elif args.optimizer in ("lion", "dlion"): |
| 701 | args.adam_beta1 = 0.95 | 694 | args.adam_beta1 = 0.95 |
| 702 | 695 | ||
| 703 | if args.adam_beta2 is None: | 696 | if args.adam_beta2 is None: |
| 704 | if args.optimizer in ('adam', 'adam8bit'): | 697 | if args.optimizer in ("adam", "adam8bit", "dadam"): |
| 705 | args.adam_beta2 = 0.999 | 698 | args.adam_beta2 = 0.999 |
| 706 | elif args.optimizer == 'lion': | 699 | elif args.optimizer in ("lion", "dlion"): |
| 707 | args.adam_beta2 = 0.98 | 700 | args.adam_beta2 = 0.98 |
| 708 | 701 | ||
| 709 | return args | 702 | return args |
| @@ -719,7 +712,7 @@ def main(): | |||
| 719 | accelerator = Accelerator( | 712 | accelerator = Accelerator( |
| 720 | log_with=LoggerType.TENSORBOARD, | 713 | log_with=LoggerType.TENSORBOARD, |
| 721 | project_dir=f"{output_dir}", | 714 | project_dir=f"{output_dir}", |
| 722 | mixed_precision=args.mixed_precision | 715 | mixed_precision=args.mixed_precision, |
| 723 | ) | 716 | ) |
| 724 | 717 | ||
| 725 | weight_dtype = torch.float32 | 718 | weight_dtype = torch.float32 |
| @@ -728,6 +721,8 @@ def main(): | |||
| 728 | elif args.mixed_precision == "bf16": | 721 | elif args.mixed_precision == "bf16": |
| 729 | weight_dtype = torch.bfloat16 | 722 | weight_dtype = torch.bfloat16 |
| 730 | 723 | ||
| 724 | patch_xformers(weight_dtype) | ||
| 725 | |||
| 731 | logging.basicConfig(filename=output_dir / "log.txt", level=logging.DEBUG) | 726 | logging.basicConfig(filename=output_dir / "log.txt", level=logging.DEBUG) |
| 732 | 727 | ||
| 733 | if args.seed is None: | 728 | if args.seed is None: |
| @@ -737,12 +732,18 @@ def main(): | |||
| 737 | 732 | ||
| 738 | save_args(output_dir, args) | 733 | save_args(output_dir, args) |
| 739 | 734 | ||
| 740 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler = get_models(args.pretrained_model_name_or_path) | 735 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler = get_models( |
| 741 | schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, noise_scheduler.config.num_train_timesteps) | 736 | args.pretrained_model_name_or_path |
| 742 | 737 | ) | |
| 738 | schedule_sampler = create_named_schedule_sampler( | ||
| 739 | args.schedule_sampler, noise_scheduler.config.num_train_timesteps | ||
| 740 | ) | ||
| 741 | |||
| 743 | def ensure_embeddings(): | 742 | def ensure_embeddings(): |
| 744 | if args.lora_text_encoder_emb: | 743 | if args.lora_text_encoder_emb: |
| 745 | raise ValueError("Can't use TI options when training token embeddings with LoRA") | 744 | raise ValueError( |
| 745 | "Can't use TI options when training token embeddings with LoRA" | ||
| 746 | ) | ||
| 746 | return patch_managed_embeddings(text_encoder, args.emb_alpha, args.emb_dropout) | 747 | return patch_managed_embeddings(text_encoder, args.emb_alpha, args.emb_dropout) |
| 747 | 748 | ||
| 748 | unet_config = LoraConfig( | 749 | unet_config = LoraConfig( |
| @@ -757,7 +758,9 @@ def main(): | |||
| 757 | text_encoder_config = LoraConfig( | 758 | text_encoder_config = LoraConfig( |
| 758 | r=args.lora_text_encoder_r, | 759 | r=args.lora_text_encoder_r, |
| 759 | lora_alpha=args.lora_text_encoder_alpha, | 760 | lora_alpha=args.lora_text_encoder_alpha, |
| 760 | target_modules=TEXT_ENCODER_TARGET_MODULES_WITH_EMBEDDING if args.lora_text_encoder_emb else TEXT_ENCODER_TARGET_MODULES, | 761 | target_modules=TEXT_ENCODER_TARGET_MODULES_WITH_EMBEDDING |
| 762 | if args.lora_text_encoder_emb | ||
| 763 | else TEXT_ENCODER_TARGET_MODULES, | ||
| 761 | lora_dropout=args.lora_text_encoder_dropout, | 764 | lora_dropout=args.lora_text_encoder_dropout, |
| 762 | bias=args.lora_text_encoder_bias, | 765 | bias=args.lora_text_encoder_bias, |
| 763 | ) | 766 | ) |
| @@ -787,7 +790,7 @@ def main(): | |||
| 787 | 790 | ||
| 788 | if len(args.alias_tokens) != 0: | 791 | if len(args.alias_tokens) != 0: |
| 789 | embeddings = ensure_embeddings() | 792 | embeddings = ensure_embeddings() |
| 790 | 793 | ||
| 791 | alias_placeholder_tokens = args.alias_tokens[::2] | 794 | alias_placeholder_tokens = args.alias_tokens[::2] |
| 792 | alias_initializer_tokens = args.alias_tokens[1::2] | 795 | alias_initializer_tokens = args.alias_tokens[1::2] |
| 793 | 796 | ||
| @@ -795,27 +798,33 @@ def main(): | |||
| 795 | tokenizer=tokenizer, | 798 | tokenizer=tokenizer, |
| 796 | embeddings=embeddings, | 799 | embeddings=embeddings, |
| 797 | placeholder_tokens=alias_placeholder_tokens, | 800 | placeholder_tokens=alias_placeholder_tokens, |
| 798 | initializer_tokens=alias_initializer_tokens | 801 | initializer_tokens=alias_initializer_tokens, |
| 799 | ) | 802 | ) |
| 800 | embeddings.persist() | 803 | embeddings.persist() |
| 801 | print(f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}") | 804 | print( |
| 805 | f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}" | ||
| 806 | ) | ||
| 802 | 807 | ||
| 803 | placeholder_tokens = [] | 808 | placeholder_tokens = [] |
| 804 | placeholder_token_ids = [] | 809 | placeholder_token_ids = [] |
| 805 | 810 | ||
| 806 | if args.embeddings_dir is not None: | 811 | if args.embeddings_dir is not None: |
| 807 | embeddings = ensure_embeddings() | 812 | embeddings = ensure_embeddings() |
| 808 | 813 | ||
| 809 | embeddings_dir = Path(args.embeddings_dir) | 814 | embeddings_dir = Path(args.embeddings_dir) |
| 810 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): | 815 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): |
| 811 | raise ValueError("--embeddings_dir must point to an existing directory") | 816 | raise ValueError("--embeddings_dir must point to an existing directory") |
| 812 | 817 | ||
| 813 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) | 818 | added_tokens, added_ids = load_embeddings_from_dir( |
| 819 | tokenizer, embeddings, embeddings_dir | ||
| 820 | ) | ||
| 814 | 821 | ||
| 815 | placeholder_tokens = added_tokens | 822 | placeholder_tokens = added_tokens |
| 816 | placeholder_token_ids = added_ids | 823 | placeholder_token_ids = added_ids |
| 817 | 824 | ||
| 818 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") | 825 | print( |
| 826 | f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}" | ||
| 827 | ) | ||
| 819 | 828 | ||
| 820 | if args.train_dir_embeddings: | 829 | if args.train_dir_embeddings: |
| 821 | print("Training embeddings from embeddings dir") | 830 | print("Training embeddings from embeddings dir") |
| @@ -824,7 +833,7 @@ def main(): | |||
| 824 | 833 | ||
| 825 | if len(args.placeholder_tokens) != 0 and not args.train_dir_embeddings: | 834 | if len(args.placeholder_tokens) != 0 and not args.train_dir_embeddings: |
| 826 | embeddings = ensure_embeddings() | 835 | embeddings = ensure_embeddings() |
| 827 | 836 | ||
| 828 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( | 837 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( |
| 829 | tokenizer=tokenizer, | 838 | tokenizer=tokenizer, |
| 830 | embeddings=embeddings, | 839 | embeddings=embeddings, |
| @@ -836,23 +845,34 @@ def main(): | |||
| 836 | 845 | ||
| 837 | placeholder_tokens = args.placeholder_tokens | 846 | placeholder_tokens = args.placeholder_tokens |
| 838 | 847 | ||
| 839 | stats = list(zip( | 848 | stats = list( |
| 840 | placeholder_tokens, placeholder_token_ids, args.initializer_tokens, initializer_token_ids | 849 | zip( |
| 841 | )) | 850 | placeholder_tokens, |
| 851 | placeholder_token_ids, | ||
| 852 | args.initializer_tokens, | ||
| 853 | initializer_token_ids, | ||
| 854 | ) | ||
| 855 | ) | ||
| 842 | print(f"Training embeddings: {stats}") | 856 | print(f"Training embeddings: {stats}") |
| 843 | 857 | ||
| 844 | if args.scale_lr: | 858 | if args.scale_lr: |
| 845 | args.learning_rate_unet = ( | 859 | args.learning_rate_unet = ( |
| 846 | args.learning_rate_unet * args.gradient_accumulation_steps * | 860 | args.learning_rate_unet |
| 847 | args.train_batch_size * accelerator.num_processes | 861 | * args.gradient_accumulation_steps |
| 862 | * args.train_batch_size | ||
| 863 | * accelerator.num_processes | ||
| 848 | ) | 864 | ) |
| 849 | args.learning_rate_text = ( | 865 | args.learning_rate_text = ( |
| 850 | args.learning_rate_text * args.gradient_accumulation_steps * | 866 | args.learning_rate_text |
| 851 | args.train_batch_size * accelerator.num_processes | 867 | * args.gradient_accumulation_steps |
| 868 | * args.train_batch_size | ||
| 869 | * accelerator.num_processes | ||
| 852 | ) | 870 | ) |
| 853 | args.learning_rate_emb = ( | 871 | args.learning_rate_emb = ( |
| 854 | args.learning_rate_emb * args.gradient_accumulation_steps * | 872 | args.learning_rate_emb |
| 855 | args.train_batch_size * accelerator.num_processes | 873 | * args.gradient_accumulation_steps |
| 874 | * args.train_batch_size | ||
| 875 | * accelerator.num_processes | ||
| 856 | ) | 876 | ) |
| 857 | 877 | ||
| 858 | if args.find_lr: | 878 | if args.find_lr: |
| @@ -861,11 +881,13 @@ def main(): | |||
| 861 | args.learning_rate_emb = 1e-6 | 881 | args.learning_rate_emb = 1e-6 |
| 862 | args.lr_scheduler = "exponential_growth" | 882 | args.lr_scheduler = "exponential_growth" |
| 863 | 883 | ||
| 864 | if args.optimizer == 'adam8bit': | 884 | if args.optimizer == "adam8bit": |
| 865 | try: | 885 | try: |
| 866 | import bitsandbytes as bnb | 886 | import bitsandbytes as bnb |
| 867 | except ImportError: | 887 | except ImportError: |
| 868 | raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.") | 888 | raise ImportError( |
| 889 | "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." | ||
| 890 | ) | ||
| 869 | 891 | ||
| 870 | create_optimizer = partial( | 892 | create_optimizer = partial( |
| 871 | bnb.optim.AdamW8bit, | 893 | bnb.optim.AdamW8bit, |
| @@ -874,7 +896,7 @@ def main(): | |||
| 874 | eps=args.adam_epsilon, | 896 | eps=args.adam_epsilon, |
| 875 | amsgrad=args.adam_amsgrad, | 897 | amsgrad=args.adam_amsgrad, |
| 876 | ) | 898 | ) |
| 877 | elif args.optimizer == 'adam': | 899 | elif args.optimizer == "adam": |
| 878 | create_optimizer = partial( | 900 | create_optimizer = partial( |
| 879 | torch.optim.AdamW, | 901 | torch.optim.AdamW, |
| 880 | betas=(args.adam_beta1, args.adam_beta2), | 902 | betas=(args.adam_beta1, args.adam_beta2), |
| @@ -882,11 +904,13 @@ def main(): | |||
| 882 | eps=args.adam_epsilon, | 904 | eps=args.adam_epsilon, |
| 883 | amsgrad=args.adam_amsgrad, | 905 | amsgrad=args.adam_amsgrad, |
| 884 | ) | 906 | ) |
| 885 | elif args.optimizer == 'adan': | 907 | elif args.optimizer == "adan": |
| 886 | try: | 908 | try: |
| 887 | import timm.optim | 909 | import timm.optim |
| 888 | except ImportError: | 910 | except ImportError: |
| 889 | raise ImportError("To use Adan, please install the PyTorch Image Models library: `pip install timm`.") | 911 | raise ImportError( |
| 912 | "To use Adan, please install the PyTorch Image Models library: `pip install timm`." | ||
| 913 | ) | ||
| 890 | 914 | ||
| 891 | create_optimizer = partial( | 915 | create_optimizer = partial( |
| 892 | timm.optim.Adan, | 916 | timm.optim.Adan, |
| @@ -894,11 +918,13 @@ def main(): | |||
| 894 | eps=args.adam_epsilon, | 918 | eps=args.adam_epsilon, |
| 895 | no_prox=True, | 919 | no_prox=True, |
| 896 | ) | 920 | ) |
| 897 | elif args.optimizer == 'lion': | 921 | elif args.optimizer == "lion": |
| 898 | try: | 922 | try: |
| 899 | import lion_pytorch | 923 | import lion_pytorch |
| 900 | except ImportError: | 924 | except ImportError: |
| 901 | raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion-pytorch`.") | 925 | raise ImportError( |
| 926 | "To use Lion, please install the lion_pytorch library: `pip install lion-pytorch`." | ||
| 927 | ) | ||
| 902 | 928 | ||
| 903 | create_optimizer = partial( | 929 | create_optimizer = partial( |
| 904 | lion_pytorch.Lion, | 930 | lion_pytorch.Lion, |
| @@ -906,7 +932,7 @@ def main(): | |||
| 906 | weight_decay=args.adam_weight_decay, | 932 | weight_decay=args.adam_weight_decay, |
| 907 | use_triton=True, | 933 | use_triton=True, |
| 908 | ) | 934 | ) |
| 909 | elif args.optimizer == 'adafactor': | 935 | elif args.optimizer == "adafactor": |
| 910 | create_optimizer = partial( | 936 | create_optimizer = partial( |
| 911 | transformers.optimization.Adafactor, | 937 | transformers.optimization.Adafactor, |
| 912 | weight_decay=args.adam_weight_decay, | 938 | weight_decay=args.adam_weight_decay, |
| @@ -920,11 +946,13 @@ def main(): | |||
| 920 | args.learning_rate_unet = None | 946 | args.learning_rate_unet = None |
| 921 | args.learning_rate_text = None | 947 | args.learning_rate_text = None |
| 922 | args.learning_rate_emb = None | 948 | args.learning_rate_emb = None |
| 923 | elif args.optimizer == 'dadam': | 949 | elif args.optimizer == "dadam": |
| 924 | try: | 950 | try: |
| 925 | import dadaptation | 951 | import dadaptation |
| 926 | except ImportError: | 952 | except ImportError: |
| 927 | raise ImportError("To use DAdaptAdam, please install the dadaptation library: `pip install dadaptation`.") | 953 | raise ImportError( |
| 954 | "To use DAdaptAdam, please install the dadaptation library: `pip install dadaptation`." | ||
| 955 | ) | ||
| 928 | 956 | ||
| 929 | create_optimizer = partial( | 957 | create_optimizer = partial( |
| 930 | dadaptation.DAdaptAdam, | 958 | dadaptation.DAdaptAdam, |
| @@ -933,29 +961,35 @@ def main(): | |||
| 933 | eps=args.adam_epsilon, | 961 | eps=args.adam_epsilon, |
| 934 | decouple=True, | 962 | decouple=True, |
| 935 | d0=args.dadaptation_d0, | 963 | d0=args.dadaptation_d0, |
| 964 | growth_rate=args.dadaptation_growth_rate, | ||
| 936 | ) | 965 | ) |
| 937 | 966 | ||
| 938 | args.learning_rate_unet = 1.0 | 967 | args.learning_rate_unet = 1.0 |
| 939 | args.learning_rate_text = 1.0 | 968 | args.learning_rate_text = 1.0 |
| 940 | args.learning_rate_emb = 1.0 | 969 | args.learning_rate_emb = 1.0 |
| 941 | elif args.optimizer == 'dadan': | 970 | elif args.optimizer == "dadan": |
| 942 | try: | 971 | try: |
| 943 | import dadaptation | 972 | import dadaptation |
| 944 | except ImportError: | 973 | except ImportError: |
| 945 | raise ImportError("To use DAdaptAdan, please install the dadaptation library: `pip install dadaptation`.") | 974 | raise ImportError( |
| 975 | "To use DAdaptAdan, please install the dadaptation library: `pip install dadaptation`." | ||
| 976 | ) | ||
| 946 | 977 | ||
| 947 | create_optimizer = partial( | 978 | create_optimizer = partial( |
| 948 | dadaptation.DAdaptAdan, | 979 | dadaptation.DAdaptAdan, |
| 949 | weight_decay=args.adam_weight_decay, | 980 | weight_decay=args.adam_weight_decay, |
| 950 | eps=args.adam_epsilon, | 981 | eps=args.adam_epsilon, |
| 951 | d0=args.dadaptation_d0, | 982 | d0=args.dadaptation_d0, |
| 983 | growth_rate=args.dadaptation_growth_rate, | ||
| 952 | ) | 984 | ) |
| 953 | 985 | ||
| 954 | args.learning_rate_unet = 1.0 | 986 | args.learning_rate_unet = 1.0 |
| 955 | args.learning_rate_text = 1.0 | 987 | args.learning_rate_text = 1.0 |
| 956 | args.learning_rate_emb = 1.0 | 988 | args.learning_rate_emb = 1.0 |
| 989 | elif args.optimizer == "dlion": | ||
| 990 | raise ImportError("DLion has not been merged into dadaptation yet") | ||
| 957 | else: | 991 | else: |
| 958 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") | 992 | raise ValueError(f'Unknown --optimizer "{args.optimizer}"') |
| 959 | 993 | ||
| 960 | trainer = partial( | 994 | trainer = partial( |
| 961 | train, | 995 | train, |
| @@ -1026,25 +1060,33 @@ def main(): | |||
| 1026 | 1060 | ||
| 1027 | if args.run_pti and len(placeholder_tokens) != 0: | 1061 | if args.run_pti and len(placeholder_tokens) != 0: |
| 1028 | embeddings = ensure_embeddings() | 1062 | embeddings = ensure_embeddings() |
| 1029 | 1063 | ||
| 1030 | filter_tokens = [token for token in args.filter_tokens if token in placeholder_tokens] | 1064 | filter_tokens = [ |
| 1065 | token for token in args.filter_tokens if token in placeholder_tokens | ||
| 1066 | ] | ||
| 1031 | 1067 | ||
| 1032 | pti_datamodule = create_datamodule( | 1068 | pti_datamodule = create_datamodule( |
| 1033 | valid_set_size=0, | 1069 | valid_set_size=0, |
| 1034 | batch_size=args.train_batch_size, | 1070 | batch_size=args.train_batch_size, |
| 1035 | filter=partial(keyword_filter, filter_tokens, args.collection, args.exclude_collections), | 1071 | filter=partial( |
| 1072 | keyword_filter, filter_tokens, args.collection, args.exclude_collections | ||
| 1073 | ), | ||
| 1036 | ) | 1074 | ) |
| 1037 | pti_datamodule.setup() | 1075 | pti_datamodule.setup() |
| 1038 | 1076 | ||
| 1039 | num_train_epochs = args.num_train_epochs | 1077 | num_train_epochs = args.num_train_epochs |
| 1040 | pti_sample_frequency = args.sample_frequency | 1078 | pti_sample_frequency = args.sample_frequency |
| 1041 | if num_train_epochs is None: | 1079 | if num_train_epochs is None: |
| 1042 | num_train_epochs = math.ceil( | 1080 | num_train_epochs = ( |
| 1043 | args.num_train_steps / len(pti_datamodule.train_dataset) | 1081 | math.ceil(args.num_train_steps / len(pti_datamodule.train_dataset)) |
| 1044 | ) * args.gradient_accumulation_steps | 1082 | * args.gradient_accumulation_steps |
| 1045 | pti_sample_frequency = math.ceil(num_train_epochs * (pti_sample_frequency / args.num_train_steps)) | 1083 | ) |
| 1084 | pti_sample_frequency = math.ceil( | ||
| 1085 | num_train_epochs * (pti_sample_frequency / args.num_train_steps) | ||
| 1086 | ) | ||
| 1046 | num_training_steps_per_epoch = math.ceil( | 1087 | num_training_steps_per_epoch = math.ceil( |
| 1047 | len(pti_datamodule.train_dataset) / args.gradient_accumulation_steps) | 1088 | len(pti_datamodule.train_dataset) / args.gradient_accumulation_steps |
| 1089 | ) | ||
| 1048 | num_train_steps = num_training_steps_per_epoch * num_train_epochs | 1090 | num_train_steps = num_training_steps_per_epoch * num_train_epochs |
| 1049 | if args.sample_num is not None: | 1091 | if args.sample_num is not None: |
| 1050 | pti_sample_frequency = math.ceil(num_train_epochs / args.sample_num) | 1092 | pti_sample_frequency = math.ceil(num_train_epochs / args.sample_num) |
| @@ -1060,11 +1102,15 @@ def main(): | |||
| 1060 | print(f"============ PTI ============") | 1102 | print(f"============ PTI ============") |
| 1061 | print("") | 1103 | print("") |
| 1062 | 1104 | ||
| 1063 | pti_optimizer = create_optimizer([{ | 1105 | pti_optimizer = create_optimizer( |
| 1064 | "params": text_encoder.text_model.embeddings.token_embedding.parameters(), | 1106 | [ |
| 1065 | "lr": args.learning_rate_emb, | 1107 | { |
| 1066 | "weight_decay": 0, | 1108 | "params": text_encoder.text_model.embeddings.token_embedding.parameters(), |
| 1067 | }]) | 1109 | "lr": args.learning_rate_emb, |
| 1110 | "weight_decay": 0, | ||
| 1111 | } | ||
| 1112 | ] | ||
| 1113 | ) | ||
| 1068 | 1114 | ||
| 1069 | pti_lr_scheduler = create_lr_scheduler( | 1115 | pti_lr_scheduler = create_lr_scheduler( |
| 1070 | "constant_with_warmup", | 1116 | "constant_with_warmup", |
| @@ -1113,11 +1159,16 @@ def main(): | |||
| 1113 | num_train_epochs = args.num_train_epochs | 1159 | num_train_epochs = args.num_train_epochs |
| 1114 | lora_sample_frequency = args.sample_frequency | 1160 | lora_sample_frequency = args.sample_frequency |
| 1115 | if num_train_epochs is None: | 1161 | if num_train_epochs is None: |
| 1116 | num_train_epochs = math.ceil( | 1162 | num_train_epochs = ( |
| 1117 | args.num_train_steps / len(lora_datamodule.train_dataset) | 1163 | math.ceil(args.num_train_steps / len(lora_datamodule.train_dataset)) |
| 1118 | ) * args.gradient_accumulation_steps | 1164 | * args.gradient_accumulation_steps |
| 1119 | lora_sample_frequency = math.ceil(num_train_epochs * (lora_sample_frequency / args.num_train_steps)) | 1165 | ) |
| 1120 | num_training_steps_per_epoch = math.ceil(len(lora_datamodule.train_dataset) / args.gradient_accumulation_steps) | 1166 | lora_sample_frequency = math.ceil( |
| 1167 | num_train_epochs * (lora_sample_frequency / args.num_train_steps) | ||
| 1168 | ) | ||
| 1169 | num_training_steps_per_epoch = math.ceil( | ||
| 1170 | len(lora_datamodule.train_dataset) / args.gradient_accumulation_steps | ||
| 1171 | ) | ||
| 1121 | num_train_steps = num_training_steps_per_epoch * num_train_epochs | 1172 | num_train_steps = num_training_steps_per_epoch * num_train_epochs |
| 1122 | if args.sample_num is not None: | 1173 | if args.sample_num is not None: |
| 1123 | lora_sample_frequency = math.ceil(num_train_epochs / args.sample_num) | 1174 | lora_sample_frequency = math.ceil(num_train_epochs / args.sample_num) |
| @@ -1131,7 +1182,6 @@ def main(): | |||
| 1131 | 1182 | ||
| 1132 | training_iter = 0 | 1183 | training_iter = 0 |
| 1133 | auto_cycles = list(args.auto_cycles) | 1184 | auto_cycles = list(args.auto_cycles) |
| 1134 | learning_rate_emb = args.learning_rate_emb | ||
| 1135 | learning_rate_unet = args.learning_rate_unet | 1185 | learning_rate_unet = args.learning_rate_unet |
| 1136 | learning_rate_text = args.learning_rate_text | 1186 | learning_rate_text = args.learning_rate_text |
| 1137 | lr_scheduler = args.lr_scheduler | 1187 | lr_scheduler = args.lr_scheduler |
| @@ -1145,21 +1195,15 @@ def main(): | |||
| 1145 | 1195 | ||
| 1146 | params_to_optimize = [ | 1196 | params_to_optimize = [ |
| 1147 | { | 1197 | { |
| 1148 | "params": ( | 1198 | "params": (param for param in unet.parameters() if param.requires_grad), |
| 1149 | param | ||
| 1150 | for param in unet.parameters() | ||
| 1151 | if param.requires_grad | ||
| 1152 | ), | ||
| 1153 | "lr": learning_rate_unet, | 1199 | "lr": learning_rate_unet, |
| 1154 | }, | 1200 | }, |
| 1155 | { | 1201 | { |
| 1156 | "params": ( | 1202 | "params": ( |
| 1157 | param | 1203 | param for param in text_encoder.parameters() if param.requires_grad |
| 1158 | for param in text_encoder.parameters() | ||
| 1159 | if param.requires_grad | ||
| 1160 | ), | 1204 | ), |
| 1161 | "lr": learning_rate_text, | 1205 | "lr": learning_rate_text, |
| 1162 | } | 1206 | }, |
| 1163 | ] | 1207 | ] |
| 1164 | group_labels = ["unet", "text"] | 1208 | group_labels = ["unet", "text"] |
| 1165 | 1209 | ||
| @@ -1169,19 +1213,26 @@ def main(): | |||
| 1169 | if len(auto_cycles) != 0: | 1213 | if len(auto_cycles) != 0: |
| 1170 | response = auto_cycles.pop(0) | 1214 | response = auto_cycles.pop(0) |
| 1171 | else: | 1215 | else: |
| 1172 | response = input("\n### Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> ") | 1216 | response = input( |
| 1217 | "\n### Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> " | ||
| 1218 | ) | ||
| 1173 | 1219 | ||
| 1174 | if response.lower().strip() == "o": | 1220 | if response.lower().strip() == "o": |
| 1175 | if args.learning_rate_emb is not None: | ||
| 1176 | learning_rate_emb = args.learning_rate_emb * 2 | ||
| 1177 | if args.learning_rate_unet is not None: | 1221 | if args.learning_rate_unet is not None: |
| 1178 | learning_rate_unet = args.learning_rate_unet * 2 | 1222 | learning_rate_unet = ( |
| 1223 | args.learning_rate_unet * 2 * (args.cycle_decay**training_iter) | ||
| 1224 | ) | ||
| 1179 | if args.learning_rate_text is not None: | 1225 | if args.learning_rate_text is not None: |
| 1180 | learning_rate_text = args.learning_rate_text * 2 | 1226 | learning_rate_text = ( |
| 1227 | args.learning_rate_text * 2 * (args.cycle_decay**training_iter) | ||
| 1228 | ) | ||
| 1181 | else: | 1229 | else: |
| 1182 | learning_rate_emb = args.learning_rate_emb | 1230 | learning_rate_unet = args.learning_rate_unet * ( |
| 1183 | learning_rate_unet = args.learning_rate_unet | 1231 | args.cycle_decay**training_iter |
| 1184 | learning_rate_text = args.learning_rate_text | 1232 | ) |
| 1233 | learning_rate_text = args.learning_rate_text * ( | ||
| 1234 | args.cycle_decay**training_iter | ||
| 1235 | ) | ||
| 1185 | 1236 | ||
| 1186 | if response.lower().strip() == "o": | 1237 | if response.lower().strip() == "o": |
| 1187 | lr_scheduler = "one_cycle" | 1238 | lr_scheduler = "one_cycle" |
| @@ -1204,9 +1255,11 @@ def main(): | |||
| 1204 | print("") | 1255 | print("") |
| 1205 | print(f"============ LoRA cycle {training_iter + 1}: {response} ============") | 1256 | print(f"============ LoRA cycle {training_iter + 1}: {response} ============") |
| 1206 | print("") | 1257 | print("") |
| 1207 | 1258 | ||
| 1208 | for group, lr in zip(lora_optimizer.param_groups, [learning_rate_unet, learning_rate_text]): | 1259 | for group, lr in zip( |
| 1209 | group['lr'] = lr | 1260 | lora_optimizer.param_groups, [learning_rate_unet, learning_rate_text] |
| 1261 | ): | ||
| 1262 | group["lr"] = lr | ||
| 1210 | 1263 | ||
| 1211 | lora_lr_scheduler = create_lr_scheduler( | 1264 | lora_lr_scheduler = create_lr_scheduler( |
| 1212 | lr_scheduler, | 1265 | lr_scheduler, |
| @@ -1218,7 +1271,9 @@ def main(): | |||
| 1218 | warmup_epochs=lr_warmup_epochs, | 1271 | warmup_epochs=lr_warmup_epochs, |
| 1219 | ) | 1272 | ) |
| 1220 | 1273 | ||
| 1221 | lora_checkpoint_output_dir = output_dir / lora_project / f"model_{training_iter}" | 1274 | lora_checkpoint_output_dir = ( |
| 1275 | output_dir / lora_project / f"model_{training_iter}" | ||
| 1276 | ) | ||
| 1222 | 1277 | ||
| 1223 | trainer( | 1278 | trainer( |
| 1224 | strategy=lora_strategy, | 1279 | strategy=lora_strategy, |
| @@ -1246,12 +1301,6 @@ def main(): | |||
| 1246 | ) | 1301 | ) |
| 1247 | 1302 | ||
| 1248 | training_iter += 1 | 1303 | training_iter += 1 |
| 1249 | if learning_rate_emb is not None: | ||
| 1250 | learning_rate_emb *= args.cycle_decay | ||
| 1251 | if learning_rate_unet is not None: | ||
| 1252 | learning_rate_unet *= args.cycle_decay | ||
| 1253 | if learning_rate_text is not None: | ||
| 1254 | learning_rate_text *= args.cycle_decay | ||
| 1255 | 1304 | ||
| 1256 | accelerator.end_training() | 1305 | accelerator.end_training() |
| 1257 | 1306 | ||
