diff options
-rw-r--r-- | .gitignore | 1 | ||||
-rw-r--r-- | dreambooth.py | 45 | ||||
-rw-r--r-- | textual_inversion.py | 70 |
3 files changed, 59 insertions, 57 deletions
@@ -161,4 +161,5 @@ cython_debug/ | |||
161 | 161 | ||
162 | output/ | 162 | output/ |
163 | conf/ | 163 | conf/ |
164 | embeddings/ | ||
164 | v1-inference.yaml* | 165 | v1-inference.yaml* |
diff --git a/dreambooth.py b/dreambooth.py index 24e6091..a26bea7 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -68,7 +68,7 @@ def parse_args(): | |||
68 | parser.add_argument( | 68 | parser.add_argument( |
69 | "--num_class_images", | 69 | "--num_class_images", |
70 | type=int, | 70 | type=int, |
71 | default=2, | 71 | default=4, |
72 | help="How many class images to generate per training image." | 72 | help="How many class images to generate per training image." |
73 | ) | 73 | ) |
74 | parser.add_argument( | 74 | parser.add_argument( |
@@ -106,7 +106,8 @@ def parse_args(): | |||
106 | parser.add_argument( | 106 | parser.add_argument( |
107 | "--num_train_epochs", | 107 | "--num_train_epochs", |
108 | type=int, | 108 | type=int, |
109 | default=100) | 109 | default=100 |
110 | ) | ||
110 | parser.add_argument( | 111 | parser.add_argument( |
111 | "--max_train_steps", | 112 | "--max_train_steps", |
112 | type=int, | 113 | type=int, |
@@ -293,7 +294,7 @@ class Checkpointer: | |||
293 | unet, | 294 | unet, |
294 | tokenizer, | 295 | tokenizer, |
295 | text_encoder, | 296 | text_encoder, |
296 | output_dir, | 297 | output_dir: Path, |
297 | instance_identifier, | 298 | instance_identifier, |
298 | sample_image_size, | 299 | sample_image_size, |
299 | sample_batches, | 300 | sample_batches, |
@@ -321,14 +322,14 @@ class Checkpointer: | |||
321 | pipeline = VlpnStableDiffusion( | 322 | pipeline = VlpnStableDiffusion( |
322 | text_encoder=self.text_encoder, | 323 | text_encoder=self.text_encoder, |
323 | vae=self.vae, | 324 | vae=self.vae, |
324 | unet=self.accelerator.unwrap_model(self.unet), | 325 | unet=unwrapped, |
325 | tokenizer=self.tokenizer, | 326 | tokenizer=self.tokenizer, |
326 | scheduler=PNDMScheduler( | 327 | scheduler=PNDMScheduler( |
327 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True | 328 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True |
328 | ), | 329 | ), |
329 | ) | 330 | ) |
330 | pipeline.enable_attention_slicing() | 331 | pipeline.enable_attention_slicing() |
331 | pipeline.save_pretrained(f"{self.output_dir}/model") | 332 | pipeline.save_pretrained(self.output_dir.joinpath("model")) |
332 | 333 | ||
333 | del unwrapped | 334 | del unwrapped |
334 | del pipeline | 335 | del pipeline |
@@ -524,7 +525,7 @@ def main(): | |||
524 | tokenizer=tokenizer, | 525 | tokenizer=tokenizer, |
525 | instance_identifier=args.instance_identifier, | 526 | instance_identifier=args.instance_identifier, |
526 | class_identifier=args.class_identifier, | 527 | class_identifier=args.class_identifier, |
527 | class_subdir="db_cls", | 528 | class_subdir="cls", |
528 | num_class_images=args.num_class_images, | 529 | num_class_images=args.num_class_images, |
529 | size=args.resolution, | 530 | size=args.resolution, |
530 | repeats=args.repeats, | 531 | repeats=args.repeats, |
@@ -580,21 +581,6 @@ def main(): | |||
580 | train_dataloader = datamodule.train_dataloader() | 581 | train_dataloader = datamodule.train_dataloader() |
581 | val_dataloader = datamodule.val_dataloader() | 582 | val_dataloader = datamodule.val_dataloader() |
582 | 583 | ||
583 | checkpointer = Checkpointer( | ||
584 | datamodule=datamodule, | ||
585 | accelerator=accelerator, | ||
586 | vae=vae, | ||
587 | unet=unet, | ||
588 | tokenizer=tokenizer, | ||
589 | text_encoder=text_encoder, | ||
590 | output_dir=basepath, | ||
591 | instance_identifier=args.instance_identifier, | ||
592 | sample_image_size=args.sample_image_size, | ||
593 | sample_batch_size=args.sample_batch_size, | ||
594 | sample_batches=args.sample_batches, | ||
595 | seed=args.seed | ||
596 | ) | ||
597 | |||
598 | # Scheduler and math around the number of training steps. | 584 | # Scheduler and math around the number of training steps. |
599 | overrode_max_train_steps = False | 585 | overrode_max_train_steps = False |
600 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) | 586 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) |
@@ -613,7 +599,7 @@ def main(): | |||
613 | unet, optimizer, train_dataloader, val_dataloader, lr_scheduler | 599 | unet, optimizer, train_dataloader, val_dataloader, lr_scheduler |
614 | ) | 600 | ) |
615 | 601 | ||
616 | # Move vae and unet to device | 602 | # Move text_encoder and vae to device |
617 | text_encoder.to(accelerator.device) | 603 | text_encoder.to(accelerator.device) |
618 | vae.to(accelerator.device) | 604 | vae.to(accelerator.device) |
619 | 605 | ||
@@ -649,6 +635,21 @@ def main(): | |||
649 | global_step = 0 | 635 | global_step = 0 |
650 | min_val_loss = np.inf | 636 | min_val_loss = np.inf |
651 | 637 | ||
638 | checkpointer = Checkpointer( | ||
639 | datamodule=datamodule, | ||
640 | accelerator=accelerator, | ||
641 | vae=vae, | ||
642 | unet=unet, | ||
643 | tokenizer=tokenizer, | ||
644 | text_encoder=text_encoder, | ||
645 | output_dir=basepath, | ||
646 | instance_identifier=args.instance_identifier, | ||
647 | sample_image_size=args.sample_image_size, | ||
648 | sample_batch_size=args.sample_batch_size, | ||
649 | sample_batches=args.sample_batches, | ||
650 | seed=args.seed | ||
651 | ) | ||
652 | |||
652 | if accelerator.is_main_process: | 653 | if accelerator.is_main_process: |
653 | checkpointer.save_samples( | 654 | checkpointer.save_samples( |
654 | 0, | 655 | 0, |
diff --git a/textual_inversion.py b/textual_inversion.py index 86fcdfe..4f2de9e 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
@@ -19,7 +19,7 @@ from schedulers.scheduling_euler_a import EulerAScheduler | |||
19 | from diffusers.optimization import get_scheduler | 19 | from diffusers.optimization import get_scheduler |
20 | from PIL import Image | 20 | from PIL import Image |
21 | from tqdm.auto import tqdm | 21 | from tqdm.auto import tqdm |
22 | from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer | 22 | from transformers import CLIPTextModel, CLIPTokenizer |
23 | from slugify import slugify | 23 | from slugify import slugify |
24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
25 | import json | 25 | import json |
@@ -70,7 +70,7 @@ def parse_args(): | |||
70 | parser.add_argument( | 70 | parser.add_argument( |
71 | "--num_class_images", | 71 | "--num_class_images", |
72 | type=int, | 72 | type=int, |
73 | default=2, | 73 | default=4, |
74 | help="How many class images to generate per training image." | 74 | help="How many class images to generate per training image." |
75 | ) | 75 | ) |
76 | parser.add_argument( | 76 | parser.add_argument( |
@@ -107,7 +107,8 @@ def parse_args(): | |||
107 | parser.add_argument( | 107 | parser.add_argument( |
108 | "--num_train_epochs", | 108 | "--num_train_epochs", |
109 | type=int, | 109 | type=int, |
110 | default=100) | 110 | default=100 |
111 | ) | ||
111 | parser.add_argument( | 112 | parser.add_argument( |
112 | "--max_train_steps", | 113 | "--max_train_steps", |
113 | type=int, | 114 | type=int, |
@@ -128,7 +129,7 @@ def parse_args(): | |||
128 | parser.add_argument( | 129 | parser.add_argument( |
129 | "--learning_rate", | 130 | "--learning_rate", |
130 | type=float, | 131 | type=float, |
131 | default=1e-4, | 132 | default=5e-5, |
132 | help="Initial learning rate (after the potential warmup period) to use.", | 133 | help="Initial learning rate (after the potential warmup period) to use.", |
133 | ) | 134 | ) |
134 | parser.add_argument( | 135 | parser.add_argument( |
@@ -325,9 +326,10 @@ class Checkpointer: | |||
325 | vae, | 326 | vae, |
326 | unet, | 327 | unet, |
327 | tokenizer, | 328 | tokenizer, |
329 | text_encoder, | ||
328 | placeholder_token, | 330 | placeholder_token, |
329 | placeholder_token_id, | 331 | placeholder_token_id, |
330 | output_dir, | 332 | output_dir: Path, |
331 | sample_image_size, | 333 | sample_image_size, |
332 | sample_batches, | 334 | sample_batches, |
333 | sample_batch_size, | 335 | sample_batch_size, |
@@ -338,6 +340,7 @@ class Checkpointer: | |||
338 | self.vae = vae | 340 | self.vae = vae |
339 | self.unet = unet | 341 | self.unet = unet |
340 | self.tokenizer = tokenizer | 342 | self.tokenizer = tokenizer |
343 | self.text_encoder = text_encoder | ||
341 | self.placeholder_token = placeholder_token | 344 | self.placeholder_token = placeholder_token |
342 | self.placeholder_token_id = placeholder_token_id | 345 | self.placeholder_token_id = placeholder_token_id |
343 | self.output_dir = output_dir | 346 | self.output_dir = output_dir |
@@ -347,14 +350,14 @@ class Checkpointer: | |||
347 | self.sample_batch_size = sample_batch_size | 350 | self.sample_batch_size = sample_batch_size |
348 | 351 | ||
349 | @torch.no_grad() | 352 | @torch.no_grad() |
350 | def checkpoint(self, step, postfix, text_encoder, save_samples=True, path=None): | 353 | def checkpoint(self, step, postfix, path=None): |
351 | print("Saving checkpoint for step %d..." % step) | 354 | print("Saving checkpoint for step %d..." % step) |
352 | 355 | ||
353 | if path is None: | 356 | if path is None: |
354 | checkpoints_path = f"{self.output_dir}/checkpoints" | 357 | checkpoints_path = self.output_dir.joinpath("checkpoints") |
355 | os.makedirs(checkpoints_path, exist_ok=True) | 358 | checkpoints_path.mkdir(parents=True, exist_ok=True) |
356 | 359 | ||
357 | unwrapped = self.accelerator.unwrap_model(text_encoder) | 360 | unwrapped = self.accelerator.unwrap_model(self.text_encoder) |
358 | 361 | ||
359 | # Save a checkpoint | 362 | # Save a checkpoint |
360 | learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id] | 363 | learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id] |
@@ -364,17 +367,16 @@ class Checkpointer: | |||
364 | if path is not None: | 367 | if path is not None: |
365 | torch.save(learned_embeds_dict, path) | 368 | torch.save(learned_embeds_dict, path) |
366 | else: | 369 | else: |
367 | torch.save(learned_embeds_dict, f"{checkpoints_path}/{filename}") | 370 | torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename)) |
368 | torch.save(learned_embeds_dict, f"{checkpoints_path}/last.bin") | ||
369 | 371 | ||
370 | del unwrapped | 372 | del unwrapped |
371 | del learned_embeds | 373 | del learned_embeds |
372 | 374 | ||
373 | @torch.no_grad() | 375 | @torch.no_grad() |
374 | def save_samples(self, step, text_encoder, height, width, guidance_scale, eta, num_inference_steps): | 376 | def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps): |
375 | samples_path = Path(self.output_dir).joinpath("samples") | 377 | samples_path = Path(self.output_dir).joinpath("samples") |
376 | 378 | ||
377 | unwrapped = self.accelerator.unwrap_model(text_encoder) | 379 | unwrapped = self.accelerator.unwrap_model(self.text_encoder) |
378 | scheduler = EulerAScheduler( | 380 | scheduler = EulerAScheduler( |
379 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | 381 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" |
380 | ) | 382 | ) |
@@ -608,7 +610,7 @@ def main(): | |||
608 | tokenizer=tokenizer, | 610 | tokenizer=tokenizer, |
609 | instance_identifier=args.placeholder_token, | 611 | instance_identifier=args.placeholder_token, |
610 | class_identifier=args.initializer_token, | 612 | class_identifier=args.initializer_token, |
611 | class_subdir="ti_cls", | 613 | class_subdir="cls", |
612 | num_class_images=args.num_class_images, | 614 | num_class_images=args.num_class_images, |
613 | size=args.resolution, | 615 | size=args.resolution, |
614 | repeats=args.repeats, | 616 | repeats=args.repeats, |
@@ -664,21 +666,6 @@ def main(): | |||
664 | train_dataloader = datamodule.train_dataloader() | 666 | train_dataloader = datamodule.train_dataloader() |
665 | val_dataloader = datamodule.val_dataloader() | 667 | val_dataloader = datamodule.val_dataloader() |
666 | 668 | ||
667 | checkpointer = Checkpointer( | ||
668 | datamodule=datamodule, | ||
669 | accelerator=accelerator, | ||
670 | vae=vae, | ||
671 | unet=unet, | ||
672 | tokenizer=tokenizer, | ||
673 | placeholder_token=args.placeholder_token, | ||
674 | placeholder_token_id=placeholder_token_id, | ||
675 | output_dir=basepath, | ||
676 | sample_image_size=args.sample_image_size, | ||
677 | sample_batch_size=args.sample_batch_size, | ||
678 | sample_batches=args.sample_batches, | ||
679 | seed=args.seed | ||
680 | ) | ||
681 | |||
682 | # Scheduler and math around the number of training steps. | 669 | # Scheduler and math around the number of training steps. |
683 | overrode_max_train_steps = False | 670 | overrode_max_train_steps = False |
684 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) | 671 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) |
@@ -733,10 +720,25 @@ def main(): | |||
733 | global_step = 0 | 720 | global_step = 0 |
734 | min_val_loss = np.inf | 721 | min_val_loss = np.inf |
735 | 722 | ||
723 | checkpointer = Checkpointer( | ||
724 | datamodule=datamodule, | ||
725 | accelerator=accelerator, | ||
726 | vae=vae, | ||
727 | unet=unet, | ||
728 | tokenizer=tokenizer, | ||
729 | text_encoder=text_encoder, | ||
730 | placeholder_token=args.placeholder_token, | ||
731 | placeholder_token_id=placeholder_token_id, | ||
732 | output_dir=basepath, | ||
733 | sample_image_size=args.sample_image_size, | ||
734 | sample_batch_size=args.sample_batch_size, | ||
735 | sample_batches=args.sample_batches, | ||
736 | seed=args.seed | ||
737 | ) | ||
738 | |||
736 | if accelerator.is_main_process: | 739 | if accelerator.is_main_process: |
737 | checkpointer.save_samples( | 740 | checkpointer.save_samples( |
738 | 0, | 741 | 0, |
739 | text_encoder, | ||
740 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) | 742 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) |
741 | 743 | ||
742 | local_progress_bar = tqdm( | 744 | local_progress_bar = tqdm( |
@@ -838,7 +840,7 @@ def main(): | |||
838 | local_progress_bar.clear() | 840 | local_progress_bar.clear() |
839 | global_progress_bar.clear() | 841 | global_progress_bar.clear() |
840 | 842 | ||
841 | checkpointer.checkpoint(global_step + global_step_offset, "training", text_encoder) | 843 | checkpointer.checkpoint(global_step + global_step_offset, "training") |
842 | save_resume_file(basepath, args, { | 844 | save_resume_file(basepath, args, { |
843 | "global_step": global_step + global_step_offset, | 845 | "global_step": global_step + global_step_offset, |
844 | "resume_checkpoint": f"{basepath}/checkpoints/last.bin" | 846 | "resume_checkpoint": f"{basepath}/checkpoints/last.bin" |
@@ -897,13 +899,12 @@ def main(): | |||
897 | 899 | ||
898 | if min_val_loss > val_loss: | 900 | if min_val_loss > val_loss: |
899 | accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") | 901 | accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") |
900 | checkpointer.checkpoint(global_step + global_step_offset, "milestone", text_encoder) | 902 | checkpointer.checkpoint(global_step + global_step_offset, "milestone") |
901 | min_val_loss = val_loss | 903 | min_val_loss = val_loss |
902 | 904 | ||
903 | if sample_checkpoint and accelerator.is_main_process: | 905 | if sample_checkpoint and accelerator.is_main_process: |
904 | checkpointer.save_samples( | 906 | checkpointer.save_samples( |
905 | global_step + global_step_offset, | 907 | global_step + global_step_offset, |
906 | text_encoder, | ||
907 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) | 908 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) |
908 | 909 | ||
909 | # Create the pipeline using using the trained modules and save it. | 910 | # Create the pipeline using using the trained modules and save it. |
@@ -912,7 +913,6 @@ def main(): | |||
912 | checkpointer.checkpoint( | 913 | checkpointer.checkpoint( |
913 | global_step + global_step_offset, | 914 | global_step + global_step_offset, |
914 | "end", | 915 | "end", |
915 | text_encoder, | ||
916 | path=f"{basepath}/learned_embeds.bin" | 916 | path=f"{basepath}/learned_embeds.bin" |
917 | ) | 917 | ) |
918 | 918 | ||
@@ -926,7 +926,7 @@ def main(): | |||
926 | except KeyboardInterrupt: | 926 | except KeyboardInterrupt: |
927 | if accelerator.is_main_process: | 927 | if accelerator.is_main_process: |
928 | print("Interrupted, saving checkpoint and resume state...") | 928 | print("Interrupted, saving checkpoint and resume state...") |
929 | checkpointer.checkpoint(global_step + global_step_offset, "end", text_encoder) | 929 | checkpointer.checkpoint(global_step + global_step_offset, "end") |
930 | save_resume_file(basepath, args, { | 930 | save_resume_file(basepath, args, { |
931 | "global_step": global_step + global_step_offset, | 931 | "global_step": global_step + global_step_offset, |
932 | "resume_checkpoint": f"{basepath}/checkpoints/last.bin" | 932 | "resume_checkpoint": f"{basepath}/checkpoints/last.bin" |