summaryrefslogtreecommitdiffstats
path: root/training/common.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/common.py')
-rw-r--r--training/common.py264
1 files changed, 17 insertions, 247 deletions
diff --git a/training/common.py b/training/common.py
index 73ce814..b6964a3 100644
--- a/training/common.py
+++ b/training/common.py
@@ -1,52 +1,24 @@
1import math 1import math
2from pathlib import Path
3from contextlib import _GeneratorContextManager, nullcontext 2from contextlib import _GeneratorContextManager, nullcontext
4from typing import Callable, Any, Tuple, Union, Literal, Optional, NamedTuple 3from typing import Callable, Any, Tuple, Union
5import datetime
6import logging
7 4
8import torch 5import torch
9import torch.nn.functional as F 6import torch.nn.functional as F
10from torch.utils.data import DataLoader 7from torch.utils.data import DataLoader
11 8
12from accelerate import Accelerator 9from accelerate import Accelerator
13from accelerate.utils import LoggerType, set_seed
14from transformers import CLIPTextModel 10from transformers import CLIPTextModel
15from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, DPMSolverMultistepScheduler 11from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, DPMSolverMultistepScheduler
16from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup
17 12
18from tqdm.auto import tqdm 13from tqdm.auto import tqdm
19from slugify import slugify
20 14
21from data.csv import VlpnDataModule, VlpnDataItem
22from util import load_embeddings_from_dir
23from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 15from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
24from models.clip.embeddings import patch_managed_embeddings 16from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings
25from models.clip.util import get_extended_embeddings 17from models.clip.util import get_extended_embeddings
26from models.clip.tokenizer import MultiCLIPTokenizer 18from models.clip.tokenizer import MultiCLIPTokenizer
27from training.optimization import get_one_cycle_schedule
28from training.util import AverageMeter, CheckpointerBase 19from training.util import AverageMeter, CheckpointerBase
29 20
30 21
31class TrainingSetup(NamedTuple):
32 accelerator: Accelerator
33 tokenizer: MultiCLIPTokenizer
34 text_encoder: CLIPTextModel
35 vae: AutoencoderKL
36 unet: UNet2DConditionModel
37 noise_scheduler: DDPMScheduler
38 checkpoint_scheduler: DPMSolverMultistepScheduler
39 optimizer_class: Callable
40 learning_rate: float
41 weight_dtype: torch.dtype
42 output_dir: Path
43 seed: int
44 train_dataloader: DataLoader
45 val_dataloader: DataLoader
46 placeholder_token: list[str]
47 placeholder_token_ids: list[list[int]]
48
49
50def noop(*args, **kwards): 22def noop(*args, **kwards):
51 pass 23 pass
52 24
@@ -59,57 +31,6 @@ def noop_on_log():
59 return {} 31 return {}
60 32
61 33
62def get_scheduler(
63 id: str,
64 optimizer: torch.optim.Optimizer,
65 num_training_steps_per_epoch: int,
66 gradient_accumulation_steps: int,
67 min_lr: float = 0.04,
68 warmup_func: str = "cos",
69 annealing_func: str = "cos",
70 warmup_exp: int = 1,
71 annealing_exp: int = 1,
72 cycles: int = 1,
73 train_epochs: int = 100,
74 warmup_epochs: int = 10,
75):
76 num_training_steps_per_epoch = math.ceil(
77 num_training_steps_per_epoch / gradient_accumulation_steps
78 ) * gradient_accumulation_steps
79 num_training_steps = train_epochs * num_training_steps_per_epoch
80 num_warmup_steps = warmup_epochs * num_training_steps_per_epoch
81
82 if id == "one_cycle":
83 lr_scheduler = get_one_cycle_schedule(
84 optimizer=optimizer,
85 num_training_steps=num_training_steps,
86 warmup=warmup_func,
87 annealing=annealing_func,
88 warmup_exp=warmup_exp,
89 annealing_exp=annealing_exp,
90 min_lr=min_lr,
91 )
92 elif id == "cosine_with_restarts":
93 if cycles is None:
94 cycles = math.ceil(math.sqrt(((num_training_steps - num_warmup_steps) / num_training_steps_per_epoch)))
95
96 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
97 optimizer=optimizer,
98 num_warmup_steps=num_warmup_steps,
99 num_training_steps=num_training_steps,
100 num_cycles=cycles,
101 )
102 else:
103 lr_scheduler = get_scheduler_(
104 id,
105 optimizer=optimizer,
106 num_warmup_steps=num_warmup_steps,
107 num_training_steps=num_training_steps,
108 )
109
110 return lr_scheduler
111
112
113def generate_class_images( 34def generate_class_images(
114 accelerator, 35 accelerator,
115 text_encoder, 36 text_encoder,
@@ -162,194 +83,43 @@ def generate_class_images(
162 torch.cuda.empty_cache() 83 torch.cuda.empty_cache()
163 84
164 85
165def train_setup( 86def get_models(pretrained_model_name_or_path: str):
166 output_dir: str,
167 project: str,
168 pretrained_model_name_or_path: str,
169 learning_rate: float,
170 data_file: str,
171 gradient_accumulation_steps: int = 1,
172 mixed_precision: Literal["no", "fp16", "bf16"] = "no",
173 seed: Optional[int] = None,
174 vector_shuffle: Union[bool, Literal["all", "trailing", "leading", "between", "off"]] = "auto",
175 vector_dropout: float = 0.1,
176 gradient_checkpointing: bool = True,
177 embeddings_dir: Optional[str] = None,
178 placeholder_token: list[str] = [],
179 initializer_token: list[str] = [],
180 num_vectors: int = 1,
181 scale_lr: bool = False,
182 use_8bit_adam: bool = False,
183 train_batch_size: int = 1,
184 class_image_dir: Optional[str] = None,
185 num_class_images: int = 0,
186 resolution: int = 768,
187 num_buckets: int = 0,
188 progressive_buckets: bool = False,
189 bucket_step_size: int = 64,
190 bucket_max_pixels: Optional[int] = None,
191 tag_dropout: float = 0.1,
192 tag_shuffle: bool = True,
193 data_template: str = "template",
194 valid_set_size: Optional[int] = None,
195 valid_set_repeat: int = 1,
196 data_filter: Optional[Callable[[VlpnDataItem], bool]] = None,
197 sample_batch_size: int = 1,
198 sample_image_size: int = 768,
199 sample_steps: int = 20,
200) -> TrainingSetup:
201 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
202 output_dir = Path(output_dir).joinpath(slugify(project), now)
203 output_dir.mkdir(parents=True, exist_ok=True)
204
205 accelerator = Accelerator(
206 log_with=LoggerType.TENSORBOARD,
207 logging_dir=f"{output_dir}",
208 gradient_accumulation_steps=gradient_accumulation_steps,
209 mixed_precision=mixed_precision
210 )
211
212 logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG)
213
214 seed = seed or (torch.random.seed() >> 32)
215 set_seed(seed)
216
217 # Load the tokenizer and add the placeholder token as a additional special token
218 tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') 87 tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer')
219 tokenizer.set_use_vector_shuffle(vector_shuffle)
220 tokenizer.set_dropout(vector_dropout)
221
222 # Load models and create wrapper for stable diffusion
223 text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') 88 text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder')
224 vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') 89 vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae')
225 unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet') 90 unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet')
226 noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') 91 noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler')
227 checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( 92 sample_scheduler = DPMSolverMultistepScheduler.from_pretrained(
228 pretrained_model_name_or_path, subfolder='scheduler') 93 pretrained_model_name_or_path, subfolder='scheduler')
229 94
230 vae.enable_slicing() 95 vae.enable_slicing()
231 vae.set_use_memory_efficient_attention_xformers(True) 96 vae.set_use_memory_efficient_attention_xformers(True)
232 unet.set_use_memory_efficient_attention_xformers(True) 97 unet.set_use_memory_efficient_attention_xformers(True)
233 98
234 if gradient_checkpointing:
235 unet.enable_gradient_checkpointing()
236 text_encoder.gradient_checkpointing_enable()
237
238 embeddings = patch_managed_embeddings(text_encoder) 99 embeddings = patch_managed_embeddings(text_encoder)
239 100
240 if embeddings_dir is not None: 101 return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings
241 embeddings_dir = Path(embeddings_dir)
242 if not embeddings_dir.exists() or not embeddings_dir.is_dir():
243 raise ValueError("--embeddings_dir must point to an existing directory")
244 102
245 added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir)
246 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}")
247 103
248 # Convert the initializer_token, placeholder_token to ids 104def add_placeholder_tokens(
105 tokenizer: MultiCLIPTokenizer,
106 embeddings: ManagedCLIPTextEmbeddings,
107 placeholder_tokens: list[str],
108 initializer_tokens: list[str],
109 num_vectors: Union[list[int], int]
110):
249 initializer_token_ids = [ 111 initializer_token_ids = [
250 tokenizer.encode(token, add_special_tokens=False) 112 tokenizer.encode(token, add_special_tokens=False)
251 for token in initializer_token 113 for token in initializer_tokens
252 ] 114 ]
115 placeholder_token_ids = tokenizer.add_multi_tokens(placeholder_tokens, num_vectors)
253 116
254 placeholder_token_ids = tokenizer.add_multi_tokens(placeholder_token, num_vectors)
255 embeddings.resize(len(tokenizer)) 117 embeddings.resize(len(tokenizer))
256 118
257 for (new_id, init_ids) in zip(placeholder_token_ids, initializer_token_ids): 119 for (placeholder_token_id, initializer_token_id) in zip(placeholder_token_ids, initializer_token_ids):
258 embeddings.add_embed(new_id, init_ids) 120 embeddings.add_embed(placeholder_token_id, initializer_token_id)
259
260 init_ratios = [
261 f"{len(init_ids)} / {len(new_id)}"
262 for new_id, init_ids in zip(placeholder_token_ids, initializer_token_ids)
263 ]
264
265 print(f"Added {len(placeholder_token_ids)} new tokens: {list(zip(placeholder_token, placeholder_token_ids, init_ratios))}")
266 121
267 vae.requires_grad_(False) 122 return placeholder_token_ids
268 unet.requires_grad_(False)
269 text_encoder.requires_grad_(False)
270
271 if scale_lr:
272 learning_rate = (
273 learning_rate * gradient_accumulation_steps *
274 train_batch_size * accelerator.num_processes
275 )
276
277 # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
278 if use_8bit_adam:
279 try:
280 import bitsandbytes as bnb
281 except ImportError:
282 raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.")
283
284 optimizer_class = bnb.optim.AdamW8bit
285 else:
286 optimizer_class = torch.optim.AdamW
287
288 weight_dtype = torch.float32
289 if mixed_precision == "fp16":
290 weight_dtype = torch.float16
291 elif mixed_precision == "bf16":
292 weight_dtype = torch.bfloat16
293
294 datamodule = VlpnDataModule(
295 data_file=data_file,
296 batch_size=train_batch_size,
297 tokenizer=tokenizer,
298 class_subdir=class_image_dir,
299 num_class_images=num_class_images,
300 size=resolution,
301 num_buckets=num_buckets,
302 progressive_buckets=progressive_buckets,
303 bucket_step_size=bucket_step_size,
304 bucket_max_pixels=bucket_max_pixels,
305 dropout=tag_dropout,
306 shuffle=tag_shuffle,
307 template_key=data_template,
308 valid_set_size=valid_set_size,
309 valid_set_repeat=valid_set_repeat,
310 seed=seed,
311 filter=data_filter,
312 dtype=weight_dtype
313 )
314 datamodule.setup()
315
316 train_dataloader = datamodule.train_dataloader
317 val_dataloader = datamodule.val_dataloader
318
319 train_dataloader, val_dataloader = accelerator.prepare(train_dataloader, val_dataloader)
320
321 if num_class_images != 0:
322 generate_class_images(
323 accelerator,
324 text_encoder,
325 vae,
326 unet,
327 tokenizer,
328 checkpoint_scheduler,
329 datamodule.data_train,
330 sample_batch_size,
331 sample_image_size,
332 sample_steps
333 )
334
335 return TrainingSetup(
336 accelerator=accelerator,
337 tokenizer=tokenizer,
338 text_encoder=text_encoder,
339 vae=vae,
340 unet=unet,
341 noise_scheduler=noise_scheduler,
342 checkpoint_scheduler=checkpoint_scheduler,
343 optimizer_class=optimizer_class,
344 learning_rate=learning_rate,
345 output_dir=output_dir,
346 weight_dtype=weight_dtype,
347 seed=seed,
348 train_dataloader=train_dataloader,
349 val_dataloader=val_dataloader,
350 placeholder_token=placeholder_token,
351 placeholder_token_ids=placeholder_token_ids
352 )
353 123
354 124
355def loss_step( 125def loss_step(