diff options
| -rw-r--r-- | dreambooth.py | 101 | ||||
| -rw-r--r-- | infer.py | 6 | ||||
| -rw-r--r-- | textual_inversion.py | 15 |
3 files changed, 77 insertions, 45 deletions
diff --git a/dreambooth.py b/dreambooth.py index da8399f..72c56cd 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
| @@ -170,14 +170,14 @@ def parse_args(): | |||
| 170 | parser.add_argument( | 170 | parser.add_argument( |
| 171 | "--lr_warmup_steps", | 171 | "--lr_warmup_steps", |
| 172 | type=int, | 172 | type=int, |
| 173 | default=300, | 173 | default=500, |
| 174 | help="Number of steps for the warmup in the lr scheduler." | 174 | help="Number of steps for the warmup in the lr scheduler." |
| 175 | ) | 175 | ) |
| 176 | parser.add_argument( | 176 | parser.add_argument( |
| 177 | "--lr_cycles", | 177 | "--lr_cycles", |
| 178 | type=int, | 178 | type=int, |
| 179 | default=None, | 179 | default=None, |
| 180 | help="Number of restart cycles in the lr scheduler." | 180 | help="Number of restart cycles in the lr scheduler (if supported)." |
| 181 | ) | 181 | ) |
| 182 | parser.add_argument( | 182 | parser.add_argument( |
| 183 | "--use_ema", | 183 | "--use_ema", |
| @@ -506,11 +506,10 @@ def main(): | |||
| 506 | 506 | ||
| 507 | logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) | 507 | logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) |
| 508 | 508 | ||
| 509 | save_args(basepath, args) | 509 | args.seed = args.seed or (torch.random.seed() >> 32) |
| 510 | set_seed(args.seed) | ||
| 510 | 511 | ||
| 511 | # If passed along, set the training seed now. | 512 | save_args(basepath, args) |
| 512 | if args.seed is not None: | ||
| 513 | set_seed(args.seed) | ||
| 514 | 513 | ||
| 515 | # Load the tokenizer and add the placeholder token as a additional special token | 514 | # Load the tokenizer and add the placeholder token as a additional special token |
| 516 | if args.tokenizer_name: | 515 | if args.tokenizer_name: |
| @@ -523,13 +522,22 @@ def main(): | |||
| 523 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae') | 522 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae') |
| 524 | unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='unet') | 523 | unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='unet') |
| 525 | 524 | ||
| 526 | ema_unet = EMAModel( | 525 | ema_unet = None |
| 527 | unet, | 526 | if args.use_ema: |
| 528 | inv_gamma=args.ema_inv_gamma, | 527 | ema_unet = EMAModel( |
| 529 | power=args.ema_power, | 528 | unet, |
| 530 | max_value=args.ema_max_decay, | 529 | inv_gamma=args.ema_inv_gamma, |
| 531 | device=accelerator.device | 530 | power=args.ema_power, |
| 532 | ) if args.use_ema else None | 531 | max_value=args.ema_max_decay, |
| 532 | device=accelerator.device | ||
| 533 | ) | ||
| 534 | |||
| 535 | if args.gradient_checkpointing: | ||
| 536 | unet.enable_gradient_checkpointing() | ||
| 537 | text_encoder.gradient_checkpointing_enable() | ||
| 538 | |||
| 539 | # Freeze text_encoder and vae | ||
| 540 | freeze_params(vae.parameters()) | ||
| 533 | 541 | ||
| 534 | if args.initializer_token is not None: | 542 | if args.initializer_token is not None: |
| 535 | # Convert the initializer_token, placeholder_token to ids | 543 | # Convert the initializer_token, placeholder_token to ids |
| @@ -545,22 +553,22 @@ def main(): | |||
| 545 | print(f"Training new token {args.placeholder_token}.") | 553 | print(f"Training new token {args.placeholder_token}.") |
| 546 | placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) | 554 | placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) |
| 547 | 555 | ||
| 556 | # Resize the token embeddings as we are adding new special tokens to the tokenizer | ||
| 548 | text_encoder.resize_token_embeddings(len(tokenizer)) | 557 | text_encoder.resize_token_embeddings(len(tokenizer)) |
| 549 | token_embeds = text_encoder.get_input_embeddings() | ||
| 550 | initializer_token_embeddings = token_embeds(initializer_token_ids) | ||
| 551 | token_embeds.weight.data[placeholder_token_id] = initializer_token_embeddings | ||
| 552 | |||
| 553 | prompt_processor = PromptProcessor(tokenizer, text_encoder) | ||
| 554 | 558 | ||
| 555 | if args.gradient_checkpointing: | 559 | # Initialise the newly added placeholder token with the embeddings of the initializer token |
| 556 | unet.enable_gradient_checkpointing() | 560 | token_embeds = text_encoder.get_input_embeddings().weight.data |
| 557 | text_encoder.gradient_checkpointing_enable() | 561 | original_token_embeds = token_embeds.detach().clone().to(accelerator.device) |
| 562 | initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids) | ||
| 563 | token_embeds[placeholder_token_id] = initializer_token_embeddings | ||
| 558 | 564 | ||
| 559 | # slice_size = unet.config.attention_head_dim // 2 | 565 | freeze_params(itertools.chain( |
| 560 | # unet.set_attention_slice(slice_size) | 566 | text_encoder.text_model.encoder.parameters(), |
| 567 | text_encoder.text_model.final_layer_norm.parameters(), | ||
| 568 | text_encoder.text_model.embeddings.position_embedding.parameters(), | ||
| 569 | )) | ||
| 561 | 570 | ||
| 562 | # Freeze text_encoder and vae | 571 | prompt_processor = PromptProcessor(tokenizer, text_encoder) |
| 563 | freeze_params(vae.parameters()) | ||
| 564 | 572 | ||
| 565 | if args.scale_lr: | 573 | if args.scale_lr: |
| 566 | args.learning_rate_unet = ( | 574 | args.learning_rate_unet = ( |
| @@ -583,6 +591,11 @@ def main(): | |||
| 583 | else: | 591 | else: |
| 584 | optimizer_class = torch.optim.AdamW | 592 | optimizer_class = torch.optim.AdamW |
| 585 | 593 | ||
| 594 | if args.initializer_token is not None: | ||
| 595 | text_encoder_params_to_optimize = text_encoder.get_input_embeddings().parameters() | ||
| 596 | else: | ||
| 597 | text_encoder_params_to_optimize = text_encoder.parameters() | ||
| 598 | |||
| 586 | # Initialize the optimizer | 599 | # Initialize the optimizer |
| 587 | optimizer = optimizer_class( | 600 | optimizer = optimizer_class( |
| 588 | [ | 601 | [ |
| @@ -591,7 +604,7 @@ def main(): | |||
| 591 | 'lr': args.learning_rate_unet, | 604 | 'lr': args.learning_rate_unet, |
| 592 | }, | 605 | }, |
| 593 | { | 606 | { |
| 594 | 'params': text_encoder.parameters(), | 607 | 'params': text_encoder_params_to_optimize, |
| 595 | 'lr': args.learning_rate_text, | 608 | 'lr': args.learning_rate_text, |
| 596 | } | 609 | } |
| 597 | ], | 610 | ], |
| @@ -849,9 +862,27 @@ def main(): | |||
| 849 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") | 862 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") |
| 850 | 863 | ||
| 851 | accelerator.backward(loss) | 864 | accelerator.backward(loss) |
| 865 | |||
| 866 | if args.initializer_token is not None: | ||
| 867 | # Keep the token embeddings fixed except the newly added | ||
| 868 | # embeddings for the concept, as we only want to optimize the concept embeddings | ||
| 869 | if accelerator.num_processes > 1: | ||
| 870 | token_embeds = text_encoder.module.get_input_embeddings().weight | ||
| 871 | else: | ||
| 872 | token_embeds = text_encoder.get_input_embeddings().weight | ||
| 873 | |||
| 874 | # Get the index for tokens that we want to freeze | ||
| 875 | index_fixed_tokens = torch.arange(len(tokenizer)) != placeholder_token_id | ||
| 876 | token_embeds.data[index_fixed_tokens, :] = original_token_embeds[index_fixed_tokens, :] | ||
| 877 | |||
| 852 | if accelerator.sync_gradients: | 878 | if accelerator.sync_gradients: |
| 853 | accelerator.clip_grad_norm_(itertools.chain( | 879 | params_to_clip = ( |
| 854 | unet.parameters(), text_encoder.parameters()), args.max_grad_norm) | 880 | unet.parameters() |
| 881 | if args.initializer_token is not None | ||
| 882 | else itertools.chain(unet.parameters(), text_encoder.parameters()) | ||
| 883 | ) | ||
| 884 | accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) | ||
| 885 | |||
| 855 | optimizer.step() | 886 | optimizer.step() |
| 856 | if not accelerator.optimizer_step_was_skipped: | 887 | if not accelerator.optimizer_step_was_skipped: |
| 857 | lr_scheduler.step() | 888 | lr_scheduler.step() |
| @@ -896,8 +927,8 @@ def main(): | |||
| 896 | text_encoder.eval() | 927 | text_encoder.eval() |
| 897 | val_loss = 0.0 | 928 | val_loss = 0.0 |
| 898 | 929 | ||
| 899 | for step, batch in enumerate(val_dataloader): | 930 | with torch.inference_mode(): |
| 900 | with torch.no_grad(): | 931 | for step, batch in enumerate(val_dataloader): |
| 901 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 932 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() |
| 902 | latents = latents * 0.18215 | 933 | latents = latents * 0.18215 |
| 903 | 934 | ||
| @@ -920,12 +951,12 @@ def main(): | |||
| 920 | loss = loss.detach().item() | 951 | loss = loss.detach().item() |
| 921 | val_loss += loss | 952 | val_loss += loss |
| 922 | 953 | ||
| 923 | if accelerator.sync_gradients: | 954 | if accelerator.sync_gradients: |
| 924 | local_progress_bar.update(1) | 955 | local_progress_bar.update(1) |
| 925 | global_progress_bar.update(1) | 956 | global_progress_bar.update(1) |
| 926 | 957 | ||
| 927 | logs = {"val/loss": loss} | 958 | logs = {"val/loss": loss} |
| 928 | local_progress_bar.set_postfix(**logs) | 959 | local_progress_bar.set_postfix(**logs) |
| 929 | 960 | ||
| 930 | val_loss /= len(val_dataloader) | 961 | val_loss /= len(val_dataloader) |
| 931 | 962 | ||
| @@ -258,7 +258,7 @@ def generate(output_dir, pipeline, args): | |||
| 258 | output_dir = output_dir.joinpath(f"{now}_{slugify(args.prompt[0])[:100]}") | 258 | output_dir = output_dir.joinpath(f"{now}_{slugify(args.prompt[0])[:100]}") |
| 259 | output_dir.mkdir(parents=True, exist_ok=True) | 259 | output_dir.mkdir(parents=True, exist_ok=True) |
| 260 | 260 | ||
| 261 | seed = args.seed or torch.random.seed() | 261 | args.seed = args.seed or torch.random.seed() |
| 262 | 262 | ||
| 263 | save_args(output_dir, args) | 263 | save_args(output_dir, args) |
| 264 | 264 | ||
| @@ -276,7 +276,7 @@ def generate(output_dir, pipeline, args): | |||
| 276 | dynamic_ncols=True | 276 | dynamic_ncols=True |
| 277 | ) | 277 | ) |
| 278 | 278 | ||
| 279 | generator = torch.Generator(device="cuda").manual_seed(seed + i) | 279 | generator = torch.Generator(device="cuda").manual_seed(args.seed + i) |
| 280 | images = pipeline( | 280 | images = pipeline( |
| 281 | prompt=args.prompt * (args.batch_size // len(args.prompt)), | 281 | prompt=args.prompt * (args.batch_size // len(args.prompt)), |
| 282 | height=args.height, | 282 | height=args.height, |
| @@ -290,7 +290,7 @@ def generate(output_dir, pipeline, args): | |||
| 290 | ).images | 290 | ).images |
| 291 | 291 | ||
| 292 | for j, image in enumerate(images): | 292 | for j, image in enumerate(images): |
| 293 | image.save(output_dir.joinpath(f"{seed + i}_{j}.jpg")) | 293 | image.save(output_dir.joinpath(f"{args.seed + i}_{j}.jpg")) |
| 294 | 294 | ||
| 295 | if torch.cuda.is_available(): | 295 | if torch.cuda.is_available(): |
| 296 | torch.cuda.empty_cache() | 296 | torch.cuda.empty_cache() |
diff --git a/textual_inversion.py b/textual_inversion.py index 8f266e0..fe56d36 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
| @@ -521,6 +521,7 @@ def main(): | |||
| 521 | 521 | ||
| 522 | if args.gradient_checkpointing: | 522 | if args.gradient_checkpointing: |
| 523 | unet.enable_gradient_checkpointing() | 523 | unet.enable_gradient_checkpointing() |
| 524 | text_encoder.gradient_checkpointing_enable() | ||
| 524 | 525 | ||
| 525 | # slice_size = unet.config.attention_head_dim // 2 | 526 | # slice_size = unet.config.attention_head_dim // 2 |
| 526 | # unet.set_attention_slice(slice_size) | 527 | # unet.set_attention_slice(slice_size) |
| @@ -875,8 +876,8 @@ def main(): | |||
| 875 | text_encoder.eval() | 876 | text_encoder.eval() |
| 876 | val_loss = 0.0 | 877 | val_loss = 0.0 |
| 877 | 878 | ||
| 878 | for step, batch in enumerate(val_dataloader): | 879 | with torch.inference_mode(): |
| 879 | with torch.no_grad(): | 880 | for step, batch in enumerate(val_dataloader): |
| 880 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 881 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() |
| 881 | latents = latents * 0.18215 | 882 | latents = latents * 0.18215 |
| 882 | 883 | ||
| @@ -899,12 +900,12 @@ def main(): | |||
| 899 | loss = loss.detach().item() | 900 | loss = loss.detach().item() |
| 900 | val_loss += loss | 901 | val_loss += loss |
| 901 | 902 | ||
| 902 | if accelerator.sync_gradients: | 903 | if accelerator.sync_gradients: |
| 903 | local_progress_bar.update(1) | 904 | local_progress_bar.update(1) |
| 904 | global_progress_bar.update(1) | 905 | global_progress_bar.update(1) |
| 905 | 906 | ||
| 906 | logs = {"val/loss": loss} | 907 | logs = {"val/loss": loss} |
| 907 | local_progress_bar.set_postfix(**logs) | 908 | local_progress_bar.set_postfix(**logs) |
| 908 | 909 | ||
| 909 | val_loss /= len(val_dataloader) | 910 | val_loss /= len(val_dataloader) |
| 910 | 911 | ||
