summaryrefslogtreecommitdiffstats
path: root/textual_inversion.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-10-08 21:56:54 +0200
committerVolpeon <git@volpeon.ink>2022-10-08 21:56:54 +0200
commit6aadb34af4fe5ca2dfc92fae8eee87610a5848ad (patch)
treef490b4794366e78f7b079eb04de1c7c00e17d34a /textual_inversion.py
parentFix small details (diff)
downloadtextual-inversion-diff-6aadb34af4fe5ca2dfc92fae8eee87610a5848ad.tar.gz
textual-inversion-diff-6aadb34af4fe5ca2dfc92fae8eee87610a5848ad.tar.bz2
textual-inversion-diff-6aadb34af4fe5ca2dfc92fae8eee87610a5848ad.zip
Update
Diffstat (limited to 'textual_inversion.py')
-rw-r--r--textual_inversion.py57
1 files changed, 28 insertions, 29 deletions
diff --git a/textual_inversion.py b/textual_inversion.py
index 4f2de9e..09871d4 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -4,6 +4,7 @@ import math
4import os 4import os
5import datetime 5import datetime
6import logging 6import logging
7import json
7from pathlib import Path 8from pathlib import Path
8 9
9import numpy as np 10import numpy as np
@@ -22,8 +23,6 @@ from tqdm.auto import tqdm
22from transformers import CLIPTextModel, CLIPTokenizer 23from transformers import CLIPTextModel, CLIPTokenizer
23from slugify import slugify 24from slugify import slugify
24from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 25from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
25import json
26import os
27 26
28from data.csv import CSVDataModule 27from data.csv import CSVDataModule
29 28
@@ -70,7 +69,7 @@ def parse_args():
70 parser.add_argument( 69 parser.add_argument(
71 "--num_class_images", 70 "--num_class_images",
72 type=int, 71 type=int,
73 default=4, 72 default=200,
74 help="How many class images to generate per training image." 73 help="How many class images to generate per training image."
75 ) 74 )
76 parser.add_argument( 75 parser.add_argument(
@@ -141,7 +140,7 @@ def parse_args():
141 parser.add_argument( 140 parser.add_argument(
142 "--lr_scheduler", 141 "--lr_scheduler",
143 type=str, 142 type=str,
144 default="constant", 143 default="linear",
145 help=( 144 help=(
146 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' 145 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
147 ' "constant", "constant_with_warmup"]' 146 ' "constant", "constant_with_warmup"]'
@@ -402,20 +401,20 @@ class Checkpointer:
402 generator=generator, 401 generator=generator,
403 ) 402 )
404 403
405 for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: 404 with torch.inference_mode():
406 all_samples = [] 405 for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]:
407 file_path = samples_path.joinpath(pool, f"step_{step}.png") 406 all_samples = []
408 file_path.parent.mkdir(parents=True, exist_ok=True) 407 file_path = samples_path.joinpath(pool, f"step_{step}.png")
408 file_path.parent.mkdir(parents=True, exist_ok=True)
409 409
410 data_enum = enumerate(data) 410 data_enum = enumerate(data)
411 411
412 for i in range(self.sample_batches): 412 for i in range(self.sample_batches):
413 batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] 413 batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size]
414 prompt = [prompt.format(self.placeholder_token) 414 prompt = [prompt.format(self.placeholder_token)
415 for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] 415 for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size]
416 nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] 416 nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size]
417 417
418 with self.accelerator.autocast():
419 samples = pipeline( 418 samples = pipeline(
420 prompt=prompt, 419 prompt=prompt,
421 negative_prompt=nprompt, 420 negative_prompt=nprompt,
@@ -429,15 +428,15 @@ class Checkpointer:
429 output_type='pil' 428 output_type='pil'
430 )["sample"] 429 )["sample"]
431 430
432 all_samples += samples 431 all_samples += samples
433 432
434 del samples 433 del samples
435 434
436 image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size) 435 image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size)
437 image_grid.save(file_path) 436 image_grid.save(file_path)
438 437
439 del all_samples 438 del all_samples
440 del image_grid 439 del image_grid
441 440
442 del unwrapped 441 del unwrapped
443 del scheduler 442 del scheduler
@@ -623,7 +622,7 @@ def main():
623 datamodule.setup() 622 datamodule.setup()
624 623
625 if args.num_class_images != 0: 624 if args.num_class_images != 0:
626 missing_data = [item for item in datamodule.data if not item[1].exists()] 625 missing_data = [item for item in datamodule.data_train if not item.class_image_path.exists()]
627 626
628 if len(missing_data) != 0: 627 if len(missing_data) != 0:
629 batched_data = [missing_data[i:i+args.sample_batch_size] 628 batched_data = [missing_data[i:i+args.sample_batch_size]
@@ -643,20 +642,20 @@ def main():
643 pipeline.enable_attention_slicing() 642 pipeline.enable_attention_slicing()
644 pipeline.set_progress_bar_config(dynamic_ncols=True) 643 pipeline.set_progress_bar_config(dynamic_ncols=True)
645 644
646 for batch in batched_data: 645 with torch.inference_mode():
647 image_name = [p[1] for p in batch] 646 for batch in batched_data:
648 prompt = [p[2].format(args.initializer_token) for p in batch] 647 image_name = [p.class_image_path for p in batch]
649 nprompt = [p[3] for p in batch] 648 prompt = [p.prompt.format(args.initializer_token) for p in batch]
649 nprompt = [p.nprompt for p in batch]
650 650
651 with accelerator.autocast():
652 images = pipeline( 651 images = pipeline(
653 prompt=prompt, 652 prompt=prompt,
654 negative_prompt=nprompt, 653 negative_prompt=nprompt,
655 num_inference_steps=args.sample_steps 654 num_inference_steps=args.sample_steps
656 ).images 655 ).images
657 656
658 for i, image in enumerate(images): 657 for i, image in enumerate(images):
659 image.save(image_name[i]) 658 image.save(image_name[i])
660 659
661 del pipeline 660 del pipeline
662 661