summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-13 13:49:35 +0100
committerVolpeon <git@volpeon.ink>2023-01-13 13:49:35 +0100
commit7b149930bb53b93db74106ad20a30abf4b114f9b (patch)
tree67c2ccbce2a9838ad8a020ee527b19113e67e30a /training
parentAdded TI decay start offset (diff)
downloadtextual-inversion-diff-7b149930bb53b93db74106ad20a30abf4b114f9b.tar.gz
textual-inversion-diff-7b149930bb53b93db74106ad20a30abf4b114f9b.tar.bz2
textual-inversion-diff-7b149930bb53b93db74106ad20a30abf4b114f9b.zip
Removed PromptProcessor, modularized training loop
Diffstat (limited to 'training')
-rw-r--r--training/common.py205
-rw-r--r--training/util.py13
2 files changed, 208 insertions, 10 deletions
diff --git a/training/common.py b/training/common.py
index 90cf910..842ac07 100644
--- a/training/common.py
+++ b/training/common.py
@@ -1,14 +1,30 @@
1import math 1import math
2from contextlib import _GeneratorContextManager, nullcontext
3from typing import Callable, Any, Tuple, Union
2 4
3import torch 5import torch
4import torch.nn.functional as F 6import torch.nn.functional as F
7from torch.utils.data import DataLoader
5 8
9from accelerate import Accelerator
10from transformers import CLIPTokenizer, CLIPTextModel
6from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel 11from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
7from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup 12from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup
8 13
9from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 14from tqdm.auto import tqdm
10 15
16from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
17from models.clip.util import get_extended_embeddings
11from training.optimization import get_one_cycle_schedule 18from training.optimization import get_one_cycle_schedule
19from training.util import AverageMeter, CheckpointerBase
20
21
22def noop(*args, **kwards):
23 pass
24
25
26def noop_on_log():
27 return {}
12 28
13 29
14def get_scheduler( 30def get_scheduler(
@@ -22,10 +38,11 @@ def get_scheduler(
22 cycles: int, 38 cycles: int,
23 warmup_epochs: int, 39 warmup_epochs: int,
24 optimizer: torch.optim.Optimizer, 40 optimizer: torch.optim.Optimizer,
25 max_train_steps: int, 41 num_train_epochs: int,
26 num_update_steps_per_epoch: int, 42 num_update_steps_per_epoch: int,
27 gradient_accumulation_steps: int, 43 gradient_accumulation_steps: int,
28): 44):
45 num_train_steps = num_train_epochs * num_update_steps_per_epoch
29 warmup_steps = warmup_epochs * num_update_steps_per_epoch * gradient_accumulation_steps 46 warmup_steps = warmup_epochs * num_update_steps_per_epoch * gradient_accumulation_steps
30 47
31 if id == "one_cycle": 48 if id == "one_cycle":
@@ -33,7 +50,7 @@ def get_scheduler(
33 50
34 lr_scheduler = get_one_cycle_schedule( 51 lr_scheduler = get_one_cycle_schedule(
35 optimizer=optimizer, 52 optimizer=optimizer,
36 num_training_steps=max_train_steps * gradient_accumulation_steps, 53 num_training_steps=num_train_steps * gradient_accumulation_steps,
37 warmup=warmup_func, 54 warmup=warmup_func,
38 annealing=annealing_func, 55 annealing=annealing_func,
39 warmup_exp=warmup_exp, 56 warmup_exp=warmup_exp,
@@ -42,12 +59,12 @@ def get_scheduler(
42 ) 59 )
43 elif id == "cosine_with_restarts": 60 elif id == "cosine_with_restarts":
44 cycles = cycles if cycles is not None else math.ceil( 61 cycles = cycles if cycles is not None else math.ceil(
45 math.sqrt(((max_train_steps - warmup_steps) / num_update_steps_per_epoch))) 62 math.sqrt(((num_train_steps - warmup_steps) / num_update_steps_per_epoch)))
46 63
47 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( 64 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
48 optimizer=optimizer, 65 optimizer=optimizer,
49 num_warmup_steps=warmup_steps, 66 num_warmup_steps=warmup_steps,
50 num_training_steps=max_train_steps * gradient_accumulation_steps, 67 num_training_steps=num_train_steps * gradient_accumulation_steps,
51 num_cycles=cycles, 68 num_cycles=cycles,
52 ) 69 )
53 else: 70 else:
@@ -55,7 +72,7 @@ def get_scheduler(
55 id, 72 id,
56 optimizer=optimizer, 73 optimizer=optimizer,
57 num_warmup_steps=warmup_steps, 74 num_warmup_steps=warmup_steps,
58 num_training_steps=max_train_steps * gradient_accumulation_steps, 75 num_training_steps=num_train_steps * gradient_accumulation_steps,
59 ) 76 )
60 77
61 return lr_scheduler 78 return lr_scheduler
@@ -117,12 +134,12 @@ def loss_step(
117 vae: AutoencoderKL, 134 vae: AutoencoderKL,
118 noise_scheduler: DDPMScheduler, 135 noise_scheduler: DDPMScheduler,
119 unet: UNet2DConditionModel, 136 unet: UNet2DConditionModel,
120 prompt_processor, 137 text_encoder: CLIPTextModel,
121 num_class_images: int, 138 num_class_images: int,
122 prior_loss_weight: float, 139 prior_loss_weight: float,
123 seed: int, 140 seed: int,
124 step: int, 141 step: int,
125 batch, 142 batch: dict[str, Any],
126 eval: bool = False 143 eval: bool = False
127): 144):
128 # Convert images to latent space 145 # Convert images to latent space
@@ -149,7 +166,8 @@ def loss_step(
149 noisy_latents = noisy_latents.to(dtype=unet.dtype) 166 noisy_latents = noisy_latents.to(dtype=unet.dtype)
150 167
151 # Get the text embedding for conditioning 168 # Get the text embedding for conditioning
152 encoder_hidden_states = prompt_processor.get_embeddings( 169 encoder_hidden_states = get_extended_embeddings(
170 text_encoder,
153 batch["input_ids"], 171 batch["input_ids"],
154 batch["attention_mask"] 172 batch["attention_mask"]
155 ) 173 )
@@ -185,3 +203,172 @@ def loss_step(
185 acc = (model_pred == target).float().mean() 203 acc = (model_pred == target).float().mean()
186 204
187 return loss, acc, bsz 205 return loss, acc, bsz
206
207
208def train_loop(
209 accelerator: Accelerator,
210 optimizer: torch.optim.Optimizer,
211 lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
212 model: torch.nn.Module,
213 checkpointer: CheckpointerBase,
214 train_dataloader: DataLoader,
215 val_dataloader: DataLoader,
216 loss_step: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]],
217 sample_frequency: int = 10,
218 sample_steps: int = 20,
219 checkpoint_frequency: int = 50,
220 global_step_offset: int = 0,
221 gradient_accumulation_steps: int = 1,
222 num_epochs: int = 100,
223 on_log: Callable[[], dict[str, Any]] = noop_on_log,
224 on_train: Callable[[], _GeneratorContextManager] = nullcontext,
225 on_before_optimize: Callable[[], None] = noop,
226 on_after_optimize: Callable[[float], None] = noop,
227 on_eval: Callable[[], _GeneratorContextManager] = nullcontext
228):
229 num_update_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps)
230 num_train_steps = num_epochs * num_update_steps_per_epoch
231
232 num_val_steps_per_epoch = len(val_dataloader)
233 num_epochs = math.ceil(num_train_steps / num_update_steps_per_epoch)
234 num_val_steps = num_val_steps_per_epoch * num_epochs
235
236 global_step = 0
237
238 avg_loss = AverageMeter()
239 avg_acc = AverageMeter()
240
241 avg_loss_val = AverageMeter()
242 avg_acc_val = AverageMeter()
243
244 max_acc_val = 0.0
245
246 local_progress_bar = tqdm(
247 range(num_update_steps_per_epoch + num_val_steps_per_epoch),
248 disable=not accelerator.is_local_main_process,
249 dynamic_ncols=True
250 )
251 local_progress_bar.set_description(f"Epoch 1 / {num_epochs}")
252
253 global_progress_bar = tqdm(
254 range(num_train_steps + num_val_steps),
255 disable=not accelerator.is_local_main_process,
256 dynamic_ncols=True
257 )
258 global_progress_bar.set_description("Total progress")
259
260 try:
261 for epoch in range(num_epochs):
262 if accelerator.is_main_process:
263 if epoch % sample_frequency == 0:
264 checkpointer.save_samples(global_step + global_step_offset, sample_steps)
265
266 if epoch % checkpoint_frequency == 0 and epoch != 0:
267 checkpointer.checkpoint(global_step + global_step_offset, "training")
268
269 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}")
270 local_progress_bar.reset()
271
272 model.train()
273
274 with on_train():
275 for step, batch in enumerate(train_dataloader):
276 with accelerator.accumulate(model):
277 loss, acc, bsz = loss_step(step, batch)
278
279 accelerator.backward(loss)
280
281 on_before_optimize()
282
283 optimizer.step()
284 lr_scheduler.step()
285 optimizer.zero_grad(set_to_none=True)
286
287 avg_loss.update(loss.detach_(), bsz)
288 avg_acc.update(acc.detach_(), bsz)
289
290 # Checks if the accelerator has performed an optimization step behind the scenes
291 if accelerator.sync_gradients:
292 on_after_optimize(lr_scheduler.get_last_lr()[0])
293
294 local_progress_bar.update(1)
295 global_progress_bar.update(1)
296
297 global_step += 1
298
299 logs = {
300 "train/loss": avg_loss.avg.item(),
301 "train/acc": avg_acc.avg.item(),
302 "train/cur_loss": loss.item(),
303 "train/cur_acc": acc.item(),
304 "lr": lr_scheduler.get_last_lr()[0],
305 }
306 logs.update(on_log())
307
308 accelerator.log(logs, step=global_step)
309
310 local_progress_bar.set_postfix(**logs)
311
312 if global_step >= num_train_steps:
313 break
314
315 accelerator.wait_for_everyone()
316
317 model.eval()
318
319 cur_loss_val = AverageMeter()
320 cur_acc_val = AverageMeter()
321
322 with torch.inference_mode():
323 with on_eval():
324 for step, batch in enumerate(val_dataloader):
325 loss, acc, bsz = loss_step(step, batch, True)
326
327 loss = loss.detach_()
328 acc = acc.detach_()
329
330 cur_loss_val.update(loss, bsz)
331 cur_acc_val.update(acc, bsz)
332
333 avg_loss_val.update(loss, bsz)
334 avg_acc_val.update(acc, bsz)
335
336 local_progress_bar.update(1)
337 global_progress_bar.update(1)
338
339 logs = {
340 "val/loss": avg_loss_val.avg.item(),
341 "val/acc": avg_acc_val.avg.item(),
342 "val/cur_loss": loss.item(),
343 "val/cur_acc": acc.item(),
344 }
345 local_progress_bar.set_postfix(**logs)
346
347 logs["val/cur_loss"] = cur_loss_val.avg.item()
348 logs["val/cur_acc"] = cur_acc_val.avg.item()
349
350 accelerator.log(logs, step=global_step)
351
352 local_progress_bar.clear()
353 global_progress_bar.clear()
354
355 if accelerator.is_main_process:
356 if avg_acc_val.avg.item() > max_acc_val:
357 accelerator.print(
358 f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}")
359 checkpointer.checkpoint(global_step + global_step_offset, "milestone")
360 max_acc_val = avg_acc_val.avg.item()
361
362 # Create the pipeline using using the trained modules and save it.
363 if accelerator.is_main_process:
364 print("Finished!")
365 checkpointer.checkpoint(global_step + global_step_offset, "end")
366 checkpointer.save_samples(global_step + global_step_offset, sample_steps)
367 accelerator.end_training()
368
369 except KeyboardInterrupt:
370 if accelerator.is_main_process:
371 print("Interrupted")
372 checkpointer.checkpoint(global_step + global_step_offset, "end")
373 accelerator.end_training()
374 quit()
diff --git a/training/util.py b/training/util.py
index 60d64f0..0ec2032 100644
--- a/training/util.py
+++ b/training/util.py
@@ -55,8 +55,19 @@ class CheckpointerBase:
55 self.sample_batches = sample_batches 55 self.sample_batches = sample_batches
56 self.sample_batch_size = sample_batch_size 56 self.sample_batch_size = sample_batch_size
57 57
58 @torch.no_grad()
59 def checkpoint(self, step: int, postfix: str):
60 pass
61
58 @torch.inference_mode() 62 @torch.inference_mode()
59 def save_samples(self, pipeline, step, num_inference_steps, guidance_scale=7.5, eta=0.0): 63 def save_samples(
64 self,
65 pipeline,
66 step: int,
67 num_inference_steps: int,
68 guidance_scale: float = 7.5,
69 eta: float = 0.0
70 ):
60 samples_path = Path(self.output_dir).joinpath("samples") 71 samples_path = Path(self.output_dir).joinpath("samples")
61 72
62 train_data = self.datamodule.train_dataloader 73 train_data = self.datamodule.train_dataloader