summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-13 18:59:26 +0100
committerVolpeon <git@volpeon.ink>2023-01-13 18:59:26 +0100
commit127ec21e5bd4e7df21e36c561d070f8b9a0e19f5 (patch)
tree61cb98adbf33ed08506601f8b70f1b62bc42c4ee /training
parentSimplified step calculations (diff)
downloadtextual-inversion-diff-127ec21e5bd4e7df21e36c561d070f8b9a0e19f5.tar.gz
textual-inversion-diff-127ec21e5bd4e7df21e36c561d070f8b9a0e19f5.tar.bz2
textual-inversion-diff-127ec21e5bd4e7df21e36c561d070f8b9a0e19f5.zip
More modularization
Diffstat (limited to 'training')
-rw-r--r--training/common.py260
-rw-r--r--training/lr.py14
-rw-r--r--training/modules/dreambooth.py0
-rw-r--r--training/modules/lora.py0
-rw-r--r--training/modules/ti.py284
-rw-r--r--training/util.py15
6 files changed, 541 insertions, 32 deletions
diff --git a/training/common.py b/training/common.py
index 180396e..73ce814 100644
--- a/training/common.py
+++ b/training/common.py
@@ -1,46 +1,77 @@
1import math 1import math
2from pathlib import Path
2from contextlib import _GeneratorContextManager, nullcontext 3from contextlib import _GeneratorContextManager, nullcontext
3from typing import Callable, Any, Tuple, Union 4from typing import Callable, Any, Tuple, Union, Literal, Optional, NamedTuple
5import datetime
6import logging
4 7
5import torch 8import torch
6import torch.nn.functional as F 9import torch.nn.functional as F
7from torch.utils.data import DataLoader 10from torch.utils.data import DataLoader
8 11
9from accelerate import Accelerator 12from accelerate import Accelerator
10from transformers import CLIPTokenizer, CLIPTextModel 13from accelerate.utils import LoggerType, set_seed
11from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel 14from transformers import CLIPTextModel
15from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, DPMSolverMultistepScheduler
12from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup 16from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup
13 17
14from tqdm.auto import tqdm 18from tqdm.auto import tqdm
19from slugify import slugify
15 20
21from data.csv import VlpnDataModule, VlpnDataItem
22from util import load_embeddings_from_dir
16from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 23from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
24from models.clip.embeddings import patch_managed_embeddings
17from models.clip.util import get_extended_embeddings 25from models.clip.util import get_extended_embeddings
26from models.clip.tokenizer import MultiCLIPTokenizer
18from training.optimization import get_one_cycle_schedule 27from training.optimization import get_one_cycle_schedule
19from training.util import AverageMeter, CheckpointerBase 28from training.util import AverageMeter, CheckpointerBase
20 29
21 30
31class TrainingSetup(NamedTuple):
32 accelerator: Accelerator
33 tokenizer: MultiCLIPTokenizer
34 text_encoder: CLIPTextModel
35 vae: AutoencoderKL
36 unet: UNet2DConditionModel
37 noise_scheduler: DDPMScheduler
38 checkpoint_scheduler: DPMSolverMultistepScheduler
39 optimizer_class: Callable
40 learning_rate: float
41 weight_dtype: torch.dtype
42 output_dir: Path
43 seed: int
44 train_dataloader: DataLoader
45 val_dataloader: DataLoader
46 placeholder_token: list[str]
47 placeholder_token_ids: list[list[int]]
48
49
22def noop(*args, **kwards): 50def noop(*args, **kwards):
23 pass 51 pass
24 52
25 53
54def noop_ctx(*args, **kwards):
55 return nullcontext()
56
57
26def noop_on_log(): 58def noop_on_log():
27 return {} 59 return {}
28 60
29 61
30def get_scheduler( 62def get_scheduler(
31 id: str, 63 id: str,
32 min_lr: float,
33 lr: float,
34 warmup_func: str,
35 annealing_func: str,
36 warmup_exp: int,
37 annealing_exp: int,
38 cycles: int,
39 train_epochs: int,
40 warmup_epochs: int,
41 optimizer: torch.optim.Optimizer, 64 optimizer: torch.optim.Optimizer,
42 num_training_steps_per_epoch: int, 65 num_training_steps_per_epoch: int,
43 gradient_accumulation_steps: int, 66 gradient_accumulation_steps: int,
67 min_lr: float = 0.04,
68 warmup_func: str = "cos",
69 annealing_func: str = "cos",
70 warmup_exp: int = 1,
71 annealing_exp: int = 1,
72 cycles: int = 1,
73 train_epochs: int = 100,
74 warmup_epochs: int = 10,
44): 75):
45 num_training_steps_per_epoch = math.ceil( 76 num_training_steps_per_epoch = math.ceil(
46 num_training_steps_per_epoch / gradient_accumulation_steps 77 num_training_steps_per_epoch / gradient_accumulation_steps
@@ -49,8 +80,6 @@ def get_scheduler(
49 num_warmup_steps = warmup_epochs * num_training_steps_per_epoch 80 num_warmup_steps = warmup_epochs * num_training_steps_per_epoch
50 81
51 if id == "one_cycle": 82 if id == "one_cycle":
52 min_lr = 0.04 if min_lr is None else min_lr / lr
53
54 lr_scheduler = get_one_cycle_schedule( 83 lr_scheduler = get_one_cycle_schedule(
55 optimizer=optimizer, 84 optimizer=optimizer,
56 num_training_steps=num_training_steps, 85 num_training_steps=num_training_steps,
@@ -133,6 +162,196 @@ def generate_class_images(
133 torch.cuda.empty_cache() 162 torch.cuda.empty_cache()
134 163
135 164
165def train_setup(
166 output_dir: str,
167 project: str,
168 pretrained_model_name_or_path: str,
169 learning_rate: float,
170 data_file: str,
171 gradient_accumulation_steps: int = 1,
172 mixed_precision: Literal["no", "fp16", "bf16"] = "no",
173 seed: Optional[int] = None,
174 vector_shuffle: Union[bool, Literal["all", "trailing", "leading", "between", "off"]] = "auto",
175 vector_dropout: float = 0.1,
176 gradient_checkpointing: bool = True,
177 embeddings_dir: Optional[str] = None,
178 placeholder_token: list[str] = [],
179 initializer_token: list[str] = [],
180 num_vectors: int = 1,
181 scale_lr: bool = False,
182 use_8bit_adam: bool = False,
183 train_batch_size: int = 1,
184 class_image_dir: Optional[str] = None,
185 num_class_images: int = 0,
186 resolution: int = 768,
187 num_buckets: int = 0,
188 progressive_buckets: bool = False,
189 bucket_step_size: int = 64,
190 bucket_max_pixels: Optional[int] = None,
191 tag_dropout: float = 0.1,
192 tag_shuffle: bool = True,
193 data_template: str = "template",
194 valid_set_size: Optional[int] = None,
195 valid_set_repeat: int = 1,
196 data_filter: Optional[Callable[[VlpnDataItem], bool]] = None,
197 sample_batch_size: int = 1,
198 sample_image_size: int = 768,
199 sample_steps: int = 20,
200) -> TrainingSetup:
201 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
202 output_dir = Path(output_dir).joinpath(slugify(project), now)
203 output_dir.mkdir(parents=True, exist_ok=True)
204
205 accelerator = Accelerator(
206 log_with=LoggerType.TENSORBOARD,
207 logging_dir=f"{output_dir}",
208 gradient_accumulation_steps=gradient_accumulation_steps,
209 mixed_precision=mixed_precision
210 )
211
212 logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG)
213
214 seed = seed or (torch.random.seed() >> 32)
215 set_seed(seed)
216
217 # Load the tokenizer and add the placeholder token as a additional special token
218 tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer')
219 tokenizer.set_use_vector_shuffle(vector_shuffle)
220 tokenizer.set_dropout(vector_dropout)
221
222 # Load models and create wrapper for stable diffusion
223 text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder')
224 vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae')
225 unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet')
226 noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler')
227 checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained(
228 pretrained_model_name_or_path, subfolder='scheduler')
229
230 vae.enable_slicing()
231 vae.set_use_memory_efficient_attention_xformers(True)
232 unet.set_use_memory_efficient_attention_xformers(True)
233
234 if gradient_checkpointing:
235 unet.enable_gradient_checkpointing()
236 text_encoder.gradient_checkpointing_enable()
237
238 embeddings = patch_managed_embeddings(text_encoder)
239
240 if embeddings_dir is not None:
241 embeddings_dir = Path(embeddings_dir)
242 if not embeddings_dir.exists() or not embeddings_dir.is_dir():
243 raise ValueError("--embeddings_dir must point to an existing directory")
244
245 added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir)
246 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}")
247
248 # Convert the initializer_token, placeholder_token to ids
249 initializer_token_ids = [
250 tokenizer.encode(token, add_special_tokens=False)
251 for token in initializer_token
252 ]
253
254 placeholder_token_ids = tokenizer.add_multi_tokens(placeholder_token, num_vectors)
255 embeddings.resize(len(tokenizer))
256
257 for (new_id, init_ids) in zip(placeholder_token_ids, initializer_token_ids):
258 embeddings.add_embed(new_id, init_ids)
259
260 init_ratios = [
261 f"{len(init_ids)} / {len(new_id)}"
262 for new_id, init_ids in zip(placeholder_token_ids, initializer_token_ids)
263 ]
264
265 print(f"Added {len(placeholder_token_ids)} new tokens: {list(zip(placeholder_token, placeholder_token_ids, init_ratios))}")
266
267 vae.requires_grad_(False)
268 unet.requires_grad_(False)
269 text_encoder.requires_grad_(False)
270
271 if scale_lr:
272 learning_rate = (
273 learning_rate * gradient_accumulation_steps *
274 train_batch_size * accelerator.num_processes
275 )
276
277 # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
278 if use_8bit_adam:
279 try:
280 import bitsandbytes as bnb
281 except ImportError:
282 raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.")
283
284 optimizer_class = bnb.optim.AdamW8bit
285 else:
286 optimizer_class = torch.optim.AdamW
287
288 weight_dtype = torch.float32
289 if mixed_precision == "fp16":
290 weight_dtype = torch.float16
291 elif mixed_precision == "bf16":
292 weight_dtype = torch.bfloat16
293
294 datamodule = VlpnDataModule(
295 data_file=data_file,
296 batch_size=train_batch_size,
297 tokenizer=tokenizer,
298 class_subdir=class_image_dir,
299 num_class_images=num_class_images,
300 size=resolution,
301 num_buckets=num_buckets,
302 progressive_buckets=progressive_buckets,
303 bucket_step_size=bucket_step_size,
304 bucket_max_pixels=bucket_max_pixels,
305 dropout=tag_dropout,
306 shuffle=tag_shuffle,
307 template_key=data_template,
308 valid_set_size=valid_set_size,
309 valid_set_repeat=valid_set_repeat,
310 seed=seed,
311 filter=data_filter,
312 dtype=weight_dtype
313 )
314 datamodule.setup()
315
316 train_dataloader = datamodule.train_dataloader
317 val_dataloader = datamodule.val_dataloader
318
319 train_dataloader, val_dataloader = accelerator.prepare(train_dataloader, val_dataloader)
320
321 if num_class_images != 0:
322 generate_class_images(
323 accelerator,
324 text_encoder,
325 vae,
326 unet,
327 tokenizer,
328 checkpoint_scheduler,
329 datamodule.data_train,
330 sample_batch_size,
331 sample_image_size,
332 sample_steps
333 )
334
335 return TrainingSetup(
336 accelerator=accelerator,
337 tokenizer=tokenizer,
338 text_encoder=text_encoder,
339 vae=vae,
340 unet=unet,
341 noise_scheduler=noise_scheduler,
342 checkpoint_scheduler=checkpoint_scheduler,
343 optimizer_class=optimizer_class,
344 learning_rate=learning_rate,
345 output_dir=output_dir,
346 weight_dtype=weight_dtype,
347 seed=seed,
348 train_dataloader=train_dataloader,
349 val_dataloader=val_dataloader,
350 placeholder_token=placeholder_token,
351 placeholder_token_ids=placeholder_token_ids
352 )
353
354
136def loss_step( 355def loss_step(
137 vae: AutoencoderKL, 356 vae: AutoencoderKL,
138 noise_scheduler: DDPMScheduler, 357 noise_scheduler: DDPMScheduler,
@@ -221,15 +440,14 @@ def train_loop(
221 sample_steps: int = 20, 440 sample_steps: int = 20,
222 checkpoint_frequency: int = 50, 441 checkpoint_frequency: int = 50,
223 global_step_offset: int = 0, 442 global_step_offset: int = 0,
224 gradient_accumulation_steps: int = 1,
225 num_epochs: int = 100, 443 num_epochs: int = 100,
226 on_log: Callable[[], dict[str, Any]] = noop_on_log, 444 on_log: Callable[[], dict[str, Any]] = noop_on_log,
227 on_train: Callable[[], _GeneratorContextManager] = nullcontext, 445 on_train: Callable[[int], _GeneratorContextManager] = noop_ctx,
228 on_before_optimize: Callable[[], None] = noop, 446 on_before_optimize: Callable[[int], None] = noop,
229 on_after_optimize: Callable[[float], None] = noop, 447 on_after_optimize: Callable[[float], None] = noop,
230 on_eval: Callable[[], _GeneratorContextManager] = nullcontext 448 on_eval: Callable[[], _GeneratorContextManager] = noop_ctx
231): 449):
232 num_training_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) 450 num_training_steps_per_epoch = math.ceil(len(train_dataloader) / accelerator.gradient_accumulation_steps)
233 num_val_steps_per_epoch = len(val_dataloader) 451 num_val_steps_per_epoch = len(val_dataloader)
234 452
235 num_training_steps = num_training_steps_per_epoch * num_epochs 453 num_training_steps = num_training_steps_per_epoch * num_epochs
@@ -273,14 +491,14 @@ def train_loop(
273 491
274 model.train() 492 model.train()
275 493
276 with on_train(): 494 with on_train(epoch):
277 for step, batch in enumerate(train_dataloader): 495 for step, batch in enumerate(train_dataloader):
278 with accelerator.accumulate(model): 496 with accelerator.accumulate(model):
279 loss, acc, bsz = loss_step(step, batch) 497 loss, acc, bsz = loss_step(step, batch)
280 498
281 accelerator.backward(loss) 499 accelerator.backward(loss)
282 500
283 on_before_optimize() 501 on_before_optimize(epoch)
284 502
285 optimizer.step() 503 optimizer.step()
286 lr_scheduler.step() 504 lr_scheduler.step()
diff --git a/training/lr.py b/training/lr.py
index 84e30a0..7584ba2 100644
--- a/training/lr.py
+++ b/training/lr.py
@@ -16,6 +16,10 @@ def noop(*args, **kwards):
16 pass 16 pass
17 17
18 18
19def noop_ctx(*args, **kwards):
20 return nullcontext()
21
22
19class LRFinder(): 23class LRFinder():
20 def __init__( 24 def __init__(
21 self, 25 self,
@@ -25,10 +29,10 @@ class LRFinder():
25 train_dataloader, 29 train_dataloader,
26 val_dataloader, 30 val_dataloader,
27 loss_fn: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], 31 loss_fn: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]],
28 on_train: Callable[[], _GeneratorContextManager] = nullcontext, 32 on_train: Callable[[int], _GeneratorContextManager] = noop_ctx,
29 on_before_optimize: Callable[[], None] = noop, 33 on_before_optimize: Callable[[int], None] = noop,
30 on_after_optimize: Callable[[float], None] = noop, 34 on_after_optimize: Callable[[float], None] = noop,
31 on_eval: Callable[[], _GeneratorContextManager] = nullcontext 35 on_eval: Callable[[], _GeneratorContextManager] = noop_ctx
32 ): 36 ):
33 self.accelerator = accelerator 37 self.accelerator = accelerator
34 self.model = model 38 self.model = model
@@ -86,7 +90,7 @@ class LRFinder():
86 90
87 self.model.train() 91 self.model.train()
88 92
89 with self.on_train(): 93 with self.on_train(epoch):
90 for step, batch in enumerate(self.train_dataloader): 94 for step, batch in enumerate(self.train_dataloader):
91 if step >= num_train_batches: 95 if step >= num_train_batches:
92 break 96 break
@@ -96,7 +100,7 @@ class LRFinder():
96 100
97 self.accelerator.backward(loss) 101 self.accelerator.backward(loss)
98 102
99 self.on_before_optimize() 103 self.on_before_optimize(epoch)
100 104
101 self.optimizer.step() 105 self.optimizer.step()
102 lr_scheduler.step() 106 lr_scheduler.step()
diff --git a/training/modules/dreambooth.py b/training/modules/dreambooth.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/training/modules/dreambooth.py
diff --git a/training/modules/lora.py b/training/modules/lora.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/training/modules/lora.py
diff --git a/training/modules/ti.py b/training/modules/ti.py
new file mode 100644
index 0000000..2db6f88
--- /dev/null
+++ b/training/modules/ti.py
@@ -0,0 +1,284 @@
1from typing import Literal
2from functools import partial
3from contextlib import contextmanager, nullcontext
4
5import torch
6
7from slugify import slugify
8
9from accelerate import Accelerator
10from transformers import CLIPTextModel
11from diffusers import AutoencoderKL, UNet2DConditionModel
12
13from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
14from models.clip.tokenizer import MultiCLIPTokenizer
15
16from training.common import TrainingSetup, get_scheduler, train_loop, loss_step
17from training.util import EMAModel, CheckpointerBase
18
19
20class Checkpointer(CheckpointerBase):
21 def __init__(
22 self,
23 accelerator: Accelerator,
24 vae: AutoencoderKL,
25 unet: UNet2DConditionModel,
26 tokenizer: MultiCLIPTokenizer,
27 text_encoder: CLIPTextModel,
28 ema_embeddings: EMAModel,
29 weight_dtype: torch.dtype,
30 scheduler,
31 placeholder_token,
32 placeholder_token_ids,
33 *args,
34 **kwargs
35 ):
36 super().__init__(*args, **kwargs)
37
38 self.weight_dtype = weight_dtype
39 self.accelerator = accelerator
40 self.vae = vae
41 self.unet = unet
42 self.tokenizer = tokenizer
43 self.text_encoder = text_encoder
44 self.ema_embeddings = ema_embeddings
45 self.scheduler = scheduler
46 self.placeholder_token = placeholder_token
47 self.placeholder_token_ids = placeholder_token_ids
48
49 @torch.no_grad()
50 def checkpoint(self, step, postfix):
51 print("Saving checkpoint for step %d..." % step)
52
53 checkpoints_path = self.output_dir.joinpath("checkpoints")
54 checkpoints_path.mkdir(parents=True, exist_ok=True)
55
56 text_encoder = self.accelerator.unwrap_model(self.text_encoder)
57
58 ema_context = nullcontext()
59 if self.ema_embeddings is not None:
60 ema_context = self.ema_embeddings.apply_temporary(
61 text_encoder.text_model.embeddings.temp_token_embedding.parameters())
62
63 with ema_context:
64 for (token, ids) in zip(self.placeholder_token, self.placeholder_token_ids):
65 text_encoder.text_model.embeddings.save_embed(
66 ids,
67 checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin")
68 )
69
70 del text_encoder
71
72 @torch.no_grad()
73 def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0):
74 text_encoder = self.accelerator.unwrap_model(self.text_encoder)
75
76 ema_context = nullcontext()
77 if self.ema_embeddings is not None:
78 ema_context = self.ema_embeddings.apply_temporary(
79 text_encoder.text_model.embeddings.temp_token_embedding.parameters())
80
81 with ema_context:
82 orig_dtype = text_encoder.dtype
83 text_encoder.to(dtype=self.weight_dtype)
84
85 pipeline = VlpnStableDiffusion(
86 text_encoder=text_encoder,
87 vae=self.vae,
88 unet=self.unet,
89 tokenizer=self.tokenizer,
90 scheduler=self.scheduler,
91 ).to(self.accelerator.device)
92 pipeline.set_progress_bar_config(dynamic_ncols=True)
93
94 super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta)
95
96 text_encoder.to(dtype=orig_dtype)
97
98 del text_encoder
99 del pipeline
100
101 if torch.cuda.is_available():
102 torch.cuda.empty_cache()
103
104
105def train_ti(
106 setup: TrainingSetup,
107 num_train_epochs: int = 100,
108 num_class_images: int = 0,
109 prior_loss_weight: float = 1.0,
110 use_ema: bool = False,
111 ema_inv_gamma: float = 1.0,
112 ema_power: float = 4/5,
113 ema_max_decay: float = .9999,
114 adam_beta1: float = 0.9,
115 adam_beta2: float = 0.999,
116 adam_weight_decay: float = 0,
117 adam_epsilon: float = 1e-08,
118 adam_amsgrad: bool = False,
119 lr_scheduler: Literal[
120 "linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup", "one_cycle"
121 ] = "one_cycle",
122 lr_min_lr: float = 0.04,
123 lr_warmup_func: Literal["linear", "cos"] = "cos",
124 lr_annealing_func: Literal["linear", "half_cos", "cos"] = "cos",
125 lr_warmup_exp: int = 1,
126 lr_annealing_exp: int = 1,
127 lr_cycles: int = 1,
128 lr_warmup_epochs: int = 10,
129 emb_decay_target: float = 0.4,
130 emb_decay_factor: float = 1,
131 emb_decay_start: float = 1e-4,
132 sample_image_size: int = 768,
133 sample_batch_size: int = 1,
134 sample_batches: int = 1,
135 sample_frequency: int = 10,
136 sample_steps: int = 20,
137 checkpoint_frequency: int = 50,
138 global_step_offset: int = 0,
139):
140 if use_ema:
141 ema_embeddings = EMAModel(
142 setup.text_encoder.text_model.embeddings.temp_token_embedding.parameters(),
143 inv_gamma=ema_inv_gamma,
144 power=ema_power,
145 max_value=ema_max_decay,
146 )
147 else:
148 ema_embeddings = None
149
150 setup.text_encoder.requires_grad_(True)
151 setup.text_encoder.text_model.encoder.requires_grad_(False)
152 setup.text_encoder.text_model.final_layer_norm.requires_grad_(False)
153 setup.text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
154 setup.text_encoder.text_model.embeddings.token_embedding.requires_grad_(False)
155
156 # Initialize the optimizer
157 optimizer = setup.optimizer_class(
158 setup.text_encoder.text_model.embeddings.temp_token_embedding.parameters(),
159 lr=setup.learning_rate,
160 betas=(adam_beta1, adam_beta2),
161 weight_decay=adam_weight_decay,
162 eps=adam_epsilon,
163 amsgrad=adam_amsgrad,
164 )
165
166 lr_scheduler = get_scheduler(
167 lr_scheduler,
168 optimizer=optimizer,
169 min_lr=lr_min_lr,
170 warmup_func=lr_warmup_func,
171 annealing_func=lr_annealing_func,
172 warmup_exp=lr_warmup_exp,
173 annealing_exp=lr_annealing_exp,
174 cycles=lr_cycles,
175 train_epochs=num_train_epochs,
176 warmup_epochs=lr_warmup_epochs,
177 num_training_steps_per_epoch=len(setup.train_dataloader),
178 gradient_accumulation_steps=setup.accelerator.gradient_accumulation_steps
179 )
180
181 text_encoder, optimizer, lr_scheduler = setup.accelerator.prepare(
182 setup.text_encoder, optimizer, lr_scheduler
183 )
184
185 # Move vae and unet to device
186 setup.vae.to(setup.accelerator.device, dtype=setup.weight_dtype)
187 setup.unet.to(setup.accelerator.device, dtype=setup.weight_dtype)
188
189 if use_ema:
190 ema_embeddings.to(setup.accelerator.device)
191
192 setup.unet.train()
193
194 @contextmanager
195 def on_train(epoch: int):
196 try:
197 setup.tokenizer.train()
198 yield
199 finally:
200 pass
201
202 @contextmanager
203 def on_eval():
204 try:
205 setup.tokenizer.eval()
206
207 ema_context = nullcontext()
208 if use_ema:
209 ema_context = ema_embeddings.apply_temporary(
210 text_encoder.text_model.embeddings.temp_token_embedding.parameters())
211
212 with ema_context:
213 yield
214 finally:
215 pass
216
217 @torch.no_grad()
218 def on_after_optimize(lr: float):
219 text_encoder.text_model.embeddings.normalize(
220 emb_decay_target,
221 min(1.0, max(0.0, emb_decay_factor * ((lr - emb_decay_start) / (setup.learning_rate - emb_decay_start))))
222 )
223
224 if use_ema:
225 ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters())
226
227 def on_log():
228 if use_ema:
229 return {"ema_decay": ema_embeddings.decay}
230 return {}
231
232 loss_step_ = partial(
233 loss_step,
234 setup.vae,
235 setup.noise_scheduler,
236 setup.unet,
237 text_encoder,
238 num_class_images != 0,
239 prior_loss_weight,
240 setup.seed,
241 )
242
243 checkpointer = Checkpointer(
244 accelerator=setup.accelerator,
245 vae=setup.vae,
246 unet=setup.unet,
247 tokenizer=setup.tokenizer,
248 text_encoder=text_encoder,
249 ema_embeddings=ema_embeddings,
250 weight_dtype=setup.weight_dtype,
251 scheduler=setup.checkpoint_scheduler,
252 placeholder_token=setup.placeholder_token,
253 placeholder_token_ids=setup.placeholder_token_ids,
254 train_dataloader=setup.train_dataloader,
255 val_dataloader=setup.val_dataloader,
256 output_dir=setup.output_dir,
257 seed=setup.seed,
258 sample_image_size=sample_image_size,
259 sample_batch_size=sample_batch_size,
260 sample_batches=sample_batches
261 )
262
263 if setup.accelerator.is_main_process:
264 setup.accelerator.init_trackers("textual_inversion")
265
266 train_loop(
267 accelerator=setup.accelerator,
268 optimizer=optimizer,
269 lr_scheduler=lr_scheduler,
270 model=text_encoder,
271 checkpointer=checkpointer,
272 train_dataloader=setup.train_dataloader,
273 val_dataloader=setup.val_dataloader,
274 loss_step=loss_step_,
275 sample_frequency=sample_frequency,
276 sample_steps=sample_steps,
277 checkpoint_frequency=checkpoint_frequency,
278 global_step_offset=global_step_offset,
279 num_epochs=num_train_epochs,
280 on_log=on_log,
281 on_train=on_train,
282 on_after_optimize=on_after_optimize,
283 on_eval=on_eval
284 )
diff --git a/training/util.py b/training/util.py
index 0ec2032..cc4cdee 100644
--- a/training/util.py
+++ b/training/util.py
@@ -41,14 +41,16 @@ class AverageMeter:
41class CheckpointerBase: 41class CheckpointerBase:
42 def __init__( 42 def __init__(
43 self, 43 self,
44 datamodule, 44 train_dataloader,
45 val_dataloader,
45 output_dir: Path, 46 output_dir: Path,
46 sample_image_size: int, 47 sample_image_size: int,
47 sample_batches: int, 48 sample_batches: int,
48 sample_batch_size: int, 49 sample_batch_size: int,
49 seed: Optional[int] = None 50 seed: Optional[int] = None
50 ): 51 ):
51 self.datamodule = datamodule 52 self.train_dataloader = train_dataloader
53 self.val_dataloader = val_dataloader
52 self.output_dir = output_dir 54 self.output_dir = output_dir
53 self.sample_image_size = sample_image_size 55 self.sample_image_size = sample_image_size
54 self.seed = seed if seed is not None else torch.random.seed() 56 self.seed = seed if seed is not None else torch.random.seed()
@@ -70,15 +72,16 @@ class CheckpointerBase:
70 ): 72 ):
71 samples_path = Path(self.output_dir).joinpath("samples") 73 samples_path = Path(self.output_dir).joinpath("samples")
72 74
73 train_data = self.datamodule.train_dataloader
74 val_data = self.datamodule.val_dataloader
75
76 generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) 75 generator = torch.Generator(device=pipeline.device).manual_seed(self.seed)
77 76
78 grid_cols = min(self.sample_batch_size, 4) 77 grid_cols = min(self.sample_batch_size, 4)
79 grid_rows = (self.sample_batches * self.sample_batch_size) // grid_cols 78 grid_rows = (self.sample_batches * self.sample_batch_size) // grid_cols
80 79
81 for pool, data, gen in [("stable", val_data, generator), ("val", val_data, None), ("train", train_data, None)]: 80 for pool, data, gen in [
81 ("stable", self.val_dataloader, generator),
82 ("val", self.val_dataloader, None),
83 ("train", self.train_dataloader, None)
84 ]:
82 all_samples = [] 85 all_samples = []
83 file_path = samples_path.joinpath(pool, f"step_{step}.jpg") 86 file_path = samples_path.joinpath(pool, f"step_{step}.jpg")
84 file_path.parent.mkdir(parents=True, exist_ok=True) 87 file_path.parent.mkdir(parents=True, exist_ok=True)