diff options
-rw-r--r-- | data/csv.py | 21 | ||||
-rw-r--r-- | infer.py | 19 | ||||
-rw-r--r-- | models/clip/tokenizer.py | 5 | ||||
-rw-r--r-- | train_dreambooth.py | 33 | ||||
-rw-r--r-- | train_ti.py | 33 | ||||
-rw-r--r-- | training/common.py | 97 | ||||
-rw-r--r-- | training/util.py | 26 |
7 files changed, 103 insertions, 131 deletions
diff --git a/data/csv.py b/data/csv.py index a3fef30..df3ee77 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -100,20 +100,16 @@ def generate_buckets( | |||
100 | return buckets, bucket_items, bucket_assignments | 100 | return buckets, bucket_items, bucket_assignments |
101 | 101 | ||
102 | 102 | ||
103 | def collate_fn( | 103 | def collate_fn(weight_dtype: torch.dtype, tokenizer: CLIPTokenizer, examples): |
104 | num_class_images: int, | 104 | with_prior = all("class_prompt_ids" in example for example in examples) |
105 | weight_dtype: torch.dtype, | 105 | |
106 | tokenizer: CLIPTokenizer, | ||
107 | examples | ||
108 | ): | ||
109 | prompt_ids = [example["prompt_ids"] for example in examples] | 106 | prompt_ids = [example["prompt_ids"] for example in examples] |
110 | nprompt_ids = [example["nprompt_ids"] for example in examples] | 107 | nprompt_ids = [example["nprompt_ids"] for example in examples] |
111 | 108 | ||
112 | input_ids = [example["instance_prompt_ids"] for example in examples] | 109 | input_ids = [example["instance_prompt_ids"] for example in examples] |
113 | pixel_values = [example["instance_images"] for example in examples] | 110 | pixel_values = [example["instance_images"] for example in examples] |
114 | 111 | ||
115 | # concat class and instance examples for prior preservation | 112 | if with_prior: |
116 | if num_class_images != 0 and "class_prompt_ids" in examples[0]: | ||
117 | input_ids += [example["class_prompt_ids"] for example in examples] | 113 | input_ids += [example["class_prompt_ids"] for example in examples] |
118 | pixel_values += [example["class_images"] for example in examples] | 114 | pixel_values += [example["class_images"] for example in examples] |
119 | 115 | ||
@@ -125,6 +121,7 @@ def collate_fn( | |||
125 | inputs = unify_input_ids(tokenizer, input_ids) | 121 | inputs = unify_input_ids(tokenizer, input_ids) |
126 | 122 | ||
127 | batch = { | 123 | batch = { |
124 | "with_prior": torch.tensor(with_prior), | ||
128 | "prompt_ids": prompts.input_ids, | 125 | "prompt_ids": prompts.input_ids, |
129 | "nprompt_ids": nprompts.input_ids, | 126 | "nprompt_ids": nprompts.input_ids, |
130 | "input_ids": inputs.input_ids, | 127 | "input_ids": inputs.input_ids, |
@@ -166,7 +163,6 @@ class VlpnDataModule(): | |||
166 | seed: Optional[int] = None, | 163 | seed: Optional[int] = None, |
167 | filter: Optional[Callable[[VlpnDataItem], bool]] = None, | 164 | filter: Optional[Callable[[VlpnDataItem], bool]] = None, |
168 | dtype: torch.dtype = torch.float32, | 165 | dtype: torch.dtype = torch.float32, |
169 | num_workers: int = 0 | ||
170 | ): | 166 | ): |
171 | super().__init__() | 167 | super().__init__() |
172 | 168 | ||
@@ -194,7 +190,6 @@ class VlpnDataModule(): | |||
194 | self.valid_set_repeat = valid_set_repeat | 190 | self.valid_set_repeat = valid_set_repeat |
195 | self.seed = seed | 191 | self.seed = seed |
196 | self.filter = filter | 192 | self.filter = filter |
197 | self.num_workers = num_workers | ||
198 | self.batch_size = batch_size | 193 | self.batch_size = batch_size |
199 | self.dtype = dtype | 194 | self.dtype = dtype |
200 | 195 | ||
@@ -290,16 +285,16 @@ class VlpnDataModule(): | |||
290 | size=self.size, interpolation=self.interpolation, | 285 | size=self.size, interpolation=self.interpolation, |
291 | ) | 286 | ) |
292 | 287 | ||
293 | collate_fn_ = partial(collate_fn, self.num_class_images, self.dtype, self.tokenizer) | 288 | collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer) |
294 | 289 | ||
295 | self.train_dataloader = DataLoader( | 290 | self.train_dataloader = DataLoader( |
296 | train_dataset, | 291 | train_dataset, |
297 | batch_size=None, pin_memory=True, collate_fn=collate_fn_, num_workers=self.num_workers | 292 | batch_size=None, pin_memory=True, collate_fn=collate_fn_ |
298 | ) | 293 | ) |
299 | 294 | ||
300 | self.val_dataloader = DataLoader( | 295 | self.val_dataloader = DataLoader( |
301 | val_dataset, | 296 | val_dataset, |
302 | batch_size=None, pin_memory=True, collate_fn=collate_fn_, num_workers=self.num_workers | 297 | batch_size=None, pin_memory=True, collate_fn=collate_fn_ |
303 | ) | 298 | ) |
304 | 299 | ||
305 | 300 | ||
@@ -214,21 +214,10 @@ def load_embeddings(pipeline, embeddings_dir): | |||
214 | def create_pipeline(model, dtype): | 214 | def create_pipeline(model, dtype): |
215 | print("Loading Stable Diffusion pipeline...") | 215 | print("Loading Stable Diffusion pipeline...") |
216 | 216 | ||
217 | tokenizer = MultiCLIPTokenizer.from_pretrained(model, subfolder='tokenizer', torch_dtype=dtype) | 217 | pipeline = VlpnStableDiffusion.from_pretrained(model, torch_dtype=dtype) |
218 | text_encoder = CLIPTextModel.from_pretrained(model, subfolder='text_encoder', torch_dtype=dtype) | 218 | |
219 | vae = AutoencoderKL.from_pretrained(model, subfolder='vae', torch_dtype=dtype) | 219 | patch_managed_embeddings(pipeline.text_encoder) |
220 | unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype) | 220 | |
221 | scheduler = DDIMScheduler.from_pretrained(model, subfolder='scheduler', torch_dtype=dtype) | ||
222 | |||
223 | patch_managed_embeddings(text_encoder) | ||
224 | |||
225 | pipeline = VlpnStableDiffusion( | ||
226 | text_encoder=text_encoder, | ||
227 | vae=vae, | ||
228 | unet=unet, | ||
229 | tokenizer=tokenizer, | ||
230 | scheduler=scheduler, | ||
231 | ) | ||
232 | pipeline.enable_xformers_memory_efficient_attention() | 221 | pipeline.enable_xformers_memory_efficient_attention() |
233 | pipeline.enable_vae_slicing() | 222 | pipeline.enable_vae_slicing() |
234 | pipeline.to("cuda") | 223 | pipeline.to("cuda") |
diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py index 39c41ed..789b525 100644 --- a/models/clip/tokenizer.py +++ b/models/clip/tokenizer.py | |||
@@ -55,6 +55,9 @@ def shuffle_auto(tokens: list[int]): | |||
55 | return shuffle_all(tokens) | 55 | return shuffle_all(tokens) |
56 | 56 | ||
57 | 57 | ||
58 | ShuffleAlgorithm = Union[bool, Literal["all", "trailing", "leading", "between", "off"]] | ||
59 | |||
60 | |||
58 | class MultiCLIPTokenizer(CLIPTokenizer): | 61 | class MultiCLIPTokenizer(CLIPTokenizer): |
59 | def __init__(self, *args, **kwargs): | 62 | def __init__(self, *args, **kwargs): |
60 | super().__init__(*args, **kwargs) | 63 | super().__init__(*args, **kwargs) |
@@ -73,7 +76,7 @@ class MultiCLIPTokenizer(CLIPTokenizer): | |||
73 | def set_dropout(self, dropout: float): | 76 | def set_dropout(self, dropout: float): |
74 | self.dropout = dropout | 77 | self.dropout = dropout |
75 | 78 | ||
76 | def set_use_vector_shuffle(self, algorithm: Union[bool, Literal["all", "trailing", "leading", "between", "off"]]): | 79 | def set_use_vector_shuffle(self, algorithm: ShuffleAlgorithm): |
77 | if algorithm == "leading": | 80 | if algorithm == "leading": |
78 | self.vector_shuffle = shuffle_leading | 81 | self.vector_shuffle = shuffle_leading |
79 | elif algorithm == "trailing": | 82 | elif algorithm == "trailing": |
diff --git a/train_dreambooth.py b/train_dreambooth.py index a1802a0..c180170 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -195,15 +195,6 @@ def parse_args(): | |||
195 | ), | 195 | ), |
196 | ) | 196 | ) |
197 | parser.add_argument( | 197 | parser.add_argument( |
198 | "--dataloader_num_workers", | ||
199 | type=int, | ||
200 | default=0, | ||
201 | help=( | ||
202 | "The number of subprocesses to use for data loading. 0 means that the data will be loaded in the main" | ||
203 | " process." | ||
204 | ), | ||
205 | ) | ||
206 | parser.add_argument( | ||
207 | "--num_train_epochs", | 198 | "--num_train_epochs", |
208 | type=int, | 199 | type=int, |
209 | default=100 | 200 | default=100 |
@@ -577,24 +568,24 @@ def main(): | |||
577 | ) | 568 | ) |
578 | 569 | ||
579 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | 570 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") |
580 | basepath = Path(args.output_dir).joinpath(slugify(args.project), now) | 571 | output_dir = Path(args.output_dir).joinpath(slugify(args.project), now) |
581 | basepath.mkdir(parents=True, exist_ok=True) | 572 | output_dir.mkdir(parents=True, exist_ok=True) |
582 | 573 | ||
583 | accelerator = Accelerator( | 574 | accelerator = Accelerator( |
584 | log_with=LoggerType.TENSORBOARD, | 575 | log_with=LoggerType.TENSORBOARD, |
585 | logging_dir=f"{basepath}", | 576 | logging_dir=f"{output_dir}", |
586 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 577 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
587 | mixed_precision=args.mixed_precision | 578 | mixed_precision=args.mixed_precision |
588 | ) | 579 | ) |
589 | 580 | ||
590 | logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) | 581 | logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) |
591 | 582 | ||
592 | if args.seed is None: | 583 | if args.seed is None: |
593 | args.seed = torch.random.seed() >> 32 | 584 | args.seed = torch.random.seed() >> 32 |
594 | 585 | ||
595 | set_seed(args.seed) | 586 | set_seed(args.seed) |
596 | 587 | ||
597 | save_args(basepath, args) | 588 | save_args(output_dir, args) |
598 | 589 | ||
599 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( | 590 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( |
600 | args.pretrained_model_name_or_path) | 591 | args.pretrained_model_name_or_path) |
@@ -618,7 +609,7 @@ def main(): | |||
618 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) | 609 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) |
619 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") | 610 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") |
620 | 611 | ||
621 | placeholder_token_ids = add_placeholder_tokens( | 612 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( |
622 | tokenizer=tokenizer, | 613 | tokenizer=tokenizer, |
623 | embeddings=embeddings, | 614 | embeddings=embeddings, |
624 | placeholder_tokens=args.placeholder_tokens, | 615 | placeholder_tokens=args.placeholder_tokens, |
@@ -627,7 +618,9 @@ def main(): | |||
627 | ) | 618 | ) |
628 | 619 | ||
629 | if len(placeholder_token_ids) != 0: | 620 | if len(placeholder_token_ids) != 0: |
630 | print(f"Added {len(placeholder_token_ids)} new tokens: {list(zip(args.placeholder_tokens, placeholder_token_ids))}") | 621 | initializer_token_id_lens = [len(id) for id in initializer_token_ids] |
622 | placeholder_token_stats = list(zip(args.placeholder_tokens, placeholder_token_ids, initializer_token_id_lens)) | ||
623 | print(f"Added {len(placeholder_token_ids)} new tokens: {placeholder_token_stats}") | ||
631 | 624 | ||
632 | if args.use_ema: | 625 | if args.use_ema: |
633 | ema_unet = EMAModel( | 626 | ema_unet = EMAModel( |
@@ -726,7 +719,6 @@ def main(): | |||
726 | template_key=args.train_data_template, | 719 | template_key=args.train_data_template, |
727 | valid_set_size=args.valid_set_size, | 720 | valid_set_size=args.valid_set_size, |
728 | valid_set_repeat=args.valid_set_repeat, | 721 | valid_set_repeat=args.valid_set_repeat, |
729 | num_workers=args.dataloader_num_workers, | ||
730 | seed=args.seed, | 722 | seed=args.seed, |
731 | filter=keyword_filter, | 723 | filter=keyword_filter, |
732 | dtype=weight_dtype | 724 | dtype=weight_dtype |
@@ -830,7 +822,6 @@ def main(): | |||
830 | noise_scheduler, | 822 | noise_scheduler, |
831 | unet, | 823 | unet, |
832 | text_encoder, | 824 | text_encoder, |
833 | args.num_class_images, | ||
834 | args.prior_loss_weight, | 825 | args.prior_loss_weight, |
835 | args.seed, | 826 | args.seed, |
836 | ) | 827 | ) |
@@ -848,7 +839,8 @@ def main(): | |||
848 | scheduler=sample_scheduler, | 839 | scheduler=sample_scheduler, |
849 | placeholder_tokens=args.placeholder_tokens, | 840 | placeholder_tokens=args.placeholder_tokens, |
850 | placeholder_token_ids=placeholder_token_ids, | 841 | placeholder_token_ids=placeholder_token_ids, |
851 | output_dir=basepath, | 842 | output_dir=output_dir, |
843 | sample_steps=args.sample_steps, | ||
852 | sample_image_size=args.sample_image_size, | 844 | sample_image_size=args.sample_image_size, |
853 | sample_batch_size=args.sample_batch_size, | 845 | sample_batch_size=args.sample_batch_size, |
854 | sample_batches=args.sample_batches, | 846 | sample_batches=args.sample_batches, |
@@ -873,7 +865,7 @@ def main(): | |||
873 | ) | 865 | ) |
874 | lr_finder.run(num_epochs=100, end_lr=1e2) | 866 | lr_finder.run(num_epochs=100, end_lr=1e2) |
875 | 867 | ||
876 | plt.savefig(basepath.joinpath("lr.png"), dpi=300) | 868 | plt.savefig(output_dir.joinpath("lr.png"), dpi=300) |
877 | plt.close() | 869 | plt.close() |
878 | else: | 870 | else: |
879 | train_loop( | 871 | train_loop( |
@@ -886,7 +878,6 @@ def main(): | |||
886 | val_dataloader=val_dataloader, | 878 | val_dataloader=val_dataloader, |
887 | loss_step=loss_step_, | 879 | loss_step=loss_step_, |
888 | sample_frequency=args.sample_frequency, | 880 | sample_frequency=args.sample_frequency, |
889 | sample_steps=args.sample_steps, | ||
890 | checkpoint_frequency=args.checkpoint_frequency, | 881 | checkpoint_frequency=args.checkpoint_frequency, |
891 | global_step_offset=0, | 882 | global_step_offset=0, |
892 | num_epochs=args.num_train_epochs, | 883 | num_epochs=args.num_train_epochs, |
diff --git a/train_ti.py b/train_ti.py index d2ca7eb..d752927 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -181,15 +181,6 @@ def parse_args(): | |||
181 | help='Vector shuffling algorithm. Choose between ["all", "trailing", "leading", "between", "auto", "off"]', | 181 | help='Vector shuffling algorithm. Choose between ["all", "trailing", "leading", "between", "auto", "off"]', |
182 | ) | 182 | ) |
183 | parser.add_argument( | 183 | parser.add_argument( |
184 | "--dataloader_num_workers", | ||
185 | type=int, | ||
186 | default=0, | ||
187 | help=( | ||
188 | "The number of subprocesses to use for data loading. 0 means that the data will be loaded in the main" | ||
189 | " process." | ||
190 | ), | ||
191 | ) | ||
192 | parser.add_argument( | ||
193 | "--num_train_epochs", | 184 | "--num_train_epochs", |
194 | type=int, | 185 | type=int, |
195 | default=100 | 186 | default=100 |
@@ -575,24 +566,24 @@ def main(): | |||
575 | 566 | ||
576 | global_step_offset = args.global_step | 567 | global_step_offset = args.global_step |
577 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | 568 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") |
578 | basepath = Path(args.output_dir).joinpath(slugify(args.project), now) | 569 | output_dir = Path(args.output_dir).joinpath(slugify(args.project), now) |
579 | basepath.mkdir(parents=True, exist_ok=True) | 570 | output_dir.mkdir(parents=True, exist_ok=True) |
580 | 571 | ||
581 | accelerator = Accelerator( | 572 | accelerator = Accelerator( |
582 | log_with=LoggerType.TENSORBOARD, | 573 | log_with=LoggerType.TENSORBOARD, |
583 | logging_dir=f"{basepath}", | 574 | logging_dir=f"{output_dir}", |
584 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 575 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
585 | mixed_precision=args.mixed_precision | 576 | mixed_precision=args.mixed_precision |
586 | ) | 577 | ) |
587 | 578 | ||
588 | logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) | 579 | logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) |
589 | 580 | ||
590 | if args.seed is None: | 581 | if args.seed is None: |
591 | args.seed = torch.random.seed() >> 32 | 582 | args.seed = torch.random.seed() >> 32 |
592 | 583 | ||
593 | set_seed(args.seed) | 584 | set_seed(args.seed) |
594 | 585 | ||
595 | save_args(basepath, args) | 586 | save_args(output_dir, args) |
596 | 587 | ||
597 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( | 588 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( |
598 | args.pretrained_model_name_or_path) | 589 | args.pretrained_model_name_or_path) |
@@ -616,7 +607,7 @@ def main(): | |||
616 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) | 607 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) |
617 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") | 608 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") |
618 | 609 | ||
619 | placeholder_token_ids = add_placeholder_tokens( | 610 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( |
620 | tokenizer=tokenizer, | 611 | tokenizer=tokenizer, |
621 | embeddings=embeddings, | 612 | embeddings=embeddings, |
622 | placeholder_tokens=args.placeholder_tokens, | 613 | placeholder_tokens=args.placeholder_tokens, |
@@ -625,7 +616,9 @@ def main(): | |||
625 | ) | 616 | ) |
626 | 617 | ||
627 | if len(placeholder_token_ids) != 0: | 618 | if len(placeholder_token_ids) != 0: |
628 | print(f"Added {len(placeholder_token_ids)} new tokens: {list(zip(args.placeholder_tokens, placeholder_token_ids))}") | 619 | initializer_token_id_lens = [len(id) for id in initializer_token_ids] |
620 | placeholder_token_stats = list(zip(args.placeholder_tokens, placeholder_token_ids, initializer_token_id_lens)) | ||
621 | print(f"Added {len(placeholder_token_ids)} new tokens: {placeholder_token_stats}") | ||
629 | 622 | ||
630 | if args.use_ema: | 623 | if args.use_ema: |
631 | ema_embeddings = EMAModel( | 624 | ema_embeddings = EMAModel( |
@@ -708,7 +701,6 @@ def main(): | |||
708 | template_key=args.train_data_template, | 701 | template_key=args.train_data_template, |
709 | valid_set_size=args.valid_set_size, | 702 | valid_set_size=args.valid_set_size, |
710 | valid_set_repeat=args.valid_set_repeat, | 703 | valid_set_repeat=args.valid_set_repeat, |
711 | num_workers=args.dataloader_num_workers, | ||
712 | seed=args.seed, | 704 | seed=args.seed, |
713 | filter=keyword_filter, | 705 | filter=keyword_filter, |
714 | dtype=weight_dtype | 706 | dtype=weight_dtype |
@@ -807,7 +799,6 @@ def main(): | |||
807 | noise_scheduler, | 799 | noise_scheduler, |
808 | unet, | 800 | unet, |
809 | text_encoder, | 801 | text_encoder, |
810 | args.num_class_images != 0, | ||
811 | args.prior_loss_weight, | 802 | args.prior_loss_weight, |
812 | args.seed, | 803 | args.seed, |
813 | ) | 804 | ) |
@@ -825,7 +816,8 @@ def main(): | |||
825 | scheduler=sample_scheduler, | 816 | scheduler=sample_scheduler, |
826 | placeholder_tokens=args.placeholder_tokens, | 817 | placeholder_tokens=args.placeholder_tokens, |
827 | placeholder_token_ids=placeholder_token_ids, | 818 | placeholder_token_ids=placeholder_token_ids, |
828 | output_dir=basepath, | 819 | output_dir=output_dir, |
820 | sample_steps=args.sample_steps, | ||
829 | sample_image_size=args.sample_image_size, | 821 | sample_image_size=args.sample_image_size, |
830 | sample_batch_size=args.sample_batch_size, | 822 | sample_batch_size=args.sample_batch_size, |
831 | sample_batches=args.sample_batches, | 823 | sample_batches=args.sample_batches, |
@@ -849,7 +841,7 @@ def main(): | |||
849 | ) | 841 | ) |
850 | lr_finder.run(num_epochs=100, end_lr=1e3) | 842 | lr_finder.run(num_epochs=100, end_lr=1e3) |
851 | 843 | ||
852 | plt.savefig(basepath.joinpath("lr.png"), dpi=300) | 844 | plt.savefig(output_dir.joinpath("lr.png"), dpi=300) |
853 | plt.close() | 845 | plt.close() |
854 | else: | 846 | else: |
855 | train_loop( | 847 | train_loop( |
@@ -862,7 +854,6 @@ def main(): | |||
862 | val_dataloader=val_dataloader, | 854 | val_dataloader=val_dataloader, |
863 | loss_step=loss_step_, | 855 | loss_step=loss_step_, |
864 | sample_frequency=args.sample_frequency, | 856 | sample_frequency=args.sample_frequency, |
865 | sample_steps=args.sample_steps, | ||
866 | checkpoint_frequency=args.checkpoint_frequency, | 857 | checkpoint_frequency=args.checkpoint_frequency, |
867 | global_step_offset=global_step_offset, | 858 | global_step_offset=global_step_offset, |
868 | num_epochs=args.num_train_epochs, | 859 | num_epochs=args.num_train_epochs, |
diff --git a/training/common.py b/training/common.py index b6964a3..f5ab326 100644 --- a/training/common.py +++ b/training/common.py | |||
@@ -45,42 +45,44 @@ def generate_class_images( | |||
45 | ): | 45 | ): |
46 | missing_data = [item for item in data_train if not item.class_image_path.exists()] | 46 | missing_data = [item for item in data_train if not item.class_image_path.exists()] |
47 | 47 | ||
48 | if len(missing_data) != 0: | 48 | if len(missing_data) == 0: |
49 | batched_data = [ | 49 | return |
50 | missing_data[i:i+sample_batch_size] | 50 | |
51 | for i in range(0, len(missing_data), sample_batch_size) | 51 | batched_data = [ |
52 | ] | 52 | missing_data[i:i+sample_batch_size] |
53 | 53 | for i in range(0, len(missing_data), sample_batch_size) | |
54 | pipeline = VlpnStableDiffusion( | 54 | ] |
55 | text_encoder=text_encoder, | 55 | |
56 | vae=vae, | 56 | pipeline = VlpnStableDiffusion( |
57 | unet=unet, | 57 | text_encoder=text_encoder, |
58 | tokenizer=tokenizer, | 58 | vae=vae, |
59 | scheduler=scheduler, | 59 | unet=unet, |
60 | ).to(accelerator.device) | 60 | tokenizer=tokenizer, |
61 | pipeline.set_progress_bar_config(dynamic_ncols=True) | 61 | scheduler=scheduler, |
62 | 62 | ).to(accelerator.device) | |
63 | with torch.inference_mode(): | 63 | pipeline.set_progress_bar_config(dynamic_ncols=True) |
64 | for batch in batched_data: | 64 | |
65 | image_name = [item.class_image_path for item in batch] | 65 | with torch.inference_mode(): |
66 | prompt = [item.cprompt for item in batch] | 66 | for batch in batched_data: |
67 | nprompt = [item.nprompt for item in batch] | 67 | image_name = [item.class_image_path for item in batch] |
68 | 68 | prompt = [item.cprompt for item in batch] | |
69 | images = pipeline( | 69 | nprompt = [item.nprompt for item in batch] |
70 | prompt=prompt, | 70 | |
71 | negative_prompt=nprompt, | 71 | images = pipeline( |
72 | height=sample_image_size, | 72 | prompt=prompt, |
73 | width=sample_image_size, | 73 | negative_prompt=nprompt, |
74 | num_inference_steps=sample_steps | 74 | height=sample_image_size, |
75 | ).images | 75 | width=sample_image_size, |
76 | 76 | num_inference_steps=sample_steps | |
77 | for i, image in enumerate(images): | 77 | ).images |
78 | image.save(image_name[i]) | 78 | |
79 | 79 | for i, image in enumerate(images): | |
80 | del pipeline | 80 | image.save(image_name[i]) |
81 | 81 | ||
82 | if torch.cuda.is_available(): | 82 | del pipeline |
83 | torch.cuda.empty_cache() | 83 | |
84 | if torch.cuda.is_available(): | ||
85 | torch.cuda.empty_cache() | ||
84 | 86 | ||
85 | 87 | ||
86 | def get_models(pretrained_model_name_or_path: str): | 88 | def get_models(pretrained_model_name_or_path: str): |
@@ -119,7 +121,7 @@ def add_placeholder_tokens( | |||
119 | for (placeholder_token_id, initializer_token_id) in zip(placeholder_token_ids, initializer_token_ids): | 121 | for (placeholder_token_id, initializer_token_id) in zip(placeholder_token_ids, initializer_token_ids): |
120 | embeddings.add_embed(placeholder_token_id, initializer_token_id) | 122 | embeddings.add_embed(placeholder_token_id, initializer_token_id) |
121 | 123 | ||
122 | return placeholder_token_ids | 124 | return placeholder_token_ids, initializer_token_ids |
123 | 125 | ||
124 | 126 | ||
125 | def loss_step( | 127 | def loss_step( |
@@ -127,7 +129,6 @@ def loss_step( | |||
127 | noise_scheduler: DDPMScheduler, | 129 | noise_scheduler: DDPMScheduler, |
128 | unet: UNet2DConditionModel, | 130 | unet: UNet2DConditionModel, |
129 | text_encoder: CLIPTextModel, | 131 | text_encoder: CLIPTextModel, |
130 | with_prior: bool, | ||
131 | prior_loss_weight: float, | 132 | prior_loss_weight: float, |
132 | seed: int, | 133 | seed: int, |
133 | step: int, | 134 | step: int, |
@@ -138,16 +139,23 @@ def loss_step( | |||
138 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() | 139 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() |
139 | latents = latents * 0.18215 | 140 | latents = latents * 0.18215 |
140 | 141 | ||
142 | generator = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None | ||
143 | |||
141 | # Sample noise that we'll add to the latents | 144 | # Sample noise that we'll add to the latents |
142 | noise = torch.randn_like(latents) | 145 | noise = torch.randn( |
146 | latents.shape, | ||
147 | dtype=latents.dtype, | ||
148 | layout=latents.layout, | ||
149 | device=latents.device, | ||
150 | generator=generator | ||
151 | ) | ||
143 | bsz = latents.shape[0] | 152 | bsz = latents.shape[0] |
144 | # Sample a random timestep for each image | 153 | # Sample a random timestep for each image |
145 | timesteps_gen = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None | ||
146 | timesteps = torch.randint( | 154 | timesteps = torch.randint( |
147 | 0, | 155 | 0, |
148 | noise_scheduler.config.num_train_timesteps, | 156 | noise_scheduler.config.num_train_timesteps, |
149 | (bsz,), | 157 | (bsz,), |
150 | generator=timesteps_gen, | 158 | generator=generator, |
151 | device=latents.device, | 159 | device=latents.device, |
152 | ) | 160 | ) |
153 | timesteps = timesteps.long() | 161 | timesteps = timesteps.long() |
@@ -176,7 +184,7 @@ def loss_step( | |||
176 | else: | 184 | else: |
177 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | 185 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") |
178 | 186 | ||
179 | if with_prior: | 187 | if batch["with_prior"]: |
180 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. | 188 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. |
181 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) | 189 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) |
182 | target, target_prior = torch.chunk(target, 2, dim=0) | 190 | target, target_prior = torch.chunk(target, 2, dim=0) |
@@ -207,7 +215,6 @@ def train_loop( | |||
207 | val_dataloader: DataLoader, | 215 | val_dataloader: DataLoader, |
208 | loss_step: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], | 216 | loss_step: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], |
209 | sample_frequency: int = 10, | 217 | sample_frequency: int = 10, |
210 | sample_steps: int = 20, | ||
211 | checkpoint_frequency: int = 50, | 218 | checkpoint_frequency: int = 50, |
212 | global_step_offset: int = 0, | 219 | global_step_offset: int = 0, |
213 | num_epochs: int = 100, | 220 | num_epochs: int = 100, |
@@ -251,7 +258,7 @@ def train_loop( | |||
251 | for epoch in range(num_epochs): | 258 | for epoch in range(num_epochs): |
252 | if accelerator.is_main_process: | 259 | if accelerator.is_main_process: |
253 | if epoch % sample_frequency == 0: | 260 | if epoch % sample_frequency == 0: |
254 | checkpointer.save_samples(global_step + global_step_offset, sample_steps) | 261 | checkpointer.save_samples(global_step + global_step_offset) |
255 | 262 | ||
256 | if epoch % checkpoint_frequency == 0 and epoch != 0: | 263 | if epoch % checkpoint_frequency == 0 and epoch != 0: |
257 | checkpointer.checkpoint(global_step + global_step_offset, "training") | 264 | checkpointer.checkpoint(global_step + global_step_offset, "training") |
@@ -353,7 +360,7 @@ def train_loop( | |||
353 | if accelerator.is_main_process: | 360 | if accelerator.is_main_process: |
354 | print("Finished!") | 361 | print("Finished!") |
355 | checkpointer.checkpoint(global_step + global_step_offset, "end") | 362 | checkpointer.checkpoint(global_step + global_step_offset, "end") |
356 | checkpointer.save_samples(global_step + global_step_offset, sample_steps) | 363 | checkpointer.save_samples(global_step + global_step_offset) |
357 | accelerator.end_training() | 364 | accelerator.end_training() |
358 | 365 | ||
359 | except KeyboardInterrupt: | 366 | except KeyboardInterrupt: |
diff --git a/training/util.py b/training/util.py index cc4cdee..1008021 100644 --- a/training/util.py +++ b/training/util.py | |||
@@ -44,32 +44,29 @@ class CheckpointerBase: | |||
44 | train_dataloader, | 44 | train_dataloader, |
45 | val_dataloader, | 45 | val_dataloader, |
46 | output_dir: Path, | 46 | output_dir: Path, |
47 | sample_image_size: int, | 47 | sample_steps: int = 20, |
48 | sample_batches: int, | 48 | sample_guidance_scale: float = 7.5, |
49 | sample_batch_size: int, | 49 | sample_image_size: int = 768, |
50 | sample_batches: int = 1, | ||
51 | sample_batch_size: int = 1, | ||
50 | seed: Optional[int] = None | 52 | seed: Optional[int] = None |
51 | ): | 53 | ): |
52 | self.train_dataloader = train_dataloader | 54 | self.train_dataloader = train_dataloader |
53 | self.val_dataloader = val_dataloader | 55 | self.val_dataloader = val_dataloader |
54 | self.output_dir = output_dir | 56 | self.output_dir = output_dir |
55 | self.sample_image_size = sample_image_size | 57 | self.sample_image_size = sample_image_size |
56 | self.seed = seed if seed is not None else torch.random.seed() | 58 | self.sample_steps = sample_steps |
59 | self.sample_guidance_scale = sample_guidance_scale | ||
57 | self.sample_batches = sample_batches | 60 | self.sample_batches = sample_batches |
58 | self.sample_batch_size = sample_batch_size | 61 | self.sample_batch_size = sample_batch_size |
62 | self.seed = seed if seed is not None else torch.random.seed() | ||
59 | 63 | ||
60 | @torch.no_grad() | 64 | @torch.no_grad() |
61 | def checkpoint(self, step: int, postfix: str): | 65 | def checkpoint(self, step: int, postfix: str): |
62 | pass | 66 | pass |
63 | 67 | ||
64 | @torch.inference_mode() | 68 | @torch.inference_mode() |
65 | def save_samples( | 69 | def save_samples(self, pipeline, step: int): |
66 | self, | ||
67 | pipeline, | ||
68 | step: int, | ||
69 | num_inference_steps: int, | ||
70 | guidance_scale: float = 7.5, | ||
71 | eta: float = 0.0 | ||
72 | ): | ||
73 | samples_path = Path(self.output_dir).joinpath("samples") | 70 | samples_path = Path(self.output_dir).joinpath("samples") |
74 | 71 | ||
75 | generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) | 72 | generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) |
@@ -110,9 +107,8 @@ class CheckpointerBase: | |||
110 | height=self.sample_image_size, | 107 | height=self.sample_image_size, |
111 | width=self.sample_image_size, | 108 | width=self.sample_image_size, |
112 | generator=gen, | 109 | generator=gen, |
113 | guidance_scale=guidance_scale, | 110 | guidance_scale=self.sample_guidance_scale, |
114 | eta=eta, | 111 | num_inference_steps=self.sample_steps, |
115 | num_inference_steps=num_inference_steps, | ||
116 | output_type='pil' | 112 | output_type='pil' |
117 | ).images | 113 | ).images |
118 | 114 | ||