diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/functional.py | 19 | ||||
-rw-r--r-- | training/strategy/dreambooth.py | 10 | ||||
-rw-r--r-- | training/strategy/ti.py | 19 | ||||
-rw-r--r-- | training/util.py | 11 |
4 files changed, 38 insertions, 21 deletions
diff --git a/training/functional.py b/training/functional.py index 3d27380..7a3e821 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -39,11 +39,18 @@ class TrainingCallbacks(): | |||
39 | on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) | 39 | on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) |
40 | on_before_optimize: Callable[[int], None] = const() | 40 | on_before_optimize: Callable[[int], None] = const() |
41 | on_after_optimize: Callable[[float], None] = const() | 41 | on_after_optimize: Callable[[float], None] = const() |
42 | on_after_epoch: Callable[[float], None] = const() | ||
42 | on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()) | 43 | on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()) |
43 | on_sample: Callable[[int], None] = const() | 44 | on_sample: Callable[[int], None] = const() |
44 | on_checkpoint: Callable[[int, str], None] = const() | 45 | on_checkpoint: Callable[[int, str], None] = const() |
45 | 46 | ||
46 | 47 | ||
48 | @dataclass | ||
49 | class TrainingStrategy(): | ||
50 | callbacks: Callable[..., TrainingCallbacks] | ||
51 | prepare_unet: bool = False | ||
52 | |||
53 | |||
47 | def make_grid(images, rows, cols): | 54 | def make_grid(images, rows, cols): |
48 | w, h = images[0].size | 55 | w, h = images[0].size |
49 | grid = Image.new('RGB', size=(cols*w, rows*h)) | 56 | grid = Image.new('RGB', size=(cols*w, rows*h)) |
@@ -373,6 +380,7 @@ def train_loop( | |||
373 | on_train = callbacks.on_train | 380 | on_train = callbacks.on_train |
374 | on_before_optimize = callbacks.on_before_optimize | 381 | on_before_optimize = callbacks.on_before_optimize |
375 | on_after_optimize = callbacks.on_after_optimize | 382 | on_after_optimize = callbacks.on_after_optimize |
383 | on_after_epoch = callbacks.on_after_epoch | ||
376 | on_eval = callbacks.on_eval | 384 | on_eval = callbacks.on_eval |
377 | on_sample = callbacks.on_sample | 385 | on_sample = callbacks.on_sample |
378 | on_checkpoint = callbacks.on_checkpoint | 386 | on_checkpoint = callbacks.on_checkpoint |
@@ -434,6 +442,8 @@ def train_loop( | |||
434 | 442 | ||
435 | accelerator.wait_for_everyone() | 443 | accelerator.wait_for_everyone() |
436 | 444 | ||
445 | on_after_epoch(lr_scheduler.get_last_lr()[0]) | ||
446 | |||
437 | if val_dataloader is not None: | 447 | if val_dataloader is not None: |
438 | model.eval() | 448 | model.eval() |
439 | 449 | ||
@@ -512,8 +522,7 @@ def train( | |||
512 | val_dataloader: Optional[DataLoader], | 522 | val_dataloader: Optional[DataLoader], |
513 | optimizer: torch.optim.Optimizer, | 523 | optimizer: torch.optim.Optimizer, |
514 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | 524 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, |
515 | callbacks_fn: Callable[..., TrainingCallbacks], | 525 | strategy: TrainingStrategy, |
516 | prepare_unet: bool = False, | ||
517 | num_train_epochs: int = 100, | 526 | num_train_epochs: int = 100, |
518 | sample_frequency: int = 20, | 527 | sample_frequency: int = 20, |
519 | checkpoint_frequency: int = 50, | 528 | checkpoint_frequency: int = 50, |
@@ -524,12 +533,12 @@ def train( | |||
524 | ): | 533 | ): |
525 | prep = [text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler] | 534 | prep = [text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler] |
526 | 535 | ||
527 | if prepare_unet: | 536 | if strategy.prepare_unet: |
528 | prep.append(unet) | 537 | prep.append(unet) |
529 | 538 | ||
530 | prep = accelerator.prepare(*prep) | 539 | prep = accelerator.prepare(*prep) |
531 | 540 | ||
532 | if prepare_unet: | 541 | if strategy.prepare_unet: |
533 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler, unet = prep | 542 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler, unet = prep |
534 | else: | 543 | else: |
535 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = prep | 544 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = prep |
@@ -542,7 +551,7 @@ def train( | |||
542 | model.requires_grad_(False) | 551 | model.requires_grad_(False) |
543 | model.eval() | 552 | model.eval() |
544 | 553 | ||
545 | callbacks = callbacks_fn( | 554 | callbacks = strategy.callbacks( |
546 | accelerator=accelerator, | 555 | accelerator=accelerator, |
547 | unet=unet, | 556 | unet=unet, |
548 | text_encoder=text_encoder, | 557 | text_encoder=text_encoder, |
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index 93c81cb..bc26ee6 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py | |||
@@ -15,10 +15,10 @@ from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepSch | |||
15 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 15 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
16 | from models.clip.tokenizer import MultiCLIPTokenizer | 16 | from models.clip.tokenizer import MultiCLIPTokenizer |
17 | from training.util import EMAModel | 17 | from training.util import EMAModel |
18 | from training.functional import TrainingCallbacks, save_samples | 18 | from training.functional import TrainingStrategy, TrainingCallbacks, save_samples |
19 | 19 | ||
20 | 20 | ||
21 | def dreambooth_strategy( | 21 | def dreambooth_strategy_callbacks( |
22 | accelerator: Accelerator, | 22 | accelerator: Accelerator, |
23 | unet: UNet2DConditionModel, | 23 | unet: UNet2DConditionModel, |
24 | text_encoder: CLIPTextModel, | 24 | text_encoder: CLIPTextModel, |
@@ -185,3 +185,9 @@ def dreambooth_strategy( | |||
185 | on_checkpoint=on_checkpoint, | 185 | on_checkpoint=on_checkpoint, |
186 | on_sample=on_sample, | 186 | on_sample=on_sample, |
187 | ) | 187 | ) |
188 | |||
189 | |||
190 | dreambooth_strategy = TrainingStrategy( | ||
191 | callbacks=dreambooth_strategy_callbacks, | ||
192 | prepare_unet=True | ||
193 | ) | ||
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 00f3529..597abd0 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
@@ -15,10 +15,10 @@ from slugify import slugify | |||
15 | 15 | ||
16 | from models.clip.tokenizer import MultiCLIPTokenizer | 16 | from models.clip.tokenizer import MultiCLIPTokenizer |
17 | from training.util import EMAModel | 17 | from training.util import EMAModel |
18 | from training.functional import TrainingCallbacks, save_samples | 18 | from training.functional import TrainingStrategy, TrainingCallbacks, save_samples |
19 | 19 | ||
20 | 20 | ||
21 | def textual_inversion_strategy( | 21 | def textual_inversion_strategy_callbacks( |
22 | accelerator: Accelerator, | 22 | accelerator: Accelerator, |
23 | unet: UNet2DConditionModel, | 23 | unet: UNet2DConditionModel, |
24 | text_encoder: CLIPTextModel, | 24 | text_encoder: CLIPTextModel, |
@@ -119,17 +119,18 @@ def textual_inversion_strategy( | |||
119 | with ema_context(): | 119 | with ema_context(): |
120 | yield | 120 | yield |
121 | 121 | ||
122 | @torch.no_grad() | ||
123 | def on_after_optimize(lr: float): | 122 | def on_after_optimize(lr: float): |
123 | if use_ema: | ||
124 | ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) | ||
125 | |||
126 | @torch.no_grad() | ||
127 | def on_after_epoch(lr: float): | ||
124 | if use_emb_decay: | 128 | if use_emb_decay: |
125 | text_encoder.text_model.embeddings.normalize( | 129 | text_encoder.text_model.embeddings.normalize( |
126 | emb_decay_target, | 130 | emb_decay_target, |
127 | min(1.0, max(0.0, emb_decay_factor * ((lr - emb_decay_start) / (learning_rate - emb_decay_start)))) | 131 | min(1.0, max(0.0, emb_decay_factor * ((lr - emb_decay_start) / (learning_rate - emb_decay_start)))) |
128 | ) | 132 | ) |
129 | 133 | ||
130 | if use_ema: | ||
131 | ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) | ||
132 | |||
133 | def on_log(): | 134 | def on_log(): |
134 | if use_ema: | 135 | if use_ema: |
135 | return {"ema_decay": ema_embeddings.decay} | 136 | return {"ema_decay": ema_embeddings.decay} |
@@ -157,7 +158,13 @@ def textual_inversion_strategy( | |||
157 | on_train=on_train, | 158 | on_train=on_train, |
158 | on_eval=on_eval, | 159 | on_eval=on_eval, |
159 | on_after_optimize=on_after_optimize, | 160 | on_after_optimize=on_after_optimize, |
161 | on_after_epoch=on_after_epoch, | ||
160 | on_log=on_log, | 162 | on_log=on_log, |
161 | on_checkpoint=on_checkpoint, | 163 | on_checkpoint=on_checkpoint, |
162 | on_sample=on_sample, | 164 | on_sample=on_sample, |
163 | ) | 165 | ) |
166 | |||
167 | |||
168 | textual_inversion_strategy = TrainingStrategy( | ||
169 | callbacks=textual_inversion_strategy_callbacks, | ||
170 | ) | ||
diff --git a/training/util.py b/training/util.py index 557b196..237626f 100644 --- a/training/util.py +++ b/training/util.py | |||
@@ -1,18 +1,11 @@ | |||
1 | from pathlib import Path | 1 | from pathlib import Path |
2 | import json | 2 | import json |
3 | import copy | 3 | import copy |
4 | from typing import Iterable, Union | 4 | from typing import Iterable, Any |
5 | from contextlib import contextmanager | 5 | from contextlib import contextmanager |
6 | 6 | ||
7 | import torch | 7 | import torch |
8 | 8 | ||
9 | from transformers import CLIPTextModel | ||
10 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, DPMSolverMultistepScheduler | ||
11 | |||
12 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | ||
13 | from models.clip.tokenizer import MultiCLIPTokenizer | ||
14 | from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings | ||
15 | |||
16 | 9 | ||
17 | def save_args(basepath: Path, args, extra={}): | 10 | def save_args(basepath: Path, args, extra={}): |
18 | info = {"args": vars(args)} | 11 | info = {"args": vars(args)} |
@@ -22,6 +15,8 @@ def save_args(basepath: Path, args, extra={}): | |||
22 | 15 | ||
23 | 16 | ||
24 | class AverageMeter: | 17 | class AverageMeter: |
18 | avg: Any | ||
19 | |||
25 | def __init__(self, name=None): | 20 | def __init__(self, name=None): |
26 | self.name = name | 21 | self.name = name |
27 | self.reset() | 22 | self.reset() |