diff options
| -rw-r--r-- | data/csv.py | 2 | ||||
| -rw-r--r-- | dreambooth.py | 153 |
2 files changed, 83 insertions, 72 deletions
diff --git a/data/csv.py b/data/csv.py index 4c91ded..df15c5a 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -76,7 +76,7 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 76 | 76 | ||
| 77 | def prepare_data(self): | 77 | def prepare_data(self): |
| 78 | metadata = pd.read_json(self.data_file) | 78 | metadata = pd.read_json(self.data_file) |
| 79 | metadata = [item for item in metadata.itertuples() if "skip" not in item or item.skip != True] | 79 | metadata = [item for item in metadata.itertuples() if not hasattr(item, "skip") or item.skip != True] |
| 80 | num_images = len(metadata) | 80 | num_images = len(metadata) |
| 81 | 81 | ||
| 82 | valid_set_size = int(num_images * 0.2) | 82 | valid_set_size = int(num_images * 0.2) |
diff --git a/dreambooth.py b/dreambooth.py index 1539e81..72c56cd 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
| @@ -810,93 +810,85 @@ def main(): | |||
| 810 | ) | 810 | ) |
| 811 | global_progress_bar.set_description("Total progress") | 811 | global_progress_bar.set_description("Total progress") |
| 812 | 812 | ||
| 813 | def run_step(batch, train=False, class_images=False): | 813 | try: |
| 814 | # Convert images to latent space | 814 | for epoch in range(num_epochs): |
| 815 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 815 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") |
| 816 | latents = latents * 0.18215 | 816 | local_progress_bar.reset() |
| 817 | |||
| 818 | # Sample noise that we'll add to the latents | ||
| 819 | noise = torch.randn_like(latents) | ||
| 820 | bsz = latents.shape[0] | ||
| 821 | # Sample a random timestep for each image | ||
| 822 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, (bsz,), device=latents.device) | ||
| 823 | timesteps = timesteps.long() | ||
| 824 | 817 | ||
| 825 | # Add noise to the latents according to the noise magnitude at each timestep | 818 | unet.train() |
| 826 | # (this is the forward diffusion process) | 819 | text_encoder.train() |
| 827 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | 820 | train_loss = 0.0 |
| 828 | 821 | ||
| 829 | # Get the text embedding for conditioning | 822 | sample_checkpoint = False |
| 830 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) | ||
| 831 | 823 | ||
| 832 | # Predict the noise residual | 824 | for step, batch in enumerate(train_dataloader): |
| 833 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 825 | with accelerator.accumulate(itertools.chain(unet, text_encoder)): |
| 826 | # Convert images to latent space | ||
| 827 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | ||
| 828 | latents = latents * 0.18215 | ||
| 834 | 829 | ||
| 835 | if class_images: | 830 | # Sample noise that we'll add to the latents |
| 836 | # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. | 831 | noise = torch.randn_like(latents) |
| 837 | noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) | 832 | bsz = latents.shape[0] |
| 838 | noise, noise_prior = torch.chunk(noise, 2, dim=0) | 833 | # Sample a random timestep for each image |
| 834 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, | ||
| 835 | (bsz,), device=latents.device) | ||
| 836 | timesteps = timesteps.long() | ||
| 839 | 837 | ||
| 840 | # Compute instance loss | 838 | # Add noise to the latents according to the noise magnitude at each timestep |
| 841 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="none").mean([1, 2, 3]).mean() | 839 | # (this is the forward diffusion process) |
| 840 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | ||
| 842 | 841 | ||
| 843 | # Compute prior loss | 842 | # Get the text embedding for conditioning |
| 844 | prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean") | 843 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) |
| 845 | 844 | ||
| 846 | # Add the prior loss to the instance loss. | 845 | # Predict the noise residual |
| 847 | loss = loss + args.prior_loss_weight * prior_loss | 846 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
| 848 | else: | ||
| 849 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") | ||
| 850 | 847 | ||
| 851 | if train: | 848 | if args.num_class_images != 0: |
| 852 | accelerator.backward(loss) | 849 | # Chunk the noise and noise_pred into two parts and compute the loss on each part separately. |
| 850 | noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0) | ||
| 851 | noise, noise_prior = torch.chunk(noise, 2, dim=0) | ||
| 853 | 852 | ||
| 854 | if args.initializer_token is not None: | 853 | # Compute instance loss |
| 855 | # Keep the token embeddings fixed except the newly added | 854 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="none").mean([1, 2, 3]).mean() |
| 856 | # embeddings for the concept, as we only want to optimize the concept embeddings | ||
| 857 | if accelerator.num_processes > 1: | ||
| 858 | token_embeds = text_encoder.module.get_input_embeddings().weight | ||
| 859 | else: | ||
| 860 | token_embeds = text_encoder.get_input_embeddings().weight | ||
| 861 | 855 | ||
| 862 | # Get the index for tokens that we want to freeze | 856 | # Compute prior loss |
| 863 | index_fixed_tokens = torch.arange(len(tokenizer)) != placeholder_token_id | 857 | prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean") |
| 864 | token_embeds.data[index_fixed_tokens, :] = original_token_embeds[index_fixed_tokens, :] | ||
| 865 | 858 | ||
| 866 | if accelerator.sync_gradients: | 859 | # Add the prior loss to the instance loss. |
| 867 | params_to_clip = ( | 860 | loss = loss + args.prior_loss_weight * prior_loss |
| 868 | unet.parameters() | 861 | else: |
| 869 | if args.initializer_token is not None | 862 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") |
| 870 | else itertools.chain(unet.parameters(), text_encoder.parameters()) | ||
| 871 | ) | ||
| 872 | accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) | ||
| 873 | 863 | ||
| 874 | optimizer.step() | 864 | accelerator.backward(loss) |
| 875 | if not accelerator.optimizer_step_was_skipped: | ||
| 876 | lr_scheduler.step() | ||
| 877 | optimizer.zero_grad(set_to_none=True) | ||
| 878 | 865 | ||
| 879 | loss = loss.detach().item() | 866 | if args.initializer_token is not None: |
| 880 | return loss | 867 | # Keep the token embeddings fixed except the newly added |
| 868 | # embeddings for the concept, as we only want to optimize the concept embeddings | ||
| 869 | if accelerator.num_processes > 1: | ||
| 870 | token_embeds = text_encoder.module.get_input_embeddings().weight | ||
| 871 | else: | ||
| 872 | token_embeds = text_encoder.get_input_embeddings().weight | ||
| 881 | 873 | ||
| 882 | try: | 874 | # Get the index for tokens that we want to freeze |
| 883 | for epoch in range(num_epochs): | 875 | index_fixed_tokens = torch.arange(len(tokenizer)) != placeholder_token_id |
| 884 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") | 876 | token_embeds.data[index_fixed_tokens, :] = original_token_embeds[index_fixed_tokens, :] |
| 885 | local_progress_bar.reset() | ||
| 886 | 877 | ||
| 887 | unet.train() | 878 | if accelerator.sync_gradients: |
| 888 | text_encoder.train() | 879 | params_to_clip = ( |
| 889 | train_loss = 0.0 | 880 | unet.parameters() |
| 881 | if args.initializer_token is not None | ||
| 882 | else itertools.chain(unet.parameters(), text_encoder.parameters()) | ||
| 883 | ) | ||
| 884 | accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) | ||
| 890 | 885 | ||
| 891 | sample_checkpoint = False | 886 | optimizer.step() |
| 887 | if not accelerator.optimizer_step_was_skipped: | ||
| 888 | lr_scheduler.step() | ||
| 889 | optimizer.zero_grad(set_to_none=True) | ||
| 892 | 890 | ||
| 893 | for step, batch in enumerate(train_dataloader): | 891 | loss = loss.detach().item() |
| 894 | with accelerator.accumulate(itertools.chain(unet, text_encoder)): | ||
| 895 | loss = run_step( | ||
| 896 | batch, | ||
| 897 | train=True, | ||
| 898 | class_images=args.num_class_images != 0 | ||
| 899 | ) | ||
| 900 | train_loss += loss | 892 | train_loss += loss |
| 901 | 893 | ||
| 902 | # Checks if the accelerator has performed an optimization step behind the scenes | 894 | # Checks if the accelerator has performed an optimization step behind the scenes |
| @@ -937,7 +929,26 @@ def main(): | |||
| 937 | 929 | ||
| 938 | with torch.inference_mode(): | 930 | with torch.inference_mode(): |
| 939 | for step, batch in enumerate(val_dataloader): | 931 | for step, batch in enumerate(val_dataloader): |
| 940 | loss = run_step(batch) | 932 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() |
| 933 | latents = latents * 0.18215 | ||
| 934 | |||
| 935 | noise = torch.randn_like(latents) | ||
| 936 | bsz = latents.shape[0] | ||
| 937 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, | ||
| 938 | (bsz,), device=latents.device) | ||
| 939 | timesteps = timesteps.long() | ||
| 940 | |||
| 941 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | ||
| 942 | |||
| 943 | encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"]) | ||
| 944 | |||
| 945 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | ||
| 946 | |||
| 947 | noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) | ||
| 948 | |||
| 949 | loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") | ||
| 950 | |||
| 951 | loss = loss.detach().item() | ||
| 941 | val_loss += loss | 952 | val_loss += loss |
| 942 | 953 | ||
| 943 | if accelerator.sync_gradients: | 954 | if accelerator.sync_gradients: |
