diff options
| -rw-r--r-- | data/csv.py | 4 | ||||
| -rw-r--r-- | train_dreambooth.py | 7 | ||||
| -rw-r--r-- | train_lora.py | 7 | ||||
| -rw-r--r-- | train_ti.py | 7 | ||||
| -rw-r--r-- | training/functional.py | 18 |
5 files changed, 1 insertions, 42 deletions
diff --git a/data/csv.py b/data/csv.py index d726033..43bf14c 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -108,7 +108,6 @@ def collate_fn( | |||
| 108 | dtype: torch.dtype, | 108 | dtype: torch.dtype, |
| 109 | tokenizer: CLIPTokenizer, | 109 | tokenizer: CLIPTokenizer, |
| 110 | max_token_id_length: Optional[int], | 110 | max_token_id_length: Optional[int], |
| 111 | with_guidance: bool, | ||
| 112 | with_prior_preservation: bool, | 111 | with_prior_preservation: bool, |
| 113 | examples, | 112 | examples, |
| 114 | ): | 113 | ): |
| @@ -195,7 +194,6 @@ class VlpnDataModule: | |||
| 195 | tokenizer: CLIPTokenizer, | 194 | tokenizer: CLIPTokenizer, |
| 196 | constant_prompt_length: bool = False, | 195 | constant_prompt_length: bool = False, |
| 197 | class_subdir: str = "cls", | 196 | class_subdir: str = "cls", |
| 198 | with_guidance: bool = False, | ||
| 199 | num_class_images: int = 1, | 197 | num_class_images: int = 1, |
| 200 | size: int = 768, | 198 | size: int = 768, |
| 201 | num_buckets: int = 0, | 199 | num_buckets: int = 0, |
| @@ -228,7 +226,6 @@ class VlpnDataModule: | |||
| 228 | self.class_root.mkdir(parents=True, exist_ok=True) | 226 | self.class_root.mkdir(parents=True, exist_ok=True) |
| 229 | self.placeholder_tokens = placeholder_tokens | 227 | self.placeholder_tokens = placeholder_tokens |
| 230 | self.num_class_images = num_class_images | 228 | self.num_class_images = num_class_images |
| 231 | self.with_guidance = with_guidance | ||
| 232 | 229 | ||
| 233 | self.constant_prompt_length = constant_prompt_length | 230 | self.constant_prompt_length = constant_prompt_length |
| 234 | self.max_token_id_length = None | 231 | self.max_token_id_length = None |
| @@ -356,7 +353,6 @@ class VlpnDataModule: | |||
| 356 | self.dtype, | 353 | self.dtype, |
| 357 | self.tokenizer, | 354 | self.tokenizer, |
| 358 | self.max_token_id_length, | 355 | self.max_token_id_length, |
| 359 | self.with_guidance, | ||
| 360 | self.num_class_images != 0, | 356 | self.num_class_images != 0, |
| 361 | ) | 357 | ) |
| 362 | 358 | ||
diff --git a/train_dreambooth.py b/train_dreambooth.py index 0543a35..939a8f3 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -194,11 +194,6 @@ def parse_args(): | |||
| 194 | help="Shuffle tags.", | 194 | help="Shuffle tags.", |
| 195 | ) | 195 | ) |
| 196 | parser.add_argument( | 196 | parser.add_argument( |
| 197 | "--guidance_scale", | ||
| 198 | type=float, | ||
| 199 | default=0, | ||
| 200 | ) | ||
| 201 | parser.add_argument( | ||
| 202 | "--num_class_images", | 197 | "--num_class_images", |
| 203 | type=int, | 198 | type=int, |
| 204 | default=0, | 199 | default=0, |
| @@ -874,7 +869,6 @@ def main(): | |||
| 874 | dtype=weight_dtype, | 869 | dtype=weight_dtype, |
| 875 | seed=args.seed, | 870 | seed=args.seed, |
| 876 | compile_unet=args.compile_unet, | 871 | compile_unet=args.compile_unet, |
| 877 | guidance_scale=args.guidance_scale, | ||
| 878 | prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, | 872 | prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, |
| 879 | sample_scheduler=sample_scheduler, | 873 | sample_scheduler=sample_scheduler, |
| 880 | sample_batch_size=args.sample_batch_size, | 874 | sample_batch_size=args.sample_batch_size, |
| @@ -893,7 +887,6 @@ def main(): | |||
| 893 | tokenizer=tokenizer, | 887 | tokenizer=tokenizer, |
| 894 | constant_prompt_length=args.compile_unet, | 888 | constant_prompt_length=args.compile_unet, |
| 895 | class_subdir=args.class_image_dir, | 889 | class_subdir=args.class_image_dir, |
| 896 | with_guidance=args.guidance_scale != 0, | ||
| 897 | num_class_images=args.num_class_images, | 890 | num_class_images=args.num_class_images, |
| 898 | size=args.resolution, | 891 | size=args.resolution, |
| 899 | num_buckets=args.num_buckets, | 892 | num_buckets=args.num_buckets, |
diff --git a/train_lora.py b/train_lora.py index b7ee2d6..51dc827 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -207,11 +207,6 @@ def parse_args(): | |||
| 207 | help="Shuffle tags.", | 207 | help="Shuffle tags.", |
| 208 | ) | 208 | ) |
| 209 | parser.add_argument( | 209 | parser.add_argument( |
| 210 | "--guidance_scale", | ||
| 211 | type=float, | ||
| 212 | default=0, | ||
| 213 | ) | ||
| 214 | parser.add_argument( | ||
| 215 | "--num_class_images", | 210 | "--num_class_images", |
| 216 | type=int, | 211 | type=int, |
| 217 | default=0, | 212 | default=0, |
| @@ -998,7 +993,6 @@ def main(): | |||
| 998 | dtype=weight_dtype, | 993 | dtype=weight_dtype, |
| 999 | seed=args.seed, | 994 | seed=args.seed, |
| 1000 | compile_unet=args.compile_unet, | 995 | compile_unet=args.compile_unet, |
| 1001 | guidance_scale=args.guidance_scale, | ||
| 1002 | prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, | 996 | prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, |
| 1003 | sample_scheduler=sample_scheduler, | 997 | sample_scheduler=sample_scheduler, |
| 1004 | sample_batch_size=args.sample_batch_size, | 998 | sample_batch_size=args.sample_batch_size, |
| @@ -1022,7 +1016,6 @@ def main(): | |||
| 1022 | tokenizer=tokenizer, | 1016 | tokenizer=tokenizer, |
| 1023 | constant_prompt_length=args.compile_unet, | 1017 | constant_prompt_length=args.compile_unet, |
| 1024 | class_subdir=args.class_image_dir, | 1018 | class_subdir=args.class_image_dir, |
| 1025 | with_guidance=args.guidance_scale != 0, | ||
| 1026 | num_class_images=args.num_class_images, | 1019 | num_class_images=args.num_class_images, |
| 1027 | size=args.resolution, | 1020 | size=args.resolution, |
| 1028 | num_buckets=args.num_buckets, | 1021 | num_buckets=args.num_buckets, |
diff --git a/train_ti.py b/train_ti.py index 7d1ef19..7f93960 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -125,11 +125,6 @@ def parse_args(): | |||
| 125 | action="store_true", | 125 | action="store_true", |
| 126 | ) | 126 | ) |
| 127 | parser.add_argument( | 127 | parser.add_argument( |
| 128 | "--guidance_scale", | ||
| 129 | type=float, | ||
| 130 | default=0, | ||
| 131 | ) | ||
| 132 | parser.add_argument( | ||
| 133 | "--num_class_images", | 128 | "--num_class_images", |
| 134 | type=int, | 129 | type=int, |
| 135 | default=0, | 130 | default=0, |
| @@ -852,7 +847,6 @@ def main(): | |||
| 852 | dtype=weight_dtype, | 847 | dtype=weight_dtype, |
| 853 | seed=args.seed, | 848 | seed=args.seed, |
| 854 | compile_unet=args.compile_unet, | 849 | compile_unet=args.compile_unet, |
| 855 | guidance_scale=args.guidance_scale, | ||
| 856 | prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, | 850 | prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, |
| 857 | no_val=args.valid_set_size == 0, | 851 | no_val=args.valid_set_size == 0, |
| 858 | strategy=textual_inversion_strategy, | 852 | strategy=textual_inversion_strategy, |
| @@ -923,7 +917,6 @@ def main(): | |||
| 923 | batch_size=args.train_batch_size, | 917 | batch_size=args.train_batch_size, |
| 924 | tokenizer=tokenizer, | 918 | tokenizer=tokenizer, |
| 925 | class_subdir=args.class_image_dir, | 919 | class_subdir=args.class_image_dir, |
| 926 | with_guidance=args.guidance_scale != 0, | ||
| 927 | num_class_images=args.num_class_images, | 920 | num_class_images=args.num_class_images, |
| 928 | size=args.resolution, | 921 | size=args.resolution, |
| 929 | num_buckets=args.num_buckets, | 922 | num_buckets=args.num_buckets, |
diff --git a/training/functional.py b/training/functional.py index a3d1f08..43b03ac 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -342,7 +342,6 @@ def loss_step( | |||
| 342 | schedule_sampler: ScheduleSampler, | 342 | schedule_sampler: ScheduleSampler, |
| 343 | unet: UNet2DConditionModel, | 343 | unet: UNet2DConditionModel, |
| 344 | text_encoder: CLIPTextModel, | 344 | text_encoder: CLIPTextModel, |
| 345 | guidance_scale: float, | ||
| 346 | prior_loss_weight: float, | 345 | prior_loss_weight: float, |
| 347 | seed: int, | 346 | seed: int, |
| 348 | input_pertubation: float, | 347 | input_pertubation: float, |
| @@ -400,19 +399,6 @@ def loss_step( | |||
| 400 | noisy_latents, timesteps, encoder_hidden_states, return_dict=False | 399 | noisy_latents, timesteps, encoder_hidden_states, return_dict=False |
| 401 | )[0] | 400 | )[0] |
| 402 | 401 | ||
| 403 | if guidance_scale != 0: | ||
| 404 | uncond_encoder_hidden_states = get_extended_embeddings( | ||
| 405 | text_encoder, batch["negative_input_ids"], batch["negative_attention_mask"] | ||
| 406 | ) | ||
| 407 | uncond_encoder_hidden_states = uncond_encoder_hidden_states.to(dtype=unet.dtype) | ||
| 408 | |||
| 409 | model_pred_uncond = unet( | ||
| 410 | noisy_latents, timesteps, uncond_encoder_hidden_states, return_dict=False | ||
| 411 | )[0] | ||
| 412 | model_pred = model_pred_uncond + guidance_scale * ( | ||
| 413 | model_pred - model_pred_uncond | ||
| 414 | ) | ||
| 415 | |||
| 416 | # Get the target for loss depending on the prediction type | 402 | # Get the target for loss depending on the prediction type |
| 417 | if noise_scheduler.config.prediction_type == "epsilon": | 403 | if noise_scheduler.config.prediction_type == "epsilon": |
| 418 | target = noise | 404 | target = noise |
| @@ -425,7 +411,7 @@ def loss_step( | |||
| 425 | 411 | ||
| 426 | acc = (model_pred == target).float().mean() | 412 | acc = (model_pred == target).float().mean() |
| 427 | 413 | ||
| 428 | if guidance_scale == 0 and prior_loss_weight != 0: | 414 | if prior_loss_weight != 0: |
| 429 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. | 415 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. |
| 430 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) | 416 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) |
| 431 | target, target_prior = torch.chunk(target, 2, dim=0) | 417 | target, target_prior = torch.chunk(target, 2, dim=0) |
| @@ -727,7 +713,6 @@ def train( | |||
| 727 | milestone_checkpoints: bool = True, | 713 | milestone_checkpoints: bool = True, |
| 728 | cycle: int = 1, | 714 | cycle: int = 1, |
| 729 | global_step_offset: int = 0, | 715 | global_step_offset: int = 0, |
| 730 | guidance_scale: float = 0.0, | ||
| 731 | prior_loss_weight: float = 1.0, | 716 | prior_loss_weight: float = 1.0, |
| 732 | input_pertubation: float = 0.1, | 717 | input_pertubation: float = 0.1, |
| 733 | schedule_sampler: Optional[ScheduleSampler] = None, | 718 | schedule_sampler: Optional[ScheduleSampler] = None, |
| @@ -787,7 +772,6 @@ def train( | |||
| 787 | schedule_sampler, | 772 | schedule_sampler, |
| 788 | unet, | 773 | unet, |
| 789 | text_encoder, | 774 | text_encoder, |
| 790 | guidance_scale, | ||
| 791 | prior_loss_weight, | 775 | prior_loss_weight, |
| 792 | seed, | 776 | seed, |
| 793 | input_pertubation, | 777 | input_pertubation, |
