diff options
Diffstat (limited to 'train.py')
| -rw-r--r-- | train.py | 672 |
1 files changed, 672 insertions, 0 deletions
diff --git a/train.py b/train.py new file mode 100644 index 0000000..d8644c4 --- /dev/null +++ b/train.py | |||
| @@ -0,0 +1,672 @@ | |||
| 1 | import argparse | ||
| 2 | import datetime | ||
| 3 | import logging | ||
| 4 | from pathlib import Path | ||
| 5 | |||
| 6 | import torch | ||
| 7 | import torch.utils.checkpoint | ||
| 8 | |||
| 9 | from accelerate import Accelerator | ||
| 10 | from accelerate.logging import get_logger | ||
| 11 | from accelerate.utils import LoggerType, set_seed | ||
| 12 | from slugify import slugify | ||
| 13 | |||
| 14 | from data.csv import VlpnDataModule, VlpnDataItem | ||
| 15 | from util import load_config, load_embeddings_from_dir | ||
| 16 | |||
| 17 | from trainer.ti import TextualInversionTrainingStrategy | ||
| 18 | from trainer.base import Trainer | ||
| 19 | from training.optimization import get_scheduler | ||
| 20 | from training.util import save_args, generate_class_images, add_placeholder_tokens, get_models | ||
| 21 | |||
| 22 | logger = get_logger(__name__) | ||
| 23 | |||
| 24 | |||
| 25 | torch.backends.cuda.matmul.allow_tf32 = True | ||
| 26 | torch.backends.cudnn.benchmark = True | ||
| 27 | |||
| 28 | |||
| 29 | def parse_args(): | ||
| 30 | parser = argparse.ArgumentParser( | ||
| 31 | description="Simple example of a training script." | ||
| 32 | ) | ||
| 33 | parser.add_argument( | ||
| 34 | "--pretrained_model_name_or_path", | ||
| 35 | type=str, | ||
| 36 | default=None, | ||
| 37 | help="Path to pretrained model or model identifier from huggingface.co/models.", | ||
| 38 | ) | ||
| 39 | parser.add_argument( | ||
| 40 | "--tokenizer_name", | ||
| 41 | type=str, | ||
| 42 | default=None, | ||
| 43 | help="Pretrained tokenizer name or path if not the same as model_name", | ||
| 44 | ) | ||
| 45 | parser.add_argument( | ||
| 46 | "--train_data_file", | ||
| 47 | type=str, | ||
| 48 | default=None, | ||
| 49 | help="A CSV file containing the training data." | ||
| 50 | ) | ||
| 51 | parser.add_argument( | ||
| 52 | "--train_data_template", | ||
| 53 | type=str, | ||
| 54 | default="template", | ||
| 55 | ) | ||
| 56 | parser.add_argument( | ||
| 57 | "--project", | ||
| 58 | type=str, | ||
| 59 | default=None, | ||
| 60 | help="The name of the current project.", | ||
| 61 | ) | ||
| 62 | parser.add_argument( | ||
| 63 | "--placeholder_tokens", | ||
| 64 | type=str, | ||
| 65 | nargs='*', | ||
| 66 | help="A token to use as a placeholder for the concept.", | ||
| 67 | ) | ||
| 68 | parser.add_argument( | ||
| 69 | "--initializer_tokens", | ||
| 70 | type=str, | ||
| 71 | nargs='*', | ||
| 72 | help="A token to use as initializer word." | ||
| 73 | ) | ||
| 74 | parser.add_argument( | ||
| 75 | "--num_vectors", | ||
| 76 | type=int, | ||
| 77 | nargs='*', | ||
| 78 | help="Number of vectors per embedding." | ||
| 79 | ) | ||
| 80 | parser.add_argument( | ||
| 81 | "--num_class_images", | ||
| 82 | type=int, | ||
| 83 | default=1, | ||
| 84 | help="How many class images to generate." | ||
| 85 | ) | ||
| 86 | parser.add_argument( | ||
| 87 | "--class_image_dir", | ||
| 88 | type=str, | ||
| 89 | default="cls", | ||
| 90 | help="The directory where class images will be saved.", | ||
| 91 | ) | ||
| 92 | parser.add_argument( | ||
| 93 | "--exclude_collections", | ||
| 94 | type=str, | ||
| 95 | nargs='*', | ||
| 96 | help="Exclude all items with a listed collection.", | ||
| 97 | ) | ||
| 98 | parser.add_argument( | ||
| 99 | "--output_dir", | ||
| 100 | type=str, | ||
| 101 | default="output/text-inversion", | ||
| 102 | help="The output directory where the model predictions and checkpoints will be written.", | ||
| 103 | ) | ||
| 104 | parser.add_argument( | ||
| 105 | "--embeddings_dir", | ||
| 106 | type=str, | ||
| 107 | default=None, | ||
| 108 | help="The embeddings directory where Textual Inversion embeddings are stored.", | ||
| 109 | ) | ||
| 110 | parser.add_argument( | ||
| 111 | "--collection", | ||
| 112 | type=str, | ||
| 113 | nargs='*', | ||
| 114 | help="A collection to filter the dataset.", | ||
| 115 | ) | ||
| 116 | parser.add_argument( | ||
| 117 | "--seed", | ||
| 118 | type=int, | ||
| 119 | default=None, | ||
| 120 | help="A seed for reproducible training." | ||
| 121 | ) | ||
| 122 | parser.add_argument( | ||
| 123 | "--resolution", | ||
| 124 | type=int, | ||
| 125 | default=768, | ||
| 126 | help=( | ||
| 127 | "The resolution for input images, all the images in the train/validation dataset will be resized to this" | ||
| 128 | " resolution" | ||
| 129 | ), | ||
| 130 | ) | ||
| 131 | parser.add_argument( | ||
| 132 | "--num_buckets", | ||
| 133 | type=int, | ||
| 134 | default=0, | ||
| 135 | help="Number of aspect ratio buckets in either direction.", | ||
| 136 | ) | ||
| 137 | parser.add_argument( | ||
| 138 | "--progressive_buckets", | ||
| 139 | action="store_true", | ||
| 140 | help="Include images in smaller buckets as well.", | ||
| 141 | ) | ||
| 142 | parser.add_argument( | ||
| 143 | "--bucket_step_size", | ||
| 144 | type=int, | ||
| 145 | default=64, | ||
| 146 | help="Step size between buckets.", | ||
| 147 | ) | ||
| 148 | parser.add_argument( | ||
| 149 | "--bucket_max_pixels", | ||
| 150 | type=int, | ||
| 151 | default=None, | ||
| 152 | help="Maximum pixels per bucket.", | ||
| 153 | ) | ||
| 154 | parser.add_argument( | ||
| 155 | "--tag_dropout", | ||
| 156 | type=float, | ||
| 157 | default=0, | ||
| 158 | help="Tag dropout probability.", | ||
| 159 | ) | ||
| 160 | parser.add_argument( | ||
| 161 | "--no_tag_shuffle", | ||
| 162 | action="store_true", | ||
| 163 | help="Shuffle tags.", | ||
| 164 | ) | ||
| 165 | parser.add_argument( | ||
| 166 | "--vector_dropout", | ||
| 167 | type=int, | ||
| 168 | default=0, | ||
| 169 | help="Vector dropout probability.", | ||
| 170 | ) | ||
| 171 | parser.add_argument( | ||
| 172 | "--vector_shuffle", | ||
| 173 | type=str, | ||
| 174 | default="auto", | ||
| 175 | help='Vector shuffling algorithm. Choose between ["all", "trailing", "leading", "between", "auto", "off"]', | ||
| 176 | ) | ||
| 177 | parser.add_argument( | ||
| 178 | "--num_train_epochs", | ||
| 179 | type=int, | ||
| 180 | default=100 | ||
| 181 | ) | ||
| 182 | parser.add_argument( | ||
| 183 | "--gradient_accumulation_steps", | ||
| 184 | type=int, | ||
| 185 | default=1, | ||
| 186 | help="Number of updates steps to accumulate before performing a backward/update pass.", | ||
| 187 | ) | ||
| 188 | parser.add_argument( | ||
| 189 | "--gradient_checkpointing", | ||
| 190 | action="store_true", | ||
| 191 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", | ||
| 192 | ) | ||
| 193 | parser.add_argument( | ||
| 194 | "--find_lr", | ||
| 195 | action="store_true", | ||
| 196 | help="Automatically find a learning rate (no training).", | ||
| 197 | ) | ||
| 198 | parser.add_argument( | ||
| 199 | "--learning_rate", | ||
| 200 | type=float, | ||
| 201 | default=1e-4, | ||
| 202 | help="Initial learning rate (after the potential warmup period) to use.", | ||
| 203 | ) | ||
| 204 | parser.add_argument( | ||
| 205 | "--scale_lr", | ||
| 206 | action="store_true", | ||
| 207 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", | ||
| 208 | ) | ||
| 209 | parser.add_argument( | ||
| 210 | "--lr_scheduler", | ||
| 211 | type=str, | ||
| 212 | default="one_cycle", | ||
| 213 | help=( | ||
| 214 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' | ||
| 215 | ' "constant", "constant_with_warmup", "one_cycle"]' | ||
| 216 | ), | ||
| 217 | ) | ||
| 218 | parser.add_argument( | ||
| 219 | "--lr_warmup_epochs", | ||
| 220 | type=int, | ||
| 221 | default=10, | ||
| 222 | help="Number of steps for the warmup in the lr scheduler." | ||
| 223 | ) | ||
| 224 | parser.add_argument( | ||
| 225 | "--lr_cycles", | ||
| 226 | type=int, | ||
| 227 | default=None, | ||
| 228 | help="Number of restart cycles in the lr scheduler." | ||
| 229 | ) | ||
| 230 | parser.add_argument( | ||
| 231 | "--lr_warmup_func", | ||
| 232 | type=str, | ||
| 233 | default="cos", | ||
| 234 | help='Choose between ["linear", "cos"]' | ||
| 235 | ) | ||
| 236 | parser.add_argument( | ||
| 237 | "--lr_warmup_exp", | ||
| 238 | type=int, | ||
| 239 | default=1, | ||
| 240 | help='If lr_warmup_func is "cos", exponent to modify the function' | ||
| 241 | ) | ||
| 242 | parser.add_argument( | ||
| 243 | "--lr_annealing_func", | ||
| 244 | type=str, | ||
| 245 | default="cos", | ||
| 246 | help='Choose between ["linear", "half_cos", "cos"]' | ||
| 247 | ) | ||
| 248 | parser.add_argument( | ||
| 249 | "--lr_annealing_exp", | ||
| 250 | type=int, | ||
| 251 | default=1, | ||
| 252 | help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' | ||
| 253 | ) | ||
| 254 | parser.add_argument( | ||
| 255 | "--lr_min_lr", | ||
| 256 | type=float, | ||
| 257 | default=0.04, | ||
| 258 | help="Minimum learning rate in the lr scheduler." | ||
| 259 | ) | ||
| 260 | parser.add_argument( | ||
| 261 | "--use_ema", | ||
| 262 | action="store_true", | ||
| 263 | help="Whether to use EMA model." | ||
| 264 | ) | ||
| 265 | parser.add_argument( | ||
| 266 | "--ema_inv_gamma", | ||
| 267 | type=float, | ||
| 268 | default=1.0 | ||
| 269 | ) | ||
| 270 | parser.add_argument( | ||
| 271 | "--ema_power", | ||
| 272 | type=float, | ||
| 273 | default=1 | ||
| 274 | ) | ||
| 275 | parser.add_argument( | ||
| 276 | "--ema_max_decay", | ||
| 277 | type=float, | ||
| 278 | default=0.9999 | ||
| 279 | ) | ||
| 280 | parser.add_argument( | ||
| 281 | "--use_8bit_adam", | ||
| 282 | action="store_true", | ||
| 283 | help="Whether or not to use 8-bit Adam from bitsandbytes." | ||
| 284 | ) | ||
| 285 | parser.add_argument( | ||
| 286 | "--adam_beta1", | ||
| 287 | type=float, | ||
| 288 | default=0.9, | ||
| 289 | help="The beta1 parameter for the Adam optimizer." | ||
| 290 | ) | ||
| 291 | parser.add_argument( | ||
| 292 | "--adam_beta2", | ||
| 293 | type=float, | ||
| 294 | default=0.999, | ||
| 295 | help="The beta2 parameter for the Adam optimizer." | ||
| 296 | ) | ||
| 297 | parser.add_argument( | ||
| 298 | "--adam_weight_decay", | ||
| 299 | type=float, | ||
| 300 | default=0, | ||
| 301 | help="Weight decay to use." | ||
| 302 | ) | ||
| 303 | parser.add_argument( | ||
| 304 | "--adam_epsilon", | ||
| 305 | type=float, | ||
| 306 | default=1e-08, | ||
| 307 | help="Epsilon value for the Adam optimizer" | ||
| 308 | ) | ||
| 309 | parser.add_argument( | ||
| 310 | "--adam_amsgrad", | ||
| 311 | type=bool, | ||
| 312 | default=False, | ||
| 313 | help="Amsgrad value for the Adam optimizer" | ||
| 314 | ) | ||
| 315 | parser.add_argument( | ||
| 316 | "--mixed_precision", | ||
| 317 | type=str, | ||
| 318 | default="no", | ||
| 319 | choices=["no", "fp16", "bf16"], | ||
| 320 | help=( | ||
| 321 | "Whether to use mixed precision. Choose" | ||
| 322 | "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." | ||
| 323 | "and an Nvidia Ampere GPU." | ||
| 324 | ), | ||
| 325 | ) | ||
| 326 | parser.add_argument( | ||
| 327 | "--checkpoint_frequency", | ||
| 328 | type=int, | ||
| 329 | default=5, | ||
| 330 | help="How often to save a checkpoint and sample image (in epochs)", | ||
| 331 | ) | ||
| 332 | parser.add_argument( | ||
| 333 | "--sample_frequency", | ||
| 334 | type=int, | ||
| 335 | default=1, | ||
| 336 | help="How often to save a checkpoint and sample image (in epochs)", | ||
| 337 | ) | ||
| 338 | parser.add_argument( | ||
| 339 | "--sample_image_size", | ||
| 340 | type=int, | ||
| 341 | default=768, | ||
| 342 | help="Size of sample images", | ||
| 343 | ) | ||
| 344 | parser.add_argument( | ||
| 345 | "--sample_batches", | ||
| 346 | type=int, | ||
| 347 | default=1, | ||
| 348 | help="Number of sample batches to generate per checkpoint", | ||
| 349 | ) | ||
| 350 | parser.add_argument( | ||
| 351 | "--sample_batch_size", | ||
| 352 | type=int, | ||
| 353 | default=1, | ||
| 354 | help="Number of samples to generate per batch", | ||
| 355 | ) | ||
| 356 | parser.add_argument( | ||
| 357 | "--valid_set_size", | ||
| 358 | type=int, | ||
| 359 | default=None, | ||
| 360 | help="Number of images in the validation dataset." | ||
| 361 | ) | ||
| 362 | parser.add_argument( | ||
| 363 | "--valid_set_repeat", | ||
| 364 | type=int, | ||
| 365 | default=1, | ||
| 366 | help="Times the images in the validation dataset are repeated." | ||
| 367 | ) | ||
| 368 | parser.add_argument( | ||
| 369 | "--train_batch_size", | ||
| 370 | type=int, | ||
| 371 | default=1, | ||
| 372 | help="Batch size (per device) for the training dataloader." | ||
| 373 | ) | ||
| 374 | parser.add_argument( | ||
| 375 | "--sample_steps", | ||
| 376 | type=int, | ||
| 377 | default=20, | ||
| 378 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", | ||
| 379 | ) | ||
| 380 | parser.add_argument( | ||
| 381 | "--prior_loss_weight", | ||
| 382 | type=float, | ||
| 383 | default=1.0, | ||
| 384 | help="The weight of prior preservation loss." | ||
| 385 | ) | ||
| 386 | parser.add_argument( | ||
| 387 | "--emb_decay_target", | ||
| 388 | default=0.4, | ||
| 389 | type=float, | ||
| 390 | help="Embedding decay target." | ||
| 391 | ) | ||
| 392 | parser.add_argument( | ||
| 393 | "--emb_decay_factor", | ||
| 394 | default=0, | ||
| 395 | type=float, | ||
| 396 | help="Embedding decay factor." | ||
| 397 | ) | ||
| 398 | parser.add_argument( | ||
| 399 | "--emb_decay_start", | ||
| 400 | default=1e-4, | ||
| 401 | type=float, | ||
| 402 | help="Embedding decay start offset." | ||
| 403 | ) | ||
| 404 | parser.add_argument( | ||
| 405 | "--noise_timesteps", | ||
| 406 | type=int, | ||
| 407 | default=1000, | ||
| 408 | ) | ||
| 409 | parser.add_argument( | ||
| 410 | "--resume_from", | ||
| 411 | type=str, | ||
| 412 | default=None, | ||
| 413 | help="Path to a directory to resume training from (ie, logs/token_name/2022-09-22T23-36-27)" | ||
| 414 | ) | ||
| 415 | parser.add_argument( | ||
| 416 | "--global_step", | ||
| 417 | type=int, | ||
| 418 | default=0, | ||
| 419 | ) | ||
| 420 | parser.add_argument( | ||
| 421 | "--config", | ||
| 422 | type=str, | ||
| 423 | default=None, | ||
| 424 | help="Path to a JSON configuration file containing arguments for invoking this script." | ||
| 425 | ) | ||
| 426 | |||
| 427 | args = parser.parse_args() | ||
| 428 | if args.config is not None: | ||
| 429 | args = load_config(args.config) | ||
| 430 | args = parser.parse_args(namespace=argparse.Namespace(**args)) | ||
| 431 | |||
| 432 | if args.train_data_file is None: | ||
| 433 | raise ValueError("You must specify --train_data_file") | ||
| 434 | |||
| 435 | if args.pretrained_model_name_or_path is None: | ||
| 436 | raise ValueError("You must specify --pretrained_model_name_or_path") | ||
| 437 | |||
| 438 | if args.project is None: | ||
| 439 | raise ValueError("You must specify --project") | ||
| 440 | |||
| 441 | if isinstance(args.placeholder_tokens, str): | ||
| 442 | args.placeholder_tokens = [args.placeholder_tokens] | ||
| 443 | |||
| 444 | if len(args.placeholder_tokens) == 0: | ||
| 445 | args.placeholder_tokens = [f"<*{i}>" for i in range(args.initializer_tokens)] | ||
| 446 | |||
| 447 | if isinstance(args.initializer_tokens, str): | ||
| 448 | args.initializer_tokens = [args.initializer_tokens] * len(args.placeholder_tokens) | ||
| 449 | |||
| 450 | if len(args.initializer_tokens) == 0: | ||
| 451 | raise ValueError("You must specify --initializer_tokens") | ||
| 452 | |||
| 453 | if len(args.placeholder_tokens) != len(args.initializer_tokens): | ||
| 454 | raise ValueError("--placeholder_tokens and --initializer_tokens must have the same number of items") | ||
| 455 | |||
| 456 | if args.num_vectors is None: | ||
| 457 | args.num_vectors = 1 | ||
| 458 | |||
| 459 | if isinstance(args.num_vectors, int): | ||
| 460 | args.num_vectors = [args.num_vectors] * len(args.initializer_tokens) | ||
| 461 | |||
| 462 | if len(args.placeholder_tokens) != len(args.num_vectors): | ||
| 463 | raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") | ||
| 464 | |||
| 465 | if isinstance(args.collection, str): | ||
| 466 | args.collection = [args.collection] | ||
| 467 | |||
| 468 | if isinstance(args.exclude_collections, str): | ||
| 469 | args.exclude_collections = [args.exclude_collections] | ||
| 470 | |||
| 471 | if args.output_dir is None: | ||
| 472 | raise ValueError("You must specify --output_dir") | ||
| 473 | |||
| 474 | return args | ||
| 475 | |||
| 476 | |||
| 477 | def main(): | ||
| 478 | args = parse_args() | ||
| 479 | |||
| 480 | global_step_offset = args.global_step | ||
| 481 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | ||
| 482 | output_dir = Path(args.output_dir).joinpath(slugify(args.project), now) | ||
| 483 | output_dir.mkdir(parents=True, exist_ok=True) | ||
| 484 | |||
| 485 | accelerator = Accelerator( | ||
| 486 | log_with=LoggerType.TENSORBOARD, | ||
| 487 | logging_dir=f"{output_dir}", | ||
| 488 | gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
| 489 | mixed_precision=args.mixed_precision | ||
| 490 | ) | ||
| 491 | |||
| 492 | logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) | ||
| 493 | |||
| 494 | if args.seed is None: | ||
| 495 | args.seed = torch.random.seed() >> 32 | ||
| 496 | |||
| 497 | set_seed(args.seed) | ||
| 498 | |||
| 499 | save_args(output_dir, args) | ||
| 500 | |||
| 501 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( | ||
| 502 | args.pretrained_model_name_or_path) | ||
| 503 | |||
| 504 | tokenizer.set_use_vector_shuffle(args.vector_shuffle) | ||
| 505 | tokenizer.set_dropout(args.vector_dropout) | ||
| 506 | |||
| 507 | vae.enable_slicing() | ||
| 508 | vae.set_use_memory_efficient_attention_xformers(True) | ||
| 509 | unet.set_use_memory_efficient_attention_xformers(True) | ||
| 510 | |||
| 511 | if args.gradient_checkpointing: | ||
| 512 | unet.enable_gradient_checkpointing() | ||
| 513 | text_encoder.gradient_checkpointing_enable() | ||
| 514 | |||
| 515 | if args.embeddings_dir is not None: | ||
| 516 | embeddings_dir = Path(args.embeddings_dir) | ||
| 517 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): | ||
| 518 | raise ValueError("--embeddings_dir must point to an existing directory") | ||
| 519 | |||
| 520 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) | ||
| 521 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") | ||
| 522 | |||
| 523 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( | ||
| 524 | tokenizer=tokenizer, | ||
| 525 | embeddings=embeddings, | ||
| 526 | placeholder_tokens=args.placeholder_tokens, | ||
| 527 | initializer_tokens=args.initializer_tokens, | ||
| 528 | num_vectors=args.num_vectors | ||
| 529 | ) | ||
| 530 | |||
| 531 | if len(placeholder_token_ids) != 0: | ||
| 532 | initializer_token_id_lens = [len(id) for id in initializer_token_ids] | ||
| 533 | placeholder_token_stats = list(zip(args.placeholder_tokens, placeholder_token_ids, initializer_token_id_lens)) | ||
| 534 | print(f"Added {len(placeholder_token_ids)} new tokens: {placeholder_token_stats}") | ||
| 535 | |||
| 536 | if args.scale_lr: | ||
| 537 | args.learning_rate = ( | ||
| 538 | args.learning_rate * args.gradient_accumulation_steps * | ||
| 539 | args.train_batch_size * accelerator.num_processes | ||
| 540 | ) | ||
| 541 | |||
| 542 | if args.find_lr: | ||
| 543 | args.learning_rate = 1e-5 | ||
| 544 | |||
| 545 | if args.use_8bit_adam: | ||
| 546 | try: | ||
| 547 | import bitsandbytes as bnb | ||
| 548 | except ImportError: | ||
| 549 | raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.") | ||
| 550 | |||
| 551 | optimizer_class = bnb.optim.AdamW8bit | ||
| 552 | else: | ||
| 553 | optimizer_class = torch.optim.AdamW | ||
| 554 | |||
| 555 | optimizer = optimizer_class( | ||
| 556 | text_encoder.text_model.embeddings.temp_token_embedding.parameters(), | ||
| 557 | lr=args.learning_rate, | ||
| 558 | betas=(args.adam_beta1, args.adam_beta2), | ||
| 559 | weight_decay=args.adam_weight_decay, | ||
| 560 | eps=args.adam_epsilon, | ||
| 561 | amsgrad=args.adam_amsgrad, | ||
| 562 | ) | ||
| 563 | |||
| 564 | weight_dtype = torch.float32 | ||
| 565 | if args.mixed_precision == "fp16": | ||
| 566 | weight_dtype = torch.float16 | ||
| 567 | elif args.mixed_precision == "bf16": | ||
| 568 | weight_dtype = torch.bfloat16 | ||
| 569 | |||
| 570 | def keyword_filter(item: VlpnDataItem): | ||
| 571 | cond1 = any( | ||
| 572 | keyword in part | ||
| 573 | for keyword in args.placeholder_tokens | ||
| 574 | for part in item.prompt | ||
| 575 | ) | ||
| 576 | cond3 = args.collection is None or args.collection in item.collection | ||
| 577 | cond4 = args.exclude_collections is None or not any( | ||
| 578 | collection in item.collection | ||
| 579 | for collection in args.exclude_collections | ||
| 580 | ) | ||
| 581 | return cond1 and cond3 and cond4 | ||
| 582 | |||
| 583 | datamodule = VlpnDataModule( | ||
| 584 | data_file=args.train_data_file, | ||
| 585 | batch_size=args.train_batch_size, | ||
| 586 | tokenizer=tokenizer, | ||
| 587 | class_subdir=args.class_image_dir, | ||
| 588 | num_class_images=args.num_class_images, | ||
| 589 | size=args.resolution, | ||
| 590 | num_buckets=args.num_buckets, | ||
| 591 | progressive_buckets=args.progressive_buckets, | ||
| 592 | bucket_step_size=args.bucket_step_size, | ||
| 593 | bucket_max_pixels=args.bucket_max_pixels, | ||
| 594 | dropout=args.tag_dropout, | ||
| 595 | shuffle=not args.no_tag_shuffle, | ||
| 596 | template_key=args.train_data_template, | ||
| 597 | valid_set_size=args.valid_set_size, | ||
| 598 | valid_set_repeat=args.valid_set_repeat, | ||
| 599 | seed=args.seed, | ||
| 600 | filter=keyword_filter, | ||
| 601 | dtype=weight_dtype | ||
| 602 | ) | ||
| 603 | datamodule.setup() | ||
| 604 | |||
| 605 | train_dataloader = datamodule.train_dataloader | ||
| 606 | val_dataloader = datamodule.val_dataloader | ||
| 607 | |||
| 608 | if args.num_class_images != 0: | ||
| 609 | generate_class_images( | ||
| 610 | accelerator, | ||
| 611 | text_encoder, | ||
| 612 | vae, | ||
| 613 | unet, | ||
| 614 | tokenizer, | ||
| 615 | sample_scheduler, | ||
| 616 | datamodule.data_train, | ||
| 617 | args.sample_batch_size, | ||
| 618 | args.sample_image_size, | ||
| 619 | args.sample_steps | ||
| 620 | ) | ||
| 621 | |||
| 622 | lr_scheduler = get_scheduler( | ||
| 623 | args.lr_scheduler, | ||
| 624 | optimizer=optimizer, | ||
| 625 | num_training_steps_per_epoch=len(train_dataloader), | ||
| 626 | gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
| 627 | min_lr=args.lr_min_lr, | ||
| 628 | warmup_func=args.lr_warmup_func, | ||
| 629 | annealing_func=args.lr_annealing_func, | ||
| 630 | warmup_exp=args.lr_warmup_exp, | ||
| 631 | annealing_exp=args.lr_annealing_exp, | ||
| 632 | cycles=args.lr_cycles, | ||
| 633 | train_epochs=args.num_train_epochs, | ||
| 634 | warmup_epochs=args.lr_warmup_epochs, | ||
| 635 | ) | ||
| 636 | |||
| 637 | trainer = Trainer( | ||
| 638 | accelerator=accelerator, | ||
| 639 | unet=unet, | ||
| 640 | text_encoder=text_encoder, | ||
| 641 | tokenizer=tokenizer, | ||
| 642 | vae=vae, | ||
| 643 | noise_scheduler=noise_scheduler, | ||
| 644 | sample_scheduler=sample_scheduler, | ||
| 645 | train_dataloader=train_dataloader, | ||
| 646 | val_dataloader=val_dataloader, | ||
| 647 | dtype=weight_dtype, | ||
| 648 | ) | ||
| 649 | |||
| 650 | trainer( | ||
| 651 | strategy_class=TextualInversionTrainingStrategy, | ||
| 652 | optimizer=optimizer, | ||
| 653 | lr_scheduler=lr_scheduler, | ||
| 654 | num_train_epochs=args.num_train_epochs, | ||
| 655 | sample_frequency=args.sample_frequency, | ||
| 656 | checkpoint_frequency=args.checkpoint_frequency, | ||
| 657 | global_step_offset=global_step_offset, | ||
| 658 | prior_loss_weight=args.prior_loss_weight, | ||
| 659 | output_dir=output_dir, | ||
| 660 | placeholder_tokens=args.placeholder_tokens, | ||
| 661 | placeholder_token_ids=placeholder_token_ids, | ||
| 662 | learning_rate=args.learning_rate, | ||
| 663 | sample_steps=args.sample_steps, | ||
| 664 | sample_image_size=args.sample_image_size, | ||
| 665 | sample_batch_size=args.sample_batch_size, | ||
| 666 | sample_batches=args.sample_batches, | ||
| 667 | seed=args.seed, | ||
| 668 | ) | ||
| 669 | |||
| 670 | |||
| 671 | if __name__ == "__main__": | ||
| 672 | main() | ||
