summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--data/csv.py4
-rw-r--r--train_dreambooth.py7
-rw-r--r--train_lora.py7
-rw-r--r--train_ti.py7
-rw-r--r--training/functional.py18
5 files changed, 1 insertions, 42 deletions
diff --git a/data/csv.py b/data/csv.py
index d726033..43bf14c 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -108,7 +108,6 @@ def collate_fn(
108 dtype: torch.dtype, 108 dtype: torch.dtype,
109 tokenizer: CLIPTokenizer, 109 tokenizer: CLIPTokenizer,
110 max_token_id_length: Optional[int], 110 max_token_id_length: Optional[int],
111 with_guidance: bool,
112 with_prior_preservation: bool, 111 with_prior_preservation: bool,
113 examples, 112 examples,
114): 113):
@@ -195,7 +194,6 @@ class VlpnDataModule:
195 tokenizer: CLIPTokenizer, 194 tokenizer: CLIPTokenizer,
196 constant_prompt_length: bool = False, 195 constant_prompt_length: bool = False,
197 class_subdir: str = "cls", 196 class_subdir: str = "cls",
198 with_guidance: bool = False,
199 num_class_images: int = 1, 197 num_class_images: int = 1,
200 size: int = 768, 198 size: int = 768,
201 num_buckets: int = 0, 199 num_buckets: int = 0,
@@ -228,7 +226,6 @@ class VlpnDataModule:
228 self.class_root.mkdir(parents=True, exist_ok=True) 226 self.class_root.mkdir(parents=True, exist_ok=True)
229 self.placeholder_tokens = placeholder_tokens 227 self.placeholder_tokens = placeholder_tokens
230 self.num_class_images = num_class_images 228 self.num_class_images = num_class_images
231 self.with_guidance = with_guidance
232 229
233 self.constant_prompt_length = constant_prompt_length 230 self.constant_prompt_length = constant_prompt_length
234 self.max_token_id_length = None 231 self.max_token_id_length = None
@@ -356,7 +353,6 @@ class VlpnDataModule:
356 self.dtype, 353 self.dtype,
357 self.tokenizer, 354 self.tokenizer,
358 self.max_token_id_length, 355 self.max_token_id_length,
359 self.with_guidance,
360 self.num_class_images != 0, 356 self.num_class_images != 0,
361 ) 357 )
362 358
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 0543a35..939a8f3 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -194,11 +194,6 @@ def parse_args():
194 help="Shuffle tags.", 194 help="Shuffle tags.",
195 ) 195 )
196 parser.add_argument( 196 parser.add_argument(
197 "--guidance_scale",
198 type=float,
199 default=0,
200 )
201 parser.add_argument(
202 "--num_class_images", 197 "--num_class_images",
203 type=int, 198 type=int,
204 default=0, 199 default=0,
@@ -874,7 +869,6 @@ def main():
874 dtype=weight_dtype, 869 dtype=weight_dtype,
875 seed=args.seed, 870 seed=args.seed,
876 compile_unet=args.compile_unet, 871 compile_unet=args.compile_unet,
877 guidance_scale=args.guidance_scale,
878 prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, 872 prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0,
879 sample_scheduler=sample_scheduler, 873 sample_scheduler=sample_scheduler,
880 sample_batch_size=args.sample_batch_size, 874 sample_batch_size=args.sample_batch_size,
@@ -893,7 +887,6 @@ def main():
893 tokenizer=tokenizer, 887 tokenizer=tokenizer,
894 constant_prompt_length=args.compile_unet, 888 constant_prompt_length=args.compile_unet,
895 class_subdir=args.class_image_dir, 889 class_subdir=args.class_image_dir,
896 with_guidance=args.guidance_scale != 0,
897 num_class_images=args.num_class_images, 890 num_class_images=args.num_class_images,
898 size=args.resolution, 891 size=args.resolution,
899 num_buckets=args.num_buckets, 892 num_buckets=args.num_buckets,
diff --git a/train_lora.py b/train_lora.py
index b7ee2d6..51dc827 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -207,11 +207,6 @@ def parse_args():
207 help="Shuffle tags.", 207 help="Shuffle tags.",
208 ) 208 )
209 parser.add_argument( 209 parser.add_argument(
210 "--guidance_scale",
211 type=float,
212 default=0,
213 )
214 parser.add_argument(
215 "--num_class_images", 210 "--num_class_images",
216 type=int, 211 type=int,
217 default=0, 212 default=0,
@@ -998,7 +993,6 @@ def main():
998 dtype=weight_dtype, 993 dtype=weight_dtype,
999 seed=args.seed, 994 seed=args.seed,
1000 compile_unet=args.compile_unet, 995 compile_unet=args.compile_unet,
1001 guidance_scale=args.guidance_scale,
1002 prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, 996 prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0,
1003 sample_scheduler=sample_scheduler, 997 sample_scheduler=sample_scheduler,
1004 sample_batch_size=args.sample_batch_size, 998 sample_batch_size=args.sample_batch_size,
@@ -1022,7 +1016,6 @@ def main():
1022 tokenizer=tokenizer, 1016 tokenizer=tokenizer,
1023 constant_prompt_length=args.compile_unet, 1017 constant_prompt_length=args.compile_unet,
1024 class_subdir=args.class_image_dir, 1018 class_subdir=args.class_image_dir,
1025 with_guidance=args.guidance_scale != 0,
1026 num_class_images=args.num_class_images, 1019 num_class_images=args.num_class_images,
1027 size=args.resolution, 1020 size=args.resolution,
1028 num_buckets=args.num_buckets, 1021 num_buckets=args.num_buckets,
diff --git a/train_ti.py b/train_ti.py
index 7d1ef19..7f93960 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -125,11 +125,6 @@ def parse_args():
125 action="store_true", 125 action="store_true",
126 ) 126 )
127 parser.add_argument( 127 parser.add_argument(
128 "--guidance_scale",
129 type=float,
130 default=0,
131 )
132 parser.add_argument(
133 "--num_class_images", 128 "--num_class_images",
134 type=int, 129 type=int,
135 default=0, 130 default=0,
@@ -852,7 +847,6 @@ def main():
852 dtype=weight_dtype, 847 dtype=weight_dtype,
853 seed=args.seed, 848 seed=args.seed,
854 compile_unet=args.compile_unet, 849 compile_unet=args.compile_unet,
855 guidance_scale=args.guidance_scale,
856 prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, 850 prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0,
857 no_val=args.valid_set_size == 0, 851 no_val=args.valid_set_size == 0,
858 strategy=textual_inversion_strategy, 852 strategy=textual_inversion_strategy,
@@ -923,7 +917,6 @@ def main():
923 batch_size=args.train_batch_size, 917 batch_size=args.train_batch_size,
924 tokenizer=tokenizer, 918 tokenizer=tokenizer,
925 class_subdir=args.class_image_dir, 919 class_subdir=args.class_image_dir,
926 with_guidance=args.guidance_scale != 0,
927 num_class_images=args.num_class_images, 920 num_class_images=args.num_class_images,
928 size=args.resolution, 921 size=args.resolution,
929 num_buckets=args.num_buckets, 922 num_buckets=args.num_buckets,
diff --git a/training/functional.py b/training/functional.py
index a3d1f08..43b03ac 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -342,7 +342,6 @@ def loss_step(
342 schedule_sampler: ScheduleSampler, 342 schedule_sampler: ScheduleSampler,
343 unet: UNet2DConditionModel, 343 unet: UNet2DConditionModel,
344 text_encoder: CLIPTextModel, 344 text_encoder: CLIPTextModel,
345 guidance_scale: float,
346 prior_loss_weight: float, 345 prior_loss_weight: float,
347 seed: int, 346 seed: int,
348 input_pertubation: float, 347 input_pertubation: float,
@@ -400,19 +399,6 @@ def loss_step(
400 noisy_latents, timesteps, encoder_hidden_states, return_dict=False 399 noisy_latents, timesteps, encoder_hidden_states, return_dict=False
401 )[0] 400 )[0]
402 401
403 if guidance_scale != 0:
404 uncond_encoder_hidden_states = get_extended_embeddings(
405 text_encoder, batch["negative_input_ids"], batch["negative_attention_mask"]
406 )
407 uncond_encoder_hidden_states = uncond_encoder_hidden_states.to(dtype=unet.dtype)
408
409 model_pred_uncond = unet(
410 noisy_latents, timesteps, uncond_encoder_hidden_states, return_dict=False
411 )[0]
412 model_pred = model_pred_uncond + guidance_scale * (
413 model_pred - model_pred_uncond
414 )
415
416 # Get the target for loss depending on the prediction type 402 # Get the target for loss depending on the prediction type
417 if noise_scheduler.config.prediction_type == "epsilon": 403 if noise_scheduler.config.prediction_type == "epsilon":
418 target = noise 404 target = noise
@@ -425,7 +411,7 @@ def loss_step(
425 411
426 acc = (model_pred == target).float().mean() 412 acc = (model_pred == target).float().mean()
427 413
428 if guidance_scale == 0 and prior_loss_weight != 0: 414 if prior_loss_weight != 0:
429 # Chunk the noise and model_pred into two parts and compute the loss on each part separately. 415 # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
430 model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) 416 model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
431 target, target_prior = torch.chunk(target, 2, dim=0) 417 target, target_prior = torch.chunk(target, 2, dim=0)
@@ -727,7 +713,6 @@ def train(
727 milestone_checkpoints: bool = True, 713 milestone_checkpoints: bool = True,
728 cycle: int = 1, 714 cycle: int = 1,
729 global_step_offset: int = 0, 715 global_step_offset: int = 0,
730 guidance_scale: float = 0.0,
731 prior_loss_weight: float = 1.0, 716 prior_loss_weight: float = 1.0,
732 input_pertubation: float = 0.1, 717 input_pertubation: float = 0.1,
733 schedule_sampler: Optional[ScheduleSampler] = None, 718 schedule_sampler: Optional[ScheduleSampler] = None,
@@ -787,7 +772,6 @@ def train(
787 schedule_sampler, 772 schedule_sampler,
788 unet, 773 unet,
789 text_encoder, 774 text_encoder,
790 guidance_scale,
791 prior_loss_weight, 775 prior_loss_weight,
792 seed, 776 seed,
793 input_pertubation, 777 input_pertubation,