summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-14 09:25:13 +0100
committerVolpeon <git@volpeon.ink>2023-01-14 09:25:13 +0100
commite2d3a62bce63fcde940395a1c5618c4eb43385a9 (patch)
tree574f7a794feab13e1cf0ed18522a66d4737b6db3
parentUnified training script structure (diff)
downloadtextual-inversion-diff-e2d3a62bce63fcde940395a1c5618c4eb43385a9.tar.gz
textual-inversion-diff-e2d3a62bce63fcde940395a1c5618c4eb43385a9.tar.bz2
textual-inversion-diff-e2d3a62bce63fcde940395a1c5618c4eb43385a9.zip
Cleanup
-rw-r--r--data/csv.py21
-rw-r--r--infer.py19
-rw-r--r--models/clip/tokenizer.py5
-rw-r--r--train_dreambooth.py33
-rw-r--r--train_ti.py33
-rw-r--r--training/common.py97
-rw-r--r--training/util.py26
7 files changed, 103 insertions, 131 deletions
diff --git a/data/csv.py b/data/csv.py
index a3fef30..df3ee77 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -100,20 +100,16 @@ def generate_buckets(
100 return buckets, bucket_items, bucket_assignments 100 return buckets, bucket_items, bucket_assignments
101 101
102 102
103def collate_fn( 103def collate_fn(weight_dtype: torch.dtype, tokenizer: CLIPTokenizer, examples):
104 num_class_images: int, 104 with_prior = all("class_prompt_ids" in example for example in examples)
105 weight_dtype: torch.dtype, 105
106 tokenizer: CLIPTokenizer,
107 examples
108):
109 prompt_ids = [example["prompt_ids"] for example in examples] 106 prompt_ids = [example["prompt_ids"] for example in examples]
110 nprompt_ids = [example["nprompt_ids"] for example in examples] 107 nprompt_ids = [example["nprompt_ids"] for example in examples]
111 108
112 input_ids = [example["instance_prompt_ids"] for example in examples] 109 input_ids = [example["instance_prompt_ids"] for example in examples]
113 pixel_values = [example["instance_images"] for example in examples] 110 pixel_values = [example["instance_images"] for example in examples]
114 111
115 # concat class and instance examples for prior preservation 112 if with_prior:
116 if num_class_images != 0 and "class_prompt_ids" in examples[0]:
117 input_ids += [example["class_prompt_ids"] for example in examples] 113 input_ids += [example["class_prompt_ids"] for example in examples]
118 pixel_values += [example["class_images"] for example in examples] 114 pixel_values += [example["class_images"] for example in examples]
119 115
@@ -125,6 +121,7 @@ def collate_fn(
125 inputs = unify_input_ids(tokenizer, input_ids) 121 inputs = unify_input_ids(tokenizer, input_ids)
126 122
127 batch = { 123 batch = {
124 "with_prior": torch.tensor(with_prior),
128 "prompt_ids": prompts.input_ids, 125 "prompt_ids": prompts.input_ids,
129 "nprompt_ids": nprompts.input_ids, 126 "nprompt_ids": nprompts.input_ids,
130 "input_ids": inputs.input_ids, 127 "input_ids": inputs.input_ids,
@@ -166,7 +163,6 @@ class VlpnDataModule():
166 seed: Optional[int] = None, 163 seed: Optional[int] = None,
167 filter: Optional[Callable[[VlpnDataItem], bool]] = None, 164 filter: Optional[Callable[[VlpnDataItem], bool]] = None,
168 dtype: torch.dtype = torch.float32, 165 dtype: torch.dtype = torch.float32,
169 num_workers: int = 0
170 ): 166 ):
171 super().__init__() 167 super().__init__()
172 168
@@ -194,7 +190,6 @@ class VlpnDataModule():
194 self.valid_set_repeat = valid_set_repeat 190 self.valid_set_repeat = valid_set_repeat
195 self.seed = seed 191 self.seed = seed
196 self.filter = filter 192 self.filter = filter
197 self.num_workers = num_workers
198 self.batch_size = batch_size 193 self.batch_size = batch_size
199 self.dtype = dtype 194 self.dtype = dtype
200 195
@@ -290,16 +285,16 @@ class VlpnDataModule():
290 size=self.size, interpolation=self.interpolation, 285 size=self.size, interpolation=self.interpolation,
291 ) 286 )
292 287
293 collate_fn_ = partial(collate_fn, self.num_class_images, self.dtype, self.tokenizer) 288 collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer)
294 289
295 self.train_dataloader = DataLoader( 290 self.train_dataloader = DataLoader(
296 train_dataset, 291 train_dataset,
297 batch_size=None, pin_memory=True, collate_fn=collate_fn_, num_workers=self.num_workers 292 batch_size=None, pin_memory=True, collate_fn=collate_fn_
298 ) 293 )
299 294
300 self.val_dataloader = DataLoader( 295 self.val_dataloader = DataLoader(
301 val_dataset, 296 val_dataset,
302 batch_size=None, pin_memory=True, collate_fn=collate_fn_, num_workers=self.num_workers 297 batch_size=None, pin_memory=True, collate_fn=collate_fn_
303 ) 298 )
304 299
305 300
diff --git a/infer.py b/infer.py
index 2b07b21..36b5a2c 100644
--- a/infer.py
+++ b/infer.py
@@ -214,21 +214,10 @@ def load_embeddings(pipeline, embeddings_dir):
214def create_pipeline(model, dtype): 214def create_pipeline(model, dtype):
215 print("Loading Stable Diffusion pipeline...") 215 print("Loading Stable Diffusion pipeline...")
216 216
217 tokenizer = MultiCLIPTokenizer.from_pretrained(model, subfolder='tokenizer', torch_dtype=dtype) 217 pipeline = VlpnStableDiffusion.from_pretrained(model, torch_dtype=dtype)
218 text_encoder = CLIPTextModel.from_pretrained(model, subfolder='text_encoder', torch_dtype=dtype) 218
219 vae = AutoencoderKL.from_pretrained(model, subfolder='vae', torch_dtype=dtype) 219 patch_managed_embeddings(pipeline.text_encoder)
220 unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype) 220
221 scheduler = DDIMScheduler.from_pretrained(model, subfolder='scheduler', torch_dtype=dtype)
222
223 patch_managed_embeddings(text_encoder)
224
225 pipeline = VlpnStableDiffusion(
226 text_encoder=text_encoder,
227 vae=vae,
228 unet=unet,
229 tokenizer=tokenizer,
230 scheduler=scheduler,
231 )
232 pipeline.enable_xformers_memory_efficient_attention() 221 pipeline.enable_xformers_memory_efficient_attention()
233 pipeline.enable_vae_slicing() 222 pipeline.enable_vae_slicing()
234 pipeline.to("cuda") 223 pipeline.to("cuda")
diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py
index 39c41ed..789b525 100644
--- a/models/clip/tokenizer.py
+++ b/models/clip/tokenizer.py
@@ -55,6 +55,9 @@ def shuffle_auto(tokens: list[int]):
55 return shuffle_all(tokens) 55 return shuffle_all(tokens)
56 56
57 57
58ShuffleAlgorithm = Union[bool, Literal["all", "trailing", "leading", "between", "off"]]
59
60
58class MultiCLIPTokenizer(CLIPTokenizer): 61class MultiCLIPTokenizer(CLIPTokenizer):
59 def __init__(self, *args, **kwargs): 62 def __init__(self, *args, **kwargs):
60 super().__init__(*args, **kwargs) 63 super().__init__(*args, **kwargs)
@@ -73,7 +76,7 @@ class MultiCLIPTokenizer(CLIPTokenizer):
73 def set_dropout(self, dropout: float): 76 def set_dropout(self, dropout: float):
74 self.dropout = dropout 77 self.dropout = dropout
75 78
76 def set_use_vector_shuffle(self, algorithm: Union[bool, Literal["all", "trailing", "leading", "between", "off"]]): 79 def set_use_vector_shuffle(self, algorithm: ShuffleAlgorithm):
77 if algorithm == "leading": 80 if algorithm == "leading":
78 self.vector_shuffle = shuffle_leading 81 self.vector_shuffle = shuffle_leading
79 elif algorithm == "trailing": 82 elif algorithm == "trailing":
diff --git a/train_dreambooth.py b/train_dreambooth.py
index a1802a0..c180170 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -195,15 +195,6 @@ def parse_args():
195 ), 195 ),
196 ) 196 )
197 parser.add_argument( 197 parser.add_argument(
198 "--dataloader_num_workers",
199 type=int,
200 default=0,
201 help=(
202 "The number of subprocesses to use for data loading. 0 means that the data will be loaded in the main"
203 " process."
204 ),
205 )
206 parser.add_argument(
207 "--num_train_epochs", 198 "--num_train_epochs",
208 type=int, 199 type=int,
209 default=100 200 default=100
@@ -577,24 +568,24 @@ def main():
577 ) 568 )
578 569
579 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") 570 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
580 basepath = Path(args.output_dir).joinpath(slugify(args.project), now) 571 output_dir = Path(args.output_dir).joinpath(slugify(args.project), now)
581 basepath.mkdir(parents=True, exist_ok=True) 572 output_dir.mkdir(parents=True, exist_ok=True)
582 573
583 accelerator = Accelerator( 574 accelerator = Accelerator(
584 log_with=LoggerType.TENSORBOARD, 575 log_with=LoggerType.TENSORBOARD,
585 logging_dir=f"{basepath}", 576 logging_dir=f"{output_dir}",
586 gradient_accumulation_steps=args.gradient_accumulation_steps, 577 gradient_accumulation_steps=args.gradient_accumulation_steps,
587 mixed_precision=args.mixed_precision 578 mixed_precision=args.mixed_precision
588 ) 579 )
589 580
590 logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) 581 logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG)
591 582
592 if args.seed is None: 583 if args.seed is None:
593 args.seed = torch.random.seed() >> 32 584 args.seed = torch.random.seed() >> 32
594 585
595 set_seed(args.seed) 586 set_seed(args.seed)
596 587
597 save_args(basepath, args) 588 save_args(output_dir, args)
598 589
599 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( 590 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models(
600 args.pretrained_model_name_or_path) 591 args.pretrained_model_name_or_path)
@@ -618,7 +609,7 @@ def main():
618 added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) 609 added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir)
619 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") 610 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}")
620 611
621 placeholder_token_ids = add_placeholder_tokens( 612 placeholder_token_ids, initializer_token_ids = add_placeholder_tokens(
622 tokenizer=tokenizer, 613 tokenizer=tokenizer,
623 embeddings=embeddings, 614 embeddings=embeddings,
624 placeholder_tokens=args.placeholder_tokens, 615 placeholder_tokens=args.placeholder_tokens,
@@ -627,7 +618,9 @@ def main():
627 ) 618 )
628 619
629 if len(placeholder_token_ids) != 0: 620 if len(placeholder_token_ids) != 0:
630 print(f"Added {len(placeholder_token_ids)} new tokens: {list(zip(args.placeholder_tokens, placeholder_token_ids))}") 621 initializer_token_id_lens = [len(id) for id in initializer_token_ids]
622 placeholder_token_stats = list(zip(args.placeholder_tokens, placeholder_token_ids, initializer_token_id_lens))
623 print(f"Added {len(placeholder_token_ids)} new tokens: {placeholder_token_stats}")
631 624
632 if args.use_ema: 625 if args.use_ema:
633 ema_unet = EMAModel( 626 ema_unet = EMAModel(
@@ -726,7 +719,6 @@ def main():
726 template_key=args.train_data_template, 719 template_key=args.train_data_template,
727 valid_set_size=args.valid_set_size, 720 valid_set_size=args.valid_set_size,
728 valid_set_repeat=args.valid_set_repeat, 721 valid_set_repeat=args.valid_set_repeat,
729 num_workers=args.dataloader_num_workers,
730 seed=args.seed, 722 seed=args.seed,
731 filter=keyword_filter, 723 filter=keyword_filter,
732 dtype=weight_dtype 724 dtype=weight_dtype
@@ -830,7 +822,6 @@ def main():
830 noise_scheduler, 822 noise_scheduler,
831 unet, 823 unet,
832 text_encoder, 824 text_encoder,
833 args.num_class_images,
834 args.prior_loss_weight, 825 args.prior_loss_weight,
835 args.seed, 826 args.seed,
836 ) 827 )
@@ -848,7 +839,8 @@ def main():
848 scheduler=sample_scheduler, 839 scheduler=sample_scheduler,
849 placeholder_tokens=args.placeholder_tokens, 840 placeholder_tokens=args.placeholder_tokens,
850 placeholder_token_ids=placeholder_token_ids, 841 placeholder_token_ids=placeholder_token_ids,
851 output_dir=basepath, 842 output_dir=output_dir,
843 sample_steps=args.sample_steps,
852 sample_image_size=args.sample_image_size, 844 sample_image_size=args.sample_image_size,
853 sample_batch_size=args.sample_batch_size, 845 sample_batch_size=args.sample_batch_size,
854 sample_batches=args.sample_batches, 846 sample_batches=args.sample_batches,
@@ -873,7 +865,7 @@ def main():
873 ) 865 )
874 lr_finder.run(num_epochs=100, end_lr=1e2) 866 lr_finder.run(num_epochs=100, end_lr=1e2)
875 867
876 plt.savefig(basepath.joinpath("lr.png"), dpi=300) 868 plt.savefig(output_dir.joinpath("lr.png"), dpi=300)
877 plt.close() 869 plt.close()
878 else: 870 else:
879 train_loop( 871 train_loop(
@@ -886,7 +878,6 @@ def main():
886 val_dataloader=val_dataloader, 878 val_dataloader=val_dataloader,
887 loss_step=loss_step_, 879 loss_step=loss_step_,
888 sample_frequency=args.sample_frequency, 880 sample_frequency=args.sample_frequency,
889 sample_steps=args.sample_steps,
890 checkpoint_frequency=args.checkpoint_frequency, 881 checkpoint_frequency=args.checkpoint_frequency,
891 global_step_offset=0, 882 global_step_offset=0,
892 num_epochs=args.num_train_epochs, 883 num_epochs=args.num_train_epochs,
diff --git a/train_ti.py b/train_ti.py
index d2ca7eb..d752927 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -181,15 +181,6 @@ def parse_args():
181 help='Vector shuffling algorithm. Choose between ["all", "trailing", "leading", "between", "auto", "off"]', 181 help='Vector shuffling algorithm. Choose between ["all", "trailing", "leading", "between", "auto", "off"]',
182 ) 182 )
183 parser.add_argument( 183 parser.add_argument(
184 "--dataloader_num_workers",
185 type=int,
186 default=0,
187 help=(
188 "The number of subprocesses to use for data loading. 0 means that the data will be loaded in the main"
189 " process."
190 ),
191 )
192 parser.add_argument(
193 "--num_train_epochs", 184 "--num_train_epochs",
194 type=int, 185 type=int,
195 default=100 186 default=100
@@ -575,24 +566,24 @@ def main():
575 566
576 global_step_offset = args.global_step 567 global_step_offset = args.global_step
577 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") 568 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
578 basepath = Path(args.output_dir).joinpath(slugify(args.project), now) 569 output_dir = Path(args.output_dir).joinpath(slugify(args.project), now)
579 basepath.mkdir(parents=True, exist_ok=True) 570 output_dir.mkdir(parents=True, exist_ok=True)
580 571
581 accelerator = Accelerator( 572 accelerator = Accelerator(
582 log_with=LoggerType.TENSORBOARD, 573 log_with=LoggerType.TENSORBOARD,
583 logging_dir=f"{basepath}", 574 logging_dir=f"{output_dir}",
584 gradient_accumulation_steps=args.gradient_accumulation_steps, 575 gradient_accumulation_steps=args.gradient_accumulation_steps,
585 mixed_precision=args.mixed_precision 576 mixed_precision=args.mixed_precision
586 ) 577 )
587 578
588 logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) 579 logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG)
589 580
590 if args.seed is None: 581 if args.seed is None:
591 args.seed = torch.random.seed() >> 32 582 args.seed = torch.random.seed() >> 32
592 583
593 set_seed(args.seed) 584 set_seed(args.seed)
594 585
595 save_args(basepath, args) 586 save_args(output_dir, args)
596 587
597 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( 588 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models(
598 args.pretrained_model_name_or_path) 589 args.pretrained_model_name_or_path)
@@ -616,7 +607,7 @@ def main():
616 added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) 607 added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir)
617 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") 608 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}")
618 609
619 placeholder_token_ids = add_placeholder_tokens( 610 placeholder_token_ids, initializer_token_ids = add_placeholder_tokens(
620 tokenizer=tokenizer, 611 tokenizer=tokenizer,
621 embeddings=embeddings, 612 embeddings=embeddings,
622 placeholder_tokens=args.placeholder_tokens, 613 placeholder_tokens=args.placeholder_tokens,
@@ -625,7 +616,9 @@ def main():
625 ) 616 )
626 617
627 if len(placeholder_token_ids) != 0: 618 if len(placeholder_token_ids) != 0:
628 print(f"Added {len(placeholder_token_ids)} new tokens: {list(zip(args.placeholder_tokens, placeholder_token_ids))}") 619 initializer_token_id_lens = [len(id) for id in initializer_token_ids]
620 placeholder_token_stats = list(zip(args.placeholder_tokens, placeholder_token_ids, initializer_token_id_lens))
621 print(f"Added {len(placeholder_token_ids)} new tokens: {placeholder_token_stats}")
629 622
630 if args.use_ema: 623 if args.use_ema:
631 ema_embeddings = EMAModel( 624 ema_embeddings = EMAModel(
@@ -708,7 +701,6 @@ def main():
708 template_key=args.train_data_template, 701 template_key=args.train_data_template,
709 valid_set_size=args.valid_set_size, 702 valid_set_size=args.valid_set_size,
710 valid_set_repeat=args.valid_set_repeat, 703 valid_set_repeat=args.valid_set_repeat,
711 num_workers=args.dataloader_num_workers,
712 seed=args.seed, 704 seed=args.seed,
713 filter=keyword_filter, 705 filter=keyword_filter,
714 dtype=weight_dtype 706 dtype=weight_dtype
@@ -807,7 +799,6 @@ def main():
807 noise_scheduler, 799 noise_scheduler,
808 unet, 800 unet,
809 text_encoder, 801 text_encoder,
810 args.num_class_images != 0,
811 args.prior_loss_weight, 802 args.prior_loss_weight,
812 args.seed, 803 args.seed,
813 ) 804 )
@@ -825,7 +816,8 @@ def main():
825 scheduler=sample_scheduler, 816 scheduler=sample_scheduler,
826 placeholder_tokens=args.placeholder_tokens, 817 placeholder_tokens=args.placeholder_tokens,
827 placeholder_token_ids=placeholder_token_ids, 818 placeholder_token_ids=placeholder_token_ids,
828 output_dir=basepath, 819 output_dir=output_dir,
820 sample_steps=args.sample_steps,
829 sample_image_size=args.sample_image_size, 821 sample_image_size=args.sample_image_size,
830 sample_batch_size=args.sample_batch_size, 822 sample_batch_size=args.sample_batch_size,
831 sample_batches=args.sample_batches, 823 sample_batches=args.sample_batches,
@@ -849,7 +841,7 @@ def main():
849 ) 841 )
850 lr_finder.run(num_epochs=100, end_lr=1e3) 842 lr_finder.run(num_epochs=100, end_lr=1e3)
851 843
852 plt.savefig(basepath.joinpath("lr.png"), dpi=300) 844 plt.savefig(output_dir.joinpath("lr.png"), dpi=300)
853 plt.close() 845 plt.close()
854 else: 846 else:
855 train_loop( 847 train_loop(
@@ -862,7 +854,6 @@ def main():
862 val_dataloader=val_dataloader, 854 val_dataloader=val_dataloader,
863 loss_step=loss_step_, 855 loss_step=loss_step_,
864 sample_frequency=args.sample_frequency, 856 sample_frequency=args.sample_frequency,
865 sample_steps=args.sample_steps,
866 checkpoint_frequency=args.checkpoint_frequency, 857 checkpoint_frequency=args.checkpoint_frequency,
867 global_step_offset=global_step_offset, 858 global_step_offset=global_step_offset,
868 num_epochs=args.num_train_epochs, 859 num_epochs=args.num_train_epochs,
diff --git a/training/common.py b/training/common.py
index b6964a3..f5ab326 100644
--- a/training/common.py
+++ b/training/common.py
@@ -45,42 +45,44 @@ def generate_class_images(
45): 45):
46 missing_data = [item for item in data_train if not item.class_image_path.exists()] 46 missing_data = [item for item in data_train if not item.class_image_path.exists()]
47 47
48 if len(missing_data) != 0: 48 if len(missing_data) == 0:
49 batched_data = [ 49 return
50 missing_data[i:i+sample_batch_size] 50
51 for i in range(0, len(missing_data), sample_batch_size) 51 batched_data = [
52 ] 52 missing_data[i:i+sample_batch_size]
53 53 for i in range(0, len(missing_data), sample_batch_size)
54 pipeline = VlpnStableDiffusion( 54 ]
55 text_encoder=text_encoder, 55
56 vae=vae, 56 pipeline = VlpnStableDiffusion(
57 unet=unet, 57 text_encoder=text_encoder,
58 tokenizer=tokenizer, 58 vae=vae,
59 scheduler=scheduler, 59 unet=unet,
60 ).to(accelerator.device) 60 tokenizer=tokenizer,
61 pipeline.set_progress_bar_config(dynamic_ncols=True) 61 scheduler=scheduler,
62 62 ).to(accelerator.device)
63 with torch.inference_mode(): 63 pipeline.set_progress_bar_config(dynamic_ncols=True)
64 for batch in batched_data: 64
65 image_name = [item.class_image_path for item in batch] 65 with torch.inference_mode():
66 prompt = [item.cprompt for item in batch] 66 for batch in batched_data:
67 nprompt = [item.nprompt for item in batch] 67 image_name = [item.class_image_path for item in batch]
68 68 prompt = [item.cprompt for item in batch]
69 images = pipeline( 69 nprompt = [item.nprompt for item in batch]
70 prompt=prompt, 70
71 negative_prompt=nprompt, 71 images = pipeline(
72 height=sample_image_size, 72 prompt=prompt,
73 width=sample_image_size, 73 negative_prompt=nprompt,
74 num_inference_steps=sample_steps 74 height=sample_image_size,
75 ).images 75 width=sample_image_size,
76 76 num_inference_steps=sample_steps
77 for i, image in enumerate(images): 77 ).images
78 image.save(image_name[i]) 78
79 79 for i, image in enumerate(images):
80 del pipeline 80 image.save(image_name[i])
81 81
82 if torch.cuda.is_available(): 82 del pipeline
83 torch.cuda.empty_cache() 83
84 if torch.cuda.is_available():
85 torch.cuda.empty_cache()
84 86
85 87
86def get_models(pretrained_model_name_or_path: str): 88def get_models(pretrained_model_name_or_path: str):
@@ -119,7 +121,7 @@ def add_placeholder_tokens(
119 for (placeholder_token_id, initializer_token_id) in zip(placeholder_token_ids, initializer_token_ids): 121 for (placeholder_token_id, initializer_token_id) in zip(placeholder_token_ids, initializer_token_ids):
120 embeddings.add_embed(placeholder_token_id, initializer_token_id) 122 embeddings.add_embed(placeholder_token_id, initializer_token_id)
121 123
122 return placeholder_token_ids 124 return placeholder_token_ids, initializer_token_ids
123 125
124 126
125def loss_step( 127def loss_step(
@@ -127,7 +129,6 @@ def loss_step(
127 noise_scheduler: DDPMScheduler, 129 noise_scheduler: DDPMScheduler,
128 unet: UNet2DConditionModel, 130 unet: UNet2DConditionModel,
129 text_encoder: CLIPTextModel, 131 text_encoder: CLIPTextModel,
130 with_prior: bool,
131 prior_loss_weight: float, 132 prior_loss_weight: float,
132 seed: int, 133 seed: int,
133 step: int, 134 step: int,
@@ -138,16 +139,23 @@ def loss_step(
138 latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() 139 latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach()
139 latents = latents * 0.18215 140 latents = latents * 0.18215
140 141
142 generator = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None
143
141 # Sample noise that we'll add to the latents 144 # Sample noise that we'll add to the latents
142 noise = torch.randn_like(latents) 145 noise = torch.randn(
146 latents.shape,
147 dtype=latents.dtype,
148 layout=latents.layout,
149 device=latents.device,
150 generator=generator
151 )
143 bsz = latents.shape[0] 152 bsz = latents.shape[0]
144 # Sample a random timestep for each image 153 # Sample a random timestep for each image
145 timesteps_gen = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None
146 timesteps = torch.randint( 154 timesteps = torch.randint(
147 0, 155 0,
148 noise_scheduler.config.num_train_timesteps, 156 noise_scheduler.config.num_train_timesteps,
149 (bsz,), 157 (bsz,),
150 generator=timesteps_gen, 158 generator=generator,
151 device=latents.device, 159 device=latents.device,
152 ) 160 )
153 timesteps = timesteps.long() 161 timesteps = timesteps.long()
@@ -176,7 +184,7 @@ def loss_step(
176 else: 184 else:
177 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") 185 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
178 186
179 if with_prior: 187 if batch["with_prior"]:
180 # Chunk the noise and model_pred into two parts and compute the loss on each part separately. 188 # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
181 model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) 189 model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
182 target, target_prior = torch.chunk(target, 2, dim=0) 190 target, target_prior = torch.chunk(target, 2, dim=0)
@@ -207,7 +215,6 @@ def train_loop(
207 val_dataloader: DataLoader, 215 val_dataloader: DataLoader,
208 loss_step: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], 216 loss_step: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]],
209 sample_frequency: int = 10, 217 sample_frequency: int = 10,
210 sample_steps: int = 20,
211 checkpoint_frequency: int = 50, 218 checkpoint_frequency: int = 50,
212 global_step_offset: int = 0, 219 global_step_offset: int = 0,
213 num_epochs: int = 100, 220 num_epochs: int = 100,
@@ -251,7 +258,7 @@ def train_loop(
251 for epoch in range(num_epochs): 258 for epoch in range(num_epochs):
252 if accelerator.is_main_process: 259 if accelerator.is_main_process:
253 if epoch % sample_frequency == 0: 260 if epoch % sample_frequency == 0:
254 checkpointer.save_samples(global_step + global_step_offset, sample_steps) 261 checkpointer.save_samples(global_step + global_step_offset)
255 262
256 if epoch % checkpoint_frequency == 0 and epoch != 0: 263 if epoch % checkpoint_frequency == 0 and epoch != 0:
257 checkpointer.checkpoint(global_step + global_step_offset, "training") 264 checkpointer.checkpoint(global_step + global_step_offset, "training")
@@ -353,7 +360,7 @@ def train_loop(
353 if accelerator.is_main_process: 360 if accelerator.is_main_process:
354 print("Finished!") 361 print("Finished!")
355 checkpointer.checkpoint(global_step + global_step_offset, "end") 362 checkpointer.checkpoint(global_step + global_step_offset, "end")
356 checkpointer.save_samples(global_step + global_step_offset, sample_steps) 363 checkpointer.save_samples(global_step + global_step_offset)
357 accelerator.end_training() 364 accelerator.end_training()
358 365
359 except KeyboardInterrupt: 366 except KeyboardInterrupt:
diff --git a/training/util.py b/training/util.py
index cc4cdee..1008021 100644
--- a/training/util.py
+++ b/training/util.py
@@ -44,32 +44,29 @@ class CheckpointerBase:
44 train_dataloader, 44 train_dataloader,
45 val_dataloader, 45 val_dataloader,
46 output_dir: Path, 46 output_dir: Path,
47 sample_image_size: int, 47 sample_steps: int = 20,
48 sample_batches: int, 48 sample_guidance_scale: float = 7.5,
49 sample_batch_size: int, 49 sample_image_size: int = 768,
50 sample_batches: int = 1,
51 sample_batch_size: int = 1,
50 seed: Optional[int] = None 52 seed: Optional[int] = None
51 ): 53 ):
52 self.train_dataloader = train_dataloader 54 self.train_dataloader = train_dataloader
53 self.val_dataloader = val_dataloader 55 self.val_dataloader = val_dataloader
54 self.output_dir = output_dir 56 self.output_dir = output_dir
55 self.sample_image_size = sample_image_size 57 self.sample_image_size = sample_image_size
56 self.seed = seed if seed is not None else torch.random.seed() 58 self.sample_steps = sample_steps
59 self.sample_guidance_scale = sample_guidance_scale
57 self.sample_batches = sample_batches 60 self.sample_batches = sample_batches
58 self.sample_batch_size = sample_batch_size 61 self.sample_batch_size = sample_batch_size
62 self.seed = seed if seed is not None else torch.random.seed()
59 63
60 @torch.no_grad() 64 @torch.no_grad()
61 def checkpoint(self, step: int, postfix: str): 65 def checkpoint(self, step: int, postfix: str):
62 pass 66 pass
63 67
64 @torch.inference_mode() 68 @torch.inference_mode()
65 def save_samples( 69 def save_samples(self, pipeline, step: int):
66 self,
67 pipeline,
68 step: int,
69 num_inference_steps: int,
70 guidance_scale: float = 7.5,
71 eta: float = 0.0
72 ):
73 samples_path = Path(self.output_dir).joinpath("samples") 70 samples_path = Path(self.output_dir).joinpath("samples")
74 71
75 generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) 72 generator = torch.Generator(device=pipeline.device).manual_seed(self.seed)
@@ -110,9 +107,8 @@ class CheckpointerBase:
110 height=self.sample_image_size, 107 height=self.sample_image_size,
111 width=self.sample_image_size, 108 width=self.sample_image_size,
112 generator=gen, 109 generator=gen,
113 guidance_scale=guidance_scale, 110 guidance_scale=self.sample_guidance_scale,
114 eta=eta, 111 num_inference_steps=self.sample_steps,
115 num_inference_steps=num_inference_steps,
116 output_type='pil' 112 output_type='pil'
117 ).images 113 ).images
118 114