From 49a37b054ea7c1cdd8c0d7c44f3809ab8bee0693 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 6 Oct 2022 17:15:22 +0200 Subject: Update --- textual_inversion.py | 112 ++++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 101 insertions(+), 11 deletions(-) (limited to 'textual_inversion.py') diff --git a/textual_inversion.py b/textual_inversion.py index 7919ebd..11c324d 100644 --- a/textual_inversion.py +++ b/textual_inversion.py @@ -25,7 +25,7 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion import json import os -from data.textual_inversion.csv import CSVDataModule +from data.csv import CSVDataModule logger = get_logger(__name__) @@ -68,10 +68,10 @@ def parse_args(): help="A token to use as initializer word." ) parser.add_argument( - "--vectors_per_token", - type=int, - default=1, - help="Vectors per token." + "--use_class_images", + action="store_true", + default=True, + help="Include class images in the loss calculation a la Dreambooth.", ) parser.add_argument( "--repeats", @@ -233,6 +233,12 @@ def parse_args(): default=30, help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", ) + parser.add_argument( + "--prior_loss_weight", + type=float, + default=1.0, + help="The weight of prior preservation loss." + ) parser.add_argument( "--resume_from", type=str, @@ -395,7 +401,8 @@ class Checkpointer: for i in range(self.sample_batches): batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] - prompt = [prompt for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] + prompt = [prompt.format(self.placeholder_token) + for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] with self.accelerator.autocast(): @@ -556,25 +563,94 @@ def main(): eps=args.adam_epsilon, ) - # TODO (patil-suraj): laod scheduler using args noise_scheduler = DDPMScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000 + beta_start=0.00085, + beta_end=0.012, + beta_schedule="scaled_linear", + num_train_timesteps=1000 ) + def collate_fn(examples): + prompts = [example["prompts"] for example in examples] + nprompts = [example["nprompts"] for example in examples] + input_ids = [example["instance_prompt_ids"] for example in examples] + pixel_values = [example["instance_images"] for example in examples] + + # concat class and instance examples for prior preservation + if args.use_class_images and "class_prompt_ids" in examples[0]: + input_ids += [example["class_prompt_ids"] for example in examples] + pixel_values += [example["class_images"] for example in examples] + + pixel_values = torch.stack(pixel_values) + pixel_values = pixel_values.to(dtype=torch.float32, memory_format=torch.contiguous_format) + + input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids + + batch = { + "prompts": prompts, + "nprompts": nprompts, + "input_ids": input_ids, + "pixel_values": pixel_values, + } + return batch + datamodule = CSVDataModule( data_file=args.train_data_file, batch_size=args.train_batch_size, tokenizer=tokenizer, + instance_identifier=args.placeholder_token, + class_identifier=args.initializer_token if args.use_class_images else None, + class_subdir="ti_cls", size=args.resolution, - placeholder_token=args.placeholder_token, repeats=args.repeats, center_crop=args.center_crop, - valid_set_size=args.sample_batch_size*args.sample_batches + valid_set_size=args.sample_batch_size*args.sample_batches, + collate_fn=collate_fn ) datamodule.prepare_data() datamodule.setup() + if args.use_class_images: + missing_data = [item for item in datamodule.data if not item[1].exists()] + + if len(missing_data) != 0: + batched_data = [missing_data[i:i+args.sample_batch_size] + for i in range(0, len(missing_data), args.sample_batch_size)] + + scheduler = EulerAScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" + ) + + pipeline = VlpnStableDiffusion( + text_encoder=text_encoder, + vae=vae, + unet=unet, + tokenizer=tokenizer, + scheduler=scheduler, + ).to(accelerator.device) + pipeline.enable_attention_slicing() + + for batch in batched_data: + image_name = [p[1] for p in batch] + prompt = [p[2].format(args.initializer_token) for p in batch] + nprompt = [p[3] for p in batch] + + with accelerator.autocast(): + images = pipeline( + prompt=prompt, + negative_prompt=nprompt, + num_inference_steps=args.sample_steps + ).images + + for i, image in enumerate(images): + image.save(image_name[i]) + + del pipeline + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + train_dataloader = datamodule.train_dataloader() val_dataloader = datamodule.val_dataloader() @@ -693,7 +769,21 @@ def main(): # Predict the noise residual noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample - loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() + if args.use_class_images: + # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. + noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) + noise, noise_prior = torch.chunk(noise, 2, dim=0) + + # Compute instance loss + loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() + + # Compute prior loss + prior_loss = F.mse_loss(noise_pred_prior, noise_prior, reduction="none").mean([1, 2, 3]).mean() + + # Add the prior loss to the instance loss. + loss = loss + args.prior_loss_weight * prior_loss + else: + loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() accelerator.backward(loss) -- cgit v1.2.3-54-g00ecf