diff options
-rw-r--r-- | data/csv.py | 8 | ||||
-rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 16 | ||||
-rw-r--r-- | train_ti.py | 82 | ||||
-rw-r--r-- | training/functional.py | 102 | ||||
-rw-r--r-- | training/util.py | 8 |
5 files changed, 112 insertions, 104 deletions
diff --git a/data/csv.py b/data/csv.py index 5de3ac7..2a8115b 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -15,9 +15,6 @@ from data.keywords import prompt_to_keywords, keywords_to_prompt | |||
15 | from models.clip.util import unify_input_ids | 15 | from models.clip.util import unify_input_ids |
16 | 16 | ||
17 | 17 | ||
18 | image_cache: dict[str, Image.Image] = {} | ||
19 | |||
20 | |||
21 | interpolations = { | 18 | interpolations = { |
22 | "linear": transforms.InterpolationMode.NEAREST, | 19 | "linear": transforms.InterpolationMode.NEAREST, |
23 | "bilinear": transforms.InterpolationMode.BILINEAR, | 20 | "bilinear": transforms.InterpolationMode.BILINEAR, |
@@ -27,14 +24,9 @@ interpolations = { | |||
27 | 24 | ||
28 | 25 | ||
29 | def get_image(path): | 26 | def get_image(path): |
30 | if path in image_cache: | ||
31 | return image_cache[path] | ||
32 | |||
33 | image = Image.open(path) | 27 | image = Image.open(path) |
34 | if not image.mode == "RGB": | 28 | if not image.mode == "RGB": |
35 | image = image.convert("RGB") | 29 | image = image.convert("RGB") |
36 | image_cache[path] = image | ||
37 | |||
38 | return image | 30 | return image |
39 | 31 | ||
40 | 32 | ||
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index 43141bd..3027421 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
@@ -162,8 +162,8 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
162 | self, | 162 | self, |
163 | prompt: Union[str, List[str], List[int], List[List[int]]], | 163 | prompt: Union[str, List[str], List[int], List[List[int]]], |
164 | negative_prompt: Optional[Union[str, List[str], List[int], List[List[int]]]], | 164 | negative_prompt: Optional[Union[str, List[str], List[int], List[List[int]]]], |
165 | width: Optional[int], | 165 | width: int, |
166 | height: Optional[int], | 166 | height: int, |
167 | strength: float, | 167 | strength: float, |
168 | callback_steps: Optional[int] | 168 | callback_steps: Optional[int] |
169 | ): | 169 | ): |
@@ -324,19 +324,19 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
324 | self, | 324 | self, |
325 | prompt: Union[str, List[str], List[int], List[List[int]]], | 325 | prompt: Union[str, List[str], List[int], List[List[int]]], |
326 | negative_prompt: Optional[Union[str, List[str], List[int], List[List[int]]]] = None, | 326 | negative_prompt: Optional[Union[str, List[str], List[int], List[List[int]]]] = None, |
327 | num_images_per_prompt: Optional[int] = 1, | 327 | num_images_per_prompt: int = 1, |
328 | strength: float = 0.8, | 328 | strength: float = 0.8, |
329 | height: Optional[int] = None, | 329 | height: Optional[int] = None, |
330 | width: Optional[int] = None, | 330 | width: Optional[int] = None, |
331 | num_inference_steps: Optional[int] = 50, | 331 | num_inference_steps: int = 50, |
332 | guidance_scale: Optional[float] = 7.5, | 332 | guidance_scale: float = 7.5, |
333 | eta: Optional[float] = 0.0, | 333 | eta: float = 0.0, |
334 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | 334 | generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, |
335 | image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, | 335 | image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, |
336 | output_type: Optional[str] = "pil", | 336 | output_type: str = "pil", |
337 | return_dict: bool = True, | 337 | return_dict: bool = True, |
338 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, | 338 | callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, |
339 | callback_steps: Optional[int] = 1, | 339 | callback_steps: int = 1, |
340 | ): | 340 | ): |
341 | r""" | 341 | r""" |
342 | Function invoked when calling the pipeline for generation. | 342 | Function invoked when calling the pipeline for generation. |
diff --git a/train_ti.py b/train_ti.py index 4bac736..77dec12 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -10,15 +10,13 @@ import torch.utils.checkpoint | |||
10 | from accelerate import Accelerator | 10 | from accelerate import Accelerator |
11 | from accelerate.logging import get_logger | 11 | from accelerate.logging import get_logger |
12 | from accelerate.utils import LoggerType, set_seed | 12 | from accelerate.utils import LoggerType, set_seed |
13 | import matplotlib.pyplot as plt | ||
14 | from slugify import slugify | 13 | from slugify import slugify |
15 | 14 | ||
16 | from util import load_config, load_embeddings_from_dir | 15 | from util import load_config, load_embeddings_from_dir |
17 | from data.csv import VlpnDataModule, VlpnDataItem | 16 | from data.csv import VlpnDataModule, VlpnDataItem |
18 | from training.functional import train_loop, loss_step, generate_class_images, add_placeholder_tokens, get_models | 17 | from training.functional import train, generate_class_images, add_placeholder_tokens, get_models |
19 | from training.strategy.ti import textual_inversion_strategy | 18 | from training.strategy.ti import textual_inversion_strategy |
20 | from training.optimization import get_scheduler | 19 | from training.optimization import get_scheduler |
21 | from training.lr import LRFinder | ||
22 | from training.util import save_args | 20 | from training.util import save_args |
23 | 21 | ||
24 | logger = get_logger(__name__) | 22 | logger = get_logger(__name__) |
@@ -644,23 +642,33 @@ def main(): | |||
644 | warmup_epochs=args.lr_warmup_epochs, | 642 | warmup_epochs=args.lr_warmup_epochs, |
645 | ) | 643 | ) |
646 | 644 | ||
647 | unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( | 645 | trainer = partial( |
648 | unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler | 646 | train, |
649 | ) | ||
650 | |||
651 | vae.to(accelerator.device, dtype=weight_dtype) | ||
652 | |||
653 | callbacks = textual_inversion_strategy( | ||
654 | accelerator=accelerator, | 647 | accelerator=accelerator, |
655 | unet=unet, | 648 | unet=unet, |
656 | text_encoder=text_encoder, | 649 | text_encoder=text_encoder, |
657 | tokenizer=tokenizer, | ||
658 | vae=vae, | 650 | vae=vae, |
659 | sample_scheduler=sample_scheduler, | 651 | noise_scheduler=noise_scheduler, |
660 | train_dataloader=train_dataloader, | 652 | train_dataloader=train_dataloader, |
661 | val_dataloader=val_dataloader, | 653 | val_dataloader=val_dataloader, |
662 | output_dir=output_dir, | 654 | dtype=weight_dtype, |
663 | seed=args.seed, | 655 | seed=args.seed, |
656 | callbacks_fn=textual_inversion_strategy | ||
657 | ) | ||
658 | |||
659 | trainer( | ||
660 | optimizer=optimizer, | ||
661 | lr_scheduler=lr_scheduler, | ||
662 | num_train_epochs=args.num_train_epochs, | ||
663 | sample_frequency=args.sample_frequency, | ||
664 | checkpoint_frequency=args.checkpoint_frequency, | ||
665 | global_step_offset=global_step_offset, | ||
666 | with_prior_preservation=args.num_class_images != 0, | ||
667 | prior_loss_weight=args.prior_loss_weight, | ||
668 | # -- | ||
669 | tokenizer=tokenizer, | ||
670 | sample_scheduler=sample_scheduler, | ||
671 | output_dir=output_dir, | ||
664 | placeholder_tokens=args.placeholder_tokens, | 672 | placeholder_tokens=args.placeholder_tokens, |
665 | placeholder_token_ids=placeholder_token_ids, | 673 | placeholder_token_ids=placeholder_token_ids, |
666 | learning_rate=args.learning_rate, | 674 | learning_rate=args.learning_rate, |
@@ -679,54 +687,6 @@ def main(): | |||
679 | sample_image_size=args.sample_image_size, | 687 | sample_image_size=args.sample_image_size, |
680 | ) | 688 | ) |
681 | 689 | ||
682 | for model in (unet, text_encoder, vae): | ||
683 | model.requires_grad_(False) | ||
684 | model.eval() | ||
685 | |||
686 | callbacks.on_prepare() | ||
687 | |||
688 | loss_step_ = partial( | ||
689 | loss_step, | ||
690 | vae, | ||
691 | noise_scheduler, | ||
692 | unet, | ||
693 | text_encoder, | ||
694 | args.num_class_images != 0, | ||
695 | args.prior_loss_weight, | ||
696 | args.seed, | ||
697 | ) | ||
698 | |||
699 | if args.find_lr: | ||
700 | lr_finder = LRFinder( | ||
701 | accelerator=accelerator, | ||
702 | optimizer=optimizer, | ||
703 | train_dataloader=train_dataloader, | ||
704 | val_dataloader=val_dataloader, | ||
705 | callbacks=callbacks, | ||
706 | ) | ||
707 | lr_finder.run(num_epochs=100, end_lr=1e3) | ||
708 | |||
709 | plt.savefig(output_dir.joinpath("lr.png"), dpi=300) | ||
710 | plt.close() | ||
711 | else: | ||
712 | if accelerator.is_main_process: | ||
713 | accelerator.init_trackers("textual_inversion") | ||
714 | |||
715 | train_loop( | ||
716 | accelerator=accelerator, | ||
717 | optimizer=optimizer, | ||
718 | lr_scheduler=lr_scheduler, | ||
719 | train_dataloader=train_dataloader, | ||
720 | val_dataloader=val_dataloader, | ||
721 | loss_step=loss_step_, | ||
722 | sample_frequency=args.sample_frequency, | ||
723 | checkpoint_frequency=args.checkpoint_frequency, | ||
724 | global_step_offset=global_step_offset, | ||
725 | callbacks=callbacks, | ||
726 | ) | ||
727 | |||
728 | accelerator.end_training() | ||
729 | |||
730 | 690 | ||
731 | if __name__ == "__main__": | 691 | if __name__ == "__main__": |
732 | main() | 692 | main() |
diff --git a/training/functional.py b/training/functional.py index c01595a..5984ffb 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -1,7 +1,7 @@ | |||
1 | from dataclasses import dataclass | 1 | from dataclasses import dataclass |
2 | import math | 2 | import math |
3 | from contextlib import _GeneratorContextManager, nullcontext | 3 | from contextlib import _GeneratorContextManager, nullcontext |
4 | from typing import Callable, Any, Tuple, Union, Optional | 4 | from typing import Callable, Any, Tuple, Union, Optional, Type |
5 | from functools import partial | 5 | from functools import partial |
6 | from pathlib import Path | 6 | from pathlib import Path |
7 | import itertools | 7 | import itertools |
@@ -32,7 +32,7 @@ def const(result=None): | |||
32 | 32 | ||
33 | @dataclass | 33 | @dataclass |
34 | class TrainingCallbacks(): | 34 | class TrainingCallbacks(): |
35 | on_prepare: Callable[[float], None] = const() | 35 | on_prepare: Callable[[], None] = const() |
36 | on_model: Callable[[], torch.nn.Module] = const(None) | 36 | on_model: Callable[[], torch.nn.Module] = const(None) |
37 | on_log: Callable[[], dict[str, Any]] = const({}) | 37 | on_log: Callable[[], dict[str, Any]] = const({}) |
38 | on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) | 38 | on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) |
@@ -220,28 +220,6 @@ def generate_class_images( | |||
220 | torch.cuda.empty_cache() | 220 | torch.cuda.empty_cache() |
221 | 221 | ||
222 | 222 | ||
223 | def get_models(pretrained_model_name_or_path: str): | ||
224 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') | ||
225 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') | ||
226 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') | ||
227 | unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet') | ||
228 | noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') | ||
229 | sample_scheduler = DPMSolverMultistepScheduler.from_pretrained( | ||
230 | pretrained_model_name_or_path, subfolder='scheduler') | ||
231 | |||
232 | vae.enable_slicing() | ||
233 | vae.set_use_memory_efficient_attention_xformers(True) | ||
234 | unet.set_use_memory_efficient_attention_xformers(True) | ||
235 | |||
236 | embeddings = patch_managed_embeddings(text_encoder) | ||
237 | |||
238 | vae.requires_grad_(False) | ||
239 | unet.requires_grad_(False) | ||
240 | text_encoder.requires_grad_(False) | ||
241 | |||
242 | return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings | ||
243 | |||
244 | |||
245 | def add_placeholder_tokens( | 223 | def add_placeholder_tokens( |
246 | tokenizer: MultiCLIPTokenizer, | 224 | tokenizer: MultiCLIPTokenizer, |
247 | embeddings: ManagedCLIPTextEmbeddings, | 225 | embeddings: ManagedCLIPTextEmbeddings, |
@@ -508,3 +486,79 @@ def train_loop( | |||
508 | if accelerator.is_main_process: | 486 | if accelerator.is_main_process: |
509 | print("Interrupted") | 487 | print("Interrupted") |
510 | on_checkpoint(global_step + global_step_offset, "end") | 488 | on_checkpoint(global_step + global_step_offset, "end") |
489 | |||
490 | |||
491 | def train( | ||
492 | accelerator: Accelerator, | ||
493 | unet: UNet2DConditionModel, | ||
494 | text_encoder: CLIPTextModel, | ||
495 | vae: AutoencoderKL, | ||
496 | noise_scheduler: DDPMScheduler, | ||
497 | train_dataloader: DataLoader, | ||
498 | val_dataloader: DataLoader, | ||
499 | dtype: torch.dtype, | ||
500 | seed: int, | ||
501 | optimizer: torch.optim.Optimizer, | ||
502 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | ||
503 | callbacks_fn: Callable[..., TrainingCallbacks], | ||
504 | num_train_epochs: int = 100, | ||
505 | sample_frequency: int = 20, | ||
506 | checkpoint_frequency: int = 50, | ||
507 | global_step_offset: int = 0, | ||
508 | with_prior_preservation: bool = False, | ||
509 | prior_loss_weight: float = 1.0, | ||
510 | **kwargs, | ||
511 | ): | ||
512 | unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( | ||
513 | unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler | ||
514 | ) | ||
515 | |||
516 | vae.to(accelerator.device, dtype=dtype) | ||
517 | |||
518 | for model in (unet, text_encoder, vae): | ||
519 | model.requires_grad_(False) | ||
520 | model.eval() | ||
521 | |||
522 | callbacks = callbacks_fn( | ||
523 | accelerator=accelerator, | ||
524 | unet=unet, | ||
525 | text_encoder=text_encoder, | ||
526 | vae=vae, | ||
527 | train_dataloader=train_dataloader, | ||
528 | val_dataloader=val_dataloader, | ||
529 | seed=seed, | ||
530 | **kwargs, | ||
531 | ) | ||
532 | |||
533 | callbacks.on_prepare() | ||
534 | |||
535 | loss_step_ = partial( | ||
536 | loss_step, | ||
537 | vae, | ||
538 | noise_scheduler, | ||
539 | unet, | ||
540 | text_encoder, | ||
541 | with_prior_preservation, | ||
542 | prior_loss_weight, | ||
543 | seed, | ||
544 | ) | ||
545 | |||
546 | if accelerator.is_main_process: | ||
547 | accelerator.init_trackers("textual_inversion") | ||
548 | |||
549 | train_loop( | ||
550 | accelerator=accelerator, | ||
551 | optimizer=optimizer, | ||
552 | lr_scheduler=lr_scheduler, | ||
553 | train_dataloader=train_dataloader, | ||
554 | val_dataloader=val_dataloader, | ||
555 | loss_step=loss_step_, | ||
556 | sample_frequency=sample_frequency, | ||
557 | checkpoint_frequency=checkpoint_frequency, | ||
558 | global_step_offset=global_step_offset, | ||
559 | num_epochs=num_train_epochs, | ||
560 | callbacks=callbacks, | ||
561 | ) | ||
562 | |||
563 | accelerator.end_training() | ||
564 | accelerator.free_memory() | ||
diff --git a/training/util.py b/training/util.py index f46cc61..557b196 100644 --- a/training/util.py +++ b/training/util.py | |||
@@ -180,11 +180,13 @@ class EMAModel: | |||
180 | 180 | ||
181 | @contextmanager | 181 | @contextmanager |
182 | def apply_temporary(self, parameters: Iterable[torch.nn.Parameter]): | 182 | def apply_temporary(self, parameters: Iterable[torch.nn.Parameter]): |
183 | parameters = list(parameters) | ||
184 | original_params = [p.clone() for p in parameters] | ||
185 | self.copy_to(parameters) | ||
186 | |||
183 | try: | 187 | try: |
184 | parameters = list(parameters) | ||
185 | original_params = [p.clone() for p in parameters] | ||
186 | self.copy_to(parameters) | ||
187 | yield | 188 | yield |
188 | finally: | 189 | finally: |
189 | for o_param, param in zip(original_params, parameters): | 190 | for o_param, param in zip(original_params, parameters): |
190 | param.data.copy_(o_param.data) | 191 | param.data.copy_(o_param.data) |
192 | del original_params | ||