From 2ad46871e2ead985445da2848a4eb7072b6e48aa Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 14 Nov 2022 17:09:58 +0100 Subject: Update --- data/csv.py | 2 +- dreambooth.py | 71 +++--- infer.py | 21 +- .../stable_diffusion/vlpn_stable_diffusion.py | 33 ++- schedulers/scheduling_euler_ancestral_discrete.py | 261 --------------------- textual_inversion.py | 13 +- training/optimization.py | 2 +- 7 files changed, 87 insertions(+), 316 deletions(-) delete mode 100644 schedulers/scheduling_euler_ancestral_discrete.py diff --git a/data/csv.py b/data/csv.py index 793fbf8..67ac43b 100644 --- a/data/csv.py +++ b/data/csv.py @@ -93,7 +93,7 @@ class CSVDataModule(pl.LightningDataModule): items = [item for item in items if not "skip" in item or item["skip"] != True] num_images = len(items) - valid_set_size = int(num_images * 0.2) + valid_set_size = int(num_images * 0.1) if self.valid_set_size: valid_set_size = min(valid_set_size, self.valid_set_size) valid_set_size = max(valid_set_size, 1) diff --git a/dreambooth.py b/dreambooth.py index 8c4bf50..7b34fce 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -15,7 +15,7 @@ import torch.utils.checkpoint from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import LoggerType, set_seed -from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, UNet2DConditionModel +from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, PNDMScheduler, UNet2DConditionModel from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup from diffusers.training_utils import EMAModel from PIL import Image @@ -23,7 +23,6 @@ from tqdm.auto import tqdm from transformers import CLIPTextModel, CLIPTokenizer from slugify import slugify -from schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from data.csv import CSVDataModule from training.optimization import get_one_cycle_schedule @@ -144,7 +143,7 @@ def parse_args(): parser.add_argument( "--max_train_steps", type=int, - default=6000, + default=None, help="Total number of training steps to perform. If provided, overrides num_train_epochs.", ) parser.add_argument( @@ -211,7 +210,7 @@ def parse_args(): parser.add_argument( "--ema_power", type=float, - default=7 / 8 + default=6/7 ) parser.add_argument( "--ema_max_decay", @@ -283,6 +282,12 @@ def parse_args(): default=1, help="Number of samples to generate per batch", ) + parser.add_argument( + "--valid_set_size", + type=int, + default=None, + help="Number of images in the validation dataset." + ) parser.add_argument( "--train_batch_size", type=int, @@ -292,7 +297,7 @@ def parse_args(): parser.add_argument( "--sample_steps", type=int, - default=30, + default=25, help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", ) parser.add_argument( @@ -461,7 +466,7 @@ class Checkpointer: self.ema_unet.averaged_model if self.ema_unet is not None else self.unet) unwrapped_text_encoder = self.accelerator.unwrap_model(self.text_encoder) - scheduler = EulerAncestralDiscreteScheduler( + scheduler = DPMSolverMultistepScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" ) @@ -487,23 +492,30 @@ class Checkpointer: with torch.inference_mode(): for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: all_samples = [] - file_path = samples_path.joinpath(pool, f"step_{step}.png") + file_path = samples_path.joinpath(pool, f"step_{step}.jpg") file_path.parent.mkdir(parents=True, exist_ok=True) data_enum = enumerate(data) + batches = [ + batch + for j, batch in data_enum + if j * data.batch_size < self.sample_batch_size * self.sample_batches + ] + prompts = [ + prompt.format(identifier=self.instance_identifier) + for batch in batches + for prompt in batch["prompts"] + ] + nprompts = [ + prompt + for batch in batches + for prompt in batch["nprompts"] + ] + for i in range(self.sample_batches): - batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] - prompt = [ - prompt.format(identifier=self.instance_identifier) - for batch in batches - for prompt in batch["prompts"] - ][:self.sample_batch_size] - nprompt = [ - prompt - for batch in batches - for prompt in batch["nprompts"] - ][:self.sample_batch_size] + prompt = prompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] + nprompt = nprompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] samples = pipeline( prompt=prompt, @@ -523,7 +535,7 @@ class Checkpointer: del samples image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size) - image_grid.save(file_path) + image_grid.save(file_path, quality=85) del all_samples del image_grid @@ -576,6 +588,12 @@ def main(): vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae') unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='unet') + unet.set_use_memory_efficient_attention_xformers(True) + + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + text_encoder.gradient_checkpointing_enable() + ema_unet = None if args.use_ema: ema_unet = EMAModel( @@ -586,12 +604,6 @@ def main(): device=accelerator.device ) - unet.set_use_memory_efficient_attention_xformers(True) - - if args.gradient_checkpointing: - unet.enable_gradient_checkpointing() - text_encoder.gradient_checkpointing_enable() - # Freeze text_encoder and vae freeze_params(vae.parameters()) @@ -726,7 +738,7 @@ def main(): size=args.resolution, repeats=args.repeats, center_crop=args.center_crop, - valid_set_size=args.sample_batch_size*args.sample_batches, + valid_set_size=args.valid_set_size, num_workers=args.dataloader_num_workers, collate_fn=collate_fn ) @@ -743,7 +755,7 @@ def main(): for i in range(0, len(missing_data), args.sample_batch_size) ] - scheduler = EulerAncestralDiscreteScheduler( + scheduler = DPMSolverMultistepScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" ) @@ -962,6 +974,8 @@ def main(): optimizer.step() if not accelerator.optimizer_step_was_skipped: lr_scheduler.step() + if args.use_ema: + ema_unet.step(unet) optimizer.zero_grad(set_to_none=True) loss = loss.detach().item() @@ -969,9 +983,6 @@ def main(): # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: - if args.use_ema: - ema_unet.step(unet) - local_progress_bar.update(1) global_progress_bar.update(1) diff --git a/infer.py b/infer.py index 9bc9efe..9b0ec1f 100644 --- a/infer.py +++ b/infer.py @@ -8,11 +8,10 @@ from pathlib import Path import torch import json from PIL import Image -from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DDIMScheduler, LMSDiscreteScheduler +from diffusers import AutoencoderKL, UNet2DConditionModel, PNDMScheduler, DPMSolverMultistepScheduler, DDIMScheduler, LMSDiscreteScheduler, EulerAncestralDiscreteScheduler from transformers import CLIPTextModel, CLIPTokenizer from slugify import slugify -from schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion @@ -21,7 +20,7 @@ torch.backends.cuda.matmul.allow_tf32 = True default_args = { "model": None, - "scheduler": "euler_a", + "scheduler": "dpmpp", "precision": "fp32", "ti_embeddings_dir": "embeddings_ti", "output_dir": "output/inference", @@ -65,7 +64,7 @@ def create_args_parser(): parser.add_argument( "--scheduler", type=str, - choices=["plms", "ddim", "klms", "euler_a"], + choices=["plms", "ddim", "klms", "dpmpp", "euler_a"], ) parser.add_argument( "--precision", @@ -222,6 +221,10 @@ def create_pipeline(model, scheduler, ti_embeddings_dir, dtype): scheduler = DDIMScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", clip_sample=False, set_alpha_to_one=False ) + elif scheduler == "dpmpp": + scheduler = DPMSolverMultistepScheduler( + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" + ) else: scheduler = EulerAncestralDiscreteScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" @@ -282,7 +285,8 @@ def generate(output_dir, pipeline, args): ).images for j, image in enumerate(images): - image.save(output_dir.joinpath(f"{args.seed + i}_{j}.jpg")) + image.save(output_dir.joinpath(f"{args.seed + i}_{j}.png")) + image.save(output_dir.joinpath(f"{args.seed + i}_{j}.jpg"), quality=85) if torch.cuda.is_available(): torch.cuda.empty_cache() @@ -312,15 +316,16 @@ class CmdParse(cmd.Cmd): try: args = run_parser(self.parser, default_cmds, elements) + + if len(args.prompt) == 0: + print('Try again with a prompt!') + return except SystemExit: self.parser.print_help() except Exception as e: print(e) return - if len(args.prompt) == 0: - print('Try again with a prompt!') - try: generate(self.output_dir, self.pipeline, args) except KeyboardInterrupt: diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 36942f0..ba057ba 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py @@ -8,11 +8,20 @@ import PIL from diffusers.configuration_utils import FrozenDict from diffusers.utils import is_accelerate_available -from diffusers import AutoencoderKL, DiffusionPipeline, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel +from diffusers import ( + AutoencoderKL, + DiffusionPipeline, + UNet2DConditionModel, + DDIMScheduler, + DPMSolverMultistepScheduler, + EulerAncestralDiscreteScheduler, + EulerDiscreteScheduler, + LMSDiscreteScheduler, + PNDMScheduler, +) from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput from diffusers.utils import logging from transformers import CLIPTextModel, CLIPTokenizer -from schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler from models.clip.prompt import PromptProcessor logger = logging.get_logger(__name__) # pylint: disable=invalid-name @@ -33,7 +42,14 @@ class VlpnStableDiffusion(DiffusionPipeline): text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, unet: UNet2DConditionModel, - scheduler: Union[DDIMScheduler, PNDMScheduler, LMSDiscreteScheduler, EulerAncestralDiscreteScheduler], + scheduler: Union[ + DDIMScheduler, + PNDMScheduler, + LMSDiscreteScheduler, + EulerDiscreteScheduler, + EulerAncestralDiscreteScheduler, + DPMSolverMultistepScheduler, + ], **kwargs, ): super().__init__() @@ -252,19 +268,14 @@ class VlpnStableDiffusion(DiffusionPipeline): latents = 0.18215 * latents # expand init_latents for batch_size - latents = torch.cat([latents] * batch_size) + latents = torch.cat([latents] * batch_size, dim=0) # get the original timestep using init_timestep init_timestep = int(num_inference_steps * strength) + offset init_timestep = min(init_timestep, num_inference_steps) - if not isinstance(self.scheduler, DDIMScheduler) and not isinstance(self.scheduler, PNDMScheduler): - timesteps = torch.tensor( - [num_inference_steps - init_timestep] * batch_size, dtype=torch.long, device=self.device - ) - else: - timesteps = self.scheduler.timesteps[-init_timestep] - timesteps = torch.tensor([timesteps] * batch_size, device=self.device) + timesteps = self.scheduler.timesteps[-init_timestep] + timesteps = torch.tensor([timesteps] * batch_size, device=self.device) # add noise to latents using the timesteps noise = torch.randn(latents.shape, generator=generator, device=self.device, dtype=latents_dtype) diff --git a/schedulers/scheduling_euler_ancestral_discrete.py b/schedulers/scheduling_euler_ancestral_discrete.py deleted file mode 100644 index cef50fe..0000000 --- a/schedulers/scheduling_euler_ancestral_discrete.py +++ /dev/null @@ -1,261 +0,0 @@ -# Copyright 2022 Katherine Crowson and The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from dataclasses import dataclass -from typing import Optional, Tuple, Union - -import numpy as np -import torch - -from diffusers.configuration_utils import ConfigMixin, register_to_config -from diffusers.utils import BaseOutput, deprecate, logging -from diffusers.schedulers.scheduling_utils import SchedulerMixin - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -@dataclass -# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->EulerAncestralDiscrete -class EulerAncestralDiscreteSchedulerOutput(BaseOutput): - """ - Output class for the scheduler's step function output. - - Args: - prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): - Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the - denoising loop. - pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): - The predicted denoised sample (x_{0}) based on the model output from the current timestep. - `pred_original_sample` can be used to preview progress or for guidance. - """ - - prev_sample: torch.FloatTensor - pred_original_sample: Optional[torch.FloatTensor] = None - - -class EulerAncestralDiscreteScheduler(SchedulerMixin, ConfigMixin): - """ - Ancestral sampling with Euler method steps. Based on the original k-diffusion implementation by Katherine Crowson: - https://github.com/crowsonkb/k-diffusion/blob/481677d114f6ea445aa009cf5bd7a9cdee909e47/k_diffusion/sampling.py#L72 - - [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` - function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. - [`~ConfigMixin`] also provides general loading and saving functionality via the [`~ConfigMixin.save_config`] and - [`~ConfigMixin.from_config`] functions. - - Args: - num_train_timesteps (`int`): number of diffusion steps used to train the model. - beta_start (`float`): the starting `beta` value of inference. - beta_end (`float`): the final `beta` value. - beta_schedule (`str`): - the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from - `linear` or `scaled_linear`. - trained_betas (`np.ndarray`, optional): - option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. - - """ - - @register_to_config - def __init__( - self, - num_train_timesteps: int = 1000, - beta_start: float = 0.0001, - beta_end: float = 0.02, - beta_schedule: str = "linear", - trained_betas: Optional[np.ndarray] = None, - ): - if trained_betas is not None: - self.betas = torch.from_numpy(trained_betas) - elif beta_schedule == "linear": - self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) - elif beta_schedule == "scaled_linear": - # this schedule is very specific to the latent diffusion model. - self.betas = ( - torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 - ) - else: - raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") - - self.alphas = 1.0 - self.betas - self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) - - sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) - sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32) - self.sigmas = torch.from_numpy(sigmas) - - # standard deviation of the initial noise distribution - self.init_noise_sigma = self.sigmas.max() - - # setable values - self.num_inference_steps = None - timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy() - self.timesteps = torch.from_numpy(timesteps) - self.is_scale_input_called = False - - def scale_model_input( - self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor] - ) -> torch.FloatTensor: - """ - Scales the denoising model input by `(sigma**2 + 1) ** 0.5` to match the Euler algorithm. - - Args: - sample (`torch.FloatTensor`): input sample - timestep (`float` or `torch.FloatTensor`): the current timestep in the diffusion chain - - Returns: - `torch.FloatTensor`: scaled input sample - """ - if isinstance(timestep, torch.Tensor): - timestep = timestep.to(self.timesteps.device) - step_index = (self.timesteps == timestep).nonzero().item() - sigma = self.sigmas[step_index] - sample = sample / ((sigma**2 + 1) ** 0.5) - self.is_scale_input_called = True - return sample - - def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): - """ - Sets the timesteps used for the diffusion chain. Supporting function to be run before inference. - - Args: - num_inference_steps (`int`): - the number of diffusion steps used when generating samples with a pre-trained model. - device (`str` or `torch.device`, optional): - the device to which the timesteps should be moved to. If `None`, the timesteps are not moved. - """ - self.num_inference_steps = num_inference_steps - - timesteps = np.linspace(0, self.config.num_train_timesteps - 1, num_inference_steps, dtype=float)[::-1].copy() - sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5) - sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas) - sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32) - self.sigmas = torch.from_numpy(sigmas).to(device=device) - self.timesteps = torch.from_numpy(timesteps).to(device=device) - - def step( - self, - model_output: torch.FloatTensor, - timestep: Union[float, torch.FloatTensor], - sample: torch.FloatTensor, - generator: Optional[torch.Generator] = None, - return_dict: bool = True, - ) -> Union[EulerAncestralDiscreteSchedulerOutput, Tuple]: - """ - Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion - process from the learned model outputs (most often the predicted noise). - - Args: - model_output (`torch.FloatTensor`): direct output from learned diffusion model. - timestep (`float`): current timestep in the diffusion chain. - sample (`torch.FloatTensor`): - current instance of sample being created by diffusion process. - generator (`torch.Generator`, optional): Random number generator. - return_dict (`bool`): option for returning tuple rather than EulerAncestralDiscreteSchedulerOutput class - - Returns: - [`~schedulers.scheduling_utils.EulerAncestralDiscreteSchedulerOutput`] or `tuple`: - [`~schedulers.scheduling_utils.EulerAncestralDiscreteSchedulerOutput`] if `return_dict` is True, otherwise - a `tuple`. When returning a tuple, the first element is the sample tensor. - - """ - - if ( - isinstance(timestep, int) - or isinstance(timestep, torch.IntTensor) - or isinstance(timestep, torch.LongTensor) - ): - raise ValueError( - "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" - " `EulerDiscreteScheduler.step()` is not supported. Make sure to pass" - " one of the `scheduler.timesteps` as a timestep.", - ) - - if not self.is_scale_input_called: - logger.warn( - "The `scale_model_input` function should be called before `step` to ensure correct denoising. " - "See `StableDiffusionPipeline` for a usage example." - ) - - if isinstance(timestep, torch.Tensor): - timestep = timestep.to(self.timesteps.device) - - step_index = (self.timesteps == timestep).nonzero().item() - sigma = self.sigmas[step_index] - - # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise - pred_original_sample = sample - sigma * model_output - sigma_from = self.sigmas[step_index] - sigma_to = self.sigmas[step_index + 1] - sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5 - sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 - - # 2. Convert to an ODE derivative - derivative = (sample - pred_original_sample) / sigma - - dt = sigma_down - sigma - - prev_sample = sample + derivative * dt - - device = model_output.device if torch.is_tensor(model_output) else "cpu" - noise = torch.randn(model_output.shape, dtype=model_output.dtype, device=device, generator=generator) - prev_sample = prev_sample + noise * sigma_up - - if not return_dict: - return (prev_sample,) - - return EulerAncestralDiscreteSchedulerOutput( - prev_sample=prev_sample, pred_original_sample=pred_original_sample - ) - - def add_noise( - self, - original_samples: torch.FloatTensor, - noise: torch.FloatTensor, - timesteps: torch.FloatTensor, - ) -> torch.FloatTensor: - # Make sure sigmas and timesteps have the same device and dtype as original_samples - self.sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype) - if original_samples.device.type == "mps" and torch.is_floating_point(timesteps): - # mps does not support float64 - self.timesteps = self.timesteps.to(original_samples.device, dtype=torch.float32) - timesteps = timesteps.to(original_samples.device, dtype=torch.float32) - else: - self.timesteps = self.timesteps.to(original_samples.device) - timesteps = timesteps.to(original_samples.device) - - schedule_timesteps = self.timesteps - - if isinstance(timesteps, torch.IntTensor) or isinstance(timesteps, torch.LongTensor): - deprecate( - "timesteps as indices", - "0.8.0", - "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to" - " `EulerAncestralDiscreteScheduler.add_noise()` will not be supported in future versions. Make sure to" - " pass values from `scheduler.timesteps` as timesteps.", - standard_warn=False, - ) - step_indices = timesteps - else: - step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps] - - sigma = self.sigmas[step_indices].flatten() - while len(sigma.shape) < len(original_samples.shape): - sigma = sigma.unsqueeze(-1) - - noisy_samples = original_samples + noise * sigma - return noisy_samples - - def __len__(self): - return self.config.num_train_timesteps diff --git a/textual_inversion.py b/textual_inversion.py index 578c054..999161b 100644 --- a/textual_inversion.py +++ b/textual_inversion.py @@ -15,14 +15,13 @@ import torch.utils.checkpoint from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import LoggerType, set_seed -from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel +from diffusers import AutoencoderKL, DDPMScheduler, EulerAncestralDiscreteScheduler, UNet2DConditionModel from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup from PIL import Image from tqdm.auto import tqdm from transformers import CLIPTextModel, CLIPTokenizer from slugify import slugify -from schedulers.scheduling_euler_ancestral_discrete import EulerAncestralDiscreteScheduler from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from data.csv import CSVDataModule from training.optimization import get_one_cycle_schedule @@ -134,7 +133,7 @@ def parse_args(): parser.add_argument( "--max_train_steps", type=int, - default=10000, + default=None, help="Total number of training steps to perform. If provided, overrides num_train_epochs.", ) parser.add_argument( @@ -251,6 +250,12 @@ def parse_args(): default=1, help="Number of samples to generate per batch", ) + parser.add_argument( + "--valid_set_size", + type=int, + default=None, + help="Number of images in the validation dataset." + ) parser.add_argument( "--train_batch_size", type=int, @@ -637,7 +642,7 @@ def main(): size=args.resolution, repeats=args.repeats, center_crop=args.center_crop, - valid_set_size=args.sample_batch_size*args.sample_batches, + valid_set_size=args.valid_set_size, num_workers=args.dataloader_num_workers, collate_fn=collate_fn ) diff --git a/training/optimization.py b/training/optimization.py index 0fd7ec8..0e603fa 100644 --- a/training/optimization.py +++ b/training/optimization.py @@ -6,7 +6,7 @@ from diffusers.utils import logging logger = logging.get_logger(__name__) -def get_one_cycle_schedule(optimizer, num_training_steps, annealing="cos", min_lr=0.05, mid_point=0.43, last_epoch=-1): +def get_one_cycle_schedule(optimizer, num_training_steps, annealing="cos", min_lr=0.05, mid_point=0.4, last_epoch=-1): """ Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. -- cgit v1.2.3-54-g00ecf