From 8364ce697ddf6117fdd4f7222832d546d63880de Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 21 Jun 2023 13:28:49 +0200 Subject: Update --- .vscode/settings.json | 6 + data/csv.py | 183 +++-- data/keywords.py | 8 +- environment.yaml | 11 +- infer.py | 124 ++-- models/attention/control.py | 106 ++- models/attention/hook.py | 5 +- models/attention/structured.py | 65 +- models/clip/embeddings.py | 29 +- models/clip/tokenizer.py | 23 +- models/clip/util.py | 17 +- models/convnext/discriminator.py | 11 +- models/sparse.py | 12 +- .../stable_diffusion/vlpn_stable_diffusion.py | 262 +++++-- train_dreambooth.py | 770 +++++++++++++++------ train_lora.py | 489 +++++++------ train_ti.py | 379 +++++----- training/functional.py | 221 ++++-- training/lr.py | 4 +- training/optimization.py | 38 +- training/sampler.py | 2 +- training/strategy/dreambooth.py | 29 +- training/strategy/lora.py | 41 +- training/strategy/ti.py | 27 +- 24 files changed, 1873 insertions(+), 989 deletions(-) create mode 100644 .vscode/settings.json diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..cee7b74 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,6 @@ +{ + "[python]": { + "editor.defaultFormatter": "ms-python.black-formatter" + }, + "python.formatting.provider": "none" +} diff --git a/data/csv.py b/data/csv.py index 14380e8..d726033 100644 --- a/data/csv.py +++ b/data/csv.py @@ -49,7 +49,7 @@ def generate_buckets( max_pixels: Optional[int] = None, num_buckets: int = 4, progressive_buckets: bool = False, - return_tensor: bool = True + return_tensor: bool = True, ): if max_pixels is None: max_pixels = (base_size + step_size) ** 2 @@ -62,7 +62,11 @@ def generate_buckets( for i in range(1, num_buckets + 1): long_side = base_size + i * step_size - short_side = min(base_size - math.ceil((base_size - max_pixels / long_side) / step_size) * step_size, base_size) + short_side = min( + base_size + - math.ceil((base_size - max_pixels / long_side) / step_size) * step_size, + base_size, + ) buckets.append(long_side / short_side) buckets.append(short_side / long_side) @@ -106,7 +110,7 @@ def collate_fn( max_token_id_length: Optional[int], with_guidance: bool, with_prior_preservation: bool, - examples + examples, ): prompt_ids = [example["prompt_ids"] for example in examples] nprompt_ids = [example["nprompt_ids"] for example in examples] @@ -125,7 +129,9 @@ def collate_fn( prompts = unify_input_ids(tokenizer, prompt_ids, max_token_id_length) nprompts = unify_input_ids(tokenizer, nprompt_ids, max_token_id_length) inputs = unify_input_ids(tokenizer, input_ids, max_token_id_length) - negative_inputs = unify_input_ids(tokenizer, negative_input_ids, max_token_id_length) + negative_inputs = unify_input_ids( + tokenizer, negative_input_ids, max_token_id_length + ) batch = { "prompt_ids": prompts.input_ids, @@ -149,35 +155,39 @@ class VlpnDataItem(NamedTuple): nprompt: str collection: list[str] - def full_prompt(self, dropout: float = 0, shuffle: bool = False, npgenerator: Optional[np.random.Generator] = None): - return keywords_to_str(self.keywords, [self.prompt], dropout, shuffle, npgenerator) + def full_prompt( + self, + dropout: float = 0, + shuffle: bool = False, + npgenerator: Optional[np.random.Generator] = None, + ): + return keywords_to_str( + self.keywords, [self.prompt], dropout, shuffle, npgenerator + ) def keyword_filter( placeholder_tokens: Optional[list[str]], collections: Optional[list[str]], exclude_collections: Optional[list[str]], - item: VlpnDataItem + item: VlpnDataItem, ): full_prompt = item.full_prompt() cond1 = placeholder_tokens is None or any( - token in full_prompt - for token in placeholder_tokens + token in full_prompt for token in placeholder_tokens ) cond2 = collections is None or any( - collection in item.collection - for collection in collections + collection in item.collection for collection in collections ) cond3 = exclude_collections is None or not any( - collection in item.collection - for collection in exclude_collections + collection in item.collection for collection in exclude_collections ) return cond1 and cond2 and cond3 -class VlpnDataModule(): +class VlpnDataModule: def __init__( self, batch_size: int, @@ -222,7 +232,7 @@ class VlpnDataModule(): self.constant_prompt_length = constant_prompt_length self.max_token_id_length = None - + self.tokenizer = tokenizer self.size = size self.num_buckets = num_buckets @@ -259,23 +269,29 @@ class VlpnDataModule(): nprompt = prepare_tpl_slots(item["nprompt"] if "nprompt" in item else "") collection = item["collection"].split(", ") if "collection" in item else [] - saturated_keywords = str_to_keywords(tpl_keywords.format(**keywords), expansions) + saturated_keywords = str_to_keywords( + tpl_keywords.format(**keywords), expansions + ) - inverted_tokens = keywords_to_str([ - f"inv_{token}" - for token in self.placeholder_tokens - if token in saturated_keywords - ]) + inverted_tokens = keywords_to_str( + [ + f"inv_{token}" + for token in self.placeholder_tokens + if token in saturated_keywords + ] + ) - items.append(VlpnDataItem( - self.data_root / image, - None, - saturated_keywords, - tpl_prompt.format(**prompt), - tpl_cprompt.format(**prompt), - tpl_nprompt.format(_inv=inverted_tokens, **nprompt), - collection - )) + items.append( + VlpnDataItem( + self.data_root / image, + None, + saturated_keywords, + tpl_prompt.format(**prompt), + tpl_cprompt.format(**prompt), + tpl_nprompt.format(_inv=inverted_tokens, **nprompt), + collection, + ) + ) return items @@ -285,13 +301,16 @@ class VlpnDataModule(): return [item for item in items if self.filter(item)] - def pad_items(self, items: list[VlpnDataItem], num_class_images: int = 1) -> list[VlpnDataItem]: + def pad_items( + self, items: list[VlpnDataItem], num_class_images: int = 1 + ) -> list[VlpnDataItem]: image_multiplier = max(num_class_images, 1) return [ VlpnDataItem( item.instance_image_path, - self.class_root / f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}", + self.class_root + / f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}", item.keywords, item.prompt, item.cprompt, @@ -303,7 +322,7 @@ class VlpnDataModule(): ] def setup(self): - with open(self.data_file, 'rt') as f: + with open(self.data_file, "rt") as f: metadata = json.load(f) template = metadata[self.template_key] if self.template_key in metadata else {} expansions = metadata["expansions"] if "expansions" in metadata else {} @@ -312,25 +331,41 @@ class VlpnDataModule(): items = self.prepare_items(template, expansions, items) items = self.filter_items(items) self.npgenerator.shuffle(items) - + if self.constant_prompt_length: all_input_ids = unify_input_ids( self.tokenizer, - [self.tokenizer(item.full_prompt(), padding="do_not_pad").input_ids for item in items] + [ + self.tokenizer(item.full_prompt(), padding="do_not_pad").input_ids + for item in items + ], ).input_ids self.max_token_id_length = all_input_ids.shape[1] num_images = len(items) - valid_set_size = min(self.valid_set_size, num_images) if self.valid_set_size is not None else num_images // 10 + valid_set_size = ( + min(self.valid_set_size, num_images) + if self.valid_set_size is not None + else num_images // 10 + ) train_set_size = max(num_images - valid_set_size, 1) valid_set_size = num_images - train_set_size - collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.max_token_id_length, self.with_guidance, self.num_class_images != 0) + collate_fn_ = partial( + collate_fn, + self.dtype, + self.tokenizer, + self.max_token_id_length, + self.with_guidance, + self.num_class_images != 0, + ) if valid_set_size == 0: data_train, data_val = items, items else: - data_train, data_val = random_split(items, [train_set_size, valid_set_size], generator=self.generator) + data_train, data_val = random_split( + items, [train_set_size, valid_set_size], generator=self.generator + ) data_train = self.pad_items(data_train, self.num_class_images) @@ -338,17 +373,25 @@ class VlpnDataModule(): data_train *= math.ceil(self.train_set_pad / len(data_train)) self.train_dataset = VlpnDataset( - data_train, self.tokenizer, - num_buckets=self.num_buckets, progressive_buckets=self.progressive_buckets, - bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, - batch_size=self.batch_size, fill_batch=True, generator=self.generator, - size=self.size, interpolation=self.interpolation, color_jitter=self.color_jitter, - num_class_images=self.num_class_images, dropout=self.dropout, shuffle=self.shuffle, + data_train, + self.tokenizer, + num_buckets=self.num_buckets, + progressive_buckets=self.progressive_buckets, + bucket_step_size=self.bucket_step_size, + bucket_max_pixels=self.bucket_max_pixels, + batch_size=self.batch_size, + fill_batch=True, + generator=self.generator, + size=self.size, + interpolation=self.interpolation, + color_jitter=self.color_jitter, + num_class_images=self.num_class_images, + dropout=self.dropout, + shuffle=self.shuffle, ) self.train_dataloader = DataLoader( - self.train_dataset, - batch_size=None, pin_memory=True, collate_fn=collate_fn_ + self.train_dataset, batch_size=None, pin_memory=True, collate_fn=collate_fn_ ) if len(data_val) != 0: @@ -358,16 +401,24 @@ class VlpnDataModule(): data_val *= math.ceil(self.valid_set_pad / len(data_val)) self.val_dataset = VlpnDataset( - data_val, self.tokenizer, - num_buckets=self.num_buckets, progressive_buckets=True, - bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, - batch_size=self.batch_size, generator=self.generator, - size=self.size, interpolation=self.interpolation, color_jitter=self.color_jitter, + data_val, + self.tokenizer, + num_buckets=self.num_buckets, + progressive_buckets=True, + bucket_step_size=self.bucket_step_size, + bucket_max_pixels=self.bucket_max_pixels, + batch_size=self.batch_size, + generator=self.generator, + size=self.size, + interpolation=self.interpolation, + color_jitter=self.color_jitter, ) self.val_dataloader = DataLoader( self.val_dataset, - batch_size=None, pin_memory=True, collate_fn=collate_fn_ + batch_size=None, + pin_memory=True, + collate_fn=collate_fn_, ) else: self.val_dataloader = None @@ -418,7 +469,13 @@ class VlpnDataset(IterableDataset): self.bucket_item_range = torch.arange(len(self.bucket_items)) - self.length_ = (self.bucket_assignments.bincount() / self.batch_size).ceil().long().sum().item() + self.length_ = ( + (self.bucket_assignments.bincount() / self.batch_size) + .ceil() + .long() + .sum() + .item() + ) def get_input_ids(self, text: str): return self.tokenizer(text, padding="do_not_pad").input_ids @@ -430,7 +487,9 @@ class VlpnDataset(IterableDataset): worker_info = torch.utils.data.get_worker_info() if self.shuffle: - perm = torch.randperm(len(self.bucket_assignments), generator=self.generator) + perm = torch.randperm( + len(self.bucket_assignments), generator=self.generator + ) self.bucket_items = self.bucket_items[perm] self.bucket_assignments = self.bucket_assignments[perm] @@ -488,7 +547,9 @@ class VlpnDataset(IterableDataset): if len(bucket_items) == 0: bucket_items = self.bucket_items[self.bucket_assignments == bucket] - item_index = bucket_items[torch.randint(len(bucket_items), (1,), generator=self.generator)] + item_index = bucket_items[ + torch.randint(len(bucket_items), (1,), generator=self.generator) + ] else: item_index = bucket_items[0] mask[self.bucket_item_range[bucket_mask][0]] = False @@ -500,12 +561,18 @@ class VlpnDataset(IterableDataset): example["prompt_ids"] = self.get_input_ids(item.full_prompt()) example["nprompt_ids"] = self.get_input_ids(item.nprompt) - example["instance_prompt_ids"] = self.get_input_ids(item.full_prompt(self.dropout, True, self.npgenerator)) + example["instance_prompt_ids"] = self.get_input_ids( + item.full_prompt(self.dropout, True, self.npgenerator) + ) example["negative_prompt_ids"] = self.get_input_ids(item.nprompt) - example["instance_images"] = image_transforms(get_image(item.instance_image_path)) + example["instance_images"] = image_transforms( + get_image(item.instance_image_path) + ) if self.num_class_images != 0: example["class_prompt_ids"] = self.get_input_ids(item.cprompt) - example["class_images"] = image_transforms(get_image(item.class_image_path)) + example["class_images"] = image_transforms( + get_image(item.class_image_path) + ) batch.append(example) diff --git a/data/keywords.py b/data/keywords.py index 8632d67..83fe9ff 100644 --- a/data/keywords.py +++ b/data/keywords.py @@ -8,7 +8,7 @@ def keywords_to_str( undroppable_keywords: list[str] = [], dropout: float = 0, shuffle: bool = False, - npgenerator: Optional[np.random.Generator] = None + npgenerator: Optional[np.random.Generator] = None, ) -> str: if dropout != 0: keywords = [keyword for keyword in keywords if np.random.random() > dropout] @@ -23,7 +23,11 @@ def keywords_to_str( def str_to_keywords(s: str, expansions: dict[str, str] = {}) -> list[str]: def expand_keyword(keyword: str) -> list[str]: - return [keyword] + expansions[keyword].split(", ") if keyword in expansions else [keyword] + return ( + [keyword] + expansions[keyword].split(", ") + if keyword in expansions + else [keyword] + ) return [ kw diff --git a/environment.yaml b/environment.yaml index 1a55967..2c81a90 100644 --- a/environment.yaml +++ b/environment.yaml @@ -14,16 +14,17 @@ dependencies: - numpy=1.24.3 - pip=22.3.1 - python=3.10.8 - - pytorch=2.0.0=*cuda11.8* - - torchvision=0.15.0 - - xformers=0.0.20.dev528 + - pytorch=2.0.1=*cuda11.8* + - scipy=1.10.1 + - torchvision=0.15.2 + - xformers=0.0.21.dev542+git.a205b24 - pip: - -e . - -e git+https://github.com/huggingface/accelerate#egg=accelerate - -e git+https://github.com/huggingface/diffusers#egg=diffusers - -e git+https://github.com/facebookresearch/dadaptation#egg=dadaptation - --pre --extra-index-url https://download.hidet.org/whl hidet - - bitsandbytes==0.38.1 + - bitsandbytes==0.39.1 - lion-pytorch==0.0.7 - peft==0.3.0 - python-slugify>=6.1.2 @@ -31,4 +32,4 @@ dependencies: - setuptools==65.6.3 - test-tube>=0.7.5 - timm==0.9.2 - - transformers==4.29.1 + - transformers==4.30.1 diff --git a/infer.py b/infer.py index 7346de9..3b3b595 100644 --- a/infer.py +++ b/infer.py @@ -24,7 +24,7 @@ from diffusers import ( KDPM2DiscreteScheduler, KDPM2AncestralDiscreteScheduler, DEISMultistepScheduler, - UniPCMultistepScheduler + UniPCMultistepScheduler, ) from peft import LoraConfig, LoraModel, set_peft_model_state_dict from safetensors.torch import load_file @@ -61,7 +61,7 @@ default_cmds = { "negative_prompt": None, "shuffle": False, "image": None, - "image_noise": .7, + "image_noise": 0.7, "width": 768, "height": 768, "batch_size": 1, @@ -69,7 +69,6 @@ default_cmds = { "steps": 30, "guidance_scale": 7.0, "sag_scale": 0, - "brightness_offset": 0, "seed": None, "config": None, } @@ -85,9 +84,7 @@ def merge_dicts(d1, *args): def create_args_parser(): - parser = argparse.ArgumentParser( - description="Simple example of a training script." - ) + parser = argparse.ArgumentParser(description="Simple example of a training script.") parser.add_argument( "--model", type=str, @@ -118,9 +115,7 @@ def create_args_parser(): def create_cmd_parser(): - parser = argparse.ArgumentParser( - description="Simple example of a training script." - ) + parser = argparse.ArgumentParser(description="Simple example of a training script.") parser.add_argument( "--project", type=str, @@ -130,13 +125,34 @@ def create_cmd_parser(): parser.add_argument( "--scheduler", type=str, - choices=["plms", "ddim", "klms", "dpmsm", "dpmss", "euler_a", "kdpm2", "kdpm2_a", "deis", "unipc"], + choices=[ + "plms", + "ddim", + "klms", + "dpmsm", + "dpmss", + "euler_a", + "kdpm2", + "kdpm2_a", + "deis", + "unipc", + ], ) parser.add_argument( "--subscheduler", type=str, default=None, - choices=["plms", "ddim", "klms", "dpmsm", "dpmss", "euler_a", "kdpm2", "kdpm2_a", "deis"], + choices=[ + "plms", + "ddim", + "klms", + "dpmsm", + "dpmss", + "euler_a", + "kdpm2", + "kdpm2_a", + "deis", + ], ) parser.add_argument( "--template", @@ -192,10 +208,6 @@ def create_cmd_parser(): "--sag_scale", type=float, ) - parser.add_argument( - "--brightness_offset", - type=float, - ) parser.add_argument( "--seed", type=int, @@ -214,7 +226,9 @@ def run_parser(parser, defaults, input=None): if args.config is not None: conf_args = load_config(args.config) - conf_args = parser.parse_known_args(namespace=argparse.Namespace(**conf_args))[0] + conf_args = parser.parse_known_args(namespace=argparse.Namespace(**conf_args))[ + 0 + ] res = defaults.copy() for dict in [vars(conf_args), vars(args)]: @@ -234,10 +248,12 @@ def load_embeddings_dir(pipeline, embeddings_dir): added_tokens, added_ids = load_embeddings_from_dir( pipeline.tokenizer, pipeline.text_encoder.text_model.embeddings, - Path(embeddings_dir) + Path(embeddings_dir), ) pipeline.text_encoder.text_model.embeddings.persist() - print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") + print( + f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}" + ) def load_lora(pipeline, path): @@ -255,9 +271,13 @@ def load_lora(pipeline, path): return lora_checkpoint_sd = load_file(path / tensor_files[0]) - unet_lora_ds = {k: v for k, v in lora_checkpoint_sd.items() if "text_encoder_" not in k} + unet_lora_ds = { + k: v for k, v in lora_checkpoint_sd.items() if "text_encoder_" not in k + } text_encoder_lora_ds = { - k.replace("text_encoder_", ""): v for k, v in lora_checkpoint_sd.items() if "text_encoder_" in k + k.replace("text_encoder_", ""): v + for k, v in lora_checkpoint_sd.items() + if "text_encoder_" in k } ti_lora_ds = { k.replace("ti_", ""): v for k, v in lora_checkpoint_sd.items() if "ti_" in k @@ -282,7 +302,9 @@ def load_lora(pipeline, path): token_embeddings=token_embeddings, ) pipeline.text_encoder.text_model.embeddings.persist() - print(f"Added {len(added_tokens)} tokens from LoRA: {list(zip(added_tokens, added_ids))}") + print( + f"Added {len(added_tokens)} tokens from LoRA: {list(zip(added_tokens, added_ids))}" + ) return @@ -315,17 +337,25 @@ def create_scheduler(config, scheduler: str, subscheduler: Optional[str] = None) solver_p=create_scheduler(config, subscheduler), ) else: - raise ValueError(f"Unknown scheduler \"{scheduler}\"") + raise ValueError(f'Unknown scheduler "{scheduler}"') def create_pipeline(model, dtype): print("Loading Stable Diffusion pipeline...") - tokenizer = MultiCLIPTokenizer.from_pretrained(model, subfolder='tokenizer', torch_dtype=dtype) - text_encoder = CLIPTextModel.from_pretrained(model, subfolder='text_encoder', torch_dtype=dtype) - vae = AutoencoderKL.from_pretrained(model, subfolder='vae', torch_dtype=dtype) - unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype) - scheduler = DDIMScheduler.from_pretrained(model, subfolder='scheduler', torch_dtype=dtype) + tokenizer = MultiCLIPTokenizer.from_pretrained( + model, subfolder="tokenizer", torch_dtype=dtype + ) + text_encoder = CLIPTextModel.from_pretrained( + model, subfolder="text_encoder", torch_dtype=dtype + ) + vae = AutoencoderKL.from_pretrained(model, subfolder="vae", torch_dtype=dtype) + unet = UNet2DConditionModel.from_pretrained( + model, subfolder="unet", torch_dtype=dtype + ) + scheduler = DDIMScheduler.from_pretrained( + model, subfolder="scheduler", torch_dtype=dtype + ) patch_managed_embeddings(text_encoder) @@ -347,7 +377,9 @@ def create_pipeline(model, dtype): def shuffle_prompts(prompts: list[str]) -> list[str]: - return [keywords_to_str(str_to_keywords(prompt), shuffle=True) for prompt in prompts] + return [ + keywords_to_str(str_to_keywords(prompt), shuffle=True) for prompt in prompts + ] @torch.inference_mode() @@ -386,12 +418,13 @@ def generate(output_dir: Path, pipeline, args): else: init_image = None - pipeline.scheduler = create_scheduler(pipeline.scheduler.config, args.scheduler, args.subscheduler) + pipeline.scheduler = create_scheduler( + pipeline.scheduler.config, args.scheduler, args.subscheduler + ) for i in range(args.batch_num): pipeline.set_progress_bar_config( - desc=f"Batch {i + 1} of {args.batch_num}", - dynamic_ncols=True + desc=f"Batch {i + 1} of {args.batch_num}", dynamic_ncols=True ) seed = args.seed + i @@ -409,7 +442,6 @@ def generate(output_dir: Path, pipeline, args): generator=generator, image=init_image, strength=args.image_noise, - brightness_offset=args.brightness_offset, ).images for j, image in enumerate(images): @@ -418,7 +450,7 @@ def generate(output_dir: Path, pipeline, args): image.save(dir / f"{basename}.png") image.save(dir / f"{basename}.jpg", quality=85) - with open(dir / f"{basename}.txt", 'w') as f: + with open(dir / f"{basename}.txt", "w") as f: f.write(prompt[j % len(args.prompt)]) if torch.cuda.is_available(): @@ -426,10 +458,12 @@ def generate(output_dir: Path, pipeline, args): class CmdParse(cmd.Cmd): - prompt = 'dream> ' + prompt = "dream> " commands = [] - def __init__(self, output_dir, ti_embeddings_dir, lora_embeddings_dir, pipeline, parser): + def __init__( + self, output_dir, ti_embeddings_dir, lora_embeddings_dir, pipeline, parser + ): super().__init__() self.output_dir = output_dir @@ -447,10 +481,10 @@ class CmdParse(cmd.Cmd): print(str(e)) return - if elements[0] == 'q': + if elements[0] == "q": return True - if elements[0] == 'reload_embeddings': + if elements[0] == "reload_embeddings": load_embeddings_dir(self.pipeline, self.ti_embeddings_dir) return @@ -458,7 +492,7 @@ class CmdParse(cmd.Cmd): args = run_parser(self.parser, default_cmds, elements) if len(args.prompt) == 0: - print('Try again with a prompt!') + print("Try again with a prompt!") return except SystemExit: traceback.print_exc() @@ -471,7 +505,7 @@ class CmdParse(cmd.Cmd): try: generate(self.output_dir, self.pipeline, args) except KeyboardInterrupt: - print('Generation cancelled.') + print("Generation cancelled.") except Exception as e: traceback.print_exc() return @@ -487,7 +521,9 @@ def main(): args = run_parser(args_parser, default_args) output_dir = Path(args.output_dir) - dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.precision] + dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[ + args.precision + ] pipeline = create_pipeline(args.model, dtype) @@ -496,7 +532,13 @@ def main(): # pipeline.unet.load_attn_procs(args.lora_embeddings_dir) cmd_parser = create_cmd_parser() - cmd_prompt = CmdParse(output_dir, args.ti_embeddings_dir, args.lora_embeddings_dir, pipeline, cmd_parser) + cmd_prompt = CmdParse( + output_dir, + args.ti_embeddings_dir, + args.lora_embeddings_dir, + pipeline, + cmd_parser, + ) cmd_prompt.cmdloop() diff --git a/models/attention/control.py b/models/attention/control.py index 248bd9f..ec378c4 100644 --- a/models/attention/control.py +++ b/models/attention/control.py @@ -23,7 +23,7 @@ class AttentionControl(abc.ABC): attn = self.forward(attn, is_cross, place_in_unet) else: h = attn.shape[0] - attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet) + attn[h // 2 :] = self.forward(attn[h // 2 :], is_cross, place_in_unet) self.cur_att_layer += 1 if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers: self.cur_att_layer = 0 @@ -49,12 +49,18 @@ class EmptyControl(AttentionControl): class AttentionStore(AttentionControl): @staticmethod def get_empty_store(): - return {"down_cross": [], "mid_cross": [], "up_cross": [], - "down_self": [], "mid_self": [], "up_self": []} + return { + "down_cross": [], + "mid_cross": [], + "up_cross": [], + "down_self": [], + "mid_self": [], + "up_self": [], + } def forward(self, attn, is_cross: bool, place_in_unet: str): key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" - if attn.shape[1] <= 32 ** 2: # avoid memory overhead + if attn.shape[1] <= 32**2: # avoid memory overhead self.step_store[key].append(attn) return attn @@ -68,8 +74,10 @@ class AttentionStore(AttentionControl): self.step_store = self.get_empty_store() def get_average_attention(self): - average_attention = {key: [item / self.cur_step for item in self.attention_store[key]] - for key in self.attention_store} + average_attention = { + key: [item / self.cur_step for item in self.attention_store[key]] + for key in self.attention_store + } return average_attention def reset(self): @@ -90,7 +98,7 @@ class AttentionControlEdit(AttentionStore, abc.ABC): return x_t def replace_self_attention(self, attn_base, att_replace): - if att_replace.shape[2] <= 16 ** 2: + if att_replace.shape[2] <= 16**2: return attn_base.unsqueeze(0).expand(att_replace.shape[0], *attn_base.shape) else: return att_replace @@ -101,41 +109,62 @@ class AttentionControlEdit(AttentionStore, abc.ABC): def forward(self, attn, is_cross: bool, place_in_unet: str): super(AttentionControlEdit, self).forward(attn, is_cross, place_in_unet) - if is_cross or (self.num_self_replace[0] <= self.cur_step < self.num_self_replace[1]): + if is_cross or ( + self.num_self_replace[0] <= self.cur_step < self.num_self_replace[1] + ): h = attn.shape[0] // (self.batch_size) attn = attn.reshape(self.batch_size, h, *attn.shape[1:]) attn_base, attn_repalce = attn[0], attn[1:] if is_cross: alpha_words = self.cross_replace_alpha[self.cur_step] - attn_repalce_new = self.replace_cross_attention( - attn_base, attn_repalce) * alpha_words + (1 - alpha_words) * attn_repalce + attn_repalce_new = ( + self.replace_cross_attention(attn_base, attn_repalce) * alpha_words + + (1 - alpha_words) * attn_repalce + ) attn[1:] = attn_repalce_new else: attn[1:] = self.replace_self_attention(attn_base, attn_repalce) attn = attn.reshape(self.batch_size * h, *attn.shape[2:]) return attn - def __init__(self, prompts, num_steps: int, - cross_replace_steps: Union[float, Tuple[float, float], Dict[str, Tuple[float, float]]], - self_replace_steps: Union[float, Tuple[float, float]], - local_blend: Optional[LocalBlend]): + def __init__( + self, + prompts, + num_steps: int, + cross_replace_steps: Union[ + float, Tuple[float, float], Dict[str, Tuple[float, float]] + ], + self_replace_steps: Union[float, Tuple[float, float]], + local_blend: Optional[LocalBlend], + ): super(AttentionControlEdit, self).__init__() self.batch_size = len(prompts) self.cross_replace_alpha = ptp_utils.get_time_words_attention_alpha( - prompts, num_steps, cross_replace_steps, tokenizer).to(device) + prompts, num_steps, cross_replace_steps, tokenizer + ).to(device) if type(self_replace_steps) is float: self_replace_steps = 0, self_replace_steps - self.num_self_replace = int(num_steps * self_replace_steps[0]), int(num_steps * self_replace_steps[1]) + self.num_self_replace = int(num_steps * self_replace_steps[0]), int( + num_steps * self_replace_steps[1] + ) self.local_blend = local_blend class AttentionReplace(AttentionControlEdit): def replace_cross_attention(self, attn_base, att_replace): - return torch.einsum('hpw,bwn->bhpn', attn_base, self.mapper) - - def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float, - local_blend: Optional[LocalBlend] = None): - super(AttentionReplace, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend) + return torch.einsum("hpw,bwn->bhpn", attn_base, self.mapper) + + def __init__( + self, + prompts, + num_steps: int, + cross_replace_steps: float, + self_replace_steps: float, + local_blend: Optional[LocalBlend] = None, + ): + super(AttentionReplace, self).__init__( + prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend + ) self.mapper = seq_aligner.get_replacement_mapper(prompts, tokenizer).to(device) @@ -145,9 +174,17 @@ class AttentionRefine(AttentionControlEdit): attn_replace = attn_base_replace * self.alphas + att_replace * (1 - self.alphas) return attn_replace - def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float, - local_blend: Optional[LocalBlend] = None): - super(AttentionRefine, self).__init__(prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend) + def __init__( + self, + prompts, + num_steps: int, + cross_replace_steps: float, + self_replace_steps: float, + local_blend: Optional[LocalBlend] = None, + ): + super(AttentionRefine, self).__init__( + prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend + ) self.mapper, alphas = seq_aligner.get_refinement_mapper(prompts, tokenizer) self.mapper, alphas = self.mapper.to(device), alphas.to(device) self.alphas = alphas.reshape(alphas.shape[0], 1, 1, alphas.shape[1]) @@ -156,13 +193,24 @@ class AttentionRefine(AttentionControlEdit): class AttentionReweight(AttentionControlEdit): def replace_cross_attention(self, attn_base, att_replace): if self.prev_controller is not None: - attn_base = self.prev_controller.replace_cross_attention(attn_base, att_replace) + attn_base = self.prev_controller.replace_cross_attention( + attn_base, att_replace + ) attn_replace = attn_base[None, :, :, :] * self.equalizer[:, None, None, :] return attn_replace - def __init__(self, prompts, num_steps: int, cross_replace_steps: float, self_replace_steps: float, equalizer, - local_blend: Optional[LocalBlend] = None, controller: Optional[AttentionControlEdit] = None): - super(AttentionReweight, self).__init__(prompts, num_steps, - cross_replace_steps, self_replace_steps, local_blend) + def __init__( + self, + prompts, + num_steps: int, + cross_replace_steps: float, + self_replace_steps: float, + equalizer, + local_blend: Optional[LocalBlend] = None, + controller: Optional[AttentionControlEdit] = None, + ): + super(AttentionReweight, self).__init__( + prompts, num_steps, cross_replace_steps, self_replace_steps, local_blend + ) self.equalizer = equalizer.to(device) self.prev_controller = controller diff --git a/models/attention/hook.py b/models/attention/hook.py index 903de02..6b5fb68 100644 --- a/models/attention/hook.py +++ b/models/attention/hook.py @@ -3,6 +3,7 @@ import torch try: import xformers.ops + xformers._is_functorch_available = True MEM_EFFICIENT_ATTN = True except ImportError: @@ -42,10 +43,10 @@ def register_attention_control(model, controller): return forward def register_recr(net_, count, place_in_unet): - if net_.__class__.__name__ == 'CrossAttention': + if net_.__class__.__name__ == "CrossAttention": net_.forward = ca_forward(net_, place_in_unet) return count + 1 - elif hasattr(net_, 'children'): + elif hasattr(net_, "children"): for net__ in net_.children(): count = register_recr(net__, count, place_in_unet) return count diff --git a/models/attention/structured.py b/models/attention/structured.py index 24d889f..5bbbc06 100644 --- a/models/attention/structured.py +++ b/models/attention/structured.py @@ -16,7 +16,9 @@ class StructuredAttentionControl(AttentionControl): if self.struct_attn: out = self.struct_qkv(q, context, mask) else: - context = torch.cat([context[0], context[1]['k'][0]], dim=0) # use key tensor for context + context = torch.cat( + [context[0], context[1]["k"][0]], dim=0 + ) # use key tensor for context out = self.normal_qkv(q, context, mask) else: context = default(context, x) @@ -29,11 +31,13 @@ class StructuredAttentionControl(AttentionControl): context: list of [uc, list of conditional context] """ uc_context = context[0] - context_k, context_v = context[1]['k'], context[1]['v'] + context_k, context_v = context[1]["k"], context[1]["v"] if isinstance(context_k, list) and isinstance(context_v, list): out = self.multi_qkv(q, uc_context, context_k, context_v, mask) - elif isinstance(context_k, torch.Tensor) and isinstance(context_v, torch.Tensor): + elif isinstance(context_k, torch.Tensor) and isinstance( + context_v, torch.Tensor + ): out = self.heterogeous_qkv(q, uc_context, context_k, context_v, mask) else: raise NotImplementedError @@ -50,36 +54,45 @@ class StructuredAttentionControl(AttentionControl): k_c = [self.to_k(c_k) for c_k in context_k] v_c = [self.to_v(c_v) for c_v in context_v] - q = rearrange(q, 'b n (h d) -> (b h) n d', h=h) + q = rearrange(q, "b n (h d) -> (b h) n d", h=h) - k_uc = rearrange(k_uc, 'b n (h d) -> (b h) n d', h=h) - v_uc = rearrange(v_uc, 'b n (h d) -> (b h) n d', h=h) + k_uc = rearrange(k_uc, "b n (h d) -> (b h) n d", h=h) + v_uc = rearrange(v_uc, "b n (h d) -> (b h) n d", h=h) - k_c = [rearrange(k, 'b n (h d) -> (b h) n d', h=h) for k in k_c] # NOTE: modification point - v_c = [rearrange(v, 'b n (h d) -> (b h) n d', h=h) for v in v_c] + k_c = [ + rearrange(k, "b n (h d) -> (b h) n d", h=h) for k in k_c + ] # NOTE: modification point + v_c = [rearrange(v, "b n (h d) -> (b h) n d", h=h) for v in v_c] # get composition - sim_uc = einsum('b i d, b j d -> b i j', q[:true_bs], k_uc) * self.scale - sim_c = [einsum('b i d, b j d -> b i j', q[true_bs:], k) * self.scale for k in k_c] + sim_uc = einsum("b i d, b j d -> b i j", q[:true_bs], k_uc) * self.scale + sim_c = [ + einsum("b i d, b j d -> b i j", q[true_bs:], k) * self.scale for k in k_c + ] attn_uc = sim_uc.softmax(dim=-1) attn_c = [sim.softmax(dim=-1) for sim in sim_c] # get uc output - out_uc = einsum('b i j, b j d -> b i d', attn_uc, v_uc) + out_uc = einsum("b i j, b j d -> b i d", attn_uc, v_uc) # get c output if len(v_c) == 1: out_c_collect = [] for attn in attn_c: for v in v_c: - out_c_collect.append(einsum('b i j, b j d -> b i d', attn, v)) + out_c_collect.append(einsum("b i j, b j d -> b i d", attn, v)) out_c = sum(out_c_collect) / len(out_c_collect) else: - out_c = sum([einsum('b i j, b j d -> b i d', attn, v) for attn, v in zip(attn_c, v_c)]) / len(v_c) + out_c = sum( + [ + einsum("b i j, b j d -> b i d", attn, v) + for attn, v in zip(attn_c, v_c) + ] + ) / len(v_c) out = torch.cat([out_uc, out_c], dim=0) - out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + out = rearrange(out, "(b h) n d -> b n (h d)", h=h) return out @@ -88,21 +101,21 @@ class StructuredAttentionControl(AttentionControl): k = self.to_k(context) v = self.to_v(context) - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) - sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + sim = einsum("b i d, b j d -> b i j", q, k) * self.scale if exists(mask): - mask = rearrange(mask, 'b ... -> b (...)') + mask = rearrange(mask, "b ... -> b (...)") max_neg_value = -torch.finfo(sim.dtype).max - mask = repeat(mask, 'b j -> (b h) () j', h=h) + mask = repeat(mask, "b j -> (b h) () j", h=h) sim.masked_fill_(~mask, max_neg_value) # attention, what we cannot get enough of attn = sim.softmax(dim=-1) - out = einsum('b i j, b j d -> b i d', attn, v) - out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + out = einsum("b i j, b j d -> b i d", attn, v) + out = rearrange(out, "(b h) n d -> b n (h d)", h=h) return out @@ -111,21 +124,21 @@ class StructuredAttentionControl(AttentionControl): k = self.to_k(torch.cat([uc_context, context_k], dim=0)) v = self.to_v(torch.cat([uc_context, context_v], dim=0)) - q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v)) + q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v)) - sim = einsum('b i d, b j d -> b i j', q, k) * self.scale + sim = einsum("b i d, b j d -> b i j", q, k) * self.scale if exists(mask): - mask = rearrange(mask, 'b ... -> b (...)') + mask = rearrange(mask, "b ... -> b (...)") max_neg_value = -torch.finfo(sim.dtype).max - mask = repeat(mask, 'b j -> (b h) () j', h=h) + mask = repeat(mask, "b j -> (b h) () j", h=h) sim.masked_fill_(~mask, max_neg_value) # attention, what we cannot get enough of attn = sim.softmax(dim=-1) - out = einsum('b i j, b j d -> b i d', attn, v) - out = rearrange(out, '(b h) n d -> b n (h d)', h=h) + out = einsum("b i j, b j d -> b i d", attn, v) + out = rearrange(out, "(b h) n d -> b n (h d)", h=h) return out def get_kv(self, context): diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py index 7c7f2ac..8c3c6d4 100644 --- a/models/clip/embeddings.py +++ b/models/clip/embeddings.py @@ -14,7 +14,13 @@ from models.sparse import SparseEmbedding class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): - def __init__(self, config: CLIPTextConfig, embeddings: CLIPTextEmbeddings, alpha: int = 8, dropout: float = 0.0): + def __init__( + self, + config: CLIPTextConfig, + embeddings: CLIPTextEmbeddings, + alpha: int = 8, + dropout: float = 0.0, + ): super().__init__(config) self.position_embedding = embeddings.position_embedding @@ -28,7 +34,9 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): self.token_embedding.weight = embeddings.token_embedding.weight def resize(self, size: int): - self.token_embedding = self.token_embedding.new_resized(size, self.initializer_factor) + self.token_embedding = self.token_embedding.new_resized( + size, self.initializer_factor + ) def add_embed( self, @@ -46,7 +54,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): initializer = [initializer] if isinstance(initializer, list): - initializer = (initializer * len(token_ids))[:len(token_ids)] + initializer = (initializer * len(token_ids))[: len(token_ids)] with torch.no_grad(): initializer = self.get_embed(initializer) @@ -76,24 +84,21 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings): def get_embed(self, input_ids: Union[list[int], torch.LongTensor]): if isinstance(input_ids, list): - input_ids = torch.tensor(input_ids, device=self.token_embedding.weight.device, dtype=torch.long) + input_ids = torch.tensor( + input_ids, device=self.token_embedding.weight.device, dtype=torch.long + ) return self.token_embedding(input_ids) def patch_managed_embeddings( - text_encoder: CLIPTextModel, - alpha: int = 8, - dropout: float = 0.0 + text_encoder: CLIPTextModel, alpha: int = 8, dropout: float = 0.0 ) -> ManagedCLIPTextEmbeddings: if isinstance(text_encoder.text_model.embeddings, ManagedCLIPTextEmbeddings): return text_encoder.text_model.embeddings - + text_embeddings = ManagedCLIPTextEmbeddings( - text_encoder.config, - text_encoder.text_model.embeddings, - alpha, - dropout + text_encoder.config, text_encoder.text_model.embeddings, alpha, dropout ) text_encoder.text_model.embeddings = text_embeddings return text_embeddings diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py index 789b525..a866641 100644 --- a/models/clip/tokenizer.py +++ b/models/clip/tokenizer.py @@ -91,18 +91,21 @@ class MultiCLIPTokenizer(CLIPTokenizer): self.vector_shuffle = shuffle_none def add_multi_tokens( - self, - new_tokens: Union[str, list[str]], - num_vectors: Union[int, list[int]] = 1 + self, new_tokens: Union[str, list[str]], num_vectors: Union[int, list[int]] = 1 ) -> Union[list[int], list[list[int]]]: if isinstance(new_tokens, list): if isinstance(num_vectors, int): num_vectors = [num_vectors] * len(new_tokens) if len(num_vectors) != len(new_tokens): - raise ValueError("Expected new_tokens and num_vectors to have the same len") + raise ValueError( + "Expected new_tokens and num_vectors to have the same len" + ) - return [self.add_multi_tokens(new_token, vecs) for new_token, vecs in zip(new_tokens, num_vectors)] + return [ + self.add_multi_tokens(new_token, vecs) + for new_token, vecs in zip(new_tokens, num_vectors) + ] if isinstance(num_vectors, list): raise ValueError("Expected num_vectors to be int for single token") @@ -129,13 +132,11 @@ class MultiCLIPTokenizer(CLIPTokenizer): return [id] def expand_ids(self, ids: list[int]): - return [ - new_id - for id in ids - for new_id in self.expand_id(id) - ] + return [new_id for id in ids for new_id in self.expand_id(id)] - def expand_batched_ids(self, input_ids: Union[list[int], list[list[int]], tuple[list[int]]]): + def expand_batched_ids( + self, input_ids: Union[list[int], list[list[int]], tuple[list[int]]] + ): if isinstance(input_ids, (list, tuple)) and isinstance(input_ids[0], list): return [self.expand_ids(batch) for batch in input_ids] else: diff --git a/models/clip/util.py b/models/clip/util.py index f94fbc7..7196bb6 100644 --- a/models/clip/util.py +++ b/models/clip/util.py @@ -5,27 +5,32 @@ import torch from transformers import CLIPTokenizer, CLIPTextModel -def unify_input_ids(tokenizer: CLIPTokenizer, input_ids: list[list[int]], max_length: Optional[int] = None): +def unify_input_ids( + tokenizer: CLIPTokenizer, + input_ids: list[list[int]], + max_length: Optional[int] = None, +): if max_length is None: return tokenizer.pad( {"input_ids": input_ids}, padding=True, pad_to_multiple_of=tokenizer.model_max_length, - return_tensors="pt" + return_tensors="pt", ) else: return tokenizer.pad( {"input_ids": input_ids}, padding="max_length", max_length=max_length, - return_tensors="pt" + return_tensors="pt", ) + def get_extended_embeddings( text_encoder: CLIPTextModel, input_ids: torch.LongTensor, position_ids: Optional[torch.LongTensor] = None, - attention_mask=None + attention_mask=None, ): model_max_length = text_encoder.config.max_position_embeddings prompts = input_ids.shape[0] @@ -36,6 +41,8 @@ def get_extended_embeddings( if attention_mask is not None: attention_mask = attention_mask.view((-1, model_max_length)) - text_embeddings = text_encoder(input_ids, position_ids=position_ids, attention_mask=attention_mask)[0] + text_embeddings = text_encoder( + input_ids, position_ids=position_ids, attention_mask=attention_mask + )[0] text_embeddings = text_embeddings.view((prompts, -1, text_embeddings.shape[2])) return text_embeddings diff --git a/models/convnext/discriminator.py b/models/convnext/discriminator.py index 571b915..5798bcf 100644 --- a/models/convnext/discriminator.py +++ b/models/convnext/discriminator.py @@ -5,7 +5,7 @@ from timm.data.constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD from torch.nn import functional as F -class ConvNeXtDiscriminator(): +class ConvNeXtDiscriminator: def __init__(self, model: ConvNeXt, input_size: int) -> None: self.net = model @@ -22,8 +22,13 @@ class ConvNeXtDiscriminator(): img_mean = self.img_mean.to(device=img.device, dtype=img.dtype) img_std = self.img_std.to(device=img.device, dtype=img.dtype) - img = ((img + 1.) / 2.).sub(img_mean).div(img_std) + img = ((img + 1.0) / 2.0).sub(img_mean).div(img_std) - img = F.interpolate(img, size=(self.input_size, self.input_size), mode='bicubic', align_corners=True) + img = F.interpolate( + img, + size=(self.input_size, self.input_size), + mode="bicubic", + align_corners=True, + ) pred = self.net(img) return pred diff --git a/models/sparse.py b/models/sparse.py index bd45696..e5897c9 100644 --- a/models/sparse.py +++ b/models/sparse.py @@ -15,21 +15,25 @@ class SparseEmbedding(nn.Embedding): ): nn.Embedding.__init__(self, num_embeddings, embedding_dim, **kwargs) - self.register_buffer('trainable_ids', self.weight.new_zeros(num_embeddings, dtype=torch.long) - 1) + self.register_buffer( + "trainable_ids", self.weight.new_zeros(num_embeddings, dtype=torch.long) - 1 + ) self.trainable = nn.ParameterList() self.scaling = alpha self.dropout_p = dropout self.weight.requires_grad = False - if dropout > 0.: + if dropout > 0.0: self.dropout = nn.Dropout(p=dropout) else: self.dropout = nn.Identity() self.reset_parameters() - def new_resized(self, new_num_embeddings: int, initializer_factor: Optional[float] = None): + def new_resized( + self, new_num_embeddings: int, initializer_factor: Optional[float] = None + ): n = min(self.num_embeddings, new_num_embeddings) new_emb = SparseEmbedding( @@ -38,7 +42,7 @@ class SparseEmbedding(nn.Embedding): self.scaling, self.dropout_p, device=self.weight.device, - dtype=self.weight.dtype + dtype=self.weight.dtype, ) if initializer_factor is not None: new_emb.weight.data.normal_(mean=0.0, std=initializer_factor * 0.02) diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index aa446ec..16b8456 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py @@ -21,7 +21,9 @@ from diffusers import ( LMSDiscreteScheduler, PNDMScheduler, ) -from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput +from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import ( + StableDiffusionPipelineOutput, +) from diffusers.utils import logging, randn_tensor from transformers import CLIPTextModel, CLIPTokenizer @@ -62,13 +64,35 @@ def gaussian_blur_2d(img, kernel_size, sigma): return img +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std( + dim=list(range(1, noise_pred_text.ndim)), keepdim=True + ) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = ( + guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + ) + return noise_cfg + + class CrossAttnStoreProcessor: def __init__(self): self.attention_probs = None - def __call__(self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None): + def __call__( + self, attn, hidden_states, encoder_hidden_states=None, attention_mask=None + ): batch_size, sequence_length, _ = hidden_states.shape - attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size + ) query = attn.to_q(hidden_states) if encoder_hidden_states is None: @@ -113,7 +137,10 @@ class VlpnStableDiffusion(DiffusionPipeline): ): super().__init__() - if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + if ( + hasattr(scheduler.config, "steps_offset") + and scheduler.config.steps_offset != 1 + ): warnings.warn( f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " @@ -179,7 +206,12 @@ class VlpnStableDiffusion(DiffusionPipeline): device = torch.device("cuda") - for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae, self.safety_checker]: + for cpu_offloaded_model in [ + self.unet, + self.text_encoder, + self.vae, + self.safety_checker, + ]: if cpu_offloaded_model is not None: cpu_offload(cpu_offloaded_model, device) @@ -223,35 +255,47 @@ class VlpnStableDiffusion(DiffusionPipeline): width: int, height: int, strength: float, - callback_steps: Optional[int] + callback_steps: Optional[int], ): - if isinstance(prompt, str) or (isinstance(prompt, list) and isinstance(prompt[0], int)): + if isinstance(prompt, str) or ( + isinstance(prompt, list) and isinstance(prompt[0], int) + ): prompt = [prompt] if negative_prompt is None: negative_prompt = "" - if isinstance(negative_prompt, str) or (isinstance(negative_prompt, list) and isinstance(negative_prompt[0], int)): + if isinstance(negative_prompt, str) or ( + isinstance(negative_prompt, list) and isinstance(negative_prompt[0], int) + ): negative_prompt = [negative_prompt] * len(prompt) if not isinstance(prompt, list): - raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + raise ValueError( + f"`prompt` has to be of type `str` or `list` but is {type(prompt)}" + ) if not isinstance(negative_prompt, list): - raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}") + raise ValueError( + f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}" + ) if len(negative_prompt) != len(prompt): raise ValueError( - f"`prompt` and `negative_prompt` have to be the same length, but are {len(prompt)} and {len(negative_prompt)}") + f"`prompt` and `negative_prompt` have to be the same length, but are {len(prompt)} and {len(negative_prompt)}" + ) if strength < 0 or strength > 1: raise ValueError(f"`strength` should in [0.0, 1.0] but is {strength}") if height % 8 != 0 or width % 8 != 0: - raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + raise ValueError( + f"`height` and `width` have to be divisible by 8 but are {height} and {width}." + ) if (callback_steps is None) or ( - callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + callback_steps is not None + and (not isinstance(callback_steps, int) or callback_steps <= 0) ): raise ValueError( f"`callback_steps` has to be a positive integer but is {callback_steps} of type" @@ -266,7 +310,7 @@ class VlpnStableDiffusion(DiffusionPipeline): negative_prompt: Union[List[str], List[List[int]]], num_images_per_prompt: int, do_classifier_free_guidance: bool, - device + device, ): if isinstance(prompt[0], str): text_input_ids = self.tokenizer(prompt, padding="do_not_pad").input_ids @@ -277,7 +321,9 @@ class VlpnStableDiffusion(DiffusionPipeline): if do_classifier_free_guidance: if isinstance(prompt[0], str): - unconditional_input_ids = self.tokenizer(negative_prompt, padding="do_not_pad").input_ids + unconditional_input_ids = self.tokenizer( + negative_prompt, padding="do_not_pad" + ).input_ids else: unconditional_input_ids = negative_prompt unconditional_input_ids *= num_images_per_prompt @@ -286,12 +332,17 @@ class VlpnStableDiffusion(DiffusionPipeline): text_inputs = unify_input_ids(self.tokenizer, text_input_ids) text_input_ids = text_inputs.input_ids - if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + if ( + hasattr(self.text_encoder.config, "use_attention_mask") + and self.text_encoder.config.use_attention_mask + ): attention_mask = text_inputs.attention_mask.to(device) else: attention_mask = None - prompt_embeds = get_extended_embeddings(self.text_encoder, text_input_ids.to(device), attention_mask) + prompt_embeds = get_extended_embeddings( + self.text_encoder, text_input_ids.to(device), attention_mask + ) prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) return prompt_embeds @@ -301,25 +352,21 @@ class VlpnStableDiffusion(DiffusionPipeline): init_timestep = min(int(num_inference_steps * strength), num_inference_steps) t_start = max(num_inference_steps - init_timestep, 0) - timesteps = self.scheduler.timesteps[t_start * self.scheduler.order:] + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] timesteps = timesteps.to(device) return timesteps, num_inference_steps - t_start - def prepare_brightness_offset(self, batch_size, height, width, dtype, device, generator=None): - offset_image = perlin_noise( - (batch_size, 1, width, height), - res=1, - generator=generator, - dtype=dtype, - device=device - ) - offset_latents = self.vae.encode(offset_image).latent_dist.sample(generator=generator) - offset_latents = self.vae.config.scaling_factor * offset_latents - return offset_latents - - def prepare_latents_from_image(self, init_image, timestep, batch_size, brightness_offset, dtype, device, generator=None): + def prepare_latents_from_image( + self, + init_image, + timestep, + batch_size, + dtype, + device, + generator=None, + ): init_image = init_image.to(device=device, dtype=dtype) latents = self.vae.encode(init_image).latent_dist.sample(generator=generator) latents = self.vae.config.scaling_factor * latents @@ -333,20 +380,32 @@ class VlpnStableDiffusion(DiffusionPipeline): latents = torch.cat([latents] * batch_multiplier, dim=0) # add noise to latents using the timesteps - noise = torch.randn(latents.shape, generator=generator, device=device, dtype=dtype) - - if brightness_offset != 0: - noise += brightness_offset * self.prepare_brightness_offset( - batch_size, init_image.shape[3], init_image.shape[2], dtype, device, generator - ) + noise = torch.randn( + latents.shape, generator=generator, device=device, dtype=dtype + ) # get latents latents = self.scheduler.add_noise(latents, noise, timestep) return latents - def prepare_latents(self, batch_size, num_channels_latents, height, width, brightness_offset, dtype, device, generator, latents=None): - shape = (batch_size, num_channels_latents, height // self.vae_scale_factor, width // self.vae_scale_factor) + def prepare_latents( + self, + batch_size, + num_channels_latents, + height, + width, + dtype, + device, + generator, + latents=None, + ): + shape = ( + batch_size, + num_channels_latents, + height // self.vae_scale_factor, + width // self.vae_scale_factor, + ) if isinstance(generator, list) and len(generator) != batch_size: raise ValueError( f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" @@ -354,15 +413,12 @@ class VlpnStableDiffusion(DiffusionPipeline): ) if latents is None: - latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + latents = randn_tensor( + shape, generator=generator, device=device, dtype=dtype + ) else: latents = latents.to(device) - if brightness_offset != 0: - latents += brightness_offset * self.prepare_brightness_offset( - batch_size, height, width, dtype, device, generator - ) - # scale the initial noise by the standard deviation required by the scheduler latents = latents * self.scheduler.init_noise_sigma return latents @@ -373,13 +429,17 @@ class VlpnStableDiffusion(DiffusionPipeline): # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 # and should be between [0, 1] - accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + accepts_eta = "eta" in set( + inspect.signature(self.scheduler.step).parameters.keys() + ) extra_step_kwargs = {} if accepts_eta: extra_step_kwargs["eta"] = eta # check if the scheduler accepts generator - accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + accepts_generator = "generator" in set( + inspect.signature(self.scheduler.step).parameters.keys() + ) if accepts_generator: extra_step_kwargs["generator"] = generator return extra_step_kwargs @@ -396,7 +456,9 @@ class VlpnStableDiffusion(DiffusionPipeline): def __call__( self, prompt: Union[str, List[str], List[int], List[List[int]]], - negative_prompt: Optional[Union[str, List[str], List[int], List[List[int]]]] = None, + negative_prompt: Optional[ + Union[str, List[str], List[int], List[List[int]]] + ] = None, num_images_per_prompt: int = 1, strength: float = 1.0, height: Optional[int] = None, @@ -407,12 +469,12 @@ class VlpnStableDiffusion(DiffusionPipeline): eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, - brightness_offset: Union[float, torch.FloatTensor] = 0, output_type: str = "pil", return_dict: bool = True, callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, + guidance_rescale: float = 0.0, ): r""" Function invoked when calling the pipeline for generation. @@ -472,7 +534,9 @@ class VlpnStableDiffusion(DiffusionPipeline): width = width or self.unet.config.sample_size * self.vae_scale_factor # 1. Check inputs. Raise error if not correct - prompt, negative_prompt = self.check_inputs(prompt, negative_prompt, width, height, strength, callback_steps) + prompt, negative_prompt = self.check_inputs( + prompt, negative_prompt, width, height, strength, callback_steps + ) # 2. Define call parameters batch_size = len(prompt) @@ -488,7 +552,7 @@ class VlpnStableDiffusion(DiffusionPipeline): negative_prompt, num_images_per_prompt, do_classifier_free_guidance, - device + device, ) # 4. Prepare latent variables @@ -497,7 +561,9 @@ class VlpnStableDiffusion(DiffusionPipeline): # 5. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) - timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device) + timesteps, num_inference_steps = self.get_timesteps( + num_inference_steps, strength, device + ) latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) # 6. Prepare latent variables @@ -506,7 +572,6 @@ class VlpnStableDiffusion(DiffusionPipeline): image, latent_timestep, batch_size * num_images_per_prompt, - brightness_offset, prompt_embeds.dtype, device, generator, @@ -517,7 +582,6 @@ class VlpnStableDiffusion(DiffusionPipeline): num_channels_latents, height, width, - brightness_offset, prompt_embeds.dtype, device, generator, @@ -530,14 +594,20 @@ class VlpnStableDiffusion(DiffusionPipeline): # 8. Denoising loo if do_self_attention_guidance: store_processor = CrossAttnStoreProcessor() - self.unet.mid_block.attentions[0].transformer_blocks[0].attn1.processor = store_processor + self.unet.mid_block.attentions[0].transformer_blocks[ + 0 + ].attn1.processor = store_processor num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance - latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents - latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + latent_model_input = ( + torch.cat([latents] * 2) if do_classifier_free_guidance else latents + ) + latent_model_input = self.scheduler.scale_model_input( + latent_model_input, t + ) # predict the noise residual noise_pred = self.unet( @@ -551,7 +621,12 @@ class VlpnStableDiffusion(DiffusionPipeline): # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) - noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + noise_pred = rescale_noise_cfg( + noise_pred, noise_pred_text, guidance_rescale=guidance_rescale + ) if do_self_attention_guidance: # classifier-free guidance produces two chunks of attention map @@ -561,15 +636,24 @@ class VlpnStableDiffusion(DiffusionPipeline): # DDIM-like prediction of x0 pred_x0 = self.pred_x0(latents, noise_pred_uncond, t) # get the stored attention maps - uncond_attn, cond_attn = store_processor.attention_probs.chunk(2) + uncond_attn, cond_attn = store_processor.attention_probs.chunk( + 2 + ) # self-attention-based degrading of latents degraded_latents = self.sag_masking( - pred_x0, uncond_attn, t, self.pred_epsilon(latents, noise_pred_uncond, t) + pred_x0, + uncond_attn, + t, + self.pred_epsilon(latents, noise_pred_uncond, t), ) uncond_emb, _ = prompt_embeds.chunk(2) # forward and give guidance degraded_pred = self.unet( - degraded_latents, t, encoder_hidden_states=uncond_emb, return_dict=False)[0] + degraded_latents, + t, + encoder_hidden_states=uncond_emb, + return_dict=False, + )[0] noise_pred += sag_scale * (noise_pred_uncond - degraded_pred) else: # DDIM-like prediction of x0 @@ -578,18 +662,29 @@ class VlpnStableDiffusion(DiffusionPipeline): cond_attn = store_processor.attention_probs # self-attention-based degrading of latents degraded_latents = self.sag_masking( - pred_x0, cond_attn, t, self.pred_epsilon(latents, noise_pred, t) + pred_x0, + cond_attn, + t, + self.pred_epsilon(latents, noise_pred, t), ) # forward and give guidance degraded_pred = self.unet( - degraded_latents, t, encoder_hidden_states=prompt_embeds, return_dict=False)[0] + degraded_latents, + t, + encoder_hidden_states=prompt_embeds, + return_dict=False, + )[0] noise_pred += sag_scale * (noise_pred - degraded_pred) # compute the previous noisy sample x_t -> x_t-1 - latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] + latents = self.scheduler.step( + noise_pred, t, latents, **extra_step_kwargs, return_dict=False + )[0] # call the callback, if provided - if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 + ): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) @@ -615,7 +710,9 @@ class VlpnStableDiffusion(DiffusionPipeline): if not return_dict: return (image, has_nsfw_concept) - return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=has_nsfw_concept) + return StableDiffusionPipelineOutput( + images=image, nsfw_content_detected=has_nsfw_concept + ) # Self-Attention-Guided (SAG) Stable Diffusion @@ -632,16 +729,23 @@ class VlpnStableDiffusion(DiffusionPipeline): attn_map = attn_map.reshape(b, h, hw1, hw2) attn_mask = attn_map.mean(1, keepdim=False).sum(1, keepdim=False) > 1.0 attn_mask = ( - attn_mask.reshape(b, map_size, map_size).unsqueeze(1).repeat(1, latent_channel, 1, 1).type(attn_map.dtype) + attn_mask.reshape(b, map_size, map_size) + .unsqueeze(1) + .repeat(1, latent_channel, 1, 1) + .type(attn_map.dtype) ) attn_mask = torch.nn.functional.interpolate(attn_mask, (latent_h, latent_w)) # Blur according to the self-attention mask degraded_latents = gaussian_blur_2d(original_latents, kernel_size=9, sigma=1.0) - degraded_latents = degraded_latents * attn_mask + original_latents * (1 - attn_mask) + degraded_latents = degraded_latents * attn_mask + original_latents * ( + 1 - attn_mask + ) # Noise it again to match the noise level - degraded_latents = self.scheduler.add_noise(degraded_latents, noise=eps, timesteps=t) + degraded_latents = self.scheduler.add_noise( + degraded_latents, noise=eps, timesteps=t + ) return degraded_latents @@ -652,13 +756,19 @@ class VlpnStableDiffusion(DiffusionPipeline): beta_prod_t = 1 - alpha_prod_t if self.scheduler.config.prediction_type == "epsilon": - pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + pred_original_sample = ( + sample - beta_prod_t ** (0.5) * model_output + ) / alpha_prod_t ** (0.5) elif self.scheduler.config.prediction_type == "sample": pred_original_sample = model_output elif self.scheduler.config.prediction_type == "v_prediction": - pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + pred_original_sample = (alpha_prod_t**0.5) * sample - ( + beta_prod_t**0.5 + ) * model_output # predict V - model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + model_output = (alpha_prod_t**0.5) * model_output + ( + beta_prod_t**0.5 + ) * sample else: raise ValueError( f"prediction_type given as {self.scheduler.config.prediction_type} must be one of `epsilon`, `sample`," @@ -674,9 +784,13 @@ class VlpnStableDiffusion(DiffusionPipeline): if self.scheduler.config.prediction_type == "epsilon": pred_eps = model_output elif self.scheduler.config.prediction_type == "sample": - pred_eps = (sample - (alpha_prod_t**0.5) * model_output) / (beta_prod_t**0.5) + pred_eps = (sample - (alpha_prod_t**0.5) * model_output) / ( + beta_prod_t**0.5 + ) elif self.scheduler.config.prediction_type == "v_prediction": - pred_eps = (beta_prod_t**0.5) * sample + (alpha_prod_t**0.5) * model_output + pred_eps = (beta_prod_t**0.5) * sample + ( + alpha_prod_t**0.5 + ) * model_output else: raise ValueError( f"prediction_type given as {self.scheduler.config.prediction_type} must be one of `epsilon`, `sample`," diff --git a/train_dreambooth.py b/train_dreambooth.py index 2aca1e7..659b84c 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -5,34 +5,70 @@ import itertools from pathlib import Path from functools import partial import math +import warnings import torch +import torch._dynamo import torch.utils.checkpoint +import hidet from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import LoggerType, set_seed -from slugify import slugify + +# from diffusers.models.attention_processor import AttnProcessor +from diffusers.utils.import_utils import is_xformers_available import transformers -from util.files import load_config, load_embeddings_from_dir +import numpy as np +from slugify import slugify + from data.csv import VlpnDataModule, keyword_filter -from training.functional import train, get_models +from models.clip.embeddings import patch_managed_embeddings +from training.functional import train, add_placeholder_tokens, get_models from training.strategy.dreambooth import dreambooth_strategy from training.optimization import get_scheduler -from training.util import save_args +from training.sampler import create_named_schedule_sampler +from training.util import AverageMeter, save_args +from util.files import load_config, load_embeddings_from_dir + logger = get_logger(__name__) +warnings.filterwarnings("ignore") + torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.benchmark = True +# torch._dynamo.config.log_level = logging.WARNING +torch._dynamo.config.suppress_errors = True + +hidet.torch.dynamo_config.use_tensor_core(True) +hidet.torch.dynamo_config.search_space(0) + + +def patch_xformers(dtype): + if is_xformers_available(): + import xformers + import xformers.ops + + orig_xformers_memory_efficient_attention = ( + xformers.ops.memory_efficient_attention + ) + + def xformers_memory_efficient_attention( + query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs + ): + return orig_xformers_memory_efficient_attention( + query.to(dtype), key.to(dtype), value.to(dtype), **kwargs + ) + + xformers.ops.memory_efficient_attention = xformers_memory_efficient_attention + def parse_args(): - parser = argparse.ArgumentParser( - description="Simple example of a training script." - ) + parser = argparse.ArgumentParser(description="Simple example of a training script.") parser.add_argument( "--pretrained_model_name_or_path", type=str, @@ -49,7 +85,7 @@ def parse_args(): "--train_data_file", type=str, default=None, - help="A folder containing the training data." + help="A folder containing the training data.", ) parser.add_argument( "--train_data_template", @@ -60,13 +96,13 @@ def parse_args(): "--train_set_pad", type=int, default=None, - help="The number to fill train dataset items up to." + help="The number to fill train dataset items up to.", ) parser.add_argument( "--valid_set_pad", type=int, default=None, - help="The number to fill validation dataset items up to." + help="The number to fill validation dataset items up to.", ) parser.add_argument( "--project", @@ -75,20 +111,58 @@ def parse_args(): help="The name of the current project.", ) parser.add_argument( - "--exclude_collections", + "--auto_cycles", type=str, default="o", help="Cycles to run automatically." + ) + parser.add_argument( + "--cycle_decay", type=float, default=1.0, help="Learning rate decay per cycle." + ) + parser.add_argument( + "--placeholder_tokens", type=str, - nargs='*', - help="Exclude all items with a listed collection.", + nargs="*", + help="A token to use as a placeholder for the concept.", ) parser.add_argument( - "--train_text_encoder_epochs", - default=999999, - help="Number of epochs the text encoder will be trained." + "--initializer_tokens", + type=str, + nargs="*", + help="A token to use as initializer word.", + ) + parser.add_argument( + "--filter_tokens", type=str, nargs="*", help="Tokens to filter the dataset by." + ) + parser.add_argument( + "--initializer_noise", + type=float, + default=0, + help="Noise to apply to the initializer word", + ) + parser.add_argument( + "--alias_tokens", + type=str, + nargs="*", + default=[], + help="Tokens to create an alias for.", + ) + parser.add_argument( + "--inverted_initializer_tokens", + type=str, + nargs="*", + help="A token to use as initializer word.", + ) + parser.add_argument( + "--num_vectors", type=int, nargs="*", help="Number of vectors per embedding." + ) + parser.add_argument( + "--exclude_collections", + type=str, + nargs="*", + help="Exclude all items with a listed collection.", ) parser.add_argument( "--num_buckets", type=int, - default=0, + default=2, help="Number of aspect ratio buckets in either direction.", ) parser.add_argument( @@ -119,19 +193,6 @@ def parse_args(): action="store_true", help="Shuffle tags.", ) - parser.add_argument( - "--vector_dropout", - type=int, - default=0, - help="Vector dropout probability.", - ) - parser.add_argument( - "--vector_shuffle", - type=str, - default="auto", - choices=["all", "trailing", "leading", "between", "auto", "off"], - help='Vector shuffling algorithm.', - ) parser.add_argument( "--guidance_scale", type=float, @@ -141,7 +202,7 @@ def parse_args(): "--num_class_images", type=int, default=0, - help="How many class images to generate." + help="How many class images to generate.", ) parser.add_argument( "--class_image_dir", @@ -161,17 +222,19 @@ def parse_args(): default=None, help="The embeddings directory where Textual Inversion embeddings are stored.", ) + parser.add_argument( + "--train_dir_embeddings", + action="store_true", + help="Train embeddings loaded from embeddings directory.", + ) parser.add_argument( "--collection", type=str, - nargs='*', + nargs="*", help="A collection to filter the dataset.", ) parser.add_argument( - "--seed", - type=int, - default=None, - help="A seed for reproducible training." + "--seed", type=int, default=None, help="A seed for reproducible training." ) parser.add_argument( "--resolution", @@ -189,15 +252,13 @@ def parse_args(): help="Perlin offset noise strength.", ) parser.add_argument( - "--num_train_epochs", - type=int, - default=None - ) - parser.add_argument( - "--num_train_steps", - type=int, - default=2000 + "--input_pertubation", + type=float, + default=0, + help="The scale of input pretubation. Recommended 0.1.", ) + parser.add_argument("--num_train_epochs", type=int, default=None) + parser.add_argument("--num_train_steps", type=int, default=2000) parser.add_argument( "--gradient_accumulation_steps", type=int, @@ -205,9 +266,9 @@ def parse_args(): help="Number of updates steps to accumulate before performing a backward/update pass.", ) parser.add_argument( - "--gradient_checkpointing", - action="store_true", - help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + "--train_text_encoder_cycles", + default=999999, + help="Number of epochs the text encoder will be trained.", ) parser.add_argument( "--find_lr", @@ -215,9 +276,15 @@ def parse_args(): help="Automatically find a learning rate (no training).", ) parser.add_argument( - "--learning_rate", + "--learning_rate_unet", + type=float, + default=1e-4, + help="Initial learning rate (after the potential warmup period) to use.", + ) + parser.add_argument( + "--learning_rate_text", type=float, - default=2e-6, + default=5e-5, help="Initial learning rate (after the potential warmup period) to use.", ) parser.add_argument( @@ -229,27 +296,31 @@ def parse_args(): "--lr_scheduler", type=str, default="one_cycle", - choices=["linear", "cosine", "cosine_with_restarts", "polynomial", - "constant", "constant_with_warmup", "one_cycle"], - help='The scheduler type to use.', + choices=[ + "linear", + "cosine", + "cosine_with_restarts", + "polynomial", + "constant", + "constant_with_warmup", + "one_cycle", + ], + help="The scheduler type to use.", ) parser.add_argument( "--lr_warmup_epochs", type=int, default=10, - help="Number of steps for the warmup in the lr scheduler." + help="Number of steps for the warmup in the lr scheduler.", ) parser.add_argument( - "--lr_mid_point", - type=float, - default=0.3, - help="OneCycle schedule mid point." + "--lr_mid_point", type=float, default=0.3, help="OneCycle schedule mid point." ) parser.add_argument( "--lr_cycles", type=int, default=None, - help="Number of restart cycles in the lr scheduler (if supported)." + help="Number of restart cycles in the lr scheduler (if supported).", ) parser.add_argument( "--lr_warmup_func", @@ -261,7 +332,7 @@ def parse_args(): "--lr_warmup_exp", type=int, default=1, - help='If lr_warmup_func is "cos", exponent to modify the function' + help='If lr_warmup_func is "cos", exponent to modify the function', ) parser.add_argument( "--lr_annealing_func", @@ -273,76 +344,76 @@ def parse_args(): "--lr_annealing_exp", type=int, default=3, - help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' + help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function', ) parser.add_argument( "--lr_min_lr", type=float, default=0.04, - help="Minimum learning rate in the lr scheduler." - ) - parser.add_argument( - "--use_ema", - action="store_true", - help="Whether to use EMA model." - ) - parser.add_argument( - "--ema_inv_gamma", - type=float, - default=1.0 + help="Minimum learning rate in the lr scheduler.", ) + parser.add_argument("--min_snr_gamma", type=int, default=5, help="MinSNR gamma.") parser.add_argument( - "--ema_power", - type=float, - default=6/7 - ) - parser.add_argument( - "--ema_max_decay", - type=float, - default=0.9999 + "--schedule_sampler", + type=str, + default="uniform", + choices=["uniform", "loss-second-moment"], + help="Noise schedule sampler.", ) parser.add_argument( "--optimizer", type=str, - default="dadan", - choices=["adam", "adam8bit", "adan", "lion", "dadam", "dadan", "adafactor"], - help='Optimizer to use' + default="adan", + choices=[ + "adam", + "adam8bit", + "adan", + "lion", + "dadam", + "dadan", + "dlion", + "adafactor", + ], + help="Optimizer to use", ) parser.add_argument( "--dadaptation_d0", type=float, default=1e-6, - help="The d0 parameter for Dadaptation optimizers." + help="The d0 parameter for Dadaptation optimizers.", + ) + parser.add_argument( + "--dadaptation_growth_rate", + type=float, + default=math.inf, + help="The growth_rate parameter for Dadaptation optimizers.", ) parser.add_argument( "--adam_beta1", type=float, default=None, - help="The beta1 parameter for the Adam optimizer." + help="The beta1 parameter for the Adam optimizer.", ) parser.add_argument( "--adam_beta2", type=float, default=None, - help="The beta2 parameter for the Adam optimizer." + help="The beta2 parameter for the Adam optimizer.", ) parser.add_argument( - "--adam_weight_decay", - type=float, - default=1e-2, - help="Weight decay to use." + "--adam_weight_decay", type=float, default=2e-2, help="Weight decay to use." ) parser.add_argument( "--adam_epsilon", type=float, default=1e-08, - help="Epsilon value for the Adam optimizer" + help="Epsilon value for the Adam optimizer", ) parser.add_argument( "--adam_amsgrad", type=bool, default=False, - help="Amsgrad value for the Adam optimizer" + help="Amsgrad value for the Adam optimizer", ) parser.add_argument( "--mixed_precision", @@ -355,12 +426,28 @@ def parse_args(): "and an Nvidia Ampere GPU." ), ) + parser.add_argument( + "--compile_unet", + action="store_true", + help="Compile UNet with Torch Dynamo.", + ) + parser.add_argument( + "--use_xformers", + action="store_true", + help="Use xformers.", + ) parser.add_argument( "--sample_frequency", type=int, default=1, help="How often to save a checkpoint and sample image", ) + parser.add_argument( + "--sample_num", + type=int, + default=None, + help="How often to save a checkpoint and sample image (in number of samples)", + ) parser.add_argument( "--sample_image_size", type=int, @@ -383,19 +470,19 @@ def parse_args(): "--valid_set_size", type=int, default=None, - help="Number of images in the validation dataset." + help="Number of images in the validation dataset.", ) parser.add_argument( "--valid_set_repeat", type=int, default=1, - help="Times the images in the validation dataset are repeated." + help="Times the images in the validation dataset are repeated.", ) parser.add_argument( "--train_batch_size", type=int, default=1, - help="Batch size (per device) for the training dataloader." + help="Batch size (per device) for the training dataloader.", ) parser.add_argument( "--sample_steps", @@ -407,13 +494,18 @@ def parse_args(): "--prior_loss_weight", type=float, default=1.0, - help="The weight of prior preservation loss." + help="The weight of prior preservation loss.", ) + parser.add_argument("--run_pti", action="store_true", help="Whether to run PTI.") + parser.add_argument("--emb_alpha", type=float, default=1.0, help="Embedding alpha") parser.add_argument( - "--max_grad_norm", - default=1.0, + "--emb_dropout", type=float, - help="Max gradient norm." + default=0, + help="Embedding dropout probability.", + ) + parser.add_argument( + "--max_grad_norm", default=1.0, type=float, help="Max gradient norm." ) parser.add_argument( "--noise_timesteps", @@ -424,7 +516,7 @@ def parse_args(): "--config", type=str, default=None, - help="Path to a JSON configuration file containing arguments for invoking this script." + help="Path to a JSON configuration file containing arguments for invoking this script.", ) args = parser.parse_args() @@ -441,6 +533,67 @@ def parse_args(): if args.project is None: raise ValueError("You must specify --project") + if args.initializer_tokens is None: + args.initializer_tokens = [] + + if args.placeholder_tokens is None: + args.placeholder_tokens = [] + + if isinstance(args.placeholder_tokens, str): + args.placeholder_tokens = [args.placeholder_tokens] + + if isinstance(args.initializer_tokens, str): + args.initializer_tokens = [args.initializer_tokens] * len( + args.placeholder_tokens + ) + + if len(args.placeholder_tokens) == 0: + args.placeholder_tokens = [ + f"<*{i}>" for i in range(len(args.initializer_tokens)) + ] + + if len(args.initializer_tokens) == 0: + args.initializer_tokens = args.placeholder_tokens.copy() + + if len(args.placeholder_tokens) != len(args.initializer_tokens): + raise ValueError( + "--placeholder_tokens and --initializer_tokens must have the same number of items" + ) + + if isinstance(args.inverted_initializer_tokens, str): + args.inverted_initializer_tokens = [args.inverted_initializer_tokens] * len( + args.placeholder_tokens + ) + + if ( + isinstance(args.inverted_initializer_tokens, list) + and len(args.inverted_initializer_tokens) != 0 + ): + args.placeholder_tokens += [f"inv_{t}" for t in args.placeholder_tokens] + args.initializer_tokens += args.inverted_initializer_tokens + + if isinstance(args.num_vectors, int): + args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) + + if isinstance(args.num_vectors, list) and len(args.placeholder_tokens) != len( + args.num_vectors + ): + raise ValueError( + "--placeholder_tokens and --num_vectors must have the same number of items" + ) + + if args.alias_tokens is None: + args.alias_tokens = [] + + if not isinstance(args.alias_tokens, list) or len(args.alias_tokens) % 2 != 0: + raise ValueError("--alias_tokens must be a list with an even number of items") + + if args.filter_tokens is None: + args.filter_tokens = args.placeholder_tokens.copy() + + if isinstance(args.filter_tokens, str): + args.filter_tokens = [args.filter_tokens] + if isinstance(args.collection, str): args.collection = [args.collection] @@ -451,15 +604,15 @@ def parse_args(): raise ValueError("You must specify --output_dir") if args.adam_beta1 is None: - if args.optimizer in ('adam', 'adam8bit'): + if args.optimizer in ("adam", "adam8bit", "dadam"): args.adam_beta1 = 0.9 - elif args.optimizer == 'lion': + elif args.optimizer in ("lion", "dlion"): args.adam_beta1 = 0.95 if args.adam_beta2 is None: - if args.optimizer in ('adam', 'adam8bit'): + if args.optimizer in ("adam", "adam8bit", "dadam"): args.adam_beta2 = 0.999 - elif args.optimizer == 'lion': + elif args.optimizer in ("lion", "dlion"): args.adam_beta2 = 0.98 return args @@ -475,7 +628,7 @@ def main(): accelerator = Accelerator( log_with=LoggerType.TENSORBOARD, project_dir=f"{output_dir}", - mixed_precision=args.mixed_precision + mixed_precision=args.mixed_precision, ) weight_dtype = torch.float32 @@ -484,6 +637,8 @@ def main(): elif args.mixed_precision == "bf16": weight_dtype = torch.bfloat16 + patch_xformers(weight_dtype) + logging.basicConfig(filename=output_dir / "log.txt", level=logging.DEBUG) if args.seed is None: @@ -493,44 +648,125 @@ def main(): save_args(output_dir, args) - tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( - args.pretrained_model_name_or_path) - - tokenizer.set_use_vector_shuffle(args.vector_shuffle) - tokenizer.set_dropout(args.vector_dropout) + tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler = get_models( + args.pretrained_model_name_or_path + ) + embeddings = patch_managed_embeddings( + text_encoder, args.emb_alpha, args.emb_dropout + ) + schedule_sampler = create_named_schedule_sampler( + args.schedule_sampler, noise_scheduler.config.num_train_timesteps + ) vae.enable_slicing() - vae.set_use_memory_efficient_attention_xformers(True) - unet.enable_xformers_memory_efficient_attention() + + if args.use_xformers: + vae.set_use_memory_efficient_attention_xformers(True) + unet.enable_xformers_memory_efficient_attention() + # elif args.compile_unet: + # unet.mid_block.attentions[0].transformer_blocks[0].attn1._use_2_0_attn = False + # + # proc = AttnProcessor() + # + # def fn_recursive_set_proc(module: torch.nn.Module): + # if hasattr(module, "processor"): + # module.processor = proc + # + # for child in module.children(): + # fn_recursive_set_proc(child) + # + # fn_recursive_set_proc(unet) if args.gradient_checkpointing: unet.enable_gradient_checkpointing() - text_encoder.gradient_checkpointing_enable() + + if len(args.alias_tokens) != 0: + alias_placeholder_tokens = args.alias_tokens[::2] + alias_initializer_tokens = args.alias_tokens[1::2] + + added_tokens, added_ids = add_placeholder_tokens( + tokenizer=tokenizer, + embeddings=embeddings, + placeholder_tokens=alias_placeholder_tokens, + initializer_tokens=alias_initializer_tokens, + ) + embeddings.persist() + print( + f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}" + ) + + placeholder_tokens = [] + placeholder_token_ids = [] if args.embeddings_dir is not None: embeddings_dir = Path(args.embeddings_dir) if not embeddings_dir.exists() or not embeddings_dir.is_dir(): raise ValueError("--embeddings_dir must point to an existing directory") - added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) - embeddings.persist() - print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") + added_tokens, added_ids = load_embeddings_from_dir( + tokenizer, embeddings, embeddings_dir + ) + + placeholder_tokens = added_tokens + placeholder_token_ids = added_ids + + print( + f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}" + ) + + if args.train_dir_embeddings: + print("Training embeddings from embeddings dir") + else: + embeddings.persist() + + if len(args.placeholder_tokens) != 0 and not args.train_dir_embeddings: + placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( + tokenizer=tokenizer, + embeddings=embeddings, + placeholder_tokens=args.placeholder_tokens, + initializer_tokens=args.initializer_tokens, + num_vectors=args.num_vectors, + initializer_noise=args.initializer_noise, + ) + + placeholder_tokens = args.placeholder_tokens + + stats = list( + zip( + placeholder_tokens, + placeholder_token_ids, + args.initializer_tokens, + initializer_token_ids, + ) + ) + print(f"Training embeddings: {stats}") if args.scale_lr: - args.learning_rate = ( - args.learning_rate * args.gradient_accumulation_steps * - args.train_batch_size * accelerator.num_processes + args.learning_rate_unet = ( + args.learning_rate_unet + * args.gradient_accumulation_steps + * args.train_batch_size + * accelerator.num_processes + ) + args.learning_rate_text = ( + args.learning_rate_text + * args.gradient_accumulation_steps + * args.train_batch_size + * accelerator.num_processes ) if args.find_lr: - args.learning_rate = 1e-6 + args.learning_rate_unet = 1e-6 + args.learning_rate_text = 1e-6 args.lr_scheduler = "exponential_growth" - if args.optimizer == 'adam8bit': + if args.optimizer == "adam8bit": try: import bitsandbytes as bnb except ImportError: - raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.") + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) create_optimizer = partial( bnb.optim.AdamW8bit, @@ -539,7 +775,7 @@ def main(): eps=args.adam_epsilon, amsgrad=args.adam_amsgrad, ) - elif args.optimizer == 'adam': + elif args.optimizer == "adam": create_optimizer = partial( torch.optim.AdamW, betas=(args.adam_beta1, args.adam_beta2), @@ -547,22 +783,27 @@ def main(): eps=args.adam_epsilon, amsgrad=args.adam_amsgrad, ) - elif args.optimizer == 'adan': + elif args.optimizer == "adan": try: import timm.optim except ImportError: - raise ImportError("To use Adan, please install the PyTorch Image Models library: `pip install timm`.") + raise ImportError( + "To use Adan, please install the PyTorch Image Models library: `pip install timm`." + ) create_optimizer = partial( timm.optim.Adan, weight_decay=args.adam_weight_decay, eps=args.adam_epsilon, + no_prox=True, ) - elif args.optimizer == 'lion': + elif args.optimizer == "lion": try: import lion_pytorch except ImportError: - raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion-pytorch`.") + raise ImportError( + "To use Lion, please install the lion_pytorch library: `pip install lion-pytorch`." + ) create_optimizer = partial( lion_pytorch.Lion, @@ -570,7 +811,7 @@ def main(): weight_decay=args.adam_weight_decay, use_triton=True, ) - elif args.optimizer == 'adafactor': + elif args.optimizer == "adafactor": create_optimizer = partial( transformers.optimization.Adafactor, weight_decay=args.adam_weight_decay, @@ -580,13 +821,16 @@ def main(): ) args.lr_scheduler = "adafactor" - args.lr_min_lr = args.learning_rate - args.learning_rate = None - elif args.optimizer == 'dadam': + args.lr_min_lr = args.learning_rate_unet + args.learning_rate_unet = None + args.learning_rate_text = None + elif args.optimizer == "dadam": try: import dadaptation except ImportError: - raise ImportError("To use DAdaptAdam, please install the dadaptation library: `pip install dadaptation`.") + raise ImportError( + "To use DAdaptAdam, please install the dadaptation library: `pip install dadaptation`." + ) create_optimizer = partial( dadaptation.DAdaptAdam, @@ -595,46 +839,65 @@ def main(): eps=args.adam_epsilon, decouple=True, d0=args.dadaptation_d0, + growth_rate=args.dadaptation_growth_rate, ) - args.learning_rate = 1.0 - elif args.optimizer == 'dadan': + args.learning_rate_unet = 1.0 + args.learning_rate_text = 1.0 + elif args.optimizer == "dadan": try: import dadaptation except ImportError: - raise ImportError("To use DAdaptAdan, please install the dadaptation library: `pip install dadaptation`.") + raise ImportError( + "To use DAdaptAdan, please install the dadaptation library: `pip install dadaptation`." + ) create_optimizer = partial( dadaptation.DAdaptAdan, weight_decay=args.adam_weight_decay, eps=args.adam_epsilon, d0=args.dadaptation_d0, + growth_rate=args.dadaptation_growth_rate, ) - args.learning_rate = 1.0 + args.learning_rate_unet = 1.0 + args.learning_rate_text = 1.0 + elif args.optimizer == "dlion": + raise ImportError("DLion has not been merged into dadaptation yet") else: - raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") + raise ValueError(f'Unknown --optimizer "{args.optimizer}"') trainer = partial( train, accelerator=accelerator, unet=unet, text_encoder=text_encoder, + tokenizer=tokenizer, vae=vae, noise_scheduler=noise_scheduler, + schedule_sampler=schedule_sampler, + min_snr_gamma=args.min_snr_gamma, dtype=weight_dtype, + seed=args.seed, + compile_unet=args.compile_unet, guidance_scale=args.guidance_scale, prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, - no_val=args.valid_set_size == 0, + sample_scheduler=sample_scheduler, + sample_batch_size=args.sample_batch_size, + sample_num_batches=args.sample_batches, + sample_num_steps=args.sample_steps, + sample_image_size=args.sample_image_size, + max_grad_norm=args.max_grad_norm, ) - checkpoint_output_dir = output_dir / "model" - sample_output_dir = output_dir / "samples" + data_generator = torch.Generator(device="cpu").manual_seed(args.seed) + data_npgenerator = np.random.default_rng(args.seed) - datamodule = VlpnDataModule( + create_datamodule = partial( + VlpnDataModule, data_file=args.train_data_file, - batch_size=args.train_batch_size, tokenizer=tokenizer, + constant_prompt_length=args.compile_unet, class_subdir=args.class_image_dir, with_guidance=args.guidance_scale != 0, num_class_images=args.num_class_images, @@ -643,83 +906,186 @@ def main(): progressive_buckets=args.progressive_buckets, bucket_step_size=args.bucket_step_size, bucket_max_pixels=args.bucket_max_pixels, - dropout=args.tag_dropout, shuffle=not args.no_tag_shuffle, template_key=args.train_data_template, - valid_set_size=args.valid_set_size, train_set_pad=args.train_set_pad, valid_set_pad=args.valid_set_pad, - seed=args.seed, - filter=partial(keyword_filter, None, args.collection, args.exclude_collections), - dtype=weight_dtype - ) - datamodule.setup() - - num_train_epochs = args.num_train_epochs - sample_frequency = args.sample_frequency - if num_train_epochs is None: - num_train_epochs = math.ceil( - args.num_train_steps / len(datamodule.train_dataset) - ) * args.gradient_accumulation_steps - sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) - - params_to_optimize = (unet.parameters(), ) - if args.train_text_encoder_epochs != 0: - params_to_optimize += ( - text_encoder.text_model.encoder.parameters(), - text_encoder.text_model.final_layer_norm.parameters(), - ) - - optimizer = create_optimizer( - itertools.chain(*params_to_optimize), - lr=args.learning_rate, + dtype=weight_dtype, + generator=data_generator, + npgenerator=data_npgenerator, ) - lr_scheduler = get_scheduler( - args.lr_scheduler, - optimizer=optimizer, - num_training_steps_per_epoch=len(datamodule.train_dataloader), - gradient_accumulation_steps=args.gradient_accumulation_steps, + create_lr_scheduler = partial( + get_scheduler, min_lr=args.lr_min_lr, warmup_func=args.lr_warmup_func, annealing_func=args.lr_annealing_func, warmup_exp=args.lr_warmup_exp, annealing_exp=args.lr_annealing_exp, - cycles=args.lr_cycles, end_lr=1e2, - train_epochs=num_train_epochs, - warmup_epochs=args.lr_warmup_epochs, mid_point=args.lr_mid_point, ) - trainer( - strategy=dreambooth_strategy, - project="dreambooth", - train_dataloader=datamodule.train_dataloader, - val_dataloader=datamodule.val_dataloader, - seed=args.seed, - optimizer=optimizer, - lr_scheduler=lr_scheduler, - num_train_epochs=num_train_epochs, - gradient_accumulation_steps=args.gradient_accumulation_steps, - sample_frequency=sample_frequency, - offset_noise_strength=args.offset_noise_strength, - # -- - tokenizer=tokenizer, - sample_scheduler=sample_scheduler, - sample_output_dir=sample_output_dir, - checkpoint_output_dir=checkpoint_output_dir, - train_text_encoder_epochs=args.train_text_encoder_epochs, - max_grad_norm=args.max_grad_norm, - use_ema=args.use_ema, - ema_inv_gamma=args.ema_inv_gamma, - ema_power=args.ema_power, - ema_max_decay=args.ema_max_decay, - sample_batch_size=args.sample_batch_size, - sample_num_batches=args.sample_batches, - sample_num_steps=args.sample_steps, - sample_image_size=args.sample_image_size, + # Dreambooth + # -------------------------------------------------------------------------------- + + dreambooth_datamodule = create_datamodule( + valid_set_size=args.valid_set_size, + batch_size=args.train_batch_size, + dropout=args.tag_dropout, + filter=partial(keyword_filter, None, args.collection, args.exclude_collections), + ) + dreambooth_datamodule.setup() + + num_train_epochs = args.num_train_epochs + dreambooth_sample_frequency = args.sample_frequency + if num_train_epochs is None: + num_train_epochs = ( + math.ceil(args.num_train_steps / len(dreambooth_datamodule.train_dataset)) + * args.gradient_accumulation_steps + ) + dreambooth_sample_frequency = math.ceil( + num_train_epochs * (dreambooth_sample_frequency / args.num_train_steps) + ) + num_training_steps_per_epoch = math.ceil( + len(dreambooth_datamodule.train_dataset) / args.gradient_accumulation_steps ) + num_train_steps = num_training_steps_per_epoch * num_train_epochs + if args.sample_num is not None: + dreambooth_sample_frequency = math.ceil(num_train_epochs / args.sample_num) + + dreambooth_project = "dreambooth" + + if accelerator.is_main_process: + accelerator.init_trackers(dreambooth_project) + + dreambooth_sample_output_dir = output_dir / dreambooth_project / "samples" + + training_iter = 0 + auto_cycles = list(args.auto_cycles) + learning_rate_unet = args.learning_rate_unet + learning_rate_text = args.learning_rate_text + lr_scheduler = args.lr_scheduler + lr_warmup_epochs = args.lr_warmup_epochs + lr_cycles = args.lr_cycles + + avg_loss = AverageMeter() + avg_acc = AverageMeter() + avg_loss_val = AverageMeter() + avg_acc_val = AverageMeter() + + params_to_optimize = [ + { + "params": (param for param in unet.parameters() if param.requires_grad), + "lr": learning_rate_unet, + }, + { + "params": ( + param for param in text_encoder.parameters() if param.requires_grad + ), + "lr": learning_rate_text, + }, + ] + group_labels = ["unet", "text"] + + dreambooth_optimizer = create_optimizer(params_to_optimize) + + while True: + if len(auto_cycles) != 0: + response = auto_cycles.pop(0) + else: + response = input( + "\n### Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> " + ) + + if response.lower().strip() == "o": + if args.learning_rate_unet is not None: + learning_rate_unet = ( + args.learning_rate_unet * 2 * (args.cycle_decay**training_iter) + ) + if args.learning_rate_text is not None: + learning_rate_text = ( + args.learning_rate_text * 2 * (args.cycle_decay**training_iter) + ) + else: + learning_rate_unet = args.learning_rate_unet * ( + args.cycle_decay**training_iter + ) + learning_rate_text = args.learning_rate_text * ( + args.cycle_decay**training_iter + ) + + if response.lower().strip() == "o": + lr_scheduler = "one_cycle" + lr_warmup_epochs = args.lr_warmup_epochs + lr_cycles = args.lr_cycles + elif response.lower().strip() == "w": + lr_scheduler = "constant_with_warmup" + lr_warmup_epochs = num_train_epochs + elif response.lower().strip() == "c": + lr_scheduler = "constant" + elif response.lower().strip() == "d": + lr_scheduler = "cosine" + lr_warmup_epochs = 0 + lr_cycles = 1 + elif response.lower().strip() == "s": + break + else: + continue + + print("") + print( + f"============ Dreambooth cycle {training_iter + 1}: {response} ============" + ) + print("") + + for group, lr in zip( + dreambooth_optimizer.param_groups, [learning_rate_unet, learning_rate_text] + ): + group["lr"] = lr + + dreambooth_lr_scheduler = create_lr_scheduler( + lr_scheduler, + gradient_accumulation_steps=args.gradient_accumulation_steps, + optimizer=dreambooth_optimizer, + num_training_steps_per_epoch=len(dreambooth_datamodule.train_dataloader), + train_epochs=num_train_epochs, + cycles=lr_cycles, + warmup_epochs=lr_warmup_epochs, + ) + + dreambooth_checkpoint_output_dir = ( + output_dir / dreambooth_project / f"model_{training_iter}" + ) + + trainer( + strategy=dreambooth_strategy, + train_dataloader=dreambooth_datamodule.train_dataloader, + val_dataloader=dreambooth_datamodule.val_dataloader, + optimizer=dreambooth_optimizer, + lr_scheduler=dreambooth_lr_scheduler, + num_train_epochs=num_train_epochs, + gradient_accumulation_steps=args.gradient_accumulation_steps, + global_step_offset=training_iter * num_train_steps, + cycle=training_iter, + train_text_encoder_cycles=args.train_text_encoder_cycles, + # -- + group_labels=group_labels, + sample_output_dir=dreambooth_sample_output_dir, + checkpoint_output_dir=dreambooth_checkpoint_output_dir, + sample_frequency=dreambooth_sample_frequency, + offset_noise_strength=args.offset_noise_strength, + input_pertubation=args.input_pertubation, + no_val=args.valid_set_size == 0, + avg_loss=avg_loss, + avg_acc=avg_acc, + avg_loss_val=avg_loss_val, + avg_acc_val=avg_acc_val, + ) + + training_iter += 1 + + accelerator.end_training() if __name__ == "__main__": diff --git a/train_lora.py b/train_lora.py index c74dd8f..fccf48d 100644 --- a/train_lora.py +++ b/train_lora.py @@ -16,6 +16,7 @@ from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import LoggerType, set_seed from peft import LoraConfig, get_peft_model + # from diffusers.models.attention_processor import AttnProcessor from diffusers.utils.import_utils import is_xformers_available import transformers @@ -34,15 +35,20 @@ from util.files import load_config, load_embeddings_from_dir # https://github.com/huggingface/peft/blob/main/examples/lora_dreambooth/train_dreambooth.py UNET_TARGET_MODULES_ORIG = ["to_q", "to_v", "query", "value"] -UNET_TARGET_MODULES = UNET_TARGET_MODULES_ORIG + ["to_out.0"] # ["to_k", "key"] +UNET_TARGET_MODULES = UNET_TARGET_MODULES_ORIG + ["to_out.0", "to_k", "key"] # [] TEXT_ENCODER_TARGET_MODULES_ORIG = ["q_proj", "v_proj"] -TEXT_ENCODER_TARGET_MODULES = TEXT_ENCODER_TARGET_MODULES_ORIG + ["out_proj"] # ["k_proj"] -TEXT_ENCODER_TARGET_MODULES_WITH_EMBEDDING = TEXT_ENCODER_TARGET_MODULES + ["token_embedding"] +TEXT_ENCODER_TARGET_MODULES = TEXT_ENCODER_TARGET_MODULES_ORIG + [ + "out_proj", + "k_proj", +] # [] +TEXT_ENCODER_TARGET_MODULES_WITH_EMBEDDING = TEXT_ENCODER_TARGET_MODULES + [ + "token_embedding" +] logger = get_logger(__name__) -warnings.filterwarnings('ignore') +warnings.filterwarnings("ignore") torch.backends.cuda.matmul.allow_tf32 = True @@ -55,20 +61,27 @@ hidet.torch.dynamo_config.use_tensor_core(True) hidet.torch.dynamo_config.search_space(0) -if is_xformers_available(): - import xformers - import xformers.ops - - orig_xformers_memory_efficient_attention = xformers.ops.memory_efficient_attention - def xformers_memory_efficient_attention(query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs): - return orig_xformers_memory_efficient_attention(query.to(key.dtype), key, value.to(key.dtype), **kwargs) - xformers.ops.memory_efficient_attention = xformers_memory_efficient_attention +def patch_xformers(dtype): + if is_xformers_available(): + import xformers + import xformers.ops + + orig_xformers_memory_efficient_attention = ( + xformers.ops.memory_efficient_attention + ) + + def xformers_memory_efficient_attention( + query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs + ): + return orig_xformers_memory_efficient_attention( + query.to(dtype), key.to(dtype), value.to(dtype), **kwargs + ) + + xformers.ops.memory_efficient_attention = xformers_memory_efficient_attention def parse_args(): - parser = argparse.ArgumentParser( - description="Simple example of a training script." - ) + parser = argparse.ArgumentParser(description="Simple example of a training script.") parser.add_argument( "--pretrained_model_name_or_path", type=str, @@ -85,7 +98,7 @@ def parse_args(): "--train_data_file", type=str, default=None, - help="A folder containing the training data." + help="A folder containing the training data.", ) parser.add_argument( "--train_data_template", @@ -96,13 +109,13 @@ def parse_args(): "--train_set_pad", type=int, default=None, - help="The number to fill train dataset items up to." + help="The number to fill train dataset items up to.", ) parser.add_argument( "--valid_set_pad", type=int, default=None, - help="The number to fill validation dataset items up to." + help="The number to fill validation dataset items up to.", ) parser.add_argument( "--project", @@ -111,64 +124,52 @@ def parse_args(): help="The name of the current project.", ) parser.add_argument( - "--auto_cycles", - type=str, - default="o", - help="Cycles to run automatically." + "--auto_cycles", type=str, default="o", help="Cycles to run automatically." ) parser.add_argument( - "--cycle_decay", - type=float, - default=1.0, - help="Learning rate decay per cycle." + "--cycle_decay", type=float, default=1.0, help="Learning rate decay per cycle." ) parser.add_argument( "--placeholder_tokens", type=str, - nargs='*', + nargs="*", help="A token to use as a placeholder for the concept.", ) parser.add_argument( "--initializer_tokens", type=str, - nargs='*', - help="A token to use as initializer word." + nargs="*", + help="A token to use as initializer word.", ) parser.add_argument( - "--filter_tokens", - type=str, - nargs='*', - help="Tokens to filter the dataset by." + "--filter_tokens", type=str, nargs="*", help="Tokens to filter the dataset by." ) parser.add_argument( "--initializer_noise", type=float, default=0, - help="Noise to apply to the initializer word" + help="Noise to apply to the initializer word", ) parser.add_argument( "--alias_tokens", type=str, - nargs='*', + nargs="*", default=[], - help="Tokens to create an alias for." + help="Tokens to create an alias for.", ) parser.add_argument( "--inverted_initializer_tokens", type=str, - nargs='*', - help="A token to use as initializer word." + nargs="*", + help="A token to use as initializer word.", ) parser.add_argument( - "--num_vectors", - type=int, - nargs='*', - help="Number of vectors per embedding." + "--num_vectors", type=int, nargs="*", help="Number of vectors per embedding." ) parser.add_argument( "--exclude_collections", type=str, - nargs='*', + nargs="*", help="Exclude all items with a listed collection.", ) parser.add_argument( @@ -214,7 +215,7 @@ def parse_args(): "--num_class_images", type=int, default=0, - help="How many class images to generate." + help="How many class images to generate.", ) parser.add_argument( "--class_image_dir", @@ -242,14 +243,11 @@ def parse_args(): parser.add_argument( "--collection", type=str, - nargs='*', + nargs="*", help="A collection to filter the dataset.", ) parser.add_argument( - "--seed", - type=int, - default=None, - help="A seed for reproducible training." + "--seed", type=int, default=None, help="A seed for reproducible training." ) parser.add_argument( "--resolution", @@ -270,18 +268,10 @@ def parse_args(): "--input_pertubation", type=float, default=0, - help="The scale of input pretubation. Recommended 0.1." - ) - parser.add_argument( - "--num_train_epochs", - type=int, - default=None - ) - parser.add_argument( - "--num_train_steps", - type=int, - default=2000 + help="The scale of input pretubation. Recommended 0.1.", ) + parser.add_argument("--num_train_epochs", type=int, default=None) + parser.add_argument("--num_train_steps", type=int, default=2000) parser.add_argument( "--gradient_accumulation_steps", type=int, @@ -289,22 +279,19 @@ def parse_args(): help="Number of updates steps to accumulate before performing a backward/update pass.", ) parser.add_argument( - "--lora_r", - type=int, - default=8, - help="Lora rank, only used if use_lora is True" + "--lora_r", type=int, default=8, help="Lora rank, only used if use_lora is True" ) parser.add_argument( "--lora_alpha", type=int, default=32, - help="Lora alpha, only used if use_lora is True" + help="Lora alpha, only used if use_lora is True", ) parser.add_argument( "--lora_dropout", type=float, default=0.0, - help="Lora dropout, only used if use_lora is True" + help="Lora dropout, only used if use_lora is True", ) parser.add_argument( "--lora_bias", @@ -344,7 +331,7 @@ def parse_args(): parser.add_argument( "--train_text_encoder_cycles", default=999999, - help="Number of epochs the text encoder will be trained." + help="Number of epochs the text encoder will be trained.", ) parser.add_argument( "--find_lr", @@ -378,27 +365,31 @@ def parse_args(): "--lr_scheduler", type=str, default="one_cycle", - choices=["linear", "cosine", "cosine_with_restarts", "polynomial", - "constant", "constant_with_warmup", "one_cycle"], - help='The scheduler type to use.', + choices=[ + "linear", + "cosine", + "cosine_with_restarts", + "polynomial", + "constant", + "constant_with_warmup", + "one_cycle", + ], + help="The scheduler type to use.", ) parser.add_argument( "--lr_warmup_epochs", type=int, default=10, - help="Number of steps for the warmup in the lr scheduler." + help="Number of steps for the warmup in the lr scheduler.", ) parser.add_argument( - "--lr_mid_point", - type=float, - default=0.3, - help="OneCycle schedule mid point." + "--lr_mid_point", type=float, default=0.3, help="OneCycle schedule mid point." ) parser.add_argument( "--lr_cycles", type=int, default=None, - help="Number of restart cycles in the lr scheduler (if supported)." + help="Number of restart cycles in the lr scheduler (if supported).", ) parser.add_argument( "--lr_warmup_func", @@ -410,7 +401,7 @@ def parse_args(): "--lr_warmup_exp", type=int, default=1, - help='If lr_warmup_func is "cos", exponent to modify the function' + help='If lr_warmup_func is "cos", exponent to modify the function', ) parser.add_argument( "--lr_annealing_func", @@ -422,69 +413,76 @@ def parse_args(): "--lr_annealing_exp", type=int, default=3, - help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' + help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function', ) parser.add_argument( "--lr_min_lr", type=float, default=0.04, - help="Minimum learning rate in the lr scheduler." - ) - parser.add_argument( - "--min_snr_gamma", - type=int, - default=5, - help="MinSNR gamma." + help="Minimum learning rate in the lr scheduler.", ) + parser.add_argument("--min_snr_gamma", type=int, default=5, help="MinSNR gamma.") parser.add_argument( "--schedule_sampler", type=str, default="uniform", choices=["uniform", "loss-second-moment"], - help="Noise schedule sampler." + help="Noise schedule sampler.", ) parser.add_argument( "--optimizer", type=str, default="adan", - choices=["adam", "adam8bit", "adan", "lion", "dadam", "dadan", "adafactor"], - help='Optimizer to use' + choices=[ + "adam", + "adam8bit", + "adan", + "lion", + "dadam", + "dadan", + "dlion", + "adafactor", + ], + help="Optimizer to use", ) parser.add_argument( "--dadaptation_d0", type=float, default=1e-6, - help="The d0 parameter for Dadaptation optimizers." + help="The d0 parameter for Dadaptation optimizers.", + ) + parser.add_argument( + "--dadaptation_growth_rate", + type=float, + default=math.inf, + help="The growth_rate parameter for Dadaptation optimizers.", ) parser.add_argument( "--adam_beta1", type=float, default=None, - help="The beta1 parameter for the Adam optimizer." + help="The beta1 parameter for the Adam optimizer.", ) parser.add_argument( "--adam_beta2", type=float, default=None, - help="The beta2 parameter for the Adam optimizer." + help="The beta2 parameter for the Adam optimizer.", ) parser.add_argument( - "--adam_weight_decay", - type=float, - default=2e-2, - help="Weight decay to use." + "--adam_weight_decay", type=float, default=2e-2, help="Weight decay to use." ) parser.add_argument( "--adam_epsilon", type=float, default=1e-08, - help="Epsilon value for the Adam optimizer" + help="Epsilon value for the Adam optimizer", ) parser.add_argument( "--adam_amsgrad", type=bool, default=False, - help="Amsgrad value for the Adam optimizer" + help="Amsgrad value for the Adam optimizer", ) parser.add_argument( "--mixed_precision", @@ -547,19 +545,19 @@ def parse_args(): "--valid_set_size", type=int, default=None, - help="Number of images in the validation dataset." + help="Number of images in the validation dataset.", ) parser.add_argument( "--valid_set_repeat", type=int, default=1, - help="Times the images in the validation dataset are repeated." + help="Times the images in the validation dataset are repeated.", ) parser.add_argument( "--train_batch_size", type=int, default=1, - help="Batch size (per device) for the training dataloader." + help="Batch size (per device) for the training dataloader.", ) parser.add_argument( "--sample_steps", @@ -571,19 +569,10 @@ def parse_args(): "--prior_loss_weight", type=float, default=1.0, - help="The weight of prior preservation loss." - ) - parser.add_argument( - "--run_pti", - action="store_true", - help="Whether to run PTI." - ) - parser.add_argument( - "--emb_alpha", - type=float, - default=1.0, - help="Embedding alpha" + help="The weight of prior preservation loss.", ) + parser.add_argument("--run_pti", action="store_true", help="Whether to run PTI.") + parser.add_argument("--emb_alpha", type=float, default=1.0, help="Embedding alpha") parser.add_argument( "--emb_dropout", type=float, @@ -591,27 +580,16 @@ def parse_args(): help="Embedding dropout probability.", ) parser.add_argument( - "--use_emb_decay", - action="store_true", - help="Whether to use embedding decay." + "--use_emb_decay", action="store_true", help="Whether to use embedding decay." ) parser.add_argument( - "--emb_decay_target", - default=0.4, - type=float, - help="Embedding decay target." + "--emb_decay_target", default=0.4, type=float, help="Embedding decay target." ) parser.add_argument( - "--emb_decay", - default=1e+2, - type=float, - help="Embedding decay factor." + "--emb_decay", default=1e2, type=float, help="Embedding decay factor." ) parser.add_argument( - "--max_grad_norm", - default=1.0, - type=float, - help="Max gradient norm." + "--max_grad_norm", default=1.0, type=float, help="Max gradient norm." ) parser.add_argument( "--noise_timesteps", @@ -622,7 +600,7 @@ def parse_args(): "--config", type=str, default=None, - help="Path to a JSON configuration file containing arguments for invoking this script." + help="Path to a JSON configuration file containing arguments for invoking this script.", ) args = parser.parse_args() @@ -649,29 +627,44 @@ def parse_args(): args.placeholder_tokens = [args.placeholder_tokens] if isinstance(args.initializer_tokens, str): - args.initializer_tokens = [args.initializer_tokens] * len(args.placeholder_tokens) + args.initializer_tokens = [args.initializer_tokens] * len( + args.placeholder_tokens + ) if len(args.placeholder_tokens) == 0: - args.placeholder_tokens = [f"<*{i}>" for i in range(len(args.initializer_tokens))] + args.placeholder_tokens = [ + f"<*{i}>" for i in range(len(args.initializer_tokens)) + ] if len(args.initializer_tokens) == 0: args.initializer_tokens = args.placeholder_tokens.copy() if len(args.placeholder_tokens) != len(args.initializer_tokens): - raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") + raise ValueError( + "--placeholder_tokens and --initializer_tokens must have the same number of items" + ) if isinstance(args.inverted_initializer_tokens, str): - args.inverted_initializer_tokens = [args.inverted_initializer_tokens] * len(args.placeholder_tokens) + args.inverted_initializer_tokens = [args.inverted_initializer_tokens] * len( + args.placeholder_tokens + ) - if isinstance(args.inverted_initializer_tokens, list) and len(args.inverted_initializer_tokens) != 0: + if ( + isinstance(args.inverted_initializer_tokens, list) + and len(args.inverted_initializer_tokens) != 0 + ): args.placeholder_tokens += [f"inv_{t}" for t in args.placeholder_tokens] args.initializer_tokens += args.inverted_initializer_tokens if isinstance(args.num_vectors, int): args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) - if isinstance(args.num_vectors, list) and len(args.placeholder_tokens) != len(args.num_vectors): - raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") + if isinstance(args.num_vectors, list) and len(args.placeholder_tokens) != len( + args.num_vectors + ): + raise ValueError( + "--placeholder_tokens and --num_vectors must have the same number of items" + ) if args.alias_tokens is None: args.alias_tokens = [] @@ -695,15 +688,15 @@ def parse_args(): raise ValueError("You must specify --output_dir") if args.adam_beta1 is None: - if args.optimizer in ('adam', 'adam8bit'): + if args.optimizer in ("adam", "adam8bit", "dadam"): args.adam_beta1 = 0.9 - elif args.optimizer == 'lion': + elif args.optimizer in ("lion", "dlion"): args.adam_beta1 = 0.95 if args.adam_beta2 is None: - if args.optimizer in ('adam', 'adam8bit'): + if args.optimizer in ("adam", "adam8bit", "dadam"): args.adam_beta2 = 0.999 - elif args.optimizer == 'lion': + elif args.optimizer in ("lion", "dlion"): args.adam_beta2 = 0.98 return args @@ -719,7 +712,7 @@ def main(): accelerator = Accelerator( log_with=LoggerType.TENSORBOARD, project_dir=f"{output_dir}", - mixed_precision=args.mixed_precision + mixed_precision=args.mixed_precision, ) weight_dtype = torch.float32 @@ -728,6 +721,8 @@ def main(): elif args.mixed_precision == "bf16": weight_dtype = torch.bfloat16 + patch_xformers(weight_dtype) + logging.basicConfig(filename=output_dir / "log.txt", level=logging.DEBUG) if args.seed is None: @@ -737,12 +732,18 @@ def main(): save_args(output_dir, args) - tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler = get_models(args.pretrained_model_name_or_path) - schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, noise_scheduler.config.num_train_timesteps) - + tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler = get_models( + args.pretrained_model_name_or_path + ) + schedule_sampler = create_named_schedule_sampler( + args.schedule_sampler, noise_scheduler.config.num_train_timesteps + ) + def ensure_embeddings(): if args.lora_text_encoder_emb: - raise ValueError("Can't use TI options when training token embeddings with LoRA") + raise ValueError( + "Can't use TI options when training token embeddings with LoRA" + ) return patch_managed_embeddings(text_encoder, args.emb_alpha, args.emb_dropout) unet_config = LoraConfig( @@ -757,7 +758,9 @@ def main(): text_encoder_config = LoraConfig( r=args.lora_text_encoder_r, lora_alpha=args.lora_text_encoder_alpha, - target_modules=TEXT_ENCODER_TARGET_MODULES_WITH_EMBEDDING if args.lora_text_encoder_emb else TEXT_ENCODER_TARGET_MODULES, + target_modules=TEXT_ENCODER_TARGET_MODULES_WITH_EMBEDDING + if args.lora_text_encoder_emb + else TEXT_ENCODER_TARGET_MODULES, lora_dropout=args.lora_text_encoder_dropout, bias=args.lora_text_encoder_bias, ) @@ -787,7 +790,7 @@ def main(): if len(args.alias_tokens) != 0: embeddings = ensure_embeddings() - + alias_placeholder_tokens = args.alias_tokens[::2] alias_initializer_tokens = args.alias_tokens[1::2] @@ -795,27 +798,33 @@ def main(): tokenizer=tokenizer, embeddings=embeddings, placeholder_tokens=alias_placeholder_tokens, - initializer_tokens=alias_initializer_tokens + initializer_tokens=alias_initializer_tokens, ) embeddings.persist() - print(f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}") + print( + f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}" + ) placeholder_tokens = [] placeholder_token_ids = [] if args.embeddings_dir is not None: embeddings = ensure_embeddings() - + embeddings_dir = Path(args.embeddings_dir) if not embeddings_dir.exists() or not embeddings_dir.is_dir(): raise ValueError("--embeddings_dir must point to an existing directory") - added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) + added_tokens, added_ids = load_embeddings_from_dir( + tokenizer, embeddings, embeddings_dir + ) placeholder_tokens = added_tokens placeholder_token_ids = added_ids - print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") + print( + f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}" + ) if args.train_dir_embeddings: print("Training embeddings from embeddings dir") @@ -824,7 +833,7 @@ def main(): if len(args.placeholder_tokens) != 0 and not args.train_dir_embeddings: embeddings = ensure_embeddings() - + placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( tokenizer=tokenizer, embeddings=embeddings, @@ -836,23 +845,34 @@ def main(): placeholder_tokens = args.placeholder_tokens - stats = list(zip( - placeholder_tokens, placeholder_token_ids, args.initializer_tokens, initializer_token_ids - )) + stats = list( + zip( + placeholder_tokens, + placeholder_token_ids, + args.initializer_tokens, + initializer_token_ids, + ) + ) print(f"Training embeddings: {stats}") if args.scale_lr: args.learning_rate_unet = ( - args.learning_rate_unet * args.gradient_accumulation_steps * - args.train_batch_size * accelerator.num_processes + args.learning_rate_unet + * args.gradient_accumulation_steps + * args.train_batch_size + * accelerator.num_processes ) args.learning_rate_text = ( - args.learning_rate_text * args.gradient_accumulation_steps * - args.train_batch_size * accelerator.num_processes + args.learning_rate_text + * args.gradient_accumulation_steps + * args.train_batch_size + * accelerator.num_processes ) args.learning_rate_emb = ( - args.learning_rate_emb * args.gradient_accumulation_steps * - args.train_batch_size * accelerator.num_processes + args.learning_rate_emb + * args.gradient_accumulation_steps + * args.train_batch_size + * accelerator.num_processes ) if args.find_lr: @@ -861,11 +881,13 @@ def main(): args.learning_rate_emb = 1e-6 args.lr_scheduler = "exponential_growth" - if args.optimizer == 'adam8bit': + if args.optimizer == "adam8bit": try: import bitsandbytes as bnb except ImportError: - raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.") + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) create_optimizer = partial( bnb.optim.AdamW8bit, @@ -874,7 +896,7 @@ def main(): eps=args.adam_epsilon, amsgrad=args.adam_amsgrad, ) - elif args.optimizer == 'adam': + elif args.optimizer == "adam": create_optimizer = partial( torch.optim.AdamW, betas=(args.adam_beta1, args.adam_beta2), @@ -882,11 +904,13 @@ def main(): eps=args.adam_epsilon, amsgrad=args.adam_amsgrad, ) - elif args.optimizer == 'adan': + elif args.optimizer == "adan": try: import timm.optim except ImportError: - raise ImportError("To use Adan, please install the PyTorch Image Models library: `pip install timm`.") + raise ImportError( + "To use Adan, please install the PyTorch Image Models library: `pip install timm`." + ) create_optimizer = partial( timm.optim.Adan, @@ -894,11 +918,13 @@ def main(): eps=args.adam_epsilon, no_prox=True, ) - elif args.optimizer == 'lion': + elif args.optimizer == "lion": try: import lion_pytorch except ImportError: - raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion-pytorch`.") + raise ImportError( + "To use Lion, please install the lion_pytorch library: `pip install lion-pytorch`." + ) create_optimizer = partial( lion_pytorch.Lion, @@ -906,7 +932,7 @@ def main(): weight_decay=args.adam_weight_decay, use_triton=True, ) - elif args.optimizer == 'adafactor': + elif args.optimizer == "adafactor": create_optimizer = partial( transformers.optimization.Adafactor, weight_decay=args.adam_weight_decay, @@ -920,11 +946,13 @@ def main(): args.learning_rate_unet = None args.learning_rate_text = None args.learning_rate_emb = None - elif args.optimizer == 'dadam': + elif args.optimizer == "dadam": try: import dadaptation except ImportError: - raise ImportError("To use DAdaptAdam, please install the dadaptation library: `pip install dadaptation`.") + raise ImportError( + "To use DAdaptAdam, please install the dadaptation library: `pip install dadaptation`." + ) create_optimizer = partial( dadaptation.DAdaptAdam, @@ -933,29 +961,35 @@ def main(): eps=args.adam_epsilon, decouple=True, d0=args.dadaptation_d0, + growth_rate=args.dadaptation_growth_rate, ) args.learning_rate_unet = 1.0 args.learning_rate_text = 1.0 args.learning_rate_emb = 1.0 - elif args.optimizer == 'dadan': + elif args.optimizer == "dadan": try: import dadaptation except ImportError: - raise ImportError("To use DAdaptAdan, please install the dadaptation library: `pip install dadaptation`.") + raise ImportError( + "To use DAdaptAdan, please install the dadaptation library: `pip install dadaptation`." + ) create_optimizer = partial( dadaptation.DAdaptAdan, weight_decay=args.adam_weight_decay, eps=args.adam_epsilon, d0=args.dadaptation_d0, + growth_rate=args.dadaptation_growth_rate, ) args.learning_rate_unet = 1.0 args.learning_rate_text = 1.0 args.learning_rate_emb = 1.0 + elif args.optimizer == "dlion": + raise ImportError("DLion has not been merged into dadaptation yet") else: - raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") + raise ValueError(f'Unknown --optimizer "{args.optimizer}"') trainer = partial( train, @@ -1026,25 +1060,33 @@ def main(): if args.run_pti and len(placeholder_tokens) != 0: embeddings = ensure_embeddings() - - filter_tokens = [token for token in args.filter_tokens if token in placeholder_tokens] + + filter_tokens = [ + token for token in args.filter_tokens if token in placeholder_tokens + ] pti_datamodule = create_datamodule( valid_set_size=0, batch_size=args.train_batch_size, - filter=partial(keyword_filter, filter_tokens, args.collection, args.exclude_collections), + filter=partial( + keyword_filter, filter_tokens, args.collection, args.exclude_collections + ), ) pti_datamodule.setup() num_train_epochs = args.num_train_epochs pti_sample_frequency = args.sample_frequency if num_train_epochs is None: - num_train_epochs = math.ceil( - args.num_train_steps / len(pti_datamodule.train_dataset) - ) * args.gradient_accumulation_steps - pti_sample_frequency = math.ceil(num_train_epochs * (pti_sample_frequency / args.num_train_steps)) + num_train_epochs = ( + math.ceil(args.num_train_steps / len(pti_datamodule.train_dataset)) + * args.gradient_accumulation_steps + ) + pti_sample_frequency = math.ceil( + num_train_epochs * (pti_sample_frequency / args.num_train_steps) + ) num_training_steps_per_epoch = math.ceil( - len(pti_datamodule.train_dataset) / args.gradient_accumulation_steps) + len(pti_datamodule.train_dataset) / args.gradient_accumulation_steps + ) num_train_steps = num_training_steps_per_epoch * num_train_epochs if args.sample_num is not None: pti_sample_frequency = math.ceil(num_train_epochs / args.sample_num) @@ -1060,11 +1102,15 @@ def main(): print(f"============ PTI ============") print("") - pti_optimizer = create_optimizer([{ - "params": text_encoder.text_model.embeddings.token_embedding.parameters(), - "lr": args.learning_rate_emb, - "weight_decay": 0, - }]) + pti_optimizer = create_optimizer( + [ + { + "params": text_encoder.text_model.embeddings.token_embedding.parameters(), + "lr": args.learning_rate_emb, + "weight_decay": 0, + } + ] + ) pti_lr_scheduler = create_lr_scheduler( "constant_with_warmup", @@ -1113,11 +1159,16 @@ def main(): num_train_epochs = args.num_train_epochs lora_sample_frequency = args.sample_frequency if num_train_epochs is None: - num_train_epochs = math.ceil( - args.num_train_steps / len(lora_datamodule.train_dataset) - ) * args.gradient_accumulation_steps - lora_sample_frequency = math.ceil(num_train_epochs * (lora_sample_frequency / args.num_train_steps)) - num_training_steps_per_epoch = math.ceil(len(lora_datamodule.train_dataset) / args.gradient_accumulation_steps) + num_train_epochs = ( + math.ceil(args.num_train_steps / len(lora_datamodule.train_dataset)) + * args.gradient_accumulation_steps + ) + lora_sample_frequency = math.ceil( + num_train_epochs * (lora_sample_frequency / args.num_train_steps) + ) + num_training_steps_per_epoch = math.ceil( + len(lora_datamodule.train_dataset) / args.gradient_accumulation_steps + ) num_train_steps = num_training_steps_per_epoch * num_train_epochs if args.sample_num is not None: lora_sample_frequency = math.ceil(num_train_epochs / args.sample_num) @@ -1131,7 +1182,6 @@ def main(): training_iter = 0 auto_cycles = list(args.auto_cycles) - learning_rate_emb = args.learning_rate_emb learning_rate_unet = args.learning_rate_unet learning_rate_text = args.learning_rate_text lr_scheduler = args.lr_scheduler @@ -1145,21 +1195,15 @@ def main(): params_to_optimize = [ { - "params": ( - param - for param in unet.parameters() - if param.requires_grad - ), + "params": (param for param in unet.parameters() if param.requires_grad), "lr": learning_rate_unet, }, { "params": ( - param - for param in text_encoder.parameters() - if param.requires_grad + param for param in text_encoder.parameters() if param.requires_grad ), "lr": learning_rate_text, - } + }, ] group_labels = ["unet", "text"] @@ -1169,19 +1213,26 @@ def main(): if len(auto_cycles) != 0: response = auto_cycles.pop(0) else: - response = input("\n### Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> ") + response = input( + "\n### Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> " + ) if response.lower().strip() == "o": - if args.learning_rate_emb is not None: - learning_rate_emb = args.learning_rate_emb * 2 if args.learning_rate_unet is not None: - learning_rate_unet = args.learning_rate_unet * 2 + learning_rate_unet = ( + args.learning_rate_unet * 2 * (args.cycle_decay**training_iter) + ) if args.learning_rate_text is not None: - learning_rate_text = args.learning_rate_text * 2 + learning_rate_text = ( + args.learning_rate_text * 2 * (args.cycle_decay**training_iter) + ) else: - learning_rate_emb = args.learning_rate_emb - learning_rate_unet = args.learning_rate_unet - learning_rate_text = args.learning_rate_text + learning_rate_unet = args.learning_rate_unet * ( + args.cycle_decay**training_iter + ) + learning_rate_text = args.learning_rate_text * ( + args.cycle_decay**training_iter + ) if response.lower().strip() == "o": lr_scheduler = "one_cycle" @@ -1204,9 +1255,11 @@ def main(): print("") print(f"============ LoRA cycle {training_iter + 1}: {response} ============") print("") - - for group, lr in zip(lora_optimizer.param_groups, [learning_rate_unet, learning_rate_text]): - group['lr'] = lr + + for group, lr in zip( + lora_optimizer.param_groups, [learning_rate_unet, learning_rate_text] + ): + group["lr"] = lr lora_lr_scheduler = create_lr_scheduler( lr_scheduler, @@ -1218,7 +1271,9 @@ def main(): warmup_epochs=lr_warmup_epochs, ) - lora_checkpoint_output_dir = output_dir / lora_project / f"model_{training_iter}" + lora_checkpoint_output_dir = ( + output_dir / lora_project / f"model_{training_iter}" + ) trainer( strategy=lora_strategy, @@ -1246,12 +1301,6 @@ def main(): ) training_iter += 1 - if learning_rate_emb is not None: - learning_rate_emb *= args.cycle_decay - if learning_rate_unet is not None: - learning_rate_unet *= args.cycle_decay - if learning_rate_text is not None: - learning_rate_text *= args.cycle_decay accelerator.end_training() diff --git a/train_ti.py b/train_ti.py index f60e3e5..c6f0b3a 100644 --- a/train_ti.py +++ b/train_ti.py @@ -32,7 +32,7 @@ from util.files import load_config, load_embeddings_from_dir logger = get_logger(__name__) -warnings.filterwarnings('ignore') +warnings.filterwarnings("ignore") torch.backends.cuda.matmul.allow_tf32 = True @@ -46,9 +46,7 @@ hidet.torch.dynamo_config.search_space(0) def parse_args(): - parser = argparse.ArgumentParser( - description="Simple example of a training script." - ) + parser = argparse.ArgumentParser(description="Simple example of a training script.") parser.add_argument( "--pretrained_model_name_or_path", type=str, @@ -65,12 +63,12 @@ def parse_args(): "--train_data_file", type=str, default=None, - help="A CSV file containing the training data." + help="A CSV file containing the training data.", ) parser.add_argument( "--train_data_template", type=str, - nargs='*', + nargs="*", default="template", ) parser.add_argument( @@ -80,59 +78,47 @@ def parse_args(): help="The name of the current project.", ) parser.add_argument( - "--auto_cycles", - type=str, - default="o", - help="Cycles to run automatically." + "--auto_cycles", type=str, default="o", help="Cycles to run automatically." ) parser.add_argument( - "--cycle_decay", - type=float, - default=1.0, - help="Learning rate decay per cycle." + "--cycle_decay", type=float, default=1.0, help="Learning rate decay per cycle." ) parser.add_argument( "--placeholder_tokens", type=str, - nargs='*', + nargs="*", help="A token to use as a placeholder for the concept.", ) parser.add_argument( "--initializer_tokens", type=str, - nargs='*', - help="A token to use as initializer word." + nargs="*", + help="A token to use as initializer word.", ) parser.add_argument( - "--filter_tokens", - type=str, - nargs='*', - help="Tokens to filter the dataset by." + "--filter_tokens", type=str, nargs="*", help="Tokens to filter the dataset by." ) parser.add_argument( "--initializer_noise", type=float, default=0, - help="Noise to apply to the initializer word" + help="Noise to apply to the initializer word", ) parser.add_argument( "--alias_tokens", type=str, - nargs='*', + nargs="*", default=[], - help="Tokens to create an alias for." + help="Tokens to create an alias for.", ) parser.add_argument( "--inverted_initializer_tokens", type=str, - nargs='*', - help="A token to use as initializer word." + nargs="*", + help="A token to use as initializer word.", ) parser.add_argument( - "--num_vectors", - type=int, - nargs='*', - help="Number of vectors per embedding." + "--num_vectors", type=int, nargs="*", help="Number of vectors per embedding." ) parser.add_argument( "--sequential", @@ -147,7 +133,7 @@ def parse_args(): "--num_class_images", type=int, default=0, - help="How many class images to generate." + help="How many class images to generate.", ) parser.add_argument( "--class_image_dir", @@ -158,7 +144,7 @@ def parse_args(): parser.add_argument( "--exclude_collections", type=str, - nargs='*', + nargs="*", help="Exclude all items with a listed collection.", ) parser.add_argument( @@ -181,14 +167,11 @@ def parse_args(): parser.add_argument( "--collection", type=str, - nargs='*', + nargs="*", help="A collection to filter the dataset.", ) parser.add_argument( - "--seed", - type=int, - default=None, - help="A seed for reproducible training." + "--seed", type=int, default=None, help="A seed for reproducible training." ) parser.add_argument( "--resolution", @@ -244,7 +227,7 @@ def parse_args(): type=str, default="auto", choices=["all", "trailing", "leading", "between", "auto", "off"], - help='Vector shuffling algorithm.', + help="Vector shuffling algorithm.", ) parser.add_argument( "--offset_noise_strength", @@ -256,18 +239,10 @@ def parse_args(): "--input_pertubation", type=float, default=0, - help="The scale of input pretubation. Recommended 0.1." - ) - parser.add_argument( - "--num_train_epochs", - type=int, - default=None - ) - parser.add_argument( - "--num_train_steps", - type=int, - default=2000 + help="The scale of input pretubation. Recommended 0.1.", ) + parser.add_argument("--num_train_epochs", type=int, default=None) + parser.add_argument("--num_train_steps", type=int, default=2000) parser.add_argument( "--gradient_accumulation_steps", type=int, @@ -299,27 +274,31 @@ def parse_args(): "--lr_scheduler", type=str, default="one_cycle", - choices=["linear", "cosine", "cosine_with_restarts", "polynomial", - "constant", "constant_with_warmup", "one_cycle"], - help='The scheduler type to use.', + choices=[ + "linear", + "cosine", + "cosine_with_restarts", + "polynomial", + "constant", + "constant_with_warmup", + "one_cycle", + ], + help="The scheduler type to use.", ) parser.add_argument( "--lr_warmup_epochs", type=int, default=10, - help="Number of steps for the warmup in the lr scheduler." + help="Number of steps for the warmup in the lr scheduler.", ) parser.add_argument( - "--lr_mid_point", - type=float, - default=0.3, - help="OneCycle schedule mid point." + "--lr_mid_point", type=float, default=0.3, help="OneCycle schedule mid point." ) parser.add_argument( "--lr_cycles", type=int, default=None, - help="Number of restart cycles in the lr scheduler." + help="Number of restart cycles in the lr scheduler.", ) parser.add_argument( "--lr_warmup_func", @@ -331,7 +310,7 @@ def parse_args(): "--lr_warmup_exp", type=int, default=1, - help='If lr_warmup_func is "cos", exponent to modify the function' + help='If lr_warmup_func is "cos", exponent to modify the function', ) parser.add_argument( "--lr_annealing_func", @@ -343,89 +322,67 @@ def parse_args(): "--lr_annealing_exp", type=int, default=1, - help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' + help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function', ) parser.add_argument( "--lr_min_lr", type=float, default=0.04, - help="Minimum learning rate in the lr scheduler." + help="Minimum learning rate in the lr scheduler.", ) parser.add_argument( - "--use_ema", - action="store_true", - help="Whether to use EMA model." - ) - parser.add_argument( - "--ema_inv_gamma", - type=float, - default=1.0 - ) - parser.add_argument( - "--ema_power", - type=float, - default=4/5 - ) - parser.add_argument( - "--ema_max_decay", - type=float, - default=0.9999 - ) - parser.add_argument( - "--min_snr_gamma", - type=int, - default=5, - help="MinSNR gamma." + "--use_ema", action="store_true", help="Whether to use EMA model." ) + parser.add_argument("--ema_inv_gamma", type=float, default=1.0) + parser.add_argument("--ema_power", type=float, default=4 / 5) + parser.add_argument("--ema_max_decay", type=float, default=0.9999) + parser.add_argument("--min_snr_gamma", type=int, default=5, help="MinSNR gamma.") parser.add_argument( "--schedule_sampler", type=str, default="uniform", choices=["uniform", "loss-second-moment"], - help="Noise schedule sampler." + help="Noise schedule sampler.", ) parser.add_argument( "--optimizer", type=str, default="adan", choices=["adam", "adam8bit", "adan", "lion", "dadam", "dadan", "adafactor"], - help='Optimizer to use' + help="Optimizer to use", ) parser.add_argument( "--dadaptation_d0", type=float, default=1e-6, - help="The d0 parameter for Dadaptation optimizers." + help="The d0 parameter for Dadaptation optimizers.", ) parser.add_argument( "--adam_beta1", type=float, default=None, - help="The beta1 parameter for the Adam optimizer." + help="The beta1 parameter for the Adam optimizer.", ) parser.add_argument( "--adam_beta2", type=float, default=None, - help="The beta2 parameter for the Adam optimizer." + help="The beta2 parameter for the Adam optimizer.", ) parser.add_argument( - "--adam_weight_decay", - type=float, - default=2e-2, - help="Weight decay to use." + "--adam_weight_decay", type=float, default=2e-2, help="Weight decay to use." ) parser.add_argument( "--adam_epsilon", type=float, default=1e-08, - help="Epsilon value for the Adam optimizer" + help="Epsilon value for the Adam optimizer", ) parser.add_argument( "--adam_amsgrad", type=bool, default=False, - help="Amsgrad value for the Adam optimizer" + help="Amsgrad value for the Adam optimizer", ) parser.add_argument( "--mixed_precision", @@ -456,7 +413,7 @@ def parse_args(): ) parser.add_argument( "--no_milestone_checkpoints", - action='store_true', + action="store_true", help="If checkpoints are saved on maximum accuracy", ) parser.add_argument( @@ -493,25 +450,25 @@ def parse_args(): "--valid_set_size", type=int, default=None, - help="Number of images in the validation dataset." + help="Number of images in the validation dataset.", ) parser.add_argument( "--train_set_pad", type=int, default=None, - help="The number to fill train dataset items up to." + help="The number to fill train dataset items up to.", ) parser.add_argument( "--valid_set_pad", type=int, default=None, - help="The number to fill validation dataset items up to." + help="The number to fill validation dataset items up to.", ) parser.add_argument( "--train_batch_size", type=int, default=1, - help="Batch size (per device) for the training dataloader." + help="Batch size (per device) for the training dataloader.", ) parser.add_argument( "--sample_steps", @@ -523,14 +480,9 @@ def parse_args(): "--prior_loss_weight", type=float, default=1.0, - help="The weight of prior preservation loss." - ) - parser.add_argument( - "--emb_alpha", - type=float, - default=1.0, - help="Embedding alpha" + help="The weight of prior preservation loss.", ) + parser.add_argument("--emb_alpha", type=float, default=1.0, help="Embedding alpha") parser.add_argument( "--emb_dropout", type=float, @@ -538,21 +490,13 @@ def parse_args(): help="Embedding dropout probability.", ) parser.add_argument( - "--use_emb_decay", - action="store_true", - help="Whether to use embedding decay." + "--use_emb_decay", action="store_true", help="Whether to use embedding decay." ) parser.add_argument( - "--emb_decay_target", - default=0.4, - type=float, - help="Embedding decay target." + "--emb_decay_target", default=0.4, type=float, help="Embedding decay target." ) parser.add_argument( - "--emb_decay", - default=1e+2, - type=float, - help="Embedding decay factor." + "--emb_decay", default=1e2, type=float, help="Embedding decay factor." ) parser.add_argument( "--noise_timesteps", @@ -563,7 +507,7 @@ def parse_args(): "--resume_from", type=str, default=None, - help="Path to a directory to resume training from (ie, logs/token_name/2022-09-22T23-36-27)" + help="Path to a directory to resume training from (ie, logs/token_name/2022-09-22T23-36-27)", ) parser.add_argument( "--global_step", @@ -574,7 +518,7 @@ def parse_args(): "--config", type=str, default=None, - help="Path to a JSON configuration file containing arguments for invoking this script." + help="Path to a JSON configuration file containing arguments for invoking this script.", ) args = parser.parse_args() @@ -595,29 +539,44 @@ def parse_args(): args.placeholder_tokens = [args.placeholder_tokens] if isinstance(args.initializer_tokens, str): - args.initializer_tokens = [args.initializer_tokens] * len(args.placeholder_tokens) + args.initializer_tokens = [args.initializer_tokens] * len( + args.placeholder_tokens + ) if len(args.placeholder_tokens) == 0: - args.placeholder_tokens = [f"<*{i}>" for i in range(len(args.initializer_tokens))] + args.placeholder_tokens = [ + f"<*{i}>" for i in range(len(args.initializer_tokens)) + ] if len(args.initializer_tokens) == 0: args.initializer_tokens = args.placeholder_tokens.copy() if len(args.placeholder_tokens) != len(args.initializer_tokens): - raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") + raise ValueError( + "--placeholder_tokens and --initializer_tokens must have the same number of items" + ) if isinstance(args.inverted_initializer_tokens, str): - args.inverted_initializer_tokens = [args.inverted_initializer_tokens] * len(args.placeholder_tokens) + args.inverted_initializer_tokens = [args.inverted_initializer_tokens] * len( + args.placeholder_tokens + ) - if isinstance(args.inverted_initializer_tokens, list) and len(args.inverted_initializer_tokens) != 0: + if ( + isinstance(args.inverted_initializer_tokens, list) + and len(args.inverted_initializer_tokens) != 0 + ): args.placeholder_tokens += [f"inv_{t}" for t in args.placeholder_tokens] args.initializer_tokens += args.inverted_initializer_tokens if isinstance(args.num_vectors, int): args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) - if isinstance(args.num_vectors, list) and len(args.placeholder_tokens) != len(args.num_vectors): - raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") + if isinstance(args.num_vectors, list) and len(args.placeholder_tokens) != len( + args.num_vectors + ): + raise ValueError( + "--placeholder_tokens and --num_vectors must have the same number of items" + ) if args.alias_tokens is None: args.alias_tokens = [] @@ -639,16 +598,22 @@ def parse_args(): ] if isinstance(args.train_data_template, str): - args.train_data_template = [args.train_data_template] * len(args.placeholder_tokens) + args.train_data_template = [args.train_data_template] * len( + args.placeholder_tokens + ) if len(args.placeholder_tokens) != len(args.train_data_template): - raise ValueError("--placeholder_tokens and --train_data_template must have the same number of items") + raise ValueError( + "--placeholder_tokens and --train_data_template must have the same number of items" + ) if args.num_vectors is None: args.num_vectors = [None] * len(args.placeholder_tokens) else: if isinstance(args.train_data_template, list): - raise ValueError("--train_data_template can't be a list in simultaneous mode") + raise ValueError( + "--train_data_template can't be a list in simultaneous mode" + ) if isinstance(args.collection, str): args.collection = [args.collection] @@ -660,13 +625,13 @@ def parse_args(): raise ValueError("You must specify --output_dir") if args.adam_beta1 is None: - if args.optimizer == 'lion': + if args.optimizer == "lion": args.adam_beta1 = 0.95 else: args.adam_beta1 = 0.9 if args.adam_beta2 is None: - if args.optimizer == 'lion': + if args.optimizer == "lion": args.adam_beta2 = 0.98 else: args.adam_beta2 = 0.999 @@ -679,13 +644,13 @@ def main(): global_step_offset = args.global_step now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") - output_dir = Path(args.output_dir)/slugify(args.project)/now + output_dir = Path(args.output_dir) / slugify(args.project) / now output_dir.mkdir(parents=True, exist_ok=True) accelerator = Accelerator( log_with=LoggerType.TENSORBOARD, project_dir=f"{output_dir}", - mixed_precision=args.mixed_precision + mixed_precision=args.mixed_precision, ) weight_dtype = torch.float32 @@ -703,9 +668,15 @@ def main(): save_args(output_dir, args) - tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler = get_models(args.pretrained_model_name_or_path) - embeddings = patch_managed_embeddings(text_encoder, args.emb_alpha, args.emb_dropout) - schedule_sampler = create_named_schedule_sampler(args.schedule_sampler, noise_scheduler.config.num_train_timesteps) + tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler = get_models( + args.pretrained_model_name_or_path + ) + embeddings = patch_managed_embeddings( + text_encoder, args.emb_alpha, args.emb_dropout + ) + schedule_sampler = create_named_schedule_sampler( + args.schedule_sampler, noise_scheduler.config.num_train_timesteps + ) tokenizer.set_use_vector_shuffle(args.vector_shuffle) tokenizer.set_dropout(args.vector_dropout) @@ -717,16 +688,16 @@ def main(): unet.enable_xformers_memory_efficient_attention() elif args.compile_unet: unet.mid_block.attentions[0].transformer_blocks[0].attn1._use_2_0_attn = False - + proc = AttnProcessor() - + def fn_recursive_set_proc(module: torch.nn.Module): if hasattr(module, "processor"): module.processor = proc - + for child in module.children(): fn_recursive_set_proc(child) - + fn_recursive_set_proc(unet) if args.gradient_checkpointing: @@ -751,18 +722,24 @@ def main(): tokenizer=tokenizer, embeddings=embeddings, placeholder_tokens=alias_placeholder_tokens, - initializer_tokens=alias_initializer_tokens + initializer_tokens=alias_initializer_tokens, ) embeddings.persist() - print(f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}") + print( + f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}" + ) if args.embeddings_dir is not None: embeddings_dir = Path(args.embeddings_dir) if not embeddings_dir.exists() or not embeddings_dir.is_dir(): raise ValueError("--embeddings_dir must point to an existing directory") - added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) - print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") + added_tokens, added_ids = load_embeddings_from_dir( + tokenizer, embeddings, embeddings_dir + ) + print( + f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}" + ) if args.train_dir_embeddings: args.placeholder_tokens = added_tokens @@ -772,19 +749,23 @@ def main(): if args.scale_lr: args.learning_rate = ( - args.learning_rate * args.gradient_accumulation_steps * - args.train_batch_size * accelerator.num_processes + args.learning_rate + * args.gradient_accumulation_steps + * args.train_batch_size + * accelerator.num_processes ) if args.find_lr: args.learning_rate = 1e-5 args.lr_scheduler = "exponential_growth" - if args.optimizer == 'adam8bit': + if args.optimizer == "adam8bit": try: import bitsandbytes as bnb except ImportError: - raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.") + raise ImportError( + "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." + ) create_optimizer = partial( bnb.optim.AdamW8bit, @@ -793,7 +774,7 @@ def main(): eps=args.adam_epsilon, amsgrad=args.adam_amsgrad, ) - elif args.optimizer == 'adam': + elif args.optimizer == "adam": create_optimizer = partial( torch.optim.AdamW, betas=(args.adam_beta1, args.adam_beta2), @@ -801,11 +782,13 @@ def main(): eps=args.adam_epsilon, amsgrad=args.adam_amsgrad, ) - elif args.optimizer == 'adan': + elif args.optimizer == "adan": try: import timm.optim except ImportError: - raise ImportError("To use Adan, please install the PyTorch Image Models library: `pip install timm`.") + raise ImportError( + "To use Adan, please install the PyTorch Image Models library: `pip install timm`." + ) create_optimizer = partial( timm.optim.Adan, @@ -813,11 +796,13 @@ def main(): eps=args.adam_epsilon, no_prox=True, ) - elif args.optimizer == 'lion': + elif args.optimizer == "lion": try: import lion_pytorch except ImportError: - raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion-pytorch`.") + raise ImportError( + "To use Lion, please install the lion_pytorch library: `pip install lion-pytorch`." + ) create_optimizer = partial( lion_pytorch.Lion, @@ -825,7 +810,7 @@ def main(): weight_decay=args.adam_weight_decay, use_triton=True, ) - elif args.optimizer == 'adafactor': + elif args.optimizer == "adafactor": create_optimizer = partial( transformers.optimization.Adafactor, weight_decay=args.adam_weight_decay, @@ -837,11 +822,13 @@ def main(): args.lr_scheduler = "adafactor" args.lr_min_lr = args.learning_rate args.learning_rate = None - elif args.optimizer == 'dadam': + elif args.optimizer == "dadam": try: import dadaptation except ImportError: - raise ImportError("To use DAdaptAdam, please install the dadaptation library: `pip install dadaptation`.") + raise ImportError( + "To use DAdaptAdam, please install the dadaptation library: `pip install dadaptation`." + ) create_optimizer = partial( dadaptation.DAdaptAdam, @@ -851,11 +838,13 @@ def main(): decouple=True, d0=args.dadaptation_d0, ) - elif args.optimizer == 'dadan': + elif args.optimizer == "dadan": try: import dadaptation except ImportError: - raise ImportError("To use DAdaptAdan, please install the dadaptation library: `pip install dadaptation`.") + raise ImportError( + "To use DAdaptAdan, please install the dadaptation library: `pip install dadaptation`." + ) create_optimizer = partial( dadaptation.DAdaptAdan, @@ -864,7 +853,7 @@ def main(): d0=args.dadaptation_d0, ) else: - raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") + raise ValueError(f'Unknown --optimizer "{args.optimizer}"') trainer = partial( train, @@ -904,10 +893,21 @@ def main(): sample_image_size=args.sample_image_size, ) + optimizer = create_optimizer( + text_encoder.text_model.embeddings.token_embedding.parameters(), + lr=learning_rate, + ) + data_generator = torch.Generator(device="cpu").manual_seed(args.seed) data_npgenerator = np.random.default_rng(args.seed) - def run(i: int, placeholder_tokens: list[str], initializer_tokens: list[str], num_vectors: Union[int, list[int]], data_template: str): + def run( + i: int, + placeholder_tokens: list[str], + initializer_tokens: list[str], + num_vectors: Union[int, list[int]], + data_template: str, + ): placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( tokenizer=tokenizer, embeddings=embeddings, @@ -917,14 +917,23 @@ def main(): initializer_noise=args.initializer_noise, ) - stats = list(zip(placeholder_tokens, placeholder_token_ids, initializer_tokens, initializer_token_ids)) + stats = list( + zip( + placeholder_tokens, + placeholder_token_ids, + initializer_tokens, + initializer_token_ids, + ) + ) print("") print(f"============ TI batch {i + 1} ============") print("") print(stats) - filter_tokens = [token for token in args.filter_tokens if token in placeholder_tokens] + filter_tokens = [ + token for token in args.filter_tokens if token in placeholder_tokens + ] datamodule = VlpnDataModule( data_file=args.train_data_file, @@ -945,7 +954,9 @@ def main(): valid_set_size=args.valid_set_size, train_set_pad=args.train_set_pad, valid_set_pad=args.valid_set_pad, - filter=partial(keyword_filter, filter_tokens, args.collection, args.exclude_collections), + filter=partial( + keyword_filter, filter_tokens, args.collection, args.exclude_collections + ), dtype=weight_dtype, generator=data_generator, npgenerator=data_npgenerator, @@ -955,11 +966,16 @@ def main(): num_train_epochs = args.num_train_epochs sample_frequency = args.sample_frequency if num_train_epochs is None: - num_train_epochs = math.ceil( - args.num_train_steps / len(datamodule.train_dataset) - ) * args.gradient_accumulation_steps - sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) - num_training_steps_per_epoch = math.ceil(len(datamodule.train_dataset) / args.gradient_accumulation_steps) + num_train_epochs = ( + math.ceil(args.num_train_steps / len(datamodule.train_dataset)) + * args.gradient_accumulation_steps + ) + sample_frequency = math.ceil( + num_train_epochs * (sample_frequency / args.num_train_steps) + ) + num_training_steps_per_epoch = math.ceil( + len(datamodule.train_dataset) / args.gradient_accumulation_steps + ) num_train_steps = num_training_steps_per_epoch * num_train_epochs if args.sample_num is not None: sample_frequency = math.ceil(num_train_epochs / args.sample_num) @@ -988,7 +1004,8 @@ def main(): response = auto_cycles.pop(0) else: response = input( - "\n### Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> ") + "\n### Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> " + ) if response.lower().strip() == "o": if args.learning_rate is not None: @@ -1018,10 +1035,8 @@ def main(): print(f"------------ TI cycle {training_iter + 1}: {response} ------------") print("") - optimizer = create_optimizer( - text_encoder.text_model.embeddings.token_embedding.parameters(), - lr=learning_rate, - ) + for group, lr in zip(optimizer.param_groups, [learning_rate]): + group["lr"] = lr lr_scheduler = get_scheduler( lr_scheduler, @@ -1040,7 +1055,9 @@ def main(): mid_point=args.lr_mid_point, ) - checkpoint_output_dir = output_dir / project / f"checkpoints_{training_iter}" + checkpoint_output_dir = ( + output_dir / project / f"checkpoints_{training_iter}" + ) trainer( train_dataloader=datamodule.train_dataloader, @@ -1070,14 +1087,20 @@ def main(): accelerator.end_training() if not args.sequential: - run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template) + run( + 0, + args.placeholder_tokens, + args.initializer_tokens, + args.num_vectors, + args.train_data_template, + ) else: for i, placeholder_token, initializer_token, num_vectors, data_template in zip( range(len(args.placeholder_tokens)), args.placeholder_tokens, args.initializer_tokens, args.num_vectors, - args.train_data_template + args.train_data_template, ): run(i, [placeholder_token], [initializer_token], num_vectors, data_template) embeddings.persist() diff --git a/training/functional.py b/training/functional.py index fd3f9f4..f68faf9 100644 --- a/training/functional.py +++ b/training/functional.py @@ -14,7 +14,13 @@ import numpy as np from accelerate import Accelerator from transformers import CLIPTextModel -from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler, UniPCMultistepScheduler, SchedulerMixin +from diffusers import ( + AutoencoderKL, + UNet2DConditionModel, + DDPMScheduler, + UniPCMultistepScheduler, + SchedulerMixin, +) from tqdm.auto import tqdm @@ -33,11 +39,12 @@ from util.noise import perlin_noise def const(result=None): def fn(*args, **kwargs): return result + return fn @dataclass -class TrainingCallbacks(): +class TrainingCallbacks: on_log: Callable[[], dict[str, Any]] = const({}) on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) on_before_optimize: Callable[[int], Any] = const() @@ -58,23 +65,36 @@ class TrainingStrategyPrepareCallable(Protocol): train_dataloader: DataLoader, val_dataloader: Optional[DataLoader], lr_scheduler: torch.optim.lr_scheduler._LRScheduler, - **kwargs - ) -> Tuple: ... + **kwargs, + ) -> Tuple: + ... @dataclass -class TrainingStrategy(): +class TrainingStrategy: callbacks: Callable[..., TrainingCallbacks] prepare: TrainingStrategyPrepareCallable def get_models(pretrained_model_name_or_path: str, torch_dtype=torch.float32): - tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') - text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder', torch_dtype=torch_dtype) - vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae', torch_dtype=torch_dtype) - unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet', torch_dtype=torch_dtype) - noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') - sample_scheduler = UniPCMultistepScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') + tokenizer = MultiCLIPTokenizer.from_pretrained( + pretrained_model_name_or_path, subfolder="tokenizer" + ) + text_encoder = CLIPTextModel.from_pretrained( + pretrained_model_name_or_path, subfolder="text_encoder", torch_dtype=torch_dtype + ) + vae = AutoencoderKL.from_pretrained( + pretrained_model_name_or_path, subfolder="vae", torch_dtype=torch_dtype + ) + unet = UNet2DConditionModel.from_pretrained( + pretrained_model_name_or_path, subfolder="unet", torch_dtype=torch_dtype + ) + noise_scheduler = DDPMScheduler.from_pretrained( + pretrained_model_name_or_path, subfolder="scheduler" + ) + sample_scheduler = UniPCMultistepScheduler.from_pretrained( + pretrained_model_name_or_path, subfolder="scheduler" + ) return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler @@ -113,7 +133,9 @@ def save_samples( generator = torch.Generator(device=accelerator.device).manual_seed(seed) - datasets: list[tuple[str, DataLoader, Optional[torch.Generator]]] = [("train", train_dataloader, None)] + datasets: list[tuple[str, DataLoader, Optional[torch.Generator]]] = [ + ("train", train_dataloader, None) + ] if val_dataloader is not None: datasets.append(("stable", val_dataloader, generator)) @@ -124,17 +146,11 @@ def save_samples( file_path = output_dir / pool / f"step_{cycle}_{step}.jpg" file_path.parent.mkdir(parents=True, exist_ok=True) - batches = list(itertools.islice(itertools.cycle(data), batch_size * num_batches)) - prompt_ids = [ - prompt - for batch in batches - for prompt in batch["prompt_ids"] - ] - nprompt_ids = [ - prompt - for batch in batches - for prompt in batch["nprompt_ids"] - ] + batches = list( + itertools.islice(itertools.cycle(data), batch_size * num_batches) + ) + prompt_ids = [prompt for batch in batches for prompt in batch["prompt_ids"]] + nprompt_ids = [prompt for batch in batches for prompt in batch["nprompt_ids"]] with torch.inference_mode(): for i in range(num_batches): @@ -165,7 +181,9 @@ def save_samples( pass image_grid = make_grid(all_samples.permute(0, 3, 1, 2), grid_cols) - image_grid = pipeline.numpy_to_pil(image_grid.unsqueeze(0).permute(0, 2, 3, 1).numpy())[0] + image_grid = pipeline.numpy_to_pil( + image_grid.unsqueeze(0).permute(0, 2, 3, 1).numpy() + )[0] image_grid.save(file_path, quality=85) del generator, pipeline @@ -184,15 +202,17 @@ def generate_class_images( train_dataset: VlpnDataset, sample_batch_size: int, sample_image_size: int, - sample_steps: int + sample_steps: int, ): - missing_data = [item for item in train_dataset.items if not item.class_image_path.exists()] + missing_data = [ + item for item in train_dataset.items if not item.class_image_path.exists() + ] if len(missing_data) == 0: return batched_data = [ - missing_data[i:i+sample_batch_size] + missing_data[i : i + sample_batch_size] for i in range(0, len(missing_data), sample_batch_size) ] @@ -216,7 +236,7 @@ def generate_class_images( negative_prompt=nprompt, height=sample_image_size, width=sample_image_size, - num_inference_steps=sample_steps + num_inference_steps=sample_steps, ).images for i, image in enumerate(images): @@ -245,8 +265,12 @@ def add_placeholder_tokens( embeddings.resize(len(tokenizer)) - for (placeholder_token_id, initializer_token_id) in zip(placeholder_token_ids, initializer_token_ids): - embeddings.add_embed(placeholder_token_id, initializer_token_id, initializer_noise) + for placeholder_token_id, initializer_token_id in zip( + placeholder_token_ids, initializer_token_ids + ): + embeddings.add_embed( + placeholder_token_id, initializer_token_id, initializer_noise + ) return placeholder_token_ids, initializer_token_ids @@ -261,12 +285,16 @@ def compute_snr(timesteps, noise_scheduler): # Expand the tensors. # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 - sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float() + sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[ + timesteps + ].float() while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] alpha = sqrt_alphas_cumprod.expand(timesteps.shape) - sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float() + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to( + device=timesteps.device + )[timesteps].float() while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) @@ -277,21 +305,22 @@ def compute_snr(timesteps, noise_scheduler): def get_original( - noise_scheduler, - model_output, - sample: torch.FloatTensor, - timesteps: torch.IntTensor + noise_scheduler, model_output, sample: torch.FloatTensor, timesteps: torch.IntTensor ): alphas_cumprod = noise_scheduler.alphas_cumprod sqrt_alphas_cumprod = alphas_cumprod**0.5 sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 - sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float() + sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[ + timesteps + ].float() while len(sqrt_alphas_cumprod.shape) < len(sample.shape): sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] alpha = sqrt_alphas_cumprod.expand(sample.shape) - sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float() + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to( + device=timesteps.device + )[timesteps].float() while len(sqrt_one_minus_alphas_cumprod.shape) < len(sample.shape): sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] sigma = sqrt_one_minus_alphas_cumprod.expand(sample.shape) @@ -329,7 +358,9 @@ def loss_step( eval: bool = False, ): images = batch["pixel_values"] - generator = torch.Generator(device=images.device).manual_seed(seed + step) if eval else None + generator = ( + torch.Generator(device=images.device).manual_seed(seed + step) if eval else None + ) bsz = images.shape[0] # Convert images to latent space @@ -342,7 +373,7 @@ def loss_step( dtype=latents.dtype, layout=latents.layout, device=latents.device, - generator=generator + generator=generator, ) applied_noise = noise @@ -353,7 +384,7 @@ def loss_step( octaves=4, dtype=latents.dtype, device=latents.device, - generator=generator + generator=generator, ) if input_pertubation != 0: @@ -362,7 +393,7 @@ def loss_step( dtype=latents.dtype, layout=latents.layout, device=latents.device, - generator=generator + generator=generator, ) # Sample a random timestep for each image @@ -375,25 +406,27 @@ def loss_step( # Get the text embedding for conditioning encoder_hidden_states = get_extended_embeddings( - text_encoder, - batch["input_ids"], - batch["attention_mask"] + text_encoder, batch["input_ids"], batch["attention_mask"] ) encoder_hidden_states = encoder_hidden_states.to(dtype=unet.dtype) # Predict the noise residual - model_pred = unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0] + model_pred = unet( + noisy_latents, timesteps, encoder_hidden_states, return_dict=False + )[0] if guidance_scale != 0: uncond_encoder_hidden_states = get_extended_embeddings( - text_encoder, - batch["negative_input_ids"], - batch["negative_attention_mask"] + text_encoder, batch["negative_input_ids"], batch["negative_attention_mask"] ) uncond_encoder_hidden_states = uncond_encoder_hidden_states.to(dtype=unet.dtype) - model_pred_uncond = unet(noisy_latents, timesteps, uncond_encoder_hidden_states, return_dict=False)[0] - model_pred = model_pred_uncond + guidance_scale * (model_pred - model_pred_uncond) + model_pred_uncond = unet( + noisy_latents, timesteps, uncond_encoder_hidden_states, return_dict=False + )[0] + model_pred = model_pred_uncond + guidance_scale * ( + model_pred - model_pred_uncond + ) # Get the target for loss depending on the prediction type if noise_scheduler.config.prediction_type == "epsilon": @@ -401,7 +434,9 @@ def loss_step( elif noise_scheduler.config.prediction_type == "v_prediction": target = noise_scheduler.get_velocity(latents, noise, timesteps) else: - raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") + raise ValueError( + f"Unknown prediction type {noise_scheduler.config.prediction_type}" + ) acc = (model_pred == target).float().mean() @@ -414,7 +449,9 @@ def loss_step( loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") # Compute prior loss - prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="none") + prior_loss = F.mse_loss( + model_pred_prior.float(), target_prior.float(), reduction="none" + ) # Add the prior loss to the instance loss. loss = loss + prior_loss_weight * prior_loss @@ -433,7 +470,10 @@ def loss_step( if min_snr_gamma != 0: snr = compute_snr(timesteps, noise_scheduler) mse_loss_weights = ( - torch.stack([snr, min_snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr + torch.stack([snr, min_snr_gamma * torch.ones_like(timesteps)], dim=1).min( + dim=1 + )[0] + / snr ) loss = loss * mse_loss_weights @@ -447,8 +487,14 @@ def loss_step( class LossCallable(Protocol): - def __call__(self, step: int, batch: dict[Any, Any], cache: dict[str, Any], - eval: bool = False) -> Tuple[Any, Any, int]: ... + def __call__( + self, + step: int, + batch: dict[Any, Any], + cache: dict[str, Any], + eval: bool = False, + ) -> Tuple[Any, Any, int]: + ... def train_loop( @@ -472,9 +518,14 @@ def train_loop( avg_acc_val: AverageMeter = AverageMeter(), callbacks: TrainingCallbacks = TrainingCallbacks(), ): - num_training_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) - num_val_steps_per_epoch = math.ceil( - len(val_dataloader) / gradient_accumulation_steps) if val_dataloader is not None else 0 + num_training_steps_per_epoch = math.ceil( + len(train_dataloader) / gradient_accumulation_steps + ) + num_val_steps_per_epoch = ( + math.ceil(len(val_dataloader) / gradient_accumulation_steps) + if val_dataloader is not None + else 0 + ) num_training_steps = num_training_steps_per_epoch * num_epochs num_val_steps = num_val_steps_per_epoch * num_epochs @@ -488,14 +539,14 @@ def train_loop( local_progress_bar = tqdm( range(num_training_steps_per_epoch + num_val_steps_per_epoch), disable=not accelerator.is_local_main_process, - dynamic_ncols=True + dynamic_ncols=True, ) local_progress_bar.set_description(f"Epoch 1 / {num_epochs}") global_progress_bar = tqdm( range(num_training_steps + num_val_steps), disable=not accelerator.is_local_main_process, - dynamic_ncols=True + dynamic_ncols=True, ) global_progress_bar.set_description("Total progress") @@ -513,7 +564,9 @@ def train_loop( try: import dadaptation - isDadaptation = isinstance(optimizer.optimizer, (dadaptation.DAdaptAdam, dadaptation.DAdaptAdan)) + isDadaptation = isinstance( + optimizer.optimizer, (dadaptation.DAdaptAdam, dadaptation.DAdaptAdan) + ) except ImportError: pass @@ -565,7 +618,10 @@ def train_loop( label = group_labels[i] if i < len(group_labels) else f"{i}" logs[f"lr/{label}"] = lr if isDadaptation: - lr = optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"] + lr = ( + optimizer.param_groups[i]["d"] + * optimizer.param_groups[i]["lr"] + ) logs[f"d*lr/{label}"] = lr lrs[label] = lr @@ -573,8 +629,10 @@ def train_loop( local_progress_bar.set_postfix(**logs) - if ((step + 1) % gradient_accumulation_steps == 0) or ((step + 1) == len(train_dataloader)): - before_optimize_result = on_before_optimize(epoch) + if ((step + 1) % gradient_accumulation_steps == 0) or ( + (step + 1) == len(train_dataloader) + ): + before_optimize_result = on_before_optimize(cycle) optimizer.step() lr_scheduler.step() @@ -614,7 +672,9 @@ def train_loop( } local_progress_bar.set_postfix(**logs) - if ((step + 1) % gradient_accumulation_steps == 0) or ((step + 1) == len(val_dataloader)): + if ((step + 1) % gradient_accumulation_steps == 0) or ( + (step + 1) == len(val_dataloader) + ): local_progress_bar.update(1) global_progress_bar.update(1) @@ -634,7 +694,8 @@ def train_loop( global_progress_bar.clear() accelerator.print( - f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg:.2e}") + f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg:.2e}" + ) on_checkpoint(global_step, "milestone") best_acc_val = avg_acc_val.max else: @@ -644,7 +705,8 @@ def train_loop( global_progress_bar.clear() accelerator.print( - f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg:.2e}") + f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg:.2e}" + ) on_checkpoint(global_step, "milestone") best_acc = avg_acc.max @@ -700,17 +762,32 @@ def train( avg_acc_val: AverageMeter = AverageMeter(), **kwargs, ): - text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = strategy.prepare( - accelerator, text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, **kwargs) + ( + text_encoder, + unet, + optimizer, + train_dataloader, + val_dataloader, + lr_scheduler, + ) = strategy.prepare( + accelerator, + text_encoder, + unet, + optimizer, + train_dataloader, + val_dataloader, + lr_scheduler, + **kwargs, + ) vae.to(accelerator.device, dtype=dtype) vae.requires_grad_(False) vae.eval() - vae = torch.compile(vae, backend='hidet') + vae = torch.compile(vae, backend="hidet") if compile_unet: - unet = torch.compile(unet, backend='hidet') + unet = torch.compile(unet, backend="hidet") # unet = torch.compile(unet, mode="reduce-overhead") callbacks = strategy.callbacks( diff --git a/training/lr.py b/training/lr.py index f5b362f..a75078f 100644 --- a/training/lr.py +++ b/training/lr.py @@ -23,12 +23,12 @@ def plot_metrics( fig, ax_loss = plt.subplots() ax_acc = ax_loss.twinx() - ax_loss.plot(lrs, losses, color='red') + ax_loss.plot(lrs, losses, color="red") ax_loss.set_xscale("log") ax_loss.set_xlabel(f"Learning rate") ax_loss.set_ylabel("Loss") - ax_acc.plot(lrs, accs, color='blue') + ax_acc.plot(lrs, accs, color="blue") ax_acc.set_xscale("log") ax_acc.set_ylabel("Accuracy") diff --git a/training/optimization.py b/training/optimization.py index d22a900..55531bf 100644 --- a/training/optimization.py +++ b/training/optimization.py @@ -5,7 +5,10 @@ from functools import partial import torch from torch.optim.lr_scheduler import LambdaLR -from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup +from diffusers.optimization import ( + get_scheduler as get_scheduler_, + get_cosine_with_hard_restarts_schedule_with_warmup, +) from transformers.optimization import get_adafactor_schedule @@ -52,7 +55,7 @@ def get_one_cycle_schedule( annealing_exp: int = 1, min_lr: float = 0.04, mid_point: float = 0.3, - last_epoch: int = -1 + last_epoch: int = -1, ): if warmup == "linear": warmup_func = warmup_linear @@ -83,12 +86,16 @@ def get_one_cycle_schedule( def lr_lambda(current_step: int): phase = [p for p in phases if current_step >= p.step_min][-1] - return phase.min + phase.func((current_step - phase.step_min) / (phase.step_max - phase.step_min)) * (phase.max - phase.min) + return phase.min + phase.func( + (current_step - phase.step_min) / (phase.step_max - phase.step_min) + ) * (phase.max - phase.min) return LambdaLR(optimizer, lr_lambda, last_epoch) -def get_exponential_growing_schedule(optimizer, end_lr: float, num_training_steps: int, last_epoch: int = -1): +def get_exponential_growing_schedule( + optimizer, end_lr: float, num_training_steps: int, last_epoch: int = -1 +): def lr_lambda(base_lr: float, current_step: int): return (end_lr / base_lr) ** (current_step / num_training_steps) @@ -132,7 +139,14 @@ def get_scheduler( ) elif id == "exponential_growth": if cycles is None: - cycles = math.ceil(math.sqrt(((num_training_steps - num_warmup_steps) / num_training_steps_per_epoch))) + cycles = math.ceil( + math.sqrt( + ( + (num_training_steps - num_warmup_steps) + / num_training_steps_per_epoch + ) + ) + ) lr_scheduler = get_exponential_growing_schedule( optimizer=optimizer, @@ -141,7 +155,14 @@ def get_scheduler( ) elif id == "cosine_with_restarts": if cycles is None: - cycles = math.ceil(math.sqrt(((num_training_steps - num_warmup_steps) / num_training_steps_per_epoch))) + cycles = math.ceil( + math.sqrt( + ( + (num_training_steps - num_warmup_steps) + / num_training_steps_per_epoch + ) + ) + ) lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( optimizer=optimizer, @@ -150,10 +171,7 @@ def get_scheduler( num_cycles=cycles, ) elif id == "adafactor": - lr_scheduler = get_adafactor_schedule( - optimizer, - initial_lr=min_lr - ) + lr_scheduler = get_adafactor_schedule(optimizer, initial_lr=min_lr) else: lr_scheduler = get_scheduler_( id, diff --git a/training/sampler.py b/training/sampler.py index bdb3e90..0487d66 100644 --- a/training/sampler.py +++ b/training/sampler.py @@ -134,7 +134,7 @@ class LossSecondMomentResampler(LossAwareSampler): def weights(self): if not self._warmed_up(): return np.ones([self.num_timesteps], dtype=np.float64) - weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) + weights = np.sqrt(np.mean(self._loss_history**2, axis=-1)) weights /= np.sum(weights) weights *= 1 - self.uniform_prob weights += self.uniform_prob / len(weights) diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index e6fcc89..88b441b 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py @@ -29,7 +29,7 @@ def dreambooth_strategy_callbacks( sample_output_dir: Path, checkpoint_output_dir: Path, seed: int, - train_text_encoder_epochs: int, + train_text_encoder_cycles: int, max_grad_norm: float = 1.0, use_ema: bool = False, ema_inv_gamma: float = 1.0, @@ -85,15 +85,13 @@ def dreambooth_strategy_callbacks( return nullcontext() @contextmanager - def on_train(epoch: int): + def on_train(cycle: int): unet.train() tokenizer.train() - if epoch < train_text_encoder_epochs: + if cycle < train_text_encoder_cycles: text_encoder.train() - elif epoch == train_text_encoder_epochs: - text_encoder.requires_grad_(False) - text_encoder.eval() + tokenizer.train() yield @@ -106,9 +104,9 @@ def dreambooth_strategy_callbacks( with ema_context(): yield - def on_before_optimize(epoch: int): + def on_before_optimize(cycle: int): params_to_clip = [unet.parameters()] - if epoch < train_text_encoder_epochs: + if cycle < train_text_encoder_cycles: params_to_clip.append(text_encoder.parameters()) accelerator.clip_grad_norm_(itertools.chain(*params_to_clip), max_grad_norm) @@ -189,8 +187,16 @@ def dreambooth_prepare( lr_scheduler: torch.optim.lr_scheduler._LRScheduler, **kwargs ): - text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( - text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) + ( + text_encoder, + unet, + optimizer, + train_dataloader, + val_dataloader, + lr_scheduler, + ) = accelerator.prepare( + text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler + ) text_encoder.text_model.embeddings.requires_grad_(False) @@ -198,6 +204,5 @@ def dreambooth_prepare( dreambooth_strategy = TrainingStrategy( - callbacks=dreambooth_strategy_callbacks, - prepare=dreambooth_prepare + callbacks=dreambooth_strategy_callbacks, prepare=dreambooth_prepare ) diff --git a/training/strategy/lora.py b/training/strategy/lora.py index f942b76..14e3384 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py @@ -81,7 +81,7 @@ def lora_strategy_callbacks( tokenizer.eval() yield - def on_before_optimize(epoch: int): + def on_before_optimize(cycle: int): if not pti_mode: accelerator.clip_grad_norm_( itertools.chain( @@ -89,7 +89,7 @@ def lora_strategy_callbacks( text_encoder.text_model.encoder.parameters(), text_encoder.text_model.final_layer_norm.parameters(), ), - max_grad_norm + max_grad_norm, ) if len(placeholder_tokens) != 0 and use_emb_decay: @@ -108,7 +108,9 @@ def lora_strategy_callbacks( if lambda_ != 0: norm = w[:, :].norm(dim=-1, keepdim=True) - w[:].add_((w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) + w[:].add_( + (w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm) + ) @torch.no_grad() def on_checkpoint(step, postfix): @@ -128,25 +130,32 @@ def lora_strategy_callbacks( if not pti_mode: lora_config = {} - state_dict = get_peft_model_state_dict(unet_, state_dict=accelerator.get_state_dict(unet_)) + state_dict = get_peft_model_state_dict( + unet_, state_dict=accelerator.get_state_dict(unet_) + ) lora_config["peft_config"] = unet_.get_peft_config_as_dict(inference=True) text_encoder_state_dict = get_peft_model_state_dict( text_encoder_, state_dict=accelerator.get_state_dict(text_encoder_) ) - text_encoder_state_dict = {f"text_encoder_{k}": v for k, v in text_encoder_state_dict.items()} + text_encoder_state_dict = { + f"text_encoder_{k}": v for k, v in text_encoder_state_dict.items() + } state_dict.update(text_encoder_state_dict) - lora_config["text_encoder_peft_config"] = text_encoder_.get_peft_config_as_dict(inference=True) + lora_config[ + "text_encoder_peft_config" + ] = text_encoder_.get_peft_config_as_dict(inference=True) if len(placeholder_tokens) != 0: ti_state_dict = { f"ti_${token}": text_encoder.text_model.embeddings.get_embed(ids) - for (token, ids) - in zip(placeholder_tokens, placeholder_token_ids) + for (token, ids) in zip(placeholder_tokens, placeholder_token_ids) } state_dict.update(ti_state_dict) - save_file(state_dict, checkpoint_output_dir / f"{step}_{postfix}.safetensors") + save_file( + state_dict, checkpoint_output_dir / f"{step}_{postfix}.safetensors" + ) with open(checkpoint_output_dir / "lora_config.json", "w") as f: json.dump(lora_config, f) @@ -185,10 +194,18 @@ def lora_prepare( train_dataloader: DataLoader, val_dataloader: Optional[DataLoader], lr_scheduler: torch.optim.lr_scheduler._LRScheduler, - **kwargs + **kwargs, ): - text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( - text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) + ( + text_encoder, + unet, + optimizer, + train_dataloader, + val_dataloader, + lr_scheduler, + ) = accelerator.prepare( + text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler + ) # text_encoder.text_model.embeddings.token_embedding.requires_grad_(True) diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 6bc1d7d..7373982 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py @@ -104,7 +104,7 @@ def textual_inversion_strategy_callbacks( yield @torch.no_grad() - def on_before_optimize(epoch: int): + def on_before_optimize(cycle: int): if use_emb_decay: params = [ p @@ -116,7 +116,9 @@ def textual_inversion_strategy_callbacks( @torch.no_grad() def on_after_optimize(w, lrs: dict[str, float]): if ema_embeddings is not None: - ema_embeddings.step(text_encoder.text_model.embeddings.token_embedding.parameters()) + ema_embeddings.step( + text_encoder.text_model.embeddings.token_embedding.parameters() + ) if use_emb_decay and w is not None: lr = lrs["emb"] if "emb" in lrs else lrs["0"] @@ -124,7 +126,9 @@ def textual_inversion_strategy_callbacks( if lambda_ != 0: norm = w[:, :].norm(dim=-1, keepdim=True) - w[:].add_((w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) + w[:].add_( + (w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm) + ) def on_log(): if ema_embeddings is not None: @@ -136,10 +140,10 @@ def textual_inversion_strategy_callbacks( print(f"Saving checkpoint for step {step}...") with ema_context(): - for (token, ids) in zip(placeholder_tokens, placeholder_token_ids): + for token, ids in zip(placeholder_tokens, placeholder_token_ids): text_encoder.text_model.embeddings.save_embed( ids, - checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin" + checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin", ) @torch.no_grad() @@ -183,7 +187,7 @@ def textual_inversion_prepare( val_dataloader: Optional[DataLoader], lr_scheduler: torch.optim.lr_scheduler._LRScheduler, gradient_checkpointing: bool = False, - **kwargs + **kwargs, ): weight_dtype = torch.float32 if accelerator.state.mixed_precision == "fp16": @@ -191,8 +195,15 @@ def textual_inversion_prepare( elif accelerator.state.mixed_precision == "bf16": weight_dtype = torch.bfloat16 - text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( - text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler) + ( + text_encoder, + optimizer, + train_dataloader, + val_dataloader, + lr_scheduler, + ) = accelerator.prepare( + text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler + ) unet.to(accelerator.device, dtype=weight_dtype) unet.requires_grad_(False) -- cgit v1.2.3-54-g00ecf