From 01fee7d37a116265edb0f16e0b2f75d2116eb9f6 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 4 Jan 2023 12:18:07 +0100 Subject: Various updates --- data/csv.py | 45 ++++++++++++++++++++++---------------- infer.py | 56 +++++++++++++++++++++++++++++++++++++----------- train_dreambooth.py | 8 +++++++ train_ti.py | 8 +++++++ training/optimization.py | 4 ---- 5 files changed, 87 insertions(+), 34 deletions(-) diff --git a/data/csv.py b/data/csv.py index e901ab4..c505230 100644 --- a/data/csv.py +++ b/data/csv.py @@ -165,19 +165,27 @@ class CSVDataModule(): self.data_val = self.pad_items(data_val) def setup(self, stage=None): - train_dataset = CSVDataset(self.data_train, self.prompt_processor, batch_size=self.batch_size, - num_class_images=self.num_class_images, - size=self.size, interpolation=self.interpolation, - center_crop=self.center_crop, repeats=self.repeats, dropout=self.dropout) - val_dataset = CSVDataset(self.data_val, self.prompt_processor, batch_size=self.batch_size, - size=self.size, interpolation=self.interpolation, - center_crop=self.center_crop) - self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, - shuffle=True, pin_memory=True, collate_fn=self.collate_fn, - num_workers=self.num_workers) - self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, - pin_memory=True, collate_fn=self.collate_fn, - num_workers=self.num_workers) + train_dataset = CSVDataset( + self.data_train, self.prompt_processor, batch_size=self.batch_size, + num_class_images=self.num_class_images, + size=self.size, interpolation=self.interpolation, + center_crop=self.center_crop, repeats=self.repeats, dropout=self.dropout + ) + val_dataset = CSVDataset( + self.data_val, self.prompt_processor, batch_size=self.batch_size, + size=self.size, interpolation=self.interpolation, + center_crop=self.center_crop + ) + self.train_dataloader_ = DataLoader( + train_dataset, batch_size=self.batch_size, + shuffle=True, pin_memory=True, collate_fn=self.collate_fn, + num_workers=self.num_workers + ) + self.val_dataloader_ = DataLoader( + val_dataset, batch_size=self.batch_size, + pin_memory=True, collate_fn=self.collate_fn, + num_workers=self.num_workers + ) def train_dataloader(self): return self.train_dataloader_ @@ -210,11 +218,12 @@ class CSVDataset(Dataset): self.num_instance_images = len(self.data) self._length = self.num_instance_images * repeats - self.interpolation = {"linear": transforms.InterpolationMode.NEAREST, - "bilinear": transforms.InterpolationMode.BILINEAR, - "bicubic": transforms.InterpolationMode.BICUBIC, - "lanczos": transforms.InterpolationMode.LANCZOS, - }[interpolation] + self.interpolation = { + "linear": transforms.InterpolationMode.NEAREST, + "bilinear": transforms.InterpolationMode.BILINEAR, + "bicubic": transforms.InterpolationMode.BICUBIC, + "lanczos": transforms.InterpolationMode.LANCZOS, + }[interpolation] self.image_transforms = transforms.Compose( [ transforms.Resize(size, interpolation=self.interpolation), diff --git a/infer.py b/infer.py index f88245a..c4d1e0d 100644 --- a/infer.py +++ b/infer.py @@ -45,6 +45,7 @@ default_args = { default_cmds = { + "project": "", "scheduler": "dpmsm", "prompt": None, "negative_prompt": None, @@ -103,6 +104,12 @@ def create_cmd_parser(): parser = argparse.ArgumentParser( description="Simple example of a training script." ) + parser.add_argument( + "--project", + type=str, + default=None, + help="The name of the current project.", + ) parser.add_argument( "--scheduler", type=str, @@ -184,7 +191,16 @@ def save_args(basepath, args, extra={}): json.dump(info, f, indent=4) -def create_pipeline(model, embeddings_dir, dtype): +def load_embeddings(pipeline, embeddings_dir): + added_tokens = load_embeddings_from_dir( + pipeline.tokenizer, + pipeline.text_encoder.text_model.embeddings, + Path(embeddings_dir) + ) + print(f"Added {len(added_tokens)} tokens from embeddings dir: {added_tokens}") + + +def create_pipeline(model, dtype): print("Loading Stable Diffusion pipeline...") tokenizer = MultiCLIPTokenizer.from_pretrained(model, subfolder='tokenizer', torch_dtype=dtype) @@ -193,10 +209,7 @@ def create_pipeline(model, embeddings_dir, dtype): unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype) scheduler = DDIMScheduler.from_pretrained(model, subfolder='scheduler', torch_dtype=dtype) - embeddings = patch_managed_embeddings(text_encoder) - added_tokens = load_embeddings_from_dir(tokenizer, embeddings, Path(embeddings_dir)) - - print(f"Added {len(added_tokens)} tokens from embeddings dir: {added_tokens}") + patch_managed_embeddings(text_encoder) pipeline = VlpnStableDiffusion( text_encoder=text_encoder, @@ -220,7 +233,14 @@ def generate(output_dir, pipeline, args): args.prompt = [args.prompt] now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") - output_dir = output_dir.joinpath(f"{now}_{slugify(args.prompt[0])[:100]}") + use_subdirs = len(args.prompt) != 1 + if use_subdirs: + if len(args.project) != 0: + output_dir = output_dir.joinpath(f"{now}_{slugify(args.project)}") + else: + output_dir = output_dir.joinpath(now) + else: + output_dir = output_dir.joinpath(f"{now}_{slugify(args.prompt[0])[:100]}") output_dir.mkdir(parents=True, exist_ok=True) args.seed = args.seed or torch.random.seed() @@ -257,7 +277,8 @@ def generate(output_dir, pipeline, args): dynamic_ncols=True ) - generator = torch.Generator(device="cuda").manual_seed(args.seed + i) + seed = args.seed + i + generator = torch.Generator(device="cuda").manual_seed(seed) images = pipeline( prompt=args.prompt, negative_prompt=args.negative_prompt, @@ -272,8 +293,13 @@ def generate(output_dir, pipeline, args): ).images for j, image in enumerate(images): - image.save(output_dir.joinpath(f"{args.seed + i}_{j}.png")) - image.save(output_dir.joinpath(f"{args.seed + i}_{j}.jpg"), quality=85) + image_dir = output_dir + if use_subdirs: + idx = j % len(args.prompt) + image_dir = image_dir.joinpath(slugify(args.prompt[idx])[:100]) + image_dir.mkdir(parents=True, exist_ok=True) + image.save(image_dir.joinpath(f"{seed}_{j}.png")) + image.save(image_dir.joinpath(f"{seed}_{j}.jpg"), quality=85) if torch.cuda.is_available(): torch.cuda.empty_cache() @@ -283,10 +309,11 @@ class CmdParse(cmd.Cmd): prompt = 'dream> ' commands = [] - def __init__(self, output_dir, pipeline, parser): + def __init__(self, output_dir, ti_embeddings_dir, pipeline, parser): super().__init__() self.output_dir = output_dir + self.ti_embeddings_dir = ti_embeddings_dir self.pipeline = pipeline self.parser = parser @@ -302,6 +329,10 @@ class CmdParse(cmd.Cmd): if elements[0] == 'q': return True + if elements[0] == 'reload_embeddings': + load_embeddings(self.pipeline, self.ti_embeddings_dir) + return + try: args = run_parser(self.parser, default_cmds, elements) @@ -337,9 +368,10 @@ def main(): output_dir = Path(args.output_dir) dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.precision] - pipeline = create_pipeline(args.model, args.ti_embeddings_dir, dtype) + pipeline = create_pipeline(args.model, dtype) + load_embeddings(pipeline, args.ti_embeddings_dir) cmd_parser = create_cmd_parser() - cmd_prompt = CmdParse(output_dir, pipeline, cmd_parser) + cmd_prompt = CmdParse(output_dir, args.ti_embeddings_dir, pipeline, cmd_parser) cmd_prompt.cmdloop() diff --git a/train_dreambooth.py b/train_dreambooth.py index 5e6e35d..2e0696b 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -268,6 +268,12 @@ def parse_args(): default=3, help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' ) + parser.add_argument( + "--lr_min_lr", + type=float, + default=None, + help="Minimum learning rate in the lr scheduler." + ) parser.add_argument( "--use_ema", action="store_true", @@ -799,6 +805,7 @@ def main(): warmup_steps = args.lr_warmup_epochs * num_update_steps_per_epoch * args.gradient_accumulation_steps if args.lr_scheduler == "one_cycle": + lr_min_lr = 0.04 if args.lr_min_lr is None else args.lr_min_lr / args.learning_rate lr_scheduler = get_one_cycle_schedule( optimizer=optimizer, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, @@ -806,6 +813,7 @@ def main(): annealing=args.lr_annealing_func, warmup_exp=args.lr_warmup_exp, annealing_exp=args.lr_annealing_exp, + min_lr=lr_min_lr, ) elif args.lr_scheduler == "cosine_with_restarts": lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( diff --git a/train_ti.py b/train_ti.py index 6f116c3..1b60f64 100644 --- a/train_ti.py +++ b/train_ti.py @@ -259,6 +259,12 @@ def parse_args(): default=1, help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' ) + parser.add_argument( + "--lr_min_lr", + type=float, + default=None, + help="Minimum learning rate in the lr scheduler." + ) parser.add_argument( "--use_8bit_adam", action="store_true", @@ -744,6 +750,7 @@ def main(): if args.find_lr: lr_scheduler = None elif args.lr_scheduler == "one_cycle": + lr_min_lr = 0.04 if args.lr_min_lr is None else args.lr_min_lr / args.learning_rate lr_scheduler = get_one_cycle_schedule( optimizer=optimizer, num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, @@ -751,6 +758,7 @@ def main(): annealing=args.lr_annealing_func, warmup_exp=args.lr_warmup_exp, annealing_exp=args.lr_annealing_exp, + min_lr=lr_min_lr, ) elif args.lr_scheduler == "cosine_with_restarts": lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( diff --git a/training/optimization.py b/training/optimization.py index 14c2bd5..dd84f9c 100644 --- a/training/optimization.py +++ b/training/optimization.py @@ -5,10 +5,6 @@ from functools import partial import torch from torch.optim.lr_scheduler import LambdaLR -from diffusers.utils import logging - -logger = logging.get_logger(__name__) - class OneCyclePhase(NamedTuple): step_min: int -- cgit v1.2.3-70-g09d2