From 2c525a0ddb0786b2f0652ab18e08fd4d0a1725d2 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 18 Feb 2023 13:00:13 +0100 Subject: Update --- infer.py | 2 ++ train_dreambooth.py | 8 +++++--- train_lora.py | 8 +++++--- train_ti.py | 6 ++++-- training/functional.py | 21 ++++++++++++++------- 5 files changed, 30 insertions(+), 15 deletions(-) diff --git a/infer.py b/infer.py index 8910e68..07dcd22 100644 --- a/infer.py +++ b/infer.py @@ -263,6 +263,8 @@ def create_scheduler(config, scheduler: str, subscheduler: Optional[str] = None) config, solver_p=create_scheduler(config, subscheduler), ) + else: + raise ValueError(f"Unknown scheduler \"{scheduler}\"") def create_pipeline(model, dtype): diff --git a/train_dreambooth.py b/train_dreambooth.py index 8f0c6ea..e039df0 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -287,7 +287,7 @@ def parse_args(): parser.add_argument( "--optimizer", type=str, - default="lion", + default="adam", help='Optimizer to use ["adam", "adam8bit", "lion"]' ) parser.add_argument( @@ -459,7 +459,7 @@ def main(): save_args(output_dir, args) tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( - args.pretrained_model_name_or_path) + args.pretrained_model_name_or_path, noise_scheduler="deis") tokenizer.set_use_vector_shuffle(args.vector_shuffle) tokenizer.set_dropout(args.vector_dropout) @@ -513,13 +513,15 @@ def main(): eps=args.adam_epsilon, amsgrad=args.adam_amsgrad, ) - else: + elif args.optimizer == 'lion': try: from lion_pytorch import Lion except ImportError: raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion_pytorch`.") create_optimizer = partial(Lion, use_triton=True) + else: + raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") trainer = partial( train, diff --git a/train_lora.py b/train_lora.py index 368c29b..db5330a 100644 --- a/train_lora.py +++ b/train_lora.py @@ -247,7 +247,7 @@ def parse_args(): parser.add_argument( "--optimizer", type=str, - default="lion", + default="adam", help='Optimizer to use ["adam", "adam8bit", "lion"]' ) parser.add_argument( @@ -419,7 +419,7 @@ def main(): save_args(output_dir, args) tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( - args.pretrained_model_name_or_path) + args.pretrained_model_name_or_path, noise_scheduler="deis") vae.enable_slicing() vae.set_use_memory_efficient_attention_xformers(True) @@ -488,13 +488,15 @@ def main(): eps=args.adam_epsilon, amsgrad=args.adam_amsgrad, ) - else: + elif args.optimizer == 'lion': try: from lion_pytorch import Lion except ImportError: raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion_pytorch`.") create_optimizer = partial(Lion, use_triton=True) + else: + raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") trainer = partial( train, diff --git a/train_ti.py b/train_ti.py index 507d710..12e3644 100644 --- a/train_ti.py +++ b/train_ti.py @@ -292,7 +292,7 @@ def parse_args(): parser.add_argument( "--optimizer", type=str, - default="lion", + default="adam", help='Optimizer to use ["adam", "adam8bit", "lion"]' ) parser.add_argument( @@ -586,13 +586,15 @@ def main(): eps=args.adam_epsilon, amsgrad=args.adam_amsgrad, ) - else: + elif args.optimizer == 'lion': try: from lion_pytorch import Lion except ImportError: raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion_pytorch`.") create_optimizer = partial(Lion, use_triton=True) + else: + raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") checkpoint_output_dir = output_dir/"checkpoints" diff --git a/training/functional.py b/training/functional.py index 4d0cf0e..85dd884 100644 --- a/training/functional.py +++ b/training/functional.py @@ -12,7 +12,7 @@ from torch.utils.data import DataLoader from accelerate import Accelerator from transformers import CLIPTextModel -from diffusers import AutoencoderKL, UNet2DConditionModel, UniPCMultistepScheduler +from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler, UniPCMultistepScheduler, SchedulerMixin from tqdm.auto import tqdm from PIL import Image @@ -74,12 +74,19 @@ def make_grid(images, rows, cols): return grid -def get_models(pretrained_model_name_or_path: str): +def get_models(pretrained_model_name_or_path: str, noise_scheduler: str = "ddpm"): + if noise_scheduler == "deis": + noise_scheduler_cls = DEISMultistepScheduler + elif noise_scheduler == "ddpm": + noise_scheduler_cls = DDPMScheduler + else: + raise ValueError(f"noise_scheduler must be one of [\"ddpm\", \"deis\"], got {noise_scheduler}") + tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet') - noise_scheduler = DEISMultistepScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') + noise_scheduler = noise_scheduler_cls.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') sample_scheduler = UniPCMultistepScheduler.from_pretrained( pretrained_model_name_or_path, subfolder='scheduler') @@ -94,7 +101,7 @@ def save_samples( text_encoder: CLIPTextModel, tokenizer: MultiCLIPTokenizer, vae: AutoencoderKL, - sample_scheduler: UniPCMultistepScheduler, + sample_scheduler: SchedulerMixin, train_dataloader: DataLoader, val_dataloader: Optional[DataLoader], output_dir: Path, @@ -181,7 +188,7 @@ def generate_class_images( vae: AutoencoderKL, unet: UNet2DConditionModel, tokenizer: MultiCLIPTokenizer, - sample_scheduler: UniPCMultistepScheduler, + sample_scheduler: SchedulerMixin, train_dataset: VlpnDataset, sample_batch_size: int, sample_image_size: int, @@ -252,7 +259,7 @@ def add_placeholder_tokens( def loss_step( vae: AutoencoderKL, - noise_scheduler: DEISMultistepScheduler, + noise_scheduler: SchedulerMixin, unet: UNet2DConditionModel, text_encoder: CLIPTextModel, with_prior_preservation: bool, @@ -552,7 +559,7 @@ def train( unet: UNet2DConditionModel, text_encoder: CLIPTextModel, vae: AutoencoderKL, - noise_scheduler: DEISMultistepScheduler, + noise_scheduler: SchedulerMixin, dtype: torch.dtype, seed: int, project: str, -- cgit v1.2.3-70-g09d2