diff options
-rw-r--r-- | data/csv.py | 4 | ||||
-rw-r--r-- | train_dreambooth.py | 7 | ||||
-rw-r--r-- | train_lora.py | 7 | ||||
-rw-r--r-- | train_ti.py | 7 | ||||
-rw-r--r-- | training/functional.py | 18 |
5 files changed, 1 insertions, 42 deletions
diff --git a/data/csv.py b/data/csv.py index d726033..43bf14c 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -108,7 +108,6 @@ def collate_fn( | |||
108 | dtype: torch.dtype, | 108 | dtype: torch.dtype, |
109 | tokenizer: CLIPTokenizer, | 109 | tokenizer: CLIPTokenizer, |
110 | max_token_id_length: Optional[int], | 110 | max_token_id_length: Optional[int], |
111 | with_guidance: bool, | ||
112 | with_prior_preservation: bool, | 111 | with_prior_preservation: bool, |
113 | examples, | 112 | examples, |
114 | ): | 113 | ): |
@@ -195,7 +194,6 @@ class VlpnDataModule: | |||
195 | tokenizer: CLIPTokenizer, | 194 | tokenizer: CLIPTokenizer, |
196 | constant_prompt_length: bool = False, | 195 | constant_prompt_length: bool = False, |
197 | class_subdir: str = "cls", | 196 | class_subdir: str = "cls", |
198 | with_guidance: bool = False, | ||
199 | num_class_images: int = 1, | 197 | num_class_images: int = 1, |
200 | size: int = 768, | 198 | size: int = 768, |
201 | num_buckets: int = 0, | 199 | num_buckets: int = 0, |
@@ -228,7 +226,6 @@ class VlpnDataModule: | |||
228 | self.class_root.mkdir(parents=True, exist_ok=True) | 226 | self.class_root.mkdir(parents=True, exist_ok=True) |
229 | self.placeholder_tokens = placeholder_tokens | 227 | self.placeholder_tokens = placeholder_tokens |
230 | self.num_class_images = num_class_images | 228 | self.num_class_images = num_class_images |
231 | self.with_guidance = with_guidance | ||
232 | 229 | ||
233 | self.constant_prompt_length = constant_prompt_length | 230 | self.constant_prompt_length = constant_prompt_length |
234 | self.max_token_id_length = None | 231 | self.max_token_id_length = None |
@@ -356,7 +353,6 @@ class VlpnDataModule: | |||
356 | self.dtype, | 353 | self.dtype, |
357 | self.tokenizer, | 354 | self.tokenizer, |
358 | self.max_token_id_length, | 355 | self.max_token_id_length, |
359 | self.with_guidance, | ||
360 | self.num_class_images != 0, | 356 | self.num_class_images != 0, |
361 | ) | 357 | ) |
362 | 358 | ||
diff --git a/train_dreambooth.py b/train_dreambooth.py index 0543a35..939a8f3 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -194,11 +194,6 @@ def parse_args(): | |||
194 | help="Shuffle tags.", | 194 | help="Shuffle tags.", |
195 | ) | 195 | ) |
196 | parser.add_argument( | 196 | parser.add_argument( |
197 | "--guidance_scale", | ||
198 | type=float, | ||
199 | default=0, | ||
200 | ) | ||
201 | parser.add_argument( | ||
202 | "--num_class_images", | 197 | "--num_class_images", |
203 | type=int, | 198 | type=int, |
204 | default=0, | 199 | default=0, |
@@ -874,7 +869,6 @@ def main(): | |||
874 | dtype=weight_dtype, | 869 | dtype=weight_dtype, |
875 | seed=args.seed, | 870 | seed=args.seed, |
876 | compile_unet=args.compile_unet, | 871 | compile_unet=args.compile_unet, |
877 | guidance_scale=args.guidance_scale, | ||
878 | prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, | 872 | prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, |
879 | sample_scheduler=sample_scheduler, | 873 | sample_scheduler=sample_scheduler, |
880 | sample_batch_size=args.sample_batch_size, | 874 | sample_batch_size=args.sample_batch_size, |
@@ -893,7 +887,6 @@ def main(): | |||
893 | tokenizer=tokenizer, | 887 | tokenizer=tokenizer, |
894 | constant_prompt_length=args.compile_unet, | 888 | constant_prompt_length=args.compile_unet, |
895 | class_subdir=args.class_image_dir, | 889 | class_subdir=args.class_image_dir, |
896 | with_guidance=args.guidance_scale != 0, | ||
897 | num_class_images=args.num_class_images, | 890 | num_class_images=args.num_class_images, |
898 | size=args.resolution, | 891 | size=args.resolution, |
899 | num_buckets=args.num_buckets, | 892 | num_buckets=args.num_buckets, |
diff --git a/train_lora.py b/train_lora.py index b7ee2d6..51dc827 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -207,11 +207,6 @@ def parse_args(): | |||
207 | help="Shuffle tags.", | 207 | help="Shuffle tags.", |
208 | ) | 208 | ) |
209 | parser.add_argument( | 209 | parser.add_argument( |
210 | "--guidance_scale", | ||
211 | type=float, | ||
212 | default=0, | ||
213 | ) | ||
214 | parser.add_argument( | ||
215 | "--num_class_images", | 210 | "--num_class_images", |
216 | type=int, | 211 | type=int, |
217 | default=0, | 212 | default=0, |
@@ -998,7 +993,6 @@ def main(): | |||
998 | dtype=weight_dtype, | 993 | dtype=weight_dtype, |
999 | seed=args.seed, | 994 | seed=args.seed, |
1000 | compile_unet=args.compile_unet, | 995 | compile_unet=args.compile_unet, |
1001 | guidance_scale=args.guidance_scale, | ||
1002 | prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, | 996 | prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, |
1003 | sample_scheduler=sample_scheduler, | 997 | sample_scheduler=sample_scheduler, |
1004 | sample_batch_size=args.sample_batch_size, | 998 | sample_batch_size=args.sample_batch_size, |
@@ -1022,7 +1016,6 @@ def main(): | |||
1022 | tokenizer=tokenizer, | 1016 | tokenizer=tokenizer, |
1023 | constant_prompt_length=args.compile_unet, | 1017 | constant_prompt_length=args.compile_unet, |
1024 | class_subdir=args.class_image_dir, | 1018 | class_subdir=args.class_image_dir, |
1025 | with_guidance=args.guidance_scale != 0, | ||
1026 | num_class_images=args.num_class_images, | 1019 | num_class_images=args.num_class_images, |
1027 | size=args.resolution, | 1020 | size=args.resolution, |
1028 | num_buckets=args.num_buckets, | 1021 | num_buckets=args.num_buckets, |
diff --git a/train_ti.py b/train_ti.py index 7d1ef19..7f93960 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -125,11 +125,6 @@ def parse_args(): | |||
125 | action="store_true", | 125 | action="store_true", |
126 | ) | 126 | ) |
127 | parser.add_argument( | 127 | parser.add_argument( |
128 | "--guidance_scale", | ||
129 | type=float, | ||
130 | default=0, | ||
131 | ) | ||
132 | parser.add_argument( | ||
133 | "--num_class_images", | 128 | "--num_class_images", |
134 | type=int, | 129 | type=int, |
135 | default=0, | 130 | default=0, |
@@ -852,7 +847,6 @@ def main(): | |||
852 | dtype=weight_dtype, | 847 | dtype=weight_dtype, |
853 | seed=args.seed, | 848 | seed=args.seed, |
854 | compile_unet=args.compile_unet, | 849 | compile_unet=args.compile_unet, |
855 | guidance_scale=args.guidance_scale, | ||
856 | prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, | 850 | prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, |
857 | no_val=args.valid_set_size == 0, | 851 | no_val=args.valid_set_size == 0, |
858 | strategy=textual_inversion_strategy, | 852 | strategy=textual_inversion_strategy, |
@@ -923,7 +917,6 @@ def main(): | |||
923 | batch_size=args.train_batch_size, | 917 | batch_size=args.train_batch_size, |
924 | tokenizer=tokenizer, | 918 | tokenizer=tokenizer, |
925 | class_subdir=args.class_image_dir, | 919 | class_subdir=args.class_image_dir, |
926 | with_guidance=args.guidance_scale != 0, | ||
927 | num_class_images=args.num_class_images, | 920 | num_class_images=args.num_class_images, |
928 | size=args.resolution, | 921 | size=args.resolution, |
929 | num_buckets=args.num_buckets, | 922 | num_buckets=args.num_buckets, |
diff --git a/training/functional.py b/training/functional.py index a3d1f08..43b03ac 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -342,7 +342,6 @@ def loss_step( | |||
342 | schedule_sampler: ScheduleSampler, | 342 | schedule_sampler: ScheduleSampler, |
343 | unet: UNet2DConditionModel, | 343 | unet: UNet2DConditionModel, |
344 | text_encoder: CLIPTextModel, | 344 | text_encoder: CLIPTextModel, |
345 | guidance_scale: float, | ||
346 | prior_loss_weight: float, | 345 | prior_loss_weight: float, |
347 | seed: int, | 346 | seed: int, |
348 | input_pertubation: float, | 347 | input_pertubation: float, |
@@ -400,19 +399,6 @@ def loss_step( | |||
400 | noisy_latents, timesteps, encoder_hidden_states, return_dict=False | 399 | noisy_latents, timesteps, encoder_hidden_states, return_dict=False |
401 | )[0] | 400 | )[0] |
402 | 401 | ||
403 | if guidance_scale != 0: | ||
404 | uncond_encoder_hidden_states = get_extended_embeddings( | ||
405 | text_encoder, batch["negative_input_ids"], batch["negative_attention_mask"] | ||
406 | ) | ||
407 | uncond_encoder_hidden_states = uncond_encoder_hidden_states.to(dtype=unet.dtype) | ||
408 | |||
409 | model_pred_uncond = unet( | ||
410 | noisy_latents, timesteps, uncond_encoder_hidden_states, return_dict=False | ||
411 | )[0] | ||
412 | model_pred = model_pred_uncond + guidance_scale * ( | ||
413 | model_pred - model_pred_uncond | ||
414 | ) | ||
415 | |||
416 | # Get the target for loss depending on the prediction type | 402 | # Get the target for loss depending on the prediction type |
417 | if noise_scheduler.config.prediction_type == "epsilon": | 403 | if noise_scheduler.config.prediction_type == "epsilon": |
418 | target = noise | 404 | target = noise |
@@ -425,7 +411,7 @@ def loss_step( | |||
425 | 411 | ||
426 | acc = (model_pred == target).float().mean() | 412 | acc = (model_pred == target).float().mean() |
427 | 413 | ||
428 | if guidance_scale == 0 and prior_loss_weight != 0: | 414 | if prior_loss_weight != 0: |
429 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. | 415 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. |
430 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) | 416 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) |
431 | target, target_prior = torch.chunk(target, 2, dim=0) | 417 | target, target_prior = torch.chunk(target, 2, dim=0) |
@@ -727,7 +713,6 @@ def train( | |||
727 | milestone_checkpoints: bool = True, | 713 | milestone_checkpoints: bool = True, |
728 | cycle: int = 1, | 714 | cycle: int = 1, |
729 | global_step_offset: int = 0, | 715 | global_step_offset: int = 0, |
730 | guidance_scale: float = 0.0, | ||
731 | prior_loss_weight: float = 1.0, | 716 | prior_loss_weight: float = 1.0, |
732 | input_pertubation: float = 0.1, | 717 | input_pertubation: float = 0.1, |
733 | schedule_sampler: Optional[ScheduleSampler] = None, | 718 | schedule_sampler: Optional[ScheduleSampler] = None, |
@@ -787,7 +772,6 @@ def train( | |||
787 | schedule_sampler, | 772 | schedule_sampler, |
788 | unet, | 773 | unet, |
789 | text_encoder, | 774 | text_encoder, |
790 | guidance_scale, | ||
791 | prior_loss_weight, | 775 | prior_loss_weight, |
792 | seed, | 776 | seed, |
793 | input_pertubation, | 777 | input_pertubation, |