diff options
| -rw-r--r-- | data/csv.py | 21 | ||||
| -rw-r--r-- | infer.py | 15 | ||||
| -rw-r--r-- | models/clip/tokenizer.py | 5 | ||||
| -rw-r--r-- | train_dreambooth.py | 33 | ||||
| -rw-r--r-- | train_ti.py | 33 | ||||
| -rw-r--r-- | training/common.py | 85 | ||||
| -rw-r--r-- | training/util.py | 26 |
7 files changed, 95 insertions, 123 deletions
diff --git a/data/csv.py b/data/csv.py index a3fef30..df3ee77 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -100,20 +100,16 @@ def generate_buckets( | |||
| 100 | return buckets, bucket_items, bucket_assignments | 100 | return buckets, bucket_items, bucket_assignments |
| 101 | 101 | ||
| 102 | 102 | ||
| 103 | def collate_fn( | 103 | def collate_fn(weight_dtype: torch.dtype, tokenizer: CLIPTokenizer, examples): |
| 104 | num_class_images: int, | 104 | with_prior = all("class_prompt_ids" in example for example in examples) |
| 105 | weight_dtype: torch.dtype, | 105 | |
| 106 | tokenizer: CLIPTokenizer, | ||
| 107 | examples | ||
| 108 | ): | ||
| 109 | prompt_ids = [example["prompt_ids"] for example in examples] | 106 | prompt_ids = [example["prompt_ids"] for example in examples] |
| 110 | nprompt_ids = [example["nprompt_ids"] for example in examples] | 107 | nprompt_ids = [example["nprompt_ids"] for example in examples] |
| 111 | 108 | ||
| 112 | input_ids = [example["instance_prompt_ids"] for example in examples] | 109 | input_ids = [example["instance_prompt_ids"] for example in examples] |
| 113 | pixel_values = [example["instance_images"] for example in examples] | 110 | pixel_values = [example["instance_images"] for example in examples] |
| 114 | 111 | ||
| 115 | # concat class and instance examples for prior preservation | 112 | if with_prior: |
| 116 | if num_class_images != 0 and "class_prompt_ids" in examples[0]: | ||
| 117 | input_ids += [example["class_prompt_ids"] for example in examples] | 113 | input_ids += [example["class_prompt_ids"] for example in examples] |
| 118 | pixel_values += [example["class_images"] for example in examples] | 114 | pixel_values += [example["class_images"] for example in examples] |
| 119 | 115 | ||
| @@ -125,6 +121,7 @@ def collate_fn( | |||
| 125 | inputs = unify_input_ids(tokenizer, input_ids) | 121 | inputs = unify_input_ids(tokenizer, input_ids) |
| 126 | 122 | ||
| 127 | batch = { | 123 | batch = { |
| 124 | "with_prior": torch.tensor(with_prior), | ||
| 128 | "prompt_ids": prompts.input_ids, | 125 | "prompt_ids": prompts.input_ids, |
| 129 | "nprompt_ids": nprompts.input_ids, | 126 | "nprompt_ids": nprompts.input_ids, |
| 130 | "input_ids": inputs.input_ids, | 127 | "input_ids": inputs.input_ids, |
| @@ -166,7 +163,6 @@ class VlpnDataModule(): | |||
| 166 | seed: Optional[int] = None, | 163 | seed: Optional[int] = None, |
| 167 | filter: Optional[Callable[[VlpnDataItem], bool]] = None, | 164 | filter: Optional[Callable[[VlpnDataItem], bool]] = None, |
| 168 | dtype: torch.dtype = torch.float32, | 165 | dtype: torch.dtype = torch.float32, |
| 169 | num_workers: int = 0 | ||
| 170 | ): | 166 | ): |
| 171 | super().__init__() | 167 | super().__init__() |
| 172 | 168 | ||
| @@ -194,7 +190,6 @@ class VlpnDataModule(): | |||
| 194 | self.valid_set_repeat = valid_set_repeat | 190 | self.valid_set_repeat = valid_set_repeat |
| 195 | self.seed = seed | 191 | self.seed = seed |
| 196 | self.filter = filter | 192 | self.filter = filter |
| 197 | self.num_workers = num_workers | ||
| 198 | self.batch_size = batch_size | 193 | self.batch_size = batch_size |
| 199 | self.dtype = dtype | 194 | self.dtype = dtype |
| 200 | 195 | ||
| @@ -290,16 +285,16 @@ class VlpnDataModule(): | |||
| 290 | size=self.size, interpolation=self.interpolation, | 285 | size=self.size, interpolation=self.interpolation, |
| 291 | ) | 286 | ) |
| 292 | 287 | ||
| 293 | collate_fn_ = partial(collate_fn, self.num_class_images, self.dtype, self.tokenizer) | 288 | collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer) |
| 294 | 289 | ||
| 295 | self.train_dataloader = DataLoader( | 290 | self.train_dataloader = DataLoader( |
| 296 | train_dataset, | 291 | train_dataset, |
| 297 | batch_size=None, pin_memory=True, collate_fn=collate_fn_, num_workers=self.num_workers | 292 | batch_size=None, pin_memory=True, collate_fn=collate_fn_ |
| 298 | ) | 293 | ) |
| 299 | 294 | ||
| 300 | self.val_dataloader = DataLoader( | 295 | self.val_dataloader = DataLoader( |
| 301 | val_dataset, | 296 | val_dataset, |
| 302 | batch_size=None, pin_memory=True, collate_fn=collate_fn_, num_workers=self.num_workers | 297 | batch_size=None, pin_memory=True, collate_fn=collate_fn_ |
| 303 | ) | 298 | ) |
| 304 | 299 | ||
| 305 | 300 | ||
| @@ -214,21 +214,10 @@ def load_embeddings(pipeline, embeddings_dir): | |||
| 214 | def create_pipeline(model, dtype): | 214 | def create_pipeline(model, dtype): |
| 215 | print("Loading Stable Diffusion pipeline...") | 215 | print("Loading Stable Diffusion pipeline...") |
| 216 | 216 | ||
| 217 | tokenizer = MultiCLIPTokenizer.from_pretrained(model, subfolder='tokenizer', torch_dtype=dtype) | 217 | pipeline = VlpnStableDiffusion.from_pretrained(model, torch_dtype=dtype) |
| 218 | text_encoder = CLIPTextModel.from_pretrained(model, subfolder='text_encoder', torch_dtype=dtype) | ||
| 219 | vae = AutoencoderKL.from_pretrained(model, subfolder='vae', torch_dtype=dtype) | ||
| 220 | unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype) | ||
| 221 | scheduler = DDIMScheduler.from_pretrained(model, subfolder='scheduler', torch_dtype=dtype) | ||
| 222 | 218 | ||
| 223 | patch_managed_embeddings(text_encoder) | 219 | patch_managed_embeddings(pipeline.text_encoder) |
| 224 | 220 | ||
| 225 | pipeline = VlpnStableDiffusion( | ||
| 226 | text_encoder=text_encoder, | ||
| 227 | vae=vae, | ||
| 228 | unet=unet, | ||
| 229 | tokenizer=tokenizer, | ||
| 230 | scheduler=scheduler, | ||
| 231 | ) | ||
| 232 | pipeline.enable_xformers_memory_efficient_attention() | 221 | pipeline.enable_xformers_memory_efficient_attention() |
| 233 | pipeline.enable_vae_slicing() | 222 | pipeline.enable_vae_slicing() |
| 234 | pipeline.to("cuda") | 223 | pipeline.to("cuda") |
diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py index 39c41ed..789b525 100644 --- a/models/clip/tokenizer.py +++ b/models/clip/tokenizer.py | |||
| @@ -55,6 +55,9 @@ def shuffle_auto(tokens: list[int]): | |||
| 55 | return shuffle_all(tokens) | 55 | return shuffle_all(tokens) |
| 56 | 56 | ||
| 57 | 57 | ||
| 58 | ShuffleAlgorithm = Union[bool, Literal["all", "trailing", "leading", "between", "off"]] | ||
| 59 | |||
| 60 | |||
| 58 | class MultiCLIPTokenizer(CLIPTokenizer): | 61 | class MultiCLIPTokenizer(CLIPTokenizer): |
| 59 | def __init__(self, *args, **kwargs): | 62 | def __init__(self, *args, **kwargs): |
| 60 | super().__init__(*args, **kwargs) | 63 | super().__init__(*args, **kwargs) |
| @@ -73,7 +76,7 @@ class MultiCLIPTokenizer(CLIPTokenizer): | |||
| 73 | def set_dropout(self, dropout: float): | 76 | def set_dropout(self, dropout: float): |
| 74 | self.dropout = dropout | 77 | self.dropout = dropout |
| 75 | 78 | ||
| 76 | def set_use_vector_shuffle(self, algorithm: Union[bool, Literal["all", "trailing", "leading", "between", "off"]]): | 79 | def set_use_vector_shuffle(self, algorithm: ShuffleAlgorithm): |
| 77 | if algorithm == "leading": | 80 | if algorithm == "leading": |
| 78 | self.vector_shuffle = shuffle_leading | 81 | self.vector_shuffle = shuffle_leading |
| 79 | elif algorithm == "trailing": | 82 | elif algorithm == "trailing": |
diff --git a/train_dreambooth.py b/train_dreambooth.py index a1802a0..c180170 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -195,15 +195,6 @@ def parse_args(): | |||
| 195 | ), | 195 | ), |
| 196 | ) | 196 | ) |
| 197 | parser.add_argument( | 197 | parser.add_argument( |
| 198 | "--dataloader_num_workers", | ||
| 199 | type=int, | ||
| 200 | default=0, | ||
| 201 | help=( | ||
| 202 | "The number of subprocesses to use for data loading. 0 means that the data will be loaded in the main" | ||
| 203 | " process." | ||
| 204 | ), | ||
| 205 | ) | ||
| 206 | parser.add_argument( | ||
| 207 | "--num_train_epochs", | 198 | "--num_train_epochs", |
| 208 | type=int, | 199 | type=int, |
| 209 | default=100 | 200 | default=100 |
| @@ -577,24 +568,24 @@ def main(): | |||
| 577 | ) | 568 | ) |
| 578 | 569 | ||
| 579 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | 570 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") |
| 580 | basepath = Path(args.output_dir).joinpath(slugify(args.project), now) | 571 | output_dir = Path(args.output_dir).joinpath(slugify(args.project), now) |
| 581 | basepath.mkdir(parents=True, exist_ok=True) | 572 | output_dir.mkdir(parents=True, exist_ok=True) |
| 582 | 573 | ||
| 583 | accelerator = Accelerator( | 574 | accelerator = Accelerator( |
| 584 | log_with=LoggerType.TENSORBOARD, | 575 | log_with=LoggerType.TENSORBOARD, |
| 585 | logging_dir=f"{basepath}", | 576 | logging_dir=f"{output_dir}", |
| 586 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 577 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
| 587 | mixed_precision=args.mixed_precision | 578 | mixed_precision=args.mixed_precision |
| 588 | ) | 579 | ) |
| 589 | 580 | ||
| 590 | logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) | 581 | logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) |
| 591 | 582 | ||
| 592 | if args.seed is None: | 583 | if args.seed is None: |
| 593 | args.seed = torch.random.seed() >> 32 | 584 | args.seed = torch.random.seed() >> 32 |
| 594 | 585 | ||
| 595 | set_seed(args.seed) | 586 | set_seed(args.seed) |
| 596 | 587 | ||
| 597 | save_args(basepath, args) | 588 | save_args(output_dir, args) |
| 598 | 589 | ||
| 599 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( | 590 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( |
| 600 | args.pretrained_model_name_or_path) | 591 | args.pretrained_model_name_or_path) |
| @@ -618,7 +609,7 @@ def main(): | |||
| 618 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) | 609 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) |
| 619 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") | 610 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") |
| 620 | 611 | ||
| 621 | placeholder_token_ids = add_placeholder_tokens( | 612 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( |
| 622 | tokenizer=tokenizer, | 613 | tokenizer=tokenizer, |
| 623 | embeddings=embeddings, | 614 | embeddings=embeddings, |
| 624 | placeholder_tokens=args.placeholder_tokens, | 615 | placeholder_tokens=args.placeholder_tokens, |
| @@ -627,7 +618,9 @@ def main(): | |||
| 627 | ) | 618 | ) |
| 628 | 619 | ||
| 629 | if len(placeholder_token_ids) != 0: | 620 | if len(placeholder_token_ids) != 0: |
| 630 | print(f"Added {len(placeholder_token_ids)} new tokens: {list(zip(args.placeholder_tokens, placeholder_token_ids))}") | 621 | initializer_token_id_lens = [len(id) for id in initializer_token_ids] |
| 622 | placeholder_token_stats = list(zip(args.placeholder_tokens, placeholder_token_ids, initializer_token_id_lens)) | ||
| 623 | print(f"Added {len(placeholder_token_ids)} new tokens: {placeholder_token_stats}") | ||
| 631 | 624 | ||
| 632 | if args.use_ema: | 625 | if args.use_ema: |
| 633 | ema_unet = EMAModel( | 626 | ema_unet = EMAModel( |
| @@ -726,7 +719,6 @@ def main(): | |||
| 726 | template_key=args.train_data_template, | 719 | template_key=args.train_data_template, |
| 727 | valid_set_size=args.valid_set_size, | 720 | valid_set_size=args.valid_set_size, |
| 728 | valid_set_repeat=args.valid_set_repeat, | 721 | valid_set_repeat=args.valid_set_repeat, |
| 729 | num_workers=args.dataloader_num_workers, | ||
| 730 | seed=args.seed, | 722 | seed=args.seed, |
| 731 | filter=keyword_filter, | 723 | filter=keyword_filter, |
| 732 | dtype=weight_dtype | 724 | dtype=weight_dtype |
| @@ -830,7 +822,6 @@ def main(): | |||
| 830 | noise_scheduler, | 822 | noise_scheduler, |
| 831 | unet, | 823 | unet, |
| 832 | text_encoder, | 824 | text_encoder, |
| 833 | args.num_class_images, | ||
| 834 | args.prior_loss_weight, | 825 | args.prior_loss_weight, |
| 835 | args.seed, | 826 | args.seed, |
| 836 | ) | 827 | ) |
| @@ -848,7 +839,8 @@ def main(): | |||
| 848 | scheduler=sample_scheduler, | 839 | scheduler=sample_scheduler, |
| 849 | placeholder_tokens=args.placeholder_tokens, | 840 | placeholder_tokens=args.placeholder_tokens, |
| 850 | placeholder_token_ids=placeholder_token_ids, | 841 | placeholder_token_ids=placeholder_token_ids, |
| 851 | output_dir=basepath, | 842 | output_dir=output_dir, |
| 843 | sample_steps=args.sample_steps, | ||
| 852 | sample_image_size=args.sample_image_size, | 844 | sample_image_size=args.sample_image_size, |
| 853 | sample_batch_size=args.sample_batch_size, | 845 | sample_batch_size=args.sample_batch_size, |
| 854 | sample_batches=args.sample_batches, | 846 | sample_batches=args.sample_batches, |
| @@ -873,7 +865,7 @@ def main(): | |||
| 873 | ) | 865 | ) |
| 874 | lr_finder.run(num_epochs=100, end_lr=1e2) | 866 | lr_finder.run(num_epochs=100, end_lr=1e2) |
| 875 | 867 | ||
| 876 | plt.savefig(basepath.joinpath("lr.png"), dpi=300) | 868 | plt.savefig(output_dir.joinpath("lr.png"), dpi=300) |
| 877 | plt.close() | 869 | plt.close() |
| 878 | else: | 870 | else: |
| 879 | train_loop( | 871 | train_loop( |
| @@ -886,7 +878,6 @@ def main(): | |||
| 886 | val_dataloader=val_dataloader, | 878 | val_dataloader=val_dataloader, |
| 887 | loss_step=loss_step_, | 879 | loss_step=loss_step_, |
| 888 | sample_frequency=args.sample_frequency, | 880 | sample_frequency=args.sample_frequency, |
| 889 | sample_steps=args.sample_steps, | ||
| 890 | checkpoint_frequency=args.checkpoint_frequency, | 881 | checkpoint_frequency=args.checkpoint_frequency, |
| 891 | global_step_offset=0, | 882 | global_step_offset=0, |
| 892 | num_epochs=args.num_train_epochs, | 883 | num_epochs=args.num_train_epochs, |
diff --git a/train_ti.py b/train_ti.py index d2ca7eb..d752927 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -181,15 +181,6 @@ def parse_args(): | |||
| 181 | help='Vector shuffling algorithm. Choose between ["all", "trailing", "leading", "between", "auto", "off"]', | 181 | help='Vector shuffling algorithm. Choose between ["all", "trailing", "leading", "between", "auto", "off"]', |
| 182 | ) | 182 | ) |
| 183 | parser.add_argument( | 183 | parser.add_argument( |
| 184 | "--dataloader_num_workers", | ||
| 185 | type=int, | ||
| 186 | default=0, | ||
| 187 | help=( | ||
| 188 | "The number of subprocesses to use for data loading. 0 means that the data will be loaded in the main" | ||
| 189 | " process." | ||
| 190 | ), | ||
| 191 | ) | ||
| 192 | parser.add_argument( | ||
| 193 | "--num_train_epochs", | 184 | "--num_train_epochs", |
| 194 | type=int, | 185 | type=int, |
| 195 | default=100 | 186 | default=100 |
| @@ -575,24 +566,24 @@ def main(): | |||
| 575 | 566 | ||
| 576 | global_step_offset = args.global_step | 567 | global_step_offset = args.global_step |
| 577 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | 568 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") |
| 578 | basepath = Path(args.output_dir).joinpath(slugify(args.project), now) | 569 | output_dir = Path(args.output_dir).joinpath(slugify(args.project), now) |
| 579 | basepath.mkdir(parents=True, exist_ok=True) | 570 | output_dir.mkdir(parents=True, exist_ok=True) |
| 580 | 571 | ||
| 581 | accelerator = Accelerator( | 572 | accelerator = Accelerator( |
| 582 | log_with=LoggerType.TENSORBOARD, | 573 | log_with=LoggerType.TENSORBOARD, |
| 583 | logging_dir=f"{basepath}", | 574 | logging_dir=f"{output_dir}", |
| 584 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 575 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
| 585 | mixed_precision=args.mixed_precision | 576 | mixed_precision=args.mixed_precision |
| 586 | ) | 577 | ) |
| 587 | 578 | ||
| 588 | logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) | 579 | logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) |
| 589 | 580 | ||
| 590 | if args.seed is None: | 581 | if args.seed is None: |
| 591 | args.seed = torch.random.seed() >> 32 | 582 | args.seed = torch.random.seed() >> 32 |
| 592 | 583 | ||
| 593 | set_seed(args.seed) | 584 | set_seed(args.seed) |
| 594 | 585 | ||
| 595 | save_args(basepath, args) | 586 | save_args(output_dir, args) |
| 596 | 587 | ||
| 597 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( | 588 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( |
| 598 | args.pretrained_model_name_or_path) | 589 | args.pretrained_model_name_or_path) |
| @@ -616,7 +607,7 @@ def main(): | |||
| 616 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) | 607 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) |
| 617 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") | 608 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") |
| 618 | 609 | ||
| 619 | placeholder_token_ids = add_placeholder_tokens( | 610 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( |
| 620 | tokenizer=tokenizer, | 611 | tokenizer=tokenizer, |
| 621 | embeddings=embeddings, | 612 | embeddings=embeddings, |
| 622 | placeholder_tokens=args.placeholder_tokens, | 613 | placeholder_tokens=args.placeholder_tokens, |
| @@ -625,7 +616,9 @@ def main(): | |||
| 625 | ) | 616 | ) |
| 626 | 617 | ||
| 627 | if len(placeholder_token_ids) != 0: | 618 | if len(placeholder_token_ids) != 0: |
| 628 | print(f"Added {len(placeholder_token_ids)} new tokens: {list(zip(args.placeholder_tokens, placeholder_token_ids))}") | 619 | initializer_token_id_lens = [len(id) for id in initializer_token_ids] |
| 620 | placeholder_token_stats = list(zip(args.placeholder_tokens, placeholder_token_ids, initializer_token_id_lens)) | ||
| 621 | print(f"Added {len(placeholder_token_ids)} new tokens: {placeholder_token_stats}") | ||
| 629 | 622 | ||
| 630 | if args.use_ema: | 623 | if args.use_ema: |
| 631 | ema_embeddings = EMAModel( | 624 | ema_embeddings = EMAModel( |
| @@ -708,7 +701,6 @@ def main(): | |||
| 708 | template_key=args.train_data_template, | 701 | template_key=args.train_data_template, |
| 709 | valid_set_size=args.valid_set_size, | 702 | valid_set_size=args.valid_set_size, |
| 710 | valid_set_repeat=args.valid_set_repeat, | 703 | valid_set_repeat=args.valid_set_repeat, |
| 711 | num_workers=args.dataloader_num_workers, | ||
| 712 | seed=args.seed, | 704 | seed=args.seed, |
| 713 | filter=keyword_filter, | 705 | filter=keyword_filter, |
| 714 | dtype=weight_dtype | 706 | dtype=weight_dtype |
| @@ -807,7 +799,6 @@ def main(): | |||
| 807 | noise_scheduler, | 799 | noise_scheduler, |
| 808 | unet, | 800 | unet, |
| 809 | text_encoder, | 801 | text_encoder, |
| 810 | args.num_class_images != 0, | ||
| 811 | args.prior_loss_weight, | 802 | args.prior_loss_weight, |
| 812 | args.seed, | 803 | args.seed, |
| 813 | ) | 804 | ) |
| @@ -825,7 +816,8 @@ def main(): | |||
| 825 | scheduler=sample_scheduler, | 816 | scheduler=sample_scheduler, |
| 826 | placeholder_tokens=args.placeholder_tokens, | 817 | placeholder_tokens=args.placeholder_tokens, |
| 827 | placeholder_token_ids=placeholder_token_ids, | 818 | placeholder_token_ids=placeholder_token_ids, |
| 828 | output_dir=basepath, | 819 | output_dir=output_dir, |
| 820 | sample_steps=args.sample_steps, | ||
| 829 | sample_image_size=args.sample_image_size, | 821 | sample_image_size=args.sample_image_size, |
| 830 | sample_batch_size=args.sample_batch_size, | 822 | sample_batch_size=args.sample_batch_size, |
| 831 | sample_batches=args.sample_batches, | 823 | sample_batches=args.sample_batches, |
| @@ -849,7 +841,7 @@ def main(): | |||
| 849 | ) | 841 | ) |
| 850 | lr_finder.run(num_epochs=100, end_lr=1e3) | 842 | lr_finder.run(num_epochs=100, end_lr=1e3) |
| 851 | 843 | ||
| 852 | plt.savefig(basepath.joinpath("lr.png"), dpi=300) | 844 | plt.savefig(output_dir.joinpath("lr.png"), dpi=300) |
| 853 | plt.close() | 845 | plt.close() |
| 854 | else: | 846 | else: |
| 855 | train_loop( | 847 | train_loop( |
| @@ -862,7 +854,6 @@ def main(): | |||
| 862 | val_dataloader=val_dataloader, | 854 | val_dataloader=val_dataloader, |
| 863 | loss_step=loss_step_, | 855 | loss_step=loss_step_, |
| 864 | sample_frequency=args.sample_frequency, | 856 | sample_frequency=args.sample_frequency, |
| 865 | sample_steps=args.sample_steps, | ||
| 866 | checkpoint_frequency=args.checkpoint_frequency, | 857 | checkpoint_frequency=args.checkpoint_frequency, |
| 867 | global_step_offset=global_step_offset, | 858 | global_step_offset=global_step_offset, |
| 868 | num_epochs=args.num_train_epochs, | 859 | num_epochs=args.num_train_epochs, |
diff --git a/training/common.py b/training/common.py index b6964a3..f5ab326 100644 --- a/training/common.py +++ b/training/common.py | |||
| @@ -45,42 +45,44 @@ def generate_class_images( | |||
| 45 | ): | 45 | ): |
| 46 | missing_data = [item for item in data_train if not item.class_image_path.exists()] | 46 | missing_data = [item for item in data_train if not item.class_image_path.exists()] |
| 47 | 47 | ||
| 48 | if len(missing_data) != 0: | 48 | if len(missing_data) == 0: |
| 49 | batched_data = [ | 49 | return |
| 50 | missing_data[i:i+sample_batch_size] | ||
| 51 | for i in range(0, len(missing_data), sample_batch_size) | ||
| 52 | ] | ||
| 53 | 50 | ||
| 54 | pipeline = VlpnStableDiffusion( | 51 | batched_data = [ |
| 55 | text_encoder=text_encoder, | 52 | missing_data[i:i+sample_batch_size] |
| 56 | vae=vae, | 53 | for i in range(0, len(missing_data), sample_batch_size) |
| 57 | unet=unet, | 54 | ] |
| 58 | tokenizer=tokenizer, | 55 | |
| 59 | scheduler=scheduler, | 56 | pipeline = VlpnStableDiffusion( |
| 60 | ).to(accelerator.device) | 57 | text_encoder=text_encoder, |
| 61 | pipeline.set_progress_bar_config(dynamic_ncols=True) | 58 | vae=vae, |
| 59 | unet=unet, | ||
| 60 | tokenizer=tokenizer, | ||
| 61 | scheduler=scheduler, | ||
| 62 | ).to(accelerator.device) | ||
| 63 | pipeline.set_progress_bar_config(dynamic_ncols=True) | ||
| 62 | 64 | ||
| 63 | with torch.inference_mode(): | 65 | with torch.inference_mode(): |
| 64 | for batch in batched_data: | 66 | for batch in batched_data: |
| 65 | image_name = [item.class_image_path for item in batch] | 67 | image_name = [item.class_image_path for item in batch] |
| 66 | prompt = [item.cprompt for item in batch] | 68 | prompt = [item.cprompt for item in batch] |
| 67 | nprompt = [item.nprompt for item in batch] | 69 | nprompt = [item.nprompt for item in batch] |
| 68 | 70 | ||
| 69 | images = pipeline( | 71 | images = pipeline( |
| 70 | prompt=prompt, | 72 | prompt=prompt, |
| 71 | negative_prompt=nprompt, | 73 | negative_prompt=nprompt, |
| 72 | height=sample_image_size, | 74 | height=sample_image_size, |
| 73 | width=sample_image_size, | 75 | width=sample_image_size, |
| 74 | num_inference_steps=sample_steps | 76 | num_inference_steps=sample_steps |
| 75 | ).images | 77 | ).images |
| 76 | 78 | ||
| 77 | for i, image in enumerate(images): | 79 | for i, image in enumerate(images): |
| 78 | image.save(image_name[i]) | 80 | image.save(image_name[i]) |
| 79 | 81 | ||
| 80 | del pipeline | 82 | del pipeline |
| 81 | 83 | ||
| 82 | if torch.cuda.is_available(): | 84 | if torch.cuda.is_available(): |
| 83 | torch.cuda.empty_cache() | 85 | torch.cuda.empty_cache() |
| 84 | 86 | ||
| 85 | 87 | ||
| 86 | def get_models(pretrained_model_name_or_path: str): | 88 | def get_models(pretrained_model_name_or_path: str): |
| @@ -119,7 +121,7 @@ def add_placeholder_tokens( | |||
| 119 | for (placeholder_token_id, initializer_token_id) in zip(placeholder_token_ids, initializer_token_ids): | 121 | for (placeholder_token_id, initializer_token_id) in zip(placeholder_token_ids, initializer_token_ids): |
| 120 | embeddings.add_embed(placeholder_token_id, initializer_token_id) | 122 | embeddings.add_embed(placeholder_token_id, initializer_token_id) |
| 121 | 123 | ||
| 122 | return placeholder_token_ids | 124 | return placeholder_token_ids, initializer_token_ids |
| 123 | 125 | ||
| 124 | 126 | ||
| 125 | def loss_step( | 127 | def loss_step( |
| @@ -127,7 +129,6 @@ def loss_step( | |||
| 127 | noise_scheduler: DDPMScheduler, | 129 | noise_scheduler: DDPMScheduler, |
| 128 | unet: UNet2DConditionModel, | 130 | unet: UNet2DConditionModel, |
| 129 | text_encoder: CLIPTextModel, | 131 | text_encoder: CLIPTextModel, |
| 130 | with_prior: bool, | ||
| 131 | prior_loss_weight: float, | 132 | prior_loss_weight: float, |
| 132 | seed: int, | 133 | seed: int, |
| 133 | step: int, | 134 | step: int, |
| @@ -138,16 +139,23 @@ def loss_step( | |||
| 138 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() | 139 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() |
| 139 | latents = latents * 0.18215 | 140 | latents = latents * 0.18215 |
| 140 | 141 | ||
| 142 | generator = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None | ||
| 143 | |||
| 141 | # Sample noise that we'll add to the latents | 144 | # Sample noise that we'll add to the latents |
| 142 | noise = torch.randn_like(latents) | 145 | noise = torch.randn( |
| 146 | latents.shape, | ||
| 147 | dtype=latents.dtype, | ||
| 148 | layout=latents.layout, | ||
| 149 | device=latents.device, | ||
| 150 | generator=generator | ||
| 151 | ) | ||
| 143 | bsz = latents.shape[0] | 152 | bsz = latents.shape[0] |
| 144 | # Sample a random timestep for each image | 153 | # Sample a random timestep for each image |
| 145 | timesteps_gen = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None | ||
| 146 | timesteps = torch.randint( | 154 | timesteps = torch.randint( |
| 147 | 0, | 155 | 0, |
| 148 | noise_scheduler.config.num_train_timesteps, | 156 | noise_scheduler.config.num_train_timesteps, |
| 149 | (bsz,), | 157 | (bsz,), |
| 150 | generator=timesteps_gen, | 158 | generator=generator, |
| 151 | device=latents.device, | 159 | device=latents.device, |
| 152 | ) | 160 | ) |
| 153 | timesteps = timesteps.long() | 161 | timesteps = timesteps.long() |
| @@ -176,7 +184,7 @@ def loss_step( | |||
| 176 | else: | 184 | else: |
| 177 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | 185 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") |
| 178 | 186 | ||
| 179 | if with_prior: | 187 | if batch["with_prior"]: |
| 180 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. | 188 | # Chunk the noise and model_pred into two parts and compute the loss on each part separately. |
| 181 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) | 189 | model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) |
| 182 | target, target_prior = torch.chunk(target, 2, dim=0) | 190 | target, target_prior = torch.chunk(target, 2, dim=0) |
| @@ -207,7 +215,6 @@ def train_loop( | |||
| 207 | val_dataloader: DataLoader, | 215 | val_dataloader: DataLoader, |
| 208 | loss_step: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], | 216 | loss_step: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], |
| 209 | sample_frequency: int = 10, | 217 | sample_frequency: int = 10, |
| 210 | sample_steps: int = 20, | ||
| 211 | checkpoint_frequency: int = 50, | 218 | checkpoint_frequency: int = 50, |
| 212 | global_step_offset: int = 0, | 219 | global_step_offset: int = 0, |
| 213 | num_epochs: int = 100, | 220 | num_epochs: int = 100, |
| @@ -251,7 +258,7 @@ def train_loop( | |||
| 251 | for epoch in range(num_epochs): | 258 | for epoch in range(num_epochs): |
| 252 | if accelerator.is_main_process: | 259 | if accelerator.is_main_process: |
| 253 | if epoch % sample_frequency == 0: | 260 | if epoch % sample_frequency == 0: |
| 254 | checkpointer.save_samples(global_step + global_step_offset, sample_steps) | 261 | checkpointer.save_samples(global_step + global_step_offset) |
| 255 | 262 | ||
| 256 | if epoch % checkpoint_frequency == 0 and epoch != 0: | 263 | if epoch % checkpoint_frequency == 0 and epoch != 0: |
| 257 | checkpointer.checkpoint(global_step + global_step_offset, "training") | 264 | checkpointer.checkpoint(global_step + global_step_offset, "training") |
| @@ -353,7 +360,7 @@ def train_loop( | |||
| 353 | if accelerator.is_main_process: | 360 | if accelerator.is_main_process: |
| 354 | print("Finished!") | 361 | print("Finished!") |
| 355 | checkpointer.checkpoint(global_step + global_step_offset, "end") | 362 | checkpointer.checkpoint(global_step + global_step_offset, "end") |
| 356 | checkpointer.save_samples(global_step + global_step_offset, sample_steps) | 363 | checkpointer.save_samples(global_step + global_step_offset) |
| 357 | accelerator.end_training() | 364 | accelerator.end_training() |
| 358 | 365 | ||
| 359 | except KeyboardInterrupt: | 366 | except KeyboardInterrupt: |
diff --git a/training/util.py b/training/util.py index cc4cdee..1008021 100644 --- a/training/util.py +++ b/training/util.py | |||
| @@ -44,32 +44,29 @@ class CheckpointerBase: | |||
| 44 | train_dataloader, | 44 | train_dataloader, |
| 45 | val_dataloader, | 45 | val_dataloader, |
| 46 | output_dir: Path, | 46 | output_dir: Path, |
| 47 | sample_image_size: int, | 47 | sample_steps: int = 20, |
| 48 | sample_batches: int, | 48 | sample_guidance_scale: float = 7.5, |
| 49 | sample_batch_size: int, | 49 | sample_image_size: int = 768, |
| 50 | sample_batches: int = 1, | ||
| 51 | sample_batch_size: int = 1, | ||
| 50 | seed: Optional[int] = None | 52 | seed: Optional[int] = None |
| 51 | ): | 53 | ): |
| 52 | self.train_dataloader = train_dataloader | 54 | self.train_dataloader = train_dataloader |
| 53 | self.val_dataloader = val_dataloader | 55 | self.val_dataloader = val_dataloader |
| 54 | self.output_dir = output_dir | 56 | self.output_dir = output_dir |
| 55 | self.sample_image_size = sample_image_size | 57 | self.sample_image_size = sample_image_size |
| 56 | self.seed = seed if seed is not None else torch.random.seed() | 58 | self.sample_steps = sample_steps |
| 59 | self.sample_guidance_scale = sample_guidance_scale | ||
| 57 | self.sample_batches = sample_batches | 60 | self.sample_batches = sample_batches |
| 58 | self.sample_batch_size = sample_batch_size | 61 | self.sample_batch_size = sample_batch_size |
| 62 | self.seed = seed if seed is not None else torch.random.seed() | ||
| 59 | 63 | ||
| 60 | @torch.no_grad() | 64 | @torch.no_grad() |
| 61 | def checkpoint(self, step: int, postfix: str): | 65 | def checkpoint(self, step: int, postfix: str): |
| 62 | pass | 66 | pass |
| 63 | 67 | ||
| 64 | @torch.inference_mode() | 68 | @torch.inference_mode() |
| 65 | def save_samples( | 69 | def save_samples(self, pipeline, step: int): |
| 66 | self, | ||
| 67 | pipeline, | ||
| 68 | step: int, | ||
| 69 | num_inference_steps: int, | ||
| 70 | guidance_scale: float = 7.5, | ||
| 71 | eta: float = 0.0 | ||
| 72 | ): | ||
| 73 | samples_path = Path(self.output_dir).joinpath("samples") | 70 | samples_path = Path(self.output_dir).joinpath("samples") |
| 74 | 71 | ||
| 75 | generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) | 72 | generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) |
| @@ -110,9 +107,8 @@ class CheckpointerBase: | |||
| 110 | height=self.sample_image_size, | 107 | height=self.sample_image_size, |
| 111 | width=self.sample_image_size, | 108 | width=self.sample_image_size, |
| 112 | generator=gen, | 109 | generator=gen, |
| 113 | guidance_scale=guidance_scale, | 110 | guidance_scale=self.sample_guidance_scale, |
| 114 | eta=eta, | 111 | num_inference_steps=self.sample_steps, |
| 115 | num_inference_steps=num_inference_steps, | ||
| 116 | output_type='pil' | 112 | output_type='pil' |
| 117 | ).images | 113 | ).images |
| 118 | 114 | ||
