diff options
Diffstat (limited to 'training')
| -rw-r--r-- | training/functional.py | 17 | ||||
| -rw-r--r-- | training/strategy/dreambooth.py | 4 | ||||
| -rw-r--r-- | training/strategy/lora.py | 4 | ||||
| -rw-r--r-- | training/strategy/ti.py | 23 | 
4 files changed, 34 insertions, 14 deletions
| diff --git a/training/functional.py b/training/functional.py index 2da0f69..ebc40de 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -42,7 +42,7 @@ class TrainingCallbacks(): | |||
| 42 | on_after_optimize: Callable[[Any, dict[str, float]], None] = const() | 42 | on_after_optimize: Callable[[Any, dict[str, float]], None] = const() | 
| 43 | on_after_epoch: Callable[[], None] = const() | 43 | on_after_epoch: Callable[[], None] = const() | 
| 44 | on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()) | 44 | on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()) | 
| 45 | on_sample: Callable[[int], None] = const() | 45 | on_sample: Callable[[int, int], None] = const() | 
| 46 | on_checkpoint: Callable[[int, str], None] = const() | 46 | on_checkpoint: Callable[[int, str], None] = const() | 
| 47 | 47 | ||
| 48 | 48 | ||
| @@ -96,6 +96,7 @@ def save_samples( | |||
| 96 | output_dir: Path, | 96 | output_dir: Path, | 
| 97 | seed: int, | 97 | seed: int, | 
| 98 | step: int, | 98 | step: int, | 
| 99 | cycle: int = 1, | ||
| 99 | batch_size: int = 1, | 100 | batch_size: int = 1, | 
| 100 | num_batches: int = 1, | 101 | num_batches: int = 1, | 
| 101 | num_steps: int = 20, | 102 | num_steps: int = 20, | 
| @@ -125,7 +126,7 @@ def save_samples( | |||
| 125 | 126 | ||
| 126 | for pool, data, gen in datasets: | 127 | for pool, data, gen in datasets: | 
| 127 | all_samples = [] | 128 | all_samples = [] | 
| 128 | file_path = output_dir / pool / f"step_{step}.jpg" | 129 | file_path = output_dir / pool / f"step_{cycle}_{step}.jpg" | 
| 129 | file_path.parent.mkdir(parents=True, exist_ok=True) | 130 | file_path.parent.mkdir(parents=True, exist_ok=True) | 
| 130 | 131 | ||
| 131 | batches = list(itertools.islice(itertools.cycle(data), batch_size * num_batches)) | 132 | batches = list(itertools.islice(itertools.cycle(data), batch_size * num_batches)) | 
| @@ -455,7 +456,7 @@ def train_loop( | |||
| 455 | sample_frequency: int = 10, | 456 | sample_frequency: int = 10, | 
| 456 | checkpoint_frequency: int = 50, | 457 | checkpoint_frequency: int = 50, | 
| 457 | milestone_checkpoints: bool = True, | 458 | milestone_checkpoints: bool = True, | 
| 458 | initial_samples: bool = True, | 459 | cycle: int = 1, | 
| 459 | global_step_offset: int = 0, | 460 | global_step_offset: int = 0, | 
| 460 | num_epochs: int = 100, | 461 | num_epochs: int = 100, | 
| 461 | gradient_accumulation_steps: int = 1, | 462 | gradient_accumulation_steps: int = 1, | 
| @@ -518,12 +519,12 @@ def train_loop( | |||
| 518 | try: | 519 | try: | 
| 519 | for epoch in range(num_epochs): | 520 | for epoch in range(num_epochs): | 
| 520 | if accelerator.is_main_process: | 521 | if accelerator.is_main_process: | 
| 521 | if epoch % sample_frequency == 0 and (initial_samples or epoch != 0): | 522 | if epoch % sample_frequency == 0 and (cycle == 1 or epoch != 0): | 
| 522 | local_progress_bar.clear() | 523 | local_progress_bar.clear() | 
| 523 | global_progress_bar.clear() | 524 | global_progress_bar.clear() | 
| 524 | 525 | ||
| 525 | with on_eval(): | 526 | with on_eval(): | 
| 526 | on_sample(global_step) | 527 | on_sample(cycle, global_step) | 
| 527 | 528 | ||
| 528 | if epoch % checkpoint_frequency == 0 and epoch != 0: | 529 | if epoch % checkpoint_frequency == 0 and epoch != 0: | 
| 529 | local_progress_bar.clear() | 530 | local_progress_bar.clear() | 
| @@ -648,7 +649,7 @@ def train_loop( | |||
| 648 | if accelerator.is_main_process: | 649 | if accelerator.is_main_process: | 
| 649 | print("Finished!") | 650 | print("Finished!") | 
| 650 | with on_eval(): | 651 | with on_eval(): | 
| 651 | on_sample(global_step) | 652 | on_sample(cycle, global_step) | 
| 652 | on_checkpoint(global_step, "end") | 653 | on_checkpoint(global_step, "end") | 
| 653 | 654 | ||
| 654 | except KeyboardInterrupt: | 655 | except KeyboardInterrupt: | 
| @@ -680,7 +681,7 @@ def train( | |||
| 680 | sample_frequency: int = 20, | 681 | sample_frequency: int = 20, | 
| 681 | checkpoint_frequency: int = 50, | 682 | checkpoint_frequency: int = 50, | 
| 682 | milestone_checkpoints: bool = True, | 683 | milestone_checkpoints: bool = True, | 
| 683 | initial_samples: bool = True, | 684 | cycle: int = 1, | 
| 684 | global_step_offset: int = 0, | 685 | global_step_offset: int = 0, | 
| 685 | guidance_scale: float = 0.0, | 686 | guidance_scale: float = 0.0, | 
| 686 | prior_loss_weight: float = 1.0, | 687 | prior_loss_weight: float = 1.0, | 
| @@ -731,7 +732,7 @@ def train( | |||
| 731 | sample_frequency=sample_frequency, | 732 | sample_frequency=sample_frequency, | 
| 732 | checkpoint_frequency=checkpoint_frequency, | 733 | checkpoint_frequency=checkpoint_frequency, | 
| 733 | milestone_checkpoints=milestone_checkpoints, | 734 | milestone_checkpoints=milestone_checkpoints, | 
| 734 | initial_samples=initial_samples, | 735 | cycle=cycle, | 
| 735 | global_step_offset=global_step_offset, | 736 | global_step_offset=global_step_offset, | 
| 736 | num_epochs=num_train_epochs, | 737 | num_epochs=num_train_epochs, | 
| 737 | gradient_accumulation_steps=gradient_accumulation_steps, | 738 | gradient_accumulation_steps=gradient_accumulation_steps, | 
| diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index 4ae28b7..e6fcc89 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py | |||
| @@ -148,7 +148,7 @@ def dreambooth_strategy_callbacks( | |||
| 148 | torch.cuda.empty_cache() | 148 | torch.cuda.empty_cache() | 
| 149 | 149 | ||
| 150 | @torch.no_grad() | 150 | @torch.no_grad() | 
| 151 | def on_sample(step): | 151 | def on_sample(cycle, step): | 
| 152 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) | 152 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) | 
| 153 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) | 153 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) | 
| 154 | 154 | ||
| @@ -158,7 +158,7 @@ def dreambooth_strategy_callbacks( | |||
| 158 | unet_.to(dtype=weight_dtype) | 158 | unet_.to(dtype=weight_dtype) | 
| 159 | text_encoder_.to(dtype=weight_dtype) | 159 | text_encoder_.to(dtype=weight_dtype) | 
| 160 | 160 | ||
| 161 | save_samples_(step=step, unet=unet_, text_encoder=text_encoder_) | 161 | save_samples_(cycle=cycle, step=step, unet=unet_, text_encoder=text_encoder_) | 
| 162 | 162 | ||
| 163 | unet_.to(dtype=orig_unet_dtype) | 163 | unet_.to(dtype=orig_unet_dtype) | 
| 164 | text_encoder_.to(dtype=orig_text_encoder_dtype) | 164 | text_encoder_.to(dtype=orig_text_encoder_dtype) | 
| diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 48236fb..5c3012e 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py | |||
| @@ -146,11 +146,11 @@ def lora_strategy_callbacks( | |||
| 146 | torch.cuda.empty_cache() | 146 | torch.cuda.empty_cache() | 
| 147 | 147 | ||
| 148 | @torch.no_grad() | 148 | @torch.no_grad() | 
| 149 | def on_sample(step): | 149 | def on_sample(cycle, step): | 
| 150 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) | 150 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) | 
| 151 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) | 151 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) | 
| 152 | 152 | ||
| 153 | save_samples_(step=step, unet=unet_, text_encoder=text_encoder_) | 153 | save_samples_(cycle=cycle, step=step, unet=unet_, text_encoder=text_encoder_) | 
| 154 | 154 | ||
| 155 | del unet_, text_encoder_ | 155 | del unet_, text_encoder_ | 
| 156 | 156 | ||
| diff --git a/training/strategy/ti.py b/training/strategy/ti.py index f0b84b5..6bbff64 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
| @@ -104,10 +104,28 @@ def textual_inversion_strategy_callbacks( | |||
| 104 | yield | 104 | yield | 
| 105 | 105 | ||
| 106 | @torch.no_grad() | 106 | @torch.no_grad() | 
| 107 | def on_before_optimize(epoch: int): | ||
| 108 | if use_emb_decay: | ||
| 109 | params = [ | ||
| 110 | p | ||
| 111 | for p in text_encoder.text_model.embeddings.token_embedding.parameters() | ||
| 112 | if p.grad is not None | ||
| 113 | ] | ||
| 114 | return torch.stack(params) if len(params) != 0 else None | ||
| 115 | |||
| 116 | @torch.no_grad() | ||
| 107 | def on_after_optimize(w, lrs: dict[str, float]): | 117 | def on_after_optimize(w, lrs: dict[str, float]): | 
| 108 | if ema_embeddings is not None: | 118 | if ema_embeddings is not None: | 
| 109 | ema_embeddings.step(text_encoder.text_model.embeddings.token_embedding.parameters()) | 119 | ema_embeddings.step(text_encoder.text_model.embeddings.token_embedding.parameters()) | 
| 110 | 120 | ||
| 121 | if use_emb_decay and w is not None: | ||
| 122 | lr = lrs["emb"] or lrs["0"] | ||
| 123 | lambda_ = emb_decay * lr | ||
| 124 | |||
| 125 | if lambda_ != 0: | ||
| 126 | norm = w[:, :].norm(dim=-1, keepdim=True) | ||
| 127 | w[:].add_((w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) | ||
| 128 | |||
| 111 | def on_log(): | 129 | def on_log(): | 
| 112 | if ema_embeddings is not None: | 130 | if ema_embeddings is not None: | 
| 113 | return {"ema_decay": ema_embeddings.decay} | 131 | return {"ema_decay": ema_embeddings.decay} | 
| @@ -125,7 +143,7 @@ def textual_inversion_strategy_callbacks( | |||
| 125 | ) | 143 | ) | 
| 126 | 144 | ||
| 127 | @torch.no_grad() | 145 | @torch.no_grad() | 
| 128 | def on_sample(step): | 146 | def on_sample(cycle, step): | 
| 129 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) | 147 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) | 
| 130 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) | 148 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) | 
| 131 | 149 | ||
| @@ -135,7 +153,7 @@ def textual_inversion_strategy_callbacks( | |||
| 135 | unet_.to(dtype=weight_dtype) | 153 | unet_.to(dtype=weight_dtype) | 
| 136 | text_encoder_.to(dtype=weight_dtype) | 154 | text_encoder_.to(dtype=weight_dtype) | 
| 137 | 155 | ||
| 138 | save_samples_(step=step, unet=unet_, text_encoder=text_encoder_) | 156 | save_samples_(cycle=cycle, step=step, unet=unet_, text_encoder=text_encoder_) | 
| 139 | 157 | ||
| 140 | unet_.to(dtype=orig_unet_dtype) | 158 | unet_.to(dtype=orig_unet_dtype) | 
| 141 | text_encoder_.to(dtype=orig_text_encoder_dtype) | 159 | text_encoder_.to(dtype=orig_text_encoder_dtype) | 
| @@ -148,6 +166,7 @@ def textual_inversion_strategy_callbacks( | |||
| 148 | return TrainingCallbacks( | 166 | return TrainingCallbacks( | 
| 149 | on_train=on_train, | 167 | on_train=on_train, | 
| 150 | on_eval=on_eval, | 168 | on_eval=on_eval, | 
| 169 | on_before_optimize=on_before_optimize, | ||
| 151 | on_after_optimize=on_after_optimize, | 170 | on_after_optimize=on_after_optimize, | 
| 152 | on_log=on_log, | 171 | on_log=on_log, | 
| 153 | on_checkpoint=on_checkpoint, | 172 | on_checkpoint=on_checkpoint, | 
