summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--textual_inversion.py112
1 files changed, 68 insertions, 44 deletions
diff --git a/textual_inversion.py b/textual_inversion.py
index b676088..20b1617 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -444,11 +444,25 @@ class Checkpointer:
444 444
445 data_enum = enumerate(data) 445 data_enum = enumerate(data)
446 446
447 batches = [
448 batch
449 for j, batch in data_enum
450 if j * data.batch_size < self.sample_batch_size * self.sample_batches
451 ]
452 prompts = [
453 prompt.format(identifier=self.instance_identifier)
454 for batch in batches
455 for prompt in batch["prompts"]
456 ]
457 nprompts = [
458 prompt
459 for batch in batches
460 for prompt in batch["nprompts"]
461 ]
462
447 for i in range(self.sample_batches): 463 for i in range(self.sample_batches):
448 batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] 464 prompt = prompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size]
449 prompt = [prompt.format(identifier=self.instance_identifier) 465 nprompt = nprompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size]
450 for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size]
451 nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size]
452 466
453 samples = pipeline( 467 samples = pipeline(
454 prompt=prompt, 468 prompt=prompt,
@@ -468,7 +482,7 @@ class Checkpointer:
468 del samples 482 del samples
469 483
470 image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size) 484 image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size)
471 image_grid.save(file_path) 485 image_grid.save(file_path, quality=85)
472 486
473 del all_samples 487 del all_samples
474 del image_grid 488 del image_grid
@@ -485,6 +499,11 @@ class Checkpointer:
485def main(): 499def main():
486 args = parse_args() 500 args = parse_args()
487 501
502 instance_identifier = args.instance_identifier
503
504 if len(args.placeholder_token) != 0:
505 instance_identifier = instance_identifier.format(args.placeholder_token[0])
506
488 global_step_offset = 0 507 global_step_offset = 0
489 if args.resume_from is not None: 508 if args.resume_from is not None:
490 basepath = Path(args.resume_from) 509 basepath = Path(args.resume_from)
@@ -496,7 +515,7 @@ def main():
496 print("We've trained %d steps so far" % global_step_offset) 515 print("We've trained %d steps so far" % global_step_offset)
497 else: 516 else:
498 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") 517 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
499 basepath = Path(args.output_dir).joinpath(slugify(args.placeholder_token), now) 518 basepath = Path(args.output_dir).joinpath(slugify(instance_identifier), now)
500 basepath.mkdir(parents=True, exist_ok=True) 519 basepath.mkdir(parents=True, exist_ok=True)
501 520
502 accelerator = Accelerator( 521 accelerator = Accelerator(
@@ -508,11 +527,8 @@ def main():
508 527
509 logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) 528 logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG)
510 529
511 # If passed along, set the training seed now. 530 args.seed = args.seed or (torch.random.seed() >> 32)
512 if args.seed is not None: 531 set_seed(args.seed)
513 set_seed(args.seed)
514
515 args.instance_identifier = args.instance_identifier.format(args.placeholder_token)
516 532
517 # Load the tokenizer and add the placeholder token as a additional special token 533 # Load the tokenizer and add the placeholder token as a additional special token
518 if args.tokenizer_name: 534 if args.tokenizer_name:
@@ -520,17 +536,6 @@ def main():
520 elif args.pretrained_model_name_or_path: 536 elif args.pretrained_model_name_or_path:
521 tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') 537 tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer')
522 538
523 # Convert the initializer_token, placeholder_token to ids
524 initializer_token_ids = torch.stack([
525 torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1])
526 for token in args.initializer_token
527 ])
528
529 num_added_tokens = tokenizer.add_tokens(args.placeholder_token)
530 print(f"Added {num_added_tokens} new tokens.")
531
532 placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
533
534 # Load models and create wrapper for stable diffusion 539 # Load models and create wrapper for stable diffusion
535 text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder') 540 text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder')
536 vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae') 541 vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae')
@@ -539,15 +544,23 @@ def main():
539 checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( 544 checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained(
540 args.pretrained_model_name_or_path, subfolder='scheduler') 545 args.pretrained_model_name_or_path, subfolder='scheduler')
541 546
542 prompt_processor = PromptProcessor(tokenizer, text_encoder)
543
544 unet.set_use_memory_efficient_attention_xformers(True) 547 unet.set_use_memory_efficient_attention_xformers(True)
545 548
546 if args.gradient_checkpointing: 549 if args.gradient_checkpointing:
547 text_encoder.gradient_checkpointing_enable() 550 text_encoder.gradient_checkpointing_enable()
548 551
549 # slice_size = unet.config.attention_head_dim // 2 552 print(f"Adding text embeddings: {args.placeholder_token}")
550 # unet.set_attention_slice(slice_size) 553
554 # Convert the initializer_token, placeholder_token to ids
555 initializer_token_ids = torch.stack([
556 torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1])
557 for token in args.initializer_token
558 ])
559
560 num_added_tokens = tokenizer.add_tokens(args.placeholder_token)
561 print(f"Added {num_added_tokens} new tokens.")
562
563 placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
551 564
552 # Resize the token embeddings as we are adding new special tokens to the tokenizer 565 # Resize the token embeddings as we are adding new special tokens to the tokenizer
553 text_encoder.resize_token_embeddings(len(tokenizer)) 566 text_encoder.resize_token_embeddings(len(tokenizer))
@@ -555,6 +568,10 @@ def main():
555 # Initialise the newly added placeholder token with the embeddings of the initializer token 568 # Initialise the newly added placeholder token with the embeddings of the initializer token
556 token_embeds = text_encoder.get_input_embeddings().weight.data 569 token_embeds = text_encoder.get_input_embeddings().weight.data
557 original_token_embeds = token_embeds.detach().clone().to(accelerator.device) 570 original_token_embeds = token_embeds.detach().clone().to(accelerator.device)
571 initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids)
572
573 for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings):
574 token_embeds[token_id] = embeddings
558 575
559 if args.resume_checkpoint is not None: 576 if args.resume_checkpoint is not None:
560 token_embeds[placeholder_token_id] = torch.load(args.resume_checkpoint)[args.placeholder_token] 577 token_embeds[placeholder_token_id] = torch.load(args.resume_checkpoint)[args.placeholder_token]
@@ -567,12 +584,13 @@ def main():
567 freeze_params(vae.parameters()) 584 freeze_params(vae.parameters())
568 freeze_params(unet.parameters()) 585 freeze_params(unet.parameters())
569 # Freeze all parameters except for the token embeddings in text encoder 586 # Freeze all parameters except for the token embeddings in text encoder
570 params_to_freeze = itertools.chain( 587 freeze_params(itertools.chain(
571 text_encoder.text_model.encoder.parameters(), 588 text_encoder.text_model.encoder.parameters(),
572 text_encoder.text_model.final_layer_norm.parameters(), 589 text_encoder.text_model.final_layer_norm.parameters(),
573 text_encoder.text_model.embeddings.position_embedding.parameters(), 590 text_encoder.text_model.embeddings.position_embedding.parameters(),
574 ) 591 ))
575 freeze_params(params_to_freeze) 592
593 prompt_processor = PromptProcessor(tokenizer, text_encoder)
576 594
577 if args.scale_lr: 595 if args.scale_lr:
578 args.learning_rate = ( 596 args.learning_rate = (
@@ -600,6 +618,12 @@ def main():
600 eps=args.adam_epsilon, 618 eps=args.adam_epsilon,
601 ) 619 )
602 620
621 weight_dtype = torch.float32
622 if args.mixed_precision == "fp16":
623 weight_dtype = torch.float16
624 elif args.mixed_precision == "bf16":
625 weight_dtype = torch.bfloat16
626
603 def collate_fn(examples): 627 def collate_fn(examples):
604 prompts = [example["prompts"] for example in examples] 628 prompts = [example["prompts"] for example in examples]
605 nprompts = [example["nprompts"] for example in examples] 629 nprompts = [example["nprompts"] for example in examples]
@@ -612,7 +636,7 @@ def main():
612 pixel_values += [example["class_images"] for example in examples] 636 pixel_values += [example["class_images"] for example in examples]
613 637
614 pixel_values = torch.stack(pixel_values) 638 pixel_values = torch.stack(pixel_values)
615 pixel_values = pixel_values.to(dtype=torch.float32, memory_format=torch.contiguous_format) 639 pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format)
616 640
617 input_ids = prompt_processor.unify_input_ids(input_ids) 641 input_ids = prompt_processor.unify_input_ids(input_ids)
618 642
@@ -647,27 +671,25 @@ def main():
647 missing_data = [item for item in datamodule.data_train if not item.class_image_path.exists()] 671 missing_data = [item for item in datamodule.data_train if not item.class_image_path.exists()]
648 672
649 if len(missing_data) != 0: 673 if len(missing_data) != 0:
650 batched_data = [missing_data[i:i+args.sample_batch_size] 674 batched_data = [
651 for i in range(0, len(missing_data), args.sample_batch_size)] 675 missing_data[i:i+args.sample_batch_size]
652 676 for i in range(0, len(missing_data), args.sample_batch_size)
653 scheduler = EulerAncestralDiscreteScheduler( 677 ]
654 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
655 )
656 678
657 pipeline = VlpnStableDiffusion( 679 pipeline = VlpnStableDiffusion(
658 text_encoder=text_encoder, 680 text_encoder=text_encoder,
659 vae=vae, 681 vae=vae,
660 unet=unet, 682 unet=unet,
661 tokenizer=tokenizer, 683 tokenizer=tokenizer,
662 scheduler=scheduler, 684 scheduler=checkpoint_scheduler,
663 ).to(accelerator.device) 685 ).to(accelerator.device)
664 pipeline.set_progress_bar_config(dynamic_ncols=True) 686 pipeline.set_progress_bar_config(dynamic_ncols=True)
665 687
666 with torch.autocast("cuda"), torch.inference_mode(): 688 with torch.autocast("cuda"), torch.inference_mode():
667 for batch in batched_data: 689 for batch in batched_data:
668 image_name = [p.class_image_path for p in batch] 690 image_name = [item.class_image_path for item in batch]
669 prompt = [p.prompt.format(identifier=args.class_identifier) for p in batch] 691 prompt = [item.prompt.format(identifier=args.class_identifier) for item in batch]
670 nprompt = [p.nprompt for p in batch] 692 nprompt = [item.nprompt for item in batch]
671 693
672 images = pipeline( 694 images = pipeline(
673 prompt=prompt, 695 prompt=prompt,
@@ -720,8 +742,8 @@ def main():
720 ) 742 )
721 743
722 # Move vae and unet to device 744 # Move vae and unet to device
723 vae.to(accelerator.device) 745 vae.to(accelerator.device, dtype=weight_dtype)
724 unet.to(accelerator.device) 746 unet.to(accelerator.device, dtype=weight_dtype)
725 747
726 # Keep vae and unet in eval mode as we don't train these 748 # Keep vae and unet in eval mode as we don't train these
727 vae.eval() 749 vae.eval()
@@ -812,7 +834,7 @@ def main():
812 latents = latents * 0.18215 834 latents = latents * 0.18215
813 835
814 # Sample noise that we'll add to the latents 836 # Sample noise that we'll add to the latents
815 noise = torch.randn(latents.shape).to(latents.device) 837 noise = torch.randn_like(latents)
816 bsz = latents.shape[0] 838 bsz = latents.shape[0]
817 # Sample a random timestep for each image 839 # Sample a random timestep for each image
818 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, 840 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps,
@@ -825,6 +847,7 @@ def main():
825 847
826 # Get the text embedding for conditioning 848 # Get the text embedding for conditioning
827 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) 849 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"])
850 encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype)
828 851
829 # Predict the noise residual 852 # Predict the noise residual
830 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 853 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
@@ -907,7 +930,7 @@ def main():
907 latents = vae.encode(batch["pixel_values"]).latent_dist.sample() 930 latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
908 latents = latents * 0.18215 931 latents = latents * 0.18215
909 932
910 noise = torch.randn(latents.shape).to(latents.device) 933 noise = torch.randn_like(latents)
911 bsz = latents.shape[0] 934 bsz = latents.shape[0]
912 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, 935 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps,
913 (bsz,), device=latents.device) 936 (bsz,), device=latents.device)
@@ -916,6 +939,7 @@ def main():
916 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 939 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
917 940
918 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) 941 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"])
942 encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype)
919 943
920 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 944 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
921 945