summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--infer.py6
-rw-r--r--models/clip/embeddings.py5
-rw-r--r--models/clip/tokenizer.py9
-rw-r--r--train_dreambooth.py86
-rw-r--r--train_lora.py946
-rw-r--r--train_ti.py86
-rw-r--r--training/common.py75
-rw-r--r--util.py (renamed from common.py)11
8 files changed, 133 insertions, 1091 deletions
diff --git a/infer.py b/infer.py
index b29b136..507d0cf 100644
--- a/infer.py
+++ b/infer.py
@@ -28,7 +28,7 @@ from transformers import CLIPTextModel
28from models.clip.embeddings import patch_managed_embeddings 28from models.clip.embeddings import patch_managed_embeddings
29from models.clip.tokenizer import MultiCLIPTokenizer 29from models.clip.tokenizer import MultiCLIPTokenizer
30from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 30from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
31from common import load_config, load_embeddings_from_dir 31from util import load_config, load_embeddings_from_dir
32 32
33 33
34torch.backends.cuda.matmul.allow_tf32 = True 34torch.backends.cuda.matmul.allow_tf32 = True
@@ -192,12 +192,12 @@ def save_args(basepath, args, extra={}):
192 192
193 193
194def load_embeddings(pipeline, embeddings_dir): 194def load_embeddings(pipeline, embeddings_dir):
195 added_tokens = load_embeddings_from_dir( 195 added_tokens, added_ids = load_embeddings_from_dir(
196 pipeline.tokenizer, 196 pipeline.tokenizer,
197 pipeline.text_encoder.text_model.embeddings, 197 pipeline.text_encoder.text_model.embeddings,
198 Path(embeddings_dir) 198 Path(embeddings_dir)
199 ) 199 )
200 print(f"Added {len(added_tokens)} tokens from embeddings dir: {added_tokens}") 200 print(f"Added {len(added_tokens)} tokens from embeddings dir: {zip(added_tokens, added_ids)}")
201 201
202 202
203def create_pipeline(model, dtype): 203def create_pipeline(model, dtype):
diff --git a/models/clip/embeddings.py b/models/clip/embeddings.py
index 1280ebd..fb639f1 100644
--- a/models/clip/embeddings.py
+++ b/models/clip/embeddings.py
@@ -53,6 +53,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
53 self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor) 53 self.token_embedding = resize_embedding(self.token_embedding, size, self.initializer_factor)
54 54
55 def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None): 55 def add_embed(self, token_ids: Union[int, list[int]], initializer: Optional[Union[int, list[int], torch.FloatTensor]] = None):
56 init_ratio = 1.0
57
56 if isinstance(token_ids, int): 58 if isinstance(token_ids, int):
57 token_ids = [token_ids] 59 token_ids = [token_ids]
58 60
@@ -63,6 +65,7 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
63 initializer = [initializer] 65 initializer = [initializer]
64 66
65 if isinstance(initializer, list): 67 if isinstance(initializer, list):
68 init_ratio = len(initializer) / len(token_ids)
66 initializer = (initializer * len(token_ids))[:len(token_ids)] 69 initializer = (initializer * len(token_ids))[:len(token_ids)]
67 70
68 with torch.no_grad(): 71 with torch.no_grad():
@@ -76,6 +79,8 @@ class ManagedCLIPTextEmbeddings(CLIPTextEmbeddings):
76 dtype=self.temp_token_embedding.weight.dtype, 79 dtype=self.temp_token_embedding.weight.dtype,
77 ) 80 )
78 81
82 return init_ratio
83
79 def load_embed(self, input_ids: list[int], filename: Path): 84 def load_embed(self, input_ids: list[int], filename: Path):
80 with safe_open(filename, framework="pt", device="cpu") as file: 85 with safe_open(filename, framework="pt", device="cpu") as file:
81 self.add_embed(input_ids, file.get_tensor("embed")) 86 self.add_embed(input_ids, file.get_tensor("embed"))
diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py
index 4e97ab5..034adf9 100644
--- a/models/clip/tokenizer.py
+++ b/models/clip/tokenizer.py
@@ -55,11 +55,6 @@ def shuffle_auto(tokens: list[int]):
55 return shuffle_all(tokens) 55 return shuffle_all(tokens)
56 56
57 57
58class MultiCLIPTokenizerItem(NamedTuple):
59 token: str
60 ids: list[int]
61
62
63class MultiCLIPTokenizer(CLIPTokenizer): 58class MultiCLIPTokenizer(CLIPTokenizer):
64 def __init__(self, *args, **kwargs): 59 def __init__(self, *args, **kwargs):
65 super().__init__(*args, **kwargs) 60 super().__init__(*args, **kwargs)
@@ -96,7 +91,7 @@ class MultiCLIPTokenizer(CLIPTokenizer):
96 self, 91 self,
97 new_tokens: Union[str, list[str]], 92 new_tokens: Union[str, list[str]],
98 num_vectors: Union[int, list[int]] = 1 93 num_vectors: Union[int, list[int]] = 1
99 ) -> Union[MultiCLIPTokenizerItem, list[MultiCLIPTokenizerItem]]: 94 ) -> Union[list[int], list[list[int]]]:
100 if isinstance(new_tokens, list): 95 if isinstance(new_tokens, list):
101 if isinstance(num_vectors, int): 96 if isinstance(num_vectors, int):
102 num_vectors = [num_vectors] * len(new_tokens) 97 num_vectors = [num_vectors] * len(new_tokens)
@@ -119,7 +114,7 @@ class MultiCLIPTokenizer(CLIPTokenizer):
119 114
120 self.token_map[ids[0]] = ids 115 self.token_map[ids[0]] = ids
121 116
122 return MultiCLIPTokenizerItem(new_tokens, ids) 117 return ids
123 118
124 def expand_id(self, id: int): 119 def expand_id(self, id: int):
125 if id in self.token_map: 120 if id in self.token_map:
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 2e0696b..c658ad6 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -4,9 +4,9 @@ import math
4import datetime 4import datetime
5import logging 5import logging
6from pathlib import Path 6from pathlib import Path
7from functools import partial
7 8
8import torch 9import torch
9import torch.nn.functional as F
10import torch.utils.checkpoint 10import torch.utils.checkpoint
11 11
12from accelerate import Accelerator 12from accelerate import Accelerator
@@ -20,9 +20,10 @@ from tqdm.auto import tqdm
20from transformers import CLIPTextModel 20from transformers import CLIPTextModel
21from slugify import slugify 21from slugify import slugify
22 22
23from common import load_config, load_embeddings_from_dir 23from util import load_config, load_embeddings_from_dir
24from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 24from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
25from data.csv import CSVDataModule, CSVDataItem 25from data.csv import CSVDataModule, CSVDataItem
26from training.common import run_model
26from training.optimization import get_one_cycle_schedule 27from training.optimization import get_one_cycle_schedule
27from training.lr import LRFinder 28from training.lr import LRFinder
28from training.util import AverageMeter, CheckpointerBase, save_args 29from training.util import AverageMeter, CheckpointerBase, save_args
@@ -610,8 +611,8 @@ def main():
610 if not embeddings_dir.exists() or not embeddings_dir.is_dir(): 611 if not embeddings_dir.exists() or not embeddings_dir.is_dir():
611 raise ValueError("--embeddings_dir must point to an existing directory") 612 raise ValueError("--embeddings_dir must point to an existing directory")
612 613
613 added_tokens_from_dir = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) 614 added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir)
614 print(f"Added {len(added_tokens_from_dir)} tokens from embeddings dir: {added_tokens_from_dir}") 615 print(f"Added {len(added_tokens)} tokens from embeddings dir: {zip(added_tokens, added_ids)}")
615 616
616 if len(args.placeholder_token) != 0: 617 if len(args.placeholder_token) != 0:
617 # Convert the initializer_token, placeholder_token to ids 618 # Convert the initializer_token, placeholder_token to ids
@@ -620,13 +621,15 @@ def main():
620 for token in args.initializer_token 621 for token in args.initializer_token
621 ] 622 ]
622 623
623 new_tokens = tokenizer.add_multi_tokens(args.placeholder_token, args.num_vectors) 624 new_ids = tokenizer.add_multi_tokens(args.placeholder_token, args.num_vectors)
624 embeddings.resize(len(tokenizer)) 625 embeddings.resize(len(tokenizer))
625 626
626 for (new_token, init_ids) in zip(new_tokens, initializer_token_ids): 627 init_ratios = [
627 embeddings.add_embed(new_token.ids, init_ids) 628 embeddings.add_embed(new_id, init_ids)
629 for (new_id, init_ids) in zip(new_ids, initializer_token_ids)
630 ]
628 631
629 print(f"Added {len(new_tokens)} new tokens.") 632 print(f"Added {len(new_ids)} new tokens: {zip(args.placeholder_token, new_ids, init_ratios)}")
630 else: 633 else:
631 placeholder_token_id = [] 634 placeholder_token_id = []
632 635
@@ -856,63 +859,16 @@ def main():
856 def on_eval(): 859 def on_eval():
857 tokenizer.eval() 860 tokenizer.eval()
858 861
859 def loop(step: int, batch, eval: bool = False): 862 loop = partial(
860 # Convert images to latent space 863 run_model,
861 latents = vae.encode(batch["pixel_values"]).latent_dist.sample() 864 vae=vae,
862 latents = latents * 0.18215 865 noise_scheduler=noise_scheduler,
863 866 unet=unet,
864 # Sample noise that we'll add to the latents 867 prompt_processor=prompt_processor,
865 noise = torch.randn_like(latents) 868 num_class_images=args.num_class_images,
866 bsz = latents.shape[0] 869 prior_loss_weight=args.prior_loss_weight,
867 # Sample a random timestep for each image 870 seed=args.seed,
868 timesteps_gen = torch.Generator(device=latents.device).manual_seed(args.seed + step) if eval else None 871 )
869 timesteps = torch.randint(
870 0,
871 noise_scheduler.config.num_train_timesteps,
872 (bsz,),
873 generator=timesteps_gen,
874 device=latents.device,
875 )
876 timesteps = timesteps.long()
877
878 # Add noise to the latents according to the noise magnitude at each timestep
879 # (this is the forward diffusion process)
880 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
881 noisy_latents = noisy_latents.to(dtype=unet.dtype)
882
883 # Get the text embedding for conditioning
884 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"])
885
886 # Predict the noise residual
887 model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
888
889 # Get the target for loss depending on the prediction type
890 if noise_scheduler.config.prediction_type == "epsilon":
891 target = noise
892 elif noise_scheduler.config.prediction_type == "v_prediction":
893 target = noise_scheduler.get_velocity(latents, noise, timesteps)
894 else:
895 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
896
897 if args.num_class_images != 0:
898 # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
899 model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
900 target, target_prior = torch.chunk(target, 2, dim=0)
901
902 # Compute instance loss
903 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
904
905 # Compute prior loss
906 prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
907
908 # Add the prior loss to the instance loss.
909 loss = loss + args.prior_loss_weight * prior_loss
910 else:
911 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
912
913 acc = (model_pred == target).float().mean()
914
915 return loss, acc, bsz
916 872
917 # We need to initialize the trackers we use, and also store our configuration. 873 # We need to initialize the trackers we use, and also store our configuration.
918 # The trackers initializes automatically on the main process. 874 # The trackers initializes automatically on the main process.
diff --git a/train_lora.py b/train_lora.py
deleted file mode 100644
index de878a4..0000000
--- a/train_lora.py
+++ /dev/null
@@ -1,946 +0,0 @@
1import argparse
2import itertools
3import math
4import datetime
5import logging
6import json
7from pathlib import Path
8
9import torch
10import torch.nn.functional as F
11import torch.utils.checkpoint
12
13from accelerate import Accelerator
14from accelerate.logging import get_logger
15from accelerate.utils import LoggerType, set_seed
16from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel
17from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup
18from diffusers.training_utils import EMAModel
19from tqdm.auto import tqdm
20from transformers import CLIPTextModel, CLIPTokenizer
21from slugify import slugify
22
23from common import load_text_embeddings, load_config
24from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
25from data.csv import CSVDataModule
26from training.lora import LoraAttnProcessor
27from training.optimization import get_one_cycle_schedule
28from training.util import AverageMeter, CheckpointerBase, save_args
29from models.clip.prompt import PromptProcessor
30
31logger = get_logger(__name__)
32
33
34torch.backends.cuda.matmul.allow_tf32 = True
35torch.backends.cudnn.benchmark = True
36
37
38def parse_args():
39 parser = argparse.ArgumentParser(
40 description="Simple example of a training script."
41 )
42 parser.add_argument(
43 "--pretrained_model_name_or_path",
44 type=str,
45 default=None,
46 help="Path to pretrained model or model identifier from huggingface.co/models.",
47 )
48 parser.add_argument(
49 "--tokenizer_name",
50 type=str,
51 default=None,
52 help="Pretrained tokenizer name or path if not the same as model_name",
53 )
54 parser.add_argument(
55 "--train_data_file",
56 type=str,
57 default=None,
58 help="A folder containing the training data."
59 )
60 parser.add_argument(
61 "--train_data_template",
62 type=str,
63 default="template",
64 )
65 parser.add_argument(
66 "--instance_identifier",
67 type=str,
68 default=None,
69 help="A token to use as a placeholder for the concept.",
70 )
71 parser.add_argument(
72 "--class_identifier",
73 type=str,
74 default=None,
75 help="A token to use as a placeholder for the concept.",
76 )
77 parser.add_argument(
78 "--placeholder_token",
79 type=str,
80 nargs='*',
81 default=[],
82 help="A token to use as a placeholder for the concept.",
83 )
84 parser.add_argument(
85 "--initializer_token",
86 type=str,
87 nargs='*',
88 default=[],
89 help="A token to use as initializer word."
90 )
91 parser.add_argument(
92 "--tag_dropout",
93 type=float,
94 default=0.1,
95 help="Tag dropout probability.",
96 )
97 parser.add_argument(
98 "--num_class_images",
99 type=int,
100 default=400,
101 help="How many class images to generate."
102 )
103 parser.add_argument(
104 "--repeats",
105 type=int,
106 default=1,
107 help="How many times to repeat the training data."
108 )
109 parser.add_argument(
110 "--output_dir",
111 type=str,
112 default="output/lora",
113 help="The output directory where the model predictions and checkpoints will be written.",
114 )
115 parser.add_argument(
116 "--embeddings_dir",
117 type=str,
118 default=None,
119 help="The embeddings directory where Textual Inversion embeddings are stored.",
120 )
121 parser.add_argument(
122 "--mode",
123 type=str,
124 default=None,
125 help="A mode to filter the dataset.",
126 )
127 parser.add_argument(
128 "--seed",
129 type=int,
130 default=None,
131 help="A seed for reproducible training."
132 )
133 parser.add_argument(
134 "--resolution",
135 type=int,
136 default=768,
137 help=(
138 "The resolution for input images, all the images in the train/validation dataset will be resized to this"
139 " resolution"
140 ),
141 )
142 parser.add_argument(
143 "--center_crop",
144 action="store_true",
145 help="Whether to center crop images before resizing to resolution"
146 )
147 parser.add_argument(
148 "--dataloader_num_workers",
149 type=int,
150 default=0,
151 help=(
152 "The number of subprocesses to use for data loading. 0 means that the data will be loaded in the main"
153 " process."
154 ),
155 )
156 parser.add_argument(
157 "--num_train_epochs",
158 type=int,
159 default=100
160 )
161 parser.add_argument(
162 "--max_train_steps",
163 type=int,
164 default=None,
165 help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
166 )
167 parser.add_argument(
168 "--gradient_accumulation_steps",
169 type=int,
170 default=1,
171 help="Number of updates steps to accumulate before performing a backward/update pass.",
172 )
173 parser.add_argument(
174 "--gradient_checkpointing",
175 action="store_true",
176 help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
177 )
178 parser.add_argument(
179 "--learning_rate",
180 type=float,
181 default=2e-6,
182 help="Initial learning rate (after the potential warmup period) to use.",
183 )
184 parser.add_argument(
185 "--scale_lr",
186 action="store_true",
187 default=True,
188 help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
189 )
190 parser.add_argument(
191 "--lr_scheduler",
192 type=str,
193 default="one_cycle",
194 help=(
195 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
196 ' "constant", "constant_with_warmup", "one_cycle"]'
197 ),
198 )
199 parser.add_argument(
200 "--lr_warmup_epochs",
201 type=int,
202 default=10,
203 help="Number of steps for the warmup in the lr scheduler."
204 )
205 parser.add_argument(
206 "--lr_cycles",
207 type=int,
208 default=None,
209 help="Number of restart cycles in the lr scheduler (if supported)."
210 )
211 parser.add_argument(
212 "--use_8bit_adam",
213 action="store_true",
214 default=True,
215 help="Whether or not to use 8-bit Adam from bitsandbytes."
216 )
217 parser.add_argument(
218 "--adam_beta1",
219 type=float,
220 default=0.9,
221 help="The beta1 parameter for the Adam optimizer."
222 )
223 parser.add_argument(
224 "--adam_beta2",
225 type=float,
226 default=0.999,
227 help="The beta2 parameter for the Adam optimizer."
228 )
229 parser.add_argument(
230 "--adam_weight_decay",
231 type=float,
232 default=1e-2,
233 help="Weight decay to use."
234 )
235 parser.add_argument(
236 "--adam_epsilon",
237 type=float,
238 default=1e-08,
239 help="Epsilon value for the Adam optimizer"
240 )
241 parser.add_argument(
242 "--mixed_precision",
243 type=str,
244 default="no",
245 choices=["no", "fp16", "bf16"],
246 help=(
247 "Whether to use mixed precision. Choose"
248 "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
249 "and an Nvidia Ampere GPU."
250 ),
251 )
252 parser.add_argument(
253 "--sample_frequency",
254 type=int,
255 default=1,
256 help="How often to save a checkpoint and sample image",
257 )
258 parser.add_argument(
259 "--sample_image_size",
260 type=int,
261 default=768,
262 help="Size of sample images",
263 )
264 parser.add_argument(
265 "--sample_batches",
266 type=int,
267 default=1,
268 help="Number of sample batches to generate per checkpoint",
269 )
270 parser.add_argument(
271 "--sample_batch_size",
272 type=int,
273 default=1,
274 help="Number of samples to generate per batch",
275 )
276 parser.add_argument(
277 "--valid_set_size",
278 type=int,
279 default=None,
280 help="Number of images in the validation dataset."
281 )
282 parser.add_argument(
283 "--train_batch_size",
284 type=int,
285 default=1,
286 help="Batch size (per device) for the training dataloader."
287 )
288 parser.add_argument(
289 "--sample_steps",
290 type=int,
291 default=15,
292 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.",
293 )
294 parser.add_argument(
295 "--prior_loss_weight",
296 type=float,
297 default=1.0,
298 help="The weight of prior preservation loss."
299 )
300 parser.add_argument(
301 "--max_grad_norm",
302 default=1.0,
303 type=float,
304 help="Max gradient norm."
305 )
306 parser.add_argument(
307 "--noise_timesteps",
308 type=int,
309 default=1000,
310 )
311 parser.add_argument(
312 "--config",
313 type=str,
314 default=None,
315 help="Path to a JSON configuration file containing arguments for invoking this script."
316 )
317
318 args = parser.parse_args()
319 if args.config is not None:
320 args = load_config(args.config)
321 args = parser.parse_args(namespace=argparse.Namespace(**args))
322
323 if args.train_data_file is None:
324 raise ValueError("You must specify --train_data_file")
325
326 if args.pretrained_model_name_or_path is None:
327 raise ValueError("You must specify --pretrained_model_name_or_path")
328
329 if args.instance_identifier is None:
330 raise ValueError("You must specify --instance_identifier")
331
332 if isinstance(args.initializer_token, str):
333 args.initializer_token = [args.initializer_token]
334
335 if isinstance(args.placeholder_token, str):
336 args.placeholder_token = [args.placeholder_token]
337
338 if len(args.placeholder_token) == 0:
339 args.placeholder_token = [f"<*{i}>" for i in range(len(args.initializer_token))]
340
341 if len(args.placeholder_token) != len(args.initializer_token):
342 raise ValueError("Number of items in --placeholder_token and --initializer_token must match")
343
344 if args.output_dir is None:
345 raise ValueError("You must specify --output_dir")
346
347 return args
348
349
350class Checkpointer(CheckpointerBase):
351 def __init__(
352 self,
353 datamodule,
354 accelerator,
355 vae,
356 unet,
357 tokenizer,
358 text_encoder,
359 unet_lora,
360 scheduler,
361 instance_identifier,
362 placeholder_token,
363 placeholder_token_id,
364 output_dir: Path,
365 sample_image_size,
366 sample_batches,
367 sample_batch_size,
368 seed
369 ):
370 super().__init__(
371 datamodule=datamodule,
372 output_dir=output_dir,
373 instance_identifier=instance_identifier,
374 placeholder_token=placeholder_token,
375 placeholder_token_id=placeholder_token_id,
376 sample_image_size=sample_image_size,
377 seed=seed or torch.random.seed(),
378 sample_batches=sample_batches,
379 sample_batch_size=sample_batch_size
380 )
381
382 self.accelerator = accelerator
383 self.vae = vae
384 self.unet = unet
385 self.tokenizer = tokenizer
386 self.text_encoder = text_encoder
387 self.unet_lora = unet_lora
388 self.scheduler = scheduler
389
390 @torch.no_grad()
391 def save_model(self):
392 print("Saving model...")
393
394 unet_lora = self.accelerator.unwrap_model(self.unet_lora)
395 unet_lora.save_pretrained(self.output_dir.joinpath("model"))
396
397 del unet_lora
398
399 if torch.cuda.is_available():
400 torch.cuda.empty_cache()
401
402 @torch.no_grad()
403 def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0):
404 # Save a sample image
405 pipeline = VlpnStableDiffusion(
406 text_encoder=self.text_encoder,
407 vae=self.vae,
408 unet=self.unet,
409 tokenizer=self.tokenizer,
410 scheduler=self.scheduler,
411 ).to(self.accelerator.device)
412 pipeline.set_progress_bar_config(dynamic_ncols=True)
413
414 super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta)
415
416 del pipeline
417 del generator
418 del stable_latents
419
420 if torch.cuda.is_available():
421 torch.cuda.empty_cache()
422
423
424def main():
425 args = parse_args()
426
427 instance_identifier = args.instance_identifier
428
429 if len(args.placeholder_token) != 0:
430 instance_identifier = instance_identifier.format(args.placeholder_token[0])
431
432 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
433 basepath = Path(args.output_dir).joinpath(slugify(instance_identifier), now)
434 basepath.mkdir(parents=True, exist_ok=True)
435
436 accelerator = Accelerator(
437 log_with=LoggerType.TENSORBOARD,
438 logging_dir=f"{basepath}",
439 gradient_accumulation_steps=args.gradient_accumulation_steps,
440 mixed_precision=args.mixed_precision
441 )
442
443 logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG)
444
445 args.seed = args.seed or (torch.random.seed() >> 32)
446 set_seed(args.seed)
447
448 save_args(basepath, args)
449
450 # Load the tokenizer and add the placeholder token as a additional special token
451 if args.tokenizer_name:
452 tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
453 elif args.pretrained_model_name_or_path:
454 tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer')
455
456 # Load models and create wrapper for stable diffusion
457 text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder')
458 vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae')
459 unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='unet')
460 noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder='scheduler')
461 checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained(
462 args.pretrained_model_name_or_path, subfolder='scheduler')
463
464 unet_lora = LoraAttnProcessor(
465 cross_attention_dim=unet.cross_attention_dim,
466 inner_dim=unet.in_channels,
467 r=4,
468 )
469
470 vae.enable_slicing()
471 vae.set_use_memory_efficient_attention_xformers(True)
472 unet.set_use_memory_efficient_attention_xformers(True)
473 unet.set_attn_processor(unet_lora)
474
475 if args.gradient_checkpointing:
476 unet.enable_gradient_checkpointing()
477 text_encoder.gradient_checkpointing_enable()
478
479 # Freeze text_encoder and vae
480 vae.requires_grad_(False)
481 unet.requires_grad_(False)
482
483 if args.embeddings_dir is not None:
484 embeddings_dir = Path(args.embeddings_dir)
485 if not embeddings_dir.exists() or not embeddings_dir.is_dir():
486 raise ValueError("--embeddings_dir must point to an existing directory")
487 added_tokens = load_text_embeddings(tokenizer, text_encoder, embeddings_dir)
488 print(f"Added {len(added_tokens)} tokens from embeddings dir: {added_tokens}")
489
490 if len(args.placeholder_token) != 0:
491 # Convert the initializer_token, placeholder_token to ids
492 initializer_token_ids = torch.stack([
493 torch.tensor(tokenizer.encode(token, add_special_tokens=False)[:1])
494 for token in args.initializer_token
495 ])
496
497 num_added_tokens = tokenizer.add_tokens(args.placeholder_token)
498 print(f"Added {num_added_tokens} new tokens.")
499
500 placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
501
502 # Resize the token embeddings as we are adding new special tokens to the tokenizer
503 text_encoder.resize_token_embeddings(len(tokenizer))
504
505 token_embeds = text_encoder.get_input_embeddings().weight.data
506 original_token_embeds = token_embeds.clone().to(accelerator.device)
507 initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids)
508
509 for (token_id, embeddings) in zip(placeholder_token_id, initializer_token_embeddings):
510 token_embeds[token_id] = embeddings
511 else:
512 placeholder_token_id = []
513
514 print(f"Training added text embeddings")
515
516 text_encoder.text_model.encoder.requires_grad_(False)
517 text_encoder.text_model.final_layer_norm.requires_grad_(False)
518 text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
519
520 index_fixed_tokens = torch.arange(len(tokenizer))
521 index_fixed_tokens = index_fixed_tokens[~torch.isin(index_fixed_tokens, torch.tensor(placeholder_token_id))]
522
523 prompt_processor = PromptProcessor(tokenizer, text_encoder)
524
525 if args.scale_lr:
526 args.learning_rate = (
527 args.learning_rate * args.gradient_accumulation_steps *
528 args.train_batch_size * accelerator.num_processes
529 )
530
531 # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
532 if args.use_8bit_adam:
533 try:
534 import bitsandbytes as bnb
535 except ImportError:
536 raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.")
537
538 optimizer_class = bnb.optim.AdamW8bit
539 else:
540 optimizer_class = torch.optim.AdamW
541
542 # Initialize the optimizer
543 optimizer = optimizer_class(
544 [
545 {
546 'params': unet_lora.parameters(),
547 'lr': args.learning_rate,
548 },
549 ],
550 betas=(args.adam_beta1, args.adam_beta2),
551 weight_decay=args.adam_weight_decay,
552 eps=args.adam_epsilon,
553 )
554
555 weight_dtype = torch.float32
556 if args.mixed_precision == "fp16":
557 weight_dtype = torch.float16
558 elif args.mixed_precision == "bf16":
559 weight_dtype = torch.bfloat16
560
561 def collate_fn(examples):
562 prompts = [example["prompts"] for example in examples]
563 nprompts = [example["nprompts"] for example in examples]
564 input_ids = [example["instance_prompt_ids"] for example in examples]
565 pixel_values = [example["instance_images"] for example in examples]
566
567 # concat class and instance examples for prior preservation
568 if args.num_class_images != 0 and "class_prompt_ids" in examples[0]:
569 input_ids += [example["class_prompt_ids"] for example in examples]
570 pixel_values += [example["class_images"] for example in examples]
571
572 pixel_values = torch.stack(pixel_values)
573 pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format)
574
575 inputs = prompt_processor.unify_input_ids(input_ids)
576
577 batch = {
578 "prompts": prompts,
579 "nprompts": nprompts,
580 "input_ids": inputs.input_ids,
581 "pixel_values": pixel_values,
582 "attention_mask": inputs.attention_mask,
583 }
584 return batch
585
586 datamodule = CSVDataModule(
587 data_file=args.train_data_file,
588 batch_size=args.train_batch_size,
589 prompt_processor=prompt_processor,
590 instance_identifier=instance_identifier,
591 class_identifier=args.class_identifier,
592 class_subdir="cls",
593 num_class_images=args.num_class_images,
594 size=args.resolution,
595 repeats=args.repeats,
596 mode=args.mode,
597 dropout=args.tag_dropout,
598 center_crop=args.center_crop,
599 template_key=args.train_data_template,
600 valid_set_size=args.valid_set_size,
601 num_workers=args.dataloader_num_workers,
602 collate_fn=collate_fn
603 )
604
605 datamodule.prepare_data()
606 datamodule.setup()
607
608 if args.num_class_images != 0:
609 missing_data = [item for item in datamodule.data_train if not item.class_image_path.exists()]
610
611 if len(missing_data) != 0:
612 batched_data = [
613 missing_data[i:i+args.sample_batch_size]
614 for i in range(0, len(missing_data), args.sample_batch_size)
615 ]
616
617 pipeline = VlpnStableDiffusion(
618 text_encoder=text_encoder,
619 vae=vae,
620 unet=unet,
621 tokenizer=tokenizer,
622 scheduler=checkpoint_scheduler,
623 ).to(accelerator.device)
624 pipeline.set_progress_bar_config(dynamic_ncols=True)
625
626 with torch.autocast("cuda"), torch.inference_mode():
627 for batch in batched_data:
628 image_name = [item.class_image_path for item in batch]
629 prompt = [item.prompt.format(identifier=args.class_identifier) for item in batch]
630 nprompt = [item.nprompt for item in batch]
631
632 images = pipeline(
633 prompt=prompt,
634 negative_prompt=nprompt,
635 num_inference_steps=args.sample_steps
636 ).images
637
638 for i, image in enumerate(images):
639 image.save(image_name[i])
640
641 del pipeline
642
643 if torch.cuda.is_available():
644 torch.cuda.empty_cache()
645
646 train_dataloader = datamodule.train_dataloader()
647 val_dataloader = datamodule.val_dataloader()
648
649 # Scheduler and math around the number of training steps.
650 overrode_max_train_steps = False
651 num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
652 if args.max_train_steps is None:
653 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
654 overrode_max_train_steps = True
655 num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
656
657 warmup_steps = args.lr_warmup_epochs * num_update_steps_per_epoch * args.gradient_accumulation_steps
658
659 if args.lr_scheduler == "one_cycle":
660 lr_scheduler = get_one_cycle_schedule(
661 optimizer=optimizer,
662 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
663 )
664 elif args.lr_scheduler == "cosine_with_restarts":
665 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
666 optimizer=optimizer,
667 num_warmup_steps=warmup_steps,
668 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
669 num_cycles=args.lr_cycles or math.ceil(math.sqrt(
670 ((args.max_train_steps - warmup_steps) / num_update_steps_per_epoch))),
671 )
672 else:
673 lr_scheduler = get_scheduler(
674 args.lr_scheduler,
675 optimizer=optimizer,
676 num_warmup_steps=warmup_steps,
677 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
678 )
679
680 unet_lora, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
681 unet_lora, optimizer, train_dataloader, val_dataloader, lr_scheduler
682 )
683
684 # Move text_encoder and vae to device
685 vae.to(accelerator.device, dtype=weight_dtype)
686 unet.to(accelerator.device, dtype=weight_dtype)
687 text_encoder.to(accelerator.device, dtype=weight_dtype)
688
689 # Keep text_encoder and vae in eval mode as we don't train these
690 vae.eval()
691 unet.eval()
692 text_encoder.eval()
693
694 # We need to recalculate our total training steps as the size of the training dataloader may have changed.
695 num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
696 if overrode_max_train_steps:
697 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
698
699 num_val_steps_per_epoch = len(val_dataloader)
700 num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
701 val_steps = num_val_steps_per_epoch * num_epochs
702
703 # We need to initialize the trackers we use, and also store our configuration.
704 # The trackers initializes automatically on the main process.
705 if accelerator.is_main_process:
706 config = vars(args).copy()
707 config["initializer_token"] = " ".join(config["initializer_token"])
708 config["placeholder_token"] = " ".join(config["placeholder_token"])
709 accelerator.init_trackers("lora", config=config)
710
711 # Train!
712 total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
713
714 logger.info("***** Running training *****")
715 logger.info(f" Num Epochs = {num_epochs}")
716 logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
717 logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
718 logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
719 logger.info(f" Total optimization steps = {args.max_train_steps}")
720 # Only show the progress bar once on each machine.
721
722 global_step = 0
723
724 avg_loss = AverageMeter()
725 avg_acc = AverageMeter()
726
727 avg_loss_val = AverageMeter()
728 avg_acc_val = AverageMeter()
729
730 max_acc_val = 0.0
731
732 checkpointer = Checkpointer(
733 datamodule=datamodule,
734 accelerator=accelerator,
735 vae=vae,
736 unet=unet,
737 tokenizer=tokenizer,
738 text_encoder=text_encoder,
739 scheduler=checkpoint_scheduler,
740 unet_lora=unet_lora,
741 output_dir=basepath,
742 instance_identifier=instance_identifier,
743 placeholder_token=args.placeholder_token,
744 placeholder_token_id=placeholder_token_id,
745 sample_image_size=args.sample_image_size,
746 sample_batch_size=args.sample_batch_size,
747 sample_batches=args.sample_batches,
748 seed=args.seed
749 )
750
751 if accelerator.is_main_process:
752 checkpointer.save_samples(0, args.sample_steps)
753
754 local_progress_bar = tqdm(
755 range(num_update_steps_per_epoch + num_val_steps_per_epoch),
756 disable=not accelerator.is_local_main_process,
757 dynamic_ncols=True
758 )
759 local_progress_bar.set_description("Epoch X / Y")
760
761 global_progress_bar = tqdm(
762 range(args.max_train_steps + val_steps),
763 disable=not accelerator.is_local_main_process,
764 dynamic_ncols=True
765 )
766 global_progress_bar.set_description("Total progress")
767
768 try:
769 for epoch in range(num_epochs):
770 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}")
771 local_progress_bar.reset()
772
773 unet_lora.train()
774
775 for step, batch in enumerate(train_dataloader):
776 with accelerator.accumulate(unet_lora):
777 # Convert images to latent space
778 latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
779 latents = latents * 0.18215
780
781 # Sample noise that we'll add to the latents
782 noise = torch.randn_like(latents)
783 bsz = latents.shape[0]
784 # Sample a random timestep for each image
785 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps,
786 (bsz,), device=latents.device)
787 timesteps = timesteps.long()
788
789 # Add noise to the latents according to the noise magnitude at each timestep
790 # (this is the forward diffusion process)
791 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
792
793 # Get the text embedding for conditioning
794 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"])
795
796 # Predict the noise residual
797 model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
798
799 # Get the target for loss depending on the prediction type
800 if noise_scheduler.config.prediction_type == "epsilon":
801 target = noise
802 elif noise_scheduler.config.prediction_type == "v_prediction":
803 target = noise_scheduler.get_velocity(latents, noise, timesteps)
804 else:
805 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
806
807 if args.num_class_images != 0:
808 # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
809 model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
810 target, target_prior = torch.chunk(target, 2, dim=0)
811
812 # Compute instance loss
813 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
814
815 # Compute prior loss
816 prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
817
818 # Add the prior loss to the instance loss.
819 loss = loss + args.prior_loss_weight * prior_loss
820 else:
821 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
822
823 acc = (model_pred == latents).float().mean()
824
825 accelerator.backward(loss)
826
827 if accelerator.sync_gradients:
828 accelerator.clip_grad_norm_(unet_lora.parameters(), args.max_grad_norm)
829
830 optimizer.step()
831 if not accelerator.optimizer_step_was_skipped:
832 lr_scheduler.step()
833 optimizer.zero_grad(set_to_none=True)
834
835 with torch.no_grad():
836 text_encoder.get_input_embeddings(
837 ).weight[index_fixed_tokens] = original_token_embeds[index_fixed_tokens]
838
839 avg_loss.update(loss.detach_(), bsz)
840 avg_acc.update(acc.detach_(), bsz)
841
842 # Checks if the accelerator has performed an optimization step behind the scenes
843 if accelerator.sync_gradients:
844 local_progress_bar.update(1)
845 global_progress_bar.update(1)
846
847 global_step += 1
848
849 logs = {
850 "train/loss": avg_loss.avg.item(),
851 "train/acc": avg_acc.avg.item(),
852 "train/cur_loss": loss.item(),
853 "train/cur_acc": acc.item(),
854 "lr/unet": lr_scheduler.get_last_lr()[0],
855 "lr/text": lr_scheduler.get_last_lr()[1]
856 }
857
858 accelerator.log(logs, step=global_step)
859
860 local_progress_bar.set_postfix(**logs)
861
862 if global_step >= args.max_train_steps:
863 break
864
865 accelerator.wait_for_everyone()
866
867 unet_lora.eval()
868
869 with torch.inference_mode():
870 for step, batch in enumerate(val_dataloader):
871 latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
872 latents = latents * 0.18215
873
874 noise = torch.randn_like(latents)
875 bsz = latents.shape[0]
876 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps,
877 (bsz,), device=latents.device)
878 timesteps = timesteps.long()
879
880 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
881
882 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"])
883
884 model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
885
886 # Get the target for loss depending on the prediction type
887 if noise_scheduler.config.prediction_type == "epsilon":
888 target = noise
889 elif noise_scheduler.config.prediction_type == "v_prediction":
890 target = noise_scheduler.get_velocity(latents, noise, timesteps)
891 else:
892 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
893
894 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
895
896 acc = (model_pred == latents).float().mean()
897
898 avg_loss_val.update(loss.detach_(), bsz)
899 avg_acc_val.update(acc.detach_(), bsz)
900
901 if accelerator.sync_gradients:
902 local_progress_bar.update(1)
903 global_progress_bar.update(1)
904
905 logs = {
906 "val/loss": avg_loss_val.avg.item(),
907 "val/acc": avg_acc_val.avg.item(),
908 "val/cur_loss": loss.item(),
909 "val/cur_acc": acc.item(),
910 }
911 local_progress_bar.set_postfix(**logs)
912
913 accelerator.log({
914 "val/loss": avg_loss_val.avg.item(),
915 "val/acc": avg_acc_val.avg.item(),
916 }, step=global_step)
917
918 local_progress_bar.clear()
919 global_progress_bar.clear()
920
921 if avg_acc_val.avg.item() > max_acc_val:
922 accelerator.print(
923 f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}")
924 max_acc_val = avg_acc_val.avg.item()
925
926 if accelerator.is_main_process:
927 if (epoch + 1) % args.sample_frequency == 0:
928 checkpointer.save_samples(global_step, args.sample_steps)
929
930 # Create the pipeline using using the trained modules and save it.
931 if accelerator.is_main_process:
932 print("Finished! Saving final checkpoint and resume state.")
933 checkpointer.save_model()
934
935 accelerator.end_training()
936
937 except KeyboardInterrupt:
938 if accelerator.is_main_process:
939 print("Interrupted, saving checkpoint and resume state...")
940 checkpointer.save_model()
941 accelerator.end_training()
942 quit()
943
944
945if __name__ == "__main__":
946 main()
diff --git a/train_ti.py b/train_ti.py
index 8ada98c..5df6850 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -3,9 +3,9 @@ import math
3import datetime 3import datetime
4import logging 4import logging
5from pathlib import Path 5from pathlib import Path
6from functools import partial
6 7
7import torch 8import torch
8import torch.nn.functional as F
9import torch.utils.checkpoint 9import torch.utils.checkpoint
10 10
11from accelerate import Accelerator 11from accelerate import Accelerator
@@ -18,9 +18,10 @@ from tqdm.auto import tqdm
18from transformers import CLIPTextModel 18from transformers import CLIPTextModel
19from slugify import slugify 19from slugify import slugify
20 20
21from common import load_config, load_embeddings_from_dir 21from util import load_config, load_embeddings_from_dir
22from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 22from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
23from data.csv import CSVDataModule, CSVDataItem 23from data.csv import CSVDataModule, CSVDataItem
24from training.common import run_model
24from training.optimization import get_one_cycle_schedule 25from training.optimization import get_one_cycle_schedule
25from training.lr import LRFinder 26from training.lr import LRFinder
26from training.util import AverageMeter, CheckpointerBase, save_args 27from training.util import AverageMeter, CheckpointerBase, save_args
@@ -570,8 +571,8 @@ def main():
570 if not embeddings_dir.exists() or not embeddings_dir.is_dir(): 571 if not embeddings_dir.exists() or not embeddings_dir.is_dir():
571 raise ValueError("--embeddings_dir must point to an existing directory") 572 raise ValueError("--embeddings_dir must point to an existing directory")
572 573
573 added_tokens_from_dir = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) 574 added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir)
574 print(f"Added {len(added_tokens_from_dir)} tokens from embeddings dir: {added_tokens_from_dir}") 575 print(f"Added {len(added_tokens)} tokens from embeddings dir: {zip(added_tokens, added_ids)}")
575 576
576 # Convert the initializer_token, placeholder_token to ids 577 # Convert the initializer_token, placeholder_token to ids
577 initializer_token_ids = [ 578 initializer_token_ids = [
@@ -579,13 +580,15 @@ def main():
579 for token in args.initializer_token 580 for token in args.initializer_token
580 ] 581 ]
581 582
582 new_tokens = tokenizer.add_multi_tokens(args.placeholder_token, args.num_vectors) 583 new_ids = tokenizer.add_multi_tokens(args.placeholder_token, args.num_vectors)
583 embeddings.resize(len(tokenizer)) 584 embeddings.resize(len(tokenizer))
584 585
585 for (new_token, init_ids) in zip(new_tokens, initializer_token_ids): 586 init_ratios = [
586 embeddings.add_embed(new_token.ids, init_ids) 587 embeddings.add_embed(new_id, init_ids)
588 for (new_id, init_ids) in zip(new_ids, initializer_token_ids)
589 ]
587 590
588 print(f"Added {len(new_tokens)} new tokens.") 591 print(f"Added {len(new_ids)} new tokens: {zip(args.placeholder_token, new_ids, init_ratios)}")
589 592
590 vae.requires_grad_(False) 593 vae.requires_grad_(False)
591 unet.requires_grad_(False) 594 unet.requires_grad_(False)
@@ -807,63 +810,16 @@ def main():
807 def on_eval(): 810 def on_eval():
808 tokenizer.eval() 811 tokenizer.eval()
809 812
810 def loop(step: int, batch, eval: bool = False): 813 loop = partial(
811 # Convert images to latent space 814 run_model,
812 latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() 815 vae=vae,
813 latents = latents * 0.18215 816 noise_scheduler=noise_scheduler,
814 817 unet=unet,
815 # Sample noise that we'll add to the latents 818 prompt_processor=prompt_processor,
816 noise = torch.randn_like(latents) 819 num_class_images=args.num_class_images,
817 bsz = latents.shape[0] 820 prior_loss_weight=args.prior_loss_weight,
818 # Sample a random timestep for each image 821 seed=args.seed,
819 timesteps_gen = torch.Generator(device=latents.device).manual_seed(args.seed + step) if eval else None 822 )
820 timesteps = torch.randint(
821 0,
822 noise_scheduler.config.num_train_timesteps,
823 (bsz,),
824 generator=timesteps_gen,
825 device=latents.device,
826 )
827 timesteps = timesteps.long()
828
829 # Add noise to the latents according to the noise magnitude at each timestep
830 # (this is the forward diffusion process)
831 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
832
833 # Get the text embedding for conditioning
834 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"])
835 encoder_hidden_states = encoder_hidden_states.to(dtype=weight_dtype)
836
837 # Predict the noise residual
838 model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
839
840 # Get the target for loss depending on the prediction type
841 if noise_scheduler.config.prediction_type == "epsilon":
842 target = noise
843 elif noise_scheduler.config.prediction_type == "v_prediction":
844 target = noise_scheduler.get_velocity(latents, noise, timesteps)
845 else:
846 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
847
848 if args.num_class_images != 0:
849 # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
850 model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
851 target, target_prior = torch.chunk(target, 2, dim=0)
852
853 # Compute instance loss
854 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
855
856 # Compute prior loss
857 prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
858
859 # Add the prior loss to the instance loss.
860 loss = loss + args.prior_loss_weight * prior_loss
861 else:
862 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
863
864 acc = (model_pred == target).float().mean()
865
866 return loss, acc, bsz
867 823
868 # We need to initialize the trackers we use, and also store our configuration. 824 # We need to initialize the trackers we use, and also store our configuration.
869 # The trackers initializes automatically on the main process. 825 # The trackers initializes automatically on the main process.
diff --git a/training/common.py b/training/common.py
new file mode 100644
index 0000000..99a6e67
--- /dev/null
+++ b/training/common.py
@@ -0,0 +1,75 @@
1import torch
2import torch.nn.functional as F
3
4from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
5
6
7def run_model(
8 vae: AutoencoderKL,
9 noise_scheduler: DDPMScheduler,
10 unet: UNet2DConditionModel,
11 prompt_processor,
12 num_class_images: int,
13 prior_loss_weight: float,
14 seed: int,
15 step: int,
16 batch,
17 eval: bool = False
18):
19 # Convert images to latent space
20 latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach()
21 latents = latents * 0.18215
22
23 # Sample noise that we'll add to the latents
24 noise = torch.randn_like(latents)
25 bsz = latents.shape[0]
26 # Sample a random timestep for each image
27 timesteps_gen = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None
28 timesteps = torch.randint(
29 0,
30 noise_scheduler.config.num_train_timesteps,
31 (bsz,),
32 generator=timesteps_gen,
33 device=latents.device,
34 )
35 timesteps = timesteps.long()
36
37 # Add noise to the latents according to the noise magnitude at each timestep
38 # (this is the forward diffusion process)
39 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
40 noisy_latents = noisy_latents.to(dtype=unet.dtype)
41
42 # Get the text embedding for conditioning
43 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"], batch["attention_mask"])
44 encoder_hidden_states = encoder_hidden_states.to(dtype=unet.dtype)
45
46 # Predict the noise residual
47 model_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
48
49 # Get the target for loss depending on the prediction type
50 if noise_scheduler.config.prediction_type == "epsilon":
51 target = noise
52 elif noise_scheduler.config.prediction_type == "v_prediction":
53 target = noise_scheduler.get_velocity(latents, noise, timesteps)
54 else:
55 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
56
57 if num_class_images != 0:
58 # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
59 model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
60 target, target_prior = torch.chunk(target, 2, dim=0)
61
62 # Compute instance loss
63 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
64
65 # Compute prior loss
66 prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="mean")
67
68 # Add the prior loss to the instance loss.
69 loss = loss + prior_loss_weight * prior_loss
70 else:
71 loss = F.mse_loss(model_pred.float(), target.float(), reduction="mean")
72
73 acc = (model_pred == target).float().mean()
74
75 return loss, acc, bsz
diff --git a/common.py b/util.py
index 0887197..545bcb5 100644
--- a/common.py
+++ b/util.py
@@ -24,8 +24,9 @@ def load_embeddings_from_dir(tokenizer: MultiCLIPTokenizer, embeddings: ManagedC
24 return [] 24 return []
25 25
26 filenames = [filename for filename in embeddings_dir.iterdir() if filename.is_file()] 26 filenames = [filename for filename in embeddings_dir.iterdir() if filename.is_file()]
27 tokens = [filename.stem for filename in filenames]
27 28
28 new_tokens = [] 29 new_ids: list[list[int]] = []
29 new_embeds = [] 30 new_embeds = []
30 31
31 for filename in filenames: 32 for filename in filenames:
@@ -33,12 +34,12 @@ def load_embeddings_from_dir(tokenizer: MultiCLIPTokenizer, embeddings: ManagedC
33 embed = file.get_tensor("embed") 34 embed = file.get_tensor("embed")
34 35
35 added = tokenizer.add_multi_tokens(filename.stem, embed.shape[0]) 36 added = tokenizer.add_multi_tokens(filename.stem, embed.shape[0])
36 new_tokens.append(added) 37 new_ids.append(added)
37 new_embeds.append(embed) 38 new_embeds.append(embed)
38 39
39 embeddings.resize(len(tokenizer)) 40 embeddings.resize(len(tokenizer))
40 41
41 for (new_token, embeds) in zip(new_tokens, new_embeds): 42 for (new_id, embeds) in zip(new_ids, new_embeds):
42 embeddings.add_embed(new_token.ids, embeds) 43 embeddings.add_embed(new_id, embeds)
43 44
44 return new_tokens 45 return tokens, new_ids