diff options
| author | Volpeon <git@volpeon.ink> | 2023-01-15 21:06:16 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-01-15 21:06:16 +0100 |
| commit | 632ce00b54ffeacfc18f44f10827f167ab3ac37c (patch) | |
| tree | ecf58df2b176d3c7d1583136bf453ed24de8d7f3 | |
| parent | Fixed Conda env (diff) | |
| download | textual-inversion-diff-632ce00b54ffeacfc18f44f10827f167ab3ac37c.tar.gz textual-inversion-diff-632ce00b54ffeacfc18f44f10827f167ab3ac37c.tar.bz2 textual-inversion-diff-632ce00b54ffeacfc18f44f10827f167ab3ac37c.zip | |
Restored functional trainer
| -rw-r--r-- | data/csv.py | 8 | ||||
| -rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 16 | ||||
| -rw-r--r-- | train_ti.py | 82 | ||||
| -rw-r--r-- | training/functional.py | 102 | ||||
| -rw-r--r-- | training/util.py | 8 |
5 files changed, 112 insertions, 104 deletions
diff --git a/data/csv.py b/data/csv.py index 5de3ac7..2a8115b 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -15,9 +15,6 @@ from data.keywords import prompt_to_keywords, keywords_to_prompt | |||
| 15 | from models.clip.util import unify_input_ids | 15 | from models.clip.util import unify_input_ids |
| 16 | 16 | ||
| 17 | 17 | ||
| 18 | image_cache: dict[str, Image.Image] = {} | ||
| 19 | |||
| 20 | |||
| 21 | interpolations = { | 18 | interpolations = { |
| 22 | "linear": transforms.InterpolationMode.NEAREST, | 19 | "linear": transforms.InterpolationMode.NEAREST, |
| 23 | "bilinear": transforms.InterpolationMode.BILINEAR, | 20 | "bilinear": transforms.InterpolationMode.BILINEAR, |
| @@ -27,14 +24,9 @@ interpolations = { | |||
| 27 | 24 | ||
| 28 | 25 | ||
| 29 | def get_image(path): | 26 | def get_image(path): |
| 30 | if path in image_cache: | ||
| 31 | return image_cache[path] | ||
| 32 | |||
| 33 | image = Image.open(path) | 27 | image = Image.open(path) |
| 34 | if not image.mode == "RGB": | 28 | if not image.mode == "RGB": |
| 35 | image = image.convert("RGB") | 29 | image = image.convert("RGB") |
| 36 | image_cache[path] = image | ||
| 37 | |||
| 38 | return image | 30 | return image |
| 39 | 31 | ||
| 40 | 32 | ||
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 43141bd..3027421 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
| @@ -162,8 +162,8 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 162 | self, | 162 | self, |
| 163 | prompt: Union[str, List[str], List[int], List[List[int]]], | 163 | prompt: Union[str, List[str], List[int], List[List[int]]], |
| 164 | negative_prompt: Optional[Union[str, List[str], List[int], List[List[int]]]], | 164 | negative_prompt: Optional[Union[str, List[str], List[int], List[List[int]]]], |
| 165 | width: Optional[int], | 165 | width: int, |
| 166 | height: Optional[int], | 166 | height: int, |
| 167 | strength: float, | 167 | strength: float, |
| 168 | callback_steps: Optional[int] | 168 | callback_steps: Optional[int] |
| 169 | ): | 169 | ): |
| @@ -324,19 +324,19 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 324 | self, | 324 | self, |
| 325 | prompt: Union[str, List[str], List[int], List[List[int]]], | 325 | prompt: Union[str, List[str], List[int], List[List[int]]], |
| 326 | negative_prompt: Optional[Union[str, List[str], List[int], List[List[int]]]] = None, | 326 | negative_prompt: Optional[Union[str, List[str], List[int], List[List[int]]]] = None, |
| 327 | num_images_per_prompt: Optional[int] = 1, | 327 | num_images_per_prompt: int = 1, |
| 328 | strength: float = 0.8, | 328 | strength: float = 0.8, |
| 329 | height: Optional[int] = None, | 329 | height: Optional[int] = None, |
| 330 | width: Optional[int] = None, | 330 | width: Optional[int] = None, |
| 331 | num_inference_steps: Optional[int] = 50, | 331 | num_inference_steps: int = 50, |
| 332 | guidance_scale: Optional[float] = 7.5, | 332 | guidance_scale: float = 7.5, |
| 333 | eta: Optional[float] = 0.0, | 333 | eta: float = 0.0, |
| 334 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | 334 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
| 335 | image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, | 335 | image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, |
| 336 | output_type: Optional[str] = "pil", | 336 | output_type: str = "pil", |
| 337 | return_dict: bool = True, | 337 | return_dict: bool = True, |
| 338 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, | 338 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, |
| 339 | callback_steps: Optional[int] = 1, | 339 | callback_steps: int = 1, |
| 340 | ): | 340 | ): |
| 341 | r""" | 341 | r""" |
| 342 | Function invoked when calling the pipeline for generation. | 342 | Function invoked when calling the pipeline for generation. |
diff --git a/train_ti.py b/train_ti.py index 4bac736..77dec12 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -10,15 +10,13 @@ import torch.utils.checkpoint | |||
| 10 | from accelerate import Accelerator | 10 | from accelerate import Accelerator |
| 11 | from accelerate.logging import get_logger | 11 | from accelerate.logging import get_logger |
| 12 | from accelerate.utils import LoggerType, set_seed | 12 | from accelerate.utils import LoggerType, set_seed |
| 13 | import matplotlib.pyplot as plt | ||
| 14 | from slugify import slugify | 13 | from slugify import slugify |
| 15 | 14 | ||
| 16 | from util import load_config, load_embeddings_from_dir | 15 | from util import load_config, load_embeddings_from_dir |
| 17 | from data.csv import VlpnDataModule, VlpnDataItem | 16 | from data.csv import VlpnDataModule, VlpnDataItem |
| 18 | from training.functional import train_loop, loss_step, generate_class_images, add_placeholder_tokens, get_models | 17 | from training.functional import train, generate_class_images, add_placeholder_tokens, get_models |
| 19 | from training.strategy.ti import textual_inversion_strategy | 18 | from training.strategy.ti import textual_inversion_strategy |
| 20 | from training.optimization import get_scheduler | 19 | from training.optimization import get_scheduler |
| 21 | from training.lr import LRFinder | ||
| 22 | from training.util import save_args | 20 | from training.util import save_args |
| 23 | 21 | ||
| 24 | logger = get_logger(__name__) | 22 | logger = get_logger(__name__) |
| @@ -644,23 +642,33 @@ def main(): | |||
| 644 | warmup_epochs=args.lr_warmup_epochs, | 642 | warmup_epochs=args.lr_warmup_epochs, |
| 645 | ) | 643 | ) |
| 646 | 644 | ||
| 647 | unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( | 645 | trainer = partial( |
| 648 | unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler | 646 | train, |
| 649 | ) | ||
| 650 | |||
| 651 | vae.to(accelerator.device, dtype=weight_dtype) | ||
| 652 | |||
| 653 | callbacks = textual_inversion_strategy( | ||
| 654 | accelerator=accelerator, | 647 | accelerator=accelerator, |
| 655 | unet=unet, | 648 | unet=unet, |
| 656 | text_encoder=text_encoder, | 649 | text_encoder=text_encoder, |
| 657 | tokenizer=tokenizer, | ||
| 658 | vae=vae, | 650 | vae=vae, |
| 659 | sample_scheduler=sample_scheduler, | 651 | noise_scheduler=noise_scheduler, |
| 660 | train_dataloader=train_dataloader, | 652 | train_dataloader=train_dataloader, |
| 661 | val_dataloader=val_dataloader, | 653 | val_dataloader=val_dataloader, |
| 662 | output_dir=output_dir, | 654 | dtype=weight_dtype, |
| 663 | seed=args.seed, | 655 | seed=args.seed, |
| 656 | callbacks_fn=textual_inversion_strategy | ||
| 657 | ) | ||
| 658 | |||
| 659 | trainer( | ||
| 660 | optimizer=optimizer, | ||
| 661 | lr_scheduler=lr_scheduler, | ||
| 662 | num_train_epochs=args.num_train_epochs, | ||
| 663 | sample_frequency=args.sample_frequency, | ||
| 664 | checkpoint_frequency=args.checkpoint_frequency, | ||
| 665 | global_step_offset=global_step_offset, | ||
| 666 | with_prior_preservation=args.num_class_images != 0, | ||
| 667 | prior_loss_weight=args.prior_loss_weight, | ||
| 668 | # -- | ||
| 669 | tokenizer=tokenizer, | ||
| 670 | sample_scheduler=sample_scheduler, | ||
| 671 | output_dir=output_dir, | ||
| 664 | placeholder_tokens=args.placeholder_tokens, | 672 | placeholder_tokens=args.placeholder_tokens, |
| 665 | placeholder_token_ids=placeholder_token_ids, | 673 | placeholder_token_ids=placeholder_token_ids, |
| 666 | learning_rate=args.learning_rate, | 674 | learning_rate=args.learning_rate, |
| @@ -679,54 +687,6 @@ def main(): | |||
| 679 | sample_image_size=args.sample_image_size, | 687 | sample_image_size=args.sample_image_size, |
| 680 | ) | 688 | ) |
| 681 | 689 | ||
| 682 | for model in (unet, text_encoder, vae): | ||
| 683 | model.requires_grad_(False) | ||
| 684 | model.eval() | ||
| 685 | |||
| 686 | callbacks.on_prepare() | ||
| 687 | |||
| 688 | loss_step_ = partial( | ||
| 689 | loss_step, | ||
| 690 | vae, | ||
| 691 | noise_scheduler, | ||
| 692 | unet, | ||
| 693 | text_encoder, | ||
| 694 | args.num_class_images != 0, | ||
| 695 | args.prior_loss_weight, | ||
| 696 | args.seed, | ||
| 697 | ) | ||
| 698 | |||
| 699 | if args.find_lr: | ||
| 700 | lr_finder = LRFinder( | ||
| 701 | accelerator=accelerator, | ||
| 702 | optimizer=optimizer, | ||
| 703 | train_dataloader=train_dataloader, | ||
| 704 | val_dataloader=val_dataloader, | ||
| 705 | callbacks=callbacks, | ||
| 706 | ) | ||
| 707 | lr_finder.run(num_epochs=100, end_lr=1e3) | ||
| 708 | |||
| 709 | plt.savefig(output_dir.joinpath("lr.png"), dpi=300) | ||
| 710 | plt.close() | ||
| 711 | else: | ||
| 712 | if accelerator.is_main_process: | ||
| 713 | accelerator.init_trackers("textual_inversion") | ||
| 714 | |||
| 715 | train_loop( | ||
| 716 | accelerator=accelerator, | ||
| 717 | optimizer=optimizer, | ||
| 718 | lr_scheduler=lr_scheduler, | ||
| 719 | train_dataloader=train_dataloader, | ||
| 720 | val_dataloader=val_dataloader, | ||
| 721 | loss_step=loss_step_, | ||
| 722 | sample_frequency=args.sample_frequency, | ||
| 723 | checkpoint_frequency=args.checkpoint_frequency, | ||
| 724 | global_step_offset=global_step_offset, | ||
| 725 | callbacks=callbacks, | ||
| 726 | ) | ||
| 727 | |||
| 728 | accelerator.end_training() | ||
| 729 | |||
| 730 | 690 | ||
| 731 | if __name__ == "__main__": | 691 | if __name__ == "__main__": |
| 732 | main() | 692 | main() |
diff --git a/training/functional.py b/training/functional.py index c01595a..5984ffb 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -1,7 +1,7 @@ | |||
| 1 | from dataclasses import dataclass | 1 | from dataclasses import dataclass |
| 2 | import math | 2 | import math |
| 3 | from contextlib import _GeneratorContextManager, nullcontext | 3 | from contextlib import _GeneratorContextManager, nullcontext |
| 4 | from typing import Callable, Any, Tuple, Union, Optional | 4 | from typing import Callable, Any, Tuple, Union, Optional, Type |
| 5 | from functools import partial | 5 | from functools import partial |
| 6 | from pathlib import Path | 6 | from pathlib import Path |
| 7 | import itertools | 7 | import itertools |
| @@ -32,7 +32,7 @@ def const(result=None): | |||
| 32 | 32 | ||
| 33 | @dataclass | 33 | @dataclass |
| 34 | class TrainingCallbacks(): | 34 | class TrainingCallbacks(): |
| 35 | on_prepare: Callable[[float], None] = const() | 35 | on_prepare: Callable[[], None] = const() |
| 36 | on_model: Callable[[], torch.nn.Module] = const(None) | 36 | on_model: Callable[[], torch.nn.Module] = const(None) |
| 37 | on_log: Callable[[], dict[str, Any]] = const({}) | 37 | on_log: Callable[[], dict[str, Any]] = const({}) |
| 38 | on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) | 38 | on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) |
| @@ -220,28 +220,6 @@ def generate_class_images( | |||
| 220 | torch.cuda.empty_cache() | 220 | torch.cuda.empty_cache() |
| 221 | 221 | ||
| 222 | 222 | ||
| 223 | def get_models(pretrained_model_name_or_path: str): | ||
| 224 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') | ||
| 225 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') | ||
| 226 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') | ||
| 227 | unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet') | ||
| 228 | noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') | ||
| 229 | sample_scheduler = DPMSolverMultistepScheduler.from_pretrained( | ||
| 230 | pretrained_model_name_or_path, subfolder='scheduler') | ||
| 231 | |||
| 232 | vae.enable_slicing() | ||
| 233 | vae.set_use_memory_efficient_attention_xformers(True) | ||
| 234 | unet.set_use_memory_efficient_attention_xformers(True) | ||
| 235 | |||
| 236 | embeddings = patch_managed_embeddings(text_encoder) | ||
| 237 | |||
| 238 | vae.requires_grad_(False) | ||
| 239 | unet.requires_grad_(False) | ||
| 240 | text_encoder.requires_grad_(False) | ||
| 241 | |||
| 242 | return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings | ||
| 243 | |||
| 244 | |||
| 245 | def add_placeholder_tokens( | 223 | def add_placeholder_tokens( |
| 246 | tokenizer: MultiCLIPTokenizer, | 224 | tokenizer: MultiCLIPTokenizer, |
| 247 | embeddings: ManagedCLIPTextEmbeddings, | 225 | embeddings: ManagedCLIPTextEmbeddings, |
| @@ -508,3 +486,79 @@ def train_loop( | |||
| 508 | if accelerator.is_main_process: | 486 | if accelerator.is_main_process: |
| 509 | print("Interrupted") | 487 | print("Interrupted") |
| 510 | on_checkpoint(global_step + global_step_offset, "end") | 488 | on_checkpoint(global_step + global_step_offset, "end") |
| 489 | |||
| 490 | |||
| 491 | def train( | ||
| 492 | accelerator: Accelerator, | ||
| 493 | unet: UNet2DConditionModel, | ||
| 494 | text_encoder: CLIPTextModel, | ||
| 495 | vae: AutoencoderKL, | ||
| 496 | noise_scheduler: DDPMScheduler, | ||
| 497 | train_dataloader: DataLoader, | ||
| 498 | val_dataloader: DataLoader, | ||
| 499 | dtype: torch.dtype, | ||
| 500 | seed: int, | ||
| 501 | optimizer: torch.optim.Optimizer, | ||
| 502 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | ||
| 503 | callbacks_fn: Callable[..., TrainingCallbacks], | ||
| 504 | num_train_epochs: int = 100, | ||
| 505 | sample_frequency: int = 20, | ||
| 506 | checkpoint_frequency: int = 50, | ||
| 507 | global_step_offset: int = 0, | ||
| 508 | with_prior_preservation: bool = False, | ||
| 509 | prior_loss_weight: float = 1.0, | ||
| 510 | **kwargs, | ||
| 511 | ): | ||
| 512 | unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( | ||
| 513 | unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler | ||
| 514 | ) | ||
| 515 | |||
| 516 | vae.to(accelerator.device, dtype=dtype) | ||
| 517 | |||
| 518 | for model in (unet, text_encoder, vae): | ||
| 519 | model.requires_grad_(False) | ||
| 520 | model.eval() | ||
| 521 | |||
| 522 | callbacks = callbacks_fn( | ||
| 523 | accelerator=accelerator, | ||
| 524 | unet=unet, | ||
| 525 | text_encoder=text_encoder, | ||
| 526 | vae=vae, | ||
| 527 | train_dataloader=train_dataloader, | ||
| 528 | val_dataloader=val_dataloader, | ||
| 529 | seed=seed, | ||
| 530 | **kwargs, | ||
| 531 | ) | ||
| 532 | |||
| 533 | callbacks.on_prepare() | ||
| 534 | |||
| 535 | loss_step_ = partial( | ||
| 536 | loss_step, | ||
| 537 | vae, | ||
| 538 | noise_scheduler, | ||
| 539 | unet, | ||
| 540 | text_encoder, | ||
| 541 | with_prior_preservation, | ||
| 542 | prior_loss_weight, | ||
| 543 | seed, | ||
| 544 | ) | ||
| 545 | |||
| 546 | if accelerator.is_main_process: | ||
| 547 | accelerator.init_trackers("textual_inversion") | ||
| 548 | |||
| 549 | train_loop( | ||
| 550 | accelerator=accelerator, | ||
| 551 | optimizer=optimizer, | ||
| 552 | lr_scheduler=lr_scheduler, | ||
| 553 | train_dataloader=train_dataloader, | ||
| 554 | val_dataloader=val_dataloader, | ||
| 555 | loss_step=loss_step_, | ||
| 556 | sample_frequency=sample_frequency, | ||
| 557 | checkpoint_frequency=checkpoint_frequency, | ||
| 558 | global_step_offset=global_step_offset, | ||
| 559 | num_epochs=num_train_epochs, | ||
| 560 | callbacks=callbacks, | ||
| 561 | ) | ||
| 562 | |||
| 563 | accelerator.end_training() | ||
| 564 | accelerator.free_memory() | ||
diff --git a/training/util.py b/training/util.py index f46cc61..557b196 100644 --- a/training/util.py +++ b/training/util.py | |||
| @@ -180,11 +180,13 @@ class EMAModel: | |||
| 180 | 180 | ||
| 181 | @contextmanager | 181 | @contextmanager |
| 182 | def apply_temporary(self, parameters: Iterable[torch.nn.Parameter]): | 182 | def apply_temporary(self, parameters: Iterable[torch.nn.Parameter]): |
| 183 | parameters = list(parameters) | ||
| 184 | original_params = [p.clone() for p in parameters] | ||
| 185 | self.copy_to(parameters) | ||
| 186 | |||
| 183 | try: | 187 | try: |
| 184 | parameters = list(parameters) | ||
| 185 | original_params = [p.clone() for p in parameters] | ||
| 186 | self.copy_to(parameters) | ||
| 187 | yield | 188 | yield |
| 188 | finally: | 189 | finally: |
| 189 | for o_param, param in zip(original_params, parameters): | 190 | for o_param, param in zip(original_params, parameters): |
| 190 | param.data.copy_(o_param.data) | 191 | param.data.copy_(o_param.data) |
| 192 | del original_params | ||
