diff options
| -rw-r--r-- | train_dreambooth.py | 5 | ||||
| -rw-r--r-- | train_ti.py | 109 | ||||
| -rw-r--r-- | training/functional.py | 19 | ||||
| -rw-r--r-- | training/strategy/dreambooth.py | 10 | ||||
| -rw-r--r-- | training/strategy/ti.py | 19 | ||||
| -rw-r--r-- | training/util.py | 11 |
6 files changed, 102 insertions, 71 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index d722e68..48bdcf8 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -14,8 +14,7 @@ from slugify import slugify | |||
| 14 | 14 | ||
| 15 | from util import load_config, load_embeddings_from_dir | 15 | from util import load_config, load_embeddings_from_dir |
| 16 | from data.csv import VlpnDataModule, keyword_filter | 16 | from data.csv import VlpnDataModule, keyword_filter |
| 17 | from training.functional import train, generate_class_images, add_placeholder_tokens, get_models | 17 | from training.functional import train, get_models |
| 18 | from training.strategy.ti import textual_inversion_strategy | ||
| 19 | from training.strategy.dreambooth import dreambooth_strategy | 18 | from training.strategy.dreambooth import dreambooth_strategy |
| 20 | from training.optimization import get_scheduler | 19 | from training.optimization import get_scheduler |
| 21 | from training.util import save_args | 20 | from training.util import save_args |
| @@ -610,7 +609,7 @@ def main(): | |||
| 610 | ) | 609 | ) |
| 611 | 610 | ||
| 612 | trainer( | 611 | trainer( |
| 613 | callbacks_fn=dreambooth_strategy, | 612 | strategy=dreambooth_strategy, |
| 614 | project="dreambooth", | 613 | project="dreambooth", |
| 615 | train_dataloader=datamodule.train_dataloader, | 614 | train_dataloader=datamodule.train_dataloader, |
| 616 | val_dataloader=datamodule.val_dataloader, | 615 | val_dataloader=datamodule.val_dataloader, |
diff --git a/train_ti.py b/train_ti.py index e7aeb23..0891c49 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -14,7 +14,7 @@ from slugify import slugify | |||
| 14 | 14 | ||
| 15 | from util import load_config, load_embeddings_from_dir | 15 | from util import load_config, load_embeddings_from_dir |
| 16 | from data.csv import VlpnDataModule, keyword_filter | 16 | from data.csv import VlpnDataModule, keyword_filter |
| 17 | from training.functional import train, generate_class_images, add_placeholder_tokens, get_models | 17 | from training.functional import train, add_placeholder_tokens, get_models |
| 18 | from training.strategy.ti import textual_inversion_strategy | 18 | from training.strategy.ti import textual_inversion_strategy |
| 19 | from training.optimization import get_scheduler | 19 | from training.optimization import get_scheduler |
| 20 | from training.util import save_args | 20 | from training.util import save_args |
| @@ -79,6 +79,10 @@ def parse_args(): | |||
| 79 | help="Number of vectors per embedding." | 79 | help="Number of vectors per embedding." |
| 80 | ) | 80 | ) |
| 81 | parser.add_argument( | 81 | parser.add_argument( |
| 82 | "--simultaneous", | ||
| 83 | action="store_true", | ||
| 84 | ) | ||
| 85 | parser.add_argument( | ||
| 82 | "--num_class_images", | 86 | "--num_class_images", |
| 83 | type=int, | 87 | type=int, |
| 84 | default=0, | 88 | default=0, |
| @@ -474,11 +478,12 @@ def parse_args(): | |||
| 474 | if len(args.placeholder_tokens) != len(args.num_vectors): | 478 | if len(args.placeholder_tokens) != len(args.num_vectors): |
| 475 | raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") | 479 | raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") |
| 476 | 480 | ||
| 477 | if isinstance(args.train_data_template, str): | 481 | if not args.simultaneous: |
| 478 | args.train_data_template = [args.train_data_template] * len(args.placeholder_tokens) | 482 | if isinstance(args.train_data_template, str): |
| 483 | args.train_data_template = [args.train_data_template] * len(args.placeholder_tokens) | ||
| 479 | 484 | ||
| 480 | if len(args.placeholder_tokens) != len(args.train_data_template): | 485 | if len(args.placeholder_tokens) != len(args.train_data_template): |
| 481 | raise ValueError("--placeholder_tokens and --train_data_template must have the same number of items") | 486 | raise ValueError("--placeholder_tokens and --train_data_template must have the same number of items") |
| 482 | 487 | ||
| 483 | if isinstance(args.collection, str): | 488 | if isinstance(args.collection, str): |
| 484 | args.collection = [args.collection] | 489 | args.collection = [args.collection] |
| @@ -560,6 +565,8 @@ def main(): | |||
| 560 | elif args.mixed_precision == "bf16": | 565 | elif args.mixed_precision == "bf16": |
| 561 | weight_dtype = torch.bfloat16 | 566 | weight_dtype = torch.bfloat16 |
| 562 | 567 | ||
| 568 | checkpoint_output_dir = output_dir.joinpath("checkpoints") | ||
| 569 | |||
| 563 | trainer = partial( | 570 | trainer = partial( |
| 564 | train, | 571 | train, |
| 565 | accelerator=accelerator, | 572 | accelerator=accelerator, |
| @@ -569,30 +576,50 @@ def main(): | |||
| 569 | noise_scheduler=noise_scheduler, | 576 | noise_scheduler=noise_scheduler, |
| 570 | dtype=weight_dtype, | 577 | dtype=weight_dtype, |
| 571 | seed=args.seed, | 578 | seed=args.seed, |
| 572 | callbacks_fn=textual_inversion_strategy | 579 | with_prior_preservation=args.num_class_images != 0, |
| 580 | prior_loss_weight=args.prior_loss_weight, | ||
| 581 | strategy=textual_inversion_strategy, | ||
| 582 | num_train_epochs=args.num_train_epochs, | ||
| 583 | sample_frequency=args.sample_frequency, | ||
| 584 | checkpoint_frequency=args.checkpoint_frequency, | ||
| 585 | global_step_offset=global_step_offset, | ||
| 586 | # -- | ||
| 587 | tokenizer=tokenizer, | ||
| 588 | sample_scheduler=sample_scheduler, | ||
| 589 | checkpoint_output_dir=checkpoint_output_dir, | ||
| 590 | learning_rate=args.learning_rate, | ||
| 591 | gradient_checkpointing=args.gradient_checkpointing, | ||
| 592 | use_emb_decay=args.use_emb_decay, | ||
| 593 | emb_decay_target=args.emb_decay_target, | ||
| 594 | emb_decay_factor=args.emb_decay_factor, | ||
| 595 | emb_decay_start=args.emb_decay_start, | ||
| 596 | use_ema=args.use_ema, | ||
| 597 | ema_inv_gamma=args.ema_inv_gamma, | ||
| 598 | ema_power=args.ema_power, | ||
| 599 | ema_max_decay=args.ema_max_decay, | ||
| 600 | sample_batch_size=args.sample_batch_size, | ||
| 601 | sample_num_batches=args.sample_batches, | ||
| 602 | sample_num_steps=args.sample_steps, | ||
| 603 | sample_image_size=args.sample_image_size, | ||
| 573 | ) | 604 | ) |
| 574 | 605 | ||
| 575 | checkpoint_output_dir = output_dir.joinpath("checkpoints") | 606 | def run(i: int, placeholder_tokens, initializer_tokens, num_vectors, data_template): |
| 576 | 607 | if len(placeholder_tokens) == 1: | |
| 577 | for i, placeholder_token, initializer_token, num_vectors, data_template in zip( | 608 | sample_output_dir = output_dir.joinpath(f"samples_{placeholder_token[0]}") |
| 578 | range(len(args.placeholder_tokens)), | 609 | else: |
| 579 | args.placeholder_tokens, | 610 | sample_output_dir = output_dir.joinpath("samples") |
| 580 | args.initializer_tokens, | ||
| 581 | args.num_vectors, | ||
| 582 | args.train_data_template | ||
| 583 | ): | ||
| 584 | sample_output_dir = output_dir.joinpath(f"samples_{placeholder_token}") | ||
| 585 | 611 | ||
| 586 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( | 612 | placeholder_token_ids, initializer_token_ids = add_placeholder_tokens( |
| 587 | tokenizer=tokenizer, | 613 | tokenizer=tokenizer, |
| 588 | embeddings=embeddings, | 614 | embeddings=embeddings, |
| 589 | placeholder_tokens=[placeholder_token], | 615 | placeholder_tokens=placeholder_tokens, |
| 590 | initializer_tokens=[initializer_token], | 616 | initializer_tokens=initializer_tokens, |
| 591 | num_vectors=[num_vectors] | 617 | num_vectors=num_vectors |
| 592 | ) | 618 | ) |
| 593 | 619 | ||
| 594 | print( | 620 | stats = list(zip(placeholder_tokens, placeholder_token_ids, initializer_tokens, initializer_token_ids)) |
| 595 | f"{i + 1}: {placeholder_token}, {placeholder_token_ids[0]} ({initializer_token}, {initializer_token_ids[0]})") | 621 | |
| 622 | print(f"{i + 1}: {stats})") | ||
| 596 | 623 | ||
| 597 | datamodule = VlpnDataModule( | 624 | datamodule = VlpnDataModule( |
| 598 | data_file=args.train_data_file, | 625 | data_file=args.train_data_file, |
| @@ -612,7 +639,7 @@ def main(): | |||
| 612 | train_set_pad=args.train_set_pad, | 639 | train_set_pad=args.train_set_pad, |
| 613 | valid_set_pad=args.valid_set_pad, | 640 | valid_set_pad=args.valid_set_pad, |
| 614 | seed=args.seed, | 641 | seed=args.seed, |
| 615 | filter=partial(keyword_filter, [placeholder_token], args.collection, args.exclude_collections), | 642 | filter=partial(keyword_filter, placeholder_tokens, args.collection, args.exclude_collections), |
| 616 | dtype=weight_dtype | 643 | dtype=weight_dtype |
| 617 | ) | 644 | ) |
| 618 | datamodule.setup() | 645 | datamodule.setup() |
| @@ -647,36 +674,24 @@ def main(): | |||
| 647 | val_dataloader=datamodule.val_dataloader, | 674 | val_dataloader=datamodule.val_dataloader, |
| 648 | optimizer=optimizer, | 675 | optimizer=optimizer, |
| 649 | lr_scheduler=lr_scheduler, | 676 | lr_scheduler=lr_scheduler, |
| 650 | num_train_epochs=args.num_train_epochs, | ||
| 651 | sample_frequency=args.sample_frequency, | ||
| 652 | checkpoint_frequency=args.checkpoint_frequency, | ||
| 653 | global_step_offset=global_step_offset, | ||
| 654 | with_prior_preservation=args.num_class_images != 0, | ||
| 655 | prior_loss_weight=args.prior_loss_weight, | ||
| 656 | # -- | 677 | # -- |
| 657 | tokenizer=tokenizer, | ||
| 658 | sample_scheduler=sample_scheduler, | ||
| 659 | sample_output_dir=sample_output_dir, | 678 | sample_output_dir=sample_output_dir, |
| 660 | checkpoint_output_dir=checkpoint_output_dir, | 679 | placeholder_tokens=placeholder_tokens, |
| 661 | placeholder_tokens=[placeholder_token], | ||
| 662 | placeholder_token_ids=placeholder_token_ids, | 680 | placeholder_token_ids=placeholder_token_ids, |
| 663 | learning_rate=args.learning_rate, | ||
| 664 | gradient_checkpointing=args.gradient_checkpointing, | ||
| 665 | use_emb_decay=args.use_emb_decay, | ||
| 666 | emb_decay_target=args.emb_decay_target, | ||
| 667 | emb_decay_factor=args.emb_decay_factor, | ||
| 668 | emb_decay_start=args.emb_decay_start, | ||
| 669 | use_ema=args.use_ema, | ||
| 670 | ema_inv_gamma=args.ema_inv_gamma, | ||
| 671 | ema_power=args.ema_power, | ||
| 672 | ema_max_decay=args.ema_max_decay, | ||
| 673 | sample_batch_size=args.sample_batch_size, | ||
| 674 | sample_num_batches=args.sample_batches, | ||
| 675 | sample_num_steps=args.sample_steps, | ||
| 676 | sample_image_size=args.sample_image_size, | ||
| 677 | ) | 681 | ) |
| 678 | 682 | ||
| 679 | embeddings.persist() | 683 | if args.simultaneous: |
| 684 | run(0, args.placeholder_tokens, args.initializer_tokens, args.num_vectors, args.train_data_template) | ||
| 685 | else: | ||
| 686 | for i, placeholder_token, initializer_token, num_vectors, data_template in zip( | ||
| 687 | range(len(args.placeholder_tokens)), | ||
| 688 | args.placeholder_tokens, | ||
| 689 | args.initializer_tokens, | ||
| 690 | args.num_vectors, | ||
| 691 | args.train_data_template | ||
| 692 | ): | ||
| 693 | run(i, [placeholder_token], [initializer_token], [num_vectors], data_template) | ||
| 694 | embeddings.persist() | ||
| 680 | 695 | ||
| 681 | 696 | ||
| 682 | if __name__ == "__main__": | 697 | if __name__ == "__main__": |
diff --git a/training/functional.py b/training/functional.py index 3d27380..7a3e821 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -39,11 +39,18 @@ class TrainingCallbacks(): | |||
| 39 | on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) | 39 | on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) |
| 40 | on_before_optimize: Callable[[int], None] = const() | 40 | on_before_optimize: Callable[[int], None] = const() |
| 41 | on_after_optimize: Callable[[float], None] = const() | 41 | on_after_optimize: Callable[[float], None] = const() |
| 42 | on_after_epoch: Callable[[float], None] = const() | ||
| 42 | on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()) | 43 | on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()) |
| 43 | on_sample: Callable[[int], None] = const() | 44 | on_sample: Callable[[int], None] = const() |
| 44 | on_checkpoint: Callable[[int, str], None] = const() | 45 | on_checkpoint: Callable[[int, str], None] = const() |
| 45 | 46 | ||
| 46 | 47 | ||
| 48 | @dataclass | ||
| 49 | class TrainingStrategy(): | ||
| 50 | callbacks: Callable[..., TrainingCallbacks] | ||
| 51 | prepare_unet: bool = False | ||
| 52 | |||
| 53 | |||
| 47 | def make_grid(images, rows, cols): | 54 | def make_grid(images, rows, cols): |
| 48 | w, h = images[0].size | 55 | w, h = images[0].size |
| 49 | grid = Image.new('RGB', size=(cols*w, rows*h)) | 56 | grid = Image.new('RGB', size=(cols*w, rows*h)) |
| @@ -373,6 +380,7 @@ def train_loop( | |||
| 373 | on_train = callbacks.on_train | 380 | on_train = callbacks.on_train |
| 374 | on_before_optimize = callbacks.on_before_optimize | 381 | on_before_optimize = callbacks.on_before_optimize |
| 375 | on_after_optimize = callbacks.on_after_optimize | 382 | on_after_optimize = callbacks.on_after_optimize |
| 383 | on_after_epoch = callbacks.on_after_epoch | ||
| 376 | on_eval = callbacks.on_eval | 384 | on_eval = callbacks.on_eval |
| 377 | on_sample = callbacks.on_sample | 385 | on_sample = callbacks.on_sample |
| 378 | on_checkpoint = callbacks.on_checkpoint | 386 | on_checkpoint = callbacks.on_checkpoint |
| @@ -434,6 +442,8 @@ def train_loop( | |||
| 434 | 442 | ||
| 435 | accelerator.wait_for_everyone() | 443 | accelerator.wait_for_everyone() |
| 436 | 444 | ||
| 445 | on_after_epoch(lr_scheduler.get_last_lr()[0]) | ||
| 446 | |||
| 437 | if val_dataloader is not None: | 447 | if val_dataloader is not None: |
| 438 | model.eval() | 448 | model.eval() |
| 439 | 449 | ||
| @@ -512,8 +522,7 @@ def train( | |||
| 512 | val_dataloader: Optional[DataLoader], | 522 | val_dataloader: Optional[DataLoader], |
| 513 | optimizer: torch.optim.Optimizer, | 523 | optimizer: torch.optim.Optimizer, |
| 514 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | 524 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, |
| 515 | callbacks_fn: Callable[..., TrainingCallbacks], | 525 | strategy: TrainingStrategy, |
| 516 | prepare_unet: bool = False, | ||
| 517 | num_train_epochs: int = 100, | 526 | num_train_epochs: int = 100, |
| 518 | sample_frequency: int = 20, | 527 | sample_frequency: int = 20, |
| 519 | checkpoint_frequency: int = 50, | 528 | checkpoint_frequency: int = 50, |
| @@ -524,12 +533,12 @@ def train( | |||
| 524 | ): | 533 | ): |
| 525 | prep = [text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler] | 534 | prep = [text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler] |
| 526 | 535 | ||
| 527 | if prepare_unet: | 536 | if strategy.prepare_unet: |
| 528 | prep.append(unet) | 537 | prep.append(unet) |
| 529 | 538 | ||
| 530 | prep = accelerator.prepare(*prep) | 539 | prep = accelerator.prepare(*prep) |
| 531 | 540 | ||
| 532 | if prepare_unet: | 541 | if strategy.prepare_unet: |
| 533 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler, unet = prep | 542 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler, unet = prep |
| 534 | else: | 543 | else: |
| 535 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = prep | 544 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = prep |
| @@ -542,7 +551,7 @@ def train( | |||
| 542 | model.requires_grad_(False) | 551 | model.requires_grad_(False) |
| 543 | model.eval() | 552 | model.eval() |
| 544 | 553 | ||
| 545 | callbacks = callbacks_fn( | 554 | callbacks = strategy.callbacks( |
| 546 | accelerator=accelerator, | 555 | accelerator=accelerator, |
| 547 | unet=unet, | 556 | unet=unet, |
| 548 | text_encoder=text_encoder, | 557 | text_encoder=text_encoder, |
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index 93c81cb..bc26ee6 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py | |||
| @@ -15,10 +15,10 @@ from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepSch | |||
| 15 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 15 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
| 16 | from models.clip.tokenizer import MultiCLIPTokenizer | 16 | from models.clip.tokenizer import MultiCLIPTokenizer |
| 17 | from training.util import EMAModel | 17 | from training.util import EMAModel |
| 18 | from training.functional import TrainingCallbacks, save_samples | 18 | from training.functional import TrainingStrategy, TrainingCallbacks, save_samples |
| 19 | 19 | ||
| 20 | 20 | ||
| 21 | def dreambooth_strategy( | 21 | def dreambooth_strategy_callbacks( |
| 22 | accelerator: Accelerator, | 22 | accelerator: Accelerator, |
| 23 | unet: UNet2DConditionModel, | 23 | unet: UNet2DConditionModel, |
| 24 | text_encoder: CLIPTextModel, | 24 | text_encoder: CLIPTextModel, |
| @@ -185,3 +185,9 @@ def dreambooth_strategy( | |||
| 185 | on_checkpoint=on_checkpoint, | 185 | on_checkpoint=on_checkpoint, |
| 186 | on_sample=on_sample, | 186 | on_sample=on_sample, |
| 187 | ) | 187 | ) |
| 188 | |||
| 189 | |||
| 190 | dreambooth_strategy = TrainingStrategy( | ||
| 191 | callbacks=dreambooth_strategy_callbacks, | ||
| 192 | prepare_unet=True | ||
| 193 | ) | ||
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 00f3529..597abd0 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
| @@ -15,10 +15,10 @@ from slugify import slugify | |||
| 15 | 15 | ||
| 16 | from models.clip.tokenizer import MultiCLIPTokenizer | 16 | from models.clip.tokenizer import MultiCLIPTokenizer |
| 17 | from training.util import EMAModel | 17 | from training.util import EMAModel |
| 18 | from training.functional import TrainingCallbacks, save_samples | 18 | from training.functional import TrainingStrategy, TrainingCallbacks, save_samples |
| 19 | 19 | ||
| 20 | 20 | ||
| 21 | def textual_inversion_strategy( | 21 | def textual_inversion_strategy_callbacks( |
| 22 | accelerator: Accelerator, | 22 | accelerator: Accelerator, |
| 23 | unet: UNet2DConditionModel, | 23 | unet: UNet2DConditionModel, |
| 24 | text_encoder: CLIPTextModel, | 24 | text_encoder: CLIPTextModel, |
| @@ -119,17 +119,18 @@ def textual_inversion_strategy( | |||
| 119 | with ema_context(): | 119 | with ema_context(): |
| 120 | yield | 120 | yield |
| 121 | 121 | ||
| 122 | @torch.no_grad() | ||
| 123 | def on_after_optimize(lr: float): | 122 | def on_after_optimize(lr: float): |
| 123 | if use_ema: | ||
| 124 | ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) | ||
| 125 | |||
| 126 | @torch.no_grad() | ||
| 127 | def on_after_epoch(lr: float): | ||
| 124 | if use_emb_decay: | 128 | if use_emb_decay: |
| 125 | text_encoder.text_model.embeddings.normalize( | 129 | text_encoder.text_model.embeddings.normalize( |
| 126 | emb_decay_target, | 130 | emb_decay_target, |
| 127 | min(1.0, max(0.0, emb_decay_factor * ((lr - emb_decay_start) / (learning_rate - emb_decay_start)))) | 131 | min(1.0, max(0.0, emb_decay_factor * ((lr - emb_decay_start) / (learning_rate - emb_decay_start)))) |
| 128 | ) | 132 | ) |
| 129 | 133 | ||
| 130 | if use_ema: | ||
| 131 | ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) | ||
| 132 | |||
| 133 | def on_log(): | 134 | def on_log(): |
| 134 | if use_ema: | 135 | if use_ema: |
| 135 | return {"ema_decay": ema_embeddings.decay} | 136 | return {"ema_decay": ema_embeddings.decay} |
| @@ -157,7 +158,13 @@ def textual_inversion_strategy( | |||
| 157 | on_train=on_train, | 158 | on_train=on_train, |
| 158 | on_eval=on_eval, | 159 | on_eval=on_eval, |
| 159 | on_after_optimize=on_after_optimize, | 160 | on_after_optimize=on_after_optimize, |
| 161 | on_after_epoch=on_after_epoch, | ||
| 160 | on_log=on_log, | 162 | on_log=on_log, |
| 161 | on_checkpoint=on_checkpoint, | 163 | on_checkpoint=on_checkpoint, |
| 162 | on_sample=on_sample, | 164 | on_sample=on_sample, |
| 163 | ) | 165 | ) |
| 166 | |||
| 167 | |||
| 168 | textual_inversion_strategy = TrainingStrategy( | ||
| 169 | callbacks=textual_inversion_strategy_callbacks, | ||
| 170 | ) | ||
diff --git a/training/util.py b/training/util.py index 557b196..237626f 100644 --- a/training/util.py +++ b/training/util.py | |||
| @@ -1,18 +1,11 @@ | |||
| 1 | from pathlib import Path | 1 | from pathlib import Path |
| 2 | import json | 2 | import json |
| 3 | import copy | 3 | import copy |
| 4 | from typing import Iterable, Union | 4 | from typing import Iterable, Any |
| 5 | from contextlib import contextmanager | 5 | from contextlib import contextmanager |
| 6 | 6 | ||
| 7 | import torch | 7 | import torch |
| 8 | 8 | ||
| 9 | from transformers import CLIPTextModel | ||
| 10 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, DPMSolverMultistepScheduler | ||
| 11 | |||
| 12 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | ||
| 13 | from models.clip.tokenizer import MultiCLIPTokenizer | ||
| 14 | from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings | ||
| 15 | |||
| 16 | 9 | ||
| 17 | def save_args(basepath: Path, args, extra={}): | 10 | def save_args(basepath: Path, args, extra={}): |
| 18 | info = {"args": vars(args)} | 11 | info = {"args": vars(args)} |
| @@ -22,6 +15,8 @@ def save_args(basepath: Path, args, extra={}): | |||
| 22 | 15 | ||
| 23 | 16 | ||
| 24 | class AverageMeter: | 17 | class AverageMeter: |
| 18 | avg: Any | ||
| 19 | |||
| 25 | def __init__(self, name=None): | 20 | def __init__(self, name=None): |
| 26 | self.name = name | 21 | self.name = name |
| 27 | self.reset() | 22 | self.reset() |
