From f5b656d21c5b449eed6ce212e909043c124f79ee Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 12 Oct 2022 08:18:22 +0200 Subject: Various updates --- data/csv.py | 3 +- dreambooth.py | 53 +++--- environment.yaml | 4 +- infer.py | 10 +- .../stable_diffusion/vlpn_stable_diffusion.py | 11 +- schedulers/scheduling_euler_a.py | 210 +++++++++------------ textual_inversion.py | 74 ++------ 7 files changed, 142 insertions(+), 223 deletions(-) diff --git a/data/csv.py b/data/csv.py index 8637ac1..253ce9e 100644 --- a/data/csv.py +++ b/data/csv.py @@ -68,13 +68,12 @@ class CSVDataModule(pl.LightningDataModule): item.nprompt if "nprompt" in item else "" ) for item in data - if "skip" not in item or item.skip != "x" for i in range(image_multiplier) ] def prepare_data(self): metadata = pd.read_csv(self.data_file) - metadata = list(metadata.itertuples()) + metadata = [item for item in metadata.itertuples() if "skip" not in item or item.skip != "x"] num_images = len(metadata) valid_set_size = int(num_images * 0.2) diff --git a/dreambooth.py b/dreambooth.py index 02f83c6..775aea2 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -112,7 +112,7 @@ def parse_args(): parser.add_argument( "--max_train_steps", type=int, - default=5000, + default=3000, help="Total number of training steps to perform. If provided, overrides num_train_epochs.", ) parser.add_argument( @@ -150,7 +150,7 @@ def parse_args(): parser.add_argument( "--lr_warmup_steps", type=int, - default=600, + default=500, help="Number of steps for the warmup in the lr scheduler." ) parser.add_argument( @@ -167,7 +167,7 @@ def parse_args(): parser.add_argument( "--ema_power", type=float, - default=1.0 + default=7 / 8 ) parser.add_argument( "--ema_max_decay", @@ -468,20 +468,20 @@ def main(): if args.tokenizer_name: tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) elif args.pretrained_model_name_or_path: - tokenizer = CLIPTokenizer.from_pretrained( - args.pretrained_model_name_or_path + '/tokenizer' - ) + tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') # Load models and create wrapper for stable diffusion text_encoder = CLIPTextModel.from_pretrained( - args.pretrained_model_name_or_path + '/text_encoder', - ) - vae = AutoencoderKL.from_pretrained( - args.pretrained_model_name_or_path + '/vae', - ) - unet = UNet2DConditionModel.from_pretrained( - args.pretrained_model_name_or_path + '/unet', - ) + args.pretrained_model_name_or_path, subfolder='text_encoder') + vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae') + unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='unet') + + ema_unet = EMAModel( + unet, + inv_gamma=args.ema_inv_gamma, + power=args.ema_power, + max_value=args.ema_max_decay + ) if args.use_ema else None if args.gradient_checkpointing: unet.enable_gradient_checkpointing() @@ -538,7 +538,7 @@ def main(): pixel_values += [example["class_images"] for example in examples] pixel_values = torch.stack(pixel_values) - pixel_values = pixel_values.to(dtype=torch.float32, memory_format=torch.contiguous_format) + pixel_values = pixel_values.to(memory_format=torch.contiguous_format) input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids @@ -629,16 +629,10 @@ def main(): unet, optimizer, train_dataloader, val_dataloader, lr_scheduler ) - ema_unet = EMAModel( - unet, - inv_gamma=args.ema_inv_gamma, - power=args.ema_power, - max_value=args.ema_max_decay - ) if args.use_ema else None - # Move text_encoder and vae to device text_encoder.to(accelerator.device) vae.to(accelerator.device) + ema_unet.averaged_model.to(accelerator.device) # Keep text_encoder and vae in eval mode as we don't train these text_encoder.eval() @@ -698,7 +692,7 @@ def main(): disable=not accelerator.is_local_main_process, dynamic_ncols=True ) - local_progress_bar.set_description("Batch X out of Y") + local_progress_bar.set_description("Epoch X / Y") global_progress_bar = tqdm( range(args.max_train_steps + val_steps), @@ -709,7 +703,7 @@ def main(): try: for epoch in range(num_epochs): - local_progress_bar.set_description(f"Batch {epoch + 1} out of {num_epochs}") + local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") local_progress_bar.reset() unet.train() @@ -720,9 +714,8 @@ def main(): for step, batch in enumerate(train_dataloader): with accelerator.accumulate(unet): # Convert images to latent space - with torch.no_grad(): - latents = vae.encode(batch["pixel_values"]).latent_dist.sample() - latents = latents * 0.18215 + latents = vae.encode(batch["pixel_values"]).latent_dist.sample() + latents = latents * 0.18215 # Sample noise that we'll add to the latents noise = torch.randn(latents.shape).to(latents.device) @@ -737,8 +730,7 @@ def main(): noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) # Get the text embedding for conditioning - with torch.no_grad(): - encoder_hidden_states = text_encoder(batch["input_ids"])[0] + encoder_hidden_states = text_encoder(batch["input_ids"])[0] # Predict the noise residual noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample @@ -840,7 +832,8 @@ def main(): global_progress_bar.clear() if min_val_loss > val_loss: - accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") + accelerator.print( + f"Global step {global_step}: Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") min_val_loss = val_loss if sample_checkpoint and accelerator.is_main_process: diff --git a/environment.yaml b/environment.yaml index 5ecc5a8..de35645 100644 --- a/environment.yaml +++ b/environment.yaml @@ -6,7 +6,7 @@ dependencies: - cudatoolkit=11.3 - numpy=1.22.3 - pip=20.3 - - python=3.8.10 + - python=3.9.13 - pytorch=1.12.1 - torchvision=0.13.1 - pandas=1.4.3 @@ -32,6 +32,6 @@ dependencies: - test-tube>=0.7.5 - torch-fidelity==0.3.0 - torchmetrics==0.9.3 - - transformers==4.22.2 + - transformers==4.23.1 - triton==2.0.0.dev20220924 - xformers==0.0.13 diff --git a/infer.py b/infer.py index 70851fd..5bd4abc 100644 --- a/infer.py +++ b/infer.py @@ -22,7 +22,7 @@ torch.backends.cuda.matmul.allow_tf32 = True default_args = { "model": None, "scheduler": "euler_a", - "precision": "fp16", + "precision": "fp32", "embeddings_dir": "embeddings", "output_dir": "output/inference", "config": None, @@ -205,10 +205,10 @@ def load_embeddings(tokenizer, text_encoder, embeddings_dir): def create_pipeline(model, scheduler, embeddings_dir, dtype): print("Loading Stable Diffusion pipeline...") - tokenizer = CLIPTokenizer.from_pretrained(model + '/tokenizer', torch_dtype=dtype) - text_encoder = CLIPTextModel.from_pretrained(model + '/text_encoder', torch_dtype=dtype) - vae = AutoencoderKL.from_pretrained(model + '/vae', torch_dtype=dtype) - unet = UNet2DConditionModel.from_pretrained(model + '/unet', torch_dtype=dtype) + tokenizer = CLIPTokenizer.from_pretrained(model, subfolder='/tokenizer', torch_dtype=dtype) + text_encoder = CLIPTextModel.from_pretrained(model, subfolder='/text_encoder', torch_dtype=dtype) + vae = AutoencoderKL.from_pretrained(model, subfolder='/vae', torch_dtype=dtype) + unet = UNet2DConditionModel.from_pretrained(model, subfolder='/unet', torch_dtype=dtype) load_embeddings(tokenizer, text_encoder, embeddings_dir) diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index bfecd1c..8927a78 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py @@ -11,7 +11,7 @@ from diffusers import AutoencoderKL, DiffusionPipeline, DDIMScheduler, LMSDiscre from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput from diffusers.utils import logging from transformers import CLIPTextModel, CLIPTokenizer -from schedulers.scheduling_euler_a import EulerAScheduler, CFGDenoiserForward +from schedulers.scheduling_euler_a import EulerAScheduler logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -284,10 +284,9 @@ class VlpnStableDiffusion(DiffusionPipeline): noise_pred = None if isinstance(self.scheduler, EulerAScheduler): - sigma = t.reshape(1) - sigma_in = torch.cat([sigma] * latent_model_input.shape[0]) - noise_pred = CFGDenoiserForward(self.unet, latent_model_input, sigma_in, - text_embeddings, guidance_scale, quantize=True, DSsigmas=self.scheduler.DSsigmas) + c_out, c_in, sigma_in = self.scheduler.prepare_input(latent_model_input, t, batch_size) + eps = self.unet(latent_model_input * c_in, sigma_in, encoder_hidden_states=text_embeddings).sample + noise_pred = latent_model_input + eps * c_out else: # predict the noise residual noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample @@ -305,7 +304,7 @@ class VlpnStableDiffusion(DiffusionPipeline): image = self.vae.decode(latents).sample image = (image / 2 + 0.5).clamp(0, 1) - image = image.cpu().permute(0, 2, 3, 1).numpy() + image = image.cpu().permute(0, 2, 3, 1).float().numpy() if output_type == "pil": image = self.numpy_to_pil(image) diff --git a/schedulers/scheduling_euler_a.py b/schedulers/scheduling_euler_a.py index 13ea6b3..6abe971 100644 --- a/schedulers/scheduling_euler_a.py +++ b/schedulers/scheduling_euler_a.py @@ -7,113 +7,6 @@ from diffusers.configuration_utils import ConfigMixin, register_to_config from diffusers.schedulers.scheduling_utils import SchedulerMixin, SchedulerOutput -''' -helper functions: append_zero(), - t_to_sigma(), - get_sigmas(), - append_dims(), - CFGDenoiserForward(), - get_scalings(), - DSsigma_to_t(), - DiscreteEpsDDPMDenoiserForward(), - to_d(), - get_ancestral_step() -need cleaning -''' - - -def append_zero(x): - return torch.cat([x, x.new_zeros([1])]) - - -def t_to_sigma(t, sigmas): - t = t.float() - low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac() - return (1 - w) * sigmas[low_idx] + w * sigmas[high_idx] - - -def get_sigmas(sigmas, n=None): - if n is None: - return append_zero(sigmas.flip(0)) - t_max = len(sigmas) - 1 # = 999 - t = torch.linspace(t_max, 0, n, device=sigmas.device, dtype=sigmas.dtype) - return append_zero(t_to_sigma(t, sigmas)) - -# from k_samplers utils.py - - -def append_dims(x, target_dims): - """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" - dims_to_append = target_dims - x.ndim - if dims_to_append < 0: - raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') - return x[(...,) + (None,) * dims_to_append] - - -def CFGDenoiserForward(Unet, x_in, sigma_in, cond_in, cond_scale, quantize=False, DSsigmas=None): - # x_in = torch.cat([x] * 2)#A# concat the latent - # sigma_in = torch.cat([sigma] * 2) #A# concat sigma - # cond_in = torch.cat([uncond, cond]) - # uncond, cond = self.inner_model(x_in, sigma_in, cond=cond_in).chunk(2) - # uncond, cond = DiscreteEpsDDPMDenoiserForward(Unet,x_in, sigma_in,DSsigmas=DSsigmas, cond=cond_in).chunk(2) - # return uncond + (cond - uncond) * cond_scale - noise_pred = DiscreteEpsDDPMDenoiserForward( - Unet, x_in, sigma_in, quantize=quantize, DSsigmas=DSsigmas, cond=cond_in) - return noise_pred - -# from k_samplers sampling.py - - -def to_d(x, sigma, denoised): - """Converts a denoiser output to a Karras ODE derivative.""" - return (x - denoised) / append_dims(sigma.to(denoised.device), x.ndim) - - -def get_scalings(sigma): - sigma_data = 1. - c_out = -sigma - c_in = 1 / (sigma ** 2 + sigma_data ** 2) ** 0.5 - return c_out, c_in - -# DiscreteSchedule DS - - -def DSsigma_to_t(sigma, quantize=False, DSsigmas=None): - dists = torch.abs(sigma - DSsigmas[:, None]) - if quantize: - return torch.argmin(dists, dim=0).view(sigma.shape) - low_idx, high_idx = torch.sort(torch.topk(dists, dim=0, k=2, largest=False).indices, dim=0)[0] - low, high = DSsigmas[low_idx], DSsigmas[high_idx] - w = (low - sigma) / (low - high) - w = w.clamp(0, 1) - t = (1 - w) * low_idx + w * high_idx - return t.view(sigma.shape) - - -def DiscreteEpsDDPMDenoiserForward(Unet, input, sigma, DSsigmas=None, quantize=False, **kwargs): - sigma = sigma.to(dtype=input.dtype, device=Unet.device) - DSsigmas = DSsigmas.to(dtype=input.dtype, device=Unet.device) - c_out, c_in = [append_dims(x, input.ndim) for x in get_scalings(sigma)] - # print(f">>>>>>>>>>> {input.dtype} {c_in.dtype} {sigma.dtype} {DSsigmas.dtype}") - eps = Unet(input * c_in, DSsigma_to_t(sigma, quantize=quantize, DSsigmas=DSsigmas), - encoder_hidden_states=kwargs['cond']).sample - return input + eps * c_out - - -# from k_samplers sampling.py -def get_ancestral_step(sigma_from, sigma_to): - """Calculates the noise level (sigma_down) to step down to and the amount - of noise to add (sigma_up) when doing an ancestral sampling step.""" - sigma_up = (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5 - sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5 - return sigma_down, sigma_up - - -''' -Euler Ancestral Scheduler -''' - - class EulerAScheduler(SchedulerMixin, ConfigMixin): """ Stochastic sampling from Karras et al. [1] tailored to the Variance-Expanding (VE) models [2]. Use Algorithm 2 and @@ -154,20 +47,24 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): beta_end: float = 0.02, beta_schedule: str = "linear", trained_betas: Optional[np.ndarray] = None, + tensor_format: str = "pt", + num_inference_steps=None, + device='cuda' ): if trained_betas is not None: - self.betas = torch.from_numpy(trained_betas) + self.betas = torch.from_numpy(trained_betas).to(device) if beta_schedule == "linear": - self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32, device=device) elif beta_schedule == "scaled_linear": # this schedule is very specific to the latent diffusion model. - self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 - elif beta_schedule == "squaredcos_cap_v2": - # Glide cosine schedule - self.betas = betas_for_alpha_bar(num_train_timesteps) + self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, + dtype=torch.float32, device=device) ** 2 else: raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + self.device = device + self.tensor_format = tensor_format + self.alphas = 1.0 - self.betas self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) @@ -175,8 +72,12 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): self.init_noise_sigma = 1.0 # setable values - self.num_inference_steps = None - self.timesteps = np.arange(0, num_train_timesteps)[::-1] + self.num_inference_steps = num_inference_steps + self.timesteps = np.arange(0, num_train_timesteps)[::-1].copy() + # get sigmas + self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 + self.sigmas = self.get_sigmas(self.DSsigmas, self.num_inference_steps) + self.set_format(tensor_format=tensor_format) # A# take number of steps as input # A# store 1) number of steps 2) timesteps 3) schedule @@ -192,7 +93,7 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): self.num_inference_steps = num_inference_steps self.DSsigmas = ((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5 - self.sigmas = get_sigmas(self.DSsigmas, self.num_inference_steps).to(device=device) + self.sigmas = self.get_sigmas(self.DSsigmas, self.num_inference_steps) self.timesteps = self.sigmas[:-1] self.is_scale_input_called = False @@ -251,8 +152,8 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): s_prev = self.sigmas[step_prev_index] latents = sample - sigma_down, sigma_up = get_ancestral_step(s, s_prev) - d = to_d(latents, s, model_output) + sigma_down, sigma_up = self.get_ancestral_step(s, s_prev) + d = self.to_d(latents, s, model_output) dt = sigma_down - s latents = latents + d * dt latents = latents + torch.randn(latents.shape, layout=latents.layout, device=latents.device, dtype=latents.dtype, @@ -313,3 +214,76 @@ class EulerAScheduler(SchedulerMixin, ConfigMixin): noisy_samples = original_samples + noise * sigma self.is_scale_input_called = True return noisy_samples + + # from k_samplers sampling.py + + def get_ancestral_step(self, sigma_from, sigma_to): + """Calculates the noise level (sigma_down) to step down to and the amount + of noise to add (sigma_up) when doing an ancestral sampling step.""" + sigma_up = (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5 + sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5 + return sigma_down, sigma_up + + def t_to_sigma(self, t, sigmas): + t = t.float() + low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac() + return (1 - w) * sigmas[low_idx] + w * sigmas[high_idx] + + def append_zero(self, x): + return torch.cat([x, x.new_zeros([1])]) + + def get_sigmas(self, sigmas, n=None): + if n is None: + return self.append_zero(sigmas.flip(0)) + t_max = len(sigmas) - 1 # = 999 + device = self.device + t = torch.linspace(t_max, 0, n, device=device) + # t = torch.linspace(t_max, 0, n, device=sigmas.device) + return self.append_zero(self.t_to_sigma(t, sigmas)) + + # from k_samplers utils.py + def append_dims(self, x, target_dims): + """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" + dims_to_append = target_dims - x.ndim + if dims_to_append < 0: + raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') + return x[(...,) + (None,) * dims_to_append] + + # from k_samplers sampling.py + def to_d(self, x, sigma, denoised): + """Converts a denoiser output to a Karras ODE derivative.""" + return (x - denoised) / self.append_dims(sigma, x.ndim) + + def get_scalings(self, sigma): + sigma_data = 1. + c_out = -sigma + c_in = 1 / (sigma ** 2 + sigma_data ** 2) ** 0.5 + return c_out, c_in + + # DiscreteSchedule DS + def DSsigma_to_t(self, sigma, quantize=None): + # quantize = self.quantize if quantize is None else quantize + quantize = False + dists = torch.abs(sigma - self.DSsigmas[:, None]) + if quantize: + return torch.argmin(dists, dim=0).view(sigma.shape) + low_idx, high_idx = torch.sort(torch.topk(dists, dim=0, k=2, largest=False).indices, dim=0)[0] + low, high = self.DSsigmas[low_idx], self.DSsigmas[high_idx] + w = (low - sigma) / (low - high) + w = w.clamp(0, 1) + t = (1 - w) * low_idx + w * high_idx + return t.view(sigma.shape) + + def prepare_input(self, latent_in, t, batch_size): + sigma = t.reshape(1) # A# potential bug: doesn't work on samples > 1 + + sigma_in = torch.cat([sigma] * 2 * batch_size) + # noise_pred = CFGDenoiserForward(self.unet, latent_model_input, sigma_in, text_embeddings , guidance_scale,DSsigmas=self.scheduler.DSsigmas) + # noise_pred = DiscreteEpsDDPMDenoiserForward(self.unet,latent_model_input, sigma_in,DSsigmas=self.scheduler.DSsigmas, cond=cond_in) + c_out, c_in = [self.append_dims(x, latent_in.ndim) for x in self.get_scalings(sigma_in)] + + sigma_in = self.DSsigma_to_t(sigma_in) + # s_in = latent_in.new_ones([latent_in.shape[0]]) + # sigma_in = sigma_in * s_in + + return c_out, c_in, sigma_in diff --git a/textual_inversion.py b/textual_inversion.py index e6d856a..3a3741d 100644 --- a/textual_inversion.py +++ b/textual_inversion.py @@ -17,7 +17,6 @@ from accelerate.logging import get_logger from accelerate.utils import LoggerType, set_seed from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel from diffusers.optimization import get_scheduler -from diffusers.training_utils import EMAModel from PIL import Image from tqdm.auto import tqdm from transformers import CLIPTextModel, CLIPTokenizer @@ -112,7 +111,7 @@ def parse_args(): parser.add_argument( "--max_train_steps", type=int, - default=5000, + default=3000, help="Total number of training steps to perform. If provided, overrides num_train_epochs.", ) parser.add_argument( @@ -150,30 +149,9 @@ def parse_args(): parser.add_argument( "--lr_warmup_steps", type=int, - default=600, + default=500, help="Number of steps for the warmup in the lr scheduler." ) - parser.add_argument( - "--use_ema", - action="store_true", - default=True, - help="Whether to use EMA model." - ) - parser.add_argument( - "--ema_inv_gamma", - type=float, - default=1.0 - ) - parser.add_argument( - "--ema_power", - type=float, - default=1.0 - ) - parser.add_argument( - "--ema_max_decay", - type=float, - default=0.9999 - ) parser.add_argument( "--use_8bit_adam", action="store_true", @@ -348,7 +326,6 @@ class Checkpointer: unet, tokenizer, text_encoder, - ema_text_encoder, placeholder_token, placeholder_token_id, output_dir: Path, @@ -363,7 +340,6 @@ class Checkpointer: self.unet = unet self.tokenizer = tokenizer self.text_encoder = text_encoder - self.ema_text_encoder = ema_text_encoder self.placeholder_token = placeholder_token self.placeholder_token_id = placeholder_token_id self.output_dir = output_dir @@ -380,8 +356,7 @@ class Checkpointer: checkpoints_path = self.output_dir.joinpath("checkpoints") checkpoints_path.mkdir(parents=True, exist_ok=True) - unwrapped = self.accelerator.unwrap_model( - self.ema_text_encoder.averaged_model if self.ema_text_encoder is not None else self.text_encoder) + unwrapped = self.accelerator.unwrap_model(self.text_encoder) # Save a checkpoint learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id] @@ -400,8 +375,7 @@ class Checkpointer: def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps): samples_path = Path(self.output_dir).joinpath("samples") - unwrapped = self.accelerator.unwrap_model( - self.ema_text_encoder.averaged_model if self.ema_text_encoder is not None else self.text_encoder) + unwrapped = self.accelerator.unwrap_model(self.text_encoder) scheduler = EulerAScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" ) @@ -507,9 +481,7 @@ def main(): if args.tokenizer_name: tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) elif args.pretrained_model_name_or_path: - tokenizer = CLIPTokenizer.from_pretrained( - args.pretrained_model_name_or_path + '/tokenizer' - ) + tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer') # Add the placeholder token in tokenizer num_added_tokens = tokenizer.add_tokens(args.placeholder_token) @@ -530,15 +502,10 @@ def main(): placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) # Load models and create wrapper for stable diffusion - text_encoder = CLIPTextModel.from_pretrained( - args.pretrained_model_name_or_path + '/text_encoder', - ) - vae = AutoencoderKL.from_pretrained( - args.pretrained_model_name_or_path + '/vae', - ) + text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder') + vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae') unet = UNet2DConditionModel.from_pretrained( - args.pretrained_model_name_or_path + '/unet', - ) + args.pretrained_model_name_or_path, subfolder='unet') if args.gradient_checkpointing: unet.enable_gradient_checkpointing() @@ -707,13 +674,6 @@ def main(): text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler ) - ema_text_encoder = EMAModel( - text_encoder, - inv_gamma=args.ema_inv_gamma, - power=args.ema_power, - max_value=args.ema_max_decay - ) if args.use_ema else None - # Move vae and unet to device vae.to(accelerator.device) unet.to(accelerator.device) @@ -757,7 +717,6 @@ def main(): unet=unet, tokenizer=tokenizer, text_encoder=text_encoder, - ema_text_encoder=ema_text_encoder, placeholder_token=args.placeholder_token, placeholder_token_id=placeholder_token_id, output_dir=basepath, @@ -777,7 +736,7 @@ def main(): disable=not accelerator.is_local_main_process, dynamic_ncols=True ) - local_progress_bar.set_description("Batch X out of Y") + local_progress_bar.set_description("Epoch X / Y") global_progress_bar = tqdm( range(args.max_train_steps + val_steps), @@ -788,7 +747,7 @@ def main(): try: for epoch in range(num_epochs): - local_progress_bar.set_description(f"Batch {epoch + 1} out of {num_epochs}") + local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") local_progress_bar.reset() text_encoder.train() @@ -799,9 +758,8 @@ def main(): for step, batch in enumerate(train_dataloader): with accelerator.accumulate(text_encoder): # Convert images to latent space - with torch.no_grad(): - latents = vae.encode(batch["pixel_values"]).latent_dist.sample() - latents = latents * 0.18215 + latents = vae.encode(batch["pixel_values"]).latent_dist.sample() + latents = latents * 0.18215 # Sample noise that we'll add to the latents noise = torch.randn(latents.shape).to(latents.device) @@ -859,9 +817,6 @@ def main(): # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: - if args.use_ema: - ema_text_encoder.step(unet) - local_progress_bar.update(1) global_progress_bar.update(1) @@ -881,8 +836,6 @@ def main(): }) logs = {"train/loss": loss, "lr": lr_scheduler.get_last_lr()[0]} - if args.use_ema: - logs["ema_decay"] = ema_text_encoder.decay accelerator.log(logs, step=global_step) @@ -937,7 +890,8 @@ def main(): global_progress_bar.clear() if min_val_loss > val_loss: - accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") + accelerator.print( + f"Global step {global_step}: Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") checkpointer.checkpoint(global_step + global_step_offset, "milestone") min_val_loss = val_loss -- cgit v1.2.3-70-g09d2