summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--dreambooth.py49
-rw-r--r--infer.py2
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py42
-rw-r--r--textual_inversion.py119
4 files changed, 101 insertions, 111 deletions
diff --git a/dreambooth.py b/dreambooth.py
index 3110c6d..9a6f70a 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -13,7 +13,7 @@ import torch.utils.checkpoint
13from accelerate import Accelerator 13from accelerate import Accelerator
14from accelerate.logging import get_logger 14from accelerate.logging import get_logger
15from accelerate.utils import LoggerType, set_seed 15from accelerate.utils import LoggerType, set_seed
16from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, PNDMScheduler, UNet2DConditionModel 16from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel
17from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup 17from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup
18from diffusers.training_utils import EMAModel 18from diffusers.training_utils import EMAModel
19from PIL import Image 19from PIL import Image
@@ -204,7 +204,7 @@ def parse_args():
204 parser.add_argument( 204 parser.add_argument(
205 "--lr_warmup_epochs", 205 "--lr_warmup_epochs",
206 type=int, 206 type=int,
207 default=20, 207 default=10,
208 help="Number of steps for the warmup in the lr scheduler." 208 help="Number of steps for the warmup in the lr scheduler."
209 ) 209 )
210 parser.add_argument( 210 parser.add_argument(
@@ -558,11 +558,11 @@ class Checkpointer:
558def main(): 558def main():
559 args = parse_args() 559 args = parse_args()
560 560
561 # if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: 561 if args.train_text_encoder and args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
562 # raise ValueError( 562 raise ValueError(
563 # "Gradient accumulation is not supported when training the text encoder in distributed training. " 563 "Gradient accumulation is not supported when training the text encoder in distributed training. "
564 # "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." 564 "Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
565 # ) 565 )
566 566
567 instance_identifier = args.instance_identifier 567 instance_identifier = args.instance_identifier
568 568
@@ -645,9 +645,9 @@ def main():
645 645
646 print(f"Token ID mappings:") 646 print(f"Token ID mappings:")
647 for (token_id, token) in zip(placeholder_token_id, args.placeholder_token): 647 for (token_id, token) in zip(placeholder_token_id, args.placeholder_token):
648 print(f"- {token_id} {token}")
649
650 embedding_file = embeddings_dir.joinpath(f"{token}.bin") 648 embedding_file = embeddings_dir.joinpath(f"{token}.bin")
649 embedding_source = "init"
650
651 if embedding_file.exists() and embedding_file.is_file(): 651 if embedding_file.exists() and embedding_file.is_file():
652 embedding_data = torch.load(embedding_file, map_location="cpu") 652 embedding_data = torch.load(embedding_file, map_location="cpu")
653 653
@@ -656,8 +656,11 @@ def main():
656 emb = emb.unsqueeze(0) 656 emb = emb.unsqueeze(0)
657 657
658 token_embeds[token_id] = emb 658 token_embeds[token_id] = emb
659 embedding_source = "file"
659 660
660 original_token_embeds = token_embeds.detach().clone().to(accelerator.device) 661 print(f"- {token_id} {token} ({embedding_source})")
662
663 original_token_embeds = token_embeds.clone().to(accelerator.device)
661 initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) 664 initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids)
662 665
663 for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings): 666 for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings):
@@ -946,7 +949,7 @@ def main():
946 sample_checkpoint = False 949 sample_checkpoint = False
947 950
948 for step, batch in enumerate(train_dataloader): 951 for step, batch in enumerate(train_dataloader):
949 with accelerator.accumulate(itertools.chain(unet, text_encoder)): 952 with accelerator.accumulate(unet):
950 # Convert images to latent space 953 # Convert images to latent space
951 latents = vae.encode(batch["pixel_values"]).latent_dist.sample() 954 latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
952 latents = latents * 0.18215 955 latents = latents * 0.18215
@@ -997,16 +1000,6 @@ def main():
997 1000
998 accelerator.backward(loss) 1001 accelerator.backward(loss)
999 1002
1000 if not args.train_text_encoder:
1001 # Keep the token embeddings fixed except the newly added
1002 # embeddings for the concept, as we only want to optimize the concept embeddings
1003 if accelerator.num_processes > 1:
1004 token_embeds = text_encoder.module.get_input_embeddings().weight
1005 else:
1006 token_embeds = text_encoder.get_input_embeddings().weight
1007
1008 token_embeds.data[index_fixed_tokens, :] = original_token_embeds[index_fixed_tokens, :]
1009
1010 if accelerator.sync_gradients: 1003 if accelerator.sync_gradients:
1011 params_to_clip = ( 1004 params_to_clip = (
1012 itertools.chain(unet.parameters(), text_encoder.parameters()) 1005 itertools.chain(unet.parameters(), text_encoder.parameters())
@@ -1022,6 +1015,12 @@ def main():
1022 ema_unet.step(unet) 1015 ema_unet.step(unet)
1023 optimizer.zero_grad(set_to_none=True) 1016 optimizer.zero_grad(set_to_none=True)
1024 1017
1018 if not args.train_text_encoder:
1019 # Let's make sure we don't update any embedding weights besides the newly added token
1020 with torch.no_grad():
1021 text_encoder.get_input_embeddings(
1022 ).weight[index_fixed_tokens] = original_token_embeds[index_fixed_tokens]
1023
1025 avg_loss.update(loss.detach_(), bsz) 1024 avg_loss.update(loss.detach_(), bsz)
1026 avg_acc.update(acc.detach_(), bsz) 1025 avg_acc.update(acc.detach_(), bsz)
1027 1026
@@ -1032,9 +1031,6 @@ def main():
1032 1031
1033 global_step += 1 1032 global_step += 1
1034 1033
1035 if global_step % args.sample_frequency == 0:
1036 sample_checkpoint = True
1037
1038 logs = { 1034 logs = {
1039 "train/loss": avg_loss.avg.item(), 1035 "train/loss": avg_loss.avg.item(),
1040 "train/acc": avg_acc.avg.item(), 1036 "train/acc": avg_acc.avg.item(),
@@ -1117,8 +1113,9 @@ def main():
1117 f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") 1113 f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}")
1118 max_acc_val = avg_acc_val.avg.item() 1114 max_acc_val = avg_acc_val.avg.item()
1119 1115
1120 if sample_checkpoint and accelerator.is_main_process: 1116 if accelerator.is_main_process:
1121 checkpointer.save_samples(global_step, args.sample_steps) 1117 if epoch % args.sample_frequency == 0:
1118 checkpointer.save_samples(global_step, args.sample_steps)
1122 1119
1123 # Create the pipeline using using the trained modules and save it. 1120 # Create the pipeline using using the trained modules and save it.
1124 if accelerator.is_main_process: 1121 if accelerator.is_main_process:
diff --git a/infer.py b/infer.py
index 5bd926a..f607041 100644
--- a/infer.py
+++ b/infer.py
@@ -31,7 +31,7 @@ torch.backends.cudnn.benchmark = True
31 31
32 32
33default_args = { 33default_args = {
34 "model": None, 34 "model": "stabilityai/stable-diffusion-2-1",
35 "precision": "fp32", 35 "precision": "fp32",
36 "ti_embeddings_dir": "embeddings_ti", 36 "ti_embeddings_dir": "embeddings_ti",
37 "output_dir": "output/inference", 37 "output_dir": "output/inference",
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index 141b9a7..707b639 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -421,25 +421,29 @@ class VlpnStableDiffusion(DiffusionPipeline):
421 extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) 421 extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
422 422
423 # 7. Denoising loop 423 # 7. Denoising loop
424 for i, t in enumerate(self.progress_bar(timesteps)): 424 num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
425 # expand the latents if we are doing classifier free guidance 425 with self.progress_bar(total=num_inference_steps) as progress_bar:
426 latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents 426 for i, t in enumerate(timesteps):
427 latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) 427 # expand the latents if we are doing classifier free guidance
428 428 latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
429 # predict the noise residual 429 latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
430 noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample 430
431 431 # predict the noise residual
432 # perform guidance 432 noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample
433 if do_classifier_free_guidance: 433
434 noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) 434 # perform guidance
435 noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) 435 if do_classifier_free_guidance:
436 436 noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
437 # compute the previous noisy sample x_t -> x_t-1 437 noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
438 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample 438
439 439 # compute the previous noisy sample x_t -> x_t-1
440 # call the callback, if provided 440 latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
441 if callback is not None and i % callback_steps == 0: 441
442 callback(i, t, latents) 442 # call the callback, if provided
443 if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
444 progress_bar.update()
445 if callback is not None and i % callback_steps == 0:
446 callback(i, t, latents)
443 447
444 # 8. Post-processing 448 # 8. Post-processing
445 image = self.decode_latents(latents) 449 image = self.decode_latents(latents)
diff --git a/textual_inversion.py b/textual_inversion.py
index a9c3326..11babd8 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -170,7 +170,7 @@ def parse_args():
170 parser.add_argument( 170 parser.add_argument(
171 "--lr_scheduler", 171 "--lr_scheduler",
172 type=str, 172 type=str,
173 default="one_cycle", 173 default="constant_with_warmup",
174 help=( 174 help=(
175 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' 175 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
176 ' "constant", "constant_with_warmup", "one_cycle"]' 176 ' "constant", "constant_with_warmup", "one_cycle"]'
@@ -231,14 +231,14 @@ def parse_args():
231 parser.add_argument( 231 parser.add_argument(
232 "--checkpoint_frequency", 232 "--checkpoint_frequency",
233 type=int, 233 type=int,
234 default=500, 234 default=5,
235 help="How often to save a checkpoint and sample image", 235 help="How often to save a checkpoint and sample image (in epochs)",
236 ) 236 )
237 parser.add_argument( 237 parser.add_argument(
238 "--sample_frequency", 238 "--sample_frequency",
239 type=int, 239 type=int,
240 default=100, 240 default=1,
241 help="How often to save a checkpoint and sample image", 241 help="How often to save a checkpoint and sample image (in epochs)",
242 ) 242 )
243 parser.add_argument( 243 parser.add_argument(
244 "--sample_image_size", 244 "--sample_image_size",
@@ -294,10 +294,9 @@ def parse_args():
294 help="Path to a directory to resume training from (ie, logs/token_name/2022-09-22T23-36-27)" 294 help="Path to a directory to resume training from (ie, logs/token_name/2022-09-22T23-36-27)"
295 ) 295 )
296 parser.add_argument( 296 parser.add_argument(
297 "--resume_checkpoint", 297 "--global_step",
298 type=str, 298 type=int,
299 default=None, 299 default=0,
300 help="Path to a specific checkpoint to resume training from (ie, logs/token_name/2022-09-22T23-36-27/checkpoints/something.bin)."
301 ) 300 )
302 parser.add_argument( 301 parser.add_argument(
303 "--config", 302 "--config",
@@ -512,19 +511,10 @@ def main():
512 if len(args.placeholder_token) != 0: 511 if len(args.placeholder_token) != 0:
513 instance_identifier = instance_identifier.format(args.placeholder_token[0]) 512 instance_identifier = instance_identifier.format(args.placeholder_token[0])
514 513
515 global_step_offset = 0 514 global_step_offset = args.global_step
516 if args.resume_from is not None: 515 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
517 basepath = Path(args.resume_from) 516 basepath = Path(args.output_dir).joinpath(slugify(instance_identifier), now)
518 print("Resuming state from %s" % args.resume_from) 517 basepath.mkdir(parents=True, exist_ok=True)
519 with open(basepath.joinpath("resume.json"), 'r') as f:
520 state = json.load(f)
521 global_step_offset = state["args"].get("global_step", 0)
522
523 print("We've trained %d steps so far" % global_step_offset)
524 else:
525 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
526 basepath = Path(args.output_dir).joinpath(slugify(instance_identifier), now)
527 basepath.mkdir(parents=True, exist_ok=True)
528 518
529 accelerator = Accelerator( 519 accelerator = Accelerator(
530 log_with=LoggerType.TENSORBOARD, 520 log_with=LoggerType.TENSORBOARD,
@@ -557,6 +547,7 @@ def main():
557 set_use_memory_efficient_attention_xformers(vae, True) 547 set_use_memory_efficient_attention_xformers(vae, True)
558 548
559 if args.gradient_checkpointing: 549 if args.gradient_checkpointing:
550 unet.enable_gradient_checkpointing()
560 text_encoder.gradient_checkpointing_enable() 551 text_encoder.gradient_checkpointing_enable()
561 552
562 print(f"Adding text embeddings: {args.placeholder_token}") 553 print(f"Adding text embeddings: {args.placeholder_token}")
@@ -577,14 +568,25 @@ def main():
577 568
578 # Initialise the newly added placeholder token with the embeddings of the initializer token 569 # Initialise the newly added placeholder token with the embeddings of the initializer token
579 token_embeds = text_encoder.get_input_embeddings().weight.data 570 token_embeds = text_encoder.get_input_embeddings().weight.data
580 original_token_embeds = token_embeds.detach().clone().to(accelerator.device)
581 571
582 if args.resume_checkpoint is not None: 572 if args.resume_from:
583 token_embeds[placeholder_token_id] = torch.load(args.resume_checkpoint)[args.placeholder_token] 573 resumepath = Path(args.resume_from).joinpath("checkpoints")
584 else: 574
585 initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) 575 for (token_id, token) in zip(placeholder_token_id, args.placeholder_token):
586 for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings): 576 embedding_file = resumepath.joinpath(f"{token}_{args.global_step}_end.bin")
587 token_embeds[token_id] = embeddings 577 embedding_data = torch.load(embedding_file, map_location="cpu")
578
579 emb = next(iter(embedding_data.values()))
580 if len(emb.shape) == 1:
581 emb = emb.unsqueeze(0)
582
583 token_embeds[token_id] = emb
584
585 original_token_embeds = token_embeds.clone().to(accelerator.device)
586
587 initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids)
588 for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings):
589 token_embeds[token_id] = embeddings
588 590
589 index_fixed_tokens = torch.arange(len(tokenizer)) 591 index_fixed_tokens = torch.arange(len(tokenizer))
590 index_fixed_tokens = index_fixed_tokens[~torch.isin(index_fixed_tokens, torch.tensor(placeholder_token_id))] 592 index_fixed_tokens = index_fixed_tokens[~torch.isin(index_fixed_tokens, torch.tensor(placeholder_token_id))]
@@ -891,21 +893,16 @@ def main():
891 893
892 accelerator.backward(loss) 894 accelerator.backward(loss)
893 895
894 # Keep the token embeddings fixed except the newly added
895 # embeddings for the concept, as we only want to optimize the concept embeddings
896 if accelerator.num_processes > 1:
897 token_embeds = text_encoder.module.get_input_embeddings().weight
898 else:
899 token_embeds = text_encoder.get_input_embeddings().weight
900
901 # Get the index for tokens that we want to freeze
902 token_embeds.data[index_fixed_tokens, :] = original_token_embeds[index_fixed_tokens, :]
903
904 optimizer.step() 896 optimizer.step()
905 if not accelerator.optimizer_step_was_skipped: 897 if not accelerator.optimizer_step_was_skipped:
906 lr_scheduler.step() 898 lr_scheduler.step()
907 optimizer.zero_grad(set_to_none=True) 899 optimizer.zero_grad(set_to_none=True)
908 900
901 # Let's make sure we don't update any embedding weights besides the newly added token
902 with torch.no_grad():
903 text_encoder.get_input_embeddings(
904 ).weight[index_fixed_tokens] = original_token_embeds[index_fixed_tokens]
905
909 loss = loss.detach().item() 906 loss = loss.detach().item()
910 train_loss += loss 907 train_loss += loss
911 908
@@ -916,19 +913,6 @@ def main():
916 913
917 global_step += 1 914 global_step += 1
918 915
919 if global_step % args.sample_frequency == 0:
920 sample_checkpoint = True
921
922 if global_step % args.checkpoint_frequency == 0 and global_step > 0 and accelerator.is_main_process:
923 local_progress_bar.clear()
924 global_progress_bar.clear()
925
926 checkpointer.checkpoint(global_step + global_step_offset, "training")
927 save_args(basepath, args, {
928 "global_step": global_step + global_step_offset,
929 "resume_checkpoint": f"{basepath}/checkpoints/last.bin"
930 })
931
932 logs = {"train/loss": loss, "lr": lr_scheduler.get_last_lr()[0]} 916 logs = {"train/loss": loss, "lr": lr_scheduler.get_last_lr()[0]}
933 917
934 accelerator.log(logs, step=global_step) 918 accelerator.log(logs, step=global_step)
@@ -992,24 +976,30 @@ def main():
992 local_progress_bar.clear() 976 local_progress_bar.clear()
993 global_progress_bar.clear() 977 global_progress_bar.clear()
994 978
995 if min_val_loss > val_loss: 979 if accelerator.is_main_process:
996 accelerator.print( 980 if min_val_loss > val_loss:
997 f"Global step {global_step}: Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") 981 accelerator.print(
998 checkpointer.checkpoint(global_step + global_step_offset, "milestone") 982 f"Global step {global_step}: Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}")
999 min_val_loss = val_loss 983 checkpointer.checkpoint(global_step + global_step_offset, "milestone")
984 min_val_loss = val_loss
985
986 if epoch % args.checkpoint_frequency == 0:
987 checkpointer.checkpoint(global_step + global_step_offset, "training")
988 save_args(basepath, args, {
989 "global_step": global_step + global_step_offset
990 })
1000 991
1001 if sample_checkpoint and accelerator.is_main_process: 992 if epoch % args.sample_frequency == 0:
1002 checkpointer.save_samples( 993 checkpointer.save_samples(
1003 global_step + global_step_offset, 994 global_step + global_step_offset,
1004 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) 995 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps)
1005 996
1006 # Create the pipeline using using the trained modules and save it. 997 # Create the pipeline using using the trained modules and save it.
1007 if accelerator.is_main_process: 998 if accelerator.is_main_process:
1008 print("Finished! Saving final checkpoint and resume state.") 999 print("Finished! Saving final checkpoint and resume state.")
1009 checkpointer.checkpoint(global_step + global_step_offset, "end") 1000 checkpointer.checkpoint(global_step + global_step_offset, "end")
1010 save_args(basepath, args, { 1001 save_args(basepath, args, {
1011 "global_step": global_step + global_step_offset, 1002 "global_step": global_step + global_step_offset
1012 "resume_checkpoint": f"{basepath}/checkpoints/last.bin"
1013 }) 1003 })
1014 accelerator.end_training() 1004 accelerator.end_training()
1015 1005
@@ -1018,8 +1008,7 @@ def main():
1018 print("Interrupted, saving checkpoint and resume state...") 1008 print("Interrupted, saving checkpoint and resume state...")
1019 checkpointer.checkpoint(global_step + global_step_offset, "end") 1009 checkpointer.checkpoint(global_step + global_step_offset, "end")
1020 save_args(basepath, args, { 1010 save_args(basepath, args, {
1021 "global_step": global_step + global_step_offset, 1011 "global_step": global_step + global_step_offset
1022 "resume_checkpoint": f"{basepath}/checkpoints/last.bin"
1023 }) 1012 })
1024 accelerator.end_training() 1013 accelerator.end_training()
1025 quit() 1014 quit()