diff options
-rw-r--r-- | data/csv.py | 2 | ||||
-rw-r--r-- | models/clip/util.py | 6 | ||||
-rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 4 | ||||
-rw-r--r-- | train_dreambooth.py | 11 | ||||
-rw-r--r-- | train_lora.py | 3 | ||||
-rw-r--r-- | train_ti.py | 2 | ||||
-rw-r--r-- | training/functional.py | 7 | ||||
-rw-r--r-- | training/strategy/dreambooth.py | 8 | ||||
-rw-r--r-- | training/strategy/lora.py | 12 | ||||
-rw-r--r-- | training/strategy/ti.py | 4 |
10 files changed, 29 insertions, 30 deletions
diff --git a/data/csv.py b/data/csv.py index 619452e..fba5d4b 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -282,7 +282,7 @@ class VlpnDataModule(): | |||
282 | collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.num_class_images != 0) | 282 | collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.num_class_images != 0) |
283 | 283 | ||
284 | if valid_set_size == 0: | 284 | if valid_set_size == 0: |
285 | data_train, data_val = items, items[:self.batch_size] | 285 | data_train, data_val = items, items |
286 | else: | 286 | else: |
287 | data_train, data_val = random_split(items, [train_set_size, valid_set_size], generator=generator) | 287 | data_train, data_val = random_split(items, [train_set_size, valid_set_size], generator=generator) |
288 | 288 | ||
diff --git a/models/clip/util.py b/models/clip/util.py index 8de8c19..883de6a 100644 --- a/models/clip/util.py +++ b/models/clip/util.py | |||
@@ -23,11 +23,11 @@ def get_extended_embeddings( | |||
23 | model_max_length = text_encoder.config.max_position_embeddings | 23 | model_max_length = text_encoder.config.max_position_embeddings |
24 | prompts = input_ids.shape[0] | 24 | prompts = input_ids.shape[0] |
25 | 25 | ||
26 | input_ids = input_ids.view((-1, model_max_length)).to(text_encoder.device) | 26 | input_ids = input_ids.view((-1, model_max_length)) |
27 | if position_ids is not None: | 27 | if position_ids is not None: |
28 | position_ids = position_ids.view((-1, model_max_length)).to(text_encoder.device) | 28 | position_ids = position_ids.view((-1, model_max_length)) |
29 | if attention_mask is not None: | 29 | if attention_mask is not None: |
30 | attention_mask = attention_mask.view((-1, model_max_length)).to(text_encoder.device) | 30 | attention_mask = attention_mask.view((-1, model_max_length)) |
31 | 31 | ||
32 | text_embeddings = text_encoder(input_ids, position_ids=position_ids, attention_mask=attention_mask)[0] | 32 | text_embeddings = text_encoder(input_ids, position_ids=position_ids, attention_mask=attention_mask)[0] |
33 | text_embeddings = text_embeddings.view((prompts, -1, text_embeddings.shape[2])) | 33 | text_embeddings = text_embeddings.view((prompts, -1, text_embeddings.shape[2])) |
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 4505a2a..dbd262f 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
@@ -291,7 +291,8 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
291 | else: | 291 | else: |
292 | attention_mask = None | 292 | attention_mask = None |
293 | 293 | ||
294 | prompt_embeds = get_extended_embeddings(self.text_encoder, text_input_ids, attention_mask) | 294 | prompt_embeds = get_extended_embeddings(self.text_encoder, text_input_ids.to(device), attention_mask) |
295 | prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) | ||
295 | 296 | ||
296 | return prompt_embeds | 297 | return prompt_embeds |
297 | 298 | ||
@@ -374,6 +375,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
374 | 375 | ||
375 | def decode_latents(self, latents): | 376 | def decode_latents(self, latents): |
376 | latents = 1 / self.vae.config.scaling_factor * latents | 377 | latents = 1 / self.vae.config.scaling_factor * latents |
378 | # image = self.vae.decode(latents.to(dtype=self.vae.dtype)).sample | ||
377 | image = self.vae.decode(latents).sample | 379 | image = self.vae.decode(latents).sample |
378 | image = (image / 2 + 0.5).clamp(0, 1) | 380 | image = (image / 2 + 0.5).clamp(0, 1) |
379 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 | 381 | # we always cast to float32 as this does not cause significant overhead and is compatible with bfloa16 |
diff --git a/train_dreambooth.py b/train_dreambooth.py index f8f6e84..a85ae4c 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -582,12 +582,15 @@ def main(): | |||
582 | ) | 582 | ) |
583 | datamodule.setup() | 583 | datamodule.setup() |
584 | 584 | ||
585 | optimizer = create_optimizer( | 585 | params_to_optimize = (unet.parameters(), ) |
586 | itertools.chain( | 586 | if args.train_text_encoder_epochs != 0: |
587 | unet.parameters(), | 587 | params_to_optimize += ( |
588 | text_encoder.text_model.encoder.parameters(), | 588 | text_encoder.text_model.encoder.parameters(), |
589 | text_encoder.text_model.final_layer_norm.parameters(), | 589 | text_encoder.text_model.final_layer_norm.parameters(), |
590 | ), | 590 | ) |
591 | |||
592 | optimizer = create_optimizer( | ||
593 | itertools.chain(*params_to_optimize), | ||
591 | lr=args.learning_rate, | 594 | lr=args.learning_rate, |
592 | ) | 595 | ) |
593 | 596 | ||
diff --git a/train_lora.py b/train_lora.py index 787f271..8dd3c86 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -424,9 +424,6 @@ def main(): | |||
424 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( | 424 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( |
425 | args.pretrained_model_name_or_path) | 425 | args.pretrained_model_name_or_path) |
426 | 426 | ||
427 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) | ||
428 | tokenizer.set_dropout(args.vector_dropout) | ||
429 | |||
430 | vae.enable_slicing() | 427 | vae.enable_slicing() |
431 | vae.set_use_memory_efficient_attention_xformers(True) | 428 | vae.set_use_memory_efficient_attention_xformers(True) |
432 | unet.enable_xformers_memory_efficient_attention() | 429 | unet.enable_xformers_memory_efficient_attention() |
diff --git a/train_ti.py b/train_ti.py index 7aeff7c..9bc74c1 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -340,7 +340,7 @@ def parse_args(): | |||
340 | parser.add_argument( | 340 | parser.add_argument( |
341 | "--checkpoint_frequency", | 341 | "--checkpoint_frequency", |
342 | type=int, | 342 | type=int, |
343 | default=5, | 343 | default=999999, |
344 | help="How often to save a checkpoint and sample image (in epochs)", | 344 | help="How often to save a checkpoint and sample image (in epochs)", |
345 | ) | 345 | ) |
346 | parser.add_argument( | 346 | parser.add_argument( |
diff --git a/training/functional.py b/training/functional.py index ebb48ab..015fe5e 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -259,7 +259,7 @@ def snr_weight(noisy_latents, latents, gamma): | |||
259 | sigma_mean_sq = F.mse_loss(sigma.float(), zeros.float(), reduction="none").mean([1, 2, 3]) | 259 | sigma_mean_sq = F.mse_loss(sigma.float(), zeros.float(), reduction="none").mean([1, 2, 3]) |
260 | snr = torch.div(alpha_mean_sq, sigma_mean_sq) | 260 | snr = torch.div(alpha_mean_sq, sigma_mean_sq) |
261 | gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr) | 261 | gamma_over_snr = torch.div(torch.ones_like(snr) * gamma, snr) |
262 | snr_weight = torch.minimum(gamma_over_snr, torch.ones_like(gamma_over_snr)).float() | 262 | snr_weight = torch.fmin(gamma_over_snr, torch.ones_like(gamma_over_snr)).float() |
263 | return snr_weight | 263 | return snr_weight |
264 | 264 | ||
265 | return torch.tensor( | 265 | return torch.tensor( |
@@ -471,10 +471,7 @@ def train_loop( | |||
471 | "lr": lr_scheduler.get_last_lr()[0], | 471 | "lr": lr_scheduler.get_last_lr()[0], |
472 | } | 472 | } |
473 | if isDadaptation: | 473 | if isDadaptation: |
474 | logs["lr/d*lr"] = ( | 474 | logs["lr/d*lr"] = optimizer.param_groups[0]["d"] * optimizer.param_groups[0]["lr"] |
475 | optimizer.param_groups[0]["d"] * | ||
476 | optimizer.param_groups[0]["lr"] | ||
477 | ) | ||
478 | logs.update(on_log()) | 475 | logs.update(on_log()) |
479 | 476 | ||
480 | local_progress_bar.set_postfix(**logs) | 477 | local_progress_bar.set_postfix(**logs) |
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index e5e84c8..28fccff 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py | |||
@@ -137,8 +137,8 @@ def dreambooth_strategy_callbacks( | |||
137 | 137 | ||
138 | print("Saving model...") | 138 | print("Saving model...") |
139 | 139 | ||
140 | unet_ = accelerator.unwrap_model(unet, False) | 140 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False) |
141 | text_encoder_ = accelerator.unwrap_model(text_encoder, False) | 141 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) |
142 | 142 | ||
143 | with ema_context(): | 143 | with ema_context(): |
144 | pipeline = VlpnStableDiffusion( | 144 | pipeline = VlpnStableDiffusion( |
@@ -160,8 +160,8 @@ def dreambooth_strategy_callbacks( | |||
160 | @torch.no_grad() | 160 | @torch.no_grad() |
161 | def on_sample(step): | 161 | def on_sample(step): |
162 | with ema_context(): | 162 | with ema_context(): |
163 | unet_ = accelerator.unwrap_model(unet, False) | 163 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) |
164 | text_encoder_ = accelerator.unwrap_model(text_encoder, False) | 164 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) |
165 | 165 | ||
166 | orig_unet_dtype = unet_.dtype | 166 | orig_unet_dtype = unet_.dtype |
167 | orig_text_encoder_dtype = text_encoder_.dtype | 167 | orig_text_encoder_dtype = text_encoder_.dtype |
diff --git a/training/strategy/lora.py b/training/strategy/lora.py index aa75bec..1c8fad6 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py | |||
@@ -47,7 +47,6 @@ def lora_strategy_callbacks( | |||
47 | save_samples_ = partial( | 47 | save_samples_ = partial( |
48 | save_samples, | 48 | save_samples, |
49 | accelerator=accelerator, | 49 | accelerator=accelerator, |
50 | text_encoder=text_encoder, | ||
51 | tokenizer=tokenizer, | 50 | tokenizer=tokenizer, |
52 | vae=vae, | 51 | vae=vae, |
53 | sample_scheduler=sample_scheduler, | 52 | sample_scheduler=sample_scheduler, |
@@ -72,6 +71,7 @@ def lora_strategy_callbacks( | |||
72 | @contextmanager | 71 | @contextmanager |
73 | def on_train(epoch: int): | 72 | def on_train(epoch: int): |
74 | tokenizer.train() | 73 | tokenizer.train() |
74 | text_encoder.train() | ||
75 | yield | 75 | yield |
76 | 76 | ||
77 | @contextmanager | 77 | @contextmanager |
@@ -89,8 +89,8 @@ def lora_strategy_callbacks( | |||
89 | def on_checkpoint(step, postfix): | 89 | def on_checkpoint(step, postfix): |
90 | print(f"Saving checkpoint for step {step}...") | 90 | print(f"Saving checkpoint for step {step}...") |
91 | 91 | ||
92 | unet_ = accelerator.unwrap_model(unet, False) | 92 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False) |
93 | text_encoder_ = accelerator.unwrap_model(text_encoder, False) | 93 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) |
94 | 94 | ||
95 | lora_config = {} | 95 | lora_config = {} |
96 | state_dict = get_peft_model_state_dict(unet, state_dict=accelerator.get_state_dict(unet)) | 96 | state_dict = get_peft_model_state_dict(unet, state_dict=accelerator.get_state_dict(unet)) |
@@ -111,10 +111,10 @@ def lora_strategy_callbacks( | |||
111 | 111 | ||
112 | @torch.no_grad() | 112 | @torch.no_grad() |
113 | def on_sample(step): | 113 | def on_sample(step): |
114 | unet_ = accelerator.unwrap_model(unet, False) | 114 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) |
115 | text_encoder_ = accelerator.unwrap_model(text_encoder, False) | 115 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) |
116 | 116 | ||
117 | save_samples_(step=step, unet=unet_) | 117 | save_samples_(step=step, unet=unet_, text_encoder=text_encoder_) |
118 | 118 | ||
119 | del unet_ | 119 | del unet_ |
120 | del text_encoder_ | 120 | del text_encoder_ |
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index bd0d178..2038e34 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
@@ -156,8 +156,8 @@ def textual_inversion_strategy_callbacks( | |||
156 | @torch.no_grad() | 156 | @torch.no_grad() |
157 | def on_sample(step): | 157 | def on_sample(step): |
158 | with ema_context(): | 158 | with ema_context(): |
159 | unet_ = accelerator.unwrap_model(unet, False) | 159 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=True) |
160 | text_encoder_ = accelerator.unwrap_model(text_encoder, False) | 160 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=True) |
161 | 161 | ||
162 | orig_unet_dtype = unet_.dtype | 162 | orig_unet_dtype = unet_.dtype |
163 | orig_text_encoder_dtype = text_encoder_.dtype | 163 | orig_text_encoder_dtype = text_encoder_.dtype |