summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-15 22:26:43 +0100
committerVolpeon <git@volpeon.ink>2023-01-15 22:26:43 +0100
commit3f922880475c2c0a5679987d4a9a43606e838566 (patch)
tree757746927e34aa7fddff1e44c837b489233029d7 /training
parentRestored functional trainer (diff)
downloadtextual-inversion-diff-3f922880475c2c0a5679987d4a9a43606e838566.tar.gz
textual-inversion-diff-3f922880475c2c0a5679987d4a9a43606e838566.tar.bz2
textual-inversion-diff-3f922880475c2c0a5679987d4a9a43606e838566.zip
Added Dreambooth strategy
Diffstat (limited to 'training')
-rw-r--r--training/strategy/dreambooth.py183
1 files changed, 183 insertions, 0 deletions
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py
new file mode 100644
index 0000000..6e7ebe2
--- /dev/null
+++ b/training/strategy/dreambooth.py
@@ -0,0 +1,183 @@
1from contextlib import nullcontext
2from typing import Optional
3from functools import partial
4from contextlib import contextmanager, nullcontext
5from pathlib import Path
6import itertools
7
8import torch
9from torch.utils.data import DataLoader
10
11from accelerate import Accelerator
12from transformers import CLIPTextModel
13from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepScheduler
14
15from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
16from models.clip.tokenizer import MultiCLIPTokenizer
17from training.util import EMAModel
18from training.functional import TrainingCallbacks, save_samples
19
20
21def dreambooth_strategy(
22 accelerator: Accelerator,
23 unet: UNet2DConditionModel,
24 text_encoder: CLIPTextModel,
25 tokenizer: MultiCLIPTokenizer,
26 vae: AutoencoderKL,
27 sample_scheduler: DPMSolverMultistepScheduler,
28 train_dataloader: DataLoader,
29 val_dataloader: DataLoader,
30 output_dir: Path,
31 seed: int,
32 train_text_encoder_epochs: int,
33 max_grad_norm: float = 1.0,
34 use_ema: bool = False,
35 ema_inv_gamma: float = 1.0,
36 ema_power: int = 1,
37 ema_max_decay: float = 0.9999,
38 sample_batch_size: int = 1,
39 sample_num_batches: int = 1,
40 sample_num_steps: int = 20,
41 sample_guidance_scale: float = 7.5,
42 sample_image_size: Optional[int] = None,
43):
44 if accelerator.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
45 raise ValueError(
46 "Gradient accumulation is not supported when training the text encoder in distributed training. "
47 "Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
48 )
49
50 weight_dtype = torch.float32
51 if accelerator.state.mixed_precision == "fp16":
52 weight_dtype = torch.float16
53 elif accelerator.state.mixed_precision == "bf16":
54 weight_dtype = torch.bfloat16
55
56 save_samples_ = partial(
57 save_samples,
58 accelerator=accelerator,
59 unet=unet,
60 text_encoder=text_encoder,
61 tokenizer=tokenizer,
62 vae=vae,
63 sample_scheduler=sample_scheduler,
64 train_dataloader=train_dataloader,
65 val_dataloader=val_dataloader,
66 dtype=weight_dtype,
67 output_dir=output_dir,
68 seed=seed,
69 batch_size=sample_batch_size,
70 num_batches=sample_num_batches,
71 num_steps=sample_num_steps,
72 guidance_scale=sample_guidance_scale,
73 image_size=sample_image_size,
74 )
75
76 if use_ema:
77 ema_unet = EMAModel(
78 unet.parameters(),
79 inv_gamma=ema_inv_gamma,
80 power=ema_power,
81 max_value=ema_max_decay,
82 )
83 else:
84 ema_unet = None
85
86 def ema_context():
87 if use_ema:
88 return ema_unet.apply_temporary(unet.parameters())
89 else:
90 return nullcontext()
91
92 def on_model():
93 return unet
94
95 def on_prepare():
96 unet.requires_grad_(True)
97 text_encoder.requires_grad_(True)
98 text_encoder.text_model.embeddings.persist()
99 text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(False)
100
101 if use_ema:
102 ema_unet.to(accelerator.device)
103
104 @contextmanager
105 def on_train(epoch: int):
106 tokenizer.train()
107
108 if epoch < train_text_encoder_epochs:
109 text_encoder.train()
110 elif epoch == train_text_encoder_epochs:
111 text_encoder.requires_grad_(False)
112 text_encoder.eval()
113
114 yield
115
116 @contextmanager
117 def on_eval():
118 tokenizer.eval()
119 text_encoder.eval()
120
121 with ema_context():
122 yield
123
124 def on_before_optimize(epoch: int):
125 if accelerator.sync_gradients:
126 params_to_clip = [unet.parameters()]
127 if epoch < train_text_encoder_epochs:
128 params_to_clip.append(text_encoder.parameters())
129 accelerator.clip_grad_norm_(itertools.chain(*params_to_clip), max_grad_norm)
130
131 @torch.no_grad()
132 def on_after_optimize(lr: float):
133 if use_ema:
134 ema_unet.step(unet.parameters())
135
136 def on_log():
137 if use_ema:
138 return {"ema_decay": ema_unet.decay}
139 return {}
140
141 @torch.no_grad()
142 def on_checkpoint(step, postfix):
143 if postfix != "end":
144 return
145
146 print("Saving model...")
147
148 unet_ = accelerator.unwrap_model(unet)
149 text_encoder_ = accelerator.unwrap_model(text_encoder)
150
151 with ema_context():
152 pipeline = VlpnStableDiffusion(
153 text_encoder=text_encoder_,
154 vae=vae,
155 unet=unet_,
156 tokenizer=tokenizer,
157 scheduler=sample_scheduler,
158 )
159 pipeline.save_pretrained(output_dir.joinpath("model"))
160
161 del unet_
162 del text_encoder_
163 del pipeline
164
165 if torch.cuda.is_available():
166 torch.cuda.empty_cache()
167
168 @torch.no_grad()
169 def on_sample(step):
170 with ema_context():
171 save_samples_(step=step)
172
173 return TrainingCallbacks(
174 on_prepare=on_prepare,
175 on_model=on_model,
176 on_train=on_train,
177 on_eval=on_eval,
178 on_before_optimize=on_before_optimize,
179 on_after_optimize=on_after_optimize,
180 on_log=on_log,
181 on_checkpoint=on_checkpoint,
182 on_sample=on_sample,
183 )