summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-04-13 07:14:24 +0200
committerVolpeon <git@volpeon.ink>2023-04-13 07:14:24 +0200
commita0b63ee7f4a8c793c0d200c86ef07677aa4cbf2e (patch)
tree6a695b2b5a73cebc35ff9e581c70f1a0e75b62e8
parentExperimental convnext discriminator support (diff)
downloadtextual-inversion-diff-a0b63ee7f4a8c793c0d200c86ef07677aa4cbf2e.tar.gz
textual-inversion-diff-a0b63ee7f4a8c793c0d200c86ef07677aa4cbf2e.tar.bz2
textual-inversion-diff-a0b63ee7f4a8c793c0d200c86ef07677aa4cbf2e.zip
Update
-rw-r--r--models/convnext/discriminator.py8
-rw-r--r--models/sparse.py2
-rw-r--r--train_lora.py73
-rw-r--r--train_ti.py13
-rw-r--r--training/functional.py35
-rw-r--r--training/strategy/dreambooth.py7
-rw-r--r--training/strategy/lora.py6
-rw-r--r--training/strategy/ti.py3
8 files changed, 80 insertions, 67 deletions
diff --git a/models/convnext/discriminator.py b/models/convnext/discriminator.py
index 7dbbe3a..571b915 100644
--- a/models/convnext/discriminator.py
+++ b/models/convnext/discriminator.py
@@ -15,13 +15,7 @@ class ConvNeXtDiscriminator():
15 self.img_std = torch.tensor(IMAGENET_DEFAULT_STD).view(1, -1, 1, 1) 15 self.img_std = torch.tensor(IMAGENET_DEFAULT_STD).view(1, -1, 1, 1)
16 16
17 def get_score(self, img): 17 def get_score(self, img):
18 img_mean = self.img_mean.to(device=img.device, dtype=img.dtype) 18 pred = self.get_all(img)
19 img_std = self.img_std.to(device=img.device, dtype=img.dtype)
20
21 img = ((img+1.)/2.).sub(img_mean).div(img_std)
22
23 img = F.interpolate(img, size=(self.input_size, self.input_size), mode='bicubic', align_corners=True)
24 pred = self.net(img)
25 return torch.softmax(pred, dim=-1)[:, 1] 19 return torch.softmax(pred, dim=-1)[:, 1]
26 20
27 def get_all(self, img): 21 def get_all(self, img):
diff --git a/models/sparse.py b/models/sparse.py
index bcb2897..07b3413 100644
--- a/models/sparse.py
+++ b/models/sparse.py
@@ -15,7 +15,7 @@ class PseudoSparseEmbedding(nn.Module):
15 if dropout_p > 0.0: 15 if dropout_p > 0.0:
16 self.dropout = nn.Dropout(p=dropout_p) 16 self.dropout = nn.Dropout(p=dropout_p)
17 else: 17 else:
18 self.dropout = lambda x: x 18 self.dropout = nn.Identity()
19 19
20 self.register_buffer('mapping', torch.zeros(0, device=device, dtype=torch.long)) 20 self.register_buffer('mapping', torch.zeros(0, device=device, dtype=torch.long))
21 21
diff --git a/train_lora.py b/train_lora.py
index 29e40b2..073e939 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -87,6 +87,12 @@ def parse_args():
87 help="How many cycles to run automatically." 87 help="How many cycles to run automatically."
88 ) 88 )
89 parser.add_argument( 89 parser.add_argument(
90 "--cycle_decay",
91 type=float,
92 default=1.0,
93 help="Learning rate decay per cycle."
94 )
95 parser.add_argument(
90 "--placeholder_tokens", 96 "--placeholder_tokens",
91 type=str, 97 type=str,
92 nargs='*', 98 nargs='*',
@@ -924,39 +930,15 @@ def main():
924 if args.sample_num is not None: 930 if args.sample_num is not None:
925 lora_sample_frequency = math.ceil(num_train_epochs / args.sample_num) 931 lora_sample_frequency = math.ceil(num_train_epochs / args.sample_num)
926 932
927 params_to_optimize = []
928 group_labels = [] 933 group_labels = []
929 if len(args.placeholder_tokens) != 0: 934 if len(args.placeholder_tokens) != 0:
930 params_to_optimize.append({
931 "params": text_encoder.text_model.embeddings.token_override_embedding.parameters(),
932 "lr": args.learning_rate_emb,
933 "weight_decay": 0,
934 })
935 group_labels.append("emb") 935 group_labels.append("emb")
936 params_to_optimize += [
937 {
938 "params": (
939 param
940 for param in unet.parameters()
941 if param.requires_grad
942 ),
943 "lr": args.learning_rate_unet,
944 },
945 {
946 "params": (
947 param
948 for param in itertools.chain(
949 text_encoder.text_model.encoder.parameters(),
950 text_encoder.text_model.final_layer_norm.parameters(),
951 )
952 if param.requires_grad
953 ),
954 "lr": args.learning_rate_text,
955 },
956 ]
957 group_labels += ["unet", "text"] 936 group_labels += ["unet", "text"]
958 937
959 training_iter = 0 938 training_iter = 0
939 learning_rate_emb = args.learning_rate_emb
940 learning_rate_unet = args.learning_rate_unet
941 learning_rate_text = args.learning_rate_text
960 942
961 lora_project = "lora" 943 lora_project = "lora"
962 944
@@ -973,6 +955,37 @@ def main():
973 print(f"============ LoRA cycle {training_iter + 1} ============") 955 print(f"============ LoRA cycle {training_iter + 1} ============")
974 print("") 956 print("")
975 957
958 params_to_optimize = []
959
960 if len(args.placeholder_tokens) != 0:
961 params_to_optimize.append({
962 "params": text_encoder.text_model.embeddings.token_override_embedding.parameters(),
963 "lr": learning_rate_emb,
964 "weight_decay": 0,
965 })
966 group_labels.append("emb")
967 params_to_optimize += [
968 {
969 "params": (
970 param
971 for param in unet.parameters()
972 if param.requires_grad
973 ),
974 "lr": learning_rate_unet,
975 },
976 {
977 "params": (
978 param
979 for param in itertools.chain(
980 text_encoder.text_model.encoder.parameters(),
981 text_encoder.text_model.final_layer_norm.parameters(),
982 )
983 if param.requires_grad
984 ),
985 "lr": learning_rate_text,
986 },
987 ]
988
976 lora_optimizer = create_optimizer(params_to_optimize) 989 lora_optimizer = create_optimizer(params_to_optimize)
977 990
978 lora_lr_scheduler = create_lr_scheduler( 991 lora_lr_scheduler = create_lr_scheduler(
@@ -1002,6 +1015,12 @@ def main():
1002 ) 1015 )
1003 1016
1004 training_iter += 1 1017 training_iter += 1
1018 if args.learning_rate_emb is not None:
1019 learning_rate_emb *= args.cycle_decay
1020 if args.learning_rate_unet is not None:
1021 learning_rate_unet *= args.cycle_decay
1022 if args.learning_rate_text is not None:
1023 learning_rate_text *= args.cycle_decay
1005 1024
1006 accelerator.end_training() 1025 accelerator.end_training()
1007 1026
diff --git a/train_ti.py b/train_ti.py
index 082e9b7..94ddbb6 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -72,6 +72,12 @@ def parse_args():
72 help="How many cycles to run automatically." 72 help="How many cycles to run automatically."
73 ) 73 )
74 parser.add_argument( 74 parser.add_argument(
75 "--cycle_decay",
76 type=float,
77 default=1.0,
78 help="Learning rate decay per cycle."
79 )
80 parser.add_argument(
75 "--placeholder_tokens", 81 "--placeholder_tokens",
76 type=str, 82 type=str,
77 nargs='*', 83 nargs='*',
@@ -672,7 +678,6 @@ def main():
672 convnext.to(accelerator.device, dtype=weight_dtype) 678 convnext.to(accelerator.device, dtype=weight_dtype)
673 convnext.requires_grad_(False) 679 convnext.requires_grad_(False)
674 convnext.eval() 680 convnext.eval()
675 disc = ConvNeXtDiscriminator(convnext, input_size=384)
676 681
677 if len(args.alias_tokens) != 0: 682 if len(args.alias_tokens) != 0:
678 alias_placeholder_tokens = args.alias_tokens[::2] 683 alias_placeholder_tokens = args.alias_tokens[::2]
@@ -815,7 +820,6 @@ def main():
815 milestone_checkpoints=not args.no_milestone_checkpoints, 820 milestone_checkpoints=not args.no_milestone_checkpoints,
816 global_step_offset=global_step_offset, 821 global_step_offset=global_step_offset,
817 offset_noise_strength=args.offset_noise_strength, 822 offset_noise_strength=args.offset_noise_strength,
818 disc=disc,
819 # -- 823 # --
820 use_emb_decay=args.use_emb_decay, 824 use_emb_decay=args.use_emb_decay,
821 emb_decay_target=args.emb_decay_target, 825 emb_decay_target=args.emb_decay_target,
@@ -890,6 +894,7 @@ def main():
890 sample_frequency = math.ceil(num_train_epochs / args.sample_num) 894 sample_frequency = math.ceil(num_train_epochs / args.sample_num)
891 895
892 training_iter = 0 896 training_iter = 0
897 learning_rate = args.learning_rate
893 898
894 project = placeholder_tokens[0] if len(placeholder_tokens) == 1 else "ti" 899 project = placeholder_tokens[0] if len(placeholder_tokens) == 1 else "ti"
895 900
@@ -908,7 +913,7 @@ def main():
908 913
909 optimizer = create_optimizer( 914 optimizer = create_optimizer(
910 text_encoder.text_model.embeddings.token_override_embedding.parameters(), 915 text_encoder.text_model.embeddings.token_override_embedding.parameters(),
911 lr=args.learning_rate, 916 lr=learning_rate,
912 ) 917 )
913 918
914 lr_scheduler = get_scheduler( 919 lr_scheduler = get_scheduler(
@@ -948,6 +953,8 @@ def main():
948 ) 953 )
949 954
950 training_iter += 1 955 training_iter += 1
956 if args.learning_rate is not None:
957 learning_rate *= args.cycle_decay
951 958
952 accelerator.end_training() 959 accelerator.end_training()
953 960
diff --git a/training/functional.py b/training/functional.py
index be39776..ed8ae3a 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -168,8 +168,7 @@ def save_samples(
168 image_grid = pipeline.numpy_to_pil(image_grid.unsqueeze(0).permute(0, 2, 3, 1).numpy())[0] 168 image_grid = pipeline.numpy_to_pil(image_grid.unsqueeze(0).permute(0, 2, 3, 1).numpy())[0]
169 image_grid.save(file_path, quality=85) 169 image_grid.save(file_path, quality=85)
170 170
171 del generator 171 del generator, pipeline
172 del pipeline
173 172
174 if torch.cuda.is_available(): 173 if torch.cuda.is_available():
175 torch.cuda.empty_cache() 174 torch.cuda.empty_cache()
@@ -398,31 +397,32 @@ def loss_step(
398 else: 397 else:
399 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") 398 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
400 399
401 if disc is None: 400 acc = (model_pred == target).float().mean()
402 if guidance_scale == 0 and prior_loss_weight != 0:
403 # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
404 model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
405 target, target_prior = torch.chunk(target, 2, dim=0)
406 401
407 # Compute instance loss 402 if guidance_scale == 0 and prior_loss_weight != 0:
408 loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") 403 # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
404 model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
405 target, target_prior = torch.chunk(target, 2, dim=0)
409 406
410 # Compute prior loss 407 # Compute instance loss
411 prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="none") 408 loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
412 409
413 # Add the prior loss to the instance loss. 410 # Compute prior loss
414 loss = loss + prior_loss_weight * prior_loss 411 prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="none")
415 else:
416 loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
417 412
418 loss = loss.mean([1, 2, 3]) 413 # Add the prior loss to the instance loss.
414 loss = loss + prior_loss_weight * prior_loss
419 else: 415 else:
416 loss = F.mse_loss(model_pred.float(), target.float(), reduction="none")
417
418 loss = loss.mean([1, 2, 3])
419
420 if disc is not None:
420 rec_latent = get_original(noise_scheduler, model_pred, noisy_latents, timesteps) 421 rec_latent = get_original(noise_scheduler, model_pred, noisy_latents, timesteps)
421 rec_latent /= vae.config.scaling_factor 422 rec_latent /= vae.config.scaling_factor
422 rec_latent = rec_latent.to(dtype=vae.dtype) 423 rec_latent = rec_latent.to(dtype=vae.dtype)
423 rec = vae.decode(rec_latent).sample 424 rec = vae.decode(rec_latent).sample
424 loss = 1 - disc.get_score(rec) 425 loss = 1 - disc.get_score(rec)
425 del rec_latent, rec
426 426
427 if min_snr_gamma != 0: 427 if min_snr_gamma != 0:
428 snr = compute_snr(timesteps, noise_scheduler) 428 snr = compute_snr(timesteps, noise_scheduler)
@@ -432,7 +432,6 @@ def loss_step(
432 loss *= mse_loss_weights 432 loss *= mse_loss_weights
433 433
434 loss = loss.mean() 434 loss = loss.mean()
435 acc = (model_pred == target).float().mean()
436 435
437 return loss, acc, bsz 436 return loss, acc, bsz
438 437
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py
index fa51bc7..4ae28b7 100644
--- a/training/strategy/dreambooth.py
+++ b/training/strategy/dreambooth.py
@@ -142,9 +142,7 @@ def dreambooth_strategy_callbacks(
142 ) 142 )
143 pipeline.save_pretrained(checkpoint_output_dir) 143 pipeline.save_pretrained(checkpoint_output_dir)
144 144
145 del unet_ 145 del unet_, text_encoder_, pipeline
146 del text_encoder_
147 del pipeline
148 146
149 if torch.cuda.is_available(): 147 if torch.cuda.is_available():
150 torch.cuda.empty_cache() 148 torch.cuda.empty_cache()
@@ -165,8 +163,7 @@ def dreambooth_strategy_callbacks(
165 unet_.to(dtype=orig_unet_dtype) 163 unet_.to(dtype=orig_unet_dtype)
166 text_encoder_.to(dtype=orig_text_encoder_dtype) 164 text_encoder_.to(dtype=orig_text_encoder_dtype)
167 165
168 del unet_ 166 del unet_, text_encoder_
169 del text_encoder_
170 167
171 if torch.cuda.is_available(): 168 if torch.cuda.is_available():
172 torch.cuda.empty_cache() 169 torch.cuda.empty_cache()
diff --git a/training/strategy/lora.py b/training/strategy/lora.py
index 73ec8f2..1517ee8 100644
--- a/training/strategy/lora.py
+++ b/training/strategy/lora.py
@@ -140,8 +140,7 @@ def lora_strategy_callbacks(
140 with open(checkpoint_output_dir / "lora_config.json", "w") as f: 140 with open(checkpoint_output_dir / "lora_config.json", "w") as f:
141 json.dump(lora_config, f) 141 json.dump(lora_config, f)
142 142
143 del unet_ 143 del unet_, text_encoder_
144 del text_encoder_
145 144
146 if torch.cuda.is_available(): 145 if torch.cuda.is_available():
147 torch.cuda.empty_cache() 146 torch.cuda.empty_cache()
@@ -153,8 +152,7 @@ def lora_strategy_callbacks(
153 152
154 save_samples_(step=step, unet=unet_, text_encoder=text_encoder_) 153 save_samples_(step=step, unet=unet_, text_encoder=text_encoder_)
155 154
156 del unet_ 155 del unet_, text_encoder_
157 del text_encoder_
158 156
159 if torch.cuda.is_available(): 157 if torch.cuda.is_available():
160 torch.cuda.empty_cache() 158 torch.cuda.empty_cache()
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index 08af89d..ca7cc3d 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -158,8 +158,7 @@ def textual_inversion_strategy_callbacks(
158 unet_.to(dtype=orig_unet_dtype) 158 unet_.to(dtype=orig_unet_dtype)
159 text_encoder_.to(dtype=orig_text_encoder_dtype) 159 text_encoder_.to(dtype=orig_text_encoder_dtype)
160 160
161 del unet_ 161 del unet_, text_encoder_
162 del text_encoder_
163 162
164 if torch.cuda.is_available(): 163 if torch.cuda.is_available():
165 torch.cuda.empty_cache() 164 torch.cuda.empty_cache()