summaryrefslogtreecommitdiffstats
path: root/training/strategy/lora.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-02-07 20:44:43 +0100
committerVolpeon <git@volpeon.ink>2023-02-07 20:44:43 +0100
commit7ccd4614a56cfd6ecacba85605f338593f1059f0 (patch)
treefa9882b256c752705bc42229bac4e00ed7088643 /training/strategy/lora.py
parentRestored LR finder (diff)
downloadtextual-inversion-diff-7ccd4614a56cfd6ecacba85605f338593f1059f0.tar.gz
textual-inversion-diff-7ccd4614a56cfd6ecacba85605f338593f1059f0.tar.bz2
textual-inversion-diff-7ccd4614a56cfd6ecacba85605f338593f1059f0.zip
Add Lora
Diffstat (limited to 'training/strategy/lora.py')
-rw-r--r--training/strategy/lora.py147
1 files changed, 147 insertions, 0 deletions
diff --git a/training/strategy/lora.py b/training/strategy/lora.py
new file mode 100644
index 0000000..88d1824
--- /dev/null
+++ b/training/strategy/lora.py
@@ -0,0 +1,147 @@
1from contextlib import nullcontext
2from typing import Optional
3from functools import partial
4from contextlib import contextmanager, nullcontext
5from pathlib import Path
6
7import torch
8import torch.nn as nn
9from torch.utils.data import DataLoader
10
11from accelerate import Accelerator
12from transformers import CLIPTextModel
13from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepScheduler
14from diffusers.loaders import AttnProcsLayers
15
16from slugify import slugify
17
18from models.clip.tokenizer import MultiCLIPTokenizer
19from training.util import EMAModel
20from training.functional import TrainingStrategy, TrainingCallbacks, save_samples
21
22
23def lora_strategy_callbacks(
24 accelerator: Accelerator,
25 unet: UNet2DConditionModel,
26 text_encoder: CLIPTextModel,
27 tokenizer: MultiCLIPTokenizer,
28 vae: AutoencoderKL,
29 sample_scheduler: DPMSolverMultistepScheduler,
30 train_dataloader: DataLoader,
31 val_dataloader: Optional[DataLoader],
32 sample_output_dir: Path,
33 checkpoint_output_dir: Path,
34 seed: int,
35 lora_layers: AttnProcsLayers,
36 max_grad_norm: float = 1.0,
37 sample_batch_size: int = 1,
38 sample_num_batches: int = 1,
39 sample_num_steps: int = 20,
40 sample_guidance_scale: float = 7.5,
41 sample_image_size: Optional[int] = None,
42):
43 sample_output_dir.mkdir(parents=True, exist_ok=True)
44 checkpoint_output_dir.mkdir(parents=True, exist_ok=True)
45
46 weight_dtype = torch.float32
47 if accelerator.state.mixed_precision == "fp16":
48 weight_dtype = torch.float16
49 elif accelerator.state.mixed_precision == "bf16":
50 weight_dtype = torch.bfloat16
51
52 save_samples_ = partial(
53 save_samples,
54 accelerator=accelerator,
55 unet=unet,
56 text_encoder=text_encoder,
57 tokenizer=tokenizer,
58 vae=vae,
59 sample_scheduler=sample_scheduler,
60 train_dataloader=train_dataloader,
61 val_dataloader=val_dataloader,
62 output_dir=sample_output_dir,
63 seed=seed,
64 batch_size=sample_batch_size,
65 num_batches=sample_num_batches,
66 num_steps=sample_num_steps,
67 guidance_scale=sample_guidance_scale,
68 image_size=sample_image_size,
69 )
70
71 def on_prepare():
72 lora_layers.requires_grad_(True)
73
74 def on_accum_model():
75 return unet
76
77 @contextmanager
78 def on_train(epoch: int):
79 tokenizer.train()
80 yield
81
82 @contextmanager
83 def on_eval():
84 tokenizer.eval()
85 yield
86
87 def on_before_optimize(lr: float, epoch: int):
88 if accelerator.sync_gradients:
89 accelerator.clip_grad_norm_(lora_layers.parameters(), max_grad_norm)
90
91 @torch.no_grad()
92 def on_checkpoint(step, postfix):
93 print(f"Saving checkpoint for step {step}...")
94 orig_unet_dtype = unet.dtype
95 unet.to(dtype=torch.float32)
96 unet.save_attn_procs(checkpoint_output_dir.joinpath(f"{step}_{postfix}"))
97 unet.to(dtype=orig_unet_dtype)
98
99 @torch.no_grad()
100 def on_sample(step):
101 orig_unet_dtype = unet.dtype
102 unet.to(dtype=weight_dtype)
103 save_samples_(step=step)
104 unet.to(dtype=orig_unet_dtype)
105
106 if torch.cuda.is_available():
107 torch.cuda.empty_cache()
108
109 return TrainingCallbacks(
110 on_prepare=on_prepare,
111 on_accum_model=on_accum_model,
112 on_train=on_train,
113 on_eval=on_eval,
114 on_before_optimize=on_before_optimize,
115 on_checkpoint=on_checkpoint,
116 on_sample=on_sample,
117 )
118
119
120def lora_prepare(
121 accelerator: Accelerator,
122 text_encoder: CLIPTextModel,
123 unet: UNet2DConditionModel,
124 optimizer: torch.optim.Optimizer,
125 train_dataloader: DataLoader,
126 val_dataloader: Optional[DataLoader],
127 lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
128 lora_layers: AttnProcsLayers,
129 **kwargs
130):
131 weight_dtype = torch.float32
132 if accelerator.state.mixed_precision == "fp16":
133 weight_dtype = torch.float16
134 elif accelerator.state.mixed_precision == "bf16":
135 weight_dtype = torch.bfloat16
136
137 lora_layers, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
138 lora_layers, optimizer, train_dataloader, val_dataloader, lr_scheduler)
139 unet.to(accelerator.device, dtype=weight_dtype)
140 text_encoder.to(accelerator.device, dtype=weight_dtype)
141 return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {"lora_layers": lora_layers}
142
143
144lora_strategy = TrainingStrategy(
145 callbacks=lora_strategy_callbacks,
146 prepare=lora_prepare,
147)