diff options
-rw-r--r-- | dreambooth.py | 101 | ||||
-rw-r--r-- | infer.py | 6 | ||||
-rw-r--r-- | textual_inversion.py | 15 |
3 files changed, 77 insertions, 45 deletions
diff --git a/dreambooth.py b/dreambooth.py index da8399f..72c56cd 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -170,14 +170,14 @@ def parse_args(): | |||
170 | parser.add_argument( | 170 | parser.add_argument( |
171 | "--lr_warmup_steps", | 171 | "--lr_warmup_steps", |
172 | type=int, | 172 | type=int, |
173 | default=300, | 173 | default=500, |
174 | help="Number of steps for the warmup in the lr scheduler." | 174 | help="Number of steps for the warmup in the lr scheduler." |
175 | ) | 175 | ) |
176 | parser.add_argument( | 176 | parser.add_argument( |
177 | "--lr_cycles", | 177 | "--lr_cycles", |
178 | type=int, | 178 | type=int, |
179 | default=None, | 179 | default=None, |
180 | help="Number of restart cycles in the lr scheduler." | 180 | help="Number of restart cycles in the lr scheduler (if supported)." |
181 | ) | 181 | ) |
182 | parser.add_argument( | 182 | parser.add_argument( |
183 | "--use_ema", | 183 | "--use_ema", |
@@ -506,11 +506,10 @@ def main(): | |||
506 | 506 | ||
507 | logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) | 507 | logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) |
508 | 508 | ||
509 | save_args(basepath, args) | 509 | args.seed = args.seed or (torch.random.seed() >> 32) |
510 | set_seed(args.seed) | ||
510 | 511 | ||
511 | # If passed along, set the training seed now. | 512 | save_args(basepath, args) |
512 | if args.seed is not None: | ||
513 | set_seed(args.seed) | ||
514 | 513 | ||
515 | # Load the tokenizer and add the placeholder token as a additional special token | 514 | # Load the tokenizer and add the placeholder token as a additional special token |
516 | if args.tokenizer_name: | 515 | if args.tokenizer_name: |
@@ -523,13 +522,22 @@ def main(): | |||
523 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae') | 522 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae') |
524 | unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='unet') | 523 | unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='unet') |
525 | 524 | ||
526 | ema_unet = EMAModel( | 525 | ema_unet = None |
527 | unet, | 526 | if args.use_ema: |
528 | inv_gamma=args.ema_inv_gamma, | 527 | ema_unet = EMAModel( |
529 | power=args.ema_power, | 528 | unet, |
530 | max_value=args.ema_max_decay, | 529 | inv_gamma=args.ema_inv_gamma, |
531 | device=accelerator.device | 530 | power=args.ema_power, |
532 | ) if args.use_ema else None | 531 | max_value=args.ema_max_decay, |
532 | device=accelerator.device | ||
533 | ) | ||
534 | |||
535 | if args.gradient_checkpointing: | ||
536 | unet.enable_gradient_checkpointing() | ||
537 | text_encoder.gradient_checkpointing_enable() | ||
538 | |||
539 | # Freeze text_encoder and vae | ||
540 | freeze_params(vae.parameters()) | ||
533 | 541 | ||
534 | if args.initializer_token is not None: | 542 | if args.initializer_token is not None: |
535 | # Convert the initializer_token, placeholder_token to ids | 543 | # Convert the initializer_token, placeholder_token to ids |
@@ -545,22 +553,22 @@ def main(): | |||
545 | print(f"Training new token {args.placeholder_token}.") | 553 | print(f"Training new token {args.placeholder_token}.") |
546 | placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) | 554 | placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) |
547 | 555 | ||
556 | # Resize the token embeddings as we are adding new special tokens to the tokenizer | ||
548 | text_encoder.resize_token_embeddings(len(tokenizer)) | 557 | text_encoder.resize_token_embeddings(len(tokenizer)) |
549 | token_embeds = text_encoder.get_input_embeddings() | ||
550 | initializer_token_embeddings = token_embeds(initializer_token_ids) | ||
551 | token_embeds.weight.data[placeholder_token_id] = initializer_token_embeddings | ||
552 | |||
553 | prompt_processor = PromptProcessor(tokenizer, text_encoder) | ||
554 | 558 | ||
555 | if args.gradient_checkpointing: | 559 | # Initialise the newly added placeholder token with the embeddings of the initializer token |
556 | unet.enable_gradient_checkpointing() | 560 | token_embeds = text_encoder.get_input_embeddings().weight.data |
557 | text_encoder.gradient_checkpointing_enable() | 561 | original_token_embeds = token_embeds.detach().clone().to(accelerator.device) |
562 | initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) | ||
563 | token_embeds[placeholder_token_id] = initializer_token_embeddings | ||
558 | 564 | ||
559 | # slice_size = unet.config.attention_head_dim // 2 | 565 | freeze_params(itertools.chain( |
560 | # unet.set_attention_slice(slice_size) | 566 | text_encoder.text_model.encoder.parameters(), |
567 | text_encoder.text_model.final_layer_norm.parameters(), | ||
568 | text_encoder.text_model.embeddings.position_embedding.parameters(), | ||
569 | )) | ||
561 | 570 | ||
562 | # Freeze text_encoder and vae | 571 | prompt_processor = PromptProcessor(tokenizer, text_encoder) |
563 | freeze_params(vae.parameters()) | ||
564 | 572 | ||
565 | if args.scale_lr: | 573 | if args.scale_lr: |
566 | args.learning_rate_unet = ( | 574 | args.learning_rate_unet = ( |
@@ -583,6 +591,11 @@ def main(): | |||
583 | else: | 591 | else: |
584 | optimizer_class = torch.optim.AdamW | 592 | optimizer_class = torch.optim.AdamW |
585 | 593 | ||
594 | if args.initializer_token is not None: | ||
595 | text_encoder_params_to_optimize = text_encoder.get_input_embeddings().parameters() | ||
596 | else: | ||
597 | text_encoder_params_to_optimize = text_encoder.parameters() | ||
598 | |||
586 | # Initialize the optimizer | 599 | # Initialize the optimizer |
587 | optimizer = optimizer_class( | 600 | optimizer = optimizer_class( |
588 | [ | 601 | [ |
@@ -591,7 +604,7 @@ def main(): | |||
591 | 'lr': args.learning_rate_unet, | 604 | 'lr': args.learning_rate_unet, |
592 | }, | 605 | }, |
593 | { | 606 | { |
594 | 'params': text_encoder.parameters(), | 607 | 'params': text_encoder_params_to_optimize, |
595 | 'lr': args.learning_rate_text, | 608 | 'lr': args.learning_rate_text, |
596 | } | 609 | } |
597 | ], | 610 | ], |
@@ -849,9 +862,27 @@ def main(): | |||
849 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") | 862 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") |
850 | 863 | ||
851 | accelerator.backward(loss) | 864 | accelerator.backward(loss) |
865 | |||
866 | if args.initializer_token is not None: | ||
867 | # Keep the token embeddings fixed except the newly added | ||
868 | # embeddings for the concept, as we only want to optimize the concept embeddings | ||
869 | if accelerator.num_processes > 1: | ||
870 | token_embeds = text_encoder.module.get_input_embeddings().weight | ||
871 | else: | ||
872 | token_embeds = text_encoder.get_input_embeddings().weight | ||
873 | |||
874 | # Get the index for tokens that we want to freeze | ||
875 | index_fixed_tokens = torch.arange(len(tokenizer)) != placeholder_token_id | ||
876 | token_embeds.data[index_fixed_tokens, :] = original_token_embeds[index_fixed_tokens, :] | ||
877 | |||
852 | if accelerator.sync_gradients: | 878 | if accelerator.sync_gradients: |
853 | accelerator.clip_grad_norm_(itertools.chain( | 879 | params_to_clip = ( |
854 | unet.parameters(), text_encoder.parameters()), args.max_grad_norm) | 880 | unet.parameters() |
881 | if args.initializer_token is not None | ||
882 | else itertools.chain(unet.parameters(), text_encoder.parameters()) | ||
883 | ) | ||
884 | accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) | ||
885 | |||
855 | optimizer.step() | 886 | optimizer.step() |
856 | if not accelerator.optimizer_step_was_skipped: | 887 | if not accelerator.optimizer_step_was_skipped: |
857 | lr_scheduler.step() | 888 | lr_scheduler.step() |
@@ -896,8 +927,8 @@ def main(): | |||
896 | text_encoder.eval() | 927 | text_encoder.eval() |
897 | val_loss = 0.0 | 928 | val_loss = 0.0 |
898 | 929 | ||
899 | for step, batch in enumerate(val_dataloader): | 930 | with torch.inference_mode(): |
900 | with torch.no_grad(): | 931 | for step, batch in enumerate(val_dataloader): |
901 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 932 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() |
902 | latents = latents * 0.18215 | 933 | latents = latents * 0.18215 |
903 | 934 | ||
@@ -920,12 +951,12 @@ def main(): | |||
920 | loss = loss.detach().item() | 951 | loss = loss.detach().item() |
921 | val_loss += loss | 952 | val_loss += loss |
922 | 953 | ||
923 | if accelerator.sync_gradients: | 954 | if accelerator.sync_gradients: |
924 | local_progress_bar.update(1) | 955 | local_progress_bar.update(1) |
925 | global_progress_bar.update(1) | 956 | global_progress_bar.update(1) |
926 | 957 | ||
927 | logs = {"val/loss": loss} | 958 | logs = {"val/loss": loss} |
928 | local_progress_bar.set_postfix(**logs) | 959 | local_progress_bar.set_postfix(**logs) |
929 | 960 | ||
930 | val_loss /= len(val_dataloader) | 961 | val_loss /= len(val_dataloader) |
931 | 962 | ||
@@ -258,7 +258,7 @@ def generate(output_dir, pipeline, args): | |||
258 | output_dir = output_dir.joinpath(f"{now}_{slugify(args.prompt[0])[:100]}") | 258 | output_dir = output_dir.joinpath(f"{now}_{slugify(args.prompt[0])[:100]}") |
259 | output_dir.mkdir(parents=True, exist_ok=True) | 259 | output_dir.mkdir(parents=True, exist_ok=True) |
260 | 260 | ||
261 | seed = args.seed or torch.random.seed() | 261 | args.seed = args.seed or torch.random.seed() |
262 | 262 | ||
263 | save_args(output_dir, args) | 263 | save_args(output_dir, args) |
264 | 264 | ||
@@ -276,7 +276,7 @@ def generate(output_dir, pipeline, args): | |||
276 | dynamic_ncols=True | 276 | dynamic_ncols=True |
277 | ) | 277 | ) |
278 | 278 | ||
279 | generator = torch.Generator(device="cuda").manual_seed(seed + i) | 279 | generator = torch.Generator(device="cuda").manual_seed(args.seed + i) |
280 | images = pipeline( | 280 | images = pipeline( |
281 | prompt=args.prompt * (args.batch_size // len(args.prompt)), | 281 | prompt=args.prompt * (args.batch_size // len(args.prompt)), |
282 | height=args.height, | 282 | height=args.height, |
@@ -290,7 +290,7 @@ def generate(output_dir, pipeline, args): | |||
290 | ).images | 290 | ).images |
291 | 291 | ||
292 | for j, image in enumerate(images): | 292 | for j, image in enumerate(images): |
293 | image.save(output_dir.joinpath(f"{seed + i}_{j}.jpg")) | 293 | image.save(output_dir.joinpath(f"{args.seed + i}_{j}.jpg")) |
294 | 294 | ||
295 | if torch.cuda.is_available(): | 295 | if torch.cuda.is_available(): |
296 | torch.cuda.empty_cache() | 296 | torch.cuda.empty_cache() |
diff --git a/textual_inversion.py b/textual_inversion.py index 8f266e0..fe56d36 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
@@ -521,6 +521,7 @@ def main(): | |||
521 | 521 | ||
522 | if args.gradient_checkpointing: | 522 | if args.gradient_checkpointing: |
523 | unet.enable_gradient_checkpointing() | 523 | unet.enable_gradient_checkpointing() |
524 | text_encoder.gradient_checkpointing_enable() | ||
524 | 525 | ||
525 | # slice_size = unet.config.attention_head_dim // 2 | 526 | # slice_size = unet.config.attention_head_dim // 2 |
526 | # unet.set_attention_slice(slice_size) | 527 | # unet.set_attention_slice(slice_size) |
@@ -875,8 +876,8 @@ def main(): | |||
875 | text_encoder.eval() | 876 | text_encoder.eval() |
876 | val_loss = 0.0 | 877 | val_loss = 0.0 |
877 | 878 | ||
878 | for step, batch in enumerate(val_dataloader): | 879 | with torch.inference_mode(): |
879 | with torch.no_grad(): | 880 | for step, batch in enumerate(val_dataloader): |
880 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 881 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() |
881 | latents = latents * 0.18215 | 882 | latents = latents * 0.18215 |
882 | 883 | ||
@@ -899,12 +900,12 @@ def main(): | |||
899 | loss = loss.detach().item() | 900 | loss = loss.detach().item() |
900 | val_loss += loss | 901 | val_loss += loss |
901 | 902 | ||
902 | if accelerator.sync_gradients: | 903 | if accelerator.sync_gradients: |
903 | local_progress_bar.update(1) | 904 | local_progress_bar.update(1) |
904 | global_progress_bar.update(1) | 905 | global_progress_bar.update(1) |
905 | 906 | ||
906 | logs = {"val/loss": loss} | 907 | logs = {"val/loss": loss} |
907 | local_progress_bar.set_postfix(**logs) | 908 | local_progress_bar.set_postfix(**logs) |
908 | 909 | ||
909 | val_loss /= len(val_dataloader) | 910 | val_loss /= len(val_dataloader) |
910 | 911 | ||