summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--data/dreambooth/csv.py11
-rw-r--r--data/dreambooth/prompt.py18
-rw-r--r--dreambooth.py25
-rw-r--r--textual_dreambooth.py948
-rw-r--r--textual_inversion.py13
5 files changed, 968 insertions, 47 deletions
diff --git a/data/dreambooth/csv.py b/data/dreambooth/csv.py
index 9075979..abd329d 100644
--- a/data/dreambooth/csv.py
+++ b/data/dreambooth/csv.py
@@ -15,6 +15,7 @@ class CSVDataModule(pl.LightningDataModule):
15 tokenizer, 15 tokenizer,
16 instance_identifier, 16 instance_identifier,
17 class_identifier=None, 17 class_identifier=None,
18 class_subdir="db_cls",
18 size=512, 19 size=512,
19 repeats=100, 20 repeats=100,
20 interpolation="bicubic", 21 interpolation="bicubic",
@@ -30,7 +31,7 @@ class CSVDataModule(pl.LightningDataModule):
30 raise ValueError("data_file must be a file") 31 raise ValueError("data_file must be a file")
31 32
32 self.data_root = self.data_file.parent 33 self.data_root = self.data_file.parent
33 self.class_root = self.data_root.joinpath("db_cls") 34 self.class_root = self.data_root.joinpath(class_subdir)
34 self.class_root.mkdir(parents=True, exist_ok=True) 35 self.class_root.mkdir(parents=True, exist_ok=True)
35 36
36 self.tokenizer = tokenizer 37 self.tokenizer = tokenizer
@@ -140,11 +141,9 @@ class CSVDataset(Dataset):
140 if not instance_image.mode == "RGB": 141 if not instance_image.mode == "RGB":
141 instance_image = instance_image.convert("RGB") 142 instance_image = instance_image.convert("RGB")
142 143
143 instance_prompt = prompt.format(self.instance_identifier)
144
145 example["instance_images"] = instance_image 144 example["instance_images"] = instance_image
146 example["instance_prompt_ids"] = self.tokenizer( 145 example["instance_prompt_ids"] = self.tokenizer(
147 instance_prompt, 146 prompt.format(self.instance_identifier),
148 padding="do_not_pad", 147 padding="do_not_pad",
149 truncation=True, 148 truncation=True,
150 max_length=self.tokenizer.model_max_length, 149 max_length=self.tokenizer.model_max_length,
@@ -155,11 +154,9 @@ class CSVDataset(Dataset):
155 if not class_image.mode == "RGB": 154 if not class_image.mode == "RGB":
156 class_image = class_image.convert("RGB") 155 class_image = class_image.convert("RGB")
157 156
158 class_prompt = prompt.format(self.class_identifier)
159
160 example["class_images"] = class_image 157 example["class_images"] = class_image
161 example["class_prompt_ids"] = self.tokenizer( 158 example["class_prompt_ids"] = self.tokenizer(
162 class_prompt, 159 prompt.format(self.class_identifier),
163 padding="do_not_pad", 160 padding="do_not_pad",
164 truncation=True, 161 truncation=True,
165 max_length=self.tokenizer.model_max_length, 162 max_length=self.tokenizer.model_max_length,
diff --git a/data/dreambooth/prompt.py b/data/dreambooth/prompt.py
deleted file mode 100644
index b3a83ce..0000000
--- a/data/dreambooth/prompt.py
+++ /dev/null
@@ -1,18 +0,0 @@
1from torch.utils.data import Dataset
2
3
4class PromptDataset(Dataset):
5 def __init__(self, prompt, nprompt, num_samples):
6 self.prompt = prompt
7 self.nprompt = nprompt
8 self.num_samples = num_samples
9
10 def __len__(self):
11 return self.num_samples
12
13 def __getitem__(self, index):
14 example = {}
15 example["prompt"] = self.prompt
16 example["nprompt"] = self.nprompt
17 example["index"] = index
18 return example
diff --git a/dreambooth.py b/dreambooth.py
index aedf25c..0c5c42a 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -24,7 +24,6 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
24import json 24import json
25 25
26from data.dreambooth.csv import CSVDataModule 26from data.dreambooth.csv import CSVDataModule
27from data.dreambooth.prompt import PromptDataset
28 27
29logger = get_logger(__name__) 28logger = get_logger(__name__)
30 29
@@ -122,7 +121,7 @@ def parse_args():
122 parser.add_argument( 121 parser.add_argument(
123 "--learning_rate", 122 "--learning_rate",
124 type=float, 123 type=float,
125 default=3e-6, 124 default=1e-6,
126 help="Initial learning rate (after the potential warmup period) to use.", 125 help="Initial learning rate (after the potential warmup period) to use.",
127 ) 126 )
128 parser.add_argument( 127 parser.add_argument(
@@ -219,17 +218,10 @@ def parse_args():
219 parser.add_argument( 218 parser.add_argument(
220 "--sample_steps", 219 "--sample_steps",
221 type=int, 220 type=int,
222 default=40, 221 default=30,
223 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", 222 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.",
224 ) 223 )
225 parser.add_argument( 224 parser.add_argument(
226 "--class_data_dir",
227 type=str,
228 default=None,
229 required=False,
230 help="A folder containing the training data of class images.",
231 )
232 parser.add_argument(
233 "--prior_loss_weight", 225 "--prior_loss_weight",
234 type=float, 226 type=float,
235 default=1.0, 227 default=1.0,
@@ -311,7 +303,7 @@ class Checkpointer:
311 self.output_dir = output_dir 303 self.output_dir = output_dir
312 self.instance_identifier = instance_identifier 304 self.instance_identifier = instance_identifier
313 self.sample_image_size = sample_image_size 305 self.sample_image_size = sample_image_size
314 self.seed = seed 306 self.seed = seed or torch.random.seed()
315 self.sample_batches = sample_batches 307 self.sample_batches = sample_batches
316 self.sample_batch_size = sample_batch_size 308 self.sample_batch_size = sample_batch_size
317 309
@@ -406,6 +398,8 @@ class Checkpointer:
406 del unwrapped 398 del unwrapped
407 del scheduler 399 del scheduler
408 del pipeline 400 del pipeline
401 del generator
402 del stable_latents
409 403
410 if torch.cuda.is_available(): 404 if torch.cuda.is_available():
411 torch.cuda.empty_cache() 405 torch.cuda.empty_cache()
@@ -523,11 +517,13 @@ def main():
523 tokenizer=tokenizer, 517 tokenizer=tokenizer,
524 instance_identifier=args.instance_identifier, 518 instance_identifier=args.instance_identifier,
525 class_identifier=args.class_identifier, 519 class_identifier=args.class_identifier,
520 class_subdir="db_cls",
526 size=args.resolution, 521 size=args.resolution,
527 repeats=args.repeats, 522 repeats=args.repeats,
528 center_crop=args.center_crop, 523 center_crop=args.center_crop,
529 valid_set_size=args.sample_batch_size*args.sample_batches, 524 valid_set_size=args.sample_batch_size*args.sample_batches,
530 collate_fn=collate_fn) 525 collate_fn=collate_fn
526 )
531 527
532 datamodule.prepare_data() 528 datamodule.prepare_data()
533 datamodule.setup() 529 datamodule.setup()
@@ -587,7 +583,7 @@ def main():
587 sample_image_size=args.sample_image_size, 583 sample_image_size=args.sample_image_size,
588 sample_batch_size=args.sample_batch_size, 584 sample_batch_size=args.sample_batch_size,
589 sample_batches=args.sample_batches, 585 sample_batches=args.sample_batches,
590 seed=args.seed or torch.random.seed() 586 seed=args.seed
591 ) 587 )
592 588
593 # Scheduler and math around the number of training steps. 589 # Scheduler and math around the number of training steps.
@@ -699,8 +695,7 @@ def main():
699 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() 695 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
700 696
701 # Compute prior loss 697 # Compute prior loss
702 prior_loss = F.mse_loss(noise_pred_prior, noise_prior, 698 prior_loss = F.mse_loss(noise_pred_prior, noise_prior, reduction="none").mean([1, 2, 3]).mean()
703 reduction="none").mean([1, 2, 3]).mean()
704 699
705 # Add the prior loss to the instance loss. 700 # Add the prior loss to the instance loss.
706 loss = loss + args.prior_loss_weight * prior_loss 701 loss = loss + args.prior_loss_weight * prior_loss
diff --git a/textual_dreambooth.py b/textual_dreambooth.py
new file mode 100644
index 0000000..a46953d
--- /dev/null
+++ b/textual_dreambooth.py
@@ -0,0 +1,948 @@
1import argparse
2import itertools
3import math
4import os
5import datetime
6import logging
7from pathlib import Path
8
9import numpy as np
10import torch
11import torch.nn.functional as F
12import torch.utils.checkpoint
13
14from accelerate import Accelerator
15from accelerate.logging import get_logger
16from accelerate.utils import LoggerType, set_seed
17from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
18from schedulers.scheduling_euler_a import EulerAScheduler
19from diffusers.optimization import get_scheduler
20from PIL import Image
21from tqdm.auto import tqdm
22from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
23from slugify import slugify
24from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
25import json
26import os
27
28from data.dreambooth.csv import CSVDataModule
29
30logger = get_logger(__name__)
31
32
33torch.backends.cuda.matmul.allow_tf32 = True
34
35
36def parse_args():
37 parser = argparse.ArgumentParser(
38 description="Simple example of a training script."
39 )
40 parser.add_argument(
41 "--pretrained_model_name_or_path",
42 type=str,
43 default=None,
44 help="Path to pretrained model or model identifier from huggingface.co/models.",
45 )
46 parser.add_argument(
47 "--tokenizer_name",
48 type=str,
49 default=None,
50 help="Pretrained tokenizer name or path if not the same as model_name",
51 )
52 parser.add_argument(
53 "--train_data_file",
54 type=str,
55 default=None,
56 help="A CSV file containing the training data."
57 )
58 parser.add_argument(
59 "--placeholder_token",
60 type=str,
61 default=None,
62 help="A token to use as a placeholder for the concept.",
63 )
64 parser.add_argument(
65 "--initializer_token",
66 type=str,
67 default=None,
68 help="A token to use as initializer word."
69 )
70 parser.add_argument(
71 "--num_vec_per_token",
72 type=int,
73 default=1,
74 help=(
75 "The number of vectors used to represent the placeholder token. The higher the number, the better the"
76 " result at the cost of editability. This can be fixed by prompt editing."
77 ),
78 )
79 parser.add_argument(
80 "--initialize_rest_random",
81 action="store_true",
82 help="Initialize rest of the placeholder tokens with random."
83 )
84 parser.add_argument(
85 "--use_class_images",
86 action="store_true",
87 default=True,
88 help="Include class images in the loss calculation a la Dreambooth.",
89 )
90 parser.add_argument(
91 "--repeats",
92 type=int,
93 default=100,
94 help="How many times to repeat the training data.")
95 parser.add_argument(
96 "--output_dir",
97 type=str,
98 default="output/text-inversion",
99 help="The output directory where the model predictions and checkpoints will be written.",
100 )
101 parser.add_argument(
102 "--seed",
103 type=int,
104 default=None,
105 help="A seed for reproducible training.")
106 parser.add_argument(
107 "--resolution",
108 type=int,
109 default=512,
110 help=(
111 "The resolution for input images, all the images in the train/validation dataset will be resized to this"
112 " resolution"
113 ),
114 )
115 parser.add_argument(
116 "--center_crop",
117 action="store_true",
118 help="Whether to center crop images before resizing to resolution"
119 )
120 parser.add_argument(
121 "--num_train_epochs",
122 type=int,
123 default=100)
124 parser.add_argument(
125 "--max_train_steps",
126 type=int,
127 default=5000,
128 help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
129 )
130 parser.add_argument(
131 "--gradient_accumulation_steps",
132 type=int,
133 default=1,
134 help="Number of updates steps to accumulate before performing a backward/update pass.",
135 )
136 parser.add_argument(
137 "--gradient_checkpointing",
138 action="store_true",
139 help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
140 )
141 parser.add_argument(
142 "--learning_rate",
143 type=float,
144 default=1e-4,
145 help="Initial learning rate (after the potential warmup period) to use.",
146 )
147 parser.add_argument(
148 "--scale_lr",
149 action="store_true",
150 default=True,
151 help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
152 )
153 parser.add_argument(
154 "--lr_scheduler",
155 type=str,
156 default="constant",
157 help=(
158 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
159 ' "constant", "constant_with_warmup"]'
160 ),
161 )
162 parser.add_argument(
163 "--lr_warmup_steps",
164 type=int,
165 default=500,
166 help="Number of steps for the warmup in the lr scheduler."
167 )
168 parser.add_argument(
169 "--use_8bit_adam",
170 action="store_true",
171 help="Whether or not to use 8-bit Adam from bitsandbytes."
172 )
173 parser.add_argument(
174 "--adam_beta1",
175 type=float,
176 default=0.9,
177 help="The beta1 parameter for the Adam optimizer."
178 )
179 parser.add_argument(
180 "--adam_beta2",
181 type=float,
182 default=0.999,
183 help="The beta2 parameter for the Adam optimizer."
184 )
185 parser.add_argument(
186 "--adam_weight_decay",
187 type=float,
188 default=1e-2,
189 help="Weight decay to use."
190 )
191 parser.add_argument(
192 "--adam_epsilon",
193 type=float,
194 default=1e-08,
195 help="Epsilon value for the Adam optimizer"
196 )
197 parser.add_argument(
198 "--mixed_precision",
199 type=str,
200 default="no",
201 choices=["no", "fp16", "bf16"],
202 help=(
203 "Whether to use mixed precision. Choose"
204 "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
205 "and an Nvidia Ampere GPU."
206 ),
207 )
208 parser.add_argument(
209 "--local_rank",
210 type=int,
211 default=-1,
212 help="For distributed training: local_rank"
213 )
214 parser.add_argument(
215 "--checkpoint_frequency",
216 type=int,
217 default=500,
218 help="How often to save a checkpoint and sample image",
219 )
220 parser.add_argument(
221 "--sample_image_size",
222 type=int,
223 default=512,
224 help="Size of sample images",
225 )
226 parser.add_argument(
227 "--sample_batches",
228 type=int,
229 default=1,
230 help="Number of sample batches to generate per checkpoint",
231 )
232 parser.add_argument(
233 "--sample_batch_size",
234 type=int,
235 default=1,
236 help="Number of samples to generate per batch",
237 )
238 parser.add_argument(
239 "--train_batch_size",
240 type=int,
241 default=1,
242 help="Batch size (per device) for the training dataloader."
243 )
244 parser.add_argument(
245 "--sample_steps",
246 type=int,
247 default=30,
248 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.",
249 )
250 parser.add_argument(
251 "--prior_loss_weight",
252 type=float,
253 default=1.0,
254 help="The weight of prior preservation loss."
255 )
256 parser.add_argument(
257 "--resume_from",
258 type=str,
259 default=None,
260 help="Path to a directory to resume training from (ie, logs/token_name/2022-09-22T23-36-27)"
261 )
262 parser.add_argument(
263 "--resume_checkpoint",
264 type=str,
265 default=None,
266 help="Path to a specific checkpoint to resume training from (ie, logs/token_name/2022-09-22T23-36-27/checkpoints/something.bin)."
267 )
268 parser.add_argument(
269 "--config",
270 type=str,
271 default=None,
272 help="Path to a JSON configuration file containing arguments for invoking this script. If resume_from is given, its resume.json takes priority over this."
273 )
274
275 args = parser.parse_args()
276 if args.resume_from is not None:
277 with open(f"{args.resume_from}/resume.json", 'rt') as f:
278 args = parser.parse_args(
279 namespace=argparse.Namespace(**json.load(f)["args"]))
280 elif args.config is not None:
281 with open(args.config, 'rt') as f:
282 args = parser.parse_args(
283 namespace=argparse.Namespace(**json.load(f)["args"]))
284
285 env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
286 if env_local_rank != -1 and env_local_rank != args.local_rank:
287 args.local_rank = env_local_rank
288
289 if args.train_data_file is None:
290 raise ValueError("You must specify --train_data_file")
291
292 if args.pretrained_model_name_or_path is None:
293 raise ValueError("You must specify --pretrained_model_name_or_path")
294
295 if args.placeholder_token is None:
296 raise ValueError("You must specify --placeholder_token")
297
298 if args.initializer_token is None:
299 raise ValueError("You must specify --initializer_token")
300
301 if args.output_dir is None:
302 raise ValueError("You must specify --output_dir")
303
304 return args
305
306
307def freeze_params(params):
308 for param in params:
309 param.requires_grad = False
310
311
312def save_resume_file(basepath, args, extra={}):
313 info = {"args": vars(args)}
314 info["args"].update(extra)
315 with open(f"{basepath}/resume.json", "w") as f:
316 json.dump(info, f, indent=4)
317
318
319def make_grid(images, rows, cols):
320 w, h = images[0].size
321 grid = Image.new('RGB', size=(cols*w, rows*h))
322 for i, image in enumerate(images):
323 grid.paste(image, box=(i % cols*w, i//cols*h))
324 return grid
325
326
327def add_tokens_and_get_placeholder_token(args, token_ids, tokenizer, text_encoder):
328 assert args.num_vec_per_token >= len(token_ids)
329 placeholder_tokens = [f"{args.placeholder_token}_{i}" for i in range(args.num_vec_per_token)]
330
331 for placeholder_token in placeholder_tokens:
332 num_added_tokens = tokenizer.add_tokens(placeholder_token)
333 if num_added_tokens == 0:
334 raise ValueError(
335 f"The tokenizer already contains the token {placeholder_token}. Please pass a different"
336 " `placeholder_token` that is not already in the tokenizer."
337 )
338
339 placeholder_token = " ".join(placeholder_tokens)
340 placeholder_token_ids = tokenizer.encode(placeholder_token, add_special_tokens=False)
341
342 print(f"The placeholder tokens are {placeholder_token} while the ids are {placeholder_token_ids}")
343
344 text_encoder.resize_token_embeddings(len(tokenizer))
345 token_embeds = text_encoder.get_input_embeddings().weight.data
346
347 if args.initialize_rest_random:
348 # The idea is that the placeholder tokens form adjectives as in x x x white dog.
349 for i, placeholder_token_id in enumerate(placeholder_token_ids):
350 if len(placeholder_token_ids) - i < len(token_ids):
351 token_embeds[placeholder_token_id] = token_embeds[token_ids[i % len(token_ids)]]
352 else:
353 token_embeds[placeholder_token_id] = torch.rand_like(token_embeds[placeholder_token_id])
354 else:
355 for i, placeholder_token_id in enumerate(placeholder_token_ids):
356 token_embeds[placeholder_token_id] = token_embeds[token_ids[i % len(token_ids)]]
357
358 return placeholder_token, placeholder_token_ids
359
360
361class Checkpointer:
362 def __init__(
363 self,
364 datamodule,
365 accelerator,
366 vae,
367 unet,
368 tokenizer,
369 placeholder_token,
370 placeholder_token_ids,
371 output_dir,
372 sample_image_size,
373 sample_batches,
374 sample_batch_size,
375 seed
376 ):
377 self.datamodule = datamodule
378 self.accelerator = accelerator
379 self.vae = vae
380 self.unet = unet
381 self.tokenizer = tokenizer
382 self.placeholder_token = placeholder_token
383 self.placeholder_token_ids = placeholder_token_ids
384 self.output_dir = output_dir
385 self.sample_image_size = sample_image_size
386 self.seed = seed or torch.random.seed()
387 self.sample_batches = sample_batches
388 self.sample_batch_size = sample_batch_size
389
390 @torch.no_grad()
391 def checkpoint(self, step, postfix, text_encoder, save_samples=True, path=None):
392 print("Saving checkpoint for step %d..." % step)
393
394 if path is None:
395 checkpoints_path = f"{self.output_dir}/checkpoints"
396 os.makedirs(checkpoints_path, exist_ok=True)
397
398 unwrapped = self.accelerator.unwrap_model(text_encoder)
399
400 # Save a checkpoint
401 learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_ids]
402 learned_embeds_dict = {}
403 for i, placeholder_token in enumerate(self.placeholder_token.split(" ")):
404 learned_embeds_dict[placeholder_token] = learned_embeds[i].detach().cpu()
405
406 filename = f"%s_%d_%s.bin" % (slugify(self.placeholder_token), step, postfix)
407 if path is not None:
408 torch.save(learned_embeds_dict, path)
409 else:
410 torch.save(learned_embeds_dict, f"{checkpoints_path}/{filename}")
411 torch.save(learned_embeds_dict, f"{checkpoints_path}/last.bin")
412
413 del unwrapped
414 del learned_embeds
415
416 @torch.no_grad()
417 def save_samples(self, step, text_encoder, height, width, guidance_scale, eta, num_inference_steps):
418 samples_path = Path(self.output_dir).joinpath("samples")
419
420 unwrapped = self.accelerator.unwrap_model(text_encoder)
421 scheduler = EulerAScheduler(
422 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
423 )
424
425 # Save a sample image
426 pipeline = VlpnStableDiffusion(
427 text_encoder=unwrapped,
428 vae=self.vae,
429 unet=self.unet,
430 tokenizer=self.tokenizer,
431 scheduler=scheduler,
432 ).to(self.accelerator.device)
433 pipeline.enable_attention_slicing()
434
435 train_data = self.datamodule.train_dataloader()
436 val_data = self.datamodule.val_dataloader()
437
438 generator = torch.Generator(device=pipeline.device).manual_seed(self.seed)
439 stable_latents = torch.randn(
440 (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8),
441 device=pipeline.device,
442 generator=generator,
443 )
444
445 for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]:
446 all_samples = []
447 file_path = samples_path.joinpath(pool, f"step_{step}.png")
448 file_path.parent.mkdir(parents=True, exist_ok=True)
449
450 data_enum = enumerate(data)
451
452 for i in range(self.sample_batches):
453 batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size]
454 prompt = [prompt.format(self.placeholder_token)
455 for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size]
456 nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size]
457
458 with self.accelerator.autocast():
459 samples = pipeline(
460 prompt=prompt,
461 negative_prompt=nprompt,
462 height=self.sample_image_size,
463 width=self.sample_image_size,
464 latents=latents[:len(prompt)] if latents is not None else None,
465 generator=generator if latents is not None else None,
466 guidance_scale=guidance_scale,
467 eta=eta,
468 num_inference_steps=num_inference_steps,
469 output_type='pil'
470 )["sample"]
471
472 all_samples += samples
473
474 del samples
475
476 image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size)
477 image_grid.save(file_path)
478
479 del all_samples
480 del image_grid
481
482 del unwrapped
483 del scheduler
484 del pipeline
485 del generator
486 del stable_latents
487
488 if torch.cuda.is_available():
489 torch.cuda.empty_cache()
490
491
492def main():
493 args = parse_args()
494
495 global_step_offset = 0
496 if args.resume_from is not None:
497 basepath = Path(args.resume_from)
498 print("Resuming state from %s" % args.resume_from)
499 with open(basepath.joinpath("resume.json"), 'r') as f:
500 state = json.load(f)
501 global_step_offset = state["args"].get("global_step", 0)
502
503 print("We've trained %d steps so far" % global_step_offset)
504 else:
505 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
506 basepath = Path(args.output_dir).joinpath(slugify(args.placeholder_token), now)
507 basepath.mkdir(parents=True, exist_ok=True)
508
509 accelerator = Accelerator(
510 log_with=LoggerType.TENSORBOARD,
511 logging_dir=f"{basepath}",
512 gradient_accumulation_steps=args.gradient_accumulation_steps,
513 mixed_precision=args.mixed_precision
514 )
515
516 logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG)
517
518 # If passed along, set the training seed now.
519 if args.seed is not None:
520 set_seed(args.seed)
521
522 # Load the tokenizer and add the placeholder token as a additional special token
523 if args.tokenizer_name:
524 tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
525 elif args.pretrained_model_name_or_path:
526 tokenizer = CLIPTokenizer.from_pretrained(
527 args.pretrained_model_name_or_path + '/tokenizer'
528 )
529
530 # Load models and create wrapper for stable diffusion
531 text_encoder = CLIPTextModel.from_pretrained(
532 args.pretrained_model_name_or_path + '/text_encoder',
533 )
534 vae = AutoencoderKL.from_pretrained(
535 args.pretrained_model_name_or_path + '/vae',
536 )
537 unet = UNet2DConditionModel.from_pretrained(
538 args.pretrained_model_name_or_path + '/unet',
539 )
540
541 if args.gradient_checkpointing:
542 unet.enable_gradient_checkpointing()
543
544 slice_size = unet.config.attention_head_dim // 2
545 unet.set_attention_slice(slice_size)
546
547 token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False)
548 # regardless of whether the number of token_ids is 1 or more, it'll set one and then keep repeating.
549 placeholder_token, placeholder_token_ids = add_tokens_and_get_placeholder_token(
550 args, token_ids, tokenizer, text_encoder)
551
552 # if args.resume_checkpoint is not None:
553 # token_embeds[placeholder_token_id] = torch.load(args.resume_checkpoint)[
554 # args.placeholder_token]
555 # else:
556 # token_embeds[placeholder_token_id] = initializer_token_embeddings
557
558 # Freeze vae and unet
559 freeze_params(vae.parameters())
560 freeze_params(unet.parameters())
561 # Freeze all parameters except for the token embeddings in text encoder
562 params_to_freeze = itertools.chain(
563 text_encoder.text_model.encoder.parameters(),
564 text_encoder.text_model.final_layer_norm.parameters(),
565 text_encoder.text_model.embeddings.position_embedding.parameters(),
566 )
567 freeze_params(params_to_freeze)
568
569 if args.scale_lr:
570 args.learning_rate = (
571 args.learning_rate * args.gradient_accumulation_steps *
572 args.train_batch_size * accelerator.num_processes
573 )
574
575 # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
576 if args.use_8bit_adam:
577 try:
578 import bitsandbytes as bnb
579 except ImportError:
580 raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.")
581
582 optimizer_class = bnb.optim.AdamW8bit
583 else:
584 optimizer_class = torch.optim.AdamW
585
586 # Initialize the optimizer
587 optimizer = optimizer_class(
588 text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings
589 lr=args.learning_rate,
590 betas=(args.adam_beta1, args.adam_beta2),
591 weight_decay=args.adam_weight_decay,
592 eps=args.adam_epsilon,
593 )
594
595 noise_scheduler = DDPMScheduler(
596 beta_start=0.00085,
597 beta_end=0.012,
598 beta_schedule="scaled_linear",
599 num_train_timesteps=1000
600 )
601
602 def collate_fn(examples):
603 prompts = [example["prompts"] for example in examples]
604 nprompts = [example["nprompts"] for example in examples]
605 input_ids = [example["instance_prompt_ids"] for example in examples]
606 pixel_values = [example["instance_images"] for example in examples]
607
608 # concat class and instance examples for prior preservation
609 if args.use_class_images and "class_prompt_ids" in examples[0]:
610 input_ids += [example["class_prompt_ids"] for example in examples]
611 pixel_values += [example["class_images"] for example in examples]
612
613 pixel_values = torch.stack(pixel_values)
614 pixel_values = pixel_values.to(dtype=torch.float32, memory_format=torch.contiguous_format)
615
616 input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids
617
618 batch = {
619 "prompts": prompts,
620 "nprompts": nprompts,
621 "input_ids": input_ids,
622 "pixel_values": pixel_values,
623 }
624 return batch
625
626 datamodule = CSVDataModule(
627 data_file=args.train_data_file,
628 batch_size=args.train_batch_size,
629 tokenizer=tokenizer,
630 instance_identifier=placeholder_token,
631 class_identifier=args.initializer_token if args.use_class_images else None,
632 class_subdir="ti_cls",
633 size=args.resolution,
634 repeats=args.repeats,
635 center_crop=args.center_crop,
636 valid_set_size=args.sample_batch_size*args.sample_batches,
637 collate_fn=collate_fn
638 )
639
640 datamodule.prepare_data()
641 datamodule.setup()
642
643 if args.use_class_images:
644 missing_data = [item for item in datamodule.data if not item[1].exists()]
645
646 if len(missing_data) != 0:
647 batched_data = [missing_data[i:i+args.sample_batch_size]
648 for i in range(0, len(missing_data), args.sample_batch_size)]
649
650 scheduler = EulerAScheduler(
651 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
652 )
653
654 pipeline = VlpnStableDiffusion(
655 text_encoder=text_encoder,
656 vae=vae,
657 unet=unet,
658 tokenizer=tokenizer,
659 scheduler=scheduler,
660 ).to(accelerator.device)
661 pipeline.enable_attention_slicing()
662
663 for batch in batched_data:
664 image_name = [p[1] for p in batch]
665 prompt = [p[2].format(args.initializer_token) for p in batch]
666 nprompt = [p[3] for p in batch]
667
668 with accelerator.autocast():
669 images = pipeline(
670 prompt=prompt,
671 negative_prompt=nprompt,
672 num_inference_steps=args.sample_steps
673 ).images
674
675 for i, image in enumerate(images):
676 image.save(image_name[i])
677
678 del pipeline
679
680 if torch.cuda.is_available():
681 torch.cuda.empty_cache()
682
683 train_dataloader = datamodule.train_dataloader()
684 val_dataloader = datamodule.val_dataloader()
685
686 checkpointer = Checkpointer(
687 datamodule=datamodule,
688 accelerator=accelerator,
689 vae=vae,
690 unet=unet,
691 tokenizer=tokenizer,
692 placeholder_token=args.placeholder_token,
693 placeholder_token_ids=placeholder_token_ids,
694 output_dir=basepath,
695 sample_image_size=args.sample_image_size,
696 sample_batch_size=args.sample_batch_size,
697 sample_batches=args.sample_batches,
698 seed=args.seed
699 )
700
701 # Scheduler and math around the number of training steps.
702 overrode_max_train_steps = False
703 num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
704 if args.max_train_steps is None:
705 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
706 overrode_max_train_steps = True
707
708 lr_scheduler = get_scheduler(
709 args.lr_scheduler,
710 optimizer=optimizer,
711 num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
712 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
713 )
714
715 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
716 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler
717 )
718
719 # Move vae and unet to device
720 vae.to(accelerator.device)
721 unet.to(accelerator.device)
722
723 # Keep vae and unet in eval mode as we don't train these
724 vae.eval()
725 unet.eval()
726
727 # We need to recalculate our total training steps as the size of the training dataloader may have changed.
728 num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
729 if overrode_max_train_steps:
730 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
731
732 num_val_steps_per_epoch = len(val_dataloader)
733 num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
734 val_steps = num_val_steps_per_epoch * num_epochs
735
736 # We need to initialize the trackers we use, and also store our configuration.
737 # The trackers initializes automatically on the main process.
738 if accelerator.is_main_process:
739 accelerator.init_trackers("textual_inversion", config=vars(args))
740
741 # Train!
742 total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
743
744 logger.info("***** Running training *****")
745 logger.info(f" Num Epochs = {num_epochs}")
746 logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
747 logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
748 logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
749 logger.info(f" Total optimization steps = {args.max_train_steps}")
750 # Only show the progress bar once on each machine.
751
752 global_step = 0
753 min_val_loss = np.inf
754
755 if accelerator.is_main_process:
756 checkpointer.save_samples(
757 0,
758 text_encoder,
759 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps)
760
761 local_progress_bar = tqdm(range(num_update_steps_per_epoch + num_val_steps_per_epoch),
762 disable=not accelerator.is_local_main_process)
763 local_progress_bar.set_description("Batch X out of Y")
764
765 global_progress_bar = tqdm(range(args.max_train_steps + val_steps), disable=not accelerator.is_local_main_process)
766 global_progress_bar.set_description("Total progress")
767
768 try:
769 for epoch in range(num_epochs):
770 local_progress_bar.set_description(f"Batch {epoch + 1} out of {num_epochs}")
771 local_progress_bar.reset()
772
773 text_encoder.train()
774 train_loss = 0.0
775
776 for step, batch in enumerate(train_dataloader):
777 with accelerator.accumulate(text_encoder):
778 # Convert images to latent space
779 with torch.no_grad():
780 latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
781 latents = latents * 0.18215
782
783 # Sample noise that we'll add to the latents
784 noise = torch.randn(latents.shape).to(latents.device)
785 bsz = latents.shape[0]
786 # Sample a random timestep for each image
787 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps,
788 (bsz,), device=latents.device)
789 timesteps = timesteps.long()
790
791 # Add noise to the latents according to the noise magnitude at each timestep
792 # (this is the forward diffusion process)
793 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
794
795 # Get the text embedding for conditioning
796 encoder_hidden_states = text_encoder(batch["input_ids"])[0]
797
798 # Predict the noise residual
799 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
800
801 if args.use_class_images:
802 # Chunk the noise and noise_pred into two parts and compute the loss on each part separately.
803 noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0)
804 noise, noise_prior = torch.chunk(noise, 2, dim=0)
805
806 # Compute instance loss
807 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
808
809 # Compute prior loss
810 prior_loss = F.mse_loss(noise_pred_prior, noise_prior, reduction="none").mean([1, 2, 3]).mean()
811
812 # Add the prior loss to the instance loss.
813 loss = loss + args.prior_loss_weight * prior_loss
814 else:
815 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
816
817 accelerator.backward(loss)
818
819 # Zero out the gradients for all token embeddings except the newly added
820 # embeddings for the concept, as we only want to optimize the concept embeddings
821 if accelerator.num_processes > 1:
822 grads = text_encoder.module.get_input_embeddings().weight.grad
823 else:
824 grads = text_encoder.get_input_embeddings().weight.grad
825 # Get the index for tokens that we want to zero the grads for
826 grad_mask = torch.arange(len(tokenizer)) != placeholder_token_ids[0]
827 for i in range(1, len(placeholder_token_ids)):
828 grad_mask = grad_mask & (torch.arange(len(tokenizer)) != placeholder_token_ids[i])
829 grads.data[grad_mask, :] = grads.data[grad_mask, :].fill_(0)
830
831 optimizer.step()
832 if not accelerator.optimizer_step_was_skipped:
833 lr_scheduler.step()
834 optimizer.zero_grad(set_to_none=True)
835
836 loss = loss.detach().item()
837 train_loss += loss
838
839 # Checks if the accelerator has performed an optimization step behind the scenes
840 if accelerator.sync_gradients:
841 local_progress_bar.update(1)
842 global_progress_bar.update(1)
843
844 global_step += 1
845
846 if global_step % args.checkpoint_frequency == 0 and global_step > 0 and accelerator.is_main_process:
847 local_progress_bar.clear()
848 global_progress_bar.clear()
849
850 checkpointer.checkpoint(global_step + global_step_offset, "training", text_encoder)
851 save_resume_file(basepath, args, {
852 "global_step": global_step + global_step_offset,
853 "resume_checkpoint": f"{basepath}/checkpoints/last.bin"
854 })
855
856 logs = {"mode": "training", "loss": loss, "lr": lr_scheduler.get_last_lr()[0]}
857 local_progress_bar.set_postfix(**logs)
858
859 if global_step >= args.max_train_steps:
860 break
861
862 train_loss /= len(train_dataloader)
863
864 accelerator.wait_for_everyone()
865
866 text_encoder.eval()
867 val_loss = 0.0
868
869 for step, batch in enumerate(val_dataloader):
870 with torch.no_grad():
871 latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
872 latents = latents * 0.18215
873
874 noise = torch.randn(latents.shape).to(latents.device)
875 bsz = latents.shape[0]
876 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps,
877 (bsz,), device=latents.device)
878 timesteps = timesteps.long()
879
880 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
881
882 encoder_hidden_states = text_encoder(batch["input_ids"])[0]
883
884 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
885
886 noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise))
887
888 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
889
890 loss = loss.detach().item()
891 val_loss += loss
892
893 if accelerator.sync_gradients:
894 local_progress_bar.update(1)
895 global_progress_bar.update(1)
896
897 logs = {"mode": "validation", "loss": loss}
898 local_progress_bar.set_postfix(**logs)
899
900 val_loss /= len(val_dataloader)
901
902 accelerator.log({"train/loss": train_loss, "val/loss": val_loss}, step=global_step)
903
904 local_progress_bar.clear()
905 global_progress_bar.clear()
906
907 if min_val_loss > val_loss:
908 accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}")
909 checkpointer.checkpoint(global_step + global_step_offset, "milestone", text_encoder)
910 min_val_loss = val_loss
911
912 if accelerator.is_main_process:
913 checkpointer.save_samples(
914 global_step + global_step_offset,
915 text_encoder,
916 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps)
917
918 # Create the pipeline using using the trained modules and save it.
919 if accelerator.is_main_process:
920 print("Finished! Saving final checkpoint and resume state.")
921 checkpointer.checkpoint(
922 global_step + global_step_offset,
923 "end",
924 text_encoder,
925 path=f"{basepath}/learned_embeds.bin"
926 )
927
928 save_resume_file(basepath, args, {
929 "global_step": global_step + global_step_offset,
930 "resume_checkpoint": f"{basepath}/checkpoints/last.bin"
931 })
932
933 accelerator.end_training()
934
935 except KeyboardInterrupt:
936 if accelerator.is_main_process:
937 print("Interrupted, saving checkpoint and resume state...")
938 checkpointer.checkpoint(global_step + global_step_offset, "end", text_encoder)
939 save_resume_file(basepath, args, {
940 "global_step": global_step + global_step_offset,
941 "resume_checkpoint": f"{basepath}/checkpoints/last.bin"
942 })
943 accelerator.end_training()
944 quit()
945
946
947if __name__ == "__main__":
948 main()
diff --git a/textual_inversion.py b/textual_inversion.py
index d842288..7919ebd 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -230,7 +230,7 @@ def parse_args():
230 parser.add_argument( 230 parser.add_argument(
231 "--sample_steps", 231 "--sample_steps",
232 type=int, 232 type=int,
233 default=40, 233 default=30,
234 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", 234 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.",
235 ) 235 )
236 parser.add_argument( 236 parser.add_argument(
@@ -329,7 +329,7 @@ class Checkpointer:
329 self.placeholder_token_id = placeholder_token_id 329 self.placeholder_token_id = placeholder_token_id
330 self.output_dir = output_dir 330 self.output_dir = output_dir
331 self.sample_image_size = sample_image_size 331 self.sample_image_size = sample_image_size
332 self.seed = seed 332 self.seed = seed or torch.random.seed()
333 self.sample_batches = sample_batches 333 self.sample_batches = sample_batches
334 self.sample_batch_size = sample_batch_size 334 self.sample_batch_size = sample_batch_size
335 335
@@ -481,9 +481,9 @@ def main():
481 # Convert the initializer_token, placeholder_token to ids 481 # Convert the initializer_token, placeholder_token to ids
482 initializer_token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False) 482 initializer_token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False)
483 # Check if initializer_token is a single token or a sequence of tokens 483 # Check if initializer_token is a single token or a sequence of tokens
484 if args.vectors_per_token % len(initializer_token_ids) != 0: 484 if len(initializer_token_ids) > 1:
485 raise ValueError( 485 raise ValueError(
486 f"vectors_per_token ({args.vectors_per_token}) must be divisible by initializer token ({len(initializer_token_ids)}).") 486 f"initializer_token_ids must not have more than 1 vector, but it's {len(initializer_token_ids)}.")
487 487
488 initializer_token_ids = torch.tensor(initializer_token_ids) 488 initializer_token_ids = torch.tensor(initializer_token_ids)
489 placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) 489 placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
@@ -590,7 +590,7 @@ def main():
590 sample_image_size=args.sample_image_size, 590 sample_image_size=args.sample_image_size,
591 sample_batch_size=args.sample_batch_size, 591 sample_batch_size=args.sample_batch_size,
592 sample_batches=args.sample_batches, 592 sample_batches=args.sample_batches,
593 seed=args.seed or torch.random.seed() 593 seed=args.seed
594 ) 594 )
595 595
596 # Scheduler and math around the number of training steps. 596 # Scheduler and math around the number of training steps.
@@ -620,8 +620,7 @@ def main():
620 unet.eval() 620 unet.eval()
621 621
622 # We need to recalculate our total training steps as the size of the training dataloader may have changed. 622 # We need to recalculate our total training steps as the size of the training dataloader may have changed.
623 num_update_steps_per_epoch = math.ceil( 623 num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
624 (len(train_dataloader) + len(val_dataloader)) / args.gradient_accumulation_steps)
625 if overrode_max_train_steps: 624 if overrode_max_train_steps:
626 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 625 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
627 626