diff options
-rw-r--r-- | environment.yaml | 3 | ||||
-rw-r--r-- | models/clip/embeddings.py | 8 | ||||
-rw-r--r-- | train_dreambooth.py | 11 | ||||
-rw-r--r-- | train_lora.py | 11 | ||||
-rw-r--r-- | train_ti.py | 21 | ||||
-rw-r--r-- | training/functional.py | 7 | ||||
-rw-r--r-- | training/strategy/ti.py | 15 |
7 files changed, 31 insertions, 45 deletions
diff --git a/environment.yaml b/environment.yaml index 57624a3..1e6ac60 100644 --- a/environment.yaml +++ b/environment.yaml | |||
@@ -13,16 +13,15 @@ dependencies: | |||
13 | - python=3.10.8 | 13 | - python=3.10.8 |
14 | - pytorch=1.13.1=*cuda* | 14 | - pytorch=1.13.1=*cuda* |
15 | - torchvision=0.14.1 | 15 | - torchvision=0.14.1 |
16 | - xformers=0.0.17.dev461 | ||
16 | - pip: | 17 | - pip: |
17 | - -e . | 18 | - -e . |
18 | - -e git+https://github.com/huggingface/diffusers#egg=diffusers | 19 | - -e git+https://github.com/huggingface/diffusers#egg=diffusers |
19 | - accelerate==0.16.0 | 20 | - accelerate==0.16.0 |
20 | - bitsandbytes==0.37.0 | 21 | - bitsandbytes==0.37.0 |
21 | - lion-pytorch==0.0.6 | ||
22 | - python-slugify>=6.1.2 | 22 | - python-slugify>=6.1.2 |
23 | - safetensors==0.2.8 | 23 | - safetensors==0.2.8 |
24 | - setuptools==65.6.3 | 24 | - setuptools==65.6.3 |
25 | - test-tube>=0.7.5 | 25 | - test-tube>=0.7.5 |
26 | - transformers==4.26.1 | 26 | - transformers==4.26.1 |
27 | - triton==2.0.0a2 | 27 | - triton==2.0.0a2 |
28 | - xformers==0.0.17.dev451 | ||
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 6c41c33..734730e 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py | |||
@@ -98,14 +98,6 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): | |||
98 | 98 | ||
99 | return embeds | 99 | return embeds |
100 | 100 | ||
101 | def normalize(self, target: float = 0.4, lambda_: float = 1.0): | ||
102 | if lambda_ == 0: | ||
103 | return | ||
104 | |||
105 | w = self.temp_token_embedding.weight | ||
106 | norm = w[self.temp_token_ids, :].norm(dim=-1, keepdim=True) | ||
107 | w[self.temp_token_ids].add_((w[self.temp_token_ids] / norm.clamp_min(1e-12)) * lambda_ * (target - norm)) | ||
108 | |||
109 | def forward( | 101 | def forward( |
110 | self, | 102 | self, |
111 | input_ids: Optional[torch.LongTensor] = None, | 103 | input_ids: Optional[torch.LongTensor] = None, |
diff --git a/train_dreambooth.py b/train_dreambooth.py index e039df0..431ff3d 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -288,7 +288,7 @@ def parse_args(): | |||
288 | "--optimizer", | 288 | "--optimizer", |
289 | type=str, | 289 | type=str, |
290 | default="adam", | 290 | default="adam", |
291 | help='Optimizer to use ["adam", "adam8bit", "lion"]' | 291 | help='Optimizer to use ["adam", "adam8bit"]' |
292 | ) | 292 | ) |
293 | parser.add_argument( | 293 | parser.add_argument( |
294 | "--adam_beta1", | 294 | "--adam_beta1", |
@@ -459,7 +459,7 @@ def main(): | |||
459 | save_args(output_dir, args) | 459 | save_args(output_dir, args) |
460 | 460 | ||
461 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( | 461 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( |
462 | args.pretrained_model_name_or_path, noise_scheduler="deis") | 462 | args.pretrained_model_name_or_path) |
463 | 463 | ||
464 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) | 464 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) |
465 | tokenizer.set_dropout(args.vector_dropout) | 465 | tokenizer.set_dropout(args.vector_dropout) |
@@ -513,13 +513,6 @@ def main(): | |||
513 | eps=args.adam_epsilon, | 513 | eps=args.adam_epsilon, |
514 | amsgrad=args.adam_amsgrad, | 514 | amsgrad=args.adam_amsgrad, |
515 | ) | 515 | ) |
516 | elif args.optimizer == 'lion': | ||
517 | try: | ||
518 | from lion_pytorch import Lion | ||
519 | except ImportError: | ||
520 | raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion_pytorch`.") | ||
521 | |||
522 | create_optimizer = partial(Lion, use_triton=True) | ||
523 | else: | 516 | else: |
524 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") | 517 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") |
525 | 518 | ||
diff --git a/train_lora.py b/train_lora.py index db5330a..a06591d 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -248,7 +248,7 @@ def parse_args(): | |||
248 | "--optimizer", | 248 | "--optimizer", |
249 | type=str, | 249 | type=str, |
250 | default="adam", | 250 | default="adam", |
251 | help='Optimizer to use ["adam", "adam8bit", "lion"]' | 251 | help='Optimizer to use ["adam", "adam8bit"]' |
252 | ) | 252 | ) |
253 | parser.add_argument( | 253 | parser.add_argument( |
254 | "--adam_beta1", | 254 | "--adam_beta1", |
@@ -419,7 +419,7 @@ def main(): | |||
419 | save_args(output_dir, args) | 419 | save_args(output_dir, args) |
420 | 420 | ||
421 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( | 421 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( |
422 | args.pretrained_model_name_or_path, noise_scheduler="deis") | 422 | args.pretrained_model_name_or_path) |
423 | 423 | ||
424 | vae.enable_slicing() | 424 | vae.enable_slicing() |
425 | vae.set_use_memory_efficient_attention_xformers(True) | 425 | vae.set_use_memory_efficient_attention_xformers(True) |
@@ -488,13 +488,6 @@ def main(): | |||
488 | eps=args.adam_epsilon, | 488 | eps=args.adam_epsilon, |
489 | amsgrad=args.adam_amsgrad, | 489 | amsgrad=args.adam_amsgrad, |
490 | ) | 490 | ) |
491 | elif args.optimizer == 'lion': | ||
492 | try: | ||
493 | from lion_pytorch import Lion | ||
494 | except ImportError: | ||
495 | raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion_pytorch`.") | ||
496 | |||
497 | create_optimizer = partial(Lion, use_triton=True) | ||
498 | else: | 491 | else: |
499 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") | 492 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") |
500 | 493 | ||
diff --git a/train_ti.py b/train_ti.py index 12e3644..6dc07dd 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -86,7 +86,7 @@ def parse_args(): | |||
86 | help="Number of vectors per embedding." | 86 | help="Number of vectors per embedding." |
87 | ) | 87 | ) |
88 | parser.add_argument( | 88 | parser.add_argument( |
89 | "--simultaneous", | 89 | "--sequential", |
90 | action="store_true", | 90 | action="store_true", |
91 | ) | 91 | ) |
92 | parser.add_argument( | 92 | parser.add_argument( |
@@ -293,7 +293,7 @@ def parse_args(): | |||
293 | "--optimizer", | 293 | "--optimizer", |
294 | type=str, | 294 | type=str, |
295 | default="adam", | 295 | default="adam", |
296 | help='Optimizer to use ["adam", "adam8bit", "lion"]' | 296 | help='Optimizer to use ["adam", "adam8bit"]' |
297 | ) | 297 | ) |
298 | parser.add_argument( | 298 | parser.add_argument( |
299 | "--adam_beta1", | 299 | "--adam_beta1", |
@@ -343,6 +343,11 @@ def parse_args(): | |||
343 | help="How often to save a checkpoint and sample image (in epochs)", | 343 | help="How often to save a checkpoint and sample image (in epochs)", |
344 | ) | 344 | ) |
345 | parser.add_argument( | 345 | parser.add_argument( |
346 | "--no_milestone_checkpoints", | ||
347 | action='store_true', | ||
348 | help="If checkpoints are saved on maximum accuracy", | ||
349 | ) | ||
350 | parser.add_argument( | ||
346 | "--sample_frequency", | 351 | "--sample_frequency", |
347 | type=int, | 352 | type=int, |
348 | default=1, | 353 | default=1, |
@@ -480,7 +485,7 @@ def parse_args(): | |||
480 | if len(args.placeholder_tokens) != len(args.num_vectors): | 485 | if len(args.placeholder_tokens) != len(args.num_vectors): |
481 | raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") | 486 | raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") |
482 | 487 | ||
483 | if not args.simultaneous: | 488 | if args.sequential: |
484 | if isinstance(args.train_data_template, str): | 489 | if isinstance(args.train_data_template, str): |
485 | args.train_data_template = [args.train_data_template] * len(args.placeholder_tokens) | 490 | args.train_data_template = [args.train_data_template] * len(args.placeholder_tokens) |
486 | 491 | ||
@@ -586,13 +591,6 @@ def main(): | |||
586 | eps=args.adam_epsilon, | 591 | eps=args.adam_epsilon, |
587 | amsgrad=args.adam_amsgrad, | 592 | amsgrad=args.adam_amsgrad, |
588 | ) | 593 | ) |
589 | elif args.optimizer == 'lion': | ||
590 | try: | ||
591 | from lion_pytorch import Lion | ||
592 | except ImportError: | ||
593 | raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion_pytorch`.") | ||
594 | |||
595 | create_optimizer = partial(Lion, use_triton=True) | ||
596 | else: | 594 | else: |
597 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") | 595 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") |
598 | 596 | ||
@@ -615,6 +613,7 @@ def main(): | |||
615 | num_train_epochs=args.num_train_epochs, | 613 | num_train_epochs=args.num_train_epochs, |
616 | sample_frequency=args.sample_frequency, | 614 | sample_frequency=args.sample_frequency, |
617 | checkpoint_frequency=args.checkpoint_frequency, | 615 | checkpoint_frequency=args.checkpoint_frequency, |
616 | milestone_checkpoints=not args.no_milestone_checkpoints, | ||
618 | global_step_offset=global_step_offset, | 617 | global_step_offset=global_step_offset, |
619 | # -- | 618 | # -- |
620 | tokenizer=tokenizer, | 619 | tokenizer=tokenizer, |
@@ -715,7 +714,7 @@ def main(): | |||
715 | 714 | ||
716 | plot_metrics(metrics, metrics_output_file) | 715 | plot_metrics(metrics, metrics_output_file) |
717 | 716 | ||
718 | if args.simultaneous: | 717 | if not args.sequential: |
719 | run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template) | 718 | run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template) |
720 | else: | 719 | else: |
721 | for i, placeholder_token, initializer_token, num_vectors, data_template in zip( | 720 | for i, placeholder_token, initializer_token, num_vectors, data_template in zip( |
diff --git a/training/functional.py b/training/functional.py index 85dd884..739d055 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -362,6 +362,7 @@ def train_loop( | |||
362 | loss_step: LossCallable, | 362 | loss_step: LossCallable, |
363 | sample_frequency: int = 10, | 363 | sample_frequency: int = 10, |
364 | checkpoint_frequency: int = 50, | 364 | checkpoint_frequency: int = 50, |
365 | milestone_checkpoints: bool = True, | ||
365 | global_step_offset: int = 0, | 366 | global_step_offset: int = 0, |
366 | num_epochs: int = 100, | 367 | num_epochs: int = 100, |
367 | callbacks: TrainingCallbacks = TrainingCallbacks(), | 368 | callbacks: TrainingCallbacks = TrainingCallbacks(), |
@@ -514,7 +515,7 @@ def train_loop( | |||
514 | accelerator.log(logs, step=global_step) | 515 | accelerator.log(logs, step=global_step) |
515 | 516 | ||
516 | if accelerator.is_main_process: | 517 | if accelerator.is_main_process: |
517 | if avg_acc_val.avg.item() > best_acc_val: | 518 | if avg_acc_val.avg.item() > best_acc_val and milestone_checkpoints: |
518 | local_progress_bar.clear() | 519 | local_progress_bar.clear() |
519 | global_progress_bar.clear() | 520 | global_progress_bar.clear() |
520 | 521 | ||
@@ -527,7 +528,7 @@ def train_loop( | |||
527 | accs.append(avg_acc_val.avg.item()) | 528 | accs.append(avg_acc_val.avg.item()) |
528 | else: | 529 | else: |
529 | if accelerator.is_main_process: | 530 | if accelerator.is_main_process: |
530 | if avg_acc.avg.item() > best_acc: | 531 | if avg_acc.avg.item() > best_acc and milestone_checkpoints: |
531 | local_progress_bar.clear() | 532 | local_progress_bar.clear() |
532 | global_progress_bar.clear() | 533 | global_progress_bar.clear() |
533 | 534 | ||
@@ -572,6 +573,7 @@ def train( | |||
572 | num_train_epochs: int = 100, | 573 | num_train_epochs: int = 100, |
573 | sample_frequency: int = 20, | 574 | sample_frequency: int = 20, |
574 | checkpoint_frequency: int = 50, | 575 | checkpoint_frequency: int = 50, |
576 | milestone_checkpoints: bool = True, | ||
575 | global_step_offset: int = 0, | 577 | global_step_offset: int = 0, |
576 | with_prior_preservation: bool = False, | 578 | with_prior_preservation: bool = False, |
577 | prior_loss_weight: float = 1.0, | 579 | prior_loss_weight: float = 1.0, |
@@ -626,6 +628,7 @@ def train( | |||
626 | loss_step=loss_step_, | 628 | loss_step=loss_step_, |
627 | sample_frequency=sample_frequency, | 629 | sample_frequency=sample_frequency, |
628 | checkpoint_frequency=checkpoint_frequency, | 630 | checkpoint_frequency=checkpoint_frequency, |
631 | milestone_checkpoints=milestone_checkpoints, | ||
629 | global_step_offset=global_step_offset, | 632 | global_step_offset=global_step_offset, |
630 | num_epochs=num_train_epochs, | 633 | num_epochs=num_train_epochs, |
631 | callbacks=callbacks, | 634 | callbacks=callbacks, |
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 66d3129..09beec4 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
@@ -116,10 +116,17 @@ def textual_inversion_strategy_callbacks( | |||
116 | @torch.no_grad() | 116 | @torch.no_grad() |
117 | def on_before_optimize(lr: float, epoch: int): | 117 | def on_before_optimize(lr: float, epoch: int): |
118 | if use_emb_decay: | 118 | if use_emb_decay: |
119 | text_encoder.text_model.embeddings.normalize( | 119 | lambda_ = emb_decay * lr |
120 | emb_decay_target, | 120 | |
121 | min(1.0, emb_decay * lr) | 121 | if lambda_ != 0: |
122 | ) | 122 | w = text_encoder.text_model.embeddings.temp_token_embedding.weight |
123 | |||
124 | mask = torch.zeros(w.size(0), dtype=torch.bool) | ||
125 | mask[text_encoder.text_model.embeddings.temp_token_ids] = True | ||
126 | mask[torch.all(w.grad == 0, dim=1)] = False | ||
127 | |||
128 | norm = w[mask, :].norm(dim=-1, keepdim=True) | ||
129 | w[mask].add_((w[mask] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) | ||
123 | 130 | ||
124 | def on_after_optimize(lr: float): | 131 | def on_after_optimize(lr: float): |
125 | if ema_embeddings is not None: | 132 | if ema_embeddings is not None: |