diff options
author | Volpeon <git@volpeon.ink> | 2023-01-15 21:06:16 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-15 21:06:16 +0100 |
commit | 632ce00b54ffeacfc18f44f10827f167ab3ac37c (patch) | |
tree | ecf58df2b176d3c7d1583136bf453ed24de8d7f3 /training | |
parent | Fixed Conda env (diff) | |
download | textual-inversion-diff-632ce00b54ffeacfc18f44f10827f167ab3ac37c.tar.gz textual-inversion-diff-632ce00b54ffeacfc18f44f10827f167ab3ac37c.tar.bz2 textual-inversion-diff-632ce00b54ffeacfc18f44f10827f167ab3ac37c.zip |
Restored functional trainer
Diffstat (limited to 'training')
-rw-r--r-- | training/functional.py | 102 | ||||
-rw-r--r-- | training/util.py | 8 |
2 files changed, 83 insertions, 27 deletions
diff --git a/training/functional.py b/training/functional.py index c01595a..5984ffb 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -1,7 +1,7 @@ | |||
1 | from dataclasses import dataclass | 1 | from dataclasses import dataclass |
2 | import math | 2 | import math |
3 | from contextlib import _GeneratorContextManager, nullcontext | 3 | from contextlib import _GeneratorContextManager, nullcontext |
4 | from typing import Callable, Any, Tuple, Union, Optional | 4 | from typing import Callable, Any, Tuple, Union, Optional, Type |
5 | from functools import partial | 5 | from functools import partial |
6 | from pathlib import Path | 6 | from pathlib import Path |
7 | import itertools | 7 | import itertools |
@@ -32,7 +32,7 @@ def const(result=None): | |||
32 | 32 | ||
33 | @dataclass | 33 | @dataclass |
34 | class TrainingCallbacks(): | 34 | class TrainingCallbacks(): |
35 | on_prepare: Callable[[float], None] = const() | 35 | on_prepare: Callable[[], None] = const() |
36 | on_model: Callable[[], torch.nn.Module] = const(None) | 36 | on_model: Callable[[], torch.nn.Module] = const(None) |
37 | on_log: Callable[[], dict[str, Any]] = const({}) | 37 | on_log: Callable[[], dict[str, Any]] = const({}) |
38 | on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) | 38 | on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) |
@@ -220,28 +220,6 @@ def generate_class_images( | |||
220 | torch.cuda.empty_cache() | 220 | torch.cuda.empty_cache() |
221 | 221 | ||
222 | 222 | ||
223 | def get_models(pretrained_model_name_or_path: str): | ||
224 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') | ||
225 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') | ||
226 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') | ||
227 | unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet') | ||
228 | noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') | ||
229 | sample_scheduler = DPMSolverMultistepScheduler.from_pretrained( | ||
230 | pretrained_model_name_or_path, subfolder='scheduler') | ||
231 | |||
232 | vae.enable_slicing() | ||
233 | vae.set_use_memory_efficient_attention_xformers(True) | ||
234 | unet.set_use_memory_efficient_attention_xformers(True) | ||
235 | |||
236 | embeddings = patch_managed_embeddings(text_encoder) | ||
237 | |||
238 | vae.requires_grad_(False) | ||
239 | unet.requires_grad_(False) | ||
240 | text_encoder.requires_grad_(False) | ||
241 | |||
242 | return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings | ||
243 | |||
244 | |||
245 | def add_placeholder_tokens( | 223 | def add_placeholder_tokens( |
246 | tokenizer: MultiCLIPTokenizer, | 224 | tokenizer: MultiCLIPTokenizer, |
247 | embeddings: ManagedCLIPTextEmbeddings, | 225 | embeddings: ManagedCLIPTextEmbeddings, |
@@ -508,3 +486,79 @@ def train_loop( | |||
508 | if accelerator.is_main_process: | 486 | if accelerator.is_main_process: |
509 | print("Interrupted") | 487 | print("Interrupted") |
510 | on_checkpoint(global_step + global_step_offset, "end") | 488 | on_checkpoint(global_step + global_step_offset, "end") |
489 | |||
490 | |||
491 | def train( | ||
492 | accelerator: Accelerator, | ||
493 | unet: UNet2DConditionModel, | ||
494 | text_encoder: CLIPTextModel, | ||
495 | vae: AutoencoderKL, | ||
496 | noise_scheduler: DDPMScheduler, | ||
497 | train_dataloader: DataLoader, | ||
498 | val_dataloader: DataLoader, | ||
499 | dtype: torch.dtype, | ||
500 | seed: int, | ||
501 | optimizer: torch.optim.Optimizer, | ||
502 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | ||
503 | callbacks_fn: Callable[..., TrainingCallbacks], | ||
504 | num_train_epochs: int = 100, | ||
505 | sample_frequency: int = 20, | ||
506 | checkpoint_frequency: int = 50, | ||
507 | global_step_offset: int = 0, | ||
508 | with_prior_preservation: bool = False, | ||
509 | prior_loss_weight: float = 1.0, | ||
510 | **kwargs, | ||
511 | ): | ||
512 | unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( | ||
513 | unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler | ||
514 | ) | ||
515 | |||
516 | vae.to(accelerator.device, dtype=dtype) | ||
517 | |||
518 | for model in (unet, text_encoder, vae): | ||
519 | model.requires_grad_(False) | ||
520 | model.eval() | ||
521 | |||
522 | callbacks = callbacks_fn( | ||
523 | accelerator=accelerator, | ||
524 | unet=unet, | ||
525 | text_encoder=text_encoder, | ||
526 | vae=vae, | ||
527 | train_dataloader=train_dataloader, | ||
528 | val_dataloader=val_dataloader, | ||
529 | seed=seed, | ||
530 | **kwargs, | ||
531 | ) | ||
532 | |||
533 | callbacks.on_prepare() | ||
534 | |||
535 | loss_step_ = partial( | ||
536 | loss_step, | ||
537 | vae, | ||
538 | noise_scheduler, | ||
539 | unet, | ||
540 | text_encoder, | ||
541 | with_prior_preservation, | ||
542 | prior_loss_weight, | ||
543 | seed, | ||
544 | ) | ||
545 | |||
546 | if accelerator.is_main_process: | ||
547 | accelerator.init_trackers("textual_inversion") | ||
548 | |||
549 | train_loop( | ||
550 | accelerator=accelerator, | ||
551 | optimizer=optimizer, | ||
552 | lr_scheduler=lr_scheduler, | ||
553 | train_dataloader=train_dataloader, | ||
554 | val_dataloader=val_dataloader, | ||
555 | loss_step=loss_step_, | ||
556 | sample_frequency=sample_frequency, | ||
557 | checkpoint_frequency=checkpoint_frequency, | ||
558 | global_step_offset=global_step_offset, | ||
559 | num_epochs=num_train_epochs, | ||
560 | callbacks=callbacks, | ||
561 | ) | ||
562 | |||
563 | accelerator.end_training() | ||
564 | accelerator.free_memory() | ||
diff --git a/training/util.py b/training/util.py index f46cc61..557b196 100644 --- a/training/util.py +++ b/training/util.py | |||
@@ -180,11 +180,13 @@ class EMAModel: | |||
180 | 180 | ||
181 | @contextmanager | 181 | @contextmanager |
182 | def apply_temporary(self, parameters: Iterable[torch.nn.Parameter]): | 182 | def apply_temporary(self, parameters: Iterable[torch.nn.Parameter]): |
183 | parameters = list(parameters) | ||
184 | original_params = [p.clone() for p in parameters] | ||
185 | self.copy_to(parameters) | ||
186 | |||
183 | try: | 187 | try: |
184 | parameters = list(parameters) | ||
185 | original_params = [p.clone() for p in parameters] | ||
186 | self.copy_to(parameters) | ||
187 | yield | 188 | yield |
188 | finally: | 189 | finally: |
189 | for o_param, param in zip(original_params, parameters): | 190 | for o_param, param in zip(original_params, parameters): |
190 | param.data.copy_(o_param.data) | 191 | param.data.copy_(o_param.data) |
192 | del original_params | ||