summaryrefslogtreecommitdiffstats
path: root/trainer/base.py
diff options
context:
space:
mode:
Diffstat (limited to 'trainer/base.py')
-rw-r--r--trainer/base.py544
1 files changed, 544 insertions, 0 deletions
diff --git a/trainer/base.py b/trainer/base.py
new file mode 100644
index 0000000..e700dd6
--- /dev/null
+++ b/trainer/base.py
@@ -0,0 +1,544 @@
1from pathlib import Path
2import math
3from contextlib import contextmanager
4from typing import Type, Optional
5import itertools
6from functools import partial
7
8import torch
9import torch.nn as nn
10import torch.nn.functional as F
11from torch.utils.data import DataLoader
12
13from accelerate import Accelerator
14from transformers import CLIPTextModel
15from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, DPMSolverMultistepScheduler
16
17from tqdm.auto import tqdm
18from PIL import Image
19
20from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
21from models.clip.tokenizer import MultiCLIPTokenizer
22from models.clip.util import get_extended_embeddings
23from training.util import AverageMeter
24
25
26def make_grid(images, rows, cols):
27 w, h = images[0].size
28 grid = Image.new('RGB', size=(cols*w, rows*h))
29 for i, image in enumerate(images):
30 grid.paste(image, box=(i % cols*w, i//cols*h))
31 return grid
32
33
34class Checkpointer():
35 def __init__(
36 self,
37 accelerator: Accelerator,
38 vae: AutoencoderKL,
39 unet: UNet2DConditionModel,
40 text_encoder: CLIPTextModel,
41 tokenizer: MultiCLIPTokenizer,
42 sample_scheduler,
43 dtype,
44 train_dataloader: DataLoader,
45 val_dataloader: DataLoader,
46 output_dir: Path,
47 sample_steps: int = 20,
48 sample_guidance_scale: float = 7.5,
49 sample_image_size: int = 768,
50 sample_batches: int = 1,
51 sample_batch_size: int = 1,
52 seed: Optional[int] = None,
53 *args,
54 **kwargs,
55 ):
56 self.accelerator = accelerator
57 self.vae = vae
58 self.unet = unet
59 self.text_encoder = text_encoder
60 self.tokenizer = tokenizer
61 self.sample_scheduler = sample_scheduler
62 self.dtype = dtype
63 self.train_dataloader = train_dataloader
64 self.val_dataloader = val_dataloader
65 self.output_dir = output_dir
66 self.sample_steps = sample_steps
67 self.sample_guidance_scale = sample_guidance_scale
68 self.sample_image_size = sample_image_size
69 self.sample_batches = sample_batches
70 self.sample_batch_size = sample_batch_size
71 self.seed = seed if seed is not None else torch.random.seed()
72
73 @torch.no_grad()
74 def checkpoint(self, step: int, postfix: str):
75 pass
76
77 @torch.inference_mode()
78 def save_samples(self, step: int):
79 print(f"Saving samples for step {step}...")
80
81 samples_path = self.output_dir.joinpath("samples")
82
83 grid_cols = min(self.sample_batch_size, 4)
84 grid_rows = (self.sample_batches * self.sample_batch_size) // grid_cols
85
86 unet = self.accelerator.unwrap_model(self.unet)
87 text_encoder = self.accelerator.unwrap_model(self.text_encoder)
88
89 orig_unet_dtype = unet.dtype
90 orig_text_encoder_dtype = text_encoder.dtype
91
92 unet.to(dtype=self.dtype)
93 text_encoder.to(dtype=self.dtype)
94
95 pipeline = VlpnStableDiffusion(
96 text_encoder=text_encoder,
97 vae=self.vae,
98 unet=self.unet,
99 tokenizer=self.tokenizer,
100 scheduler=self.sample_scheduler,
101 ).to(self.accelerator.device)
102 pipeline.set_progress_bar_config(dynamic_ncols=True)
103
104 generator = torch.Generator(device=self.accelerator.device).manual_seed(self.seed)
105
106 for pool, data, gen in [
107 ("stable", self.val_dataloader, generator),
108 ("val", self.val_dataloader, None),
109 ("train", self.train_dataloader, None)
110 ]:
111 all_samples = []
112 file_path = samples_path.joinpath(pool, f"step_{step}.jpg")
113 file_path.parent.mkdir(parents=True, exist_ok=True)
114
115 batches = list(itertools.islice(itertools.cycle(data), self.sample_batch_size * self.sample_batches))
116 prompt_ids = [
117 prompt
118 for batch in batches
119 for prompt in batch["prompt_ids"]
120 ]
121 nprompt_ids = [
122 prompt
123 for batch in batches
124 for prompt in batch["nprompt_ids"]
125 ]
126
127 for i in range(self.sample_batches):
128 start = i * self.sample_batch_size
129 end = (i + 1) * self.sample_batch_size
130 prompt = prompt_ids[start:end]
131 nprompt = nprompt_ids[start:end]
132
133 samples = pipeline(
134 prompt=prompt,
135 negative_prompt=nprompt,
136 height=self.sample_image_size,
137 width=self.sample_image_size,
138 generator=gen,
139 guidance_scale=self.sample_guidance_scale,
140 num_inference_steps=self.sample_steps,
141 output_type='pil'
142 ).images
143
144 all_samples += samples
145
146 image_grid = make_grid(all_samples, grid_rows, grid_cols)
147 image_grid.save(file_path, quality=85)
148
149 unet.to(dtype=orig_unet_dtype)
150 text_encoder.to(dtype=orig_text_encoder_dtype)
151
152 del unet
153 del text_encoder
154 del generator
155 del pipeline
156
157 if torch.cuda.is_available():
158 torch.cuda.empty_cache()
159
160
161class TrainingStrategy():
162 def __init__(
163 self,
164 tokenizer: MultiCLIPTokenizer,
165 *args,
166 **kwargs,
167 ):
168 self.tokenizer = tokenizer
169 self.checkpointer = Checkpointer(tokenizer=tokenizer, *args, **kwargs)
170
171 @property
172 def main_model(self) -> nn.Module:
173 ...
174
175 @contextmanager
176 def on_train(self, epoch: int):
177 try:
178 self.tokenizer.train()
179 yield
180 finally:
181 pass
182
183 @contextmanager
184 def on_eval(self):
185 try:
186 self.tokenizer.eval()
187 yield
188 finally:
189 pass
190
191 def on_before_optimize(self, epoch: int):
192 ...
193
194 def on_after_optimize(self, lr: float):
195 ...
196
197 def on_log():
198 return {}
199
200
201def loss_step(
202 vae: AutoencoderKL,
203 unet: UNet2DConditionModel,
204 text_encoder: CLIPTextModel,
205 seed: int,
206 noise_scheduler,
207 prior_loss_weight: float,
208 step: int,
209 batch: dict,
210 eval: bool = False
211):
212 # Convert images to latent space
213 latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach()
214 latents = latents * 0.18215
215
216 generator = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None
217
218 # Sample noise that we'll add to the latents
219 noise = torch.randn(
220 latents.shape,
221 dtype=latents.dtype,
222 layout=latents.layout,
223 device=latents.device,
224 generator=generator
225 )
226 bsz = latents.shape[0]
227 # Sample a random timestep for each image
228 timesteps = torch.randint(
229 0,
230 noise_scheduler.config.num_train_timesteps,
231 (bsz,),
232 generator=generator,
233 device=latents.device,
234 )
235 timesteps = timesteps.long()
236
237 # Add noise to the latents according to the noise magnitude at each timestep
238 # (this is the forward diffusion process)
239 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
240 noisy_latents = noisy_latents.to(dtype=unet.dtype)
241
242 # Get the text embedding for conditioning
243 encoder_hidden_states = get_extended_embeddings(
244 text_encoder,
245 batch["input_ids"],
246 batch["attention_mask"]
247 )
248 encoder_hidden_states = encoder_hidden_states.to(dtype=unet.dtype)
249
250 # Predict the noise residual
251 model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
252
253 # Get the target for loss depending on the prediction type
254 if noise_scheduler.config.prediction_type == "epsilon":
255 target = noise
256 elif noise_scheduler.config.prediction_type == "v_prediction":
257 target = noise_scheduler.get_velocity(latents, noise, timesteps)
258 else:
259 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
260
261 if batch["with_prior"].all():
262 # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
263 model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
264 target, target_prior = torch.chunk(target, 2, dim=0)
265
266 # Compute instance loss
267 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
268
269 # Compute prior loss
270 prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
271
272 # Add the prior loss to the instance loss.
273 loss = loss + prior_loss_weight * prior_loss
274 else:
275 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
276
277 acc = (model_pred == target).float().mean()
278
279 return loss, acc, bsz
280
281
282def train_loop(
283 strategy: TrainingStrategy,
284 accelerator: Accelerator,
285 vae: AutoencoderKL,
286 unet: UNet2DConditionModel,
287 text_encoder: CLIPTextModel,
288 train_dataloader: DataLoader,
289 val_dataloader: DataLoader,
290 seed: int,
291 optimizer: torch.optim.Optimizer,
292 lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
293 noise_scheduler,
294 prior_loss_weight: float = 1.0,
295 sample_frequency: int = 10,
296 checkpoint_frequency: int = 50,
297 global_step_offset: int = 0,
298 num_epochs: int = 100,
299):
300 num_training_steps_per_epoch = math.ceil(
301 len(train_dataloader) / accelerator.gradient_accumulation_steps
302 )
303 num_val_steps_per_epoch = len(val_dataloader)
304
305 num_training_steps = num_training_steps_per_epoch * num_epochs
306 num_val_steps = num_val_steps_per_epoch * num_epochs
307
308 global_step = 0
309
310 avg_loss = AverageMeter()
311 avg_acc = AverageMeter()
312
313 avg_loss_val = AverageMeter()
314 avg_acc_val = AverageMeter()
315
316 max_acc_val = 0.0
317
318 local_progress_bar = tqdm(
319 range(num_training_steps_per_epoch + num_val_steps_per_epoch),
320 disable=not accelerator.is_local_main_process,
321 dynamic_ncols=True
322 )
323 local_progress_bar.set_description(f"Epoch 1 / {num_epochs}")
324
325 global_progress_bar = tqdm(
326 range(num_training_steps + num_val_steps),
327 disable=not accelerator.is_local_main_process,
328 dynamic_ncols=True
329 )
330 global_progress_bar.set_description("Total progress")
331
332 loss_step_ = partial(
333 loss_step,
334 vae,
335 unet,
336 text_encoder,
337 seed,
338 noise_scheduler,
339 prior_loss_weight
340 )
341
342 try:
343 for epoch in range(num_epochs):
344 if accelerator.is_main_process:
345 if epoch % sample_frequency == 0 and epoch != 0:
346 strategy.checkpointer.save_samples(global_step + global_step_offset)
347
348 if epoch % checkpoint_frequency == 0 and epoch != 0:
349 strategy.checkpointer.checkpoint(global_step + global_step_offset, "training")
350
351 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}")
352 local_progress_bar.reset()
353
354 strategy.main_model.train()
355
356 with strategy.on_train(epoch):
357 for step, batch in enumerate(train_dataloader):
358 with accelerator.accumulate(strategy.main_model):
359 loss, acc, bsz = loss_step_(step, batch)
360
361 accelerator.backward(loss)
362
363 strategy.on_before_optimize(epoch)
364
365 optimizer.step()
366 lr_scheduler.step()
367 optimizer.zero_grad(set_to_none=True)
368
369 avg_loss.update(loss.detach_(), bsz)
370 avg_acc.update(acc.detach_(), bsz)
371
372 # Checks if the accelerator has performed an optimization step behind the scenes
373 if accelerator.sync_gradients:
374 strategy.on_after_optimize(lr_scheduler.get_last_lr()[0])
375
376 local_progress_bar.update(1)
377 global_progress_bar.update(1)
378
379 global_step += 1
380
381 logs = {
382 "train/loss": avg_loss.avg.item(),
383 "train/acc": avg_acc.avg.item(),
384 "train/cur_loss": loss.item(),
385 "train/cur_acc": acc.item(),
386 "lr": lr_scheduler.get_last_lr()[0],
387 }
388 logs.update(strategy.on_log())
389
390 accelerator.log(logs, step=global_step)
391
392 local_progress_bar.set_postfix(**logs)
393
394 if global_step >= num_training_steps:
395 break
396
397 accelerator.wait_for_everyone()
398
399 strategy.main_model.eval()
400
401 cur_loss_val = AverageMeter()
402 cur_acc_val = AverageMeter()
403
404 with torch.inference_mode(), strategy.on_eval():
405 for step, batch in enumerate(val_dataloader):
406 loss, acc, bsz = loss_step_(step, batch, True)
407
408 loss = loss.detach_()
409 acc = acc.detach_()
410
411 cur_loss_val.update(loss, bsz)
412 cur_acc_val.update(acc, bsz)
413
414 avg_loss_val.update(loss, bsz)
415 avg_acc_val.update(acc, bsz)
416
417 local_progress_bar.update(1)
418 global_progress_bar.update(1)
419
420 logs = {
421 "val/loss": avg_loss_val.avg.item(),
422 "val/acc": avg_acc_val.avg.item(),
423 "val/cur_loss": loss.item(),
424 "val/cur_acc": acc.item(),
425 }
426 local_progress_bar.set_postfix(**logs)
427
428 logs["val/cur_loss"] = cur_loss_val.avg.item()
429 logs["val/cur_acc"] = cur_acc_val.avg.item()
430
431 accelerator.log(logs, step=global_step)
432
433 local_progress_bar.clear()
434 global_progress_bar.clear()
435
436 if accelerator.is_main_process:
437 if avg_acc_val.avg.item() > max_acc_val:
438 accelerator.print(
439 f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}")
440 strategy.checkpointer.checkpoint(global_step + global_step_offset, "milestone")
441 max_acc_val = avg_acc_val.avg.item()
442
443 # Create the pipeline using using the trained modules and save it.
444 if accelerator.is_main_process:
445 print("Finished!")
446 strategy.checkpointer.checkpoint(global_step + global_step_offset, "end")
447 strategy.checkpointer.save_samples(global_step + global_step_offset)
448 accelerator.end_training()
449
450 except KeyboardInterrupt:
451 if accelerator.is_main_process:
452 print("Interrupted")
453 strategy.checkpointer.checkpoint(global_step + global_step_offset, "end")
454 accelerator.end_training()
455
456
457class Trainer():
458 def __init__(
459 self,
460 accelerator: Accelerator,
461 unet: UNet2DConditionModel,
462 text_encoder: CLIPTextModel,
463 tokenizer: MultiCLIPTokenizer,
464 vae: AutoencoderKL,
465 noise_scheduler: DDPMScheduler,
466 sample_scheduler: DPMSolverMultistepScheduler,
467 train_dataloader: DataLoader,
468 val_dataloader: DataLoader,
469 dtype: torch.dtype,
470 ):
471 self.accelerator = accelerator
472 self.unet = unet
473 self.text_encoder = text_encoder
474 self.tokenizer = tokenizer
475 self.vae = vae
476 self.noise_scheduler = noise_scheduler
477 self.sample_scheduler = sample_scheduler
478 self.train_dataloader = train_dataloader
479 self.val_dataloader = val_dataloader
480 self.dtype = dtype
481
482 def __call__(
483 self,
484 strategy_class: Type[TrainingStrategy],
485 optimizer,
486 lr_scheduler,
487 num_train_epochs: int = 100,
488 sample_frequency: int = 20,
489 checkpoint_frequency: int = 50,
490 global_step_offset: int = 0,
491 prior_loss_weight: float = 0,
492 seed: Optional[int] = None,
493 **kwargs,
494 ):
495 unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = self.accelerator.prepare(
496 self.unet, self.text_encoder, optimizer, self.train_dataloader, self.val_dataloader, lr_scheduler
497 )
498
499 self.vae.to(self.accelerator.device, dtype=self.dtype)
500
501 for model in (unet, text_encoder, self.vae):
502 model.requires_grad_(False)
503 model.eval()
504
505 if seed is None:
506 seed = torch.random.seed()
507
508 strategy = strategy_class(
509 accelerator=self.accelerator,
510 vae=self.vae,
511 unet=unet,
512 text_encoder=text_encoder,
513 tokenizer=self.tokenizer,
514 sample_scheduler=self.sample_scheduler,
515 train_dataloader=train_dataloader,
516 val_dataloader=val_dataloader,
517 dtype=self.dtype,
518 seed=seed,
519 **kwargs
520 )
521
522 if self.accelerator.is_main_process:
523 self.accelerator.init_trackers("textual_inversion")
524
525 train_loop(
526 strategy=strategy,
527 accelerator=self.accelerator,
528 vae=self.vae,
529 unet=unet,
530 text_encoder=text_encoder,
531 train_dataloader=train_dataloader,
532 val_dataloader=val_dataloader,
533 seed=seed,
534 optimizer=optimizer,
535 lr_scheduler=lr_scheduler,
536 noise_scheduler=self.noise_scheduler,
537 prior_loss_weight=prior_loss_weight,
538 sample_frequency=sample_frequency,
539 checkpoint_frequency=checkpoint_frequency,
540 global_step_offset=global_step_offset,
541 num_epochs=num_train_epochs,
542 )
543
544 self.accelerator.free_memory()