summaryrefslogtreecommitdiffstats
path: root/training/functional.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/functional.py')
-rw-r--r--training/functional.py102
1 files changed, 78 insertions, 24 deletions
diff --git a/training/functional.py b/training/functional.py
index c01595a..5984ffb 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -1,7 +1,7 @@
1from dataclasses import dataclass 1from dataclasses import dataclass
2import math 2import math
3from contextlib import _GeneratorContextManager, nullcontext 3from contextlib import _GeneratorContextManager, nullcontext
4from typing import Callable, Any, Tuple, Union, Optional 4from typing import Callable, Any, Tuple, Union, Optional, Type
5from functools import partial 5from functools import partial
6from pathlib import Path 6from pathlib import Path
7import itertools 7import itertools
@@ -32,7 +32,7 @@ def const(result=None):
32 32
33@dataclass 33@dataclass
34class TrainingCallbacks(): 34class TrainingCallbacks():
35 on_prepare: Callable[[float], None] = const() 35 on_prepare: Callable[[], None] = const()
36 on_model: Callable[[], torch.nn.Module] = const(None) 36 on_model: Callable[[], torch.nn.Module] = const(None)
37 on_log: Callable[[], dict[str, Any]] = const({}) 37 on_log: Callable[[], dict[str, Any]] = const({})
38 on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) 38 on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext())
@@ -220,28 +220,6 @@ def generate_class_images(
220 torch.cuda.empty_cache() 220 torch.cuda.empty_cache()
221 221
222 222
223def get_models(pretrained_model_name_or_path: str):
224 tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer')
225 text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder')
226 vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae')
227 unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet')
228 noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler')
229 sample_scheduler = DPMSolverMultistepScheduler.from_pretrained(
230 pretrained_model_name_or_path, subfolder='scheduler')
231
232 vae.enable_slicing()
233 vae.set_use_memory_efficient_attention_xformers(True)
234 unet.set_use_memory_efficient_attention_xformers(True)
235
236 embeddings = patch_managed_embeddings(text_encoder)
237
238 vae.requires_grad_(False)
239 unet.requires_grad_(False)
240 text_encoder.requires_grad_(False)
241
242 return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings
243
244
245def add_placeholder_tokens( 223def add_placeholder_tokens(
246 tokenizer: MultiCLIPTokenizer, 224 tokenizer: MultiCLIPTokenizer,
247 embeddings: ManagedCLIPTextEmbeddings, 225 embeddings: ManagedCLIPTextEmbeddings,
@@ -508,3 +486,79 @@ def train_loop(
508 if accelerator.is_main_process: 486 if accelerator.is_main_process:
509 print("Interrupted") 487 print("Interrupted")
510 on_checkpoint(global_step + global_step_offset, "end") 488 on_checkpoint(global_step + global_step_offset, "end")
489
490
491def train(
492 accelerator: Accelerator,
493 unet: UNet2DConditionModel,
494 text_encoder: CLIPTextModel,
495 vae: AutoencoderKL,
496 noise_scheduler: DDPMScheduler,
497 train_dataloader: DataLoader,
498 val_dataloader: DataLoader,
499 dtype: torch.dtype,
500 seed: int,
501 optimizer: torch.optim.Optimizer,
502 lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
503 callbacks_fn: Callable[..., TrainingCallbacks],
504 num_train_epochs: int = 100,
505 sample_frequency: int = 20,
506 checkpoint_frequency: int = 50,
507 global_step_offset: int = 0,
508 with_prior_preservation: bool = False,
509 prior_loss_weight: float = 1.0,
510 **kwargs,
511):
512 unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
513 unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler
514 )
515
516 vae.to(accelerator.device, dtype=dtype)
517
518 for model in (unet, text_encoder, vae):
519 model.requires_grad_(False)
520 model.eval()
521
522 callbacks = callbacks_fn(
523 accelerator=accelerator,
524 unet=unet,
525 text_encoder=text_encoder,
526 vae=vae,
527 train_dataloader=train_dataloader,
528 val_dataloader=val_dataloader,
529 seed=seed,
530 **kwargs,
531 )
532
533 callbacks.on_prepare()
534
535 loss_step_ = partial(
536 loss_step,
537 vae,
538 noise_scheduler,
539 unet,
540 text_encoder,
541 with_prior_preservation,
542 prior_loss_weight,
543 seed,
544 )
545
546 if accelerator.is_main_process:
547 accelerator.init_trackers("textual_inversion")
548
549 train_loop(
550 accelerator=accelerator,
551 optimizer=optimizer,
552 lr_scheduler=lr_scheduler,
553 train_dataloader=train_dataloader,
554 val_dataloader=val_dataloader,
555 loss_step=loss_step_,
556 sample_frequency=sample_frequency,
557 checkpoint_frequency=checkpoint_frequency,
558 global_step_offset=global_step_offset,
559 num_epochs=num_train_epochs,
560 callbacks=callbacks,
561 )
562
563 accelerator.end_training()
564 accelerator.free_memory()