summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--data/csv.py8
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py16
-rw-r--r--train_ti.py82
-rw-r--r--training/functional.py102
-rw-r--r--training/util.py8
5 files changed, 112 insertions, 104 deletions
diff --git a/data/csv.py b/data/csv.py
index 5de3ac7..2a8115b 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -15,9 +15,6 @@ from data.keywords import prompt_to_keywords, keywords_to_prompt
15from models.clip.util import unify_input_ids 15from models.clip.util import unify_input_ids
16 16
17 17
18image_cache: dict[str, Image.Image] = {}
19
20
21interpolations = { 18interpolations = {
22 "linear": transforms.InterpolationMode.NEAREST, 19 "linear": transforms.InterpolationMode.NEAREST,
23 "bilinear": transforms.InterpolationMode.BILINEAR, 20 "bilinear": transforms.InterpolationMode.BILINEAR,
@@ -27,14 +24,9 @@ interpolations = {
27 24
28 25
29def get_image(path): 26def get_image(path):
30 if path in image_cache:
31 return image_cache[path]
32
33 image = Image.open(path) 27 image = Image.open(path)
34 if not image.mode == "RGB": 28 if not image.mode == "RGB":
35 image = image.convert("RGB") 29 image = image.convert("RGB")
36 image_cache[path] = image
37
38 return image 30 return image
39 31
40 32
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index 43141bd..3027421 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -162,8 +162,8 @@ class VlpnStableDiffusion(DiffusionPipeline):
162 self, 162 self,
163 prompt: Union[str, List[str], List[int], List[List[int]]], 163 prompt: Union[str, List[str], List[int], List[List[int]]],
164 negative_prompt: Optional[Union[str, List[str], List[int], List[List[int]]]], 164 negative_prompt: Optional[Union[str, List[str], List[int], List[List[int]]]],
165 width: Optional[int], 165 width: int,
166 height: Optional[int], 166 height: int,
167 strength: float, 167 strength: float,
168 callback_steps: Optional[int] 168 callback_steps: Optional[int]
169 ): 169 ):
@@ -324,19 +324,19 @@ class VlpnStableDiffusion(DiffusionPipeline):
324 self, 324 self,
325 prompt: Union[str, List[str], List[int], List[List[int]]], 325 prompt: Union[str, List[str], List[int], List[List[int]]],
326 negative_prompt: Optional[Union[str, List[str], List[int], List[List[int]]]] = None, 326 negative_prompt: Optional[Union[str, List[str], List[int], List[List[int]]]] = None,
327 num_images_per_prompt: Optional[int] = 1, 327 num_images_per_prompt: int = 1,
328 strength: float = 0.8, 328 strength: float = 0.8,
329 height: Optional[int] = None, 329 height: Optional[int] = None,
330 width: Optional[int] = None, 330 width: Optional[int] = None,
331 num_inference_steps: Optional[int] = 50, 331 num_inference_steps: int = 50,
332 guidance_scale: Optional[float] = 7.5, 332 guidance_scale: float = 7.5,
333 eta: Optional[float] = 0.0, 333 eta: float = 0.0,
334 generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, 334 generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
335 image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None, 335 image: Optional[Union[torch.FloatTensor, PIL.Image.Image]] = None,
336 output_type: Optional[str] = "pil", 336 output_type: str = "pil",
337 return_dict: bool = True, 337 return_dict: bool = True,
338 callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, 338 callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
339 callback_steps: Optional[int] = 1, 339 callback_steps: int = 1,
340 ): 340 ):
341 r""" 341 r"""
342 Function invoked when calling the pipeline for generation. 342 Function invoked when calling the pipeline for generation.
diff --git a/train_ti.py b/train_ti.py
index 4bac736..77dec12 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -10,15 +10,13 @@ import torch.utils.checkpoint
10from accelerate import Accelerator 10from accelerate import Accelerator
11from accelerate.logging import get_logger 11from accelerate.logging import get_logger
12from accelerate.utils import LoggerType, set_seed 12from accelerate.utils import LoggerType, set_seed
13import matplotlib.pyplot as plt
14from slugify import slugify 13from slugify import slugify
15 14
16from util import load_config, load_embeddings_from_dir 15from util import load_config, load_embeddings_from_dir
17from data.csv import VlpnDataModule, VlpnDataItem 16from data.csv import VlpnDataModule, VlpnDataItem
18from training.functional import train_loop, loss_step, generate_class_images, add_placeholder_tokens, get_models 17from training.functional import train, generate_class_images, add_placeholder_tokens, get_models
19from training.strategy.ti import textual_inversion_strategy 18from training.strategy.ti import textual_inversion_strategy
20from training.optimization import get_scheduler 19from training.optimization import get_scheduler
21from training.lr import LRFinder
22from training.util import save_args 20from training.util import save_args
23 21
24logger = get_logger(__name__) 22logger = get_logger(__name__)
@@ -644,23 +642,33 @@ def main():
644 warmup_epochs=args.lr_warmup_epochs, 642 warmup_epochs=args.lr_warmup_epochs,
645 ) 643 )
646 644
647 unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( 645 trainer = partial(
648 unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler 646 train,
649 )
650
651 vae.to(accelerator.device, dtype=weight_dtype)
652
653 callbacks = textual_inversion_strategy(
654 accelerator=accelerator, 647 accelerator=accelerator,
655 unet=unet, 648 unet=unet,
656 text_encoder=text_encoder, 649 text_encoder=text_encoder,
657 tokenizer=tokenizer,
658 vae=vae, 650 vae=vae,
659 sample_scheduler=sample_scheduler, 651 noise_scheduler=noise_scheduler,
660 train_dataloader=train_dataloader, 652 train_dataloader=train_dataloader,
661 val_dataloader=val_dataloader, 653 val_dataloader=val_dataloader,
662 output_dir=output_dir, 654 dtype=weight_dtype,
663 seed=args.seed, 655 seed=args.seed,
656 callbacks_fn=textual_inversion_strategy
657 )
658
659 trainer(
660 optimizer=optimizer,
661 lr_scheduler=lr_scheduler,
662 num_train_epochs=args.num_train_epochs,
663 sample_frequency=args.sample_frequency,
664 checkpoint_frequency=args.checkpoint_frequency,
665 global_step_offset=global_step_offset,
666 with_prior_preservation=args.num_class_images != 0,
667 prior_loss_weight=args.prior_loss_weight,
668 # --
669 tokenizer=tokenizer,
670 sample_scheduler=sample_scheduler,
671 output_dir=output_dir,
664 placeholder_tokens=args.placeholder_tokens, 672 placeholder_tokens=args.placeholder_tokens,
665 placeholder_token_ids=placeholder_token_ids, 673 placeholder_token_ids=placeholder_token_ids,
666 learning_rate=args.learning_rate, 674 learning_rate=args.learning_rate,
@@ -679,54 +687,6 @@ def main():
679 sample_image_size=args.sample_image_size, 687 sample_image_size=args.sample_image_size,
680 ) 688 )
681 689
682 for model in (unet, text_encoder, vae):
683 model.requires_grad_(False)
684 model.eval()
685
686 callbacks.on_prepare()
687
688 loss_step_ = partial(
689 loss_step,
690 vae,
691 noise_scheduler,
692 unet,
693 text_encoder,
694 args.num_class_images != 0,
695 args.prior_loss_weight,
696 args.seed,
697 )
698
699 if args.find_lr:
700 lr_finder = LRFinder(
701 accelerator=accelerator,
702 optimizer=optimizer,
703 train_dataloader=train_dataloader,
704 val_dataloader=val_dataloader,
705 callbacks=callbacks,
706 )
707 lr_finder.run(num_epochs=100, end_lr=1e3)
708
709 plt.savefig(output_dir.joinpath("lr.png"), dpi=300)
710 plt.close()
711 else:
712 if accelerator.is_main_process:
713 accelerator.init_trackers("textual_inversion")
714
715 train_loop(
716 accelerator=accelerator,
717 optimizer=optimizer,
718 lr_scheduler=lr_scheduler,
719 train_dataloader=train_dataloader,
720 val_dataloader=val_dataloader,
721 loss_step=loss_step_,
722 sample_frequency=args.sample_frequency,
723 checkpoint_frequency=args.checkpoint_frequency,
724 global_step_offset=global_step_offset,
725 callbacks=callbacks,
726 )
727
728 accelerator.end_training()
729
730 690
731if __name__ == "__main__": 691if __name__ == "__main__":
732 main() 692 main()
diff --git a/training/functional.py b/training/functional.py
index c01595a..5984ffb 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -1,7 +1,7 @@
1from dataclasses import dataclass 1from dataclasses import dataclass
2import math 2import math
3from contextlib import _GeneratorContextManager, nullcontext 3from contextlib import _GeneratorContextManager, nullcontext
4from typing import Callable, Any, Tuple, Union, Optional 4from typing import Callable, Any, Tuple, Union, Optional, Type
5from functools import partial 5from functools import partial
6from pathlib import Path 6from pathlib import Path
7import itertools 7import itertools
@@ -32,7 +32,7 @@ def const(result=None):
32 32
33@dataclass 33@dataclass
34class TrainingCallbacks(): 34class TrainingCallbacks():
35 on_prepare: Callable[[float], None] = const() 35 on_prepare: Callable[[], None] = const()
36 on_model: Callable[[], torch.nn.Module] = const(None) 36 on_model: Callable[[], torch.nn.Module] = const(None)
37 on_log: Callable[[], dict[str, Any]] = const({}) 37 on_log: Callable[[], dict[str, Any]] = const({})
38 on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) 38 on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext())
@@ -220,28 +220,6 @@ def generate_class_images(
220 torch.cuda.empty_cache() 220 torch.cuda.empty_cache()
221 221
222 222
223def get_models(pretrained_model_name_or_path: str):
224 tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer')
225 text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder')
226 vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae')
227 unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet')
228 noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler')
229 sample_scheduler = DPMSolverMultistepScheduler.from_pretrained(
230 pretrained_model_name_or_path, subfolder='scheduler')
231
232 vae.enable_slicing()
233 vae.set_use_memory_efficient_attention_xformers(True)
234 unet.set_use_memory_efficient_attention_xformers(True)
235
236 embeddings = patch_managed_embeddings(text_encoder)
237
238 vae.requires_grad_(False)
239 unet.requires_grad_(False)
240 text_encoder.requires_grad_(False)
241
242 return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings
243
244
245def add_placeholder_tokens( 223def add_placeholder_tokens(
246 tokenizer: MultiCLIPTokenizer, 224 tokenizer: MultiCLIPTokenizer,
247 embeddings: ManagedCLIPTextEmbeddings, 225 embeddings: ManagedCLIPTextEmbeddings,
@@ -508,3 +486,79 @@ def train_loop(
508 if accelerator.is_main_process: 486 if accelerator.is_main_process:
509 print("Interrupted") 487 print("Interrupted")
510 on_checkpoint(global_step + global_step_offset, "end") 488 on_checkpoint(global_step + global_step_offset, "end")
489
490
491def train(
492 accelerator: Accelerator,
493 unet: UNet2DConditionModel,
494 text_encoder: CLIPTextModel,
495 vae: AutoencoderKL,
496 noise_scheduler: DDPMScheduler,
497 train_dataloader: DataLoader,
498 val_dataloader: DataLoader,
499 dtype: torch.dtype,
500 seed: int,
501 optimizer: torch.optim.Optimizer,
502 lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
503 callbacks_fn: Callable[..., TrainingCallbacks],
504 num_train_epochs: int = 100,
505 sample_frequency: int = 20,
506 checkpoint_frequency: int = 50,
507 global_step_offset: int = 0,
508 with_prior_preservation: bool = False,
509 prior_loss_weight: float = 1.0,
510 **kwargs,
511):
512 unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
513 unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler
514 )
515
516 vae.to(accelerator.device, dtype=dtype)
517
518 for model in (unet, text_encoder, vae):
519 model.requires_grad_(False)
520 model.eval()
521
522 callbacks = callbacks_fn(
523 accelerator=accelerator,
524 unet=unet,
525 text_encoder=text_encoder,
526 vae=vae,
527 train_dataloader=train_dataloader,
528 val_dataloader=val_dataloader,
529 seed=seed,
530 **kwargs,
531 )
532
533 callbacks.on_prepare()
534
535 loss_step_ = partial(
536 loss_step,
537 vae,
538 noise_scheduler,
539 unet,
540 text_encoder,
541 with_prior_preservation,
542 prior_loss_weight,
543 seed,
544 )
545
546 if accelerator.is_main_process:
547 accelerator.init_trackers("textual_inversion")
548
549 train_loop(
550 accelerator=accelerator,
551 optimizer=optimizer,
552 lr_scheduler=lr_scheduler,
553 train_dataloader=train_dataloader,
554 val_dataloader=val_dataloader,
555 loss_step=loss_step_,
556 sample_frequency=sample_frequency,
557 checkpoint_frequency=checkpoint_frequency,
558 global_step_offset=global_step_offset,
559 num_epochs=num_train_epochs,
560 callbacks=callbacks,
561 )
562
563 accelerator.end_training()
564 accelerator.free_memory()
diff --git a/training/util.py b/training/util.py
index f46cc61..557b196 100644
--- a/training/util.py
+++ b/training/util.py
@@ -180,11 +180,13 @@ class EMAModel:
180 180
181 @contextmanager 181 @contextmanager
182 def apply_temporary(self, parameters: Iterable[torch.nn.Parameter]): 182 def apply_temporary(self, parameters: Iterable[torch.nn.Parameter]):
183 parameters = list(parameters)
184 original_params = [p.clone() for p in parameters]
185 self.copy_to(parameters)
186
183 try: 187 try:
184 parameters = list(parameters)
185 original_params = [p.clone() for p in parameters]
186 self.copy_to(parameters)
187 yield 188 yield
188 finally: 189 finally:
189 for o_param, param in zip(original_params, parameters): 190 for o_param, param in zip(original_params, parameters):
190 param.data.copy_(o_param.data) 191 param.data.copy_(o_param.data)
192 del original_params