summaryrefslogtreecommitdiffstats
path: root/training/util.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/util.py')
-rw-r--r--training/util.py214
1 files changed, 118 insertions, 96 deletions
diff --git a/training/util.py b/training/util.py
index 781cf04..a292edd 100644
--- a/training/util.py
+++ b/training/util.py
@@ -1,12 +1,40 @@
1from pathlib import Path 1from pathlib import Path
2import json 2import json
3import copy 3import copy
4import itertools 4from typing import Iterable, Union
5from typing import Iterable, Optional
6from contextlib import contextmanager 5from contextlib import contextmanager
7 6
8import torch 7import torch
9from PIL import Image 8
9from transformers import CLIPTextModel
10from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, DPMSolverMultistepScheduler
11
12from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
13from models.clip.tokenizer import MultiCLIPTokenizer
14from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings
15
16
17class TrainingStrategy():
18 @property
19 def main_model(self) -> torch.nn.Module:
20 ...
21
22 @contextmanager
23 def on_train(self, epoch: int):
24 yield
25
26 @contextmanager
27 def on_eval(self):
28 yield
29
30 def on_before_optimize(self, epoch: int):
31 ...
32
33 def on_after_optimize(self, lr: float):
34 ...
35
36 def on_log():
37 return {}
10 38
11 39
12def save_args(basepath: Path, args, extra={}): 40def save_args(basepath: Path, args, extra={}):
@@ -16,12 +44,93 @@ def save_args(basepath: Path, args, extra={}):
16 json.dump(info, f, indent=4) 44 json.dump(info, f, indent=4)
17 45
18 46
19def make_grid(images, rows, cols): 47def generate_class_images(
20 w, h = images[0].size 48 accelerator,
21 grid = Image.new('RGB', size=(cols*w, rows*h)) 49 text_encoder,
22 for i, image in enumerate(images): 50 vae,
23 grid.paste(image, box=(i % cols*w, i//cols*h)) 51 unet,
24 return grid 52 tokenizer,
53 scheduler,
54 data_train,
55 sample_batch_size,
56 sample_image_size,
57 sample_steps
58):
59 missing_data = [item for item in data_train if not item.class_image_path.exists()]
60
61 if len(missing_data) == 0:
62 return
63
64 batched_data = [
65 missing_data[i:i+sample_batch_size]
66 for i in range(0, len(missing_data), sample_batch_size)
67 ]
68
69 pipeline = VlpnStableDiffusion(
70 text_encoder=text_encoder,
71 vae=vae,
72 unet=unet,
73 tokenizer=tokenizer,
74 scheduler=scheduler,
75 ).to(accelerator.device)
76 pipeline.set_progress_bar_config(dynamic_ncols=True)
77
78 with torch.inference_mode():
79 for batch in batched_data:
80 image_name = [item.class_image_path for item in batch]
81 prompt = [item.cprompt for item in batch]
82 nprompt = [item.nprompt for item in batch]
83
84 images = pipeline(
85 prompt=prompt,
86 negative_prompt=nprompt,
87 height=sample_image_size,
88 width=sample_image_size,
89 num_inference_steps=sample_steps
90 ).images
91
92 for i, image in enumerate(images):
93 image.save(image_name[i])
94
95 del pipeline
96
97 if torch.cuda.is_available():
98 torch.cuda.empty_cache()
99
100
101def get_models(pretrained_model_name_or_path: str):
102 tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer')
103 text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder')
104 vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae')
105 unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet')
106 noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler')
107 sample_scheduler = DPMSolverMultistepScheduler.from_pretrained(
108 pretrained_model_name_or_path, subfolder='scheduler')
109
110 embeddings = patch_managed_embeddings(text_encoder)
111
112 return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings
113
114
115def add_placeholder_tokens(
116 tokenizer: MultiCLIPTokenizer,
117 embeddings: ManagedCLIPTextEmbeddings,
118 placeholder_tokens: list[str],
119 initializer_tokens: list[str],
120 num_vectors: Union[list[int], int]
121):
122 initializer_token_ids = [
123 tokenizer.encode(token, add_special_tokens=False)
124 for token in initializer_tokens
125 ]
126 placeholder_token_ids = tokenizer.add_multi_tokens(placeholder_tokens, num_vectors)
127
128 embeddings.resize(len(tokenizer))
129
130 for (placeholder_token_id, initializer_token_id) in zip(placeholder_token_ids, initializer_token_ids):
131 embeddings.add_embed(placeholder_token_id, initializer_token_id)
132
133 return placeholder_token_ids, initializer_token_ids
25 134
26 135
27class AverageMeter: 136class AverageMeter:
@@ -38,93 +147,6 @@ class AverageMeter:
38 self.avg = self.sum / self.count 147 self.avg = self.sum / self.count
39 148
40 149
41class CheckpointerBase:
42 def __init__(
43 self,
44 train_dataloader,
45 val_dataloader,
46 output_dir: Path,
47 sample_steps: int = 20,
48 sample_guidance_scale: float = 7.5,
49 sample_image_size: int = 768,
50 sample_batches: int = 1,
51 sample_batch_size: int = 1,
52 seed: Optional[int] = None
53 ):
54 self.train_dataloader = train_dataloader
55 self.val_dataloader = val_dataloader
56 self.output_dir = output_dir
57 self.sample_image_size = sample_image_size
58 self.sample_steps = sample_steps
59 self.sample_guidance_scale = sample_guidance_scale
60 self.sample_batches = sample_batches
61 self.sample_batch_size = sample_batch_size
62 self.seed = seed if seed is not None else torch.random.seed()
63
64 @torch.no_grad()
65 def checkpoint(self, step: int, postfix: str):
66 pass
67
68 @torch.inference_mode()
69 def save_samples(self, pipeline, step: int):
70 samples_path = Path(self.output_dir).joinpath("samples")
71
72 generator = torch.Generator(device=pipeline.device).manual_seed(self.seed)
73
74 grid_cols = min(self.sample_batch_size, 4)
75 grid_rows = (self.sample_batches * self.sample_batch_size) // grid_cols
76
77 for pool, data, gen in [
78 ("stable", self.val_dataloader, generator),
79 ("val", self.val_dataloader, None),
80 ("train", self.train_dataloader, None)
81 ]:
82 all_samples = []
83 file_path = samples_path.joinpath(pool, f"step_{step}.jpg")
84 file_path.parent.mkdir(parents=True, exist_ok=True)
85
86 batches = list(itertools.islice(itertools.cycle(data), self.sample_batch_size * self.sample_batches))
87 prompt_ids = [
88 prompt
89 for batch in batches
90 for prompt in batch["prompt_ids"]
91 ]
92 nprompt_ids = [
93 prompt
94 for batch in batches
95 for prompt in batch["nprompt_ids"]
96 ]
97
98 for i in range(self.sample_batches):
99 start = i * self.sample_batch_size
100 end = (i + 1) * self.sample_batch_size
101 prompt = prompt_ids[start:end]
102 nprompt = nprompt_ids[start:end]
103
104 samples = pipeline(
105 prompt=prompt,
106 negative_prompt=nprompt,
107 height=self.sample_image_size,
108 width=self.sample_image_size,
109 generator=gen,
110 guidance_scale=self.sample_guidance_scale,
111 num_inference_steps=self.sample_steps,
112 output_type='pil'
113 ).images
114
115 all_samples += samples
116
117 del samples
118
119 image_grid = make_grid(all_samples, grid_rows, grid_cols)
120 image_grid.save(file_path, quality=85)
121
122 del all_samples
123 del image_grid
124
125 del generator
126
127
128# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14 150# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
129class EMAModel: 151class EMAModel:
130 """ 152 """