summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/functional.py118
-rw-r--r--training/strategy/ti.py164
2 files changed, 282 insertions, 0 deletions
diff --git a/training/functional.py b/training/functional.py
index 1f2ca6d..e54c9c8 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -2,6 +2,8 @@ import math
2from contextlib import _GeneratorContextManager, nullcontext 2from contextlib import _GeneratorContextManager, nullcontext
3from typing import Callable, Any, Tuple, Union, Optional 3from typing import Callable, Any, Tuple, Union, Optional
4from functools import partial 4from functools import partial
5from pathlib import Path
6import itertools
5 7
6import torch 8import torch
7import torch.nn.functional as F 9import torch.nn.functional as F
@@ -26,6 +28,14 @@ def const(result=None):
26 return fn 28 return fn
27 29
28 30
31def make_grid(images, rows, cols):
32 w, h = images[0].size
33 grid = Image.new('RGB', size=(cols*w, rows*h))
34 for i, image in enumerate(images):
35 grid.paste(image, box=(i % cols*w, i//cols*h))
36 return grid
37
38
29def get_models(pretrained_model_name_or_path: str): 39def get_models(pretrained_model_name_or_path: str):
30 tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') 40 tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer')
31 text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') 41 text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder')
@@ -40,6 +50,107 @@ def get_models(pretrained_model_name_or_path: str):
40 return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings 50 return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings
41 51
42 52
53def save_samples(
54 accelerator: Accelerator,
55 unet: UNet2DConditionModel,
56 text_encoder: CLIPTextModel,
57 tokenizer: MultiCLIPTokenizer,
58 vae: AutoencoderKL,
59 sample_scheduler: DPMSolverMultistepScheduler,
60 train_dataloader: DataLoader,
61 val_dataloader: DataLoader,
62 dtype: torch.dtype,
63 output_dir: Path,
64 seed: int,
65 step: int,
66 batch_size: int = 1,
67 num_batches: int = 1,
68 num_steps: int = 20,
69 guidance_scale: float = 7.5,
70 image_size: Optional[int] = None,
71):
72 print(f"Saving samples for step {step}...")
73
74 samples_path = output_dir.joinpath("samples")
75
76 grid_cols = min(batch_size, 4)
77 grid_rows = (num_batches * batch_size) // grid_cols
78
79 unet = accelerator.unwrap_model(unet)
80 text_encoder = accelerator.unwrap_model(text_encoder)
81
82 orig_unet_dtype = unet.dtype
83 orig_text_encoder_dtype = text_encoder.dtype
84
85 unet.to(dtype=dtype)
86 text_encoder.to(dtype=dtype)
87
88 pipeline = VlpnStableDiffusion(
89 text_encoder=text_encoder,
90 vae=vae,
91 unet=unet,
92 tokenizer=tokenizer,
93 scheduler=sample_scheduler,
94 ).to(accelerator.device)
95 pipeline.set_progress_bar_config(dynamic_ncols=True)
96
97 generator = torch.Generator(device=accelerator.device).manual_seed(seed)
98
99 for pool, data, gen in [
100 ("stable", val_dataloader, generator),
101 ("val", val_dataloader, None),
102 ("train", train_dataloader, None)
103 ]:
104 all_samples = []
105 file_path = samples_path.joinpath(pool, f"step_{step}.jpg")
106 file_path.parent.mkdir(parents=True, exist_ok=True)
107
108 batches = list(itertools.islice(itertools.cycle(data), batch_size * num_batches))
109 prompt_ids = [
110 prompt
111 for batch in batches
112 for prompt in batch["prompt_ids"]
113 ]
114 nprompt_ids = [
115 prompt
116 for batch in batches
117 for prompt in batch["nprompt_ids"]
118 ]
119
120 for i in range(num_batches):
121 start = i * batch_size
122 end = (i + 1) * batch_size
123 prompt = prompt_ids[start:end]
124 nprompt = nprompt_ids[start:end]
125
126 samples = pipeline(
127 prompt=prompt,
128 negative_prompt=nprompt,
129 height=image_size,
130 width=image_size,
131 generator=gen,
132 guidance_scale=guidance_scale,
133 num_inference_steps=num_steps,
134 output_type='pil'
135 ).images
136
137 all_samples += samples
138
139 image_grid = make_grid(all_samples, grid_rows, grid_cols)
140 image_grid.save(file_path, quality=85)
141
142 unet.to(dtype=orig_unet_dtype)
143 text_encoder.to(dtype=orig_text_encoder_dtype)
144
145 del unet
146 del text_encoder
147 del generator
148 del pipeline
149
150 if torch.cuda.is_available():
151 torch.cuda.empty_cache()
152
153
43def generate_class_images( 154def generate_class_images(
44 accelerator: Accelerator, 155 accelerator: Accelerator,
45 text_encoder: CLIPTextModel, 156 text_encoder: CLIPTextModel,
@@ -109,6 +220,10 @@ def get_models(pretrained_model_name_or_path: str):
109 220
110 embeddings = patch_managed_embeddings(text_encoder) 221 embeddings = patch_managed_embeddings(text_encoder)
111 222
223 vae.requires_grad_(False)
224 unet.requires_grad_(False)
225 text_encoder.requires_grad_(False)
226
112 return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings 227 return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings
113 228
114 229
@@ -427,6 +542,9 @@ def train(
427 seed, 542 seed,
428 ) 543 )
429 544
545 if accelerator.is_main_process:
546 accelerator.init_trackers("textual_inversion")
547
430 train_loop( 548 train_loop(
431 accelerator=accelerator, 549 accelerator=accelerator,
432 optimizer=optimizer, 550 optimizer=optimizer,
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
new file mode 100644
index 0000000..83dc566
--- /dev/null
+++ b/training/strategy/ti.py
@@ -0,0 +1,164 @@
1from contextlib import nullcontext
2from typing import Optional
3from functools import partial
4from contextlib import contextmanager, nullcontext
5from pathlib import Path
6
7import torch
8from torch.utils.data import DataLoader
9
10from accelerate import Accelerator
11from transformers import CLIPTextModel
12from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepScheduler
13
14from slugify import slugify
15
16from models.clip.tokenizer import MultiCLIPTokenizer
17from training.util import EMAModel
18from training.functional import save_samples
19
20
21def textual_inversion_strategy(
22 accelerator: Accelerator,
23 unet: UNet2DConditionModel,
24 text_encoder: CLIPTextModel,
25 tokenizer: MultiCLIPTokenizer,
26 vae: AutoencoderKL,
27 sample_scheduler: DPMSolverMultistepScheduler,
28 train_dataloader: DataLoader,
29 val_dataloader: DataLoader,
30 dtype: torch.dtype,
31 output_dir: Path,
32 seed: int,
33 placeholder_tokens: list[str],
34 placeholder_token_ids: list[list[int]],
35 learning_rate: float,
36 gradient_checkpointing: bool = False,
37 use_emb_decay: bool = False,
38 emb_decay_target: float = 0.4,
39 emb_decay_factor: float = 1,
40 emb_decay_start: float = 1e-4,
41 use_ema: bool = False,
42 ema_inv_gamma: float = 1.0,
43 ema_power: int = 1,
44 ema_max_decay: float = 0.9999,
45 sample_batch_size: int = 1,
46 sample_num_batches: int = 1,
47 sample_num_steps: int = 20,
48 sample_guidance_scale: float = 7.5,
49 sample_image_size: Optional[int] = None,
50):
51 save_samples_ = partial(
52 save_samples,
53 accelerator=accelerator,
54 unet=unet,
55 text_encoder=text_encoder,
56 tokenizer=tokenizer,
57 vae=vae,
58 sample_scheduler=sample_scheduler,
59 train_dataloader=train_dataloader,
60 val_dataloader=val_dataloader,
61 dtype=dtype,
62 output_dir=output_dir,
63 seed=seed,
64 batch_size=sample_batch_size,
65 num_batches=sample_num_batches,
66 num_steps=sample_num_steps,
67 guidance_scale=sample_guidance_scale,
68 image_size=sample_image_size,
69 )
70
71 if use_ema:
72 ema_embeddings = EMAModel(
73 text_encoder.text_model.embeddings.temp_token_embedding.parameters(),
74 inv_gamma=ema_inv_gamma,
75 power=ema_power,
76 max_value=ema_max_decay,
77 )
78 else:
79 ema_embeddings = None
80
81 def on_prepare():
82 text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(True)
83
84 if use_ema:
85 ema_embeddings.to(accelerator.device)
86
87 if gradient_checkpointing:
88 unet.train()
89
90 @contextmanager
91 def on_train(epoch: int):
92 try:
93 tokenizer.train()
94 yield
95 finally:
96 pass
97
98 @contextmanager
99 def on_eval():
100 try:
101 tokenizer.eval()
102
103 ema_context = ema_embeddings.apply_temporary(
104 text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if use_ema else nullcontext()
105
106 with ema_context:
107 yield
108 finally:
109 pass
110
111 @torch.no_grad()
112 def on_after_optimize(lr: float):
113 if use_emb_decay:
114 text_encoder.text_model.embeddings.normalize(
115 emb_decay_target,
116 min(1.0, max(0.0, emb_decay_factor * ((lr - emb_decay_start) / (learning_rate - emb_decay_start))))
117 )
118
119 if use_ema:
120 ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters())
121
122 def on_log():
123 if use_ema:
124 return {"ema_decay": ema_embeddings.decay}
125 return {}
126
127 @torch.no_grad()
128 def on_checkpoint(step, postfix):
129 print(f"Saving checkpoint for step {step}...")
130
131 checkpoints_path = output_dir.joinpath("checkpoints")
132 checkpoints_path.mkdir(parents=True, exist_ok=True)
133
134 text_encoder = accelerator.unwrap_model(text_encoder)
135
136 ema_context = ema_embeddings.apply_temporary(
137 text_encoder.text_model.embeddings.temp_token_embedding.parameters()
138 ) if ema_embeddings is not None else nullcontext()
139
140 with ema_context:
141 for (token, ids) in zip(placeholder_tokens, placeholder_token_ids):
142 text_encoder.text_model.embeddings.save_embed(
143 ids,
144 checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin")
145 )
146
147 @torch.no_grad()
148 def on_sample(step):
149 ema_context = ema_embeddings.apply_temporary(
150 text_encoder.text_model.embeddings.temp_token_embedding.parameters()
151 ) if ema_embeddings is not None else nullcontext()
152
153 with ema_context:
154 save_samples_(step=step)
155
156 return {
157 "on_prepare": on_prepare,
158 "on_train": on_train,
159 "on_eval": on_eval,
160 "on_after_optimize": on_after_optimize,
161 "on_log": on_log,
162 "on_checkpoint": on_checkpoint,
163 "on_sample": on_sample,
164 }