diff options
Diffstat (limited to 'training')
| -rw-r--r-- | training/functional.py | 7 | ||||
| -rw-r--r-- | training/strategy/dreambooth.py | 8 | ||||
| -rw-r--r-- | training/strategy/lora.py | 12 | ||||
| -rw-r--r-- | training/strategy/ti.py | 4 |
4 files changed, 14 insertions, 17 deletions
diff --git a/training/functional.py b/training/functional.py index ebb48ab..015fe5e 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -259,7 +259,7 @@ def snr_weight(noisy_latents, latents, gamma): | |||
| 259 | sigma_mean_sq = F.mse_loss(sigma.float(), zeros.float(), reduction="none").mean([1, 2, 3]) | 259 | sigma_mean_sq = F.mse_loss(sigma.float(), zeros.float(), reduction="none").mean([1, 2, 3]) |
| 260 | snr = torch.div(alpha_mean_sq, sigma_mean_sq) | 260 | snr = torch.div(alpha_mean_sq, sigma_mean_sq) |
| 261 | gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr) | 261 | gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr) |
| 262 | snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float() | 262 | snr_weight = torch.fmin(gamma_over_snr, torch.ones_like(gamma_over_snr)).float() |
| 263 | return snr_weight | 263 | return snr_weight |
| 264 | 264 | ||
| 265 | return torch.tensor( | 265 | return torch.tensor( |
| @@ -471,10 +471,7 @@ def train_loop( | |||
| 471 | "lr": lr_scheduler.get_last_lr()[0], | 471 | "lr": lr_scheduler.get_last_lr()[0], |
| 472 | } | 472 | } |
| 473 | if isDadaptation: | 473 | if isDadaptation: |
| 474 | logs["lr/d*lr"] = ( | 474 | logs["lr/d*lr"] = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"] |
| 475 | optimizer.param_groups[0]["d"] * | ||
| 476 | optimizer.param_groups[0]["lr"] | ||
| 477 | ) | ||
| 478 | logs.update(on_log()) | 475 | logs.update(on_log()) |
| 479 | 476 | ||
| 480 | local_progress_bar.set_postfix(**logs) | 477 | local_progress_bar.set_postfix(**logs) |
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index e5e84c8..28fccff 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py | |||
| @@ -137,8 +137,8 @@ def dreambooth_strategy_callbacks( | |||
| 137 | 137 | ||
| 138 | print("Saving model...") | 138 | print("Saving model...") |
| 139 | 139 | ||
| 140 | unet_ = accelerator.unwrap_model(unet, False) | 140 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False) |
| 141 | text_encoder_ = accelerator.unwrap_model(text_encoder, False) | 141 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) |
| 142 | 142 | ||
| 143 | with ema_context(): | 143 | with ema_context(): |
| 144 | pipeline = VlpnStableDiffusion( | 144 | pipeline = VlpnStableDiffusion( |
| @@ -160,8 +160,8 @@ def dreambooth_strategy_callbacks( | |||
| 160 | @torch.no_grad() | 160 | @torch.no_grad() |
| 161 | def on_sample(step): | 161 | def on_sample(step): |
| 162 | with ema_context(): | 162 | with ema_context(): |
| 163 | unet_ = accelerator.unwrap_model(unet, False) | 163 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) |
| 164 | text_encoder_ = accelerator.unwrap_model(text_encoder, False) | 164 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) |
| 165 | 165 | ||
| 166 | orig_unet_dtype = unet_.dtype | 166 | orig_unet_dtype = unet_.dtype |
| 167 | orig_text_encoder_dtype = text_encoder_.dtype | 167 | orig_text_encoder_dtype = text_encoder_.dtype |
diff --git a/training/strategy/lora.py b/training/strategy/lora.py index aa75bec..1c8fad6 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py | |||
| @@ -47,7 +47,6 @@ def lora_strategy_callbacks( | |||
| 47 | save_samples_ = partial( | 47 | save_samples_ = partial( |
| 48 | save_samples, | 48 | save_samples, |
| 49 | accelerator=accelerator, | 49 | accelerator=accelerator, |
| 50 | text_encoder=text_encoder, | ||
| 51 | tokenizer=tokenizer, | 50 | tokenizer=tokenizer, |
| 52 | vae=vae, | 51 | vae=vae, |
| 53 | sample_scheduler=sample_scheduler, | 52 | sample_scheduler=sample_scheduler, |
| @@ -72,6 +71,7 @@ def lora_strategy_callbacks( | |||
| 72 | @contextmanager | 71 | @contextmanager |
| 73 | def on_train(epoch: int): | 72 | def on_train(epoch: int): |
| 74 | tokenizer.train() | 73 | tokenizer.train() |
| 74 | text_encoder.train() | ||
| 75 | yield | 75 | yield |
| 76 | 76 | ||
| 77 | @contextmanager | 77 | @contextmanager |
| @@ -89,8 +89,8 @@ def lora_strategy_callbacks( | |||
| 89 | def on_checkpoint(step, postfix): | 89 | def on_checkpoint(step, postfix): |
| 90 | print(f"Saving checkpoint for step {step}...") | 90 | print(f"Saving checkpoint for step {step}...") |
| 91 | 91 | ||
| 92 | unet_ = accelerator.unwrap_model(unet, False) | 92 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False) |
| 93 | text_encoder_ = accelerator.unwrap_model(text_encoder, False) | 93 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) |
| 94 | 94 | ||
| 95 | lora_config = {} | 95 | lora_config = {} |
| 96 | state_dict = get_peft_model_state_dict(unet, state_dict=accelerator.get_state_dict(unet)) | 96 | state_dict = get_peft_model_state_dict(unet, state_dict=accelerator.get_state_dict(unet)) |
| @@ -111,10 +111,10 @@ def lora_strategy_callbacks( | |||
| 111 | 111 | ||
| 112 | @torch.no_grad() | 112 | @torch.no_grad() |
| 113 | def on_sample(step): | 113 | def on_sample(step): |
| 114 | unet_ = accelerator.unwrap_model(unet, False) | 114 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) |
| 115 | text_encoder_ = accelerator.unwrap_model(text_encoder, False) | 115 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) |
| 116 | 116 | ||
| 117 | save_samples_(step=step, unet=unet_) | 117 | save_samples_(step=step, unet=unet_, text_encoder=text_encoder_) |
| 118 | 118 | ||
| 119 | del unet_ | 119 | del unet_ |
| 120 | del text_encoder_ | 120 | del text_encoder_ |
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index bd0d178..2038e34 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
| @@ -156,8 +156,8 @@ def textual_inversion_strategy_callbacks( | |||
| 156 | @torch.no_grad() | 156 | @torch.no_grad() |
| 157 | def on_sample(step): | 157 | def on_sample(step): |
| 158 | with ema_context(): | 158 | with ema_context(): |
| 159 | unet_ = accelerator.unwrap_model(unet, False) | 159 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) |
| 160 | text_encoder_ = accelerator.unwrap_model(text_encoder, False) | 160 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) |
| 161 | 161 | ||
| 162 | orig_unet_dtype = unet_.dtype | 162 | orig_unet_dtype = unet_.dtype |
| 163 | orig_text_encoder_dtype = text_encoder_.dtype | 163 | orig_text_encoder_dtype = text_encoder_.dtype |
