diff options
| -rw-r--r-- | data/csv.py | 3 | ||||
| -rw-r--r-- | dreambooth.py | 53 | ||||
| -rw-r--r-- | environment.yaml | 4 | ||||
| -rw-r--r-- | infer.py | 10 | ||||
| -rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 11 | ||||
| -rw-r--r-- | schedulers/scheduling_euler_a.py | 210 | ||||
| -rw-r--r-- | textual_inversion.py | 74 |
7 files changed, 142 insertions, 223 deletions
diff --git a/data/csv.py b/data/csv.py index 8637ac1..253ce9e 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -68,13 +68,12 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 68 | item.nprompt if "nprompt" in item else "" | 68 | item.nprompt if "nprompt" in item else "" |
| 69 | ) | 69 | ) |
| 70 | for item in data | 70 | for item in data |
| 71 | if "skip" not in item or item.skip != "x" | ||
| 72 | for i in range(image_multiplier) | 71 | for i in range(image_multiplier) |
| 73 | ] | 72 | ] |
| 74 | 73 | ||
| 75 | def prepare_data(self): | 74 | def prepare_data(self): |
| 76 | metadata = pd.read_csv(self.data_file) | 75 | metadata = pd.read_csv(self.data_file) |
| 77 | metadata = list(metadata.itertuples()) | 76 | metadata = [item for item in metadata.itertuples() if "skip" not in item or item.skip != "x"] |
| 78 | num_images = len(metadata) | 77 | num_images = len(metadata) |
| 79 | 78 | ||
| 80 | valid_set_size = int(num_images * 0.2) | 79 | valid_set_size = int(num_images * 0.2) |
diff --git a/dreambooth.py b/dreambooth.py index 02f83c6..775aea2 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
| @@ -112,7 +112,7 @@ def parse_args(): | |||
| 112 | parser.add_argument( | 112 | parser.add_argument( |
| 113 | "--max_train_steps", | 113 | "--max_train_steps", |
| 114 | type=int, | 114 | type=int, |
| 115 | default=5000, | 115 | default=3000, |
| 116 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", | 116 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", |
| 117 | ) | 117 | ) |
| 118 | parser.add_argument( | 118 | parser.add_argument( |
| @@ -150,7 +150,7 @@ def parse_args(): | |||
| 150 | parser.add_argument( | 150 | parser.add_argument( |
| 151 | "--lr_warmup_steps", | 151 | "--lr_warmup_steps", |
| 152 | type=int, | 152 | type=int, |
| 153 | default=600, | 153 | default=500, |
| 154 | help="Number of steps for the warmup in the lr scheduler." | 154 | help="Number of steps for the warmup in the lr scheduler." |
| 155 | ) | 155 | ) |
| 156 | parser.add_argument( | 156 | parser.add_argument( |
| @@ -167,7 +167,7 @@ def parse_args(): | |||
| 167 | parser.add_argument( | 167 | parser.add_argument( |
| 168 | "--ema_power", | 168 | "--ema_power", |
| 169 | type=float, | 169 | type=float, |
| 170 | default=1.0 | 170 | default=7 / 8 |
| 171 | ) | 171 | ) |
| 172 | parser.add_argument( | 172 | parser.add_argument( |
| 173 | "--ema_max_decay", | 173 | "--ema_max_decay", |
| @@ -468,20 +468,20 @@ def main(): | |||
| 468 | if args.tokenizer_name: | 468 | if args.tokenizer_name: |
| 469 | tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) | 469 | tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) |
| 470 | elif args.pretrained_model_name_or_path: | 470 | elif args.pretrained_model_name_or_path: |
| 471 | tokenizer = CLIPTokenizer.from_pretrained( | 471 | tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') |
| 472 | args.pretrained_model_name_or_path + '/tokenizer' | ||
| 473 | ) | ||
| 474 | 472 | ||
| 475 | # Load models and create wrapper for stable diffusion | 473 | # Load models and create wrapper for stable diffusion |
| 476 | text_encoder = CLIPTextModel.from_pretrained( | 474 | text_encoder = CLIPTextModel.from_pretrained( |
| 477 | args.pretrained_model_name_or_path + '/text_encoder', | 475 | args.pretrained_model_name_or_path, subfolder='text_encoder') |
| 478 | ) | 476 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae') |
| 479 | vae = AutoencoderKL.from_pretrained( | 477 | unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='unet') |
| 480 | args.pretrained_model_name_or_path + '/vae', | 478 | |
| 481 | ) | 479 | ema_unet = EMAModel( |
| 482 | unet = UNet2DConditionModel.from_pretrained( | 480 | unet, |
| 483 | args.pretrained_model_name_or_path + '/unet', | 481 | inv_gamma=args.ema_inv_gamma, |
| 484 | ) | 482 | power=args.ema_power, |
| 483 | max_value=args.ema_max_decay | ||
| 484 | ) if args.use_ema else None | ||
| 485 | 485 | ||
| 486 | if args.gradient_checkpointing: | 486 | if args.gradient_checkpointing: |
| 487 | unet.enable_gradient_checkpointing() | 487 | unet.enable_gradient_checkpointing() |
| @@ -538,7 +538,7 @@ def main(): | |||
| 538 | pixel_values += [example["class_images"] for example in examples] | 538 | pixel_values += [example["class_images"] for example in examples] |
| 539 | 539 | ||
| 540 | pixel_values = torch.stack(pixel_values) | 540 | pixel_values = torch.stack(pixel_values) |
| 541 | pixel_values = pixel_values.to(dtype=torch.float32, memory_format=torch.contiguous_format) | 541 | pixel_values = pixel_values.to(memory_format=torch.contiguous_format) |
| 542 | 542 | ||
| 543 | input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids | 543 | input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids |
| 544 | 544 | ||
| @@ -629,16 +629,10 @@ def main(): | |||
| 629 | unet, optimizer, train_dataloader, val_dataloader, lr_scheduler | 629 | unet, optimizer, train_dataloader, val_dataloader, lr_scheduler |
| 630 | ) | 630 | ) |
| 631 | 631 | ||
| 632 | ema_unet = EMAModel( | ||
| 633 | unet, | ||
| 634 | inv_gamma=args.ema_inv_gamma, | ||
| 635 | power=args.ema_power, | ||
| 636 | max_value=args.ema_max_decay | ||
| 637 | ) if args.use_ema else None | ||
| 638 | |||
| 639 | # Move text_encoder and vae to device | 632 | # Move text_encoder and vae to device |
| 640 | text_encoder.to(accelerator.device) | 633 | text_encoder.to(accelerator.device) |
| 641 | vae.to(accelerator.device) | 634 | vae.to(accelerator.device) |
| 635 | ema_unet.averaged_model.to(accelerator.device) | ||
| 642 | 636 | ||
| 643 | # Keep text_encoder and vae in eval mode as we don't train these | 637 | # Keep text_encoder and vae in eval mode as we don't train these |
| 644 | text_encoder.eval() | 638 | text_encoder.eval() |
| @@ -698,7 +692,7 @@ def main(): | |||
| 698 | disable=not accelerator.is_local_main_process, | 692 | disable=not accelerator.is_local_main_process, |
| 699 | dynamic_ncols=True | 693 | dynamic_ncols=True |
| 700 | ) | 694 | ) |
| 701 | local_progress_bar.set_description("Batch X out of Y") | 695 | local_progress_bar.set_description("Epoch X / Y") |
| 702 | 696 | ||
| 703 | global_progress_bar = tqdm( | 697 | global_progress_bar = tqdm( |
| 704 | range(args.max_train_steps + val_steps), | 698 | range(args.max_train_steps + val_steps), |
| @@ -709,7 +703,7 @@ def main(): | |||
| 709 | 703 | ||
| 710 | try: | 704 | try: |
| 711 | for epoch in range(num_epochs): | 705 | for epoch in range(num_epochs): |
| 712 | local_progress_bar.set_description(f"Batch {epoch + 1} out of {num_epochs}") | 706 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") |
| 713 | local_progress_bar.reset() | 707 | local_progress_bar.reset() |
| 714 | 708 | ||
| 715 | unet.train() | 709 | unet.train() |
| @@ -720,9 +714,8 @@ def main(): | |||
| 720 | for step, batch in enumerate(train_dataloader): | 714 | for step, batch in enumerate(train_dataloader): |
| 721 | with accelerator.accumulate(unet): | 715 | with accelerator.accumulate(unet): |
| 722 | # Convert images to latent space | 716 | # Convert images to latent space |
| 723 | with torch.no_grad(): | 717 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() |
| 724 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 718 | latents = latents * 0.18215 |
| 725 | latents = latents * 0.18215 | ||
| 726 | 719 | ||
| 727 | # Sample noise that we'll add to the latents | 720 | # Sample noise that we'll add to the latents |
| 728 | noise = torch.randn(latents.shape).to(latents.device) | 721 | noise = torch.randn(latents.shape).to(latents.device) |
| @@ -737,8 +730,7 @@ def main(): | |||
| 737 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | 730 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) |
| 738 | 731 | ||
| 739 | # Get the text embedding for conditioning | 732 | # Get the text embedding for conditioning |
| 740 | with torch.no_grad(): | 733 | encoder_hidden_states = text_encoder(batch["input_ids"])[0] |
| 741 | encoder_hidden_states = text_encoder(batch["input_ids"])[0] | ||
| 742 | 734 | ||
| 743 | # Predict the noise residual | 735 | # Predict the noise residual |
| 744 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 736 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
| @@ -840,7 +832,8 @@ def main(): | |||
| 840 | global_progress_bar.clear() | 832 | global_progress_bar.clear() |
| 841 | 833 | ||
| 842 | if min_val_loss > val_loss: | 834 | if min_val_loss > val_loss: |
| 843 | accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") | 835 | accelerator.print( |
| 836 | f"Global step {global_step}: Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") | ||
| 844 | min_val_loss = val_loss | 837 | min_val_loss = val_loss |
| 845 | 838 | ||
| 846 | if sample_checkpoint and accelerator.is_main_process: | 839 | if sample_checkpoint and accelerator.is_main_process: |
diff --git a/environment.yaml b/environment.yaml index 5ecc5a8..de35645 100644 --- a/environment.yaml +++ b/environment.yaml | |||
| @@ -6,7 +6,7 @@ dependencies: | |||
| 6 | - cudatoolkit=11.3 | 6 | - cudatoolkit=11.3 |
| 7 | - numpy=1.22.3 | 7 | - numpy=1.22.3 |
| 8 | - pip=20.3 | 8 | - pip=20.3 |
| 9 | - python=3.8.10 | 9 | - python=3.9.13 |
| 10 | - pytorch=1.12.1 | 10 | - pytorch=1.12.1 |
| 11 | - torchvision=0.13.1 | 11 | - torchvision=0.13.1 |
| 12 | - pandas=1.4.3 | 12 | - pandas=1.4.3 |
| @@ -32,6 +32,6 @@ dependencies: | |||
| 32 | - test-tube>=0.7.5 | 32 | - test-tube>=0.7.5 |
| 33 | - torch-fidelity==0.3.0 | 33 | - torch-fidelity==0.3.0 |
| 34 | - torchmetrics==0.9.3 | 34 | - torchmetrics==0.9.3 |
| 35 | - transformers==4.22.2 | 35 | - transformers==4.23.1 |
| 36 | - triton==2.0.0.dev20220924 | 36 | - triton==2.0.0.dev20220924 |
| 37 | - xformers==0.0.13 | 37 | - xformers==0.0.13 |
| @@ -22,7 +22,7 @@ torch.backends.cuda.matmul.allow_tf32 = True | |||
| 22 | default_args = { | 22 | default_args = { |
| 23 | "model": None, | 23 | "model": None, |
| 24 | "scheduler": "euler_a", | 24 | "scheduler": "euler_a", |
| 25 | "precision": "fp16", | 25 | "precision": "fp32", |
| 26 | "embeddings_dir": "embeddings", | 26 | "embeddings_dir": "embeddings", |
| 27 | "output_dir": "output/inference", | 27 | "output_dir": "output/inference", |
| 28 | "config": None, | 28 | "config": None, |
| @@ -205,10 +205,10 @@ def load_embeddings(tokenizer, text_encoder, embeddings_dir): | |||
| 205 | def create_pipeline(model, scheduler, embeddings_dir, dtype): | 205 | def create_pipeline(model, scheduler, embeddings_dir, dtype): |
| 206 | print("Loading Stable Diffusion pipeline...") | 206 | print("Loading Stable Diffusion pipeline...") |
| 207 | 207 | ||
| 208 | tokenizer = CLIPTokenizer.from_pretrained(model + '/tokenizer', torch_dtype=dtype) | 208 | tokenizer = CLIPTokenizer.from_pretrained(model, subfolder='/tokenizer', torch_dtype=dtype) |
| 209 | text_encoder = CLIPTextModel.from_pretrained(model + '/text_encoder', torch_dtype=dtype) | 209 | text_encoder = CLIPTextModel.from_pretrained(model, subfolder='/text_encoder', torch_dtype=dtype) |
| 210 | vae = AutoencoderKL.from_pretrained(model + '/vae', torch_dtype=dtype) | 210 | vae = AutoencoderKL.from_pretrained(model, subfolder='/vae', torch_dtype=dtype) |
| 211 | unet = UNet2DConditionModel.from_pretrained(model + '/unet', torch_dtype=dtype) | 211 | unet = UNet2DConditionModel.from_pretrained(model, subfolder='/unet', torch_dtype=dtype) |
| 212 | 212 | ||
| 213 | load_embeddings(tokenizer, text_encoder, embeddings_dir) | 213 | load_embeddings(tokenizer, text_encoder, embeddings_dir) |
| 214 | 214 | ||
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index bfecd1c..8927a78 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
| @@ -11,7 +11,7 @@ from diffusers import AutoencoderKL, DiffusionPipeline, DDIMScheduler, LMSDiscre | |||
| 11 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput | 11 | from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput |
| 12 | from diffusers.utils import logging | 12 | from diffusers.utils import logging |
| 13 | from transformers import CLIPTextModel, CLIPTokenizer | 13 | from transformers import CLIPTextModel, CLIPTokenizer |
| 14 | from schedulers.scheduling_euler_a import EulerAScheduler, CFGDenoiserForward | 14 | from schedulers.scheduling_euler_a import EulerAScheduler |
| 15 | 15 | ||
| 16 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name | 16 | logger = logging.get_logger(__name__) # pylint: disable=invalid-name |
| 17 | 17 | ||
| @@ -284,10 +284,9 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 284 | 284 | ||
| 285 | noise_pred = None | 285 | noise_pred = None |
| 286 | if isinstance(self.scheduler, EulerAScheduler): | 286 | if isinstance(self.scheduler, EulerAScheduler): |
| 287 | sigma = t.reshape(1) | 287 | c_out, c_in, sigma_in = self.scheduler.prepare_input(latent_model_input, t, batch_size) |
| 288 | sigma_in = torch.cat([sigma] * latent_model_input.shape[0]) | 288 | eps = self.unet(latent_model_input * c_in, sigma_in, encoder_hidden_states=text_embeddings).sample |
| 289 | noise_pred = CFGDenoiserForward(self.unet, latent_model_input, sigma_in, | 289 | noise_pred = latent_model_input + eps * c_out |
| 290 | text_embeddings, guidance_scale, quantize=True, DSsigmas=self.scheduler.DSsigmas) | ||
| 291 | else: | 290 | else: |
| 292 | # predict the noise residual | 291 | # predict the noise residual |
| 293 | noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample | 292 | noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample |
| @@ -305,7 +304,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 305 | image = self.vae.decode(latents).sample | 304 | image = self.vae.decode(latents).sample |
| 306 | 305 | ||
| 307 | image = (image / 2 + 0.5).clamp(0, 1) | 306 | image = (image / 2 + 0.5).clamp(0, 1) |
| 308 | image = image.cpu().permute(0, 2, 3, 1).numpy() | 307 | image = image.cpu().permute(0, 2, 3, 1).float().numpy() |
| 309 | 308 | ||
| 310 | if output_type == "pil": | 309 | if output_type == "pil": |
| 311 | image = self.numpy_to_pil(image) | 310 | image = self.numpy_to_pil(image) |
diff --git a/schedulers/scheduling_euler_a.py b/schedulers/scheduling_euler_a.py index 13ea6b3..6abe971 100644 --- a/schedulers/scheduling_euler_a.py +++ b/schedulers/scheduling_euler_a.py | |||
| @@ -7,113 +7,6 @@ from diffusers.configuration_utils import ConfigMixin, register_to_config | |||
| 7 | from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput | 7 | from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput |
| 8 | 8 | ||
| 9 | 9 | ||
| 10 | ''' | ||
| 11 | helper functions: append_zero(), | ||
| 12 | t_to_sigma(), | ||
| 13 | get_sigmas(), | ||
| 14 | append_dims(), | ||
| 15 | CFGDenoiserForward(), | ||
| 16 | get_scalings(), | ||
| 17 | DSsigma_to_t(), | ||
| 18 | DiscreteEpsDDPMDenoiserForward(), | ||
| 19 | to_d(), | ||
| 20 | get_ancestral_step() | ||
| 21 | need cleaning | ||
| 22 | ''' | ||
| 23 | |||
| 24 | |||
| 25 | def append_zero(x): | ||
| 26 | return torch.cat([x, x.new_zeros([1])]) | ||
| 27 | |||
| 28 | |||
| 29 | def t_to_sigma(t, sigmas): | ||
| 30 | t = t.float() | ||
| 31 | low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac() | ||
| 32 | return (1 - w) * sigmas[low_idx] + w * sigmas[high_idx] | ||
| 33 | |||
| 34 | |||
| 35 | def get_sigmas(sigmas, n=None): | ||
| 36 | if n is None: | ||
| 37 | return append_zero(sigmas.flip(0)) | ||
| 38 | t_max = len(sigmas) - 1 # = 999 | ||
| 39 | t = torch.linspace(t_max, 0, n, device=sigmas.device, dtype=sigmas.dtype) | ||
| 40 | return append_zero(t_to_sigma(t, sigmas)) | ||
| 41 | |||
| 42 | # from k_samplers utils.py | ||
| 43 | |||
| 44 | |||
| 45 | def append_dims(x, target_dims): | ||
| 46 | """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" | ||
| 47 | dims_to_append = target_dims - x.ndim | ||
| 48 | if dims_to_append < 0: | ||
| 49 | raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') | ||
| 50 | return x[(...,) + (None,) * dims_to_append] | ||
| 51 | |||
| 52 | |||
| 53 | def CFGDenoiserForward(Unet, x_in, sigma_in, cond_in, cond_scale, quantize=False, DSsigmas=None): | ||
| 54 | # x_in = torch.cat([x] * 2)#A# concat the latent | ||
| 55 | # sigma_in = torch.cat([sigma] * 2) #A# concat sigma | ||
| 56 | # cond_in = torch.cat([uncond, cond]) | ||
| 57 | # uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) | ||
| 58 | # uncond, cond = DiscreteEpsDDPMDenoiserForward(Unet,x_in, sigma_in,DSsigmas=DSsigmas, cond=cond_in).chunk(2) | ||
| 59 | # return uncond + (cond - uncond) * cond_scale | ||
| 60 | noise_pred = DiscreteEpsDDPMDenoiserForward( | ||
| 61 | Unet, x_in, sigma_in, quantize=quantize, DSsigmas=DSsigmas, cond=cond_in) | ||
| 62 | return noise_pred | ||
| 63 | |||
| 64 | # from k_samplers sampling.py | ||
| 65 | |||
| 66 | |||
| 67 | def to_d(x, sigma, denoised): | ||
| 68 | """Converts a denoiser output to a Karras ODE derivative.""" | ||
| 69 | return (x - denoised) / append_dims(sigma.to(denoised.device), x.ndim) | ||
| 70 | |||
| 71 | |||
| 72 | def get_scalings(sigma): | ||
| 73 | sigma_data = 1. | ||
| 74 | c_out = -sigma | ||
| 75 | c_in = 1 / (sigma ** 2 + sigma_data ** 2) ** 0.5 | ||
| 76 | return c_out, c_in | ||
| 77 | |||
| 78 | # DiscreteSchedule DS | ||
| 79 | |||
| 80 | |||
| 81 | def DSsigma_to_t(sigma, quantize=False, DSsigmas=None): | ||
| 82 | dists = torch.abs(sigma - DSsigmas[:, None]) | ||
| 83 | if quantize: | ||
| 84 | return torch.argmin(dists, dim=0).view(sigma.shape) | ||
| 85 | low_idx, high_idx = torch.sort(torch.topk(dists, dim=0, k=2, largest=False).indices, dim=0)[0] | ||
| 86 | low, high = DSsigmas[low_idx], DSsigmas[high_idx] | ||
| 87 | w = (low - sigma) / (low - high) | ||
| 88 | w = w.clamp(0, 1) | ||
| 89 | t = (1 - w) * low_idx + w * high_idx | ||
| 90 | return t.view(sigma.shape) | ||
| 91 | |||
| 92 | |||
| 93 | def DiscreteEpsDDPMDenoiserForward(Unet, input, sigma, DSsigmas=None, quantize=False, **kwargs): | ||
| 94 | sigma = sigma.to(dtype=input.dtype, device=Unet.device) | ||
| 95 | DSsigmas = DSsigmas.to(dtype=input.dtype, device=Unet.device) | ||
| 96 | c_out, c_in = [append_dims(x, input.ndim) for x in get_scalings(sigma)] | ||
| 97 | # print(f">>>>>>>>>>> {input.dtype} {c_in.dtype} {sigma.dtype} {DSsigmas.dtype}") | ||
| 98 | eps = Unet(input * c_in, DSsigma_to_t(sigma, quantize=quantize, DSsigmas=DSsigmas), | ||
| 99 | encoder_hidden_states=kwargs['cond']).sample | ||
| 100 | return input + eps * c_out | ||
| 101 | |||
| 102 | |||
| 103 | # from k_samplers sampling.py | ||
| 104 | def get_ancestral_step(sigma_from, sigma_to): | ||
| 105 | """Calculates the noise level (sigma_down) to step down to and the amount | ||
| 106 | of noise to add (sigma_up) when doing an ancestral sampling step.""" | ||
| 107 | sigma_up = (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5 | ||
| 108 | sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5 | ||
| 109 | return sigma_down, sigma_up | ||
| 110 | |||
| 111 | |||
| 112 | ''' | ||
| 113 | Euler Ancestral Scheduler | ||
| 114 | ''' | ||
| 115 | |||
| 116 | |||
| 117 | class EulerAScheduler(SchedulerMixin, ConfigMixin): | 10 | class EulerAScheduler(SchedulerMixin, ConfigMixin): |
| 118 | """ | 11 | """ |
| 119 | Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and | 12 | Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and |
| @@ -154,20 +47,24 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): | |||
| 154 | beta_end: float = 0.02, | 47 | beta_end: float = 0.02, |
| 155 | beta_schedule: str = "linear", | 48 | beta_schedule: str = "linear", |
| 156 | trained_betas: Optional[np.ndarray] = None, | 49 | trained_betas: Optional[np.ndarray] = None, |
| 50 | tensor_format: str = "pt", | ||
| 51 | num_inference_steps=None, | ||
| 52 | device='cuda' | ||
| 157 | ): | 53 | ): |
| 158 | if trained_betas is not None: | 54 | if trained_betas is not None: |
| 159 | self.betas = torch.from_numpy(trained_betas) | 55 | self.betas = torch.from_numpy(trained_betas).to(device) |
| 160 | if beta_schedule == "linear": | 56 | if beta_schedule == "linear": |
| 161 | self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) | 57 | self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32, device=device) |
| 162 | elif beta_schedule == "scaled_linear": | 58 | elif beta_schedule == "scaled_linear": |
| 163 | # this schedule is very specific to the latent diffusion model. | 59 | # this schedule is very specific to the latent diffusion model. |
| 164 | self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 | 60 | self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, |
| 165 | elif beta_schedule == "squaredcos_cap_v2": | 61 | dtype=torch.float32, device=device) ** 2 |
| 166 | # Glide cosine schedule | ||
| 167 | self.betas = betas_for_alpha_bar(num_train_timesteps) | ||
| 168 | else: | 62 | else: |
| 169 | raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") | 63 | raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") |
| 170 | 64 | ||
| 65 | self.device = device | ||
| 66 | self.tensor_format = tensor_format | ||
| 67 | |||
| 171 | self.alphas = 1.0 - self.betas | 68 | self.alphas = 1.0 - self.betas |
| 172 | self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) | 69 | self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) |
| 173 | 70 | ||
| @@ -175,8 +72,12 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): | |||
| 175 | self.init_noise_sigma = 1.0 | 72 | self.init_noise_sigma = 1.0 |
| 176 | 73 | ||
| 177 | # setable values | 74 | # setable values |
| 178 | self.num_inference_steps = None | 75 | self.num_inference_steps = num_inference_steps |
| 179 | self.timesteps = np.arange(0, num_train_timesteps)[::-1] | 76 | self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy() |
| 77 | # get sigmas | ||
| 78 | self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 | ||
| 79 | self.sigmas = self.get_sigmas(self.DSsigmas, self.num_inference_steps) | ||
| 80 | self.set_format(tensor_format=tensor_format) | ||
| 180 | 81 | ||
| 181 | # A# take number of steps as input | 82 | # A# take number of steps as input |
| 182 | # A# store 1) number of steps 2) timesteps 3) schedule | 83 | # A# store 1) number of steps 2) timesteps 3) schedule |
| @@ -192,7 +93,7 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): | |||
| 192 | 93 | ||
| 193 | self.num_inference_steps = num_inference_steps | 94 | self.num_inference_steps = num_inference_steps |
| 194 | self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 | 95 | self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 |
| 195 | self.sigmas = get_sigmas(self.DSsigmas, self.num_inference_steps).to(device=device) | 96 | self.sigmas = self.get_sigmas(self.DSsigmas, self.num_inference_steps) |
| 196 | self.timesteps = self.sigmas[:-1] | 97 | self.timesteps = self.sigmas[:-1] |
| 197 | self.is_scale_input_called = False | 98 | self.is_scale_input_called = False |
| 198 | 99 | ||
| @@ -251,8 +152,8 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): | |||
| 251 | s_prev = self.sigmas[step_prev_index] | 152 | s_prev = self.sigmas[step_prev_index] |
| 252 | latents = sample | 153 | latents = sample |
| 253 | 154 | ||
| 254 | sigma_down, sigma_up = get_ancestral_step(s, s_prev) | 155 | sigma_down, sigma_up = self.get_ancestral_step(s, s_prev) |
| 255 | d = to_d(latents, s, model_output) | 156 | d = self.to_d(latents, s, model_output) |
| 256 | dt = sigma_down - s | 157 | dt = sigma_down - s |
| 257 | latents = latents + d * dt | 158 | latents = latents + d * dt |
| 258 | latents = latents + torch.randn(latents.shape, layout=latents.layout, device=latents.device, dtype=latents.dtype, | 159 | latents = latents + torch.randn(latents.shape, layout=latents.layout, device=latents.device, dtype=latents.dtype, |
| @@ -313,3 +214,76 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): | |||
| 313 | noisy_samples = original_samples + noise * sigma | 214 | noisy_samples = original_samples + noise * sigma |
| 314 | self.is_scale_input_called = True | 215 | self.is_scale_input_called = True |
| 315 | return noisy_samples | 216 | return noisy_samples |
| 217 | |||
| 218 | # from k_samplers sampling.py | ||
| 219 | |||
| 220 | def get_ancestral_step(self, sigma_from, sigma_to): | ||
| 221 | """Calculates the noise level (sigma_down) to step down to and the amount | ||
| 222 | of noise to add (sigma_up) when doing an ancestral sampling step.""" | ||
| 223 | sigma_up = (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5 | ||
| 224 | sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5 | ||
| 225 | return sigma_down, sigma_up | ||
| 226 | |||
| 227 | def t_to_sigma(self, t, sigmas): | ||
| 228 | t = t.float() | ||
| 229 | low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac() | ||
| 230 | return (1 - w) * sigmas[low_idx] + w * sigmas[high_idx] | ||
| 231 | |||
| 232 | def append_zero(self, x): | ||
| 233 | return torch.cat([x, x.new_zeros([1])]) | ||
| 234 | |||
| 235 | def get_sigmas(self, sigmas, n=None): | ||
| 236 | if n is None: | ||
| 237 | return self.append_zero(sigmas.flip(0)) | ||
| 238 | t_max = len(sigmas) - 1 # = 999 | ||
| 239 | device = self.device | ||
| 240 | t = torch.linspace(t_max, 0, n, device=device) | ||
| 241 | # t = torch.linspace(t_max, 0, n, device=sigmas.device) | ||
| 242 | return self.append_zero(self.t_to_sigma(t, sigmas)) | ||
| 243 | |||
| 244 | # from k_samplers utils.py | ||
| 245 | def append_dims(self, x, target_dims): | ||
| 246 | """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" | ||
| 247 | dims_to_append = target_dims - x.ndim | ||
| 248 | if dims_to_append < 0: | ||
| 249 | raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') | ||
| 250 | return x[(...,) + (None,) * dims_to_append] | ||
| 251 | |||
| 252 | # from k_samplers sampling.py | ||
| 253 | def to_d(self, x, sigma, denoised): | ||
| 254 | """Converts a denoiser output to a Karras ODE derivative.""" | ||
| 255 | return (x - denoised) / self.append_dims(sigma, x.ndim) | ||
| 256 | |||
| 257 | def get_scalings(self, sigma): | ||
| 258 | sigma_data = 1. | ||
| 259 | c_out = -sigma | ||
| 260 | c_in = 1 / (sigma ** 2 + sigma_data ** 2) ** 0.5 | ||
| 261 | return c_out, c_in | ||
| 262 | |||
| 263 | # DiscreteSchedule DS | ||
| 264 | def DSsigma_to_t(self, sigma, quantize=None): | ||
| 265 | # quantize = self.quantize if quantize is None else quantize | ||
| 266 | quantize = False | ||
| 267 | dists = torch.abs(sigma - self.DSsigmas[:, None]) | ||
| 268 | if quantize: | ||
| 269 | return torch.argmin(dists, dim=0).view(sigma.shape) | ||
| 270 | low_idx, high_idx = torch.sort(torch.topk(dists, dim=0, k=2, largest=False).indices, dim=0)[0] | ||
| 271 | low, high = self.DSsigmas[low_idx], self.DSsigmas[high_idx] | ||
| 272 | w = (low - sigma) / (low - high) | ||
| 273 | w = w.clamp(0, 1) | ||
| 274 | t = (1 - w) * low_idx + w * high_idx | ||
| 275 | return t.view(sigma.shape) | ||
| 276 | |||
| 277 | def prepare_input(self, latent_in, t, batch_size): | ||
| 278 | sigma = t.reshape(1) # A# potential bug: doesn't work on samples > 1 | ||
| 279 | |||
| 280 | sigma_in = torch.cat([sigma] * 2 * batch_size) | ||
| 281 | # noise_pred = CFGDenoiserForward(self.unet, latent_model_input, sigma_in, text_embeddings , guidance_scale,DSsigmas=self.scheduler.DSsigmas) | ||
| 282 | # noise_pred = DiscreteEpsDDPMDenoiserForward(self.unet,latent_model_input, sigma_in,DSsigmas=self.scheduler.DSsigmas, cond=cond_in) | ||
| 283 | c_out, c_in = [self.append_dims(x, latent_in.ndim) for x in self.get_scalings(sigma_in)] | ||
| 284 | |||
| 285 | sigma_in = self.DSsigma_to_t(sigma_in) | ||
| 286 | # s_in = latent_in.new_ones([latent_in.shape[0]]) | ||
| 287 | # sigma_in = sigma_in * s_in | ||
| 288 | |||
| 289 | return c_out, c_in, sigma_in | ||
diff --git a/textual_inversion.py b/textual_inversion.py index e6d856a..3a3741d 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
| @@ -17,7 +17,6 @@ from accelerate.logging import get_logger | |||
| 17 | from accelerate.utils import LoggerType, set_seed | 17 | from accelerate.utils import LoggerType, set_seed |
| 18 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel | 18 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel |
| 19 | from diffusers.optimization import get_scheduler | 19 | from diffusers.optimization import get_scheduler |
| 20 | from diffusers.training_utils import EMAModel | ||
| 21 | from PIL import Image | 20 | from PIL import Image |
| 22 | from tqdm.auto import tqdm | 21 | from tqdm.auto import tqdm |
| 23 | from transformers import CLIPTextModel, CLIPTokenizer | 22 | from transformers import CLIPTextModel, CLIPTokenizer |
| @@ -112,7 +111,7 @@ def parse_args(): | |||
| 112 | parser.add_argument( | 111 | parser.add_argument( |
| 113 | "--max_train_steps", | 112 | "--max_train_steps", |
| 114 | type=int, | 113 | type=int, |
| 115 | default=5000, | 114 | default=3000, |
| 116 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", | 115 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", |
| 117 | ) | 116 | ) |
| 118 | parser.add_argument( | 117 | parser.add_argument( |
| @@ -150,31 +149,10 @@ def parse_args(): | |||
| 150 | parser.add_argument( | 149 | parser.add_argument( |
| 151 | "--lr_warmup_steps", | 150 | "--lr_warmup_steps", |
| 152 | type=int, | 151 | type=int, |
| 153 | default=600, | 152 | default=500, |
| 154 | help="Number of steps for the warmup in the lr scheduler." | 153 | help="Number of steps for the warmup in the lr scheduler." |
| 155 | ) | 154 | ) |
| 156 | parser.add_argument( | 155 | parser.add_argument( |
| 157 | "--use_ema", | ||
| 158 | action="store_true", | ||
| 159 | default=True, | ||
| 160 | help="Whether to use EMA model." | ||
| 161 | ) | ||
| 162 | parser.add_argument( | ||
| 163 | "--ema_inv_gamma", | ||
| 164 | type=float, | ||
| 165 | default=1.0 | ||
| 166 | ) | ||
| 167 | parser.add_argument( | ||
| 168 | "--ema_power", | ||
| 169 | type=float, | ||
| 170 | default=1.0 | ||
| 171 | ) | ||
| 172 | parser.add_argument( | ||
| 173 | "--ema_max_decay", | ||
| 174 | type=float, | ||
| 175 | default=0.9999 | ||
| 176 | ) | ||
| 177 | parser.add_argument( | ||
| 178 | "--use_8bit_adam", | 156 | "--use_8bit_adam", |
| 179 | action="store_true", | 157 | action="store_true", |
| 180 | help="Whether or not to use 8-bit Adam from bitsandbytes." | 158 | help="Whether or not to use 8-bit Adam from bitsandbytes." |
| @@ -348,7 +326,6 @@ class Checkpointer: | |||
| 348 | unet, | 326 | unet, |
| 349 | tokenizer, | 327 | tokenizer, |
| 350 | text_encoder, | 328 | text_encoder, |
| 351 | ema_text_encoder, | ||
| 352 | placeholder_token, | 329 | placeholder_token, |
| 353 | placeholder_token_id, | 330 | placeholder_token_id, |
| 354 | output_dir: Path, | 331 | output_dir: Path, |
| @@ -363,7 +340,6 @@ class Checkpointer: | |||
| 363 | self.unet = unet | 340 | self.unet = unet |
| 364 | self.tokenizer = tokenizer | 341 | self.tokenizer = tokenizer |
| 365 | self.text_encoder = text_encoder | 342 | self.text_encoder = text_encoder |
| 366 | self.ema_text_encoder = ema_text_encoder | ||
| 367 | self.placeholder_token = placeholder_token | 343 | self.placeholder_token = placeholder_token |
| 368 | self.placeholder_token_id = placeholder_token_id | 344 | self.placeholder_token_id = placeholder_token_id |
| 369 | self.output_dir = output_dir | 345 | self.output_dir = output_dir |
| @@ -380,8 +356,7 @@ class Checkpointer: | |||
| 380 | checkpoints_path = self.output_dir.joinpath("checkpoints") | 356 | checkpoints_path = self.output_dir.joinpath("checkpoints") |
| 381 | checkpoints_path.mkdir(parents=True, exist_ok=True) | 357 | checkpoints_path.mkdir(parents=True, exist_ok=True) |
| 382 | 358 | ||
| 383 | unwrapped = self.accelerator.unwrap_model( | 359 | unwrapped = self.accelerator.unwrap_model(self.text_encoder) |
| 384 | self.ema_text_encoder.averaged_model if self.ema_text_encoder is not None else self.text_encoder) | ||
| 385 | 360 | ||
| 386 | # Save a checkpoint | 361 | # Save a checkpoint |
| 387 | learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id] | 362 | learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id] |
| @@ -400,8 +375,7 @@ class Checkpointer: | |||
| 400 | def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps): | 375 | def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps): |
| 401 | samples_path = Path(self.output_dir).joinpath("samples") | 376 | samples_path = Path(self.output_dir).joinpath("samples") |
| 402 | 377 | ||
| 403 | unwrapped = self.accelerator.unwrap_model( | 378 | unwrapped = self.accelerator.unwrap_model(self.text_encoder) |
| 404 | self.ema_text_encoder.averaged_model if self.ema_text_encoder is not None else self.text_encoder) | ||
| 405 | scheduler = EulerAScheduler( | 379 | scheduler = EulerAScheduler( |
| 406 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | 380 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" |
| 407 | ) | 381 | ) |
| @@ -507,9 +481,7 @@ def main(): | |||
| 507 | if args.tokenizer_name: | 481 | if args.tokenizer_name: |
| 508 | tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) | 482 | tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) |
| 509 | elif args.pretrained_model_name_or_path: | 483 | elif args.pretrained_model_name_or_path: |
| 510 | tokenizer = CLIPTokenizer.from_pretrained( | 484 | tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') |
| 511 | args.pretrained_model_name_or_path + '/tokenizer' | ||
| 512 | ) | ||
| 513 | 485 | ||
| 514 | # Add the placeholder token in tokenizer | 486 | # Add the placeholder token in tokenizer |
| 515 | num_added_tokens = tokenizer.add_tokens(args.placeholder_token) | 487 | num_added_tokens = tokenizer.add_tokens(args.placeholder_token) |
| @@ -530,15 +502,10 @@ def main(): | |||
| 530 | placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) | 502 | placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) |
| 531 | 503 | ||
| 532 | # Load models and create wrapper for stable diffusion | 504 | # Load models and create wrapper for stable diffusion |
| 533 | text_encoder = CLIPTextModel.from_pretrained( | 505 | text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder') |
| 534 | args.pretrained_model_name_or_path + '/text_encoder', | 506 | vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae') |
| 535 | ) | ||
| 536 | vae = AutoencoderKL.from_pretrained( | ||
| 537 | args.pretrained_model_name_or_path + '/vae', | ||
| 538 | ) | ||
| 539 | unet = UNet2DConditionModel.from_pretrained( | 507 | unet = UNet2DConditionModel.from_pretrained( |
| 540 | args.pretrained_model_name_or_path + '/unet', | 508 | args.pretrained_model_name_or_path, subfolder='unet') |
| 541 | ) | ||
| 542 | 509 | ||
| 543 | if args.gradient_checkpointing: | 510 | if args.gradient_checkpointing: |
| 544 | unet.enable_gradient_checkpointing() | 511 | unet.enable_gradient_checkpointing() |
| @@ -707,13 +674,6 @@ def main(): | |||
| 707 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler | 674 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler |
| 708 | ) | 675 | ) |
| 709 | 676 | ||
| 710 | ema_text_encoder = EMAModel( | ||
| 711 | text_encoder, | ||
| 712 | inv_gamma=args.ema_inv_gamma, | ||
| 713 | power=args.ema_power, | ||
| 714 | max_value=args.ema_max_decay | ||
| 715 | ) if args.use_ema else None | ||
| 716 | |||
| 717 | # Move vae and unet to device | 677 | # Move vae and unet to device |
| 718 | vae.to(accelerator.device) | 678 | vae.to(accelerator.device) |
| 719 | unet.to(accelerator.device) | 679 | unet.to(accelerator.device) |
| @@ -757,7 +717,6 @@ def main(): | |||
| 757 | unet=unet, | 717 | unet=unet, |
| 758 | tokenizer=tokenizer, | 718 | tokenizer=tokenizer, |
| 759 | text_encoder=text_encoder, | 719 | text_encoder=text_encoder, |
| 760 | ema_text_encoder=ema_text_encoder, | ||
| 761 | placeholder_token=args.placeholder_token, | 720 | placeholder_token=args.placeholder_token, |
| 762 | placeholder_token_id=placeholder_token_id, | 721 | placeholder_token_id=placeholder_token_id, |
| 763 | output_dir=basepath, | 722 | output_dir=basepath, |
| @@ -777,7 +736,7 @@ def main(): | |||
| 777 | disable=not accelerator.is_local_main_process, | 736 | disable=not accelerator.is_local_main_process, |
| 778 | dynamic_ncols=True | 737 | dynamic_ncols=True |
| 779 | ) | 738 | ) |
| 780 | local_progress_bar.set_description("Batch X out of Y") | 739 | local_progress_bar.set_description("Epoch X / Y") |
| 781 | 740 | ||
| 782 | global_progress_bar = tqdm( | 741 | global_progress_bar = tqdm( |
| 783 | range(args.max_train_steps + val_steps), | 742 | range(args.max_train_steps + val_steps), |
| @@ -788,7 +747,7 @@ def main(): | |||
| 788 | 747 | ||
| 789 | try: | 748 | try: |
| 790 | for epoch in range(num_epochs): | 749 | for epoch in range(num_epochs): |
| 791 | local_progress_bar.set_description(f"Batch {epoch + 1} out of {num_epochs}") | 750 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") |
| 792 | local_progress_bar.reset() | 751 | local_progress_bar.reset() |
| 793 | 752 | ||
| 794 | text_encoder.train() | 753 | text_encoder.train() |
| @@ -799,9 +758,8 @@ def main(): | |||
| 799 | for step, batch in enumerate(train_dataloader): | 758 | for step, batch in enumerate(train_dataloader): |
| 800 | with accelerator.accumulate(text_encoder): | 759 | with accelerator.accumulate(text_encoder): |
| 801 | # Convert images to latent space | 760 | # Convert images to latent space |
| 802 | with torch.no_grad(): | 761 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() |
| 803 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 762 | latents = latents * 0.18215 |
| 804 | latents = latents * 0.18215 | ||
| 805 | 763 | ||
| 806 | # Sample noise that we'll add to the latents | 764 | # Sample noise that we'll add to the latents |
| 807 | noise = torch.randn(latents.shape).to(latents.device) | 765 | noise = torch.randn(latents.shape).to(latents.device) |
| @@ -859,9 +817,6 @@ def main(): | |||
| 859 | 817 | ||
| 860 | # Checks if the accelerator has performed an optimization step behind the scenes | 818 | # Checks if the accelerator has performed an optimization step behind the scenes |
| 861 | if accelerator.sync_gradients: | 819 | if accelerator.sync_gradients: |
| 862 | if args.use_ema: | ||
| 863 | ema_text_encoder.step(unet) | ||
| 864 | |||
| 865 | local_progress_bar.update(1) | 820 | local_progress_bar.update(1) |
| 866 | global_progress_bar.update(1) | 821 | global_progress_bar.update(1) |
| 867 | 822 | ||
| @@ -881,8 +836,6 @@ def main(): | |||
| 881 | }) | 836 | }) |
| 882 | 837 | ||
| 883 | logs = {"train/loss": loss, "lr": lr_scheduler.get_last_lr()[0]} | 838 | logs = {"train/loss": loss, "lr": lr_scheduler.get_last_lr()[0]} |
| 884 | if args.use_ema: | ||
| 885 | logs["ema_decay"] = ema_text_encoder.decay | ||
| 886 | 839 | ||
| 887 | accelerator.log(logs, step=global_step) | 840 | accelerator.log(logs, step=global_step) |
| 888 | 841 | ||
| @@ -937,7 +890,8 @@ def main(): | |||
| 937 | global_progress_bar.clear() | 890 | global_progress_bar.clear() |
| 938 | 891 | ||
| 939 | if min_val_loss > val_loss: | 892 | if min_val_loss > val_loss: |
| 940 | accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") | 893 | accelerator.print( |
| 894 | f"Global step {global_step}: Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") | ||
| 941 | checkpointer.checkpoint(global_step + global_step_offset, "milestone") | 895 | checkpointer.checkpoint(global_step + global_step_offset, "milestone") |
| 942 | min_val_loss = val_loss | 896 | min_val_loss = val_loss |
| 943 | 897 | ||
