diff options
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 379 |
1 files changed, 201 insertions, 178 deletions
diff --git a/train_ti.py b/train_ti.py index f60e3e5..c6f0b3a 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -32,7 +32,7 @@ from util.files import load_config, load_embeddings_from_dir | |||
| 32 | 32 | ||
| 33 | logger = get_logger(__name__) | 33 | logger = get_logger(__name__) |
| 34 | 34 | ||
| 35 | warnings.filterwarnings('ignore') | 35 | warnings.filterwarnings("ignore") |
| 36 | 36 | ||
| 37 | 37 | ||
| 38 | torch.backends.cuda.matmul.allow_tf32 = True | 38 | torch.backends.cuda.matmul.allow_tf32 = True |
| @@ -46,9 +46,7 @@ hidet.torch.dynamo_config.search_space(0) | |||
| 46 | 46 | ||
| 47 | 47 | ||
| 48 | def parse_args(): | 48 | def parse_args(): |
| 49 | parser = argparse.ArgumentParser( | 49 | parser = argparse.ArgumentParser(description="Simple example of a training script.") |
| 50 | description="Simple example of a training script." | ||
| 51 | ) | ||
| 52 | parser.add_argument( | 50 | parser.add_argument( |
| 53 | "--pretrained_model_name_or_path", | 51 | "--pretrained_model_name_or_path", |
| 54 | type=str, | 52 | type=str, |
| @@ -65,12 +63,12 @@ def parse_args(): | |||
| 65 | "--train_data_file", | 63 | "--train_data_file", |
| 66 | type=str, | 64 | type=str, |
| 67 | default=None, | 65 | default=None, |
| 68 | help="A CSV file containing the training data." | 66 | help="A CSV file containing the training data.", |
| 69 | ) | 67 | ) |
| 70 | parser.add_argument( | 68 | parser.add_argument( |
| 71 | "--train_data_template", | 69 | "--train_data_template", |
| 72 | type=str, | 70 | type=str, |
| 73 | nargs='*', | 71 | nargs="*", |
| 74 | default="template", | 72 | default="template", |
| 75 | ) | 73 | ) |
| 76 | parser.add_argument( | 74 | parser.add_argument( |
| @@ -80,59 +78,47 @@ def parse_args(): | |||
| 80 | help="The name of the current project.", | 78 | help="The name of the current project.", |
| 81 | ) | 79 | ) |
| 82 | parser.add_argument( | 80 | parser.add_argument( |
| 83 | "--auto_cycles", | 81 | "--auto_cycles", type=str, default="o", help="Cycles to run automatically." |
| 84 | type=str, | ||
| 85 | default="o", | ||
| 86 | help="Cycles to run automatically." | ||
| 87 | ) | 82 | ) |
| 88 | parser.add_argument( | 83 | parser.add_argument( |
| 89 | "--cycle_decay", | 84 | "--cycle_decay", type=float, default=1.0, help="Learning rate decay per cycle." |
| 90 | type=float, | ||
| 91 | default=1.0, | ||
| 92 | help="Learning rate decay per cycle." | ||
| 93 | ) | 85 | ) |
| 94 | parser.add_argument( | 86 | parser.add_argument( |
| 95 | "--placeholder_tokens", | 87 | "--placeholder_tokens", |
| 96 | type=str, | 88 | type=str, |
| 97 | nargs='*', | 89 | nargs="*", |
| 98 | help="A token to use as a placeholder for the concept.", | 90 | help="A token to use as a placeholder for the concept.", |
| 99 | ) | 91 | ) |
| 100 | parser.add_argument( | 92 | parser.add_argument( |
| 101 | "--initializer_tokens", | 93 | "--initializer_tokens", |
| 102 | type=str, | 94 | type=str, |
| 103 | nargs='*', | 95 | nargs="*", |
| 104 | help="A token to use as initializer word." | 96 | help="A token to use as initializer word.", |
| 105 | ) | 97 | ) |
| 106 | parser.add_argument( | 98 | parser.add_argument( |
| 107 | "--filter_tokens", | 99 | "--filter_tokens", type=str, nargs="*", help="Tokens to filter the dataset by." |
| 108 | type=str, | ||
| 109 | nargs='*', | ||
| 110 | help="Tokens to filter the dataset by." | ||
| 111 | ) | 100 | ) |
| 112 | parser.add_argument( | 101 | parser.add_argument( |
| 113 | "--initializer_noise", | 102 | "--initializer_noise", |
| 114 | type=float, | 103 | type=float, |
| 115 | default=0, | 104 | default=0, |
| 116 | help="Noise to apply to the initializer word" | 105 | help="Noise to apply to the initializer word", |
| 117 | ) | 106 | ) |
| 118 | parser.add_argument( | 107 | parser.add_argument( |
| 119 | "--alias_tokens", | 108 | "--alias_tokens", |
| 120 | type=str, | 109 | type=str, |
| 121 | nargs='*', | 110 | nargs="*", |
| 122 | default=[], | 111 | default=[], |
| 123 | help="Tokens to create an alias for." | 112 | help="Tokens to create an alias for.", |
| 124 | ) | 113 | ) |
| 125 | parser.add_argument( | 114 | parser.add_argument( |
| 126 | "--inverted_initializer_tokens", | 115 | "--inverted_initializer_tokens", |
| 127 | type=str, | 116 | type=str, |
| 128 | nargs='*', | 117 | nargs="*", |
| 129 | help="A token to use as initializer word." | 118 | help="A token to use as initializer word.", |
| 130 | ) | 119 | ) |
| 131 | parser.add_argument( | 120 | parser.add_argument( |
| 132 | "--num_vectors", | 121 | "--num_vectors", type=int, nargs="*", help="Number of vectors per embedding." |
| 133 | type=int, | ||
| 134 | nargs='*', | ||
| 135 | help="Number of vectors per embedding." | ||
| 136 | ) | 122 | ) |
| 137 | parser.add_argument( | 123 | parser.add_argument( |
| 138 | "--sequential", | 124 | "--sequential", |
| @@ -147,7 +133,7 @@ def parse_args(): | |||
| 147 | "--num_class_images", | 133 | "--num_class_images", |
| 148 | type=int, | 134 | type=int, |
| 149 | default=0, | 135 | default=0, |
| 150 | help="How many class images to generate." | 136 | help="How many class images to generate.", |
| 151 | ) | 137 | ) |
| 152 | parser.add_argument( | 138 | parser.add_argument( |
| 153 | "--class_image_dir", | 139 | "--class_image_dir", |
| @@ -158,7 +144,7 @@ def parse_args(): | |||
| 158 | parser.add_argument( | 144 | parser.add_argument( |
| 159 | "--exclude_collections", | 145 | "--exclude_collections", |
| 160 | type=str, | 146 | type=str, |
| 161 | nargs='*', | 147 | nargs="*", |
| 162 | help="Exclude all items with a listed collection.", | 148 | help="Exclude all items with a listed collection.", |
| 163 | ) | 149 | ) |
| 164 | parser.add_argument( | 150 | parser.add_argument( |
| @@ -181,14 +167,11 @@ def parse_args(): | |||
| 181 | parser.add_argument( | 167 | parser.add_argument( |
| 182 | "--collection", | 168 | "--collection", |
| 183 | type=str, | 169 | type=str, |
| 184 | nargs='*', | 170 | nargs="*", |
| 185 | help="A collection to filter the dataset.", | 171 | help="A collection to filter the dataset.", |
| 186 | ) | 172 | ) |
| 187 | parser.add_argument( | 173 | parser.add_argument( |
| 188 | "--seed", | 174 | "--seed", type=int, default=None, help="A seed for reproducible training." |
| 189 | type=int, | ||
| 190 | default=None, | ||
| 191 | help="A seed for reproducible training." | ||
| 192 | ) | 175 | ) |
| 193 | parser.add_argument( | 176 | parser.add_argument( |
| 194 | "--resolution", | 177 | "--resolution", |
| @@ -244,7 +227,7 @@ def parse_args(): | |||
| 244 | type=str, | 227 | type=str, |
| 245 | default="auto", | 228 | default="auto", |
| 246 | choices=["all", "trailing", "leading", "between", "auto", "off"], | 229 | choices=["all", "trailing", "leading", "between", "auto", "off"], |
| 247 | help='Vector shuffling algorithm.', | 230 | help="Vector shuffling algorithm.", |
| 248 | ) | 231 | ) |
| 249 | parser.add_argument( | 232 | parser.add_argument( |
| 250 | "--offset_noise_strength", | 233 | "--offset_noise_strength", |
| @@ -256,18 +239,10 @@ def parse_args(): | |||
| 256 | "--input_pertubation", | 239 | "--input_pertubation", |
| 257 | type=float, | 240 | type=float, |
| 258 | default=0, | 241 | default=0, |
| 259 | help="The scale of input pretubation. Recommended 0.1." | 242 | help="The scale of input pretubation. Recommended 0.1.", |
| 260 | ) | ||
| 261 | parser.add_argument( | ||
| 262 | "--num_train_epochs", | ||
| 263 | type=int, | ||
| 264 | default=None | ||
| 265 | ) | ||
| 266 | parser.add_argument( | ||
| 267 | "--num_train_steps", | ||
| 268 | type=int, | ||
| 269 | default=2000 | ||
| 270 | ) | 243 | ) |
| 244 | parser.add_argument("--num_train_epochs", type=int, default=None) | ||
| 245 | parser.add_argument("--num_train_steps", type=int, default=2000) | ||
| 271 | parser.add_argument( | 246 | parser.add_argument( |
| 272 | "--gradient_accumulation_steps", | 247 | "--gradient_accumulation_steps", |
| 273 | type=int, | 248 | type=int, |
| @@ -299,27 +274,31 @@ def parse_args(): | |||
| 299 | "--lr_scheduler", | 274 | "--lr_scheduler", |
| 300 | type=str, | 275 | type=str, |
| 301 | default="one_cycle", | 276 | default="one_cycle", |
| 302 | choices=["linear", "cosine", "cosine_with_restarts", "polynomial", | 277 | choices=[ |
| 303 | "constant", "constant_with_warmup", "one_cycle"], | 278 | "linear", |
| 304 | help='The scheduler type to use.', | 279 | "cosine", |
| 280 | "cosine_with_restarts", | ||
| 281 | "polynomial", | ||
| 282 | "constant", | ||
| 283 | "constant_with_warmup", | ||
| 284 | "one_cycle", | ||
| 285 | ], | ||
| 286 | help="The scheduler type to use.", | ||
| 305 | ) | 287 | ) |
| 306 | parser.add_argument( | 288 | parser.add_argument( |
| 307 | "--lr_warmup_epochs", | 289 | "--lr_warmup_epochs", |
| 308 | type=int, | 290 | type=int, |
| 309 | default=10, | 291 | default=10, |
| 310 | help="Number of steps for the warmup in the lr scheduler." | 292 | help="Number of steps for the warmup in the lr scheduler.", |
| 311 | ) | 293 | ) |
| 312 | parser.add_argument( | 294 | parser.add_argument( |
| 313 | "--lr_mid_point", | 295 | "--lr_mid_point", type=float, default=0.3, help="OneCycle schedule mid point." |
| 314 | type=float, | ||
| 315 | default=0.3, | ||
| 316 | help="OneCycle schedule mid point." | ||
| 317 | ) | 296 | ) |
| 318 | parser.add_argument( | 297 | parser.add_argument( |
| 319 | "--lr_cycles", | 298 | "--lr_cycles", |
| 320 | type=int, | 299 | type=int, |
| 321 | default=None, | 300 | default=None, |
| 322 | help="Number of restart cycles in the lr scheduler." | 301 | help="Number of restart cycles in the lr scheduler.", |
| 323 | ) | 302 | ) |
| 324 | parser.add_argument( | 303 | parser.add_argument( |
| 325 | "--lr_warmup_func", | 304 | "--lr_warmup_func", |
| @@ -331,7 +310,7 @@ def parse_args(): | |||
| 331 | "--lr_warmup_exp", | 310 | "--lr_warmup_exp", |
| 332 | type=int, | 311 | type=int, |
| 333 | default=1, | 312 | default=1, |
| 334 | help='If lr_warmup_func is "cos", exponent to modify the function' | 313 | help='If lr_warmup_func is "cos", exponent to modify the function', |
| 335 | ) | 314 | ) |
| 336 | parser.add_argument( | 315 | parser.add_argument( |
| 337 | "--lr_annealing_func", | 316 | "--lr_annealing_func", |
| @@ -343,89 +322,67 @@ def parse_args(): | |||
| 343 | "--lr_annealing_exp", | 322 | "--lr_annealing_exp", |
| 344 | type=int, | 323 | type=int, |
| 345 | default=1, | 324 | default=1, |
| 346 | help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' | 325 | help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function', |
| 347 | ) | 326 | ) |
| 348 | parser.add_argument( | 327 | parser.add_argument( |
| 349 | "--lr_min_lr", | 328 | "--lr_min_lr", |
| 350 | type=float, | 329 | type=float, |
| 351 | default=0.04, | 330 | default=0.04, |
| 352 | help="Minimum learning rate in the lr scheduler." | 331 | help="Minimum learning rate in the lr scheduler.", |
| 353 | ) | 332 | ) |
| 354 | parser.add_argument( | 333 | parser.add_argument( |
| 355 | "--use_ema", | 334 | "--use_ema", action="store_true", help="Whether to use EMA model." |
| 356 | action="store_true", | ||
| 357 | help="Whether to use EMA model." | ||
| 358 | ) | ||
| 359 | parser.add_argument( | ||
| 360 | "--ema_inv_gamma", | ||
| 361 | type=float, | ||
| 362 | default=1.0 | ||
| 363 | ) | ||
| 364 | parser.add_argument( | ||
| 365 | "--ema_power", | ||
| 366 | type=float, | ||
| 367 | default=4/5 | ||
| 368 | ) | ||
| 369 | parser.add_argument( | ||
| 370 | "--ema_max_decay", | ||
| 371 | type=float, | ||
| 372 | default=0.9999 | ||
| 373 | ) | ||
| 374 | parser.add_argument( | ||
| 375 | "--min_snr_gamma", | ||
| 376 | type=int, | ||
| 377 | default=5, | ||
| 378 | help="MinSNR gamma." | ||
| 379 | ) | 335 | ) |
| 336 | parser.add_argument("--ema_inv_gamma", type=float, default=1.0) | ||
| 337 | parser.add_argument("--ema_power", type=float, default=4 / 5) | ||
| 338 | parser.add_argument("--ema_max_decay", type=float, default=0.9999) | ||
| 339 | parser.add_argument("--min_snr_gamma", type=int, default=5, help="MinSNR gamma.") | ||
| 380 | parser.add_argument( | 340 | parser.add_argument( |
| 381 | "--schedule_sampler", | 341 | "--schedule_sampler", |
| 382 | type=str, | 342 | type=str, |
| 383 | default="uniform", | 343 | default="uniform", |
| 384 | choices=["uniform", "loss-second-moment"], | 344 | choices=["uniform", "loss-second-moment"], |
| 385 | help="Noise schedule sampler." | 345 | help="Noise schedule sampler.", |
| 386 | ) | 346 | ) |
| 387 | parser.add_argument( | 347 | parser.add_argument( |
| 388 | "--optimizer", | 348 | "--optimizer", |
| 389 | type=str, | 349 | type=str, |
| 390 | default="adan", | 350 | default="adan", |
| 391 | choices=["adam", "adam8bit", "adan", "lion", "dadam", "dadan", "adafactor"], | 351 | choices=["adam", "adam8bit", "adan", "lion", "dadam", "dadan", "adafactor"], |
| 392 | help='Optimizer to use' | 352 | help="Optimizer to use", |
| 393 | ) | 353 | ) |
| 394 | parser.add_argument( | 354 | parser.add_argument( |
| 395 | "--dadaptation_d0", | 355 | "--dadaptation_d0", |
| 396 | type=float, | 356 | type=float, |
| 397 | default=1e-6, | 357 | default=1e-6, |
| 398 | help="The d0 parameter for Dadaptation optimizers." | 358 | help="The d0 parameter for Dadaptation optimizers.", |
| 399 | ) | 359 | ) |
| 400 | parser.add_argument( | 360 | parser.add_argument( |
| 401 | "--adam_beta1", | 361 | "--adam_beta1", |
| 402 | type=float, | 362 | type=float, |
| 403 | default=None, | 363 | default=None, |
| 404 | help="The beta1 parameter for the Adam optimizer." | 364 | help="The beta1 parameter for the Adam optimizer.", |
| 405 | ) | 365 | ) |
| 406 | parser.add_argument( | 366 | parser.add_argument( |
| 407 | "--adam_beta2", | 367 | "--adam_beta2", |
| 408 | type=float, | 368 | type=float, |
| 409 | default=None, | 369 | default=None, |
| 410 | help="The beta2 parameter for the Adam optimizer." | 370 | help="The beta2 parameter for the Adam optimizer.", |
| 411 | ) | 371 | ) |
| 412 | parser.add_argument( | 372 | parser.add_argument( |
| 413 | "--adam_weight_decay", | 373 | "--adam_weight_decay", type=float, default=2e-2, help="Weight decay to use." |
| 414 | type=float, | ||
| 415 | default=2e-2, | ||
| 416 | help="Weight decay to use." | ||
| 417 | ) | 374 | ) |
| 418 | parser.add_argument( | 375 | parser.add_argument( |
| 419 | "--adam_epsilon", | 376 | "--adam_epsilon", |
| 420 | type=float, | 377 | type=float, |
| 421 | default=1e-08, | 378 | default=1e-08, |
| 422 | help="Epsilon value for the Adam optimizer" | 379 | help="Epsilon value for the Adam optimizer", |
| 423 | ) | 380 | ) |
| 424 | parser.add_argument( | 381 | parser.add_argument( |
| 425 | "--adam_amsgrad", | 382 | "--adam_amsgrad", |
| 426 | type=bool, | 383 | type=bool, |
| 427 | default=False, | 384 | default=False, |
| 428 | help="Amsgrad value for the Adam optimizer" | 385 | help="Amsgrad value for the Adam optimizer", |
| 429 | ) | 386 | ) |
| 430 | parser.add_argument( | 387 | parser.add_argument( |
| 431 | "--mixed_precision", | 388 | "--mixed_precision", |
| @@ -456,7 +413,7 @@ def parse_args(): | |||
| 456 | ) | 413 | ) |
| 457 | parser.add_argument( | 414 | parser.add_argument( |
| 458 | "--no_milestone_checkpoints", | 415 | "--no_milestone_checkpoints", |
| 459 | action='store_true', | 416 | action="store_true", |
| 460 | help="If checkpoints are saved on maximum accuracy", | 417 | help="If checkpoints are saved on maximum accuracy", |
| 461 | ) | 418 | ) |
| 462 | parser.add_argument( | 419 | parser.add_argument( |
| @@ -493,25 +450,25 @@ def parse_args(): | |||
| 493 | "--valid_set_size", | 450 | "--valid_set_size", |
| 494 | type=int, | 451 | type=int, |
| 495 | default=None, | 452 | default=None, |
| 496 | help="Number of images in the validation dataset." | 453 | help="Number of images in the validation dataset.", |
| 497 | ) | 454 | ) |
| 498 | parser.add_argument( | 455 | parser.add_argument( |
| 499 | "--train_set_pad", | 456 | "--train_set_pad", |
| 500 | type=int, | 457 | type=int, |
| 501 | default=None, | 458 | default=None, |
| 502 | help="The number to fill train dataset items up to." | 459 | help="The number to fill train dataset items up to.", |
| 503 | ) | 460 | ) |
| 504 | parser.add_argument( | 461 | parser.add_argument( |
| 505 | "--valid_set_pad", | 462 | "--valid_set_pad", |
| 506 | type=int, | 463 | type=int, |
| 507 | default=None, | 464 | default=None, |
| 508 | help="The number to fill validation dataset items up to." | 465 | help="The number to fill validation dataset items up to.", |
| 509 | ) | 466 | ) |
| 510 | parser.add_argument( | 467 | parser.add_argument( |
| 511 | "--train_batch_size", | 468 | "--train_batch_size", |
| 512 | type=int, | 469 | type=int, |
| 513 | default=1, | 470 | default=1, |
| 514 | help="Batch size (per device) for the training dataloader." | 471 | help="Batch size (per device) for the training dataloader.", |
| 515 | ) | 472 | ) |
| 516 | parser.add_argument( | 473 | parser.add_argument( |
| 517 | "--sample_steps", | 474 | "--sample_steps", |
| @@ -523,14 +480,9 @@ def parse_args(): | |||
| 523 | "--prior_loss_weight", | 480 | "--prior_loss_weight", |
| 524 | type=float, | 481 | type=float, |
| 525 | default=1.0, | 482 | default=1.0, |
| 526 | help="The weight of prior preservation loss." | 483 | help="The weight of prior preservation loss.", |
| 527 | ) | ||
| 528 | parser.add_argument( | ||
| 529 | "--emb_alpha", | ||
| 530 | type=float, | ||
| 531 | default=1.0, | ||
| 532 | help="Embedding alpha" | ||
| 533 | ) | 484 | ) |
| 485 | parser.add_argument("--emb_alpha", type=float, default=1.0, help="Embedding alpha") | ||
| 534 | parser.add_argument( | 486 | parser.add_argument( |
| 535 | "--emb_dropout", | 487 | "--emb_dropout", |
| 536 | type=float, | 488 | type=float, |
| @@ -538,21 +490,13 @@ def parse_args(): | |||
| 538 | help="Embedding dropout probability.", | 490 | help="Embedding dropout probability.", |
| 539 | ) | 491 | ) |
| 540 | parser.add_argument( | 492 | parser.add_argument( |
| 541 | "--use_emb_decay", | 493 | "--use_emb_decay", action="store_true", help="Whether to use embedding decay." |
| 542 | action="store_true", | ||
| 543 | help="Whether to use embedding decay." | ||
| 544 | ) | 494 | ) |
| 545 | parser.add_argument( | 495 | parser.add_argument( |
| 546 | "--emb_decay_target", | 496 | "--emb_decay_target", default=0.4, type=float, help="Embedding decay target." |
| 547 | default=0.4, | ||
| 548 | type=float, | ||
| 549 | help="Embedding decay target." | ||
| 550 | ) | 497 | ) |
| 551 | parser.add_argument( | 498 | parser.add_argument( |
| 552 | "--emb_decay", | 499 | "--emb_decay", default=1e2, type=float, help="Embedding decay factor." |
| 553 | default=1e+2, | ||
| 554 | type=float, | ||
| 555 | help="Embedding decay factor." | ||
| 556 | ) | 500 | ) |
| 557 | parser.add_argument( | 501 | parser.add_argument( |
| 558 | "--noise_timesteps", | 502 | "--noise_timesteps", |
| @@ -563,7 +507,7 @@ def parse_args(): | |||
| 563 | "--resume_from", | 507 | "--resume_from", |
| 564 | type=str, | 508 | type=str, |
| 565 | default=None, | 509 | default=None, |
| 566 | help="Path to a directory to resume training from (ie, logs/token_name/2022-09-22T23-36-27)" | 510 | help="Path to a directory to resume training from (ie, logs/token_name/2022-09-22T23-36-27)", |
| 567 | ) | 511 | ) |
| 568 | parser.add_argument( | 512 | parser.add_argument( |
| 569 | "--global_step", | 513 | "--global_step", |
| @@ -574,7 +518,7 @@ def parse_args(): | |||
| 574 | "--config", | 518 | "--config", |
| 575 | type=str, | 519 | type=str, |
| 576 | default=None, | 520 | default=None, |
| 577 | help="Path to a JSON configuration file containing arguments for invoking this script." | 521 | help="Path to a JSON configuration file containing arguments for invoking this script.", |
| 578 | ) | 522 | ) |
| 579 | 523 | ||
| 580 | args = parser.parse_args() | 524 | args = parser.parse_args() |
| @@ -595,29 +539,44 @@ def parse_args(): | |||
| 595 | args.placeholder_tokens = [args.placeholder_tokens] | 539 | args.placeholder_tokens = [args.placeholder_tokens] |
| 596 | 540 | ||
| 597 | if isinstance(args.initializer_tokens, str): | 541 | if isinstance(args.initializer_tokens, str): |
| 598 | args.initializer_tokens = [args.initializer_tokens] * len(args.placeholder_tokens) | 542 | args.initializer_tokens = [args.initializer_tokens] * len( |
| 543 | args.placeholder_tokens | ||
| 544 | ) | ||
| 599 | 545 | ||
| 600 | if len(args.placeholder_tokens) == 0: | 546 | if len(args.placeholder_tokens) == 0: |
| 601 | args.placeholder_tokens = [f"<*{i}>" for i in range(len(args.initializer_tokens))] | 547 | args.placeholder_tokens = [ |
| 548 | f"<*{i}>" for i in range(len(args.initializer_tokens)) | ||
| 549 | ] | ||
| 602 | 550 | ||
| 603 | if len(args.initializer_tokens) == 0: | 551 | if len(args.initializer_tokens) == 0: |
| 604 | args.initializer_tokens = args.placeholder_tokens.copy() | 552 | args.initializer_tokens = args.placeholder_tokens.copy() |
| 605 | 553 | ||
| 606 | if len(args.placeholder_tokens) != len(args.initializer_tokens): | 554 | if len(args.placeholder_tokens) != len(args.initializer_tokens): |
| 607 | raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") | 555 | raise ValueError( |
| 556 | "--placeholder_tokens and --initializer_tokens must have the same number of items" | ||
| 557 | ) | ||
| 608 | 558 | ||
| 609 | if isinstance(args.inverted_initializer_tokens, str): | 559 | if isinstance(args.inverted_initializer_tokens, str): |
| 610 | args.inverted_initializer_tokens = [args.inverted_initializer_tokens] * len(args.placeholder_tokens) | 560 | args.inverted_initializer_tokens = [args.inverted_initializer_tokens] * len( |
| 561 | args.placeholder_tokens | ||
| 562 | ) | ||
| 611 | 563 | ||
| 612 | if isinstance(args.inverted_initializer_tokens, list) and len(args.inverted_initializer_tokens) != 0: | 564 | if ( |
| 565 | isinstance(args.inverted_initializer_tokens, list) | ||
| 566 | and len(args.inverted_initializer_tokens) != 0 | ||
| 567 | ): | ||
| 613 | args.placeholder_tokens += [f"inv_{t}" for t in args.placeholder_tokens] | 568 | args.placeholder_tokens += [f"inv_{t}" for t in args.placeholder_tokens] |
| 614 | args.initializer_tokens += args.inverted_initializer_tokens | 569 | args.initializer_tokens += args.inverted_initializer_tokens |
| 615 | 570 | ||
| 616 | if isinstance(args.num_vectors, int): | 571 | if isinstance(args.num_vectors, int): |
| 617 | args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) | 572 | args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) |
| 618 | 573 | ||
| 619 | if isinstance(args.num_vectors, list) and len(args.placeholder_tokens) != len(args.num_vectors): | 574 | if isinstance(args.num_vectors, list) and len(args.placeholder_tokens) != len( |
| 620 | raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") | 575 | args.num_vectors |
| 576 | ): | ||
| 577 | raise ValueError( | ||
| 578 | "--placeholder_tokens and --num_vectors must have the same number of items" | ||
| 579 | ) | ||
| 621 | 580 | ||
| 622 | if args.alias_tokens is None: | 581 | if args.alias_tokens is None: |
| 623 | args.alias_tokens = [] | 582 | args.alias_tokens = [] |
| @@ -639,16 +598,22 @@ def parse_args(): | |||
| 639 | ] | 598 | ] |
| 640 | 599 | ||
| 641 | if isinstance(args.train_data_template, str): | 600 | if isinstance(args.train_data_template, str): |
| 642 | args.train_data_template = [args.train_data_template] * len(args.placeholder_tokens) | 601 | args.train_data_template = [args.train_data_template] * len( |
| 602 | args.placeholder_tokens | ||
| 603 | ) | ||
| 643 | 604 | ||
| 644 | if len(args.placeholder_tokens) != len(args.train_data_template): | 605 | if len(args.placeholder_tokens) != len(args.train_data_template): |
| 645 | raise ValueError("--placeholder_tokens and --train_data_template must have the same number of items") | 606 | raise ValueError( |
| 607 | "--placeholder_tokens and --train_data_template must have the same number of items" | ||
| 608 | ) | ||
| 646 | 609 | ||
| 647 | if args.num_vectors is None: | 610 | if args.num_vectors is None: |
| 648 | args.num_vectors = [None] * len(args.placeholder_tokens) | 611 | args.num_vectors = [None] * len(args.placeholder_tokens) |
| 649 | else: | 612 | else: |
| 650 | if isinstance(args.train_data_template, list): | 613 | if isinstance(args.train_data_template, list): |
| 651 | raise ValueError("--train_data_template can't be a list in simultaneous mode") | 614 | raise ValueError( |
| 615 | "--train_data_template can't be a list in simultaneous mode" | ||
| 616 | ) | ||
| 652 | 617 | ||
| 653 | if isinstance(args.collection, str): | 618 | if isinstance(args.collection, str): |
| 654 | args.collection = [args.collection] | 619 | args.collection = [args.collection] |
| @@ -660,13 +625,13 @@ def parse_args(): | |||
| 660 | raise ValueError("You must specify --output_dir") | 625 | raise ValueError("You must specify --output_dir") |
| 661 | 626 | ||
| 662 | if args.adam_beta1 is None: | 627 | if args.adam_beta1 is None: |
| 663 | if args.optimizer == 'lion': | 628 | if args.optimizer == "lion": |
| 664 | args.adam_beta1 = 0.95 | 629 | args.adam_beta1 = 0.95 |
| 665 | else: | 630 | else: |
| 666 | args.adam_beta1 = 0.9 | 631 | args.adam_beta1 = 0.9 |
| 667 | 632 | ||
| 668 | if args.adam_beta2 is None: | 633 | if args.adam_beta2 is None: |
| 669 | if args.optimizer == 'lion': | 634 | if args.optimizer == "lion": |
| 670 | args.adam_beta2 = 0.98 | 635 | args.adam_beta2 = 0.98 |
| 671 | else: | 636 | else: |
| 672 | args.adam_beta2 = 0.999 | 637 | args.adam_beta2 = 0.999 |
| @@ -679,13 +644,13 @@ def main(): | |||
| 679 | 644 | ||
| 680 | global_step_offset = args.global_step | 645 | global_step_offset = args.global_step |
| 681 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | 646 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") |
| 682 | output_dir = Path(args.output_dir)/slugify(args.project)/now | 647 | output_dir = Path(args.output_dir) / slugify(args.project) / now |
| 683 | output_dir.mkdir(parents=True, exist_ok=True) | 648 | output_dir.mkdir(parents=True, exist_ok=True) |
| 684 | 649 | ||
| 685 | accelerator = Accelerator( | 650 | accelerator = Accelerator( |
| 686 | log_with=LoggerType.TENSORBOARD, | 651 | log_with=LoggerType.TENSORBOARD, |
| 687 | project_dir=f"{output_dir}", | 652 | project_dir=f"{output_dir}", |
| 688 | mixed_precision=args.mixed_precision | 653 | mixed_precision=args.mixed_precision, |
| 689 | ) | 654 | ) |
| 690 | 655 | ||
| 691 | weight_dtype = torch.float32 | 656 | weight_dtype = torch.float32 |
| @@ -703,9 +668,15 @@ def main(): | |||
| 703 | 668 | ||
| 704 | save_args(output_dir, args) | 669 | save_args(output_dir, args) |
| 705 | 670 | ||
| 706 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler = get_models(args.pretrained_model_name_or_path) | 671 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler = get_models( |
| 707 | embeddings = patch_managed_embeddings(text_encoder, args.emb_alpha, args.emb_dropout) | 672 | args.pretrained_model_name_or_path |
| 708 | schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, noise_scheduler.config.num_train_timesteps) | 673 | ) |
| 674 | embeddings = patch_managed_embeddings( | ||
| 675 | text_encoder, args.emb_alpha, args.emb_dropout | ||
| 676 | ) | ||
| 677 | schedule_sampler = create_named_schedule_sampler( | ||
| 678 | args.schedule_sampler, noise_scheduler.config.num_train_timesteps | ||
| 679 | ) | ||
| 709 | 680 | ||
| 710 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) | 681 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) |
| 711 | tokenizer.set_dropout(args.vector_dropout) | 682 | tokenizer.set_dropout(args.vector_dropout) |
| @@ -717,16 +688,16 @@ def main(): | |||
| 717 | unet.enable_xformers_memory_efficient_attention() | 688 | unet.enable_xformers_memory_efficient_attention() |
| 718 | elif args.compile_unet: | 689 | elif args.compile_unet: |
| 719 | unet.mid_block.attentions[0].transformer_blocks[0].attn1._use_2_0_attn = False | 690 | unet.mid_block.attentions[0].transformer_blocks[0].attn1._use_2_0_attn = False |
| 720 | 691 | ||
| 721 | proc = AttnProcessor() | 692 | proc = AttnProcessor() |
| 722 | 693 | ||
| 723 | def fn_recursive_set_proc(module: torch.nn.Module): | 694 | def fn_recursive_set_proc(module: torch.nn.Module): |
| 724 | if hasattr(module, "processor"): | 695 | if hasattr(module, "processor"): |
| 725 | module.processor = proc | 696 | module.processor = proc |
| 726 | 697 | ||
| 727 | for child in module.children(): | 698 | for child in module.children(): |
| 728 | fn_recursive_set_proc(child) | 699 | fn_recursive_set_proc(child) |
| 729 | 700 | ||
| 730 | fn_recursive_set_proc(unet) | 701 | fn_recursive_set_proc(unet) |
| 731 | 702 | ||
| 732 | if args.gradient_checkpointing: | 703 | if args.gradient_checkpointing: |
| @@ -751,18 +722,24 @@ def main(): | |||
| 751 | tokenizer=tokenizer, | 722 | tokenizer=tokenizer, |
| 752 | embeddings=embeddings, | 723 | embeddings=embeddings, |
| 753 | placeholder_tokens=alias_placeholder_tokens, | 724 | placeholder_tokens=alias_placeholder_tokens, |
| 754 | initializer_tokens=alias_initializer_tokens | 725 | initializer_tokens=alias_initializer_tokens, |
| 755 | ) | 726 | ) |
| 756 | embeddings.persist() | 727 | embeddings.persist() |
| 757 | print(f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}") | 728 | print( |
| 729 | f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}" | ||
| 730 | ) | ||
| 758 | 731 | ||
| 759 | if args.embeddings_dir is not None: | 732 | if args.embeddings_dir is not None: |
| 760 | embeddings_dir = Path(args.embeddings_dir) | 733 | embeddings_dir = Path(args.embeddings_dir) |
| 761 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): | 734 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): |
| 762 | raise ValueError("--embeddings_dir must point to an existing directory") | 735 | raise ValueError("--embeddings_dir must point to an existing directory") |
| 763 | 736 | ||
| 764 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) | 737 | added_tokens, added_ids = load_embeddings_from_dir( |
| 765 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") | 738 | tokenizer, embeddings, embeddings_dir |
| 739 | ) | ||
| 740 | print( | ||
| 741 | f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}" | ||
| 742 | ) | ||
| 766 | 743 | ||
| 767 | if args.train_dir_embeddings: | 744 | if args.train_dir_embeddings: |
| 768 | args.placeholder_tokens = added_tokens | 745 | args.placeholder_tokens = added_tokens |
| @@ -772,19 +749,23 @@ def main(): | |||
| 772 | 749 | ||
| 773 | if args.scale_lr: | 750 | if args.scale_lr: |
| 774 | args.learning_rate = ( | 751 | args.learning_rate = ( |
| 775 | args.learning_rate * args.gradient_accumulation_steps * | 752 | args.learning_rate |
| 776 | args.train_batch_size * accelerator.num_processes | 753 | * args.gradient_accumulation_steps |
| 754 | * args.train_batch_size | ||
| 755 | * accelerator.num_processes | ||
| 777 | ) | 756 | ) |
| 778 | 757 | ||
| 779 | if args.find_lr: | 758 | if args.find_lr: |
| 780 | args.learning_rate = 1e-5 | 759 | args.learning_rate = 1e-5 |
| 781 | args.lr_scheduler = "exponential_growth" | 760 | args.lr_scheduler = "exponential_growth" |
| 782 | 761 | ||
| 783 | if args.optimizer == 'adam8bit': | 762 | if args.optimizer == "adam8bit": |
| 784 | try: | 763 | try: |
| 785 | import bitsandbytes as bnb | 764 | import bitsandbytes as bnb |
| 786 | except ImportError: | 765 | except ImportError: |
| 787 | raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.") | 766 | raise ImportError( |
| 767 | "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." | ||
| 768 | ) | ||
| 788 | 769 | ||
| 789 | create_optimizer = partial( | 770 | create_optimizer = partial( |
| 790 | bnb.optim.AdamW8bit, | 771 | bnb.optim.AdamW8bit, |
| @@ -793,7 +774,7 @@ def main(): | |||
| 793 | eps=args.adam_epsilon, | 774 | eps=args.adam_epsilon, |
| 794 | amsgrad=args.adam_amsgrad, | 775 | amsgrad=args.adam_amsgrad, |
| 795 | ) | 776 | ) |
| 796 | elif args.optimizer == 'adam': | 777 | elif args.optimizer == "adam": |
| 797 | create_optimizer = partial( | 778 | create_optimizer = partial( |
| 798 | torch.optim.AdamW, | 779 | torch.optim.AdamW, |
| 799 | betas=(args.adam_beta1, args.adam_beta2), | 780 | betas=(args.adam_beta1, args.adam_beta2), |
| @@ -801,11 +782,13 @@ def main(): | |||
| 801 | eps=args.adam_epsilon, | 782 | eps=args.adam_epsilon, |
| 802 | amsgrad=args.adam_amsgrad, | 783 | amsgrad=args.adam_amsgrad, |
| 803 | ) | 784 | ) |
| 804 | elif args.optimizer == 'adan': | 785 | elif args.optimizer == "adan": |
| 805 | try: | 786 | try: |
| 806 | import timm.optim | 787 | import timm.optim |
| 807 | except ImportError: | 788 | except ImportError: |
| 808 | raise ImportError("To use Adan, please install the PyTorch Image Models library: `pip install timm`.") | 789 | raise ImportError( |
| 790 | "To use Adan, please install the PyTorch Image Models library: `pip install timm`." | ||
| 791 | ) | ||
| 809 | 792 | ||
| 810 | create_optimizer = partial( | 793 | create_optimizer = partial( |
| 811 | timm.optim.Adan, | 794 | timm.optim.Adan, |
| @@ -813,11 +796,13 @@ def main(): | |||
| 813 | eps=args.adam_epsilon, | 796 | eps=args.adam_epsilon, |
| 814 | no_prox=True, | 797 | no_prox=True, |
| 815 | ) | 798 | ) |
| 816 | elif args.optimizer == 'lion': | 799 | elif args.optimizer == "lion": |
| 817 | try: | 800 | try: |
| 818 | import lion_pytorch | 801 | import lion_pytorch |
| 819 | except ImportError: | 802 | except ImportError: |
| 820 | raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion-pytorch`.") | 803 | raise ImportError( |
| 804 | "To use Lion, please install the lion_pytorch library: `pip install lion-pytorch`." | ||
| 805 | ) | ||
| 821 | 806 | ||
| 822 | create_optimizer = partial( | 807 | create_optimizer = partial( |
| 823 | lion_pytorch.Lion, | 808 | lion_pytorch.Lion, |
| @@ -825,7 +810,7 @@ def main(): | |||
| 825 | weight_decay=args.adam_weight_decay, | 810 | weight_decay=args.adam_weight_decay, |
| 826 | use_triton=True, | 811 | use_triton=True, |
| 827 | ) | 812 | ) |
| 828 | elif args.optimizer == 'adafactor': | 813 | elif args.optimizer == "adafactor": |
| 829 | create_optimizer = partial( | 814 | create_optimizer = partial( |
| 830 | transformers.optimization.Adafactor, | 815 | transformers.optimization.Adafactor, |
| 831 | weight_decay=args.adam_weight_decay, | 816 | weight_decay=args.adam_weight_decay, |
| @@ -837,11 +822,13 @@ def main(): | |||
| 837 | args.lr_scheduler = "adafactor" | 822 | args.lr_scheduler = "adafactor" |
| 838 | args.lr_min_lr = args.learning_rate | 823 | args.lr_min_lr = args.learning_rate |
| 839 | args.learning_rate = None | 824 | args.learning_rate = None |
| 840 | elif args.optimizer == 'dadam': | 825 | elif args.optimizer == "dadam": |
| 841 | try: | 826 | try: |
| 842 | import dadaptation | 827 | import dadaptation |
| 843 | except ImportError: | 828 | except ImportError: |
| 844 | raise ImportError("To use DAdaptAdam, please install the dadaptation library: `pip install dadaptation`.") | 829 | raise ImportError( |
| 830 | "To use DAdaptAdam, please install the dadaptation library: `pip install dadaptation`." | ||
| 831 | ) | ||
| 845 | 832 | ||
| 846 | create_optimizer = partial( | 833 | create_optimizer = partial( |
| 847 | dadaptation.DAdaptAdam, | 834 | dadaptation.DAdaptAdam, |
| @@ -851,11 +838,13 @@ def main(): | |||
| 851 | decouple=True, | 838 | decouple=True, |
| 852 | d0=args.dadaptation_d0, | 839 | d0=args.dadaptation_d0, |
| 853 | ) | 840 | ) |
| 854 | elif args.optimizer == 'dadan': | 841 | elif args.optimizer == "dadan": |
| 855 | try: | 842 | try: |
| 856 | import dadaptation | 843 | import dadaptation |
| 857 | except ImportError: | 844 | except ImportError: |
| 858 | raise ImportError("To use DAdaptAdan, please install the dadaptation library: `pip install dadaptation`.") | 845 | raise ImportError( |
| 846 | "To use DAdaptAdan, please install the dadaptation library: `pip install dadaptation`." | ||
| 847 | ) | ||
| 859 | 848 | ||
| 860 | create_optimizer = partial( | 849 | create_optimizer = partial( |
| 861 | dadaptation.DAdaptAdan, | 850 | dadaptation.DAdaptAdan, |
| @@ -864,7 +853,7 @@ def main(): | |||
| 864 | d0=args.dadaptation_d0, | 853 | d0=args.dadaptation_d0, |
| 865 | ) | 854 | ) |
| 866 | else: | 855 | else: |
| 867 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") | 856 | raise ValueError(f'Unknown --optimizer "{args.optimizer}"') |
| 868 | 857 | ||
| 869 | trainer = partial( | 858 | trainer = partial( |
| 870 | train, | 859 | train, |
| @@ -904,10 +893,21 @@ def main(): | |||
| 904 | sample_image_size=args.sample_image_size, | 893 | sample_image_size=args.sample_image_size, |
| 905 | ) | 894 | ) |
| 906 | 895 | ||
| 896 | optimizer = create_optimizer( | ||
| 897 | text_encoder.text_model.embeddings.token_embedding.parameters(), | ||
| 898 | lr=learning_rate, | ||
| 899 | ) | ||
| 900 | |||
| 907 | data_generator = torch.Generator(device="cpu").manual_seed(args.seed) | 901 | data_generator = torch.Generator(device="cpu").manual_seed(args.seed) |
| 908 | data_npgenerator = np.random.default_rng(args.seed) | 902 | data_npgenerator = np.random.default_rng(args.seed) |
| 909 | 903 | ||
| 910 | def run(i: int, placeholder_tokens: list[str], initializer_tokens: list[str], num_vectors: Union[int, list[int]], data_template: str): | 904 | def run( |
| 905 | i: int, | ||
| 906 | placeholder_tokens: list[str], | ||
| 907 | initializer_tokens: list[str], | ||
| 908 | num_vectors: Union[int, list[int]], | ||
| 909 | data_template: str, | ||
| 910 | ): | ||
| 911 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( | 911 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( |
| 912 | tokenizer=tokenizer, | 912 | tokenizer=tokenizer, |
| 913 | embeddings=embeddings, | 913 | embeddings=embeddings, |
| @@ -917,14 +917,23 @@ def main(): | |||
| 917 | initializer_noise=args.initializer_noise, | 917 | initializer_noise=args.initializer_noise, |
| 918 | ) | 918 | ) |
| 919 | 919 | ||
| 920 | stats = list(zip(placeholder_tokens, placeholder_token_ids, initializer_tokens, initializer_token_ids)) | 920 | stats = list( |
| 921 | zip( | ||
| 922 | placeholder_tokens, | ||
| 923 | placeholder_token_ids, | ||
| 924 | initializer_tokens, | ||
| 925 | initializer_token_ids, | ||
| 926 | ) | ||
| 927 | ) | ||
| 921 | 928 | ||
| 922 | print("") | 929 | print("") |
| 923 | print(f"============ TI batch {i + 1} ============") | 930 | print(f"============ TI batch {i + 1} ============") |
| 924 | print("") | 931 | print("") |
| 925 | print(stats) | 932 | print(stats) |
| 926 | 933 | ||
| 927 | filter_tokens = [token for token in args.filter_tokens if token in placeholder_tokens] | 934 | filter_tokens = [ |
| 935 | token for token in args.filter_tokens if token in placeholder_tokens | ||
| 936 | ] | ||
| 928 | 937 | ||
| 929 | datamodule = VlpnDataModule( | 938 | datamodule = VlpnDataModule( |
| 930 | data_file=args.train_data_file, | 939 | data_file=args.train_data_file, |
| @@ -945,7 +954,9 @@ def main(): | |||
| 945 | valid_set_size=args.valid_set_size, | 954 | valid_set_size=args.valid_set_size, |
| 946 | train_set_pad=args.train_set_pad, | 955 | train_set_pad=args.train_set_pad, |
| 947 | valid_set_pad=args.valid_set_pad, | 956 | valid_set_pad=args.valid_set_pad, |
| 948 | filter=partial(keyword_filter, filter_tokens, args.collection, args.exclude_collections), | 957 | filter=partial( |
| 958 | keyword_filter, filter_tokens, args.collection, args.exclude_collections | ||
| 959 | ), | ||
| 949 | dtype=weight_dtype, | 960 | dtype=weight_dtype, |
| 950 | generator=data_generator, | 961 | generator=data_generator, |
| 951 | npgenerator=data_npgenerator, | 962 | npgenerator=data_npgenerator, |
| @@ -955,11 +966,16 @@ def main(): | |||
| 955 | num_train_epochs = args.num_train_epochs | 966 | num_train_epochs = args.num_train_epochs |
| 956 | sample_frequency = args.sample_frequency | 967 | sample_frequency = args.sample_frequency |
| 957 | if num_train_epochs is None: | 968 | if num_train_epochs is None: |
| 958 | num_train_epochs = math.ceil( | 969 | num_train_epochs = ( |
| 959 | args.num_train_steps / len(datamodule.train_dataset) | 970 | math.ceil(args.num_train_steps / len(datamodule.train_dataset)) |
| 960 | ) * args.gradient_accumulation_steps | 971 | * args.gradient_accumulation_steps |
| 961 | sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) | 972 | ) |
| 962 | num_training_steps_per_epoch = math.ceil(len(datamodule.train_dataset) / args.gradient_accumulation_steps) | 973 | sample_frequency = math.ceil( |
| 974 | num_train_epochs * (sample_frequency / args.num_train_steps) | ||
| 975 | ) | ||
| 976 | num_training_steps_per_epoch = math.ceil( | ||
| 977 | len(datamodule.train_dataset) / args.gradient_accumulation_steps | ||
| 978 | ) | ||
| 963 | num_train_steps = num_training_steps_per_epoch * num_train_epochs | 979 | num_train_steps = num_training_steps_per_epoch * num_train_epochs |
| 964 | if args.sample_num is not None: | 980 | if args.sample_num is not None: |
| 965 | sample_frequency = math.ceil(num_train_epochs / args.sample_num) | 981 | sample_frequency = math.ceil(num_train_epochs / args.sample_num) |
| @@ -988,7 +1004,8 @@ def main(): | |||
| 988 | response = auto_cycles.pop(0) | 1004 | response = auto_cycles.pop(0) |
| 989 | else: | 1005 | else: |
| 990 | response = input( | 1006 | response = input( |
| 991 | "\n### Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> ") | 1007 | "\n### Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> " |
| 1008 | ) | ||
| 992 | 1009 | ||
| 993 | if response.lower().strip() == "o": | 1010 | if response.lower().strip() == "o": |
| 994 | if args.learning_rate is not None: | 1011 | if args.learning_rate is not None: |
| @@ -1018,10 +1035,8 @@ def main(): | |||
| 1018 | print(f"------------ TI cycle {training_iter + 1}: {response} ------------") | 1035 | print(f"------------ TI cycle {training_iter + 1}: {response} ------------") |
| 1019 | print("") | 1036 | print("") |
| 1020 | 1037 | ||
| 1021 | optimizer = create_optimizer( | 1038 | for group, lr in zip(optimizer.param_groups, [learning_rate]): |
| 1022 | text_encoder.text_model.embeddings.token_embedding.parameters(), | 1039 | group["lr"] = lr |
| 1023 | lr=learning_rate, | ||
| 1024 | ) | ||
| 1025 | 1040 | ||
| 1026 | lr_scheduler = get_scheduler( | 1041 | lr_scheduler = get_scheduler( |
| 1027 | lr_scheduler, | 1042 | lr_scheduler, |
| @@ -1040,7 +1055,9 @@ def main(): | |||
| 1040 | mid_point=args.lr_mid_point, | 1055 | mid_point=args.lr_mid_point, |
| 1041 | ) | 1056 | ) |
| 1042 | 1057 | ||
| 1043 | checkpoint_output_dir = output_dir / project / f"checkpoints_{training_iter}" | 1058 | checkpoint_output_dir = ( |
| 1059 | output_dir / project / f"checkpoints_{training_iter}" | ||
| 1060 | ) | ||
| 1044 | 1061 | ||
| 1045 | trainer( | 1062 | trainer( |
| 1046 | train_dataloader=datamodule.train_dataloader, | 1063 | train_dataloader=datamodule.train_dataloader, |
| @@ -1070,14 +1087,20 @@ def main(): | |||
| 1070 | accelerator.end_training() | 1087 | accelerator.end_training() |
| 1071 | 1088 | ||
| 1072 | if not args.sequential: | 1089 | if not args.sequential: |
| 1073 | run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template) | 1090 | run( |
| 1091 | 0, | ||
| 1092 | args.placeholder_tokens, | ||
| 1093 | args.initializer_tokens, | ||
| 1094 | args.num_vectors, | ||
| 1095 | args.train_data_template, | ||
| 1096 | ) | ||
| 1074 | else: | 1097 | else: |
| 1075 | for i, placeholder_token, initializer_token, num_vectors, data_template in zip( | 1098 | for i, placeholder_token, initializer_token, num_vectors, data_template in zip( |
| 1076 | range(len(args.placeholder_tokens)), | 1099 | range(len(args.placeholder_tokens)), |
| 1077 | args.placeholder_tokens, | 1100 | args.placeholder_tokens, |
| 1078 | args.initializer_tokens, | 1101 | args.initializer_tokens, |
| 1079 | args.num_vectors, | 1102 | args.num_vectors, |
| 1080 | args.train_data_template | 1103 | args.train_data_template, |
| 1081 | ): | 1104 | ): |
| 1082 | run(i, [placeholder_token], [initializer_token], num_vectors, data_template) | 1105 | run(i, [placeholder_token], [initializer_token], num_vectors, data_template) |
| 1083 | embeddings.persist() | 1106 | embeddings.persist() |
