summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-17 07:20:45 +0100
committerVolpeon <git@volpeon.ink>2023-01-17 07:20:45 +0100
commit5821523a524190490a287c5e2aacb6e72cc3a4cf (patch)
treec0eac536c754f078683be6d59893ad23d70baf51 /training
parentTraining update (diff)
downloadtextual-inversion-diff-5821523a524190490a287c5e2aacb6e72cc3a4cf.tar.gz
textual-inversion-diff-5821523a524190490a287c5e2aacb6e72cc3a4cf.tar.bz2
textual-inversion-diff-5821523a524190490a287c5e2aacb6e72cc3a4cf.zip
Update
Diffstat (limited to 'training')
-rw-r--r--training/functional.py19
-rw-r--r--training/strategy/dreambooth.py10
-rw-r--r--training/strategy/ti.py19
-rw-r--r--training/util.py11
4 files changed, 38 insertions, 21 deletions
diff --git a/training/functional.py b/training/functional.py
index 3d27380..7a3e821 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -39,11 +39,18 @@ class TrainingCallbacks():
39 on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) 39 on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext())
40 on_before_optimize: Callable[[int], None] = const() 40 on_before_optimize: Callable[[int], None] = const()
41 on_after_optimize: Callable[[float], None] = const() 41 on_after_optimize: Callable[[float], None] = const()
42 on_after_epoch: Callable[[float], None] = const()
42 on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext()) 43 on_eval: Callable[[], _GeneratorContextManager] = const(nullcontext())
43 on_sample: Callable[[int], None] = const() 44 on_sample: Callable[[int], None] = const()
44 on_checkpoint: Callable[[int, str], None] = const() 45 on_checkpoint: Callable[[int, str], None] = const()
45 46
46 47
48@dataclass
49class TrainingStrategy():
50 callbacks: Callable[..., TrainingCallbacks]
51 prepare_unet: bool = False
52
53
47def make_grid(images, rows, cols): 54def make_grid(images, rows, cols):
48 w, h = images[0].size 55 w, h = images[0].size
49 grid = Image.new('RGB', size=(cols*w, rows*h)) 56 grid = Image.new('RGB', size=(cols*w, rows*h))
@@ -373,6 +380,7 @@ def train_loop(
373 on_train = callbacks.on_train 380 on_train = callbacks.on_train
374 on_before_optimize = callbacks.on_before_optimize 381 on_before_optimize = callbacks.on_before_optimize
375 on_after_optimize = callbacks.on_after_optimize 382 on_after_optimize = callbacks.on_after_optimize
383 on_after_epoch = callbacks.on_after_epoch
376 on_eval = callbacks.on_eval 384 on_eval = callbacks.on_eval
377 on_sample = callbacks.on_sample 385 on_sample = callbacks.on_sample
378 on_checkpoint = callbacks.on_checkpoint 386 on_checkpoint = callbacks.on_checkpoint
@@ -434,6 +442,8 @@ def train_loop(
434 442
435 accelerator.wait_for_everyone() 443 accelerator.wait_for_everyone()
436 444
445 on_after_epoch(lr_scheduler.get_last_lr()[0])
446
437 if val_dataloader is not None: 447 if val_dataloader is not None:
438 model.eval() 448 model.eval()
439 449
@@ -512,8 +522,7 @@ def train(
512 val_dataloader: Optional[DataLoader], 522 val_dataloader: Optional[DataLoader],
513 optimizer: torch.optim.Optimizer, 523 optimizer: torch.optim.Optimizer,
514 lr_scheduler: torch.optim.lr_scheduler._LRScheduler, 524 lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
515 callbacks_fn: Callable[..., TrainingCallbacks], 525 strategy: TrainingStrategy,
516 prepare_unet: bool = False,
517 num_train_epochs: int = 100, 526 num_train_epochs: int = 100,
518 sample_frequency: int = 20, 527 sample_frequency: int = 20,
519 checkpoint_frequency: int = 50, 528 checkpoint_frequency: int = 50,
@@ -524,12 +533,12 @@ def train(
524): 533):
525 prep = [text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler] 534 prep = [text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler]
526 535
527 if prepare_unet: 536 if strategy.prepare_unet:
528 prep.append(unet) 537 prep.append(unet)
529 538
530 prep = accelerator.prepare(*prep) 539 prep = accelerator.prepare(*prep)
531 540
532 if prepare_unet: 541 if strategy.prepare_unet:
533 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler, unet = prep 542 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler, unet = prep
534 else: 543 else:
535 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = prep 544 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = prep
@@ -542,7 +551,7 @@ def train(
542 model.requires_grad_(False) 551 model.requires_grad_(False)
543 model.eval() 552 model.eval()
544 553
545 callbacks = callbacks_fn( 554 callbacks = strategy.callbacks(
546 accelerator=accelerator, 555 accelerator=accelerator,
547 unet=unet, 556 unet=unet,
548 text_encoder=text_encoder, 557 text_encoder=text_encoder,
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py
index 93c81cb..bc26ee6 100644
--- a/training/strategy/dreambooth.py
+++ b/training/strategy/dreambooth.py
@@ -15,10 +15,10 @@ from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepSch
15from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 15from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
16from models.clip.tokenizer import MultiCLIPTokenizer 16from models.clip.tokenizer import MultiCLIPTokenizer
17from training.util import EMAModel 17from training.util import EMAModel
18from training.functional import TrainingCallbacks, save_samples 18from training.functional import TrainingStrategy, TrainingCallbacks, save_samples
19 19
20 20
21def dreambooth_strategy( 21def dreambooth_strategy_callbacks(
22 accelerator: Accelerator, 22 accelerator: Accelerator,
23 unet: UNet2DConditionModel, 23 unet: UNet2DConditionModel,
24 text_encoder: CLIPTextModel, 24 text_encoder: CLIPTextModel,
@@ -185,3 +185,9 @@ def dreambooth_strategy(
185 on_checkpoint=on_checkpoint, 185 on_checkpoint=on_checkpoint,
186 on_sample=on_sample, 186 on_sample=on_sample,
187 ) 187 )
188
189
190dreambooth_strategy = TrainingStrategy(
191 callbacks=dreambooth_strategy_callbacks,
192 prepare_unet=True
193)
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index 00f3529..597abd0 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -15,10 +15,10 @@ from slugify import slugify
15 15
16from models.clip.tokenizer import MultiCLIPTokenizer 16from models.clip.tokenizer import MultiCLIPTokenizer
17from training.util import EMAModel 17from training.util import EMAModel
18from training.functional import TrainingCallbacks, save_samples 18from training.functional import TrainingStrategy, TrainingCallbacks, save_samples
19 19
20 20
21def textual_inversion_strategy( 21def textual_inversion_strategy_callbacks(
22 accelerator: Accelerator, 22 accelerator: Accelerator,
23 unet: UNet2DConditionModel, 23 unet: UNet2DConditionModel,
24 text_encoder: CLIPTextModel, 24 text_encoder: CLIPTextModel,
@@ -119,17 +119,18 @@ def textual_inversion_strategy(
119 with ema_context(): 119 with ema_context():
120 yield 120 yield
121 121
122 @torch.no_grad()
123 def on_after_optimize(lr: float): 122 def on_after_optimize(lr: float):
123 if use_ema:
124 ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters())
125
126 @torch.no_grad()
127 def on_after_epoch(lr: float):
124 if use_emb_decay: 128 if use_emb_decay:
125 text_encoder.text_model.embeddings.normalize( 129 text_encoder.text_model.embeddings.normalize(
126 emb_decay_target, 130 emb_decay_target,
127 min(1.0, max(0.0, emb_decay_factor * ((lr - emb_decay_start) / (learning_rate - emb_decay_start)))) 131 min(1.0, max(0.0, emb_decay_factor * ((lr - emb_decay_start) / (learning_rate - emb_decay_start))))
128 ) 132 )
129 133
130 if use_ema:
131 ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters())
132
133 def on_log(): 134 def on_log():
134 if use_ema: 135 if use_ema:
135 return {"ema_decay": ema_embeddings.decay} 136 return {"ema_decay": ema_embeddings.decay}
@@ -157,7 +158,13 @@ def textual_inversion_strategy(
157 on_train=on_train, 158 on_train=on_train,
158 on_eval=on_eval, 159 on_eval=on_eval,
159 on_after_optimize=on_after_optimize, 160 on_after_optimize=on_after_optimize,
161 on_after_epoch=on_after_epoch,
160 on_log=on_log, 162 on_log=on_log,
161 on_checkpoint=on_checkpoint, 163 on_checkpoint=on_checkpoint,
162 on_sample=on_sample, 164 on_sample=on_sample,
163 ) 165 )
166
167
168textual_inversion_strategy = TrainingStrategy(
169 callbacks=textual_inversion_strategy_callbacks,
170)
diff --git a/training/util.py b/training/util.py
index 557b196..237626f 100644
--- a/training/util.py
+++ b/training/util.py
@@ -1,18 +1,11 @@
1from pathlib import Path 1from pathlib import Path
2import json 2import json
3import copy 3import copy
4from typing import Iterable, Union 4from typing import Iterable, Any
5from contextlib import contextmanager 5from contextlib import contextmanager
6 6
7import torch 7import torch
8 8
9from transformers import CLIPTextModel
10from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, DPMSolverMultistepScheduler
11
12from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
13from models.clip.tokenizer import MultiCLIPTokenizer
14from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings
15
16 9
17def save_args(basepath: Path, args, extra={}): 10def save_args(basepath: Path, args, extra={}):
18 info = {"args": vars(args)} 11 info = {"args": vars(args)}
@@ -22,6 +15,8 @@ def save_args(basepath: Path, args, extra={}):
22 15
23 16
24class AverageMeter: 17class AverageMeter:
18 avg: Any
19
25 def __init__(self, name=None): 20 def __init__(self, name=None):
26 self.name = name 21 self.name = name
27 self.reset() 22 self.reset()