diff options
Diffstat (limited to 'train_lora.py')
| -rw-r--r-- | train_lora.py | 566 |
1 files changed, 566 insertions, 0 deletions
diff --git a/train_lora.py b/train_lora.py new file mode 100644 index 0000000..2cb85cc --- /dev/null +++ b/train_lora.py | |||
| @@ -0,0 +1,566 @@ | |||
| 1 | import argparse | ||
| 2 | import datetime | ||
| 3 | import logging | ||
| 4 | import itertools | ||
| 5 | from pathlib import Path | ||
| 6 | from functools import partial | ||
| 7 | |||
| 8 | import torch | ||
| 9 | import torch.utils.checkpoint | ||
| 10 | |||
| 11 | from accelerate import Accelerator | ||
| 12 | from accelerate.logging import get_logger | ||
| 13 | from accelerate.utils import LoggerType, set_seed | ||
| 14 | from slugify import slugify | ||
| 15 | from diffusers.loaders import AttnProcsLayers | ||
| 16 | from diffusers.models.cross_attention import LoRACrossAttnProcessor | ||
| 17 | |||
| 18 | from util import load_config, load_embeddings_from_dir | ||
| 19 | from data.csv import VlpnDataModule, keyword_filter | ||
| 20 | from training.functional import train, get_models | ||
| 21 | from training.lr import plot_metrics | ||
| 22 | from training.strategy.lora import lora_strategy | ||
| 23 | from training.optimization import get_scheduler | ||
| 24 | from training.util import save_args | ||
| 25 | |||
| 26 | logger = get_logger(__name__) | ||
| 27 | |||
| 28 | |||
| 29 | torch.backends.cuda.matmul.allow_tf32 = True | ||
| 30 | torch.backends.cudnn.benchmark = True | ||
| 31 | |||
| 32 | |||
| 33 | def parse_args(): | ||
| 34 | parser = argparse.ArgumentParser( | ||
| 35 | description="Simple example of a training script." | ||
| 36 | ) | ||
| 37 | parser.add_argument( | ||
| 38 | "--pretrained_model_name_or_path", | ||
| 39 | type=str, | ||
| 40 | default=None, | ||
| 41 | help="Path to pretrained model or model identifier from huggingface.co/models.", | ||
| 42 | ) | ||
| 43 | parser.add_argument( | ||
| 44 | "--tokenizer_name", | ||
| 45 | type=str, | ||
| 46 | default=None, | ||
| 47 | help="Pretrained tokenizer name or path if not the same as model_name", | ||
| 48 | ) | ||
| 49 | parser.add_argument( | ||
| 50 | "--train_data_file", | ||
| 51 | type=str, | ||
| 52 | default=None, | ||
| 53 | help="A folder containing the training data." | ||
| 54 | ) | ||
| 55 | parser.add_argument( | ||
| 56 | "--train_data_template", | ||
| 57 | type=str, | ||
| 58 | default="template", | ||
| 59 | ) | ||
| 60 | parser.add_argument( | ||
| 61 | "--train_set_pad", | ||
| 62 | type=int, | ||
| 63 | default=None, | ||
| 64 | help="The number to fill train dataset items up to." | ||
| 65 | ) | ||
| 66 | parser.add_argument( | ||
| 67 | "--valid_set_pad", | ||
| 68 | type=int, | ||
| 69 | default=None, | ||
| 70 | help="The number to fill validation dataset items up to." | ||
| 71 | ) | ||
| 72 | parser.add_argument( | ||
| 73 | "--project", | ||
| 74 | type=str, | ||
| 75 | default=None, | ||
| 76 | help="The name of the current project.", | ||
| 77 | ) | ||
| 78 | parser.add_argument( | ||
| 79 | "--exclude_collections", | ||
| 80 | type=str, | ||
| 81 | nargs='*', | ||
| 82 | help="Exclude all items with a listed collection.", | ||
| 83 | ) | ||
| 84 | parser.add_argument( | ||
| 85 | "--num_buckets", | ||
| 86 | type=int, | ||
| 87 | default=4, | ||
| 88 | help="Number of aspect ratio buckets in either direction.", | ||
| 89 | ) | ||
| 90 | parser.add_argument( | ||
| 91 | "--progressive_buckets", | ||
| 92 | action="store_true", | ||
| 93 | help="Include images in smaller buckets as well.", | ||
| 94 | ) | ||
| 95 | parser.add_argument( | ||
| 96 | "--bucket_step_size", | ||
| 97 | type=int, | ||
| 98 | default=64, | ||
| 99 | help="Step size between buckets.", | ||
| 100 | ) | ||
| 101 | parser.add_argument( | ||
| 102 | "--bucket_max_pixels", | ||
| 103 | type=int, | ||
| 104 | default=None, | ||
| 105 | help="Maximum pixels per bucket.", | ||
| 106 | ) | ||
| 107 | parser.add_argument( | ||
| 108 | "--tag_dropout", | ||
| 109 | type=float, | ||
| 110 | default=0.1, | ||
| 111 | help="Tag dropout probability.", | ||
| 112 | ) | ||
| 113 | parser.add_argument( | ||
| 114 | "--no_tag_shuffle", | ||
| 115 | action="store_true", | ||
| 116 | help="Shuffle tags.", | ||
| 117 | ) | ||
| 118 | parser.add_argument( | ||
| 119 | "--num_class_images", | ||
| 120 | type=int, | ||
| 121 | default=0, | ||
| 122 | help="How many class images to generate." | ||
| 123 | ) | ||
| 124 | parser.add_argument( | ||
| 125 | "--class_image_dir", | ||
| 126 | type=str, | ||
| 127 | default="cls", | ||
| 128 | help="The directory where class images will be saved.", | ||
| 129 | ) | ||
| 130 | parser.add_argument( | ||
| 131 | "--output_dir", | ||
| 132 | type=str, | ||
| 133 | default="output/lora", | ||
| 134 | help="The output directory where the model predictions and checkpoints will be written.", | ||
| 135 | ) | ||
| 136 | parser.add_argument( | ||
| 137 | "--embeddings_dir", | ||
| 138 | type=str, | ||
| 139 | default=None, | ||
| 140 | help="The embeddings directory where Textual Inversion embeddings are stored.", | ||
| 141 | ) | ||
| 142 | parser.add_argument( | ||
| 143 | "--collection", | ||
| 144 | type=str, | ||
| 145 | nargs='*', | ||
| 146 | help="A collection to filter the dataset.", | ||
| 147 | ) | ||
| 148 | parser.add_argument( | ||
| 149 | "--seed", | ||
| 150 | type=int, | ||
| 151 | default=None, | ||
| 152 | help="A seed for reproducible training." | ||
| 153 | ) | ||
| 154 | parser.add_argument( | ||
| 155 | "--resolution", | ||
| 156 | type=int, | ||
| 157 | default=768, | ||
| 158 | help=( | ||
| 159 | "The resolution for input images, all the images in the train/validation dataset will be resized to this" | ||
| 160 | " resolution" | ||
| 161 | ), | ||
| 162 | ) | ||
| 163 | parser.add_argument( | ||
| 164 | "--num_train_epochs", | ||
| 165 | type=int, | ||
| 166 | default=100 | ||
| 167 | ) | ||
| 168 | parser.add_argument( | ||
| 169 | "--max_train_steps", | ||
| 170 | type=int, | ||
| 171 | default=None, | ||
| 172 | help="Total number of training steps to perform. If provided, overrides num_train_epochs.", | ||
| 173 | ) | ||
| 174 | parser.add_argument( | ||
| 175 | "--gradient_accumulation_steps", | ||
| 176 | type=int, | ||
| 177 | default=1, | ||
| 178 | help="Number of updates steps to accumulate before performing a backward/update pass.", | ||
| 179 | ) | ||
| 180 | parser.add_argument( | ||
| 181 | "--find_lr", | ||
| 182 | action="store_true", | ||
| 183 | help="Automatically find a learning rate (no training).", | ||
| 184 | ) | ||
| 185 | parser.add_argument( | ||
| 186 | "--learning_rate", | ||
| 187 | type=float, | ||
| 188 | default=2e-6, | ||
| 189 | help="Initial learning rate (after the potential warmup period) to use.", | ||
| 190 | ) | ||
| 191 | parser.add_argument( | ||
| 192 | "--scale_lr", | ||
| 193 | action="store_true", | ||
| 194 | help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.", | ||
| 195 | ) | ||
| 196 | parser.add_argument( | ||
| 197 | "--lr_scheduler", | ||
| 198 | type=str, | ||
| 199 | default="one_cycle", | ||
| 200 | help=( | ||
| 201 | 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",' | ||
| 202 | ' "constant", "constant_with_warmup", "one_cycle"]' | ||
| 203 | ), | ||
| 204 | ) | ||
| 205 | parser.add_argument( | ||
| 206 | "--lr_warmup_epochs", | ||
| 207 | type=int, | ||
| 208 | default=10, | ||
| 209 | help="Number of steps for the warmup in the lr scheduler." | ||
| 210 | ) | ||
| 211 | parser.add_argument( | ||
| 212 | "--lr_cycles", | ||
| 213 | type=int, | ||
| 214 | default=None, | ||
| 215 | help="Number of restart cycles in the lr scheduler (if supported)." | ||
| 216 | ) | ||
| 217 | parser.add_argument( | ||
| 218 | "--lr_warmup_func", | ||
| 219 | type=str, | ||
| 220 | default="cos", | ||
| 221 | help='Choose between ["linear", "cos"]' | ||
| 222 | ) | ||
| 223 | parser.add_argument( | ||
| 224 | "--lr_warmup_exp", | ||
| 225 | type=int, | ||
| 226 | default=1, | ||
| 227 | help='If lr_warmup_func is "cos", exponent to modify the function' | ||
| 228 | ) | ||
| 229 | parser.add_argument( | ||
| 230 | "--lr_annealing_func", | ||
| 231 | type=str, | ||
| 232 | default="cos", | ||
| 233 | help='Choose between ["linear", "half_cos", "cos"]' | ||
| 234 | ) | ||
| 235 | parser.add_argument( | ||
| 236 | "--lr_annealing_exp", | ||
| 237 | type=int, | ||
| 238 | default=3, | ||
| 239 | help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' | ||
| 240 | ) | ||
| 241 | parser.add_argument( | ||
| 242 | "--lr_min_lr", | ||
| 243 | type=float, | ||
| 244 | default=0.04, | ||
| 245 | help="Minimum learning rate in the lr scheduler." | ||
| 246 | ) | ||
| 247 | parser.add_argument( | ||
| 248 | "--use_8bit_adam", | ||
| 249 | action="store_true", | ||
| 250 | help="Whether or not to use 8-bit Adam from bitsandbytes." | ||
| 251 | ) | ||
| 252 | parser.add_argument( | ||
| 253 | "--adam_beta1", | ||
| 254 | type=float, | ||
| 255 | default=0.9, | ||
| 256 | help="The beta1 parameter for the Adam optimizer." | ||
| 257 | ) | ||
| 258 | parser.add_argument( | ||
| 259 | "--adam_beta2", | ||
| 260 | type=float, | ||
| 261 | default=0.999, | ||
| 262 | help="The beta2 parameter for the Adam optimizer." | ||
| 263 | ) | ||
| 264 | parser.add_argument( | ||
| 265 | "--adam_weight_decay", | ||
| 266 | type=float, | ||
| 267 | default=1e-2, | ||
| 268 | help="Weight decay to use." | ||
| 269 | ) | ||
| 270 | parser.add_argument( | ||
| 271 | "--adam_epsilon", | ||
| 272 | type=float, | ||
| 273 | default=1e-08, | ||
| 274 | help="Epsilon value for the Adam optimizer" | ||
| 275 | ) | ||
| 276 | parser.add_argument( | ||
| 277 | "--adam_amsgrad", | ||
| 278 | type=bool, | ||
| 279 | default=False, | ||
| 280 | help="Amsgrad value for the Adam optimizer" | ||
| 281 | ) | ||
| 282 | parser.add_argument( | ||
| 283 | "--mixed_precision", | ||
| 284 | type=str, | ||
| 285 | default="no", | ||
| 286 | choices=["no", "fp16", "bf16"], | ||
| 287 | help=( | ||
| 288 | "Whether to use mixed precision. Choose" | ||
| 289 | "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10." | ||
| 290 | "and an Nvidia Ampere GPU." | ||
| 291 | ), | ||
| 292 | ) | ||
| 293 | parser.add_argument( | ||
| 294 | "--sample_frequency", | ||
| 295 | type=int, | ||
| 296 | default=1, | ||
| 297 | help="How often to save a checkpoint and sample image", | ||
| 298 | ) | ||
| 299 | parser.add_argument( | ||
| 300 | "--sample_image_size", | ||
| 301 | type=int, | ||
| 302 | default=768, | ||
| 303 | help="Size of sample images", | ||
| 304 | ) | ||
| 305 | parser.add_argument( | ||
| 306 | "--sample_batches", | ||
| 307 | type=int, | ||
| 308 | default=1, | ||
| 309 | help="Number of sample batches to generate per checkpoint", | ||
| 310 | ) | ||
| 311 | parser.add_argument( | ||
| 312 | "--sample_batch_size", | ||
| 313 | type=int, | ||
| 314 | default=1, | ||
| 315 | help="Number of samples to generate per batch", | ||
| 316 | ) | ||
| 317 | parser.add_argument( | ||
| 318 | "--valid_set_size", | ||
| 319 | type=int, | ||
| 320 | default=None, | ||
| 321 | help="Number of images in the validation dataset." | ||
| 322 | ) | ||
| 323 | parser.add_argument( | ||
| 324 | "--valid_set_repeat", | ||
| 325 | type=int, | ||
| 326 | default=1, | ||
| 327 | help="Times the images in the validation dataset are repeated." | ||
| 328 | ) | ||
| 329 | parser.add_argument( | ||
| 330 | "--train_batch_size", | ||
| 331 | type=int, | ||
| 332 | default=1, | ||
| 333 | help="Batch size (per device) for the training dataloader." | ||
| 334 | ) | ||
| 335 | parser.add_argument( | ||
| 336 | "--sample_steps", | ||
| 337 | type=int, | ||
| 338 | default=20, | ||
| 339 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", | ||
| 340 | ) | ||
| 341 | parser.add_argument( | ||
| 342 | "--prior_loss_weight", | ||
| 343 | type=float, | ||
| 344 | default=1.0, | ||
| 345 | help="The weight of prior preservation loss." | ||
| 346 | ) | ||
| 347 | parser.add_argument( | ||
| 348 | "--max_grad_norm", | ||
| 349 | default=1.0, | ||
| 350 | type=float, | ||
| 351 | help="Max gradient norm." | ||
| 352 | ) | ||
| 353 | parser.add_argument( | ||
| 354 | "--noise_timesteps", | ||
| 355 | type=int, | ||
| 356 | default=1000, | ||
| 357 | ) | ||
| 358 | parser.add_argument( | ||
| 359 | "--config", | ||
| 360 | type=str, | ||
| 361 | default=None, | ||
| 362 | help="Path to a JSON configuration file containing arguments for invoking this script." | ||
| 363 | ) | ||
| 364 | |||
| 365 | args = parser.parse_args() | ||
| 366 | if args.config is not None: | ||
| 367 | args = load_config(args.config) | ||
| 368 | args = parser.parse_args(namespace=argparse.Namespace(**args)) | ||
| 369 | |||
| 370 | if args.train_data_file is None: | ||
| 371 | raise ValueError("You must specify --train_data_file") | ||
| 372 | |||
| 373 | if args.pretrained_model_name_or_path is None: | ||
| 374 | raise ValueError("You must specify --pretrained_model_name_or_path") | ||
| 375 | |||
| 376 | if args.project is None: | ||
| 377 | raise ValueError("You must specify --project") | ||
| 378 | |||
| 379 | if isinstance(args.collection, str): | ||
| 380 | args.collection = [args.collection] | ||
| 381 | |||
| 382 | if isinstance(args.exclude_collections, str): | ||
| 383 | args.exclude_collections = [args.exclude_collections] | ||
| 384 | |||
| 385 | if args.output_dir is None: | ||
| 386 | raise ValueError("You must specify --output_dir") | ||
| 387 | |||
| 388 | return args | ||
| 389 | |||
| 390 | |||
| 391 | def main(): | ||
| 392 | args = parse_args() | ||
| 393 | |||
| 394 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | ||
| 395 | output_dir = Path(args.output_dir).joinpath(slugify(args.project), now) | ||
| 396 | output_dir.mkdir(parents=True, exist_ok=True) | ||
| 397 | |||
| 398 | accelerator = Accelerator( | ||
| 399 | log_with=LoggerType.TENSORBOARD, | ||
| 400 | logging_dir=f"{output_dir}", | ||
| 401 | gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
| 402 | mixed_precision=args.mixed_precision | ||
| 403 | ) | ||
| 404 | |||
| 405 | logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG) | ||
| 406 | |||
| 407 | if args.seed is None: | ||
| 408 | args.seed = torch.random.seed() >> 32 | ||
| 409 | |||
| 410 | set_seed(args.seed) | ||
| 411 | |||
| 412 | save_args(output_dir, args) | ||
| 413 | |||
| 414 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( | ||
| 415 | args.pretrained_model_name_or_path) | ||
| 416 | |||
| 417 | vae.enable_slicing() | ||
| 418 | vae.set_use_memory_efficient_attention_xformers(True) | ||
| 419 | unet.enable_xformers_memory_efficient_attention() | ||
| 420 | |||
| 421 | lora_attn_procs = {} | ||
| 422 | for name in unet.attn_processors.keys(): | ||
| 423 | cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim | ||
| 424 | if name.startswith("mid_block"): | ||
| 425 | hidden_size = unet.config.block_out_channels[-1] | ||
| 426 | elif name.startswith("up_blocks"): | ||
| 427 | block_id = int(name[len("up_blocks.")]) | ||
| 428 | hidden_size = list(reversed(unet.config.block_out_channels))[block_id] | ||
| 429 | elif name.startswith("down_blocks"): | ||
| 430 | block_id = int(name[len("down_blocks.")]) | ||
| 431 | hidden_size = unet.config.block_out_channels[block_id] | ||
| 432 | |||
| 433 | lora_attn_procs[name] = LoRACrossAttnProcessor( | ||
| 434 | hidden_size=hidden_size, cross_attention_dim=cross_attention_dim | ||
| 435 | ) | ||
| 436 | |||
| 437 | unet.set_attn_processor(lora_attn_procs) | ||
| 438 | lora_layers = AttnProcsLayers(unet.attn_processors) | ||
| 439 | |||
| 440 | if args.embeddings_dir is not None: | ||
| 441 | embeddings_dir = Path(args.embeddings_dir) | ||
| 442 | if not embeddings_dir.exists() or not embeddings_dir.is_dir(): | ||
| 443 | raise ValueError("--embeddings_dir must point to an existing directory") | ||
| 444 | |||
| 445 | embeddings.persist() | ||
| 446 | |||
| 447 | added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) | ||
| 448 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") | ||
| 449 | |||
| 450 | if args.scale_lr: | ||
| 451 | args.learning_rate = ( | ||
| 452 | args.learning_rate * args.gradient_accumulation_steps * | ||
| 453 | args.train_batch_size * accelerator.num_processes | ||
| 454 | ) | ||
| 455 | |||
| 456 | if args.find_lr: | ||
| 457 | args.learning_rate = 1e-6 | ||
| 458 | args.lr_scheduler = "exponential_growth" | ||
| 459 | |||
| 460 | if args.use_8bit_adam: | ||
| 461 | try: | ||
| 462 | import bitsandbytes as bnb | ||
| 463 | except ImportError: | ||
| 464 | raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.") | ||
| 465 | |||
| 466 | optimizer_class = bnb.optim.AdamW8bit | ||
| 467 | else: | ||
| 468 | optimizer_class = torch.optim.AdamW | ||
| 469 | |||
| 470 | weight_dtype = torch.float32 | ||
| 471 | if args.mixed_precision == "fp16": | ||
| 472 | weight_dtype = torch.float16 | ||
| 473 | elif args.mixed_precision == "bf16": | ||
| 474 | weight_dtype = torch.bfloat16 | ||
| 475 | |||
| 476 | trainer = partial( | ||
| 477 | train, | ||
| 478 | accelerator=accelerator, | ||
| 479 | unet=unet, | ||
| 480 | text_encoder=text_encoder, | ||
| 481 | vae=vae, | ||
| 482 | lora_layers=lora_layers, | ||
| 483 | noise_scheduler=noise_scheduler, | ||
| 484 | dtype=weight_dtype, | ||
| 485 | with_prior_preservation=args.num_class_images != 0, | ||
| 486 | prior_loss_weight=args.prior_loss_weight, | ||
| 487 | ) | ||
| 488 | |||
| 489 | checkpoint_output_dir = output_dir.joinpath("model") | ||
| 490 | sample_output_dir = output_dir.joinpath(f"samples") | ||
| 491 | |||
| 492 | datamodule = VlpnDataModule( | ||
| 493 | data_file=args.train_data_file, | ||
| 494 | batch_size=args.train_batch_size, | ||
| 495 | tokenizer=tokenizer, | ||
| 496 | class_subdir=args.class_image_dir, | ||
| 497 | num_class_images=args.num_class_images, | ||
| 498 | size=args.resolution, | ||
| 499 | num_buckets=args.num_buckets, | ||
| 500 | progressive_buckets=args.progressive_buckets, | ||
| 501 | bucket_step_size=args.bucket_step_size, | ||
| 502 | bucket_max_pixels=args.bucket_max_pixels, | ||
| 503 | dropout=args.tag_dropout, | ||
| 504 | shuffle=not args.no_tag_shuffle, | ||
| 505 | template_key=args.train_data_template, | ||
| 506 | valid_set_size=args.valid_set_size, | ||
| 507 | train_set_pad=args.train_set_pad, | ||
| 508 | valid_set_pad=args.valid_set_pad, | ||
| 509 | seed=args.seed, | ||
| 510 | filter=partial(keyword_filter, None, args.collection, args.exclude_collections), | ||
| 511 | dtype=weight_dtype | ||
| 512 | ) | ||
| 513 | datamodule.setup() | ||
| 514 | |||
| 515 | optimizer = optimizer_class( | ||
| 516 | lora_layers.parameters(), | ||
| 517 | lr=args.learning_rate, | ||
| 518 | betas=(args.adam_beta1, args.adam_beta2), | ||
| 519 | weight_decay=args.adam_weight_decay, | ||
| 520 | eps=args.adam_epsilon, | ||
| 521 | amsgrad=args.adam_amsgrad, | ||
| 522 | ) | ||
| 523 | |||
| 524 | lr_scheduler = get_scheduler( | ||
| 525 | args.lr_scheduler, | ||
| 526 | optimizer=optimizer, | ||
| 527 | num_training_steps_per_epoch=len(datamodule.train_dataloader), | ||
| 528 | gradient_accumulation_steps=args.gradient_accumulation_steps, | ||
| 529 | min_lr=args.lr_min_lr, | ||
| 530 | warmup_func=args.lr_warmup_func, | ||
| 531 | annealing_func=args.lr_annealing_func, | ||
| 532 | warmup_exp=args.lr_warmup_exp, | ||
| 533 | annealing_exp=args.lr_annealing_exp, | ||
| 534 | cycles=args.lr_cycles, | ||
| 535 | end_lr=1e2, | ||
| 536 | train_epochs=args.num_train_epochs, | ||
| 537 | warmup_epochs=args.lr_warmup_epochs, | ||
| 538 | ) | ||
| 539 | |||
| 540 | metrics = trainer( | ||
| 541 | strategy=lora_strategy, | ||
| 542 | project="lora", | ||
| 543 | train_dataloader=datamodule.train_dataloader, | ||
| 544 | val_dataloader=datamodule.val_dataloader, | ||
| 545 | seed=args.seed, | ||
| 546 | optimizer=optimizer, | ||
| 547 | lr_scheduler=lr_scheduler, | ||
| 548 | num_train_epochs=args.num_train_epochs, | ||
| 549 | sample_frequency=args.sample_frequency, | ||
| 550 | # -- | ||
| 551 | tokenizer=tokenizer, | ||
| 552 | sample_scheduler=sample_scheduler, | ||
| 553 | sample_output_dir=sample_output_dir, | ||
| 554 | checkpoint_output_dir=checkpoint_output_dir, | ||
| 555 | max_grad_norm=args.max_grad_norm, | ||
| 556 | sample_batch_size=args.sample_batch_size, | ||
| 557 | sample_num_batches=args.sample_batches, | ||
| 558 | sample_num_steps=args.sample_steps, | ||
| 559 | sample_image_size=args.sample_image_size, | ||
| 560 | ) | ||
| 561 | |||
| 562 | plot_metrics(metrics, output_dir.joinpath("lr.png")) | ||
| 563 | |||
| 564 | |||
| 565 | if __name__ == "__main__": | ||
| 566 | main() | ||
