From 839ddbf68680739c45235639bd565a3eb7cb8871 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 27 Nov 2022 19:07:23 +0100 Subject: Fixed and improved Textual Inversion training --- textual_inversion.py | 112 +++++++++++++++++++++++++++++++-------------------- 1 file changed, 68 insertions(+), 44 deletions(-) diff --git a/textual_inversion.py b/textual_inversion.py index b676088..20b1617 100644 --- a/textual_inversion.py +++ b/textual_inversion.py @@ -444,11 +444,25 @@ class Checkpointer: data_enum = enumerate(data) + batches = [ + batch + for j, batch in data_enum + if j * data.batch_size < self.sample_batch_size * self.sample_batches + ] + prompts = [ + prompt.format(identifier=self.instance_identifier) + for batch in batches + for prompt in batch["prompts"] + ] + nprompts = [ + prompt + for batch in batches + for prompt in batch["nprompts"] + ] + for i in range(self.sample_batches): - batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] - prompt = [prompt.format(identifier=self.instance_identifier) - for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] - nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] + prompt = prompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] + nprompt = nprompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] samples = pipeline( prompt=prompt, @@ -468,7 +482,7 @@ class Checkpointer: del samples image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size) - image_grid.save(file_path) + image_grid.save(file_path, quality=85) del all_samples del image_grid @@ -485,6 +499,11 @@ class Checkpointer: def main(): args = parse_args() + instance_identifier = args.instance_identifier + + if len(args.placeholder_token) != 0: + instance_identifier = instance_identifier.format(args.placeholder_token[0]) + global_step_offset = 0 if args.resume_from is not None: basepath = Path(args.resume_from) @@ -496,7 +515,7 @@ def main(): print("We've trained %d steps so far" % global_step_offset) else: now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") - basepath = Path(args.output_dir).joinpath(slugify(args.placeholder_token), now) + basepath = Path(args.output_dir).joinpath(slugify(instance_identifier), now) basepath.mkdir(parents=True, exist_ok=True) accelerator = Accelerator( @@ -508,11 +527,8 @@ def main(): logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) - # If passed along, set the training seed now. - if args.seed is not None: - set_seed(args.seed) - - args.instance_identifier = args.instance_identifier.format(args.placeholder_token) + args.seed = args.seed or (torch.random.seed() >> 32) + set_seed(args.seed) # Load the tokenizer and add the placeholder token as a additional special token if args.tokenizer_name: @@ -520,17 +536,6 @@ def main(): elif args.pretrained_model_name_or_path: tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') - # Convert the initializer_token, placeholder_token to ids - initializer_token_ids = torch.stack([ - torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1]) - for token in args.initializer_token - ]) - - num_added_tokens = tokenizer.add_tokens(args.placeholder_token) - print(f"Added {num_added_tokens} new tokens.") - - placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) - # Load models and create wrapper for stable diffusion text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder') vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae') @@ -539,15 +544,23 @@ def main(): checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( args.pretrained_model_name_or_path, subfolder='scheduler') - prompt_processor = PromptProcessor(tokenizer, text_encoder) - unet.set_use_memory_efficient_attention_xformers(True) if args.gradient_checkpointing: text_encoder.gradient_checkpointing_enable() - # slice_size = unet.config.attention_head_dim // 2 - # unet.set_attention_slice(slice_size) + print(f"Adding text embeddings: {args.placeholder_token}") + + # Convert the initializer_token, placeholder_token to ids + initializer_token_ids = torch.stack([ + torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1]) + for token in args.initializer_token + ]) + + num_added_tokens = tokenizer.add_tokens(args.placeholder_token) + print(f"Added {num_added_tokens} new tokens.") + + placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) # Resize the token embeddings as we are adding new special tokens to the tokenizer text_encoder.resize_token_embeddings(len(tokenizer)) @@ -555,6 +568,10 @@ def main(): # Initialise the newly added placeholder token with the embeddings of the initializer token token_embeds = text_encoder.get_input_embeddings().weight.data original_token_embeds = token_embeds.detach().clone().to(accelerator.device) + initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) + + for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings): + token_embeds[token_id] = embeddings if args.resume_checkpoint is not None: token_embeds[placeholder_token_id] = torch.load(args.resume_checkpoint)[args.placeholder_token] @@ -567,12 +584,13 @@ def main(): freeze_params(vae.parameters()) freeze_params(unet.parameters()) # Freeze all parameters except for the token embeddings in text encoder - params_to_freeze = itertools.chain( + freeze_params(itertools.chain( text_encoder.text_model.encoder.parameters(), text_encoder.text_model.final_layer_norm.parameters(), text_encoder.text_model.embeddings.position_embedding.parameters(), - ) - freeze_params(params_to_freeze) + )) + + prompt_processor = PromptProcessor(tokenizer, text_encoder) if args.scale_lr: args.learning_rate = ( @@ -600,6 +618,12 @@ def main(): eps=args.adam_epsilon, ) + weight_dtype = torch.float32 + if args.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif args.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + def collate_fn(examples): prompts = [example["prompts"] for example in examples] nprompts = [example["nprompts"] for example in examples] @@ -612,7 +636,7 @@ def main(): pixel_values += [example["class_images"] for example in examples] pixel_values = torch.stack(pixel_values) - pixel_values = pixel_values.to(dtype=torch.float32, memory_format=torch.contiguous_format) + pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) input_ids = prompt_processor.unify_input_ids(input_ids) @@ -647,27 +671,25 @@ def main(): missing_data = [item for item in datamodule.data_train if not item.class_image_path.exists()] if len(missing_data) != 0: - batched_data = [missing_data[i:i+args.sample_batch_size] - for i in range(0, len(missing_data), args.sample_batch_size)] - - scheduler = EulerAncestralDiscreteScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" - ) + batched_data = [ + missing_data[i:i+args.sample_batch_size] + for i in range(0, len(missing_data), args.sample_batch_size) + ] pipeline = VlpnStableDiffusion( text_encoder=text_encoder, vae=vae, unet=unet, tokenizer=tokenizer, - scheduler=scheduler, + scheduler=checkpoint_scheduler, ).to(accelerator.device) pipeline.set_progress_bar_config(dynamic_ncols=True) with torch.autocast("cuda"), torch.inference_mode(): for batch in batched_data: - image_name = [p.class_image_path for p in batch] - prompt = [p.prompt.format(identifier=args.class_identifier) for p in batch] - nprompt = [p.nprompt for p in batch] + image_name = [item.class_image_path for item in batch] + prompt = [item.prompt.format(identifier=args.class_identifier) for item in batch] + nprompt = [item.nprompt for item in batch] images = pipeline( prompt=prompt, @@ -720,8 +742,8 @@ def main(): ) # Move vae and unet to device - vae.to(accelerator.device) - unet.to(accelerator.device) + vae.to(accelerator.device, dtype=weight_dtype) + unet.to(accelerator.device, dtype=weight_dtype) # Keep vae and unet in eval mode as we don't train these vae.eval() @@ -812,7 +834,7 @@ def main(): latents = latents * 0.18215 # Sample noise that we'll add to the latents - noise = torch.randn(latents.shape).to(latents.device) + noise = torch.randn_like(latents) bsz = latents.shape[0] # Sample a random timestep for each image timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, @@ -825,6 +847,7 @@ def main(): # Get the text embedding for conditioning encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) + encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype) # Predict the noise residual noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample @@ -907,7 +930,7 @@ def main(): latents = vae.encode(batch["pixel_values"]).latent_dist.sample() latents = latents * 0.18215 - noise = torch.randn(latents.shape).to(latents.device) + noise = torch.randn_like(latents) bsz = latents.shape[0] timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) @@ -916,6 +939,7 @@ def main(): noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) + encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype) noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample -- cgit v1.2.3-70-g09d2