diff options
author | Volpeon <git@volpeon.ink> | 2022-11-27 19:07:23 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-11-27 19:07:23 +0100 |
commit | 839ddbf68680739c45235639bd565a3eb7cb8871 (patch) | |
tree | b73a60e9f8daeacdced45c83f0170dda1f3137fb | |
parent | Fix (diff) | |
download | textual-inversion-diff-839ddbf68680739c45235639bd565a3eb7cb8871.tar.gz textual-inversion-diff-839ddbf68680739c45235639bd565a3eb7cb8871.tar.bz2 textual-inversion-diff-839ddbf68680739c45235639bd565a3eb7cb8871.zip |
Fixed and improved Textual Inversion training
-rw-r--r-- | textual_inversion.py | 112 |
1 files changed, 68 insertions, 44 deletions
diff --git a/textual_inversion.py b/textual_inversion.py index b676088..20b1617 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
@@ -444,11 +444,25 @@ class Checkpointer: | |||
444 | 444 | ||
445 | data_enum = enumerate(data) | 445 | data_enum = enumerate(data) |
446 | 446 | ||
447 | batches = [ | ||
448 | batch | ||
449 | for j, batch in data_enum | ||
450 | if j * data.batch_size < self.sample_batch_size * self.sample_batches | ||
451 | ] | ||
452 | prompts = [ | ||
453 | prompt.format(identifier=self.instance_identifier) | ||
454 | for batch in batches | ||
455 | for prompt in batch["prompts"] | ||
456 | ] | ||
457 | nprompts = [ | ||
458 | prompt | ||
459 | for batch in batches | ||
460 | for prompt in batch["nprompts"] | ||
461 | ] | ||
462 | |||
447 | for i in range(self.sample_batches): | 463 | for i in range(self.sample_batches): |
448 | batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] | 464 | prompt = prompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] |
449 | prompt = [prompt.format(identifier=self.instance_identifier) | 465 | nprompt = nprompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] |
450 | for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] | ||
451 | nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] | ||
452 | 466 | ||
453 | samples = pipeline( | 467 | samples = pipeline( |
454 | prompt=prompt, | 468 | prompt=prompt, |
@@ -468,7 +482,7 @@ class Checkpointer: | |||
468 | del samples | 482 | del samples |
469 | 483 | ||
470 | image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size) | 484 | image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size) |
471 | image_grid.save(file_path) | 485 | image_grid.save(file_path, quality=85) |
472 | 486 | ||
473 | del all_samples | 487 | del all_samples |
474 | del image_grid | 488 | del image_grid |
@@ -485,6 +499,11 @@ class Checkpointer: | |||
485 | def main(): | 499 | def main(): |
486 | args = parse_args() | 500 | args = parse_args() |
487 | 501 | ||
502 | instance_identifier = args.instance_identifier | ||
503 | |||
504 | if len(args.placeholder_token) != 0: | ||
505 | instance_identifier = instance_identifier.format(args.placeholder_token[0]) | ||
506 | |||
488 | global_step_offset = 0 | 507 | global_step_offset = 0 |
489 | if args.resume_from is not None: | 508 | if args.resume_from is not None: |
490 | basepath = Path(args.resume_from) | 509 | basepath = Path(args.resume_from) |
@@ -496,7 +515,7 @@ def main(): | |||
496 | print("We've trained %d steps so far" % global_step_offset) | 515 | print("We've trained %d steps so far" % global_step_offset) |
497 | else: | 516 | else: |
498 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | 517 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") |
499 | basepath = Path(args.output_dir).joinpath(slugify(args.placeholder_token), now) | 518 | basepath = Path(args.output_dir).joinpath(slugify(instance_identifier), now) |
500 | basepath.mkdir(parents=True, exist_ok=True) | 519 | basepath.mkdir(parents=True, exist_ok=True) |
501 | 520 | ||
502 | accelerator = Accelerator( | 521 | accelerator = Accelerator( |
@@ -508,11 +527,8 @@ def main(): | |||
508 | 527 | ||
509 | logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) | 528 | logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) |
510 | 529 | ||
511 | # If passed along, set the training seed now. | 530 | args.seed = args.seed or (torch.random.seed() >> 32) |
512 | if args.seed is not None: | 531 | set_seed(args.seed) |
513 | set_seed(args.seed) | ||
514 | |||
515 | args.instance_identifier = args.instance_identifier.format(args.placeholder_token) | ||
516 | 532 | ||
517 | # Load the tokenizer and add the placeholder token as a additional special token | 533 | # Load the tokenizer and add the placeholder token as a additional special token |
518 | if args.tokenizer_name: | 534 | if args.tokenizer_name: |
@@ -520,17 +536,6 @@ def main(): | |||
520 | elif args.pretrained_model_name_or_path: | 536 | elif args.pretrained_model_name_or_path: |
521 | tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') | 537 | tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') |
522 | 538 | ||
523 | # Convert the initializer_token, placeholder_token to ids | ||
524 | initializer_token_ids = torch.stack([ | ||
525 | torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1]) | ||
526 | for token in args.initializer_token | ||
527 | ]) | ||
528 | |||
529 | num_added_tokens = tokenizer.add_tokens(args.placeholder_token) | ||
530 | print(f"Added {num_added_tokens} new tokens.") | ||
531 | |||
532 | placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) | ||
533 | |||
534 | # Load models and create wrapper for stable diffusion | 539 | # Load models and create wrapper for stable diffusion |
535 | text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder') | 540 | text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder') |
536 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae') | 541 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae') |
@@ -539,15 +544,23 @@ def main(): | |||
539 | checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( | 544 | checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( |
540 | args.pretrained_model_name_or_path, subfolder='scheduler') | 545 | args.pretrained_model_name_or_path, subfolder='scheduler') |
541 | 546 | ||
542 | prompt_processor = PromptProcessor(tokenizer, text_encoder) | ||
543 | |||
544 | unet.set_use_memory_efficient_attention_xformers(True) | 547 | unet.set_use_memory_efficient_attention_xformers(True) |
545 | 548 | ||
546 | if args.gradient_checkpointing: | 549 | if args.gradient_checkpointing: |
547 | text_encoder.gradient_checkpointing_enable() | 550 | text_encoder.gradient_checkpointing_enable() |
548 | 551 | ||
549 | # slice_size = unet.config.attention_head_dim // 2 | 552 | print(f"Adding text embeddings: {args.placeholder_token}") |
550 | # unet.set_attention_slice(slice_size) | 553 | |
554 | # Convert the initializer_token, placeholder_token to ids | ||
555 | initializer_token_ids = torch.stack([ | ||
556 | torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1]) | ||
557 | for token in args.initializer_token | ||
558 | ]) | ||
559 | |||
560 | num_added_tokens = tokenizer.add_tokens(args.placeholder_token) | ||
561 | print(f"Added {num_added_tokens} new tokens.") | ||
562 | |||
563 | placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) | ||
551 | 564 | ||
552 | # Resize the token embeddings as we are adding new special tokens to the tokenizer | 565 | # Resize the token embeddings as we are adding new special tokens to the tokenizer |
553 | text_encoder.resize_token_embeddings(len(tokenizer)) | 566 | text_encoder.resize_token_embeddings(len(tokenizer)) |
@@ -555,6 +568,10 @@ def main(): | |||
555 | # Initialise the newly added placeholder token with the embeddings of the initializer token | 568 | # Initialise the newly added placeholder token with the embeddings of the initializer token |
556 | token_embeds = text_encoder.get_input_embeddings().weight.data | 569 | token_embeds = text_encoder.get_input_embeddings().weight.data |
557 | original_token_embeds = token_embeds.detach().clone().to(accelerator.device) | 570 | original_token_embeds = token_embeds.detach().clone().to(accelerator.device) |
571 | initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) | ||
572 | |||
573 | for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings): | ||
574 | token_embeds[token_id] = embeddings | ||
558 | 575 | ||
559 | if args.resume_checkpoint is not None: | 576 | if args.resume_checkpoint is not None: |
560 | token_embeds[placeholder_token_id] = torch.load(args.resume_checkpoint)[args.placeholder_token] | 577 | token_embeds[placeholder_token_id] = torch.load(args.resume_checkpoint)[args.placeholder_token] |
@@ -567,12 +584,13 @@ def main(): | |||
567 | freeze_params(vae.parameters()) | 584 | freeze_params(vae.parameters()) |
568 | freeze_params(unet.parameters()) | 585 | freeze_params(unet.parameters()) |
569 | # Freeze all parameters except for the token embeddings in text encoder | 586 | # Freeze all parameters except for the token embeddings in text encoder |
570 | params_to_freeze = itertools.chain( | 587 | freeze_params(itertools.chain( |
571 | text_encoder.text_model.encoder.parameters(), | 588 | text_encoder.text_model.encoder.parameters(), |
572 | text_encoder.text_model.final_layer_norm.parameters(), | 589 | text_encoder.text_model.final_layer_norm.parameters(), |
573 | text_encoder.text_model.embeddings.position_embedding.parameters(), | 590 | text_encoder.text_model.embeddings.position_embedding.parameters(), |
574 | ) | 591 | )) |
575 | freeze_params(params_to_freeze) | 592 | |
593 | prompt_processor = PromptProcessor(tokenizer, text_encoder) | ||
576 | 594 | ||
577 | if args.scale_lr: | 595 | if args.scale_lr: |
578 | args.learning_rate = ( | 596 | args.learning_rate = ( |
@@ -600,6 +618,12 @@ def main(): | |||
600 | eps=args.adam_epsilon, | 618 | eps=args.adam_epsilon, |
601 | ) | 619 | ) |
602 | 620 | ||
621 | weight_dtype = torch.float32 | ||
622 | if args.mixed_precision == "fp16": | ||
623 | weight_dtype = torch.float16 | ||
624 | elif args.mixed_precision == "bf16": | ||
625 | weight_dtype = torch.bfloat16 | ||
626 | |||
603 | def collate_fn(examples): | 627 | def collate_fn(examples): |
604 | prompts = [example["prompts"] for example in examples] | 628 | prompts = [example["prompts"] for example in examples] |
605 | nprompts = [example["nprompts"] for example in examples] | 629 | nprompts = [example["nprompts"] for example in examples] |
@@ -612,7 +636,7 @@ def main(): | |||
612 | pixel_values += [example["class_images"] for example in examples] | 636 | pixel_values += [example["class_images"] for example in examples] |
613 | 637 | ||
614 | pixel_values = torch.stack(pixel_values) | 638 | pixel_values = torch.stack(pixel_values) |
615 | pixel_values = pixel_values.to(dtype=torch.float32, memory_format=torch.contiguous_format) | 639 | pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) |
616 | 640 | ||
617 | input_ids = prompt_processor.unify_input_ids(input_ids) | 641 | input_ids = prompt_processor.unify_input_ids(input_ids) |
618 | 642 | ||
@@ -647,27 +671,25 @@ def main(): | |||
647 | missing_data = [item for item in datamodule.data_train if not item.class_image_path.exists()] | 671 | missing_data = [item for item in datamodule.data_train if not item.class_image_path.exists()] |
648 | 672 | ||
649 | if len(missing_data) != 0: | 673 | if len(missing_data) != 0: |
650 | batched_data = [missing_data[i:i+args.sample_batch_size] | 674 | batched_data = [ |
651 | for i in range(0, len(missing_data), args.sample_batch_size)] | 675 | missing_data[i:i+args.sample_batch_size] |
652 | 676 | for i in range(0, len(missing_data), args.sample_batch_size) | |
653 | scheduler = EulerAncestralDiscreteScheduler( | 677 | ] |
654 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | ||
655 | ) | ||
656 | 678 | ||
657 | pipeline = VlpnStableDiffusion( | 679 | pipeline = VlpnStableDiffusion( |
658 | text_encoder=text_encoder, | 680 | text_encoder=text_encoder, |
659 | vae=vae, | 681 | vae=vae, |
660 | unet=unet, | 682 | unet=unet, |
661 | tokenizer=tokenizer, | 683 | tokenizer=tokenizer, |
662 | scheduler=scheduler, | 684 | scheduler=checkpoint_scheduler, |
663 | ).to(accelerator.device) | 685 | ).to(accelerator.device) |
664 | pipeline.set_progress_bar_config(dynamic_ncols=True) | 686 | pipeline.set_progress_bar_config(dynamic_ncols=True) |
665 | 687 | ||
666 | with torch.autocast("cuda"), torch.inference_mode(): | 688 | with torch.autocast("cuda"), torch.inference_mode(): |
667 | for batch in batched_data: | 689 | for batch in batched_data: |
668 | image_name = [p.class_image_path for p in batch] | 690 | image_name = [item.class_image_path for item in batch] |
669 | prompt = [p.prompt.format(identifier=args.class_identifier) for p in batch] | 691 | prompt = [item.prompt.format(identifier=args.class_identifier) for item in batch] |
670 | nprompt = [p.nprompt for p in batch] | 692 | nprompt = [item.nprompt for item in batch] |
671 | 693 | ||
672 | images = pipeline( | 694 | images = pipeline( |
673 | prompt=prompt, | 695 | prompt=prompt, |
@@ -720,8 +742,8 @@ def main(): | |||
720 | ) | 742 | ) |
721 | 743 | ||
722 | # Move vae and unet to device | 744 | # Move vae and unet to device |
723 | vae.to(accelerator.device) | 745 | vae.to(accelerator.device, dtype=weight_dtype) |
724 | unet.to(accelerator.device) | 746 | unet.to(accelerator.device, dtype=weight_dtype) |
725 | 747 | ||
726 | # Keep vae and unet in eval mode as we don't train these | 748 | # Keep vae and unet in eval mode as we don't train these |
727 | vae.eval() | 749 | vae.eval() |
@@ -812,7 +834,7 @@ def main(): | |||
812 | latents = latents * 0.18215 | 834 | latents = latents * 0.18215 |
813 | 835 | ||
814 | # Sample noise that we'll add to the latents | 836 | # Sample noise that we'll add to the latents |
815 | noise = torch.randn(latents.shape).to(latents.device) | 837 | noise = torch.randn_like(latents) |
816 | bsz = latents.shape[0] | 838 | bsz = latents.shape[0] |
817 | # Sample a random timestep for each image | 839 | # Sample a random timestep for each image |
818 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, | 840 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, |
@@ -825,6 +847,7 @@ def main(): | |||
825 | 847 | ||
826 | # Get the text embedding for conditioning | 848 | # Get the text embedding for conditioning |
827 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) | 849 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) |
850 | encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype) | ||
828 | 851 | ||
829 | # Predict the noise residual | 852 | # Predict the noise residual |
830 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 853 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
@@ -907,7 +930,7 @@ def main(): | |||
907 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 930 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() |
908 | latents = latents * 0.18215 | 931 | latents = latents * 0.18215 |
909 | 932 | ||
910 | noise = torch.randn(latents.shape).to(latents.device) | 933 | noise = torch.randn_like(latents) |
911 | bsz = latents.shape[0] | 934 | bsz = latents.shape[0] |
912 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, | 935 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, |
913 | (bsz,), device=latents.device) | 936 | (bsz,), device=latents.device) |
@@ -916,6 +939,7 @@ def main(): | |||
916 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | 939 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) |
917 | 940 | ||
918 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) | 941 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) |
942 | encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype) | ||
919 | 943 | ||
920 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 944 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
921 | 945 | ||