summaryrefslogtreecommitdiffstats
path: root/trainer_old/base.py
diff options
context:
space:
mode:
Diffstat (limited to 'trainer_old/base.py')
-rw-r--r--trainer_old/base.py538
1 files changed, 0 insertions, 538 deletions
diff --git a/trainer_old/base.py b/trainer_old/base.py
deleted file mode 100644
index 5903d96..0000000
--- a/trainer_old/base.py
+++ /dev/null
@@ -1,538 +0,0 @@
1from pathlib import Path
2import math
3from contextlib import contextmanager
4from typing import Type, Optional
5import itertools
6from functools import partial
7
8import torch
9import torch.nn as nn
10import torch.nn.functional as F
11from torch.utils.data import DataLoader
12
13from accelerate import Accelerator
14from transformers import CLIPTextModel
15from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, DPMSolverMultistepScheduler
16
17from tqdm.auto import tqdm
18from PIL import Image
19
20from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
21from models.clip.tokenizer import MultiCLIPTokenizer
22from models.clip.util import get_extended_embeddings
23from training.util import AverageMeter
24
25
26def make_grid(images, rows, cols):
27 w, h = images[0].size
28 grid = Image.new('RGB', size=(cols*w, rows*h))
29 for i, image in enumerate(images):
30 grid.paste(image, box=(i % cols*w, i//cols*h))
31 return grid
32
33
34class Checkpointer():
35 def __init__(
36 self,
37 accelerator: Accelerator,
38 vae: AutoencoderKL,
39 unet: UNet2DConditionModel,
40 text_encoder: CLIPTextModel,
41 tokenizer: MultiCLIPTokenizer,
42 sample_scheduler,
43 dtype,
44 train_dataloader: DataLoader,
45 val_dataloader: DataLoader,
46 output_dir: Path,
47 sample_steps: int = 20,
48 sample_guidance_scale: float = 7.5,
49 sample_image_size: int = 768,
50 sample_batches: int = 1,
51 sample_batch_size: int = 1,
52 seed: Optional[int] = None,
53 *args,
54 **kwargs,
55 ):
56 self.accelerator = accelerator
57 self.vae = vae
58 self.unet = unet
59 self.text_encoder = text_encoder
60 self.tokenizer = tokenizer
61 self.sample_scheduler = sample_scheduler
62 self.dtype = dtype
63 self.train_dataloader = train_dataloader
64 self.val_dataloader = val_dataloader
65 self.output_dir = output_dir
66 self.sample_steps = sample_steps
67 self.sample_guidance_scale = sample_guidance_scale
68 self.sample_image_size = sample_image_size
69 self.sample_batches = sample_batches
70 self.sample_batch_size = sample_batch_size
71 self.seed = seed if seed is not None else torch.random.seed()
72
73 @torch.no_grad()
74 def checkpoint(self, step: int, postfix: str):
75 pass
76
77 @torch.no_grad()
78 def save_samples(self, step: int):
79 print(f"Saving samples for step {step}...")
80
81 samples_path = self.output_dir.joinpath("samples")
82
83 grid_cols = min(self.sample_batch_size, 4)
84 grid_rows = (self.sample_batches * self.sample_batch_size) // grid_cols
85
86 unet = self.accelerator.unwrap_model(self.unet)
87 text_encoder = self.accelerator.unwrap_model(self.text_encoder)
88
89 orig_unet_dtype = unet.dtype
90 orig_text_encoder_dtype = text_encoder.dtype
91
92 unet.to(dtype=self.dtype)
93 text_encoder.to(dtype=self.dtype)
94
95 pipeline = VlpnStableDiffusion(
96 text_encoder=text_encoder,
97 vae=self.vae,
98 unet=self.unet,
99 tokenizer=self.tokenizer,
100 scheduler=self.sample_scheduler,
101 ).to(self.accelerator.device)
102 pipeline.set_progress_bar_config(dynamic_ncols=True)
103
104 generator = torch.Generator(device=self.accelerator.device).manual_seed(self.seed)
105
106 for pool, data, gen in [
107 ("stable", self.val_dataloader, generator),
108 ("val", self.val_dataloader, None),
109 ("train", self.train_dataloader, None)
110 ]:
111 all_samples = []
112 file_path = samples_path.joinpath(pool, f"step_{step}.jpg")
113 file_path.parent.mkdir(parents=True, exist_ok=True)
114
115 batches = list(itertools.islice(itertools.cycle(data), self.sample_batch_size * self.sample_batches))
116 prompt_ids = [
117 prompt
118 for batch in batches
119 for prompt in batch["prompt_ids"]
120 ]
121 nprompt_ids = [
122 prompt
123 for batch in batches
124 for prompt in batch["nprompt_ids"]
125 ]
126
127 for i in range(self.sample_batches):
128 start = i * self.sample_batch_size
129 end = (i + 1) * self.sample_batch_size
130 prompt = prompt_ids[start:end]
131 nprompt = nprompt_ids[start:end]
132
133 samples = pipeline(
134 prompt=prompt,
135 negative_prompt=nprompt,
136 height=self.sample_image_size,
137 width=self.sample_image_size,
138 generator=gen,
139 guidance_scale=self.sample_guidance_scale,
140 num_inference_steps=self.sample_steps,
141 output_type='pil'
142 ).images
143
144 all_samples += samples
145
146 image_grid = make_grid(all_samples, grid_rows, grid_cols)
147 image_grid.save(file_path, quality=85)
148
149 unet.to(dtype=orig_unet_dtype)
150 text_encoder.to(dtype=orig_text_encoder_dtype)
151
152 del unet
153 del text_encoder
154 del generator
155 del pipeline
156
157 if torch.cuda.is_available():
158 torch.cuda.empty_cache()
159
160
161class TrainingStrategy():
162 def __init__(
163 self,
164 tokenizer: MultiCLIPTokenizer,
165 *args,
166 **kwargs,
167 ):
168 self.tokenizer = tokenizer
169 self.checkpointer = Checkpointer(tokenizer=tokenizer, *args, **kwargs)
170
171 @property
172 def main_model(self) -> nn.Module:
173 ...
174
175 @contextmanager
176 def on_train(self, epoch: int):
177 self.tokenizer.train()
178 yield
179
180 @contextmanager
181 def on_eval(self):
182 self.tokenizer.eval()
183 yield
184
185 def on_before_optimize(self, epoch: int):
186 ...
187
188 def on_after_optimize(self, lr: float):
189 ...
190
191 def on_log():
192 return {}
193
194
195def loss_step(
196 vae: AutoencoderKL,
197 unet: UNet2DConditionModel,
198 text_encoder: CLIPTextModel,
199 seed: int,
200 noise_scheduler,
201 prior_loss_weight: float,
202 step: int,
203 batch: dict,
204 eval: bool = False
205):
206 # Convert images to latent space
207 latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach()
208 latents = latents * 0.18215
209
210 generator = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None
211
212 # Sample noise that we'll add to the latents
213 noise = torch.randn(
214 latents.shape,
215 dtype=latents.dtype,
216 layout=latents.layout,
217 device=latents.device,
218 generator=generator
219 )
220 bsz = latents.shape[0]
221 # Sample a random timestep for each image
222 timesteps = torch.randint(
223 0,
224 noise_scheduler.config.num_train_timesteps,
225 (bsz,),
226 generator=generator,
227 device=latents.device,
228 )
229 timesteps = timesteps.long()
230
231 # Add noise to the latents according to the noise magnitude at each timestep
232 # (this is the forward diffusion process)
233 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
234 noisy_latents = noisy_latents.to(dtype=unet.dtype)
235
236 # Get the text embedding for conditioning
237 encoder_hidden_states = get_extended_embeddings(
238 text_encoder,
239 batch["input_ids"],
240 batch["attention_mask"]
241 )
242 encoder_hidden_states = encoder_hidden_states.to(dtype=unet.dtype)
243
244 # Predict the noise residual
245 model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
246
247 # Get the target for loss depending on the prediction type
248 if noise_scheduler.config.prediction_type == "epsilon":
249 target = noise
250 elif noise_scheduler.config.prediction_type == "v_prediction":
251 target = noise_scheduler.get_velocity(latents, noise, timesteps)
252 else:
253 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
254
255 if batch["with_prior"].all():
256 # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
257 model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
258 target, target_prior = torch.chunk(target, 2, dim=0)
259
260 # Compute instance loss
261 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
262
263 # Compute prior loss
264 prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
265
266 # Add the prior loss to the instance loss.
267 loss = loss + prior_loss_weight * prior_loss
268 else:
269 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
270
271 acc = (model_pred == target).float().mean()
272
273 return loss, acc, bsz
274
275
276def train_loop(
277 strategy: TrainingStrategy,
278 accelerator: Accelerator,
279 vae: AutoencoderKL,
280 unet: UNet2DConditionModel,
281 text_encoder: CLIPTextModel,
282 train_dataloader: DataLoader,
283 val_dataloader: DataLoader,
284 seed: int,
285 optimizer: torch.optim.Optimizer,
286 lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
287 noise_scheduler,
288 prior_loss_weight: float = 1.0,
289 sample_frequency: int = 10,
290 checkpoint_frequency: int = 50,
291 global_step_offset: int = 0,
292 num_epochs: int = 100,
293):
294 num_training_steps_per_epoch = math.ceil(
295 len(train_dataloader) / accelerator.gradient_accumulation_steps
296 )
297 num_val_steps_per_epoch = len(val_dataloader)
298
299 num_training_steps = num_training_steps_per_epoch * num_epochs
300 num_val_steps = num_val_steps_per_epoch * num_epochs
301
302 global_step = 0
303
304 avg_loss = AverageMeter()
305 avg_acc = AverageMeter()
306
307 avg_loss_val = AverageMeter()
308 avg_acc_val = AverageMeter()
309
310 max_acc_val = 0.0
311
312 local_progress_bar = tqdm(
313 range(num_training_steps_per_epoch + num_val_steps_per_epoch),
314 disable=not accelerator.is_local_main_process,
315 dynamic_ncols=True
316 )
317 local_progress_bar.set_description(f"Epoch 1 / {num_epochs}")
318
319 global_progress_bar = tqdm(
320 range(num_training_steps + num_val_steps),
321 disable=not accelerator.is_local_main_process,
322 dynamic_ncols=True
323 )
324 global_progress_bar.set_description("Total progress")
325
326 loss_step_ = partial(
327 loss_step,
328 vae,
329 unet,
330 text_encoder,
331 seed,
332 noise_scheduler,
333 prior_loss_weight
334 )
335
336 try:
337 for epoch in range(num_epochs):
338 if accelerator.is_main_process:
339 if epoch % sample_frequency == 0 and epoch != 0:
340 strategy.checkpointer.save_samples(global_step + global_step_offset)
341
342 if epoch % checkpoint_frequency == 0 and epoch != 0:
343 strategy.checkpointer.checkpoint(global_step + global_step_offset, "training")
344
345 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}")
346 local_progress_bar.reset()
347
348 strategy.main_model.train()
349
350 with strategy.on_train(epoch):
351 for step, batch in enumerate(train_dataloader):
352 with accelerator.accumulate(strategy.main_model):
353 loss, acc, bsz = loss_step_(step, batch)
354
355 accelerator.backward(loss)
356
357 strategy.on_before_optimize(epoch)
358
359 optimizer.step()
360 lr_scheduler.step()
361 optimizer.zero_grad(set_to_none=True)
362
363 avg_loss.update(loss.detach_(), bsz)
364 avg_acc.update(acc.detach_(), bsz)
365
366 # Checks if the accelerator has performed an optimization step behind the scenes
367 if accelerator.sync_gradients:
368 strategy.on_after_optimize(lr_scheduler.get_last_lr()[0])
369
370 local_progress_bar.update(1)
371 global_progress_bar.update(1)
372
373 global_step += 1
374
375 logs = {
376 "train/loss": avg_loss.avg.item(),
377 "train/acc": avg_acc.avg.item(),
378 "train/cur_loss": loss.item(),
379 "train/cur_acc": acc.item(),
380 "lr": lr_scheduler.get_last_lr()[0],
381 }
382 logs.update(strategy.on_log())
383
384 accelerator.log(logs, step=global_step)
385
386 local_progress_bar.set_postfix(**logs)
387
388 if global_step >= num_training_steps:
389 break
390
391 accelerator.wait_for_everyone()
392
393 strategy.main_model.eval()
394
395 cur_loss_val = AverageMeter()
396 cur_acc_val = AverageMeter()
397
398 with torch.inference_mode(), strategy.on_eval():
399 for step, batch in enumerate(val_dataloader):
400 loss, acc, bsz = loss_step_(step, batch, True)
401
402 loss = loss.detach_()
403 acc = acc.detach_()
404
405 cur_loss_val.update(loss, bsz)
406 cur_acc_val.update(acc, bsz)
407
408 avg_loss_val.update(loss, bsz)
409 avg_acc_val.update(acc, bsz)
410
411 local_progress_bar.update(1)
412 global_progress_bar.update(1)
413
414 logs = {
415 "val/loss": avg_loss_val.avg.item(),
416 "val/acc": avg_acc_val.avg.item(),
417 "val/cur_loss": loss.item(),
418 "val/cur_acc": acc.item(),
419 }
420 local_progress_bar.set_postfix(**logs)
421
422 logs["val/cur_loss"] = cur_loss_val.avg.item()
423 logs["val/cur_acc"] = cur_acc_val.avg.item()
424
425 accelerator.log(logs, step=global_step)
426
427 local_progress_bar.clear()
428 global_progress_bar.clear()
429
430 if accelerator.is_main_process:
431 if avg_acc_val.avg.item() > max_acc_val:
432 accelerator.print(
433 f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}")
434 strategy.checkpointer.checkpoint(global_step + global_step_offset, "milestone")
435 max_acc_val = avg_acc_val.avg.item()
436
437 # Create the pipeline using using the trained modules and save it.
438 if accelerator.is_main_process:
439 print("Finished!")
440 strategy.checkpointer.checkpoint(global_step + global_step_offset, "end")
441 strategy.checkpointer.save_samples(global_step + global_step_offset)
442 accelerator.end_training()
443
444 except KeyboardInterrupt:
445 if accelerator.is_main_process:
446 print("Interrupted")
447 strategy.checkpointer.checkpoint(global_step + global_step_offset, "end")
448 accelerator.end_training()
449
450
451class Trainer():
452 def __init__(
453 self,
454 accelerator: Accelerator,
455 unet: UNet2DConditionModel,
456 text_encoder: CLIPTextModel,
457 tokenizer: MultiCLIPTokenizer,
458 vae: AutoencoderKL,
459 noise_scheduler: DDPMScheduler,
460 sample_scheduler: DPMSolverMultistepScheduler,
461 train_dataloader: DataLoader,
462 val_dataloader: DataLoader,
463 dtype: torch.dtype,
464 ):
465 self.accelerator = accelerator
466 self.unet = unet
467 self.text_encoder = text_encoder
468 self.tokenizer = tokenizer
469 self.vae = vae
470 self.noise_scheduler = noise_scheduler
471 self.sample_scheduler = sample_scheduler
472 self.train_dataloader = train_dataloader
473 self.val_dataloader = val_dataloader
474 self.dtype = dtype
475
476 def __call__(
477 self,
478 strategy_class: Type[TrainingStrategy],
479 optimizer,
480 lr_scheduler,
481 num_train_epochs: int = 100,
482 sample_frequency: int = 20,
483 checkpoint_frequency: int = 50,
484 global_step_offset: int = 0,
485 prior_loss_weight: float = 0,
486 seed: Optional[int] = None,
487 **kwargs,
488 ):
489 unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = self.accelerator.prepare(
490 self.unet, self.text_encoder, optimizer, self.train_dataloader, self.val_dataloader, lr_scheduler
491 )
492
493 self.vae.to(self.accelerator.device, dtype=self.dtype)
494
495 for model in (unet, text_encoder, self.vae):
496 model.requires_grad_(False)
497 model.eval()
498
499 if seed is None:
500 seed = torch.random.seed()
501
502 strategy = strategy_class(
503 accelerator=self.accelerator,
504 vae=self.vae,
505 unet=unet,
506 text_encoder=text_encoder,
507 tokenizer=self.tokenizer,
508 sample_scheduler=self.sample_scheduler,
509 train_dataloader=train_dataloader,
510 val_dataloader=val_dataloader,
511 dtype=self.dtype,
512 seed=seed,
513 **kwargs
514 )
515
516 if self.accelerator.is_main_process:
517 self.accelerator.init_trackers("textual_inversion")
518
519 train_loop(
520 strategy=strategy,
521 accelerator=self.accelerator,
522 vae=self.vae,
523 unet=unet,
524 text_encoder=text_encoder,
525 train_dataloader=train_dataloader,
526 val_dataloader=val_dataloader,
527 seed=seed,
528 optimizer=optimizer,
529 lr_scheduler=lr_scheduler,
530 noise_scheduler=self.noise_scheduler,
531 prior_loss_weight=prior_loss_weight,
532 sample_frequency=sample_frequency,
533 checkpoint_frequency=checkpoint_frequency,
534 global_step_offset=global_step_offset,
535 num_epochs=num_train_epochs,
536 )
537
538 self.accelerator.free_memory()