From fcbc11be99c011ab1003451ef72c95ca587902d8 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 15 Oct 2022 18:42:27 +0200 Subject: Update --- dreambooth_plus.py | 34 +++++++++++++++++----- .../stable_diffusion/vlpn_stable_diffusion.py | 8 ++--- textual_inversion.py | 33 +++++++++++++-------- 3 files changed, 49 insertions(+), 26 deletions(-) diff --git a/dreambooth_plus.py b/dreambooth_plus.py index b5ec2fc..eeee424 100644 --- a/dreambooth_plus.py +++ b/dreambooth_plus.py @@ -16,7 +16,7 @@ from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import LoggerType, set_seed from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, UNet2DConditionModel -from diffusers.optimization import get_scheduler +from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup from diffusers.training_utils import EMAModel from PIL import Image from tqdm.auto import tqdm @@ -118,7 +118,7 @@ def parse_args(): parser.add_argument( "--max_train_steps", type=int, - default=2300, + default=1300, help="Total number of training steps to perform. If provided, overrides num_train_epochs.", ) parser.add_argument( @@ -317,6 +317,13 @@ def parse_args(): return args +def save_args(basepath: Path, args, extra={}): + info = {"args": vars(args)} + info["args"].update(extra) + with open(basepath.joinpath("args.json"), "w") as f: + json.dump(info, f, indent=4) + + def freeze_params(params): for param in params: param.requires_grad = False @@ -503,6 +510,8 @@ def main(): logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) + save_args(basepath, args) + # If passed along, set the training seed now. if args.seed is not None: set_seed(args.seed) @@ -706,12 +715,21 @@ def main(): args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch overrode_max_train_steps = True - lr_scheduler = get_scheduler( - args.lr_scheduler, - optimizer=optimizer, - num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, - num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, - ) + if args.lr_scheduler == "cosine_with_restarts": + lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + num_cycles=num_update_steps_per_epoch, + ) + else: + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + ) text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 3e41f86..2656b28 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py @@ -4,7 +4,6 @@ from typing import List, Optional, Union import numpy as np import torch -import torch.optim as optim import PIL from diffusers.configuration_utils import FrozenDict @@ -59,9 +58,6 @@ class VlpnStableDiffusion(DiffusionPipeline): scheduler=scheduler, ) - def get_text_embeddings(self, text_input_ids): - return self.text_encoder(text_input_ids)[0] - def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): r""" Enable sliced attention computation. @@ -199,7 +195,7 @@ class VlpnStableDiffusion(DiffusionPipeline): ) print(f"Too many tokens: {removed_text}") text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] - text_embeddings = self.get_text_embeddings(text_input_ids.to(self.device)) + text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0] # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` @@ -211,7 +207,7 @@ class VlpnStableDiffusion(DiffusionPipeline): uncond_input = self.tokenizer( negative_prompt, padding="max_length", max_length=max_length, return_tensors="pt" ) - uncond_embeddings = self.get_text_embeddings(uncond_input.input_ids.to(self.device)) + uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] # For classifier free guidance, we need to do two forward passes. # Here we concatenate the unconditional and text embeddings into a single batch diff --git a/textual_inversion.py b/textual_inversion.py index 6627f1f..2109d13 100644 --- a/textual_inversion.py +++ b/textual_inversion.py @@ -16,7 +16,7 @@ from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import LoggerType, set_seed from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel -from diffusers.optimization import get_scheduler +from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup from PIL import Image from tqdm.auto import tqdm from transformers import CLIPTextModel, CLIPTokenizer @@ -303,10 +303,10 @@ def freeze_params(params): param.requires_grad = False -def save_resume_file(basepath, args, extra={}): +def save_args(basepath: Path, args, extra={}): info = {"args": vars(args)} info["args"].update(extra) - with open(f"{basepath}/resume.json", "w") as f: + with open(basepath.joinpath("args.json"), "w") as f: json.dump(info, f, indent=4) @@ -660,12 +660,21 @@ def main(): args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch overrode_max_train_steps = True - lr_scheduler = get_scheduler( - args.lr_scheduler, - optimizer=optimizer, - num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, - num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, - ) + if args.lr_scheduler == "cosine_with_restarts": + lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + num_cycles=num_update_steps_per_epoch, + ) + else: + lr_scheduler = get_scheduler( + args.lr_scheduler, + optimizer=optimizer, + num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps, + num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, + ) text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler @@ -827,7 +836,7 @@ def main(): global_progress_bar.clear() checkpointer.checkpoint(global_step + global_step_offset, "training") - save_resume_file(basepath, args, { + save_args(basepath, args, { "global_step": global_step + global_step_offset, "resume_checkpoint": f"{basepath}/checkpoints/last.bin" }) @@ -901,7 +910,7 @@ def main(): if accelerator.is_main_process: print("Finished! Saving final checkpoint and resume state.") checkpointer.checkpoint(global_step + global_step_offset, "end") - save_resume_file(basepath, args, { + save_args(basepath, args, { "global_step": global_step + global_step_offset, "resume_checkpoint": f"{basepath}/checkpoints/last.bin" }) @@ -911,7 +920,7 @@ def main(): if accelerator.is_main_process: print("Interrupted, saving checkpoint and resume state...") checkpointer.checkpoint(global_step + global_step_offset, "end") - save_resume_file(basepath, args, { + save_args(basepath, args, { "global_step": global_step + global_step_offset, "resume_checkpoint": f"{basepath}/checkpoints/last.bin" }) -- cgit v1.2.3-54-g00ecf