diff options
| -rw-r--r-- | environment.yaml | 4 | ||||
| -rw-r--r-- | infer.py | 21 | ||||
| -rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 14 | ||||
| -rw-r--r-- | train_dreambooth.py | 46 | ||||
| -rw-r--r-- | train_lora.py | 566 | ||||
| -rw-r--r-- | train_ti.py | 10 | ||||
| -rw-r--r-- | training/functional.py | 31 | ||||
| -rw-r--r-- | training/strategy/dreambooth.py | 35 | ||||
| -rw-r--r-- | training/strategy/lora.py | 147 | ||||
| -rw-r--r-- | training/strategy/ti.py | 38 |
10 files changed, 819 insertions, 93 deletions
diff --git a/environment.yaml b/environment.yaml index c992759..f5632bf 100644 --- a/environment.yaml +++ b/environment.yaml | |||
| @@ -18,11 +18,11 @@ dependencies: | |||
| 18 | - -e git+https://github.com/huggingface/diffusers#egg=diffusers | 18 | - -e git+https://github.com/huggingface/diffusers#egg=diffusers |
| 19 | - -e git+https://github.com/cloneofsimo/lora#egg=lora-diffusion | 19 | - -e git+https://github.com/cloneofsimo/lora#egg=lora-diffusion |
| 20 | - accelerate==0.15.0 | 20 | - accelerate==0.15.0 |
| 21 | - bitsandbytes==0.36.0.post2 | 21 | - bitsandbytes==0.37.0 |
| 22 | - python-slugify>=6.1.2 | 22 | - python-slugify>=6.1.2 |
| 23 | - safetensors==0.2.7 | 23 | - safetensors==0.2.7 |
| 24 | - setuptools==65.6.3 | 24 | - setuptools==65.6.3 |
| 25 | - test-tube>=0.7.5 | 25 | - test-tube>=0.7.5 |
| 26 | - transformers==4.25.1 | 26 | - transformers==4.25.1 |
| 27 | - triton==2.0.0.dev20221202 | 27 | - triton==2.0.0.dev20221202 |
| 28 | - xformers==0.0.16.dev430 | 28 | - xformers==0.0.17.dev443 |
| @@ -39,7 +39,8 @@ torch.backends.cudnn.benchmark = True | |||
| 39 | default_args = { | 39 | default_args = { |
| 40 | "model": "stabilityai/stable-diffusion-2-1", | 40 | "model": "stabilityai/stable-diffusion-2-1", |
| 41 | "precision": "fp32", | 41 | "precision": "fp32", |
| 42 | "ti_embeddings_dir": "embeddings", | 42 | "ti_embeddings_dir": "embeddings_ti", |
| 43 | "lora_embeddings_dir": "embeddings_lora", | ||
| 43 | "output_dir": "output/inference", | 44 | "output_dir": "output/inference", |
| 44 | "config": None, | 45 | "config": None, |
| 45 | } | 46 | } |
| @@ -60,6 +61,7 @@ default_cmds = { | |||
| 60 | "batch_num": 1, | 61 | "batch_num": 1, |
| 61 | "steps": 30, | 62 | "steps": 30, |
| 62 | "guidance_scale": 7.0, | 63 | "guidance_scale": 7.0, |
| 64 | "lora_scale": 0.5, | ||
| 63 | "seed": None, | 65 | "seed": None, |
| 64 | "config": None, | 66 | "config": None, |
| 65 | } | 67 | } |
| @@ -92,6 +94,10 @@ def create_args_parser(): | |||
| 92 | type=str, | 94 | type=str, |
| 93 | ) | 95 | ) |
| 94 | parser.add_argument( | 96 | parser.add_argument( |
| 97 | "--lora_embeddings_dir", | ||
| 98 | type=str, | ||
| 99 | ) | ||
| 100 | parser.add_argument( | ||
| 95 | "--output_dir", | 101 | "--output_dir", |
| 96 | type=str, | 102 | type=str, |
| 97 | ) | 103 | ) |
| @@ -169,6 +175,10 @@ def create_cmd_parser(): | |||
| 169 | type=float, | 175 | type=float, |
| 170 | ) | 176 | ) |
| 171 | parser.add_argument( | 177 | parser.add_argument( |
| 178 | "--lora_scale", | ||
| 179 | type=float, | ||
| 180 | ) | ||
| 181 | parser.add_argument( | ||
| 172 | "--seed", | 182 | "--seed", |
| 173 | type=int, | 183 | type=int, |
| 174 | ) | 184 | ) |
| @@ -315,6 +325,7 @@ def generate(output_dir: Path, pipeline, args): | |||
| 315 | generator=generator, | 325 | generator=generator, |
| 316 | image=init_image, | 326 | image=init_image, |
| 317 | strength=args.image_noise, | 327 | strength=args.image_noise, |
| 328 | cross_attention_kwargs={"scale": args.lora_scale}, | ||
| 318 | ).images | 329 | ).images |
| 319 | 330 | ||
| 320 | for j, image in enumerate(images): | 331 | for j, image in enumerate(images): |
| @@ -334,11 +345,12 @@ class CmdParse(cmd.Cmd): | |||
| 334 | prompt = 'dream> ' | 345 | prompt = 'dream> ' |
| 335 | commands = [] | 346 | commands = [] |
| 336 | 347 | ||
| 337 | def __init__(self, output_dir, ti_embeddings_dir, pipeline, parser): | 348 | def __init__(self, output_dir, ti_embeddings_dir, lora_embeddings_dir, pipeline, parser): |
| 338 | super().__init__() | 349 | super().__init__() |
| 339 | 350 | ||
| 340 | self.output_dir = output_dir | 351 | self.output_dir = output_dir |
| 341 | self.ti_embeddings_dir = ti_embeddings_dir | 352 | self.ti_embeddings_dir = ti_embeddings_dir |
| 353 | self.lora_embeddings_dir = lora_embeddings_dir | ||
| 342 | self.pipeline = pipeline | 354 | self.pipeline = pipeline |
| 343 | self.parser = parser | 355 | self.parser = parser |
| 344 | 356 | ||
| @@ -394,9 +406,12 @@ def main(): | |||
| 394 | dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.precision] | 406 | dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.precision] |
| 395 | 407 | ||
| 396 | pipeline = create_pipeline(args.model, dtype) | 408 | pipeline = create_pipeline(args.model, dtype) |
| 409 | |||
| 397 | load_embeddings(pipeline, args.ti_embeddings_dir) | 410 | load_embeddings(pipeline, args.ti_embeddings_dir) |
| 411 | pipeline.unet.load_attn_procs(args.lora_embeddings_dir) | ||
| 412 | |||
| 398 | cmd_parser = create_cmd_parser() | 413 | cmd_parser = create_cmd_parser() |
| 399 | cmd_prompt = CmdParse(output_dir, args.ti_embeddings_dir, pipeline, cmd_parser) | 414 | cmd_prompt = CmdParse(output_dir, args.ti_embeddings_dir, args.lora_embeddings_dir, pipeline, cmd_parser) |
| 400 | cmd_prompt.cmdloop() | 415 | cmd_prompt.cmdloop() |
| 401 | 416 | ||
| 402 | 417 | ||
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 3027421..dab7878 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
| @@ -1,6 +1,6 @@ | |||
| 1 | import inspect | 1 | import inspect |
| 2 | import warnings | 2 | import warnings |
| 3 | from typing import List, Optional, Union, Callable | 3 | from typing import List, Dict, Any, Optional, Union, Callable |
| 4 | 4 | ||
| 5 | import numpy as np | 5 | import numpy as np |
| 6 | import torch | 6 | import torch |
| @@ -337,6 +337,7 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 337 | return_dict: bool = True, | 337 | return_dict: bool = True, |
| 338 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, | 338 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, |
| 339 | callback_steps: int = 1, | 339 | callback_steps: int = 1, |
| 340 | cross_attention_kwargs: Optional[Dict[str, Any]] = None, | ||
| 340 | ): | 341 | ): |
| 341 | r""" | 342 | r""" |
| 342 | Function invoked when calling the pipeline for generation. | 343 | Function invoked when calling the pipeline for generation. |
| @@ -379,6 +380,10 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 379 | return_dict (`bool`, *optional*, defaults to `True`): | 380 | return_dict (`bool`, *optional*, defaults to `True`): |
| 380 | Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a | 381 | Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a |
| 381 | plain tuple. | 382 | plain tuple. |
| 383 | cross_attention_kwargs (`dict`, *optional*): | ||
| 384 | A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under | ||
| 385 | `self.processor` in | ||
| 386 | [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). | ||
| 382 | 387 | ||
| 383 | Returns: | 388 | Returns: |
| 384 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: | 389 | [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: |
| @@ -450,7 +455,12 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 450 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) | 455 | latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) |
| 451 | 456 | ||
| 452 | # predict the noise residual | 457 | # predict the noise residual |
| 453 | noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=text_embeddings).sample | 458 | noise_pred = self.unet( |
| 459 | latent_model_input, | ||
| 460 | t, | ||
| 461 | encoder_hidden_states=text_embeddings, | ||
| 462 | cross_attention_kwargs=cross_attention_kwargs, | ||
| 463 | ).sample | ||
| 454 | 464 | ||
| 455 | # perform guidance | 465 | # perform guidance |
| 456 | if do_classifier_free_guidance: | 466 | if do_classifier_free_guidance: |
diff --git a/train_dreambooth.py b/train_dreambooth.py index a70c80e..5a4c47b 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -74,26 +74,6 @@ def parse_args(): | |||
| 74 | help="The name of the current project.", | 74 | help="The name of the current project.", |
| 75 | ) | 75 | ) |
| 76 | parser.add_argument( | 76 | parser.add_argument( |
| 77 | "--placeholder_tokens", | ||
| 78 | type=str, | ||
| 79 | nargs='*', | ||
| 80 | default=[], | ||
| 81 | help="A token to use as a placeholder for the concept.", | ||
| 82 | ) | ||
| 83 | parser.add_argument( | ||
| 84 | "--initializer_tokens", | ||
| 85 | type=str, | ||
| 86 | nargs='*', | ||
| 87 | default=[], | ||
| 88 | help="A token to use as initializer word." | ||
| 89 | ) | ||
| 90 | parser.add_argument( | ||
| 91 | "--num_vectors", | ||
| 92 | type=int, | ||
| 93 | nargs='*', | ||
| 94 | help="Number of vectors per embedding." | ||
| 95 | ) | ||
| 96 | parser.add_argument( | ||
| 97 | "--exclude_collections", | 77 | "--exclude_collections", |
| 98 | type=str, | 78 | type=str, |
| 99 | nargs='*', | 79 | nargs='*', |
| @@ -436,30 +416,6 @@ def parse_args(): | |||
| 436 | if args.project is None: | 416 | if args.project is None: |
| 437 | raise ValueError("You must specify --project") | 417 | raise ValueError("You must specify --project") |
| 438 | 418 | ||
| 439 | if isinstance(args.placeholder_tokens, str): | ||
| 440 | args.placeholder_tokens = [args.placeholder_tokens] | ||
| 441 | |||
| 442 | if isinstance(args.initializer_tokens, str): | ||
| 443 | args.initializer_tokens = [args.initializer_tokens] * len(args.placeholder_tokens) | ||
| 444 | |||
| 445 | if len(args.initializer_tokens) == 0: | ||
| 446 | raise ValueError("You must specify --initializer_tokens") | ||
| 447 | |||
| 448 | if len(args.placeholder_tokens) == 0: | ||
| 449 | args.placeholder_tokens = [f"<*{i}>" for i in range(len(args.initializer_tokens))] | ||
| 450 | |||
| 451 | if len(args.placeholder_tokens) != len(args.initializer_tokens): | ||
| 452 | raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") | ||
| 453 | |||
| 454 | if args.num_vectors is None: | ||
| 455 | args.num_vectors = 1 | ||
| 456 | |||
| 457 | if isinstance(args.num_vectors, int): | ||
| 458 | args.num_vectors = [args.num_vectors] * len(args.initializer_tokens) | ||
| 459 | |||
| 460 | if len(args.placeholder_tokens) != len(args.num_vectors): | ||
| 461 | raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") | ||
| 462 | |||
| 463 | if isinstance(args.collection, str): | 419 | if isinstance(args.collection, str): |
| 464 | args.collection = [args.collection] | 420 | args.collection = [args.collection] |
| 465 | 421 | ||
| @@ -503,7 +459,7 @@ def main(): | |||
| 503 | 459 | ||
| 504 | vae.enable_slicing() | 460 | vae.enable_slicing() |
| 505 | vae.set_use_memory_efficient_attention_xformers(True) | 461 | vae.set_use_memory_efficient_attention_xformers(True) |
| 506 | unet.set_use_memory_efficient_attention_xformers(True) | 462 | unet.enable_xformers_memory_efficient_attention() |
| 507 | 463 | ||
| 508 | if args.gradient_checkpointing: | 464 | if args.gradient_checkpointing: |
| 509 | unet.enable_gradient_checkpointing() | 465 | unet.enable_gradient_checkpointing() |
diff --git a/train_lora.py b/train_lora.py new file mode 100644 index 0000000..2cb85cc --- /dev/null +++ b/train_lora.py | |||
| @@ -0,0 +1,566 @@ | |||
| 1 | import argparse | ||
| 2 | import datetime | ||
| 3 | import logging | ||
| 4 | import itertools | ||
| 5 | from pathlib import Path | ||
| 6 | from functools import partial | ||
| 7 | |||
| 8 | import torch | ||
| 9 | import torch.utils.checkpoint | ||
| 10 | |||
| 11 | from accelerate import Accelerator | ||
| 12 | from accelerate.logging import get_logger | ||
| 13 | from accelerate.utils import LoggerType, set_seed | ||
| 14 | from slugify import slugify | ||
| 15 | from diffusers.loaders import AttnProcsLayers | ||
| 16 | from diffusers.models.cross_attention import LoRACrossAttnProcessor | ||
| 17 | |||
| 18 | from util import load_config, load_embeddings_from_dir | ||
| 19 | from data.csv import VlpnDataModule, keyword_filter | ||
| 20 | from training.functional import train, get_models | ||
| 21 | from training.lr import plot_metrics | ||
| 22 | from training.strategy.lora import lora_strategy | ||
| 23 | from training.optimization import get_scheduler | ||
| 24 | from training.util import save_args | ||
| 25 | |||
| 26 | logger = get_logger(__name__) | ||
| 27 | |||
| 28 | |||
| 29 | torch.backends.cuda.matmul.allow_tf32 = True | ||
| 30 | torch.backends.cudnn.benchmark = True | ||
| 31 | |||
| 32 | |||
| 33 | def parse_args(): | ||
| 34 | parser = argparse.ArgumentParser( | ||
| 35 | description="Simple example of a training script." | ||
| 36 | ) | ||
| 37 | parser.add_argument( | ||
| 38 | "--pretrained_model_name_or_path", | ||
| 39 | type=str, | ||
| 40 | default=None, | ||
| 41 | help="Path to pretrained model or model identifier from huggingface.co/models.", | ||
| 42 | ) | ||
| 43 | parser.add_argument( | ||
| 44 | "--tokenizer_name", | ||
| 45 | type=str, | ||
| 46 | default=None, | ||
| 47 | help="Pretrained tokenizer name or path if not the same as model_name", | ||
| 48 | ) | ||
| 49 | parser.add_argument( | ||
| 50 | "--train_data_file", | ||
| 51 | type=str, | ||
| 52 | default=None, | ||
| 53 | help="A folder containing the training data." | ||
| 54 | ) | ||
| 55 | parser.add_argument( | ||
| 56 | "--train_data_template", | ||
| 57 | type=str, | ||
| 58 | default="template", | ||
| 59 | ) | ||
| 60 | parser.add_argument( | ||
| 61 | "--train_set_pad", | ||
| 62 | type=int, | ||
| 63 | default=None, | ||
| 64 | help="The number to fill train dataset items up to." | ||
| 65 | ) | ||
| 66 | parser.add_argument( | ||
| 67 | "--valid_set_pad", | ||
| 68 | type=int, | ||
| 69 | default=None, | ||
| 70 | help="The number to fill validation dataset items up to." | ||
| 71 | ) | ||
| 72 | parser.add_argument( | ||
| 73 | "--project", | ||
| 74 | type=str, | ||
| 75 | default=None, | ||
| 76 | help="The name of the current project.", | ||
| 77 | ) | ||
| 78 | parser.add_argument( | ||
| 79 | "--exclude_collections", | ||
| 80 | type=str, | ||
| 81 | nargs='*', | ||
| 82 | help="Exclude all items with a listed collection.", | ||
| 83 | ) | ||
| 84 | parser.add_argument( | ||
| 85 | "--num_buckets", | ||
| 86 | type=int, | ||
| 87 | default=4, | ||
| 88 | help="Number of aspect ratio buckets in either direction.", | ||
| 89 | ) | ||
| 90 | parser.add_argument( | ||
| 91 | "--progressive_buckets", | ||
| 92 | action="store_true", | ||
| 93 | help="Include images in smaller buckets as well.", | ||
| 94 | ) | ||
| 95 | parser.add_argument( | ||
| 96 | "--bucket_step_size", | ||
| 97 | type=int, | ||
| 98 | default=64, | ||
| 99 | help="Step size between buckets.", | ||
| 100 | ) | ||
| 101 | parser.add_argument( | ||
| 102 | "--bucket_max_pixels", | ||
| 103 | type=int, | ||
| 104 | default=None, | ||
| 105 | help="Maximum pixels per bucket.", | ||
| 106 | ) | ||
| 107 | parser.add_argument( | ||
| 108 | "--tag_dropout", | ||
| 109 | type=float, | ||
| 110 | default=0.1, | ||
| 111 | help="Tag dropout probability.", | ||
| 112 | ) | ||
| 113 | parser.add_argument( | ||
| 114 | "--no_tag_shuffle", | ||
| 115 | action="store_true", | ||
| 116 | help="Shuffle tags.", | ||
| 117 | ) | ||
| 118 | parser.add_argument( | ||
| 119 | "--num_class_images", | ||
| 120 | type=int, | ||
| 121 | default=0, | ||
| 122 | help="How many class images to generate." | ||
| 123 | ) | ||
| 124 | parser.add_argument( | ||
| 125 | "--class_image_dir", | ||
| 126 | type=str, | ||
| 127 | default="cls", | ||
| 128 | help="The directory where class images will be saved.", | ||
| 129 | ) | ||
| 130 | parser.add_argument( | ||
| 131 | "--output_dir", | ||
| 132 | type=str, | ||
| 133 | default="output/lora", | ||
| 134 | help="The output directory where the model predictions and checkpoints will be written.", | ||
| 135 | ) | ||
| 136 | parser.add_argument( | ||
| 137 | "--embeddings_dir", | ||
| 138 | type=str, | ||
| 139 | default=None, | ||
| 140 | help="The embeddings directory where Textual Inversion embeddings are stored.", | ||
| 141 | ) | ||
| 142 | parser.add_argument( | ||
| 143 | "--collection", | ||
| 144 | type=str, | ||
| 145 | nargs='*', | ||
| 146 | help="A collection to filter the dataset.", | ||
| 147 | ) | ||
| 148 | parser.add_argument( | ||
| 149 | "--seed", | ||
| 150 | type=int, | ||
| 151 | default=None, | ||
| 152 | help="A seed for reproducible training." | ||
| 153 | ) | ||
| 154 | parser.add_argument( | ||
| 155 | "--resolution", | ||
| 156 | type=int, | ||
| 157 | default=768, | ||
| 158 | help=( | ||
| 159 | "The resolution for input images, all the images in the train/validation dataset will be resized to this" | ||
| 160 | " resolution" | ||
| 161 | ), | ||
| 162 | ) | ||
| 163 | parser.add_argument( | ||
| 164 | "--num_train_epochs", | ||
| 165 | type=int, | ||
| 166 | default=100 | ||
| 167 | ) | ||
| 168 | parser.add_argument( | ||
| 169 | "--max_train_steps", | ||
| 170 | type=int, | ||
| 171 | default=None, | ||
| 172 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", | ||
| 173 | ) | ||
| 174 | parser.add_argument( | ||
| 175 | "--gradient_accumulation_steps", | ||
| 176 | type=int, | ||
| 177 | default=1, | ||
| 178 | help="Number of updates steps to accumulate before performing a backward/update pass.", | ||
| 179 | ) | ||
| 180 | parser.add_argument( | ||
| 181 | "--find_lr", | ||
| 182 | action="store_true", | ||
| 183 | help="Automatically find a learning rate (no training).", | ||
| 184 | ) | ||
| 185 | parser.add_argument( | ||
| 186 | "--learning_rate", | ||
| 187 | type=float, | ||
| 188 | default=2e-6, | ||
| 189 | help="Initial learning rate (after the potential warmup period) to use.", | ||
| 190 | ) | ||
| 191 | parser.add_argument( | ||
| 192 | "--scale_lr", | ||
| 193 | action="store_true", | ||
| 194 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", | ||
| 195 | ) | ||
| 196 | parser.add_argument( | ||
| 197 | "--lr_scheduler", | ||
| 198 | type=str, | ||
| 199 | default="one_cycle", | ||
| 200 | help=( | ||
| 201 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' | ||
| 202 | ' "constant", "constant_with_warmup", "one_cycle"]' | ||
| 203 | ), | ||
| 204 | ) | ||
| 205 | parser.add_argument( | ||
| 206 | "--lr_warmup_epochs", | ||
| 207 | type=int, | ||
| 208 | default=10, | ||
| 209 | help="Number of steps for the warmup in the lr scheduler." | ||
| 210 | ) | ||
| 211 | parser.add_argument( | ||
| 212 | "--lr_cycles", | ||
| 213 | type=int, | ||
| 214 | default=None, | ||
| 215 | help="Number of restart cycles in the lr scheduler (if supported)." | ||
| 216 | ) | ||
| 217 | parser.add_argument( | ||
| 218 | "--lr_warmup_func", | ||
| 219 | type=str, | ||
| 220 | default="cos", | ||
| 221 | help='Choose between ["linear", "cos"]' | ||
| 222 | ) | ||
| 223 | parser.add_argument( | ||
| 224 | "--lr_warmup_exp", | ||
| 225 | type=int, | ||
| 226 | default=1, | ||
| 227 | help='If lr_warmup_func is "cos", exponent to modify the function' | ||
| 228 | ) | ||
| 229 | parser.add_argument( | ||
| 230 | "--lr_annealing_func", | ||
| 231 | type=str, | ||
| 232 | default="cos", | ||
| 233 | help='Choose between ["linear", "half_cos", "cos"]' | ||
| 234 | ) | ||
| 235 | parser.add_argument( | ||
| 236 | "--lr_annealing_exp", | ||
| 237 | type=int, | ||
| 238 | default=3, | ||
| 239 | help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' | ||
| 240 | ) | ||
| 241 | parser.add_argument( | ||
| 242 | "--lr_min_lr", | ||
| 243 | type=float, | ||
| 244 | default=0.04, | ||
| 245 | help="Minimum learning rate in the lr scheduler." | ||
| 246 | ) | ||
| 247 | parser.add_argument( | ||
| 248 | "--use_8bit_adam", | ||
| 249 | action="store_true", | ||
| 250 | help="Whether or not to use 8-bit Adam from bitsandbytes." | ||
| 251 | ) | ||
| 252 | parser.add_argument( | ||
| 253 | "--adam_beta1", | ||
| 254 | type=float, | ||
| 255 | default=0.9, | ||
| 256 | help="The beta1 parameter for the Adam optimizer." | ||
| 257 | ) | ||
| 258 | parser.add_argument( | ||
| 259 | "--adam_beta2", | ||
| 260 | type=float, | ||
| 261 | default=0.999, | ||
| 262 | help="The beta2 parameter for the Adam optimizer." | ||
| 263 | ) | ||
| 264 | parser.add_argument( | ||
| 265 | "--adam_weight_decay", | ||
| 266 | type=float, | ||
| 267 | default=1e-2, | ||
| 268 | help="Weight decay to use." | ||
| 269 | ) | ||
| 270 | parser.add_argument( | ||
| 271 | "--adam_epsilon", | ||
| 272 | type=float, | ||
| 273 | default=1e-08, | ||
| 274 | help="Epsilon value for the Adam optimizer" | ||
| 275 | ) | ||
| 276 | parser.add_argument( | ||
| 277 | "--adam_amsgrad", | ||
| 278 | type=bool, | ||
| 279 | default=False, | ||
| 280 | help="Amsgrad value for the Adam optimizer" | ||
| 281 | ) | ||
| 282 | parser.add_argument( | ||
| 283 | "--mixed_precision", | ||
| 284 | type=str, | ||
| 285 | default="no", | ||
| 286 | choices=["no", "fp16", "bf16"], | ||
| 287 | help=( | ||
| 288 | "Whether to use mixed precision. Choose" | ||
| 289 | "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." | ||
| 290 | "and an Nvidia Ampere GPU." | ||
| 291 | ), | ||
| 292 | ) | ||
| 293 | parser.add_argument( | ||
| 294 | "--sample_frequency", | ||
| 295 | type=int, | ||
| 296 | default=1, | ||
| 297 | help="How often to save a checkpoint and sample image", | ||
| 298 | ) | ||
| 299 | parser.add_argument( | ||
| 300 | "--sample_image_size", | ||
| 301 | type=int, | ||
| 302 | default=768, | ||
| 303 | help="Size of sample images", | ||
| 304 | ) | ||
| 305 | parser.add_argument( | ||
| 306 | "--sample_batches", | ||
| 307 | type=int, | ||
| 308 | default=1, | ||
| 309 | help="Number of sample batches to generate per checkpoint", | ||
| 310 | ) | ||
| 311 | parser.add_argument( | ||
| 312 | "--sample_batch_size", | ||
| 313 | type=int, | ||
| 314 | default=1, | ||
| 315 | help="Number of samples to generate per batch", | ||
| 316 | ) | ||
| 317 | parser.add_argument( | ||
| 318 | "--valid_set_size", | ||
| 319 | type=int, | ||
| 320 | default=None, | ||
| 321 | help="Number of images in the validation dataset." | ||
| 322 | ) | ||
| 323 | parser.add_argument( | ||
| 324 | "--valid_set_repeat", | ||
| 325 | type=int, | ||
| 326 | default=1, | ||
| 327 | help="Times the images in the validation dataset are repeated." | ||
| 328 | ) | ||
| 329 | parser.add_argument( | ||
| 330 | "--train_batch_size", | ||
| 331 | type=int, | ||
| 332 | default=1, | ||
| 333 | help="Batch size (per device) for the training dataloader." | ||
| 334 | ) | ||
| 335 | parser.add_argument( | ||
| 336 | "--sample_steps", | ||
| 337 | type=int, | ||
| 338 | default=20, | ||
| 339 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", | ||
| 340 | ) | ||
| 341 | parser.add_argument( | ||
| 342 | "--prior_loss_weight", | ||
| 343 | type=float, | ||
| 344 | default=1.0, | ||
| 345 | help="The weight of prior preservation loss." | ||
| 346 | ) | ||
| 347 | parser.add_argument( | ||
| 348 | "--max_grad_norm", | ||
| 349 | default=1.0, | ||
| 350 | type=float, | ||
| 351 | help="Max gradient norm." | ||
| 352 | ) | ||
| 353 | parser.add_argument( | ||
| 354 | "--noise_timesteps", | ||
| 355 | type=int, | ||
| 356 | default=1000, | ||
| 357 | ) | ||
| 358 | parser.add_argument( | ||
| 359 | "--config", | ||
| 360 | type=str, | ||
| 361 | default=None, | ||
| 362 | help="Path to a JSON configuration file containing arguments for invoking this script." | ||
| 363 | ) | ||
| 364 | |||
| 365 | args = parser.parse_args() | ||
| 366 | if args.config is not None: | ||
| 367 | args = load_config(args.config) | ||
| 368 | args = parser.parse_args(namespace=argparse.Namespace(**args)) | ||
| 369 | |||
| 370 | if args.train_data_file is None: | ||
| 371 | raise ValueError("You must specify --train_data_file") | ||
| 372 | |||
| 373 | if args.pretrained_model_name_or_path is None: | ||
| 374 | raise ValueError("You must specify --pretrained_model_name_or_path") | ||
| 375 | |||
| 376 | if args.project is None: | ||
| 377 | raise ValueError("You must specify --project") | ||
| 378 | |||
| 379 | if isinstance(args.collection, str): | ||
| 380 | args.collection = [args.collection] | ||
| 381 | |||
| 382 | if isinstance(args.exclude_collections, str): | ||
| 383 | args.exclude_collections = [args.exclude_collections] | ||
| 384 | |||
| 385 | if args.output_dir is None: | ||
| 386 | raise ValueError("You must specify --output_dir") | ||
| 387 | |||
| 388 | return args | ||
| 389 | |||
| 390 | |||
| 391 | def main(): | ||
| 392 | args = parse_args() | ||
| 393 | |||
| 394 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | ||
| 395 | output_dir = Path(args.output_dir).joinpath(slugify(args.project), now) | ||
| 396 | output_dir.mkdir(parents=True, exist_ok=True) | ||
| 397 | |||
| 398 | accelerator = Accelerator( | ||
| 399 | log_with=LoggerType.TENSORBOARD, | ||
| 400 | logging_dir=f"{output_dir}", | ||
| 401 | gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
| 402 | mixed_precision=args.mixed_precision | ||
| 403 | ) | ||
| 404 | |||
| 405 | logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) | ||
| 406 | |||
| 407 | if args.seed is None: | ||
| 408 | args.seed = torch.random.seed() >> 32 | ||
| 409 | |||
| 410 | set_seed(args.seed) | ||
| 411 | |||
| 412 | save_args(output_dir, args) | ||
| 413 | |||
| 414 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( | ||
| 415 | args.pretrained_model_name_or_path) | ||
| 416 | |||
| 417 | vae.enable_slicing() | ||
| 418 | vae.set_use_memory_efficient_attention_xformers(True) | ||
| 419 | unet.enable_xformers_memory_efficient_attention() | ||
| 420 | |||
| 421 | lora_attn_procs = {} | ||
| 422 | for name in unet.attn_processors.keys(): | ||
| 423 | cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim | ||
| 424 | if name.startswith("mid_block"): | ||
| 425 | hidden_size = unet.config.block_out_channels[-1] | ||
| 426 | elif name.startswith("up_blocks"): | ||
| 427 | block_id = int(name[len("up_blocks.")]) | ||
| 428 | hidden_size = list(reversed(unet.config.block_out_channels))[block_id] | ||
| 429 | elif name.startswith("down_blocks"): | ||
| 430 | block_id = int(name[len("down_blocks.")]) | ||
| 431 | hidden_size = unet.config.block_out_channels[block_id] | ||
| 432 | |||
| 433 | lora_attn_procs[name] = LoRACrossAttnProcessor( | ||
| 434 | hidden_size=hidden_size, cross_attention_dim=cross_attention_dim | ||
| 435 | ) | ||
| 436 | |||
| 437 | unet.set_attn_processor(lora_attn_procs) | ||
| 438 | lora_layers = AttnProcsLayers(unet.attn_processors) | ||
| 439 | |||
| 440 | if args.embeddings_dir is not None: | ||
| 441 | embeddings_dir = Path(args.embeddings_dir) | ||
| 442 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): | ||
| 443 | raise ValueError("--embeddings_dir must point to an existing directory") | ||
| 444 | |||
| 445 | embeddings.persist() | ||
| 446 | |||
| 447 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) | ||
| 448 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") | ||
| 449 | |||
| 450 | if args.scale_lr: | ||
| 451 | args.learning_rate = ( | ||
| 452 | args.learning_rate * args.gradient_accumulation_steps * | ||
| 453 | args.train_batch_size * accelerator.num_processes | ||
| 454 | ) | ||
| 455 | |||
| 456 | if args.find_lr: | ||
| 457 | args.learning_rate = 1e-6 | ||
| 458 | args.lr_scheduler = "exponential_growth" | ||
| 459 | |||
| 460 | if args.use_8bit_adam: | ||
| 461 | try: | ||
| 462 | import bitsandbytes as bnb | ||
| 463 | except ImportError: | ||
| 464 | raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.") | ||
| 465 | |||
| 466 | optimizer_class = bnb.optim.AdamW8bit | ||
| 467 | else: | ||
| 468 | optimizer_class = torch.optim.AdamW | ||
| 469 | |||
| 470 | weight_dtype = torch.float32 | ||
| 471 | if args.mixed_precision == "fp16": | ||
| 472 | weight_dtype = torch.float16 | ||
| 473 | elif args.mixed_precision == "bf16": | ||
| 474 | weight_dtype = torch.bfloat16 | ||
| 475 | |||
| 476 | trainer = partial( | ||
| 477 | train, | ||
| 478 | accelerator=accelerator, | ||
| 479 | unet=unet, | ||
| 480 | text_encoder=text_encoder, | ||
| 481 | vae=vae, | ||
| 482 | lora_layers=lora_layers, | ||
| 483 | noise_scheduler=noise_scheduler, | ||
| 484 | dtype=weight_dtype, | ||
| 485 | with_prior_preservation=args.num_class_images != 0, | ||
| 486 | prior_loss_weight=args.prior_loss_weight, | ||
| 487 | ) | ||
| 488 | |||
| 489 | checkpoint_output_dir = output_dir.joinpath("model") | ||
| 490 | sample_output_dir = output_dir.joinpath(f"samples") | ||
| 491 | |||
| 492 | datamodule = VlpnDataModule( | ||
| 493 | data_file=args.train_data_file, | ||
| 494 | batch_size=args.train_batch_size, | ||
| 495 | tokenizer=tokenizer, | ||
| 496 | class_subdir=args.class_image_dir, | ||
| 497 | num_class_images=args.num_class_images, | ||
| 498 | size=args.resolution, | ||
| 499 | num_buckets=args.num_buckets, | ||
| 500 | progressive_buckets=args.progressive_buckets, | ||
| 501 | bucket_step_size=args.bucket_step_size, | ||
| 502 | bucket_max_pixels=args.bucket_max_pixels, | ||
| 503 | dropout=args.tag_dropout, | ||
| 504 | shuffle=not args.no_tag_shuffle, | ||
| 505 | template_key=args.train_data_template, | ||
| 506 | valid_set_size=args.valid_set_size, | ||
| 507 | train_set_pad=args.train_set_pad, | ||
| 508 | valid_set_pad=args.valid_set_pad, | ||
| 509 | seed=args.seed, | ||
| 510 | filter=partial(keyword_filter, None, args.collection, args.exclude_collections), | ||
| 511 | dtype=weight_dtype | ||
| 512 | ) | ||
| 513 | datamodule.setup() | ||
| 514 | |||
| 515 | optimizer = optimizer_class( | ||
| 516 | lora_layers.parameters(), | ||
| 517 | lr=args.learning_rate, | ||
| 518 | betas=(args.adam_beta1, args.adam_beta2), | ||
| 519 | weight_decay=args.adam_weight_decay, | ||
| 520 | eps=args.adam_epsilon, | ||
| 521 | amsgrad=args.adam_amsgrad, | ||
| 522 | ) | ||
| 523 | |||
| 524 | lr_scheduler = get_scheduler( | ||
| 525 | args.lr_scheduler, | ||
| 526 | optimizer=optimizer, | ||
| 527 | num_training_steps_per_epoch=len(datamodule.train_dataloader), | ||
| 528 | gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
| 529 | min_lr=args.lr_min_lr, | ||
| 530 | warmup_func=args.lr_warmup_func, | ||
| 531 | annealing_func=args.lr_annealing_func, | ||
| 532 | warmup_exp=args.lr_warmup_exp, | ||
| 533 | annealing_exp=args.lr_annealing_exp, | ||
| 534 | cycles=args.lr_cycles, | ||
| 535 | end_lr=1e2, | ||
| 536 | train_epochs=args.num_train_epochs, | ||
| 537 | warmup_epochs=args.lr_warmup_epochs, | ||
| 538 | ) | ||
| 539 | |||
| 540 | metrics = trainer( | ||
| 541 | strategy=lora_strategy, | ||
| 542 | project="lora", | ||
| 543 | train_dataloader=datamodule.train_dataloader, | ||
| 544 | val_dataloader=datamodule.val_dataloader, | ||
| 545 | seed=args.seed, | ||
| 546 | optimizer=optimizer, | ||
| 547 | lr_scheduler=lr_scheduler, | ||
| 548 | num_train_epochs=args.num_train_epochs, | ||
| 549 | sample_frequency=args.sample_frequency, | ||
| 550 | # -- | ||
| 551 | tokenizer=tokenizer, | ||
| 552 | sample_scheduler=sample_scheduler, | ||
| 553 | sample_output_dir=sample_output_dir, | ||
| 554 | checkpoint_output_dir=checkpoint_output_dir, | ||
| 555 | max_grad_norm=args.max_grad_norm, | ||
| 556 | sample_batch_size=args.sample_batch_size, | ||
| 557 | sample_num_batches=args.sample_batches, | ||
| 558 | sample_num_steps=args.sample_steps, | ||
| 559 | sample_image_size=args.sample_image_size, | ||
| 560 | ) | ||
| 561 | |||
| 562 | plot_metrics(metrics, output_dir.joinpath("lr.png")) | ||
| 563 | |||
| 564 | |||
| 565 | if __name__ == "__main__": | ||
| 566 | main() | ||
diff --git a/train_ti.py b/train_ti.py index c118aab..56f9e97 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -166,7 +166,7 @@ def parse_args(): | |||
| 166 | parser.add_argument( | 166 | parser.add_argument( |
| 167 | "--tag_dropout", | 167 | "--tag_dropout", |
| 168 | type=float, | 168 | type=float, |
| 169 | default=0, | 169 | default=0.1, |
| 170 | help="Tag dropout probability.", | 170 | help="Tag dropout probability.", |
| 171 | ) | 171 | ) |
| 172 | parser.add_argument( | 172 | parser.add_argument( |
| @@ -414,7 +414,7 @@ def parse_args(): | |||
| 414 | ) | 414 | ) |
| 415 | parser.add_argument( | 415 | parser.add_argument( |
| 416 | "--emb_decay", | 416 | "--emb_decay", |
| 417 | default=1e0, | 417 | default=1e-2, |
| 418 | type=float, | 418 | type=float, |
| 419 | help="Embedding decay factor." | 419 | help="Embedding decay factor." |
| 420 | ) | 420 | ) |
| @@ -530,7 +530,7 @@ def main(): | |||
| 530 | 530 | ||
| 531 | vae.enable_slicing() | 531 | vae.enable_slicing() |
| 532 | vae.set_use_memory_efficient_attention_xformers(True) | 532 | vae.set_use_memory_efficient_attention_xformers(True) |
| 533 | unet.set_use_memory_efficient_attention_xformers(True) | 533 | unet.enable_xformers_memory_efficient_attention() |
| 534 | 534 | ||
| 535 | if args.gradient_checkpointing: | 535 | if args.gradient_checkpointing: |
| 536 | unet.enable_gradient_checkpointing() | 536 | unet.enable_gradient_checkpointing() |
| @@ -612,8 +612,10 @@ def main(): | |||
| 612 | 612 | ||
| 613 | if len(placeholder_tokens) == 1: | 613 | if len(placeholder_tokens) == 1: |
| 614 | sample_output_dir = output_dir.joinpath(f"samples_{placeholder_tokens[0]}") | 614 | sample_output_dir = output_dir.joinpath(f"samples_{placeholder_tokens[0]}") |
| 615 | metrics_output_file = output_dir.joinpath(f"{placeholder_tokens[0]}.png") | ||
| 615 | else: | 616 | else: |
| 616 | sample_output_dir = output_dir.joinpath("samples") | 617 | sample_output_dir = output_dir.joinpath("samples") |
| 618 | metrics_output_file = output_dir.joinpath(f"lr.png") | ||
| 617 | 619 | ||
| 618 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( | 620 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( |
| 619 | tokenizer=tokenizer, | 621 | tokenizer=tokenizer, |
| @@ -687,7 +689,7 @@ def main(): | |||
| 687 | placeholder_token_ids=placeholder_token_ids, | 689 | placeholder_token_ids=placeholder_token_ids, |
| 688 | ) | 690 | ) |
| 689 | 691 | ||
| 690 | plot_metrics(metrics, output_dir.joinpath("lr.png")) | 692 | plot_metrics(metrics, metrics_output_file) |
| 691 | 693 | ||
| 692 | if args.simultaneous: | 694 | if args.simultaneous: |
| 693 | run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template) | 695 | run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template) |
diff --git a/training/functional.py b/training/functional.py index c373ac9..8f47734 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -34,7 +34,7 @@ def const(result=None): | |||
| 34 | @dataclass | 34 | @dataclass |
| 35 | class TrainingCallbacks(): | 35 | class TrainingCallbacks(): |
| 36 | on_prepare: Callable[[], None] = const() | 36 | on_prepare: Callable[[], None] = const() |
| 37 | on_model: Callable[[], torch.nn.Module] = const(None) | 37 | on_accum_model: Callable[[], torch.nn.Module] = const(None) |
| 38 | on_log: Callable[[], dict[str, Any]] = const({}) | 38 | on_log: Callable[[], dict[str, Any]] = const({}) |
| 39 | on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) | 39 | on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) |
| 40 | on_before_optimize: Callable[[float, int], None] = const() | 40 | on_before_optimize: Callable[[float, int], None] = const() |
| @@ -51,7 +51,11 @@ class TrainingStrategyPrepareCallable(Protocol): | |||
| 51 | accelerator: Accelerator, | 51 | accelerator: Accelerator, |
| 52 | text_encoder: CLIPTextModel, | 52 | text_encoder: CLIPTextModel, |
| 53 | unet: UNet2DConditionModel, | 53 | unet: UNet2DConditionModel, |
| 54 | *args | 54 | optimizer: torch.optim.Optimizer, |
| 55 | train_dataloader: DataLoader, | ||
| 56 | val_dataloader: Optional[DataLoader], | ||
| 57 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | ||
| 58 | **kwargs | ||
| 55 | ) -> Tuple: ... | 59 | ) -> Tuple: ... |
| 56 | 60 | ||
| 57 | 61 | ||
| @@ -92,7 +96,6 @@ def save_samples( | |||
| 92 | sample_scheduler: DPMSolverMultistepScheduler, | 96 | sample_scheduler: DPMSolverMultistepScheduler, |
| 93 | train_dataloader: DataLoader, | 97 | train_dataloader: DataLoader, |
| 94 | val_dataloader: Optional[DataLoader], | 98 | val_dataloader: Optional[DataLoader], |
| 95 | dtype: torch.dtype, | ||
| 96 | output_dir: Path, | 99 | output_dir: Path, |
| 97 | seed: int, | 100 | seed: int, |
| 98 | step: int, | 101 | step: int, |
| @@ -107,15 +110,6 @@ def save_samples( | |||
| 107 | grid_cols = min(batch_size, 4) | 110 | grid_cols = min(batch_size, 4) |
| 108 | grid_rows = (num_batches * batch_size) // grid_cols | 111 | grid_rows = (num_batches * batch_size) // grid_cols |
| 109 | 112 | ||
| 110 | unet = accelerator.unwrap_model(unet) | ||
| 111 | text_encoder = accelerator.unwrap_model(text_encoder) | ||
| 112 | |||
| 113 | orig_unet_dtype = unet.dtype | ||
| 114 | orig_text_encoder_dtype = text_encoder.dtype | ||
| 115 | |||
| 116 | unet.to(dtype=dtype) | ||
| 117 | text_encoder.to(dtype=dtype) | ||
| 118 | |||
| 119 | pipeline = VlpnStableDiffusion( | 113 | pipeline = VlpnStableDiffusion( |
| 120 | text_encoder=text_encoder, | 114 | text_encoder=text_encoder, |
| 121 | vae=vae, | 115 | vae=vae, |
| @@ -172,11 +166,6 @@ def save_samples( | |||
| 172 | image_grid = make_grid(all_samples, grid_rows, grid_cols) | 166 | image_grid = make_grid(all_samples, grid_rows, grid_cols) |
| 173 | image_grid.save(file_path, quality=85) | 167 | image_grid.save(file_path, quality=85) |
| 174 | 168 | ||
| 175 | unet.to(dtype=orig_unet_dtype) | ||
| 176 | text_encoder.to(dtype=orig_text_encoder_dtype) | ||
| 177 | |||
| 178 | del unet | ||
| 179 | del text_encoder | ||
| 180 | del generator | 169 | del generator |
| 181 | del pipeline | 170 | del pipeline |
| 182 | 171 | ||
| @@ -393,7 +382,7 @@ def train_loop( | |||
| 393 | ) | 382 | ) |
| 394 | global_progress_bar.set_description("Total progress") | 383 | global_progress_bar.set_description("Total progress") |
| 395 | 384 | ||
| 396 | model = callbacks.on_model() | 385 | model = callbacks.on_accum_model() |
| 397 | on_log = callbacks.on_log | 386 | on_log = callbacks.on_log |
| 398 | on_train = callbacks.on_train | 387 | on_train = callbacks.on_train |
| 399 | on_before_optimize = callbacks.on_before_optimize | 388 | on_before_optimize = callbacks.on_before_optimize |
| @@ -559,8 +548,10 @@ def train( | |||
| 559 | prior_loss_weight: float = 1.0, | 548 | prior_loss_weight: float = 1.0, |
| 560 | **kwargs, | 549 | **kwargs, |
| 561 | ): | 550 | ): |
| 562 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = strategy.prepare( | 551 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare( |
| 563 | accelerator, text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) | 552 | accelerator, text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, **kwargs) |
| 553 | |||
| 554 | kwargs.update(extra) | ||
| 564 | 555 | ||
| 565 | vae.to(accelerator.device, dtype=dtype) | 556 | vae.to(accelerator.device, dtype=dtype) |
| 566 | 557 | ||
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index e88bf90..b4c77f3 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py | |||
| @@ -61,14 +61,11 @@ def dreambooth_strategy_callbacks( | |||
| 61 | save_samples_ = partial( | 61 | save_samples_ = partial( |
| 62 | save_samples, | 62 | save_samples, |
| 63 | accelerator=accelerator, | 63 | accelerator=accelerator, |
| 64 | unet=unet, | ||
| 65 | text_encoder=text_encoder, | ||
| 66 | tokenizer=tokenizer, | 64 | tokenizer=tokenizer, |
| 67 | vae=vae, | 65 | vae=vae, |
| 68 | sample_scheduler=sample_scheduler, | 66 | sample_scheduler=sample_scheduler, |
| 69 | train_dataloader=train_dataloader, | 67 | train_dataloader=train_dataloader, |
| 70 | val_dataloader=val_dataloader, | 68 | val_dataloader=val_dataloader, |
| 71 | dtype=weight_dtype, | ||
| 72 | output_dir=sample_output_dir, | 69 | output_dir=sample_output_dir, |
| 73 | seed=seed, | 70 | seed=seed, |
| 74 | batch_size=sample_batch_size, | 71 | batch_size=sample_batch_size, |
| @@ -94,7 +91,7 @@ def dreambooth_strategy_callbacks( | |||
| 94 | else: | 91 | else: |
| 95 | return nullcontext() | 92 | return nullcontext() |
| 96 | 93 | ||
| 97 | def on_model(): | 94 | def on_accum_model(): |
| 98 | return unet | 95 | return unet |
| 99 | 96 | ||
| 100 | def on_prepare(): | 97 | def on_prepare(): |
| @@ -172,11 +169,29 @@ def dreambooth_strategy_callbacks( | |||
| 172 | @torch.no_grad() | 169 | @torch.no_grad() |
| 173 | def on_sample(step): | 170 | def on_sample(step): |
| 174 | with ema_context(): | 171 | with ema_context(): |
| 175 | save_samples_(step=step) | 172 | unet_ = accelerator.unwrap_model(unet) |
| 173 | text_encoder_ = accelerator.unwrap_model(text_encoder) | ||
| 174 | |||
| 175 | orig_unet_dtype = unet_.dtype | ||
| 176 | orig_text_encoder_dtype = text_encoder_.dtype | ||
| 177 | |||
| 178 | unet_.to(dtype=weight_dtype) | ||
| 179 | text_encoder_.to(dtype=weight_dtype) | ||
| 180 | |||
| 181 | save_samples_(step=step, unet=unet_, text_encoder=text_encoder_) | ||
| 182 | |||
| 183 | unet_.to(dtype=orig_unet_dtype) | ||
| 184 | text_encoder_.to(dtype=orig_text_encoder_dtype) | ||
| 185 | |||
| 186 | del unet_ | ||
| 187 | del text_encoder_ | ||
| 188 | |||
| 189 | if torch.cuda.is_available(): | ||
| 190 | torch.cuda.empty_cache() | ||
| 176 | 191 | ||
| 177 | return TrainingCallbacks( | 192 | return TrainingCallbacks( |
| 178 | on_prepare=on_prepare, | 193 | on_prepare=on_prepare, |
| 179 | on_model=on_model, | 194 | on_accum_model=on_accum_model, |
| 180 | on_train=on_train, | 195 | on_train=on_train, |
| 181 | on_eval=on_eval, | 196 | on_eval=on_eval, |
| 182 | on_before_optimize=on_before_optimize, | 197 | on_before_optimize=on_before_optimize, |
| @@ -191,9 +206,13 @@ def dreambooth_prepare( | |||
| 191 | accelerator: Accelerator, | 206 | accelerator: Accelerator, |
| 192 | text_encoder: CLIPTextModel, | 207 | text_encoder: CLIPTextModel, |
| 193 | unet: UNet2DConditionModel, | 208 | unet: UNet2DConditionModel, |
| 194 | *args | 209 | optimizer: torch.optim.Optimizer, |
| 210 | train_dataloader: DataLoader, | ||
| 211 | val_dataloader: Optional[DataLoader], | ||
| 212 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | ||
| 213 | **kwargs | ||
| 195 | ): | 214 | ): |
| 196 | return accelerator.prepare(text_encoder, unet, *args) | 215 | return accelerator.prepare(text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) + ({}) |
| 197 | 216 | ||
| 198 | 217 | ||
| 199 | dreambooth_strategy = TrainingStrategy( | 218 | dreambooth_strategy = TrainingStrategy( |
diff --git a/training/strategy/lora.py b/training/strategy/lora.py new file mode 100644 index 0000000..88d1824 --- /dev/null +++ b/training/strategy/lora.py | |||
| @@ -0,0 +1,147 @@ | |||
| 1 | from contextlib import nullcontext | ||
| 2 | from typing import Optional | ||
| 3 | from functools import partial | ||
| 4 | from contextlib import contextmanager, nullcontext | ||
| 5 | from pathlib import Path | ||
| 6 | |||
| 7 | import torch | ||
| 8 | import torch.nn as nn | ||
| 9 | from torch.utils.data import DataLoader | ||
| 10 | |||
| 11 | from accelerate import Accelerator | ||
| 12 | from transformers import CLIPTextModel | ||
| 13 | from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepScheduler | ||
| 14 | from diffusers.loaders import AttnProcsLayers | ||
| 15 | |||
| 16 | from slugify import slugify | ||
| 17 | |||
| 18 | from models.clip.tokenizer import MultiCLIPTokenizer | ||
| 19 | from training.util import EMAModel | ||
| 20 | from training.functional import TrainingStrategy, TrainingCallbacks, save_samples | ||
| 21 | |||
| 22 | |||
| 23 | def lora_strategy_callbacks( | ||
| 24 | accelerator: Accelerator, | ||
| 25 | unet: UNet2DConditionModel, | ||
| 26 | text_encoder: CLIPTextModel, | ||
| 27 | tokenizer: MultiCLIPTokenizer, | ||
| 28 | vae: AutoencoderKL, | ||
| 29 | sample_scheduler: DPMSolverMultistepScheduler, | ||
| 30 | train_dataloader: DataLoader, | ||
| 31 | val_dataloader: Optional[DataLoader], | ||
| 32 | sample_output_dir: Path, | ||
| 33 | checkpoint_output_dir: Path, | ||
| 34 | seed: int, | ||
| 35 | lora_layers: AttnProcsLayers, | ||
| 36 | max_grad_norm: float = 1.0, | ||
| 37 | sample_batch_size: int = 1, | ||
| 38 | sample_num_batches: int = 1, | ||
| 39 | sample_num_steps: int = 20, | ||
| 40 | sample_guidance_scale: float = 7.5, | ||
| 41 | sample_image_size: Optional[int] = None, | ||
| 42 | ): | ||
| 43 | sample_output_dir.mkdir(parents=True, exist_ok=True) | ||
| 44 | checkpoint_output_dir.mkdir(parents=True, exist_ok=True) | ||
| 45 | |||
| 46 | weight_dtype = torch.float32 | ||
| 47 | if accelerator.state.mixed_precision == "fp16": | ||
| 48 | weight_dtype = torch.float16 | ||
| 49 | elif accelerator.state.mixed_precision == "bf16": | ||
| 50 | weight_dtype = torch.bfloat16 | ||
| 51 | |||
| 52 | save_samples_ = partial( | ||
| 53 | save_samples, | ||
| 54 | accelerator=accelerator, | ||
| 55 | unet=unet, | ||
| 56 | text_encoder=text_encoder, | ||
| 57 | tokenizer=tokenizer, | ||
| 58 | vae=vae, | ||
| 59 | sample_scheduler=sample_scheduler, | ||
| 60 | train_dataloader=train_dataloader, | ||
| 61 | val_dataloader=val_dataloader, | ||
| 62 | output_dir=sample_output_dir, | ||
| 63 | seed=seed, | ||
| 64 | batch_size=sample_batch_size, | ||
| 65 | num_batches=sample_num_batches, | ||
| 66 | num_steps=sample_num_steps, | ||
| 67 | guidance_scale=sample_guidance_scale, | ||
| 68 | image_size=sample_image_size, | ||
| 69 | ) | ||
| 70 | |||
| 71 | def on_prepare(): | ||
| 72 | lora_layers.requires_grad_(True) | ||
| 73 | |||
| 74 | def on_accum_model(): | ||
| 75 | return unet | ||
| 76 | |||
| 77 | @contextmanager | ||
| 78 | def on_train(epoch: int): | ||
| 79 | tokenizer.train() | ||
| 80 | yield | ||
| 81 | |||
| 82 | @contextmanager | ||
| 83 | def on_eval(): | ||
| 84 | tokenizer.eval() | ||
| 85 | yield | ||
| 86 | |||
| 87 | def on_before_optimize(lr: float, epoch: int): | ||
| 88 | if accelerator.sync_gradients: | ||
| 89 | accelerator.clip_grad_norm_(lora_layers.parameters(), max_grad_norm) | ||
| 90 | |||
| 91 | @torch.no_grad() | ||
| 92 | def on_checkpoint(step, postfix): | ||
| 93 | print(f"Saving checkpoint for step {step}...") | ||
| 94 | orig_unet_dtype = unet.dtype | ||
| 95 | unet.to(dtype=torch.float32) | ||
| 96 | unet.save_attn_procs(checkpoint_output_dir.joinpath(f"{step}_{postfix}")) | ||
| 97 | unet.to(dtype=orig_unet_dtype) | ||
| 98 | |||
| 99 | @torch.no_grad() | ||
| 100 | def on_sample(step): | ||
| 101 | orig_unet_dtype = unet.dtype | ||
| 102 | unet.to(dtype=weight_dtype) | ||
| 103 | save_samples_(step=step) | ||
| 104 | unet.to(dtype=orig_unet_dtype) | ||
| 105 | |||
| 106 | if torch.cuda.is_available(): | ||
| 107 | torch.cuda.empty_cache() | ||
| 108 | |||
| 109 | return TrainingCallbacks( | ||
| 110 | on_prepare=on_prepare, | ||
| 111 | on_accum_model=on_accum_model, | ||
| 112 | on_train=on_train, | ||
| 113 | on_eval=on_eval, | ||
| 114 | on_before_optimize=on_before_optimize, | ||
| 115 | on_checkpoint=on_checkpoint, | ||
| 116 | on_sample=on_sample, | ||
| 117 | ) | ||
| 118 | |||
| 119 | |||
| 120 | def lora_prepare( | ||
| 121 | accelerator: Accelerator, | ||
| 122 | text_encoder: CLIPTextModel, | ||
| 123 | unet: UNet2DConditionModel, | ||
| 124 | optimizer: torch.optim.Optimizer, | ||
| 125 | train_dataloader: DataLoader, | ||
| 126 | val_dataloader: Optional[DataLoader], | ||
| 127 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | ||
| 128 | lora_layers: AttnProcsLayers, | ||
| 129 | **kwargs | ||
| 130 | ): | ||
| 131 | weight_dtype = torch.float32 | ||
| 132 | if accelerator.state.mixed_precision == "fp16": | ||
| 133 | weight_dtype = torch.float16 | ||
| 134 | elif accelerator.state.mixed_precision == "bf16": | ||
| 135 | weight_dtype = torch.bfloat16 | ||
| 136 | |||
| 137 | lora_layers, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( | ||
| 138 | lora_layers, optimizer, train_dataloader, val_dataloader, lr_scheduler) | ||
| 139 | unet.to(accelerator.device, dtype=weight_dtype) | ||
| 140 | text_encoder.to(accelerator.device, dtype=weight_dtype) | ||
| 141 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {"lora_layers": lora_layers} | ||
| 142 | |||
| 143 | |||
| 144 | lora_strategy = TrainingStrategy( | ||
| 145 | callbacks=lora_strategy_callbacks, | ||
| 146 | prepare=lora_prepare, | ||
| 147 | ) | ||
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 14bdafd..d306f18 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
| @@ -59,14 +59,11 @@ def textual_inversion_strategy_callbacks( | |||
| 59 | save_samples_ = partial( | 59 | save_samples_ = partial( |
| 60 | save_samples, | 60 | save_samples, |
| 61 | accelerator=accelerator, | 61 | accelerator=accelerator, |
| 62 | unet=unet, | ||
| 63 | text_encoder=text_encoder, | ||
| 64 | tokenizer=tokenizer, | 62 | tokenizer=tokenizer, |
| 65 | vae=vae, | 63 | vae=vae, |
| 66 | sample_scheduler=sample_scheduler, | 64 | sample_scheduler=sample_scheduler, |
| 67 | train_dataloader=train_dataloader, | 65 | train_dataloader=train_dataloader, |
| 68 | val_dataloader=val_dataloader, | 66 | val_dataloader=val_dataloader, |
| 69 | dtype=weight_dtype, | ||
| 70 | output_dir=sample_output_dir, | 67 | output_dir=sample_output_dir, |
| 71 | seed=seed, | 68 | seed=seed, |
| 72 | batch_size=sample_batch_size, | 69 | batch_size=sample_batch_size, |
| @@ -94,7 +91,7 @@ def textual_inversion_strategy_callbacks( | |||
| 94 | else: | 91 | else: |
| 95 | return nullcontext() | 92 | return nullcontext() |
| 96 | 93 | ||
| 97 | def on_model(): | 94 | def on_accum_model(): |
| 98 | return text_encoder.text_model.embeddings.temp_token_embedding | 95 | return text_encoder.text_model.embeddings.temp_token_embedding |
| 99 | 96 | ||
| 100 | def on_prepare(): | 97 | def on_prepare(): |
| @@ -149,11 +146,29 @@ def textual_inversion_strategy_callbacks( | |||
| 149 | @torch.no_grad() | 146 | @torch.no_grad() |
| 150 | def on_sample(step): | 147 | def on_sample(step): |
| 151 | with ema_context(): | 148 | with ema_context(): |
| 152 | save_samples_(step=step) | 149 | unet_ = accelerator.unwrap_model(unet) |
| 150 | text_encoder_ = accelerator.unwrap_model(text_encoder) | ||
| 151 | |||
| 152 | orig_unet_dtype = unet_.dtype | ||
| 153 | orig_text_encoder_dtype = text_encoder_.dtype | ||
| 154 | |||
| 155 | unet_.to(dtype=weight_dtype) | ||
| 156 | text_encoder_.to(dtype=weight_dtype) | ||
| 157 | |||
| 158 | save_samples_(step=step, unet=unet_, text_encoder=text_encoder_) | ||
| 159 | |||
| 160 | unet_.to(dtype=orig_unet_dtype) | ||
| 161 | text_encoder_.to(dtype=orig_text_encoder_dtype) | ||
| 162 | |||
| 163 | del unet_ | ||
| 164 | del text_encoder_ | ||
| 165 | |||
| 166 | if torch.cuda.is_available(): | ||
| 167 | torch.cuda.empty_cache() | ||
| 153 | 168 | ||
| 154 | return TrainingCallbacks( | 169 | return TrainingCallbacks( |
| 155 | on_prepare=on_prepare, | 170 | on_prepare=on_prepare, |
| 156 | on_model=on_model, | 171 | on_accum_model=on_accum_model, |
| 157 | on_train=on_train, | 172 | on_train=on_train, |
| 158 | on_eval=on_eval, | 173 | on_eval=on_eval, |
| 159 | on_before_optimize=on_before_optimize, | 174 | on_before_optimize=on_before_optimize, |
| @@ -168,7 +183,11 @@ def textual_inversion_prepare( | |||
| 168 | accelerator: Accelerator, | 183 | accelerator: Accelerator, |
| 169 | text_encoder: CLIPTextModel, | 184 | text_encoder: CLIPTextModel, |
| 170 | unet: UNet2DConditionModel, | 185 | unet: UNet2DConditionModel, |
| 171 | *args | 186 | optimizer: torch.optim.Optimizer, |
| 187 | train_dataloader: DataLoader, | ||
| 188 | val_dataloader: Optional[DataLoader], | ||
| 189 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | ||
| 190 | **kwargs | ||
| 172 | ): | 191 | ): |
| 173 | weight_dtype = torch.float32 | 192 | weight_dtype = torch.float32 |
| 174 | if accelerator.state.mixed_precision == "fp16": | 193 | if accelerator.state.mixed_precision == "fp16": |
| @@ -176,9 +195,10 @@ def textual_inversion_prepare( | |||
| 176 | elif accelerator.state.mixed_precision == "bf16": | 195 | elif accelerator.state.mixed_precision == "bf16": |
| 177 | weight_dtype = torch.bfloat16 | 196 | weight_dtype = torch.bfloat16 |
| 178 | 197 | ||
| 179 | prepped = accelerator.prepare(text_encoder, *args) | 198 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( |
| 199 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler) | ||
| 180 | unet.to(accelerator.device, dtype=weight_dtype) | 200 | unet.to(accelerator.device, dtype=weight_dtype) |
| 181 | return (prepped[0], unet) + prepped[1:] | 201 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {} |
| 182 | 202 | ||
| 183 | 203 | ||
| 184 | textual_inversion_strategy = TrainingStrategy( | 204 | textual_inversion_strategy = TrainingStrategy( |
