From a1b8327085ddeab589be074d7e9df4291aba1210 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 1 Mar 2023 12:34:42 +0100 Subject: Update --- environment.yaml | 2 +- models/clip/embeddings.py | 2 +- schedulers/scheduling_deis_multistep.py | 500 -------------------------------- train_dreambooth.py | 6 +- train_lora.py | 2 +- train_ti.py | 2 +- training/functional.py | 50 ++-- training/optimization.py | 2 +- training/strategy/dreambooth.py | 6 +- training/strategy/ti.py | 2 +- 10 files changed, 39 insertions(+), 535 deletions(-) delete mode 100644 schedulers/scheduling_deis_multistep.py diff --git a/environment.yaml b/environment.yaml index 1e6ac60..4899709 100644 --- a/environment.yaml +++ b/environment.yaml @@ -13,7 +13,7 @@ dependencies: - python=3.10.8 - pytorch=1.13.1=*cuda* - torchvision=0.14.1 - - xformers=0.0.17.dev461 + - xformers=0.0.17.dev466 - pip: - -e . - -e git+https://github.com/huggingface/diffusers#egg=diffusers diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 734730e..6be6e9f 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py @@ -13,7 +13,7 @@ from transformers.models.clip.modeling_clip import CLIPTextEmbeddings def resize_embedding(old_embedding: nn.Embedding, new_num_embeddings: int, initializer_factor: float = 1.0) -> nn.Embedding: - old_num_embeddings, old_embedding_dim = old_embedding.weight.size() + old_num_embeddings, old_embedding_dim = old_embedding.weight.shape if old_num_embeddings == new_num_embeddings: return old_embedding diff --git a/schedulers/scheduling_deis_multistep.py b/schedulers/scheduling_deis_multistep.py deleted file mode 100644 index ea1281e..0000000 --- a/schedulers/scheduling_deis_multistep.py +++ /dev/null @@ -1,500 +0,0 @@ -# Copyright 2022 FLAIR Lab and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -# DISCLAIMER: check https://arxiv.org/abs/2204.13902 and https://github.com/qsh-zh/deis for more info -# The codebase is modified based on https://github.com/huggingface/diffusers/blob/main/src/diffusers/schedulers/scheduling_dpmsolver_multistep.py - -import math -from typing import List, Optional, Tuple, Union - -import numpy as np -import torch - -from diffusers.configuration_utils import ConfigMixin, register_to_config -from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin, SchedulerOutput - - -def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999): - """ - Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of - (1-beta) over time from t = [0,1]. - - Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up - to that part of the diffusion process. - - - Args: - num_diffusion_timesteps (`int`): the number of betas to produce. - max_beta (`float`): the maximum beta to use; use values lower than 1 to - prevent singularities. - - Returns: - betas (`np.ndarray`): the betas used by the scheduler to step the model outputs - """ - - def alpha_bar(time_step): - return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 - - betas = [] - for i in range(num_diffusion_timesteps): - t1 = i / num_diffusion_timesteps - t2 = (i + 1) / num_diffusion_timesteps - betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) - return torch.tensor(betas, dtype=torch.float32) - - -class DEISMultistepScheduler(SchedulerMixin, ConfigMixin): - """ - DEIS (https://arxiv.org/abs/2204.13902) is a fast high order solver for diffusion ODEs. We slightly modify the - polynomial fitting formula in log-rho space instead of the original linear t space in DEIS paper. The modification - enjoys closed-form coefficients for exponential multistep update instead of replying on the numerical solver. More - variants of DEIS can be found in https://github.com/qsh-zh/deis. - - Currently, we support the log-rho multistep DEIS. We recommend to use `solver_order=2 / 3` while `solver_order=1` - reduces to DDIM. - - We also support the "dynamic thresholding" method in Imagen (https://arxiv.org/abs/2205.11487). For pixel-space - diffusion models, you can set `thresholding=True` to use the dynamic thresholding. - - [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` - function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and - [`~SchedulerMixin.from_pretrained`] functions. - - Args: - num_train_timesteps (`int`): number of diffusion steps used to train the model. - beta_start (`float`): the starting `beta` value of inference. - beta_end (`float`): the final `beta` value. - beta_schedule (`str`): - the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from - `linear`, `scaled_linear`, or `squaredcos_cap_v2`. - trained_betas (`np.ndarray`, optional): - option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. - solver_order (`int`, default `2`): - the order of DEIS; can be `1` or `2` or `3`. We recommend to use `solver_order=2` for guided sampling, and - `solver_order=3` for unconditional sampling. - prediction_type (`str`, default `epsilon`): - indicates whether the model predicts the noise (epsilon), or the data / `x0`. One of `epsilon`, `sample`, - or `v-prediction`. - thresholding (`bool`, default `False`): - whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487). - Note that the thresholding method is unsuitable for latent-space diffusion models (such as - stable-diffusion). - dynamic_thresholding_ratio (`float`, default `0.995`): - the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen - (https://arxiv.org/abs/2205.11487). - sample_max_value (`float`, default `1.0`): - the threshold value for dynamic thresholding. Valid woks when `thresholding=True` - algorithm_type (`str`, default `deis`): - the algorithm type for the solver. current we support multistep deis, we will add other variants of DEIS in - the future - lower_order_final (`bool`, default `True`): - whether to use lower-order solvers in the final steps. Only valid for < 15 inference steps. We empirically - find this trick can stabilize the sampling of DEIS for steps < 15, especially for steps <= 10. - - """ - - _compatibles = [e.name for e in KarrasDiffusionSchedulers] - order = 1 - - @register_to_config - def __init__( - self, - num_train_timesteps: int = 1000, - beta_start: float = 0.0001, - beta_end: float = 0.02, - beta_schedule: str = "linear", - trained_betas: Optional[np.ndarray] = None, - solver_order: int = 2, - prediction_type: str = "epsilon", - thresholding: bool = False, - dynamic_thresholding_ratio: float = 0.995, - sample_max_value: float = 1.0, - algorithm_type: str = "deis", - solver_type: str = "logrho", - lower_order_final: bool = True, - ): - if trained_betas is not None: - self.betas = torch.tensor(trained_betas, dtype=torch.float32) - elif beta_schedule == "linear": - self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) - elif beta_schedule == "scaled_linear": - # this schedule is very specific to the latent diffusion model. - self.betas = ( - torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 - ) - elif beta_schedule == "squaredcos_cap_v2": - # Glide cosine schedule - self.betas = betas_for_alpha_bar(num_train_timesteps) - else: - raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") - - self.alphas = 1.0 - self.betas - self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) - # Currently we only support VP-type noise schedule - self.alpha_t = torch.sqrt(self.alphas_cumprod) - self.sigma_t = torch.sqrt(1 - self.alphas_cumprod) - self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t) - - # standard deviation of the initial noise distribution - self.init_noise_sigma = 1.0 - - # settings for DEIS - if algorithm_type not in ["deis"]: - if algorithm_type in ["dpmsolver", "dpmsolver++"]: - algorithm_type = "deis" - else: - raise NotImplementedError(f"{algorithm_type} does is not implemented for {self.__class__}") - - if solver_type not in ["logrho"]: - if solver_type in ["midpoint", "heun"]: - solver_type = "logrho" - else: - raise NotImplementedError(f"solver type {solver_type} does is not implemented for {self.__class__}") - - # setable values - self.num_inference_steps = None - timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy() - self.timesteps = torch.from_numpy(timesteps) - self.model_outputs = [None] * solver_order - self.lower_order_nums = 0 - - def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): - """ - Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. - - Args: - num_inference_steps (`int`): - the number of diffusion steps used when generating samples with a pre-trained model. - device (`str` or `torch.device`, optional): - the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. - """ - self.num_inference_steps = num_inference_steps - timesteps = ( - np.linspace(0, self.num_train_timesteps - 1, num_inference_steps + 1) - .round()[::-1][:-1] - .copy() - .astype(np.int64) - ) - self.timesteps = torch.from_numpy(timesteps).to(device) - self.model_outputs = [ - None, - ] * self.config.solver_order - self.lower_order_nums = 0 - - def convert_model_output( - self, model_output: torch.FloatTensor, timestep: int, sample: torch.FloatTensor - ) -> torch.FloatTensor: - """ - Convert the model output to the corresponding type that the algorithm DEIS needs. - - Args: - model_output (`torch.FloatTensor`): direct output from learned diffusion model. - timestep (`int`): current discrete timestep in the diffusion chain. - sample (`torch.FloatTensor`): - current instance of sample being created by diffusion process. - - Returns: - `torch.FloatTensor`: the converted model output. - """ - if self.config.prediction_type == "epsilon": - alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] - x0_pred = (sample - sigma_t * model_output) / alpha_t - elif self.config.prediction_type == "sample": - x0_pred = model_output - elif self.config.prediction_type == "v_prediction": - alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] - x0_pred = alpha_t * sample - sigma_t * model_output - else: - raise ValueError( - f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" - " `v_prediction` for the DEISMultistepScheduler." - ) - - if self.config.thresholding: - # Dynamic thresholding in https://arxiv.org/abs/2205.11487 - orig_dtype = x0_pred.dtype - if orig_dtype not in [torch.float, torch.double]: - x0_pred = x0_pred.float() - dynamic_max_val = torch.quantile( - torch.abs(x0_pred).reshape((x0_pred.shape[0], -1)), self.config.dynamic_thresholding_ratio, dim=1 - ) - dynamic_max_val = torch.maximum( - dynamic_max_val, - self.config.sample_max_value * torch.ones_like(dynamic_max_val).to(dynamic_max_val.device), - )[(...,) + (None,) * (x0_pred.ndim - 1)] - x0_pred = torch.clamp(x0_pred, -dynamic_max_val, dynamic_max_val) / dynamic_max_val - x0_pred = x0_pred.type(orig_dtype) - - if self.config.algorithm_type == "deis": - alpha_t, sigma_t = self.alpha_t[timestep], self.sigma_t[timestep] - return (sample - alpha_t * x0_pred) / sigma_t - else: - raise NotImplementedError("only support log-rho multistep deis now") - - def deis_first_order_update( - self, - model_output: torch.FloatTensor, - timestep: int, - prev_timestep: int, - sample: torch.FloatTensor, - ) -> torch.FloatTensor: - """ - One step for the first-order DEIS (equivalent to DDIM). - - Args: - model_output (`torch.FloatTensor`): direct output from learned diffusion model. - timestep (`int`): current discrete timestep in the diffusion chain. - prev_timestep (`int`): previous discrete timestep in the diffusion chain. - sample (`torch.FloatTensor`): - current instance of sample being created by diffusion process. - - Returns: - `torch.FloatTensor`: the sample tensor at the previous timestep. - """ - lambda_t, lambda_s = self.lambda_t[prev_timestep], self.lambda_t[timestep] - alpha_t, alpha_s = self.alpha_t[prev_timestep], self.alpha_t[timestep] - sigma_t, _ = self.sigma_t[prev_timestep], self.sigma_t[timestep] - h = lambda_t - lambda_s - if self.config.algorithm_type == "deis": - x_t = (alpha_t / alpha_s) * sample - (sigma_t * (torch.exp(h) - 1.0)) * model_output - else: - raise NotImplementedError("only support log-rho multistep deis now") - return x_t - - def multistep_deis_second_order_update( - self, - model_output_list: List[torch.FloatTensor], - timestep_list: List[int], - prev_timestep: int, - sample: torch.FloatTensor, - ) -> torch.FloatTensor: - """ - One step for the second-order multistep DEIS. - - Args: - model_output_list (`List[torch.FloatTensor]`): - direct outputs from learned diffusion model at current and latter timesteps. - timestep (`int`): current and latter discrete timestep in the diffusion chain. - prev_timestep (`int`): previous discrete timestep in the diffusion chain. - sample (`torch.FloatTensor`): - current instance of sample being created by diffusion process. - - Returns: - `torch.FloatTensor`: the sample tensor at the previous timestep. - """ - t, s0, s1 = prev_timestep, timestep_list[-1], timestep_list[-2] - m0, m1 = model_output_list[-1], model_output_list[-2] - alpha_t, alpha_s0, alpha_s1 = self.alpha_t[t], self.alpha_t[s0], self.alpha_t[s1] - sigma_t, sigma_s0, sigma_s1 = self.sigma_t[t], self.sigma_t[s0], self.sigma_t[s1] - - rho_t, rho_s0, rho_s1 = sigma_t / alpha_t, sigma_s0 / alpha_s0, sigma_s1 / alpha_s1 - - if self.config.algorithm_type == "deis": - - def ind_fn(t, b, c): - # Integrate[(log(t) - log(c)) / (log(b) - log(c)), {t}] - return t * (-np.log(c) + np.log(t) - 1) / (np.log(b) - np.log(c)) - - coef1 = ind_fn(rho_t, rho_s0, rho_s1) - ind_fn(rho_s0, rho_s0, rho_s1) - coef2 = ind_fn(rho_t, rho_s1, rho_s0) - ind_fn(rho_s0, rho_s1, rho_s0) - - x_t = alpha_t * (sample / alpha_s0 + coef1 * m0 + coef2 * m1) - return x_t - else: - raise NotImplementedError("only support log-rho multistep deis now") - - def multistep_deis_third_order_update( - self, - model_output_list: List[torch.FloatTensor], - timestep_list: List[int], - prev_timestep: int, - sample: torch.FloatTensor, - ) -> torch.FloatTensor: - """ - One step for the third-order multistep DEIS. - - Args: - model_output_list (`List[torch.FloatTensor]`): - direct outputs from learned diffusion model at current and latter timesteps. - timestep (`int`): current and latter discrete timestep in the diffusion chain. - prev_timestep (`int`): previous discrete timestep in the diffusion chain. - sample (`torch.FloatTensor`): - current instance of sample being created by diffusion process. - - Returns: - `torch.FloatTensor`: the sample tensor at the previous timestep. - """ - t, s0, s1, s2 = prev_timestep, timestep_list[-1], timestep_list[-2], timestep_list[-3] - m0, m1, m2 = model_output_list[-1], model_output_list[-2], model_output_list[-3] - alpha_t, alpha_s0, alpha_s1, alpha_s2 = self.alpha_t[t], self.alpha_t[s0], self.alpha_t[s1], self.alpha_t[s2] - sigma_t, sigma_s0, sigma_s1, simga_s2 = self.sigma_t[t], self.sigma_t[s0], self.sigma_t[s1], self.sigma_t[s2] - rho_t, rho_s0, rho_s1, rho_s2 = ( - sigma_t / alpha_t, - sigma_s0 / alpha_s0, - sigma_s1 / alpha_s1, - simga_s2 / alpha_s2, - ) - - if self.config.algorithm_type == "deis": - - def ind_fn(t, b, c, d): - # Integrate[(log(t) - log(c))(log(t) - log(d)) / (log(b) - log(c))(log(b) - log(d)), {t}] - numerator = t * ( - np.log(c) * (np.log(d) - np.log(t) + 1) - - np.log(d) * np.log(t) - + np.log(d) - + np.log(t) ** 2 - - 2 * np.log(t) - + 2 - ) - denominator = (np.log(b) - np.log(c)) * (np.log(b) - np.log(d)) - return numerator / denominator - - coef1 = ind_fn(rho_t, rho_s0, rho_s1, rho_s2) - ind_fn(rho_s0, rho_s0, rho_s1, rho_s2) - coef2 = ind_fn(rho_t, rho_s1, rho_s2, rho_s0) - ind_fn(rho_s0, rho_s1, rho_s2, rho_s0) - coef3 = ind_fn(rho_t, rho_s2, rho_s0, rho_s1) - ind_fn(rho_s0, rho_s2, rho_s0, rho_s1) - - x_t = alpha_t * (sample / alpha_s0 + coef1 * m0 + coef2 * m1 + coef3 * m2) - - return x_t - else: - raise NotImplementedError("only support log-rho multistep deis now") - - def step( - self, - model_output: torch.FloatTensor, - timestep: int, - sample: torch.FloatTensor, - return_dict: bool = True, - ) -> Union[SchedulerOutput, Tuple]: - """ - Step function propagating the sample with the multistep DEIS. - - Args: - model_output (`torch.FloatTensor`): direct output from learned diffusion model. - timestep (`int`): current discrete timestep in the diffusion chain. - sample (`torch.FloatTensor`): - current instance of sample being created by diffusion process. - return_dict (`bool`): option for returning tuple rather than SchedulerOutput class - - Returns: - [`~scheduling_utils.SchedulerOutput`] or `tuple`: [`~scheduling_utils.SchedulerOutput`] if `return_dict` is - True, otherwise a `tuple`. When returning a tuple, the first element is the sample tensor. - - """ - if self.num_inference_steps is None: - raise ValueError( - "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" - ) - - if isinstance(timestep, torch.Tensor): - timestep = timestep.to(self.timesteps.device) - step_index = (self.timesteps == timestep).nonzero() - if len(step_index) == 0: - step_index = len(self.timesteps) - 1 - else: - step_index = step_index.item() - prev_timestep = 0 if step_index == len(self.timesteps) - 1 else self.timesteps[step_index + 1] - lower_order_final = ( - (step_index == len(self.timesteps) - 1) and self.config.lower_order_final and len(self.timesteps) < 15 - ) - lower_order_second = ( - (step_index == len(self.timesteps) - 2) and self.config.lower_order_final and len(self.timesteps) < 15 - ) - - model_output = self.convert_model_output(model_output, timestep, sample) - for i in range(self.config.solver_order - 1): - self.model_outputs[i] = self.model_outputs[i + 1] - self.model_outputs[-1] = model_output - - if self.config.solver_order == 1 or self.lower_order_nums < 1 or lower_order_final: - prev_sample = self.deis_first_order_update(model_output, timestep, prev_timestep, sample) - elif self.config.solver_order == 2 or self.lower_order_nums < 2 or lower_order_second: - timestep_list = [self.timesteps[step_index - 1], timestep] - prev_sample = self.multistep_deis_second_order_update( - self.model_outputs, timestep_list, prev_timestep, sample - ) - else: - timestep_list = [self.timesteps[step_index - 2], self.timesteps[step_index - 1], timestep] - prev_sample = self.multistep_deis_third_order_update( - self.model_outputs, timestep_list, prev_timestep, sample - ) - - if self.lower_order_nums < self.config.solver_order: - self.lower_order_nums += 1 - - if not return_dict: - return (prev_sample,) - - return SchedulerOutput(prev_sample=prev_sample) - - def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor: - """ - Ensures interchangeability with schedulers that need to scale the denoising model input depending on the - current timestep. - - Args: - sample (`torch.FloatTensor`): input sample - - Returns: - `torch.FloatTensor`: scaled input sample - """ - return sample - - def add_noise( - self, - original_samples: torch.FloatTensor, - noise: torch.FloatTensor, - timesteps: torch.IntTensor, - ) -> torch.FloatTensor: - # Make sure alphas_cumprod and timestep have same device and dtype as original_samples - self.alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) - timesteps = timesteps.to(original_samples.device) - - sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 - sqrt_alpha_prod = sqrt_alpha_prod.flatten() - while len(sqrt_alpha_prod.shape) < len(original_samples.shape): - sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) - - sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() - while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) - - noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise - return noisy_samples - - def get_velocity( - self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor - ) -> torch.FloatTensor: - # Make sure alphas_cumprod and timestep have same device and dtype as sample - self.alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype) - timesteps = timesteps.to(sample.device) - - sqrt_alpha_prod = self.alphas_cumprod[timesteps] ** 0.5 - sqrt_alpha_prod = sqrt_alpha_prod.flatten() - while len(sqrt_alpha_prod.shape) < len(sample.shape): - sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) - - sqrt_one_minus_alpha_prod = (1 - self.alphas_cumprod[timesteps]) ** 0.5 - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() - while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape): - sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) - - velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample - return velocity - - def __len__(self): - return self.config.num_train_timesteps diff --git a/train_dreambooth.py b/train_dreambooth.py index 280cf77..6d699f3 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -87,7 +87,7 @@ def parse_args(): parser.add_argument( "--num_buckets", type=int, - default=4, + default=0, help="Number of aspect ratio buckets in either direction.", ) parser.add_argument( @@ -305,7 +305,7 @@ def parse_args(): parser.add_argument( "--adam_weight_decay", type=float, - default=0, + default=1e-2, help="Weight decay to use." ) parser.add_argument( @@ -526,6 +526,7 @@ def main(): with_prior_preservation=args.num_class_images != 0, prior_loss_weight=args.prior_loss_weight, no_val=args.valid_set_size == 0, + # low_freq_noise=0, ) checkpoint_output_dir = output_dir / "model" @@ -587,7 +588,6 @@ def main(): seed=args.seed, optimizer=optimizer, lr_scheduler=lr_scheduler, - prepare_unet=True, num_train_epochs=args.num_train_epochs, gradient_accumulation_steps=args.gradient_accumulation_steps, sample_frequency=args.sample_frequency, diff --git a/train_lora.py b/train_lora.py index d7c2de0..0a3d4c9 100644 --- a/train_lora.py +++ b/train_lora.py @@ -84,7 +84,7 @@ def parse_args(): parser.add_argument( "--num_buckets", type=int, - default=4, + default=0, help="Number of aspect ratio buckets in either direction.", ) parser.add_argument( diff --git a/train_ti.py b/train_ti.py index 68783ea..394711f 100644 --- a/train_ti.py +++ b/train_ti.py @@ -607,7 +607,7 @@ def main(): with_prior_preservation=args.num_class_images != 0, prior_loss_weight=args.prior_loss_weight, no_val=args.valid_set_size == 0, - low_freq_noise=0, + # low_freq_noise=0, strategy=textual_inversion_strategy, num_train_epochs=args.num_train_epochs, gradient_accumulation_steps=args.gradient_accumulation_steps, diff --git a/training/functional.py b/training/functional.py index b830261..990c4cd 100644 --- a/training/functional.py +++ b/training/functional.py @@ -22,7 +22,6 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings from models.clip.util import get_extended_embeddings from models.clip.tokenizer import MultiCLIPTokenizer -from schedulers.scheduling_deis_multistep import DEISMultistepScheduler from training.util import AverageMeter @@ -74,19 +73,12 @@ def make_grid(images, rows, cols): return grid -def get_models(pretrained_model_name_or_path: str, noise_scheduler: str = "ddpm"): - if noise_scheduler == "deis": - noise_scheduler_cls = DEISMultistepScheduler - elif noise_scheduler == "ddpm": - noise_scheduler_cls = DDPMScheduler - else: - raise ValueError(f"noise_scheduler must be one of [\"ddpm\", \"deis\"], got {noise_scheduler}") - +def get_models(pretrained_model_name_or_path: str): tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet') - noise_scheduler = noise_scheduler_cls.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') + noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') sample_scheduler = UniPCMultistepScheduler.from_pretrained( pretrained_model_name_or_path, subfolder='scheduler') @@ -232,9 +224,6 @@ def generate_class_images( del pipeline - if torch.cuda.is_available(): - torch.cuda.empty_cache() - def add_placeholder_tokens( tokenizer: MultiCLIPTokenizer, @@ -274,26 +263,41 @@ def loss_step( latents = vae.encode(batch["pixel_values"]).latent_dist.sample() latents = latents * vae.config.scaling_factor + bsz = latents.shape[0] + generator = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None # Sample noise that we'll add to the latents - noise = torch.randn( - latents.shape, - dtype=latents.dtype, - layout=latents.layout, - device=latents.device, - generator=generator - ) - if low_freq_noise > 0: - noise += low_freq_noise * torch.randn( + if low_freq_noise == 0: + noise = torch.randn( + latents.shape, + dtype=latents.dtype, + layout=latents.layout, + device=latents.device, + generator=generator + ) + else: + noise = (1 - low_freq_noise) * torch.randn( + latents.shape, + dtype=latents.dtype, + layout=latents.layout, + device=latents.device, + generator=generator + ) + low_freq_noise * torch.randn( latents.shape[0], latents.shape[1], 1, 1, dtype=latents.dtype, layout=latents.layout, device=latents.device, generator=generator ) + # noise += low_freq_noise * torch.randn( + # bsz, 1, 1, 1, + # dtype=latents.dtype, + # layout=latents.layout, + # device=latents.device, + # generator=generator + # ) - bsz = latents.shape[0] # Sample a random timestep for each image timesteps = torch.randint( 0, diff --git a/training/optimization.py b/training/optimization.py index 6c9a35d..7d8d55a 100644 --- a/training/optimization.py +++ b/training/optimization.py @@ -113,7 +113,7 @@ def get_scheduler( ): num_training_steps_per_epoch = math.ceil( num_training_steps_per_epoch / gradient_accumulation_steps - ) * gradient_accumulation_steps + ) # * gradient_accumulation_steps num_training_steps = train_epochs * num_training_steps_per_epoch num_warmup_steps = warmup_epochs * num_training_steps_per_epoch diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index 0290327..e5e84c8 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py @@ -88,8 +88,8 @@ def dreambooth_strategy_callbacks( def on_prepare(): unet.requires_grad_(True) - text_encoder.requires_grad_(True) - text_encoder.text_model.embeddings.requires_grad_(False) + text_encoder.text_model.encoder.requires_grad_(True) + text_encoder.text_model.final_layer_norm.requires_grad_(True) if ema_unet is not None: ema_unet.to(accelerator.device) @@ -203,7 +203,7 @@ def dreambooth_prepare( lr_scheduler: torch.optim.lr_scheduler._LRScheduler, **kwargs ): - return accelerator.prepare(text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) + ({}) + return accelerator.prepare(text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) + ({},) dreambooth_strategy = TrainingStrategy( diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 732cd74..bd0d178 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py @@ -130,7 +130,7 @@ def textual_inversion_strategy_callbacks( if lambda_ != 0: w = text_encoder.text_model.embeddings.temp_token_embedding.weight - mask = torch.zeros(w.size(0), dtype=torch.bool) + mask = torch.zeros(w.shape[0], dtype=torch.bool) mask[text_encoder.text_model.embeddings.temp_token_ids] = True mask[zero_ids] = False -- cgit v1.2.3-70-g09d2