diff options
| author | Volpeon <git@volpeon.ink> | 2023-06-21 13:28:49 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-06-21 13:28:49 +0200 |
| commit | 8364ce697ddf6117fdd4f7222832d546d63880de (patch) | |
| tree | 152c99815bbd8b2659d0dabe63c98f63151c97c2 /train_dreambooth.py | |
| parent | Fix LoRA training with DAdan (diff) | |
| download | textual-inversion-diff-8364ce697ddf6117fdd4f7222832d546d63880de.tar.gz textual-inversion-diff-8364ce697ddf6117fdd4f7222832d546d63880de.tar.bz2 textual-inversion-diff-8364ce697ddf6117fdd4f7222832d546d63880de.zip | |
Update
Diffstat (limited to 'train_dreambooth.py')
| -rw-r--r-- | train_dreambooth.py | 770 |
1 files changed, 568 insertions, 202 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 2aca1e7..659b84c 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -5,34 +5,70 @@ import itertools | |||
| 5 | from pathlib import Path | 5 | from pathlib import Path |
| 6 | from functools import partial | 6 | from functools import partial |
| 7 | import math | 7 | import math |
| 8 | import warnings | ||
| 8 | 9 | ||
| 9 | import torch | 10 | import torch |
| 11 | import torch._dynamo | ||
| 10 | import torch.utils.checkpoint | 12 | import torch.utils.checkpoint |
| 13 | import hidet | ||
| 11 | 14 | ||
| 12 | from accelerate import Accelerator | 15 | from accelerate import Accelerator |
| 13 | from accelerate.logging import get_logger | 16 | from accelerate.logging import get_logger |
| 14 | from accelerate.utils import LoggerType, set_seed | 17 | from accelerate.utils import LoggerType, set_seed |
| 15 | from slugify import slugify | 18 | |
| 19 | # from diffusers.models.attention_processor import AttnProcessor | ||
| 20 | from diffusers.utils.import_utils import is_xformers_available | ||
| 16 | import transformers | 21 | import transformers |
| 17 | 22 | ||
| 18 | from util.files import load_config, load_embeddings_from_dir | 23 | import numpy as np |
| 24 | from slugify import slugify | ||
| 25 | |||
| 19 | from data.csv import VlpnDataModule, keyword_filter | 26 | from data.csv import VlpnDataModule, keyword_filter |
| 20 | from training.functional import train, get_models | 27 | from models.clip.embeddings import patch_managed_embeddings |
| 28 | from training.functional import train, add_placeholder_tokens, get_models | ||
| 21 | from training.strategy.dreambooth import dreambooth_strategy | 29 | from training.strategy.dreambooth import dreambooth_strategy |
| 22 | from training.optimization import get_scheduler | 30 | from training.optimization import get_scheduler |
| 23 | from training.util import save_args | 31 | from training.sampler import create_named_schedule_sampler |
| 32 | from training.util import AverageMeter, save_args | ||
| 33 | from util.files import load_config, load_embeddings_from_dir | ||
| 34 | |||
| 24 | 35 | ||
| 25 | logger = get_logger(__name__) | 36 | logger = get_logger(__name__) |
| 26 | 37 | ||
| 38 | warnings.filterwarnings("ignore") | ||
| 39 | |||
| 27 | 40 | ||
| 28 | torch.backends.cuda.matmul.allow_tf32 = True | 41 | torch.backends.cuda.matmul.allow_tf32 = True |
| 29 | torch.backends.cudnn.benchmark = True | 42 | torch.backends.cudnn.benchmark = True |
| 30 | 43 | ||
| 44 | # torch._dynamo.config.log_level = logging.WARNING | ||
| 45 | torch._dynamo.config.suppress_errors = True | ||
| 46 | |||
| 47 | hidet.torch.dynamo_config.use_tensor_core(True) | ||
| 48 | hidet.torch.dynamo_config.search_space(0) | ||
| 49 | |||
| 50 | |||
| 51 | def patch_xformers(dtype): | ||
| 52 | if is_xformers_available(): | ||
| 53 | import xformers | ||
| 54 | import xformers.ops | ||
| 55 | |||
| 56 | orig_xformers_memory_efficient_attention = ( | ||
| 57 | xformers.ops.memory_efficient_attention | ||
| 58 | ) | ||
| 59 | |||
| 60 | def xformers_memory_efficient_attention( | ||
| 61 | query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, **kwargs | ||
| 62 | ): | ||
| 63 | return orig_xformers_memory_efficient_attention( | ||
| 64 | query.to(dtype), key.to(dtype), value.to(dtype), **kwargs | ||
| 65 | ) | ||
| 66 | |||
| 67 | xformers.ops.memory_efficient_attention = xformers_memory_efficient_attention | ||
| 68 | |||
| 31 | 69 | ||
| 32 | def parse_args(): | 70 | def parse_args(): |
| 33 | parser = argparse.ArgumentParser( | 71 | parser = argparse.ArgumentParser(description="Simple example of a training script.") |
| 34 | description="Simple example of a training script." | ||
| 35 | ) | ||
| 36 | parser.add_argument( | 72 | parser.add_argument( |
| 37 | "--pretrained_model_name_or_path", | 73 | "--pretrained_model_name_or_path", |
| 38 | type=str, | 74 | type=str, |
| @@ -49,7 +85,7 @@ def parse_args(): | |||
| 49 | "--train_data_file", | 85 | "--train_data_file", |
| 50 | type=str, | 86 | type=str, |
| 51 | default=None, | 87 | default=None, |
| 52 | help="A folder containing the training data." | 88 | help="A folder containing the training data.", |
| 53 | ) | 89 | ) |
| 54 | parser.add_argument( | 90 | parser.add_argument( |
| 55 | "--train_data_template", | 91 | "--train_data_template", |
| @@ -60,13 +96,13 @@ def parse_args(): | |||
| 60 | "--train_set_pad", | 96 | "--train_set_pad", |
| 61 | type=int, | 97 | type=int, |
| 62 | default=None, | 98 | default=None, |
| 63 | help="The number to fill train dataset items up to." | 99 | help="The number to fill train dataset items up to.", |
| 64 | ) | 100 | ) |
| 65 | parser.add_argument( | 101 | parser.add_argument( |
| 66 | "--valid_set_pad", | 102 | "--valid_set_pad", |
| 67 | type=int, | 103 | type=int, |
| 68 | default=None, | 104 | default=None, |
| 69 | help="The number to fill validation dataset items up to." | 105 | help="The number to fill validation dataset items up to.", |
| 70 | ) | 106 | ) |
| 71 | parser.add_argument( | 107 | parser.add_argument( |
| 72 | "--project", | 108 | "--project", |
| @@ -75,20 +111,58 @@ def parse_args(): | |||
| 75 | help="The name of the current project.", | 111 | help="The name of the current project.", |
| 76 | ) | 112 | ) |
| 77 | parser.add_argument( | 113 | parser.add_argument( |
| 78 | "--exclude_collections", | 114 | "--auto_cycles", type=str, default="o", help="Cycles to run automatically." |
| 115 | ) | ||
| 116 | parser.add_argument( | ||
| 117 | "--cycle_decay", type=float, default=1.0, help="Learning rate decay per cycle." | ||
| 118 | ) | ||
| 119 | parser.add_argument( | ||
| 120 | "--placeholder_tokens", | ||
| 79 | type=str, | 121 | type=str, |
| 80 | nargs='*', | 122 | nargs="*", |
| 81 | help="Exclude all items with a listed collection.", | 123 | help="A token to use as a placeholder for the concept.", |
| 82 | ) | 124 | ) |
| 83 | parser.add_argument( | 125 | parser.add_argument( |
| 84 | "--train_text_encoder_epochs", | 126 | "--initializer_tokens", |
| 85 | default=999999, | 127 | type=str, |
| 86 | help="Number of epochs the text encoder will be trained." | 128 | nargs="*", |
| 129 | help="A token to use as initializer word.", | ||
| 130 | ) | ||
| 131 | parser.add_argument( | ||
| 132 | "--filter_tokens", type=str, nargs="*", help="Tokens to filter the dataset by." | ||
| 133 | ) | ||
| 134 | parser.add_argument( | ||
| 135 | "--initializer_noise", | ||
| 136 | type=float, | ||
| 137 | default=0, | ||
| 138 | help="Noise to apply to the initializer word", | ||
| 139 | ) | ||
| 140 | parser.add_argument( | ||
| 141 | "--alias_tokens", | ||
| 142 | type=str, | ||
| 143 | nargs="*", | ||
| 144 | default=[], | ||
| 145 | help="Tokens to create an alias for.", | ||
| 146 | ) | ||
| 147 | parser.add_argument( | ||
| 148 | "--inverted_initializer_tokens", | ||
| 149 | type=str, | ||
| 150 | nargs="*", | ||
| 151 | help="A token to use as initializer word.", | ||
| 152 | ) | ||
| 153 | parser.add_argument( | ||
| 154 | "--num_vectors", type=int, nargs="*", help="Number of vectors per embedding." | ||
| 155 | ) | ||
| 156 | parser.add_argument( | ||
| 157 | "--exclude_collections", | ||
| 158 | type=str, | ||
| 159 | nargs="*", | ||
| 160 | help="Exclude all items with a listed collection.", | ||
| 87 | ) | 161 | ) |
| 88 | parser.add_argument( | 162 | parser.add_argument( |
| 89 | "--num_buckets", | 163 | "--num_buckets", |
| 90 | type=int, | 164 | type=int, |
| 91 | default=0, | 165 | default=2, |
| 92 | help="Number of aspect ratio buckets in either direction.", | 166 | help="Number of aspect ratio buckets in either direction.", |
| 93 | ) | 167 | ) |
| 94 | parser.add_argument( | 168 | parser.add_argument( |
| @@ -120,19 +194,6 @@ def parse_args(): | |||
| 120 | help="Shuffle tags.", | 194 | help="Shuffle tags.", |
| 121 | ) | 195 | ) |
| 122 | parser.add_argument( | 196 | parser.add_argument( |
| 123 | "--vector_dropout", | ||
| 124 | type=int, | ||
| 125 | default=0, | ||
| 126 | help="Vector dropout probability.", | ||
| 127 | ) | ||
| 128 | parser.add_argument( | ||
| 129 | "--vector_shuffle", | ||
| 130 | type=str, | ||
| 131 | default="auto", | ||
| 132 | choices=["all", "trailing", "leading", "between", "auto", "off"], | ||
| 133 | help='Vector shuffling algorithm.', | ||
| 134 | ) | ||
| 135 | parser.add_argument( | ||
| 136 | "--guidance_scale", | 197 | "--guidance_scale", |
| 137 | type=float, | 198 | type=float, |
| 138 | default=0, | 199 | default=0, |
| @@ -141,7 +202,7 @@ def parse_args(): | |||
| 141 | "--num_class_images", | 202 | "--num_class_images", |
| 142 | type=int, | 203 | type=int, |
| 143 | default=0, | 204 | default=0, |
| 144 | help="How many class images to generate." | 205 | help="How many class images to generate.", |
| 145 | ) | 206 | ) |
| 146 | parser.add_argument( | 207 | parser.add_argument( |
| 147 | "--class_image_dir", | 208 | "--class_image_dir", |
| @@ -162,16 +223,18 @@ def parse_args(): | |||
| 162 | help="The embeddings directory where Textual Inversion embeddings are stored.", | 223 | help="The embeddings directory where Textual Inversion embeddings are stored.", |
| 163 | ) | 224 | ) |
| 164 | parser.add_argument( | 225 | parser.add_argument( |
| 226 | "--train_dir_embeddings", | ||
| 227 | action="store_true", | ||
| 228 | help="Train embeddings loaded from embeddings directory.", | ||
| 229 | ) | ||
| 230 | parser.add_argument( | ||
| 165 | "--collection", | 231 | "--collection", |
| 166 | type=str, | 232 | type=str, |
| 167 | nargs='*', | 233 | nargs="*", |
| 168 | help="A collection to filter the dataset.", | 234 | help="A collection to filter the dataset.", |
| 169 | ) | 235 | ) |
| 170 | parser.add_argument( | 236 | parser.add_argument( |
| 171 | "--seed", | 237 | "--seed", type=int, default=None, help="A seed for reproducible training." |
| 172 | type=int, | ||
| 173 | default=None, | ||
| 174 | help="A seed for reproducible training." | ||
| 175 | ) | 238 | ) |
| 176 | parser.add_argument( | 239 | parser.add_argument( |
| 177 | "--resolution", | 240 | "--resolution", |
| @@ -189,15 +252,13 @@ def parse_args(): | |||
| 189 | help="Perlin offset noise strength.", | 252 | help="Perlin offset noise strength.", |
| 190 | ) | 253 | ) |
| 191 | parser.add_argument( | 254 | parser.add_argument( |
| 192 | "--num_train_epochs", | 255 | "--input_pertubation", |
| 193 | type=int, | 256 | type=float, |
| 194 | default=None | 257 | default=0, |
| 195 | ) | 258 | help="The scale of input pretubation. Recommended 0.1.", |
| 196 | parser.add_argument( | ||
| 197 | "--num_train_steps", | ||
| 198 | type=int, | ||
| 199 | default=2000 | ||
| 200 | ) | 259 | ) |
| 260 | parser.add_argument("--num_train_epochs", type=int, default=None) | ||
| 261 | parser.add_argument("--num_train_steps", type=int, default=2000) | ||
| 201 | parser.add_argument( | 262 | parser.add_argument( |
| 202 | "--gradient_accumulation_steps", | 263 | "--gradient_accumulation_steps", |
| 203 | type=int, | 264 | type=int, |
| @@ -205,9 +266,9 @@ def parse_args(): | |||
| 205 | help="Number of updates steps to accumulate before performing a backward/update pass.", | 266 | help="Number of updates steps to accumulate before performing a backward/update pass.", |
| 206 | ) | 267 | ) |
| 207 | parser.add_argument( | 268 | parser.add_argument( |
| 208 | "--gradient_checkpointing", | 269 | "--train_text_encoder_cycles", |
| 209 | action="store_true", | 270 | default=999999, |
| 210 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", | 271 | help="Number of epochs the text encoder will be trained.", |
| 211 | ) | 272 | ) |
| 212 | parser.add_argument( | 273 | parser.add_argument( |
| 213 | "--find_lr", | 274 | "--find_lr", |
| @@ -215,9 +276,15 @@ def parse_args(): | |||
| 215 | help="Automatically find a learning rate (no training).", | 276 | help="Automatically find a learning rate (no training).", |
| 216 | ) | 277 | ) |
| 217 | parser.add_argument( | 278 | parser.add_argument( |
| 218 | "--learning_rate", | 279 | "--learning_rate_unet", |
| 219 | type=float, | 280 | type=float, |
| 220 | default=2e-6, | 281 | default=1e-4, |
| 282 | help="Initial learning rate (after the potential warmup period) to use.", | ||
| 283 | ) | ||
| 284 | parser.add_argument( | ||
| 285 | "--learning_rate_text", | ||
| 286 | type=float, | ||
| 287 | default=5e-5, | ||
| 221 | help="Initial learning rate (after the potential warmup period) to use.", | 288 | help="Initial learning rate (after the potential warmup period) to use.", |
| 222 | ) | 289 | ) |
| 223 | parser.add_argument( | 290 | parser.add_argument( |
| @@ -229,27 +296,31 @@ def parse_args(): | |||
| 229 | "--lr_scheduler", | 296 | "--lr_scheduler", |
| 230 | type=str, | 297 | type=str, |
| 231 | default="one_cycle", | 298 | default="one_cycle", |
| 232 | choices=["linear", "cosine", "cosine_with_restarts", "polynomial", | 299 | choices=[ |
| 233 | "constant", "constant_with_warmup", "one_cycle"], | 300 | "linear", |
| 234 | help='The scheduler type to use.', | 301 | "cosine", |
| 302 | "cosine_with_restarts", | ||
| 303 | "polynomial", | ||
| 304 | "constant", | ||
| 305 | "constant_with_warmup", | ||
| 306 | "one_cycle", | ||
| 307 | ], | ||
| 308 | help="The scheduler type to use.", | ||
| 235 | ) | 309 | ) |
| 236 | parser.add_argument( | 310 | parser.add_argument( |
| 237 | "--lr_warmup_epochs", | 311 | "--lr_warmup_epochs", |
| 238 | type=int, | 312 | type=int, |
| 239 | default=10, | 313 | default=10, |
| 240 | help="Number of steps for the warmup in the lr scheduler." | 314 | help="Number of steps for the warmup in the lr scheduler.", |
| 241 | ) | 315 | ) |
| 242 | parser.add_argument( | 316 | parser.add_argument( |
| 243 | "--lr_mid_point", | 317 | "--lr_mid_point", type=float, default=0.3, help="OneCycle schedule mid point." |
| 244 | type=float, | ||
| 245 | default=0.3, | ||
| 246 | help="OneCycle schedule mid point." | ||
| 247 | ) | 318 | ) |
| 248 | parser.add_argument( | 319 | parser.add_argument( |
| 249 | "--lr_cycles", | 320 | "--lr_cycles", |
| 250 | type=int, | 321 | type=int, |
| 251 | default=None, | 322 | default=None, |
| 252 | help="Number of restart cycles in the lr scheduler (if supported)." | 323 | help="Number of restart cycles in the lr scheduler (if supported).", |
| 253 | ) | 324 | ) |
| 254 | parser.add_argument( | 325 | parser.add_argument( |
| 255 | "--lr_warmup_func", | 326 | "--lr_warmup_func", |
| @@ -261,7 +332,7 @@ def parse_args(): | |||
| 261 | "--lr_warmup_exp", | 332 | "--lr_warmup_exp", |
| 262 | type=int, | 333 | type=int, |
| 263 | default=1, | 334 | default=1, |
| 264 | help='If lr_warmup_func is "cos", exponent to modify the function' | 335 | help='If lr_warmup_func is "cos", exponent to modify the function', |
| 265 | ) | 336 | ) |
| 266 | parser.add_argument( | 337 | parser.add_argument( |
| 267 | "--lr_annealing_func", | 338 | "--lr_annealing_func", |
| @@ -273,76 +344,76 @@ def parse_args(): | |||
| 273 | "--lr_annealing_exp", | 344 | "--lr_annealing_exp", |
| 274 | type=int, | 345 | type=int, |
| 275 | default=3, | 346 | default=3, |
| 276 | help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' | 347 | help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function', |
| 277 | ) | 348 | ) |
| 278 | parser.add_argument( | 349 | parser.add_argument( |
| 279 | "--lr_min_lr", | 350 | "--lr_min_lr", |
| 280 | type=float, | 351 | type=float, |
| 281 | default=0.04, | 352 | default=0.04, |
| 282 | help="Minimum learning rate in the lr scheduler." | 353 | help="Minimum learning rate in the lr scheduler.", |
| 283 | ) | 354 | ) |
| 355 | parser.add_argument("--min_snr_gamma", type=int, default=5, help="MinSNR gamma.") | ||
| 284 | parser.add_argument( | 356 | parser.add_argument( |
| 285 | "--use_ema", | 357 | "--schedule_sampler", |
| 286 | action="store_true", | 358 | type=str, |
| 287 | help="Whether to use EMA model." | 359 | default="uniform", |
| 288 | ) | 360 | choices=["uniform", "loss-second-moment"], |
| 289 | parser.add_argument( | 361 | help="Noise schedule sampler.", |
| 290 | "--ema_inv_gamma", | ||
| 291 | type=float, | ||
| 292 | default=1.0 | ||
| 293 | ) | ||
| 294 | parser.add_argument( | ||
| 295 | "--ema_power", | ||
| 296 | type=float, | ||
| 297 | default=6/7 | ||
| 298 | ) | ||
| 299 | parser.add_argument( | ||
| 300 | "--ema_max_decay", | ||
| 301 | type=float, | ||
| 302 | default=0.9999 | ||
| 303 | ) | 362 | ) |
| 304 | parser.add_argument( | 363 | parser.add_argument( |
| 305 | "--optimizer", | 364 | "--optimizer", |
| 306 | type=str, | 365 | type=str, |
| 307 | default="dadan", | 366 | default="adan", |
| 308 | choices=["adam", "adam8bit", "adan", "lion", "dadam", "dadan", "adafactor"], | 367 | choices=[ |
| 309 | help='Optimizer to use' | 368 | "adam", |
| 369 | "adam8bit", | ||
| 370 | "adan", | ||
| 371 | "lion", | ||
| 372 | "dadam", | ||
| 373 | "dadan", | ||
| 374 | "dlion", | ||
| 375 | "adafactor", | ||
| 376 | ], | ||
| 377 | help="Optimizer to use", | ||
| 310 | ) | 378 | ) |
| 311 | parser.add_argument( | 379 | parser.add_argument( |
| 312 | "--dadaptation_d0", | 380 | "--dadaptation_d0", |
| 313 | type=float, | 381 | type=float, |
| 314 | default=1e-6, | 382 | default=1e-6, |
| 315 | help="The d0 parameter for Dadaptation optimizers." | 383 | help="The d0 parameter for Dadaptation optimizers.", |
| 384 | ) | ||
| 385 | parser.add_argument( | ||
| 386 | "--dadaptation_growth_rate", | ||
| 387 | type=float, | ||
| 388 | default=math.inf, | ||
| 389 | help="The growth_rate parameter for Dadaptation optimizers.", | ||
| 316 | ) | 390 | ) |
| 317 | parser.add_argument( | 391 | parser.add_argument( |
| 318 | "--adam_beta1", | 392 | "--adam_beta1", |
| 319 | type=float, | 393 | type=float, |
| 320 | default=None, | 394 | default=None, |
| 321 | help="The beta1 parameter for the Adam optimizer." | 395 | help="The beta1 parameter for the Adam optimizer.", |
| 322 | ) | 396 | ) |
| 323 | parser.add_argument( | 397 | parser.add_argument( |
| 324 | "--adam_beta2", | 398 | "--adam_beta2", |
| 325 | type=float, | 399 | type=float, |
| 326 | default=None, | 400 | default=None, |
| 327 | help="The beta2 parameter for the Adam optimizer." | 401 | help="The beta2 parameter for the Adam optimizer.", |
| 328 | ) | 402 | ) |
| 329 | parser.add_argument( | 403 | parser.add_argument( |
| 330 | "--adam_weight_decay", | 404 | "--adam_weight_decay", type=float, default=2e-2, help="Weight decay to use." |
| 331 | type=float, | ||
| 332 | default=1e-2, | ||
| 333 | help="Weight decay to use." | ||
| 334 | ) | 405 | ) |
| 335 | parser.add_argument( | 406 | parser.add_argument( |
| 336 | "--adam_epsilon", | 407 | "--adam_epsilon", |
| 337 | type=float, | 408 | type=float, |
| 338 | default=1e-08, | 409 | default=1e-08, |
| 339 | help="Epsilon value for the Adam optimizer" | 410 | help="Epsilon value for the Adam optimizer", |
| 340 | ) | 411 | ) |
| 341 | parser.add_argument( | 412 | parser.add_argument( |
| 342 | "--adam_amsgrad", | 413 | "--adam_amsgrad", |
| 343 | type=bool, | 414 | type=bool, |
| 344 | default=False, | 415 | default=False, |
| 345 | help="Amsgrad value for the Adam optimizer" | 416 | help="Amsgrad value for the Adam optimizer", |
| 346 | ) | 417 | ) |
| 347 | parser.add_argument( | 418 | parser.add_argument( |
| 348 | "--mixed_precision", | 419 | "--mixed_precision", |
| @@ -356,12 +427,28 @@ def parse_args(): | |||
| 356 | ), | 427 | ), |
| 357 | ) | 428 | ) |
| 358 | parser.add_argument( | 429 | parser.add_argument( |
| 430 | "--compile_unet", | ||
| 431 | action="store_true", | ||
| 432 | help="Compile UNet with Torch Dynamo.", | ||
| 433 | ) | ||
| 434 | parser.add_argument( | ||
| 435 | "--use_xformers", | ||
| 436 | action="store_true", | ||
| 437 | help="Use xformers.", | ||
| 438 | ) | ||
| 439 | parser.add_argument( | ||
| 359 | "--sample_frequency", | 440 | "--sample_frequency", |
| 360 | type=int, | 441 | type=int, |
| 361 | default=1, | 442 | default=1, |
| 362 | help="How often to save a checkpoint and sample image", | 443 | help="How often to save a checkpoint and sample image", |
| 363 | ) | 444 | ) |
| 364 | parser.add_argument( | 445 | parser.add_argument( |
| 446 | "--sample_num", | ||
| 447 | type=int, | ||
| 448 | default=None, | ||
| 449 | help="How often to save a checkpoint and sample image (in number of samples)", | ||
| 450 | ) | ||
| 451 | parser.add_argument( | ||
| 365 | "--sample_image_size", | 452 | "--sample_image_size", |
| 366 | type=int, | 453 | type=int, |
| 367 | default=768, | 454 | default=768, |
| @@ -383,19 +470,19 @@ def parse_args(): | |||
| 383 | "--valid_set_size", | 470 | "--valid_set_size", |
| 384 | type=int, | 471 | type=int, |
| 385 | default=None, | 472 | default=None, |
| 386 | help="Number of images in the validation dataset." | 473 | help="Number of images in the validation dataset.", |
| 387 | ) | 474 | ) |
| 388 | parser.add_argument( | 475 | parser.add_argument( |
| 389 | "--valid_set_repeat", | 476 | "--valid_set_repeat", |
| 390 | type=int, | 477 | type=int, |
| 391 | default=1, | 478 | default=1, |
| 392 | help="Times the images in the validation dataset are repeated." | 479 | help="Times the images in the validation dataset are repeated.", |
| 393 | ) | 480 | ) |
| 394 | parser.add_argument( | 481 | parser.add_argument( |
| 395 | "--train_batch_size", | 482 | "--train_batch_size", |
| 396 | type=int, | 483 | type=int, |
| 397 | default=1, | 484 | default=1, |
| 398 | help="Batch size (per device) for the training dataloader." | 485 | help="Batch size (per device) for the training dataloader.", |
| 399 | ) | 486 | ) |
| 400 | parser.add_argument( | 487 | parser.add_argument( |
| 401 | "--sample_steps", | 488 | "--sample_steps", |
| @@ -407,13 +494,18 @@ def parse_args(): | |||
| 407 | "--prior_loss_weight", | 494 | "--prior_loss_weight", |
| 408 | type=float, | 495 | type=float, |
| 409 | default=1.0, | 496 | default=1.0, |
| 410 | help="The weight of prior preservation loss." | 497 | help="The weight of prior preservation loss.", |
| 411 | ) | 498 | ) |
| 499 | parser.add_argument("--run_pti", action="store_true", help="Whether to run PTI.") | ||
| 500 | parser.add_argument("--emb_alpha", type=float, default=1.0, help="Embedding alpha") | ||
| 412 | parser.add_argument( | 501 | parser.add_argument( |
| 413 | "--max_grad_norm", | 502 | "--emb_dropout", |
| 414 | default=1.0, | ||
| 415 | type=float, | 503 | type=float, |
| 416 | help="Max gradient norm." | 504 | default=0, |
| 505 | help="Embedding dropout probability.", | ||
| 506 | ) | ||
| 507 | parser.add_argument( | ||
| 508 | "--max_grad_norm", default=1.0, type=float, help="Max gradient norm." | ||
| 417 | ) | 509 | ) |
| 418 | parser.add_argument( | 510 | parser.add_argument( |
| 419 | "--noise_timesteps", | 511 | "--noise_timesteps", |
| @@ -424,7 +516,7 @@ def parse_args(): | |||
| 424 | "--config", | 516 | "--config", |
| 425 | type=str, | 517 | type=str, |
| 426 | default=None, | 518 | default=None, |
| 427 | help="Path to a JSON configuration file containing arguments for invoking this script." | 519 | help="Path to a JSON configuration file containing arguments for invoking this script.", |
| 428 | ) | 520 | ) |
| 429 | 521 | ||
| 430 | args = parser.parse_args() | 522 | args = parser.parse_args() |
| @@ -441,6 +533,67 @@ def parse_args(): | |||
| 441 | if args.project is None: | 533 | if args.project is None: |
| 442 | raise ValueError("You must specify --project") | 534 | raise ValueError("You must specify --project") |
| 443 | 535 | ||
| 536 | if args.initializer_tokens is None: | ||
| 537 | args.initializer_tokens = [] | ||
| 538 | |||
| 539 | if args.placeholder_tokens is None: | ||
| 540 | args.placeholder_tokens = [] | ||
| 541 | |||
| 542 | if isinstance(args.placeholder_tokens, str): | ||
| 543 | args.placeholder_tokens = [args.placeholder_tokens] | ||
| 544 | |||
| 545 | if isinstance(args.initializer_tokens, str): | ||
| 546 | args.initializer_tokens = [args.initializer_tokens] * len( | ||
| 547 | args.placeholder_tokens | ||
| 548 | ) | ||
| 549 | |||
| 550 | if len(args.placeholder_tokens) == 0: | ||
| 551 | args.placeholder_tokens = [ | ||
| 552 | f"<*{i}>" for i in range(len(args.initializer_tokens)) | ||
| 553 | ] | ||
| 554 | |||
| 555 | if len(args.initializer_tokens) == 0: | ||
| 556 | args.initializer_tokens = args.placeholder_tokens.copy() | ||
| 557 | |||
| 558 | if len(args.placeholder_tokens) != len(args.initializer_tokens): | ||
| 559 | raise ValueError( | ||
| 560 | "--placeholder_tokens and --initializer_tokens must have the same number of items" | ||
| 561 | ) | ||
| 562 | |||
| 563 | if isinstance(args.inverted_initializer_tokens, str): | ||
| 564 | args.inverted_initializer_tokens = [args.inverted_initializer_tokens] * len( | ||
| 565 | args.placeholder_tokens | ||
| 566 | ) | ||
| 567 | |||
| 568 | if ( | ||
| 569 | isinstance(args.inverted_initializer_tokens, list) | ||
| 570 | and len(args.inverted_initializer_tokens) != 0 | ||
| 571 | ): | ||
| 572 | args.placeholder_tokens += [f"inv_{t}" for t in args.placeholder_tokens] | ||
| 573 | args.initializer_tokens += args.inverted_initializer_tokens | ||
| 574 | |||
| 575 | if isinstance(args.num_vectors, int): | ||
| 576 | args.num_vectors = [args.num_vectors] * len(args.placeholder_tokens) | ||
| 577 | |||
| 578 | if isinstance(args.num_vectors, list) and len(args.placeholder_tokens) != len( | ||
| 579 | args.num_vectors | ||
| 580 | ): | ||
| 581 | raise ValueError( | ||
| 582 | "--placeholder_tokens and --num_vectors must have the same number of items" | ||
| 583 | ) | ||
| 584 | |||
| 585 | if args.alias_tokens is None: | ||
| 586 | args.alias_tokens = [] | ||
| 587 | |||
| 588 | if not isinstance(args.alias_tokens, list) or len(args.alias_tokens) % 2 != 0: | ||
| 589 | raise ValueError("--alias_tokens must be a list with an even number of items") | ||
| 590 | |||
| 591 | if args.filter_tokens is None: | ||
| 592 | args.filter_tokens = args.placeholder_tokens.copy() | ||
| 593 | |||
| 594 | if isinstance(args.filter_tokens, str): | ||
| 595 | args.filter_tokens = [args.filter_tokens] | ||
| 596 | |||
| 444 | if isinstance(args.collection, str): | 597 | if isinstance(args.collection, str): |
| 445 | args.collection = [args.collection] | 598 | args.collection = [args.collection] |
| 446 | 599 | ||
| @@ -451,15 +604,15 @@ def parse_args(): | |||
| 451 | raise ValueError("You must specify --output_dir") | 604 | raise ValueError("You must specify --output_dir") |
| 452 | 605 | ||
| 453 | if args.adam_beta1 is None: | 606 | if args.adam_beta1 is None: |
| 454 | if args.optimizer in ('adam', 'adam8bit'): | 607 | if args.optimizer in ("adam", "adam8bit", "dadam"): |
| 455 | args.adam_beta1 = 0.9 | 608 | args.adam_beta1 = 0.9 |
| 456 | elif args.optimizer == 'lion': | 609 | elif args.optimizer in ("lion", "dlion"): |
| 457 | args.adam_beta1 = 0.95 | 610 | args.adam_beta1 = 0.95 |
| 458 | 611 | ||
| 459 | if args.adam_beta2 is None: | 612 | if args.adam_beta2 is None: |
| 460 | if args.optimizer in ('adam', 'adam8bit'): | 613 | if args.optimizer in ("adam", "adam8bit", "dadam"): |
| 461 | args.adam_beta2 = 0.999 | 614 | args.adam_beta2 = 0.999 |
| 462 | elif args.optimizer == 'lion': | 615 | elif args.optimizer in ("lion", "dlion"): |
| 463 | args.adam_beta2 = 0.98 | 616 | args.adam_beta2 = 0.98 |
| 464 | 617 | ||
| 465 | return args | 618 | return args |
| @@ -475,7 +628,7 @@ def main(): | |||
| 475 | accelerator = Accelerator( | 628 | accelerator = Accelerator( |
| 476 | log_with=LoggerType.TENSORBOARD, | 629 | log_with=LoggerType.TENSORBOARD, |
| 477 | project_dir=f"{output_dir}", | 630 | project_dir=f"{output_dir}", |
| 478 | mixed_precision=args.mixed_precision | 631 | mixed_precision=args.mixed_precision, |
| 479 | ) | 632 | ) |
| 480 | 633 | ||
| 481 | weight_dtype = torch.float32 | 634 | weight_dtype = torch.float32 |
| @@ -484,6 +637,8 @@ def main(): | |||
| 484 | elif args.mixed_precision == "bf16": | 637 | elif args.mixed_precision == "bf16": |
| 485 | weight_dtype = torch.bfloat16 | 638 | weight_dtype = torch.bfloat16 |
| 486 | 639 | ||
| 640 | patch_xformers(weight_dtype) | ||
| 641 | |||
| 487 | logging.basicConfig(filename=output_dir / "log.txt", level=logging.DEBUG) | 642 | logging.basicConfig(filename=output_dir / "log.txt", level=logging.DEBUG) |
| 488 | 643 | ||
| 489 | if args.seed is None: | 644 | if args.seed is None: |
| @@ -493,44 +648,125 @@ def main(): | |||
| 493 | 648 | ||
| 494 | save_args(output_dir, args) | 649 | save_args(output_dir, args) |
| 495 | 650 | ||
| 496 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( | 651 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler = get_models( |
| 497 | args.pretrained_model_name_or_path) | 652 | args.pretrained_model_name_or_path |
| 498 | 653 | ) | |
| 499 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) | 654 | embeddings = patch_managed_embeddings( |
| 500 | tokenizer.set_dropout(args.vector_dropout) | 655 | text_encoder, args.emb_alpha, args.emb_dropout |
| 656 | ) | ||
| 657 | schedule_sampler = create_named_schedule_sampler( | ||
| 658 | args.schedule_sampler, noise_scheduler.config.num_train_timesteps | ||
| 659 | ) | ||
| 501 | 660 | ||
| 502 | vae.enable_slicing() | 661 | vae.enable_slicing() |
| 503 | vae.set_use_memory_efficient_attention_xformers(True) | 662 | |
| 504 | unet.enable_xformers_memory_efficient_attention() | 663 | if args.use_xformers: |
| 664 | vae.set_use_memory_efficient_attention_xformers(True) | ||
| 665 | unet.enable_xformers_memory_efficient_attention() | ||
| 666 | # elif args.compile_unet: | ||
| 667 | # unet.mid_block.attentions[0].transformer_blocks[0].attn1._use_2_0_attn = False | ||
| 668 | # | ||
| 669 | # proc = AttnProcessor() | ||
| 670 | # | ||
| 671 | # def fn_recursive_set_proc(module: torch.nn.Module): | ||
| 672 | # if hasattr(module, "processor"): | ||
| 673 | # module.processor = proc | ||
| 674 | # | ||
| 675 | # for child in module.children(): | ||
| 676 | # fn_recursive_set_proc(child) | ||
| 677 | # | ||
| 678 | # fn_recursive_set_proc(unet) | ||
| 505 | 679 | ||
| 506 | if args.gradient_checkpointing: | 680 | if args.gradient_checkpointing: |
| 507 | unet.enable_gradient_checkpointing() | 681 | unet.enable_gradient_checkpointing() |
| 508 | text_encoder.gradient_checkpointing_enable() | 682 | |
| 683 | if len(args.alias_tokens) != 0: | ||
| 684 | alias_placeholder_tokens = args.alias_tokens[::2] | ||
| 685 | alias_initializer_tokens = args.alias_tokens[1::2] | ||
| 686 | |||
| 687 | added_tokens, added_ids = add_placeholder_tokens( | ||
| 688 | tokenizer=tokenizer, | ||
| 689 | embeddings=embeddings, | ||
| 690 | placeholder_tokens=alias_placeholder_tokens, | ||
| 691 | initializer_tokens=alias_initializer_tokens, | ||
| 692 | ) | ||
| 693 | embeddings.persist() | ||
| 694 | print( | ||
| 695 | f"Added {len(added_tokens)} aliases: {list(zip(alias_placeholder_tokens, added_tokens, alias_initializer_tokens, added_ids))}" | ||
| 696 | ) | ||
| 697 | |||
| 698 | placeholder_tokens = [] | ||
| 699 | placeholder_token_ids = [] | ||
| 509 | 700 | ||
| 510 | if args.embeddings_dir is not None: | 701 | if args.embeddings_dir is not None: |
| 511 | embeddings_dir = Path(args.embeddings_dir) | 702 | embeddings_dir = Path(args.embeddings_dir) |
| 512 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): | 703 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): |
| 513 | raise ValueError("--embeddings_dir must point to an existing directory") | 704 | raise ValueError("--embeddings_dir must point to an existing directory") |
| 514 | 705 | ||
| 515 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) | 706 | added_tokens, added_ids = load_embeddings_from_dir( |
| 516 | embeddings.persist() | 707 | tokenizer, embeddings, embeddings_dir |
| 517 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") | 708 | ) |
| 709 | |||
| 710 | placeholder_tokens = added_tokens | ||
| 711 | placeholder_token_ids = added_ids | ||
| 712 | |||
| 713 | print( | ||
| 714 | f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}" | ||
| 715 | ) | ||
| 716 | |||
| 717 | if args.train_dir_embeddings: | ||
| 718 | print("Training embeddings from embeddings dir") | ||
| 719 | else: | ||
| 720 | embeddings.persist() | ||
| 721 | |||
| 722 | if len(args.placeholder_tokens) != 0 and not args.train_dir_embeddings: | ||
| 723 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( | ||
| 724 | tokenizer=tokenizer, | ||
| 725 | embeddings=embeddings, | ||
| 726 | placeholder_tokens=args.placeholder_tokens, | ||
| 727 | initializer_tokens=args.initializer_tokens, | ||
| 728 | num_vectors=args.num_vectors, | ||
| 729 | initializer_noise=args.initializer_noise, | ||
| 730 | ) | ||
| 731 | |||
| 732 | placeholder_tokens = args.placeholder_tokens | ||
| 733 | |||
| 734 | stats = list( | ||
| 735 | zip( | ||
| 736 | placeholder_tokens, | ||
| 737 | placeholder_token_ids, | ||
| 738 | args.initializer_tokens, | ||
| 739 | initializer_token_ids, | ||
| 740 | ) | ||
| 741 | ) | ||
| 742 | print(f"Training embeddings: {stats}") | ||
| 518 | 743 | ||
| 519 | if args.scale_lr: | 744 | if args.scale_lr: |
| 520 | args.learning_rate = ( | 745 | args.learning_rate_unet = ( |
| 521 | args.learning_rate * args.gradient_accumulation_steps * | 746 | args.learning_rate_unet |
| 522 | args.train_batch_size * accelerator.num_processes | 747 | * args.gradient_accumulation_steps |
| 748 | * args.train_batch_size | ||
| 749 | * accelerator.num_processes | ||
| 750 | ) | ||
| 751 | args.learning_rate_text = ( | ||
| 752 | args.learning_rate_text | ||
| 753 | * args.gradient_accumulation_steps | ||
| 754 | * args.train_batch_size | ||
| 755 | * accelerator.num_processes | ||
| 523 | ) | 756 | ) |
| 524 | 757 | ||
| 525 | if args.find_lr: | 758 | if args.find_lr: |
| 526 | args.learning_rate = 1e-6 | 759 | args.learning_rate_unet = 1e-6 |
| 760 | args.learning_rate_text = 1e-6 | ||
| 527 | args.lr_scheduler = "exponential_growth" | 761 | args.lr_scheduler = "exponential_growth" |
| 528 | 762 | ||
| 529 | if args.optimizer == 'adam8bit': | 763 | if args.optimizer == "adam8bit": |
| 530 | try: | 764 | try: |
| 531 | import bitsandbytes as bnb | 765 | import bitsandbytes as bnb |
| 532 | except ImportError: | 766 | except ImportError: |
| 533 | raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.") | 767 | raise ImportError( |
| 768 | "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`." | ||
| 769 | ) | ||
| 534 | 770 | ||
| 535 | create_optimizer = partial( | 771 | create_optimizer = partial( |
| 536 | bnb.optim.AdamW8bit, | 772 | bnb.optim.AdamW8bit, |
| @@ -539,7 +775,7 @@ def main(): | |||
| 539 | eps=args.adam_epsilon, | 775 | eps=args.adam_epsilon, |
| 540 | amsgrad=args.adam_amsgrad, | 776 | amsgrad=args.adam_amsgrad, |
| 541 | ) | 777 | ) |
| 542 | elif args.optimizer == 'adam': | 778 | elif args.optimizer == "adam": |
| 543 | create_optimizer = partial( | 779 | create_optimizer = partial( |
| 544 | torch.optim.AdamW, | 780 | torch.optim.AdamW, |
| 545 | betas=(args.adam_beta1, args.adam_beta2), | 781 | betas=(args.adam_beta1, args.adam_beta2), |
| @@ -547,22 +783,27 @@ def main(): | |||
| 547 | eps=args.adam_epsilon, | 783 | eps=args.adam_epsilon, |
| 548 | amsgrad=args.adam_amsgrad, | 784 | amsgrad=args.adam_amsgrad, |
| 549 | ) | 785 | ) |
| 550 | elif args.optimizer == 'adan': | 786 | elif args.optimizer == "adan": |
| 551 | try: | 787 | try: |
| 552 | import timm.optim | 788 | import timm.optim |
| 553 | except ImportError: | 789 | except ImportError: |
| 554 | raise ImportError("To use Adan, please install the PyTorch Image Models library: `pip install timm`.") | 790 | raise ImportError( |
| 791 | "To use Adan, please install the PyTorch Image Models library: `pip install timm`." | ||
| 792 | ) | ||
| 555 | 793 | ||
| 556 | create_optimizer = partial( | 794 | create_optimizer = partial( |
| 557 | timm.optim.Adan, | 795 | timm.optim.Adan, |
| 558 | weight_decay=args.adam_weight_decay, | 796 | weight_decay=args.adam_weight_decay, |
| 559 | eps=args.adam_epsilon, | 797 | eps=args.adam_epsilon, |
| 798 | no_prox=True, | ||
| 560 | ) | 799 | ) |
| 561 | elif args.optimizer == 'lion': | 800 | elif args.optimizer == "lion": |
| 562 | try: | 801 | try: |
| 563 | import lion_pytorch | 802 | import lion_pytorch |
| 564 | except ImportError: | 803 | except ImportError: |
| 565 | raise ImportError("To use Lion, please install the lion_pytorch library: `pip install lion-pytorch`.") | 804 | raise ImportError( |
| 805 | "To use Lion, please install the lion_pytorch library: `pip install lion-pytorch`." | ||
| 806 | ) | ||
| 566 | 807 | ||
| 567 | create_optimizer = partial( | 808 | create_optimizer = partial( |
| 568 | lion_pytorch.Lion, | 809 | lion_pytorch.Lion, |
| @@ -570,7 +811,7 @@ def main(): | |||
| 570 | weight_decay=args.adam_weight_decay, | 811 | weight_decay=args.adam_weight_decay, |
| 571 | use_triton=True, | 812 | use_triton=True, |
| 572 | ) | 813 | ) |
| 573 | elif args.optimizer == 'adafactor': | 814 | elif args.optimizer == "adafactor": |
| 574 | create_optimizer = partial( | 815 | create_optimizer = partial( |
| 575 | transformers.optimization.Adafactor, | 816 | transformers.optimization.Adafactor, |
| 576 | weight_decay=args.adam_weight_decay, | 817 | weight_decay=args.adam_weight_decay, |
| @@ -580,13 +821,16 @@ def main(): | |||
| 580 | ) | 821 | ) |
| 581 | 822 | ||
| 582 | args.lr_scheduler = "adafactor" | 823 | args.lr_scheduler = "adafactor" |
| 583 | args.lr_min_lr = args.learning_rate | 824 | args.lr_min_lr = args.learning_rate_unet |
| 584 | args.learning_rate = None | 825 | args.learning_rate_unet = None |
| 585 | elif args.optimizer == 'dadam': | 826 | args.learning_rate_text = None |
| 827 | elif args.optimizer == "dadam": | ||
| 586 | try: | 828 | try: |
| 587 | import dadaptation | 829 | import dadaptation |
| 588 | except ImportError: | 830 | except ImportError: |
| 589 | raise ImportError("To use DAdaptAdam, please install the dadaptation library: `pip install dadaptation`.") | 831 | raise ImportError( |
| 832 | "To use DAdaptAdam, please install the dadaptation library: `pip install dadaptation`." | ||
| 833 | ) | ||
| 590 | 834 | ||
| 591 | create_optimizer = partial( | 835 | create_optimizer = partial( |
| 592 | dadaptation.DAdaptAdam, | 836 | dadaptation.DAdaptAdam, |
| @@ -595,46 +839,65 @@ def main(): | |||
| 595 | eps=args.adam_epsilon, | 839 | eps=args.adam_epsilon, |
| 596 | decouple=True, | 840 | decouple=True, |
| 597 | d0=args.dadaptation_d0, | 841 | d0=args.dadaptation_d0, |
| 842 | growth_rate=args.dadaptation_growth_rate, | ||
| 598 | ) | 843 | ) |
| 599 | 844 | ||
| 600 | args.learning_rate = 1.0 | 845 | args.learning_rate_unet = 1.0 |
| 601 | elif args.optimizer == 'dadan': | 846 | args.learning_rate_text = 1.0 |
| 847 | elif args.optimizer == "dadan": | ||
| 602 | try: | 848 | try: |
| 603 | import dadaptation | 849 | import dadaptation |
| 604 | except ImportError: | 850 | except ImportError: |
| 605 | raise ImportError("To use DAdaptAdan, please install the dadaptation library: `pip install dadaptation`.") | 851 | raise ImportError( |
| 852 | "To use DAdaptAdan, please install the dadaptation library: `pip install dadaptation`." | ||
| 853 | ) | ||
| 606 | 854 | ||
| 607 | create_optimizer = partial( | 855 | create_optimizer = partial( |
| 608 | dadaptation.DAdaptAdan, | 856 | dadaptation.DAdaptAdan, |
| 609 | weight_decay=args.adam_weight_decay, | 857 | weight_decay=args.adam_weight_decay, |
| 610 | eps=args.adam_epsilon, | 858 | eps=args.adam_epsilon, |
| 611 | d0=args.dadaptation_d0, | 859 | d0=args.dadaptation_d0, |
| 860 | growth_rate=args.dadaptation_growth_rate, | ||
| 612 | ) | 861 | ) |
| 613 | 862 | ||
| 614 | args.learning_rate = 1.0 | 863 | args.learning_rate_unet = 1.0 |
| 864 | args.learning_rate_text = 1.0 | ||
| 865 | elif args.optimizer == "dlion": | ||
| 866 | raise ImportError("DLion has not been merged into dadaptation yet") | ||
| 615 | else: | 867 | else: |
| 616 | raise ValueError(f"Unknown --optimizer \"{args.optimizer}\"") | 868 | raise ValueError(f'Unknown --optimizer "{args.optimizer}"') |
| 617 | 869 | ||
| 618 | trainer = partial( | 870 | trainer = partial( |
| 619 | train, | 871 | train, |
| 620 | accelerator=accelerator, | 872 | accelerator=accelerator, |
| 621 | unet=unet, | 873 | unet=unet, |
| 622 | text_encoder=text_encoder, | 874 | text_encoder=text_encoder, |
| 875 | tokenizer=tokenizer, | ||
| 623 | vae=vae, | 876 | vae=vae, |
| 624 | noise_scheduler=noise_scheduler, | 877 | noise_scheduler=noise_scheduler, |
| 878 | schedule_sampler=schedule_sampler, | ||
| 879 | min_snr_gamma=args.min_snr_gamma, | ||
| 625 | dtype=weight_dtype, | 880 | dtype=weight_dtype, |
| 881 | seed=args.seed, | ||
| 882 | compile_unet=args.compile_unet, | ||
| 626 | guidance_scale=args.guidance_scale, | 883 | guidance_scale=args.guidance_scale, |
| 627 | prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, | 884 | prior_loss_weight=args.prior_loss_weight if args.num_class_images != 0 else 0, |
| 628 | no_val=args.valid_set_size == 0, | 885 | sample_scheduler=sample_scheduler, |
| 886 | sample_batch_size=args.sample_batch_size, | ||
| 887 | sample_num_batches=args.sample_batches, | ||
| 888 | sample_num_steps=args.sample_steps, | ||
| 889 | sample_image_size=args.sample_image_size, | ||
| 890 | max_grad_norm=args.max_grad_norm, | ||
| 629 | ) | 891 | ) |
| 630 | 892 | ||
| 631 | checkpoint_output_dir = output_dir / "model" | 893 | data_generator = torch.Generator(device="cpu").manual_seed(args.seed) |
| 632 | sample_output_dir = output_dir / "samples" | 894 | data_npgenerator = np.random.default_rng(args.seed) |
| 633 | 895 | ||
| 634 | datamodule = VlpnDataModule( | 896 | create_datamodule = partial( |
| 897 | VlpnDataModule, | ||
| 635 | data_file=args.train_data_file, | 898 | data_file=args.train_data_file, |
| 636 | batch_size=args.train_batch_size, | ||
| 637 | tokenizer=tokenizer, | 899 | tokenizer=tokenizer, |
| 900 | constant_prompt_length=args.compile_unet, | ||
| 638 | class_subdir=args.class_image_dir, | 901 | class_subdir=args.class_image_dir, |
| 639 | with_guidance=args.guidance_scale != 0, | 902 | with_guidance=args.guidance_scale != 0, |
| 640 | num_class_images=args.num_class_images, | 903 | num_class_images=args.num_class_images, |
| @@ -643,83 +906,186 @@ def main(): | |||
| 643 | progressive_buckets=args.progressive_buckets, | 906 | progressive_buckets=args.progressive_buckets, |
| 644 | bucket_step_size=args.bucket_step_size, | 907 | bucket_step_size=args.bucket_step_size, |
| 645 | bucket_max_pixels=args.bucket_max_pixels, | 908 | bucket_max_pixels=args.bucket_max_pixels, |
| 646 | dropout=args.tag_dropout, | ||
| 647 | shuffle=not args.no_tag_shuffle, | 909 | shuffle=not args.no_tag_shuffle, |
| 648 | template_key=args.train_data_template, | 910 | template_key=args.train_data_template, |
| 649 | valid_set_size=args.valid_set_size, | ||
| 650 | train_set_pad=args.train_set_pad, | 911 | train_set_pad=args.train_set_pad, |
| 651 | valid_set_pad=args.valid_set_pad, | 912 | valid_set_pad=args.valid_set_pad, |
| 652 | seed=args.seed, | 913 | dtype=weight_dtype, |
| 653 | filter=partial(keyword_filter, None, args.collection, args.exclude_collections), | 914 | generator=data_generator, |
| 654 | dtype=weight_dtype | 915 | npgenerator=data_npgenerator, |
| 655 | ) | ||
| 656 | datamodule.setup() | ||
| 657 | |||
| 658 | num_train_epochs = args.num_train_epochs | ||
| 659 | sample_frequency = args.sample_frequency | ||
| 660 | if num_train_epochs is None: | ||
| 661 | num_train_epochs = math.ceil( | ||
| 662 | args.num_train_steps / len(datamodule.train_dataset) | ||
| 663 | ) * args.gradient_accumulation_steps | ||
| 664 | sample_frequency = math.ceil(num_train_epochs * (sample_frequency / args.num_train_steps)) | ||
| 665 | |||
| 666 | params_to_optimize = (unet.parameters(), ) | ||
| 667 | if args.train_text_encoder_epochs != 0: | ||
| 668 | params_to_optimize += ( | ||
| 669 | text_encoder.text_model.encoder.parameters(), | ||
| 670 | text_encoder.text_model.final_layer_norm.parameters(), | ||
| 671 | ) | ||
| 672 | |||
| 673 | optimizer = create_optimizer( | ||
| 674 | itertools.chain(*params_to_optimize), | ||
| 675 | lr=args.learning_rate, | ||
| 676 | ) | 916 | ) |
| 677 | 917 | ||
| 678 | lr_scheduler = get_scheduler( | 918 | create_lr_scheduler = partial( |
| 679 | args.lr_scheduler, | 919 | get_scheduler, |
| 680 | optimizer=optimizer, | ||
| 681 | num_training_steps_per_epoch=len(datamodule.train_dataloader), | ||
| 682 | gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
| 683 | min_lr=args.lr_min_lr, | 920 | min_lr=args.lr_min_lr, |
| 684 | warmup_func=args.lr_warmup_func, | 921 | warmup_func=args.lr_warmup_func, |
| 685 | annealing_func=args.lr_annealing_func, | 922 | annealing_func=args.lr_annealing_func, |
| 686 | warmup_exp=args.lr_warmup_exp, | 923 | warmup_exp=args.lr_warmup_exp, |
| 687 | annealing_exp=args.lr_annealing_exp, | 924 | annealing_exp=args.lr_annealing_exp, |
| 688 | cycles=args.lr_cycles, | ||
| 689 | end_lr=1e2, | 925 | end_lr=1e2, |
| 690 | train_epochs=num_train_epochs, | ||
| 691 | warmup_epochs=args.lr_warmup_epochs, | ||
| 692 | mid_point=args.lr_mid_point, | 926 | mid_point=args.lr_mid_point, |
| 693 | ) | 927 | ) |
| 694 | 928 | ||
| 695 | trainer( | 929 | # Dreambooth |
| 696 | strategy=dreambooth_strategy, | 930 | # -------------------------------------------------------------------------------- |
| 697 | project="dreambooth", | 931 | |
| 698 | train_dataloader=datamodule.train_dataloader, | 932 | dreambooth_datamodule = create_datamodule( |
| 699 | val_dataloader=datamodule.val_dataloader, | 933 | valid_set_size=args.valid_set_size, |
| 700 | seed=args.seed, | 934 | batch_size=args.train_batch_size, |
| 701 | optimizer=optimizer, | 935 | dropout=args.tag_dropout, |
| 702 | lr_scheduler=lr_scheduler, | 936 | filter=partial(keyword_filter, None, args.collection, args.exclude_collections), |
| 703 | num_train_epochs=num_train_epochs, | 937 | ) |
| 704 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 938 | dreambooth_datamodule.setup() |
| 705 | sample_frequency=sample_frequency, | 939 | |
| 706 | offset_noise_strength=args.offset_noise_strength, | 940 | num_train_epochs = args.num_train_epochs |
| 707 | # -- | 941 | dreambooth_sample_frequency = args.sample_frequency |
| 708 | tokenizer=tokenizer, | 942 | if num_train_epochs is None: |
| 709 | sample_scheduler=sample_scheduler, | 943 | num_train_epochs = ( |
| 710 | sample_output_dir=sample_output_dir, | 944 | math.ceil(args.num_train_steps / len(dreambooth_datamodule.train_dataset)) |
| 711 | checkpoint_output_dir=checkpoint_output_dir, | 945 | * args.gradient_accumulation_steps |
| 712 | train_text_encoder_epochs=args.train_text_encoder_epochs, | 946 | ) |
| 713 | max_grad_norm=args.max_grad_norm, | 947 | dreambooth_sample_frequency = math.ceil( |
| 714 | use_ema=args.use_ema, | 948 | num_train_epochs * (dreambooth_sample_frequency / args.num_train_steps) |
| 715 | ema_inv_gamma=args.ema_inv_gamma, | 949 | ) |
| 716 | ema_power=args.ema_power, | 950 | num_training_steps_per_epoch = math.ceil( |
| 717 | ema_max_decay=args.ema_max_decay, | 951 | len(dreambooth_datamodule.train_dataset) / args.gradient_accumulation_steps |
| 718 | sample_batch_size=args.sample_batch_size, | ||
| 719 | sample_num_batches=args.sample_batches, | ||
| 720 | sample_num_steps=args.sample_steps, | ||
| 721 | sample_image_size=args.sample_image_size, | ||
| 722 | ) | 952 | ) |
| 953 | num_train_steps = num_training_steps_per_epoch * num_train_epochs | ||
| 954 | if args.sample_num is not None: | ||
| 955 | dreambooth_sample_frequency = math.ceil(num_train_epochs / args.sample_num) | ||
| 956 | |||
| 957 | dreambooth_project = "dreambooth" | ||
| 958 | |||
| 959 | if accelerator.is_main_process: | ||
| 960 | accelerator.init_trackers(dreambooth_project) | ||
| 961 | |||
| 962 | dreambooth_sample_output_dir = output_dir / dreambooth_project / "samples" | ||
| 963 | |||
| 964 | training_iter = 0 | ||
| 965 | auto_cycles = list(args.auto_cycles) | ||
| 966 | learning_rate_unet = args.learning_rate_unet | ||
| 967 | learning_rate_text = args.learning_rate_text | ||
| 968 | lr_scheduler = args.lr_scheduler | ||
| 969 | lr_warmup_epochs = args.lr_warmup_epochs | ||
| 970 | lr_cycles = args.lr_cycles | ||
| 971 | |||
| 972 | avg_loss = AverageMeter() | ||
| 973 | avg_acc = AverageMeter() | ||
| 974 | avg_loss_val = AverageMeter() | ||
| 975 | avg_acc_val = AverageMeter() | ||
| 976 | |||
| 977 | params_to_optimize = [ | ||
| 978 | { | ||
| 979 | "params": (param for param in unet.parameters() if param.requires_grad), | ||
| 980 | "lr": learning_rate_unet, | ||
| 981 | }, | ||
| 982 | { | ||
| 983 | "params": ( | ||
| 984 | param for param in text_encoder.parameters() if param.requires_grad | ||
| 985 | ), | ||
| 986 | "lr": learning_rate_text, | ||
| 987 | }, | ||
| 988 | ] | ||
| 989 | group_labels = ["unet", "text"] | ||
| 990 | |||
| 991 | dreambooth_optimizer = create_optimizer(params_to_optimize) | ||
| 992 | |||
| 993 | while True: | ||
| 994 | if len(auto_cycles) != 0: | ||
| 995 | response = auto_cycles.pop(0) | ||
| 996 | else: | ||
| 997 | response = input( | ||
| 998 | "\n### Choose action: [o] one_cycle, [w] warmup, [c] constant, [d] decay, [s] stop \n--> " | ||
| 999 | ) | ||
| 1000 | |||
| 1001 | if response.lower().strip() == "o": | ||
| 1002 | if args.learning_rate_unet is not None: | ||
| 1003 | learning_rate_unet = ( | ||
| 1004 | args.learning_rate_unet * 2 * (args.cycle_decay**training_iter) | ||
| 1005 | ) | ||
| 1006 | if args.learning_rate_text is not None: | ||
| 1007 | learning_rate_text = ( | ||
| 1008 | args.learning_rate_text * 2 * (args.cycle_decay**training_iter) | ||
| 1009 | ) | ||
| 1010 | else: | ||
| 1011 | learning_rate_unet = args.learning_rate_unet * ( | ||
| 1012 | args.cycle_decay**training_iter | ||
| 1013 | ) | ||
| 1014 | learning_rate_text = args.learning_rate_text * ( | ||
| 1015 | args.cycle_decay**training_iter | ||
| 1016 | ) | ||
| 1017 | |||
| 1018 | if response.lower().strip() == "o": | ||
| 1019 | lr_scheduler = "one_cycle" | ||
| 1020 | lr_warmup_epochs = args.lr_warmup_epochs | ||
| 1021 | lr_cycles = args.lr_cycles | ||
| 1022 | elif response.lower().strip() == "w": | ||
| 1023 | lr_scheduler = "constant_with_warmup" | ||
| 1024 | lr_warmup_epochs = num_train_epochs | ||
| 1025 | elif response.lower().strip() == "c": | ||
| 1026 | lr_scheduler = "constant" | ||
| 1027 | elif response.lower().strip() == "d": | ||
| 1028 | lr_scheduler = "cosine" | ||
| 1029 | lr_warmup_epochs = 0 | ||
| 1030 | lr_cycles = 1 | ||
| 1031 | elif response.lower().strip() == "s": | ||
| 1032 | break | ||
| 1033 | else: | ||
| 1034 | continue | ||
| 1035 | |||
| 1036 | print("") | ||
| 1037 | print( | ||
| 1038 | f"============ Dreambooth cycle {training_iter + 1}: {response} ============" | ||
| 1039 | ) | ||
| 1040 | print("") | ||
| 1041 | |||
| 1042 | for group, lr in zip( | ||
| 1043 | dreambooth_optimizer.param_groups, [learning_rate_unet, learning_rate_text] | ||
| 1044 | ): | ||
| 1045 | group["lr"] = lr | ||
| 1046 | |||
| 1047 | dreambooth_lr_scheduler = create_lr_scheduler( | ||
| 1048 | lr_scheduler, | ||
| 1049 | gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
| 1050 | optimizer=dreambooth_optimizer, | ||
| 1051 | num_training_steps_per_epoch=len(dreambooth_datamodule.train_dataloader), | ||
| 1052 | train_epochs=num_train_epochs, | ||
| 1053 | cycles=lr_cycles, | ||
| 1054 | warmup_epochs=lr_warmup_epochs, | ||
| 1055 | ) | ||
| 1056 | |||
| 1057 | dreambooth_checkpoint_output_dir = ( | ||
| 1058 | output_dir / dreambooth_project / f"model_{training_iter}" | ||
| 1059 | ) | ||
| 1060 | |||
| 1061 | trainer( | ||
| 1062 | strategy=dreambooth_strategy, | ||
| 1063 | train_dataloader=dreambooth_datamodule.train_dataloader, | ||
| 1064 | val_dataloader=dreambooth_datamodule.val_dataloader, | ||
| 1065 | optimizer=dreambooth_optimizer, | ||
| 1066 | lr_scheduler=dreambooth_lr_scheduler, | ||
| 1067 | num_train_epochs=num_train_epochs, | ||
| 1068 | gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
| 1069 | global_step_offset=training_iter * num_train_steps, | ||
| 1070 | cycle=training_iter, | ||
| 1071 | train_text_encoder_cycles=args.train_text_encoder_cycles, | ||
| 1072 | # -- | ||
| 1073 | group_labels=group_labels, | ||
| 1074 | sample_output_dir=dreambooth_sample_output_dir, | ||
| 1075 | checkpoint_output_dir=dreambooth_checkpoint_output_dir, | ||
| 1076 | sample_frequency=dreambooth_sample_frequency, | ||
| 1077 | offset_noise_strength=args.offset_noise_strength, | ||
| 1078 | input_pertubation=args.input_pertubation, | ||
| 1079 | no_val=args.valid_set_size == 0, | ||
| 1080 | avg_loss=avg_loss, | ||
| 1081 | avg_acc=avg_acc, | ||
| 1082 | avg_loss_val=avg_loss_val, | ||
| 1083 | avg_acc_val=avg_acc_val, | ||
| 1084 | ) | ||
| 1085 | |||
| 1086 | training_iter += 1 | ||
| 1087 | |||
| 1088 | accelerator.end_training() | ||
| 723 | 1089 | ||
| 724 | 1090 | ||
| 725 | if __name__ == "__main__": | 1091 | if __name__ == "__main__": |
