From 11b7740deeef7903e81ba4c65a45853323a5fd5e Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 22 Jun 2023 07:34:04 +0200 Subject: Remove training guidance_scale --- data/csv.py | 4 ---- train_dreambooth.py | 7 ------- train_lora.py | 7 ------- train_ti.py | 7 ------- training/functional.py | 18 +----------------- 5 files changed, 1 insertion(+), 42 deletions(-) diff --git a/data/csv.py b/data/csv.py index d726033..43bf14c 100644 --- a/data/csv.py +++ b/data/csv.py @@ -108,7 +108,6 @@ def collate_fn( dtype: torch.dtype, tokenizer: CLIPTokenizer, max_token_id_length: Optional[int], - with_guidance: bool, with_prior_preservation: bool, examples, ): @@ -195,7 +194,6 @@ class VlpnDataModule: tokenizer: CLIPTokenizer, constant_prompt_length: bool = False, class_subdir: str = "cls", - with_guidance: bool = False, num_class_images: int = 1, size: int = 768, num_buckets: int = 0, @@ -228,7 +226,6 @@ class VlpnDataModule: self.class_root.mkdir(parents=True, exist_ok=True) self.placeholder_tokens = placeholder_tokens self.num_class_images = num_class_images - self.with_guidance = with_guidance self.constant_prompt_length = constant_prompt_length self.max_token_id_length = None @@ -356,7 +353,6 @@ class VlpnDataModule: self.dtype, self.tokenizer, self.max_token_id_length, - self.with_guidance, self.num_class_images != 0, ) diff --git a/train_dreambooth.py b/train_dreambooth.py index 0543a35..939a8f3 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -193,11 +193,6 @@ def parse_args(): action="store_true", help="Shuffle tags.", ) - parser.add_argument( - "--guidance_scale", - type=float, - default=0, - ) parser.add_argument( "--num_class_images", type=int, @@ -874,7 +869,6 @@ def main(): dtype=weight_dtype, seed=args.seed, compile_unet=args.compile_unet, - guidance_scale=args.guidance_scale, prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, sample_scheduler=sample_scheduler, sample_batch_size=args.sample_batch_size, @@ -893,7 +887,6 @@ def main(): tokenizer=tokenizer, constant_prompt_length=args.compile_unet, class_subdir=args.class_image_dir, - with_guidance=args.guidance_scale != 0, num_class_images=args.num_class_images, size=args.resolution, num_buckets=args.num_buckets, diff --git a/train_lora.py b/train_lora.py index b7ee2d6..51dc827 100644 --- a/train_lora.py +++ b/train_lora.py @@ -206,11 +206,6 @@ def parse_args(): action="store_true", help="Shuffle tags.", ) - parser.add_argument( - "--guidance_scale", - type=float, - default=0, - ) parser.add_argument( "--num_class_images", type=int, @@ -998,7 +993,6 @@ def main(): dtype=weight_dtype, seed=args.seed, compile_unet=args.compile_unet, - guidance_scale=args.guidance_scale, prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, sample_scheduler=sample_scheduler, sample_batch_size=args.sample_batch_size, @@ -1022,7 +1016,6 @@ def main(): tokenizer=tokenizer, constant_prompt_length=args.compile_unet, class_subdir=args.class_image_dir, - with_guidance=args.guidance_scale != 0, num_class_images=args.num_class_images, size=args.resolution, num_buckets=args.num_buckets, diff --git a/train_ti.py b/train_ti.py index 7d1ef19..7f93960 100644 --- a/train_ti.py +++ b/train_ti.py @@ -124,11 +124,6 @@ def parse_args(): "--sequential", action="store_true", ) - parser.add_argument( - "--guidance_scale", - type=float, - default=0, - ) parser.add_argument( "--num_class_images", type=int, @@ -852,7 +847,6 @@ def main(): dtype=weight_dtype, seed=args.seed, compile_unet=args.compile_unet, - guidance_scale=args.guidance_scale, prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, no_val=args.valid_set_size == 0, strategy=textual_inversion_strategy, @@ -923,7 +917,6 @@ def main(): batch_size=args.train_batch_size, tokenizer=tokenizer, class_subdir=args.class_image_dir, - with_guidance=args.guidance_scale != 0, num_class_images=args.num_class_images, size=args.resolution, num_buckets=args.num_buckets, diff --git a/training/functional.py b/training/functional.py index a3d1f08..43b03ac 100644 --- a/training/functional.py +++ b/training/functional.py @@ -342,7 +342,6 @@ def loss_step( schedule_sampler: ScheduleSampler, unet: UNet2DConditionModel, text_encoder: CLIPTextModel, - guidance_scale: float, prior_loss_weight: float, seed: int, input_pertubation: float, @@ -400,19 +399,6 @@ def loss_step( noisy_latents, timesteps, encoder_hidden_states, return_dict=False )[0] - if guidance_scale != 0: - uncond_encoder_hidden_states = get_extended_embeddings( - text_encoder, batch["negative_input_ids"], batch["negative_attention_mask"] - ) - uncond_encoder_hidden_states = uncond_encoder_hidden_states.to(dtype=unet.dtype) - - model_pred_uncond = unet( - noisy_latents, timesteps, uncond_encoder_hidden_states, return_dict=False - )[0] - model_pred = model_pred_uncond + guidance_scale * ( - model_pred - model_pred_uncond - ) - # Get the target for loss depending on the prediction type if noise_scheduler.config.prediction_type == "epsilon": target = noise @@ -425,7 +411,7 @@ def loss_step( acc = (model_pred == target).float().mean() - if guidance_scale == 0 and prior_loss_weight != 0: + if prior_loss_weight != 0: # Chunk the noise and model_pred into two parts and compute the loss on each part separately. model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) target, target_prior = torch.chunk(target, 2, dim=0) @@ -727,7 +713,6 @@ def train( milestone_checkpoints: bool = True, cycle: int = 1, global_step_offset: int = 0, - guidance_scale: float = 0.0, prior_loss_weight: float = 1.0, input_pertubation: float = 0.1, schedule_sampler: Optional[ScheduleSampler] = None, @@ -787,7 +772,6 @@ def train( schedule_sampler, unet, text_encoder, - guidance_scale, prior_loss_weight, seed, input_pertubation, -- cgit v1.2.3-70-g09d2