diff options
| -rw-r--r-- | textual_inversion.py | 112 |
1 files changed, 68 insertions, 44 deletions
diff --git a/textual_inversion.py b/textual_inversion.py index b676088..20b1617 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
| @@ -444,11 +444,25 @@ class Checkpointer: | |||
| 444 | 444 | ||
| 445 | data_enum = enumerate(data) | 445 | data_enum = enumerate(data) |
| 446 | 446 | ||
| 447 | batches = [ | ||
| 448 | batch | ||
| 449 | for j, batch in data_enum | ||
| 450 | if j * data.batch_size < self.sample_batch_size * self.sample_batches | ||
| 451 | ] | ||
| 452 | prompts = [ | ||
| 453 | prompt.format(identifier=self.instance_identifier) | ||
| 454 | for batch in batches | ||
| 455 | for prompt in batch["prompts"] | ||
| 456 | ] | ||
| 457 | nprompts = [ | ||
| 458 | prompt | ||
| 459 | for batch in batches | ||
| 460 | for prompt in batch["nprompts"] | ||
| 461 | ] | ||
| 462 | |||
| 447 | for i in range(self.sample_batches): | 463 | for i in range(self.sample_batches): |
| 448 | batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] | 464 | prompt = prompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] |
| 449 | prompt = [prompt.format(identifier=self.instance_identifier) | 465 | nprompt = nprompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] |
| 450 | for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] | ||
| 451 | nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] | ||
| 452 | 466 | ||
| 453 | samples = pipeline( | 467 | samples = pipeline( |
| 454 | prompt=prompt, | 468 | prompt=prompt, |
| @@ -468,7 +482,7 @@ class Checkpointer: | |||
| 468 | del samples | 482 | del samples |
| 469 | 483 | ||
| 470 | image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size) | 484 | image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size) |
| 471 | image_grid.save(file_path) | 485 | image_grid.save(file_path, quality=85) |
| 472 | 486 | ||
| 473 | del all_samples | 487 | del all_samples |
| 474 | del image_grid | 488 | del image_grid |
| @@ -485,6 +499,11 @@ class Checkpointer: | |||
| 485 | def main(): | 499 | def main(): |
| 486 | args = parse_args() | 500 | args = parse_args() |
| 487 | 501 | ||
| 502 | instance_identifier = args.instance_identifier | ||
| 503 | |||
| 504 | if len(args.placeholder_token) != 0: | ||
| 505 | instance_identifier = instance_identifier.format(args.placeholder_token[0]) | ||
| 506 | |||
| 488 | global_step_offset = 0 | 507 | global_step_offset = 0 |
| 489 | if args.resume_from is not None: | 508 | if args.resume_from is not None: |
| 490 | basepath = Path(args.resume_from) | 509 | basepath = Path(args.resume_from) |
| @@ -496,7 +515,7 @@ def main(): | |||
| 496 | print("We've trained %d steps so far" % global_step_offset) | 515 | print("We've trained %d steps so far" % global_step_offset) |
| 497 | else: | 516 | else: |
| 498 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | 517 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") |
| 499 | basepath = Path(args.output_dir).joinpath(slugify(args.placeholder_token), now) | 518 | basepath = Path(args.output_dir).joinpath(slugify(instance_identifier), now) |
| 500 | basepath.mkdir(parents=True, exist_ok=True) | 519 | basepath.mkdir(parents=True, exist_ok=True) |
| 501 | 520 | ||
| 502 | accelerator = Accelerator( | 521 | accelerator = Accelerator( |
| @@ -508,11 +527,8 @@ def main(): | |||
| 508 | 527 | ||
| 509 | logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) | 528 | logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) |
| 510 | 529 | ||
| 511 | # If passed along, set the training seed now. | 530 | args.seed = args.seed or (torch.random.seed() >> 32) |
| 512 | if args.seed is not None: | 531 | set_seed(args.seed) |
| 513 | set_seed(args.seed) | ||
| 514 | |||
| 515 | args.instance_identifier = args.instance_identifier.format(args.placeholder_token) | ||
| 516 | 532 | ||
| 517 | # Load the tokenizer and add the placeholder token as a additional special token | 533 | # Load the tokenizer and add the placeholder token as a additional special token |
| 518 | if args.tokenizer_name: | 534 | if args.tokenizer_name: |
| @@ -520,17 +536,6 @@ def main(): | |||
| 520 | elif args.pretrained_model_name_or_path: | 536 | elif args.pretrained_model_name_or_path: |
| 521 | tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') | 537 | tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') |
| 522 | 538 | ||
| 523 | # Convert the initializer_token, placeholder_token to ids | ||
| 524 | initializer_token_ids = torch.stack([ | ||
| 525 | torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1]) | ||
| 526 | for token in args.initializer_token | ||
| 527 | ]) | ||
| 528 | |||
| 529 | num_added_tokens = tokenizer.add_tokens(args.placeholder_token) | ||
| 530 | print(f"Added {num_added_tokens} new tokens.") | ||
| 531 | |||
| 532 | placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) | ||
| 533 | |||
| 534 | # Load models and create wrapper for stable diffusion | 539 | # Load models and create wrapper for stable diffusion |
| 535 | text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder') | 540 | text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder') |
| 536 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae') | 541 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae') |
| @@ -539,15 +544,23 @@ def main(): | |||
| 539 | checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( | 544 | checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( |
| 540 | args.pretrained_model_name_or_path, subfolder='scheduler') | 545 | args.pretrained_model_name_or_path, subfolder='scheduler') |
| 541 | 546 | ||
| 542 | prompt_processor = PromptProcessor(tokenizer, text_encoder) | ||
| 543 | |||
| 544 | unet.set_use_memory_efficient_attention_xformers(True) | 547 | unet.set_use_memory_efficient_attention_xformers(True) |
| 545 | 548 | ||
| 546 | if args.gradient_checkpointing: | 549 | if args.gradient_checkpointing: |
| 547 | text_encoder.gradient_checkpointing_enable() | 550 | text_encoder.gradient_checkpointing_enable() |
| 548 | 551 | ||
| 549 | # slice_size = unet.config.attention_head_dim // 2 | 552 | print(f"Adding text embeddings: {args.placeholder_token}") |
| 550 | # unet.set_attention_slice(slice_size) | 553 | |
| 554 | # Convert the initializer_token, placeholder_token to ids | ||
| 555 | initializer_token_ids = torch.stack([ | ||
| 556 | torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1]) | ||
| 557 | for token in args.initializer_token | ||
| 558 | ]) | ||
| 559 | |||
| 560 | num_added_tokens = tokenizer.add_tokens(args.placeholder_token) | ||
| 561 | print(f"Added {num_added_tokens} new tokens.") | ||
| 562 | |||
| 563 | placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) | ||
| 551 | 564 | ||
| 552 | # Resize the token embeddings as we are adding new special tokens to the tokenizer | 565 | # Resize the token embeddings as we are adding new special tokens to the tokenizer |
| 553 | text_encoder.resize_token_embeddings(len(tokenizer)) | 566 | text_encoder.resize_token_embeddings(len(tokenizer)) |
| @@ -555,6 +568,10 @@ def main(): | |||
| 555 | # Initialise the newly added placeholder token with the embeddings of the initializer token | 568 | # Initialise the newly added placeholder token with the embeddings of the initializer token |
| 556 | token_embeds = text_encoder.get_input_embeddings().weight.data | 569 | token_embeds = text_encoder.get_input_embeddings().weight.data |
| 557 | original_token_embeds = token_embeds.detach().clone().to(accelerator.device) | 570 | original_token_embeds = token_embeds.detach().clone().to(accelerator.device) |
| 571 | initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) | ||
| 572 | |||
| 573 | for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings): | ||
| 574 | token_embeds[token_id] = embeddings | ||
| 558 | 575 | ||
| 559 | if args.resume_checkpoint is not None: | 576 | if args.resume_checkpoint is not None: |
| 560 | token_embeds[placeholder_token_id] = torch.load(args.resume_checkpoint)[args.placeholder_token] | 577 | token_embeds[placeholder_token_id] = torch.load(args.resume_checkpoint)[args.placeholder_token] |
| @@ -567,12 +584,13 @@ def main(): | |||
| 567 | freeze_params(vae.parameters()) | 584 | freeze_params(vae.parameters()) |
| 568 | freeze_params(unet.parameters()) | 585 | freeze_params(unet.parameters()) |
| 569 | # Freeze all parameters except for the token embeddings in text encoder | 586 | # Freeze all parameters except for the token embeddings in text encoder |
| 570 | params_to_freeze = itertools.chain( | 587 | freeze_params(itertools.chain( |
| 571 | text_encoder.text_model.encoder.parameters(), | 588 | text_encoder.text_model.encoder.parameters(), |
| 572 | text_encoder.text_model.final_layer_norm.parameters(), | 589 | text_encoder.text_model.final_layer_norm.parameters(), |
| 573 | text_encoder.text_model.embeddings.position_embedding.parameters(), | 590 | text_encoder.text_model.embeddings.position_embedding.parameters(), |
| 574 | ) | 591 | )) |
| 575 | freeze_params(params_to_freeze) | 592 | |
| 593 | prompt_processor = PromptProcessor(tokenizer, text_encoder) | ||
| 576 | 594 | ||
| 577 | if args.scale_lr: | 595 | if args.scale_lr: |
| 578 | args.learning_rate = ( | 596 | args.learning_rate = ( |
| @@ -600,6 +618,12 @@ def main(): | |||
| 600 | eps=args.adam_epsilon, | 618 | eps=args.adam_epsilon, |
| 601 | ) | 619 | ) |
| 602 | 620 | ||
| 621 | weight_dtype = torch.float32 | ||
| 622 | if args.mixed_precision == "fp16": | ||
| 623 | weight_dtype = torch.float16 | ||
| 624 | elif args.mixed_precision == "bf16": | ||
| 625 | weight_dtype = torch.bfloat16 | ||
| 626 | |||
| 603 | def collate_fn(examples): | 627 | def collate_fn(examples): |
| 604 | prompts = [example["prompts"] for example in examples] | 628 | prompts = [example["prompts"] for example in examples] |
| 605 | nprompts = [example["nprompts"] for example in examples] | 629 | nprompts = [example["nprompts"] for example in examples] |
| @@ -612,7 +636,7 @@ def main(): | |||
| 612 | pixel_values += [example["class_images"] for example in examples] | 636 | pixel_values += [example["class_images"] for example in examples] |
| 613 | 637 | ||
| 614 | pixel_values = torch.stack(pixel_values) | 638 | pixel_values = torch.stack(pixel_values) |
| 615 | pixel_values = pixel_values.to(dtype=torch.float32, memory_format=torch.contiguous_format) | 639 | pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format) |
| 616 | 640 | ||
| 617 | input_ids = prompt_processor.unify_input_ids(input_ids) | 641 | input_ids = prompt_processor.unify_input_ids(input_ids) |
| 618 | 642 | ||
| @@ -647,27 +671,25 @@ def main(): | |||
| 647 | missing_data = [item for item in datamodule.data_train if not item.class_image_path.exists()] | 671 | missing_data = [item for item in datamodule.data_train if not item.class_image_path.exists()] |
| 648 | 672 | ||
| 649 | if len(missing_data) != 0: | 673 | if len(missing_data) != 0: |
| 650 | batched_data = [missing_data[i:i+args.sample_batch_size] | 674 | batched_data = [ |
| 651 | for i in range(0, len(missing_data), args.sample_batch_size)] | 675 | missing_data[i:i+args.sample_batch_size] |
| 652 | 676 | for i in range(0, len(missing_data), args.sample_batch_size) | |
| 653 | scheduler = EulerAncestralDiscreteScheduler( | 677 | ] |
| 654 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | ||
| 655 | ) | ||
| 656 | 678 | ||
| 657 | pipeline = VlpnStableDiffusion( | 679 | pipeline = VlpnStableDiffusion( |
| 658 | text_encoder=text_encoder, | 680 | text_encoder=text_encoder, |
| 659 | vae=vae, | 681 | vae=vae, |
| 660 | unet=unet, | 682 | unet=unet, |
| 661 | tokenizer=tokenizer, | 683 | tokenizer=tokenizer, |
| 662 | scheduler=scheduler, | 684 | scheduler=checkpoint_scheduler, |
| 663 | ).to(accelerator.device) | 685 | ).to(accelerator.device) |
| 664 | pipeline.set_progress_bar_config(dynamic_ncols=True) | 686 | pipeline.set_progress_bar_config(dynamic_ncols=True) |
| 665 | 687 | ||
| 666 | with torch.autocast("cuda"), torch.inference_mode(): | 688 | with torch.autocast("cuda"), torch.inference_mode(): |
| 667 | for batch in batched_data: | 689 | for batch in batched_data: |
| 668 | image_name = [p.class_image_path for p in batch] | 690 | image_name = [item.class_image_path for item in batch] |
| 669 | prompt = [p.prompt.format(identifier=args.class_identifier) for p in batch] | 691 | prompt = [item.prompt.format(identifier=args.class_identifier) for item in batch] |
| 670 | nprompt = [p.nprompt for p in batch] | 692 | nprompt = [item.nprompt for item in batch] |
| 671 | 693 | ||
| 672 | images = pipeline( | 694 | images = pipeline( |
| 673 | prompt=prompt, | 695 | prompt=prompt, |
| @@ -720,8 +742,8 @@ def main(): | |||
| 720 | ) | 742 | ) |
| 721 | 743 | ||
| 722 | # Move vae and unet to device | 744 | # Move vae and unet to device |
| 723 | vae.to(accelerator.device) | 745 | vae.to(accelerator.device, dtype=weight_dtype) |
| 724 | unet.to(accelerator.device) | 746 | unet.to(accelerator.device, dtype=weight_dtype) |
| 725 | 747 | ||
| 726 | # Keep vae and unet in eval mode as we don't train these | 748 | # Keep vae and unet in eval mode as we don't train these |
| 727 | vae.eval() | 749 | vae.eval() |
| @@ -812,7 +834,7 @@ def main(): | |||
| 812 | latents = latents * 0.18215 | 834 | latents = latents * 0.18215 |
| 813 | 835 | ||
| 814 | # Sample noise that we'll add to the latents | 836 | # Sample noise that we'll add to the latents |
| 815 | noise = torch.randn(latents.shape).to(latents.device) | 837 | noise = torch.randn_like(latents) |
| 816 | bsz = latents.shape[0] | 838 | bsz = latents.shape[0] |
| 817 | # Sample a random timestep for each image | 839 | # Sample a random timestep for each image |
| 818 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, | 840 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, |
| @@ -825,6 +847,7 @@ def main(): | |||
| 825 | 847 | ||
| 826 | # Get the text embedding for conditioning | 848 | # Get the text embedding for conditioning |
| 827 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) | 849 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) |
| 850 | encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype) | ||
| 828 | 851 | ||
| 829 | # Predict the noise residual | 852 | # Predict the noise residual |
| 830 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 853 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
| @@ -907,7 +930,7 @@ def main(): | |||
| 907 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 930 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() |
| 908 | latents = latents * 0.18215 | 931 | latents = latents * 0.18215 |
| 909 | 932 | ||
| 910 | noise = torch.randn(latents.shape).to(latents.device) | 933 | noise = torch.randn_like(latents) |
| 911 | bsz = latents.shape[0] | 934 | bsz = latents.shape[0] |
| 912 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, | 935 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, |
| 913 | (bsz,), device=latents.device) | 936 | (bsz,), device=latents.device) |
| @@ -916,6 +939,7 @@ def main(): | |||
| 916 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | 939 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) |
| 917 | 940 | ||
| 918 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) | 941 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) |
| 942 | encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype) | ||
| 919 | 943 | ||
| 920 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 944 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
| 921 | 945 | ||
