summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--dreambooth.py48
-rw-r--r--dreambooth_plus.py1023
2 files changed, 42 insertions, 1029 deletions
diff --git a/dreambooth.py b/dreambooth.py
index d1cf535..da8399f 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -69,6 +69,18 @@ def parse_args():
69 help="A token to use as a placeholder for the concept.", 69 help="A token to use as a placeholder for the concept.",
70 ) 70 )
71 parser.add_argument( 71 parser.add_argument(
72 "--placeholder_token",
73 type=str,
74 default="<*>",
75 help="A token to use as a placeholder for the concept.",
76 )
77 parser.add_argument(
78 "--initializer_token",
79 type=str,
80 default=None,
81 help="A token to use as initializer word."
82 )
83 parser.add_argument(
72 "--num_class_images", 84 "--num_class_images",
73 type=int, 85 type=int,
74 default=400, 86 default=400,
@@ -114,7 +126,7 @@ def parse_args():
114 parser.add_argument( 126 parser.add_argument(
115 "--max_train_steps", 127 "--max_train_steps",
116 type=int, 128 type=int,
117 default=3600, 129 default=6000,
118 help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 130 help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
119 ) 131 )
120 parser.add_argument( 132 parser.add_argument(
@@ -131,13 +143,13 @@ def parse_args():
131 parser.add_argument( 143 parser.add_argument(
132 "--learning_rate_unet", 144 "--learning_rate_unet",
133 type=float, 145 type=float,
134 default=3e-6, 146 default=2e-6,
135 help="Initial learning rate (after the potential warmup period) to use.", 147 help="Initial learning rate (after the potential warmup period) to use.",
136 ) 148 )
137 parser.add_argument( 149 parser.add_argument(
138 "--learning_rate_text", 150 "--learning_rate_text",
139 type=float, 151 type=float,
140 default=3e-6, 152 default=2e-6,
141 help="Initial learning rate (after the potential warmup period) to use.", 153 help="Initial learning rate (after the potential warmup period) to use.",
142 ) 154 )
143 parser.add_argument( 155 parser.add_argument(
@@ -476,8 +488,13 @@ class Checkpointer:
476def main(): 488def main():
477 args = parse_args() 489 args = parse_args()
478 490
491 instance_identifier = args.instance_identifier
492
493 if args.placeholder_token is not None:
494 instance_identifier = instance_identifier.format(args.placeholder_token)
495
479 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") 496 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
480 basepath = Path(args.output_dir).joinpath(slugify(args.instance_identifier), now) 497 basepath = Path(args.output_dir).joinpath(slugify(instance_identifier), now)
481 basepath.mkdir(parents=True, exist_ok=True) 498 basepath.mkdir(parents=True, exist_ok=True)
482 499
483 accelerator = Accelerator( 500 accelerator = Accelerator(
@@ -514,6 +531,25 @@ def main():
514 device=accelerator.device 531 device=accelerator.device
515 ) if args.use_ema else None 532 ) if args.use_ema else None
516 533
534 if args.initializer_token is not None:
535 # Convert the initializer_token, placeholder_token to ids
536 initializer_token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False)
537 print(f"Initializer token {args.initializer_token} maps to {len(initializer_token_ids)} embeddings.")
538 initializer_token_ids = torch.tensor(initializer_token_ids[:1])
539
540 # Add the placeholder token in tokenizer
541 num_added_tokens = tokenizer.add_tokens(args.placeholder_token)
542 if num_added_tokens == 0:
543 print(f"Re-using existing token {args.placeholder_token}.")
544 else:
545 print(f"Training new token {args.placeholder_token}.")
546 placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
547
548 text_encoder.resize_token_embeddings(len(tokenizer))
549 token_embeds = text_encoder.get_input_embeddings()
550 initializer_token_embeddings = token_embeds(initializer_token_ids)
551 token_embeds.weight.data[placeholder_token_id] = initializer_token_embeddings
552
517 prompt_processor = PromptProcessor(tokenizer, text_encoder) 553 prompt_processor = PromptProcessor(tokenizer, text_encoder)
518 554
519 if args.gradient_checkpointing: 555 if args.gradient_checkpointing:
@@ -605,7 +641,7 @@ def main():
605 data_file=args.train_data_file, 641 data_file=args.train_data_file,
606 batch_size=args.train_batch_size, 642 batch_size=args.train_batch_size,
607 prompt_processor=prompt_processor, 643 prompt_processor=prompt_processor,
608 instance_identifier=args.instance_identifier, 644 instance_identifier=instance_identifier,
609 class_identifier=args.class_identifier, 645 class_identifier=args.class_identifier,
610 class_subdir="cls", 646 class_subdir="cls",
611 num_class_images=args.num_class_images, 647 num_class_images=args.num_class_images,
@@ -735,7 +771,7 @@ def main():
735 tokenizer=tokenizer, 771 tokenizer=tokenizer,
736 text_encoder=text_encoder, 772 text_encoder=text_encoder,
737 output_dir=basepath, 773 output_dir=basepath,
738 instance_identifier=args.instance_identifier, 774 instance_identifier=instance_identifier,
739 sample_image_size=args.sample_image_size, 775 sample_image_size=args.sample_image_size,
740 sample_batch_size=args.sample_batch_size, 776 sample_batch_size=args.sample_batch_size,
741 sample_batches=args.sample_batches, 777 sample_batches=args.sample_batches,
diff --git a/dreambooth_plus.py b/dreambooth_plus.py
deleted file mode 100644
index 413abe3..0000000
--- a/dreambooth_plus.py
+++ /dev/null
@@ -1,1023 +0,0 @@
1import argparse
2import itertools
3import math
4import os
5import datetime
6import logging
7import json
8from pathlib import Path
9
10import numpy as np
11import torch
12import torch.nn.functional as F
13import torch.utils.checkpoint
14
15from accelerate import Accelerator
16from accelerate.logging import get_logger
17from accelerate.utils import LoggerType, set_seed
18from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, UNet2DConditionModel
19from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup
20from diffusers.training_utils import EMAModel
21from PIL import Image
22from tqdm.auto import tqdm
23from transformers import CLIPTextModel, CLIPTokenizer
24from slugify import slugify
25
26from schedulers.scheduling_euler_a import EulerAScheduler
27from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
28from data.csv import CSVDataModule
29from models.clip.prompt import PromptProcessor
30
31logger = get_logger(__name__)
32
33
34torch.backends.cuda.matmul.allow_tf32 = True
35
36
37def parse_args():
38 parser = argparse.ArgumentParser(
39 description="Simple example of a training script."
40 )
41 parser.add_argument(
42 "--pretrained_model_name_or_path",
43 type=str,
44 default=None,
45 help="Path to pretrained model or model identifier from huggingface.co/models.",
46 )
47 parser.add_argument(
48 "--tokenizer_name",
49 type=str,
50 default=None,
51 help="Pretrained tokenizer name or path if not the same as model_name",
52 )
53 parser.add_argument(
54 "--train_data_file",
55 type=str,
56 default=None,
57 help="A CSV file containing the training data."
58 )
59 parser.add_argument(
60 "--instance_identifier",
61 type=str,
62 default=None,
63 help="A token to use as a placeholder for the concept.",
64 )
65 parser.add_argument(
66 "--class_identifier",
67 type=str,
68 default=None,
69 help="A token to use as a placeholder for the concept.",
70 )
71 parser.add_argument(
72 "--placeholder_token",
73 type=str,
74 default="<*>",
75 help="A token to use as a placeholder for the concept.",
76 )
77 parser.add_argument(
78 "--initializer_token",
79 type=str,
80 default=None,
81 help="A token to use as initializer word."
82 )
83 parser.add_argument(
84 "--num_class_images",
85 type=int,
86 default=400,
87 help="How many class images to generate."
88 )
89 parser.add_argument(
90 "--repeats",
91 type=int,
92 default=1,
93 help="How many times to repeat the training data."
94 )
95 parser.add_argument(
96 "--output_dir",
97 type=str,
98 default="output/dreambooth-plus",
99 help="The output directory where the model predictions and checkpoints will be written.",
100 )
101 parser.add_argument(
102 "--seed",
103 type=int,
104 default=None,
105 help="A seed for reproducible training.")
106 parser.add_argument(
107 "--resolution",
108 type=int,
109 default=512,
110 help=(
111 "The resolution for input images, all the images in the train/validation dataset will be resized to this"
112 " resolution"
113 ),
114 )
115 parser.add_argument(
116 "--center_crop",
117 action="store_true",
118 help="Whether to center crop images before resizing to resolution"
119 )
120 parser.add_argument(
121 "--num_train_epochs",
122 type=int,
123 default=100
124 )
125 parser.add_argument(
126 "--max_train_steps",
127 type=int,
128 default=4700,
129 help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
130 )
131 parser.add_argument(
132 "--gradient_accumulation_steps",
133 type=int,
134 default=1,
135 help="Number of updates steps to accumulate before performing a backward/update pass.",
136 )
137 parser.add_argument(
138 "--gradient_checkpointing",
139 action="store_true",
140 help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
141 )
142 parser.add_argument(
143 "--learning_rate_unet",
144 type=float,
145 default=2e-6,
146 help="Initial learning rate (after the potential warmup period) to use.",
147 )
148 parser.add_argument(
149 "--learning_rate_text",
150 type=float,
151 default=2e-6,
152 help="Initial learning rate (after the potential warmup period) to use.",
153 )
154 parser.add_argument(
155 "--scale_lr",
156 action="store_true",
157 default=True,
158 help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
159 )
160 parser.add_argument(
161 "--lr_scheduler",
162 type=str,
163 default="cosine_with_restarts",
164 help=(
165 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
166 ' "constant", "constant_with_warmup"]'
167 ),
168 )
169 parser.add_argument(
170 "--lr_warmup_steps",
171 type=int,
172 default=300,
173 help="Number of steps for the warmup in the lr scheduler."
174 )
175 parser.add_argument(
176 "--lr_cycles",
177 type=int,
178 default=None,
179 help="Number of restart cycles in the lr scheduler."
180 )
181 parser.add_argument(
182 "--use_ema",
183 action="store_true",
184 default=True,
185 help="Whether to use EMA model."
186 )
187 parser.add_argument(
188 "--ema_inv_gamma",
189 type=float,
190 default=1.0
191 )
192 parser.add_argument(
193 "--ema_power",
194 type=float,
195 default=9 / 10
196 )
197 parser.add_argument(
198 "--ema_max_decay",
199 type=float,
200 default=0.9999
201 )
202 parser.add_argument(
203 "--use_8bit_adam",
204 action="store_true",
205 default=True,
206 help="Whether or not to use 8-bit Adam from bitsandbytes."
207 )
208 parser.add_argument(
209 "--adam_beta1",
210 type=float,
211 default=0.9,
212 help="The beta1 parameter for the Adam optimizer."
213 )
214 parser.add_argument(
215 "--adam_beta2",
216 type=float,
217 default=0.999,
218 help="The beta2 parameter for the Adam optimizer."
219 )
220 parser.add_argument(
221 "--adam_weight_decay",
222 type=float,
223 default=1e-2,
224 help="Weight decay to use."
225 )
226 parser.add_argument(
227 "--adam_epsilon",
228 type=float,
229 default=1e-08,
230 help="Epsilon value for the Adam optimizer"
231 )
232 parser.add_argument(
233 "--mixed_precision",
234 type=str,
235 default="no",
236 choices=["no", "fp16", "bf16"],
237 help=(
238 "Whether to use mixed precision. Choose"
239 "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
240 "and an Nvidia Ampere GPU."
241 ),
242 )
243 parser.add_argument(
244 "--checkpoint_frequency",
245 type=int,
246 default=500,
247 help="How often to save a checkpoint and sample image",
248 )
249 parser.add_argument(
250 "--sample_frequency",
251 type=int,
252 default=100,
253 help="How often to save a checkpoint and sample image",
254 )
255 parser.add_argument(
256 "--sample_image_size",
257 type=int,
258 default=512,
259 help="Size of sample images",
260 )
261 parser.add_argument(
262 "--sample_batches",
263 type=int,
264 default=1,
265 help="Number of sample batches to generate per checkpoint",
266 )
267 parser.add_argument(
268 "--sample_batch_size",
269 type=int,
270 default=1,
271 help="Number of samples to generate per batch",
272 )
273 parser.add_argument(
274 "--train_batch_size",
275 type=int,
276 default=1,
277 help="Batch size (per device) for the training dataloader."
278 )
279 parser.add_argument(
280 "--sample_steps",
281 type=int,
282 default=30,
283 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.",
284 )
285 parser.add_argument(
286 "--prior_loss_weight",
287 type=float,
288 default=1.0,
289 help="The weight of prior preservation loss."
290 )
291 parser.add_argument(
292 "--max_grad_norm",
293 default=1.0,
294 type=float,
295 help="Max gradient norm."
296 )
297 parser.add_argument(
298 "--noise_timesteps",
299 type=int,
300 default=1000,
301 )
302 parser.add_argument(
303 "--config",
304 type=str,
305 default=None,
306 help="Path to a JSON configuration file containing arguments for invoking this script."
307 )
308
309 args = parser.parse_args()
310 if args.config is not None:
311 with open(args.config, 'rt') as f:
312 args = parser.parse_args(
313 namespace=argparse.Namespace(**json.load(f)["args"]))
314
315 if args.train_data_file is None:
316 raise ValueError("You must specify --train_data_file")
317
318 if args.pretrained_model_name_or_path is None:
319 raise ValueError("You must specify --pretrained_model_name_or_path")
320
321 if args.placeholder_token is None:
322 raise ValueError("You must specify --placeholder_token")
323
324 if args.initializer_token is None:
325 raise ValueError("You must specify --initializer_token")
326
327 if args.output_dir is None:
328 raise ValueError("You must specify --output_dir")
329
330 return args
331
332
333def save_args(basepath: Path, args, extra={}):
334 info = {"args": vars(args)}
335 info["args"].update(extra)
336 with open(basepath.joinpath("args.json"), "w") as f:
337 json.dump(info, f, indent=4)
338
339
340def freeze_params(params):
341 for param in params:
342 param.requires_grad = False
343
344
345def make_grid(images, rows, cols):
346 w, h = images[0].size
347 grid = Image.new('RGB', size=(cols*w, rows*h))
348 for i, image in enumerate(images):
349 grid.paste(image, box=(i % cols*w, i//cols*h))
350 return grid
351
352
353class Checkpointer:
354 def __init__(
355 self,
356 datamodule,
357 accelerator,
358 vae,
359 unet,
360 ema_unet,
361 tokenizer,
362 text_encoder,
363 instance_identifier,
364 placeholder_token,
365 placeholder_token_id,
366 output_dir: Path,
367 sample_image_size,
368 sample_batches,
369 sample_batch_size,
370 seed
371 ):
372 self.datamodule = datamodule
373 self.accelerator = accelerator
374 self.vae = vae
375 self.unet = unet
376 self.ema_unet = ema_unet
377 self.tokenizer = tokenizer
378 self.text_encoder = text_encoder
379 self.instance_identifier = instance_identifier
380 self.placeholder_token = placeholder_token
381 self.placeholder_token_id = placeholder_token_id
382 self.output_dir = output_dir
383 self.sample_image_size = sample_image_size
384 self.seed = seed or torch.random.seed()
385 self.sample_batches = sample_batches
386 self.sample_batch_size = sample_batch_size
387
388 @torch.no_grad()
389 def checkpoint(self, step, postfix):
390 print("Saving checkpoint for step %d..." % step)
391
392 checkpoints_path = self.output_dir.joinpath("checkpoints")
393 checkpoints_path.mkdir(parents=True, exist_ok=True)
394
395 unwrapped = self.accelerator.unwrap_model(self.text_encoder)
396
397 # Save a checkpoint
398 learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id]
399 learned_embeds_dict = {self.placeholder_token: learned_embeds.detach().cpu()}
400
401 filename = f"%s_%d_%s.bin" % (slugify(self.placeholder_token), step, postfix)
402 torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename))
403
404 del unwrapped
405 del learned_embeds
406
407 @torch.no_grad()
408 def save_model(self):
409 print("Saving model...")
410
411 unwrapped_unet = self.accelerator.unwrap_model(
412 self.ema_unet.averaged_model if self.ema_unet is not None else self.unet)
413 unwrapped_text_encoder = self.accelerator.unwrap_model(self.text_encoder)
414 pipeline = VlpnStableDiffusion(
415 text_encoder=unwrapped_text_encoder,
416 vae=self.vae,
417 unet=unwrapped_unet,
418 tokenizer=self.tokenizer,
419 scheduler=PNDMScheduler(
420 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
421 ),
422 )
423 pipeline.save_pretrained(self.output_dir.joinpath("model"))
424
425 del unwrapped_unet
426 del unwrapped_text_encoder
427 del pipeline
428
429 if torch.cuda.is_available():
430 torch.cuda.empty_cache()
431
432 @torch.no_grad()
433 def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps):
434 samples_path = Path(self.output_dir).joinpath("samples")
435
436 unwrapped_unet = self.accelerator.unwrap_model(
437 self.ema_unet.averaged_model if self.ema_unet is not None else self.unet)
438 unwrapped_text_encoder = self.accelerator.unwrap_model(self.text_encoder)
439 scheduler = EulerAScheduler(
440 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
441 )
442
443 # Save a sample image
444 pipeline = VlpnStableDiffusion(
445 text_encoder=unwrapped_text_encoder,
446 vae=self.vae,
447 unet=unwrapped_unet,
448 tokenizer=self.tokenizer,
449 scheduler=scheduler,
450 ).to(self.accelerator.device)
451 pipeline.set_progress_bar_config(dynamic_ncols=True)
452
453 train_data = self.datamodule.train_dataloader()
454 val_data = self.datamodule.val_dataloader()
455
456 generator = torch.Generator(device=pipeline.device).manual_seed(self.seed)
457 stable_latents = torch.randn(
458 (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8),
459 device=pipeline.device,
460 generator=generator,
461 )
462
463 with torch.inference_mode():
464 for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]:
465 all_samples = []
466 file_path = samples_path.joinpath(pool, f"step_{step}.png")
467 file_path.parent.mkdir(parents=True, exist_ok=True)
468
469 data_enum = enumerate(data)
470
471 for i in range(self.sample_batches):
472 batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size]
473 prompt = [
474 prompt.format(self.instance_identifier)
475 for batch in batches
476 for prompt in batch["prompts"]
477 ][:self.sample_batch_size]
478 nprompt = [
479 prompt
480 for batch in batches
481 for prompt in batch["nprompts"]
482 ][:self.sample_batch_size]
483
484 samples = pipeline(
485 prompt=prompt,
486 negative_prompt=nprompt,
487 height=self.sample_image_size,
488 width=self.sample_image_size,
489 latents=latents[:len(prompt)] if latents is not None else None,
490 generator=generator if latents is not None else None,
491 guidance_scale=guidance_scale,
492 eta=eta,
493 num_inference_steps=num_inference_steps,
494 output_type='pil'
495 ).images
496
497 all_samples += samples
498
499 del samples
500
501 image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size)
502 image_grid.save(file_path)
503
504 del all_samples
505 del image_grid
506
507 del unwrapped_unet
508 del unwrapped_text_encoder
509 del scheduler
510 del pipeline
511 del generator
512 del stable_latents
513
514 if torch.cuda.is_available():
515 torch.cuda.empty_cache()
516
517
518def main():
519 args = parse_args()
520
521 global_step_offset = 0
522 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
523 basepath = Path(args.output_dir).joinpath(slugify(args.placeholder_token), now)
524 basepath.mkdir(parents=True, exist_ok=True)
525
526 accelerator = Accelerator(
527 log_with=LoggerType.TENSORBOARD,
528 logging_dir=f"{basepath}",
529 gradient_accumulation_steps=args.gradient_accumulation_steps,
530 mixed_precision=args.mixed_precision
531 )
532
533 logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG)
534
535 save_args(basepath, args)
536
537 # If passed along, set the training seed now.
538 if args.seed is not None:
539 set_seed(args.seed)
540
541 args.instance_identifier = args.instance_identifier.format(args.placeholder_token)
542
543 # Load the tokenizer and add the placeholder token as a additional special token
544 if args.tokenizer_name:
545 tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
546 elif args.pretrained_model_name_or_path:
547 tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer')
548
549 # Convert the initializer_token, placeholder_token to ids
550 initializer_token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False)
551 print(f"Initializer token maps to {len(initializer_token_ids)} embeddings.")
552 initializer_token_ids = torch.tensor(initializer_token_ids[:1])
553
554 # Add the placeholder token in tokenizer
555 num_added_tokens = tokenizer.add_tokens(args.placeholder_token)
556 if num_added_tokens == 0:
557 print(f"Re-using existing token {args.placeholder_token}.")
558 else:
559 print(f"Training new token {args.placeholder_token}.")
560
561 placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
562
563 # Load models and create wrapper for stable diffusion
564 text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder')
565 vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae')
566 unet = UNet2DConditionModel.from_pretrained(
567 args.pretrained_model_name_or_path, subfolder='unet')
568
569 ema_unet = EMAModel(
570 unet,
571 inv_gamma=args.ema_inv_gamma,
572 power=args.ema_power,
573 max_value=args.ema_max_decay,
574 device=accelerator.device
575 ) if args.use_ema else None
576
577 prompt_processor = PromptProcessor(tokenizer, text_encoder)
578
579 if args.gradient_checkpointing:
580 unet.enable_gradient_checkpointing()
581 text_encoder.gradient_checkpointing_enable()
582
583 # slice_size = unet.config.attention_head_dim // 2
584 # unet.set_attention_slice(slice_size)
585
586 # Resize the token embeddings as we are adding new special tokens to the tokenizer
587 text_encoder.resize_token_embeddings(len(tokenizer))
588
589 # Initialise the newly added placeholder token with the embeddings of the initializer token
590 token_embeds = text_encoder.get_input_embeddings().weight.data
591 original_token_embeds = token_embeds.detach().clone().to(accelerator.device)
592 initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids)
593 token_embeds[placeholder_token_id] = initializer_token_embeddings
594
595 # Freeze vae and unet
596 freeze_params(vae.parameters())
597 # Freeze all parameters except for the token embeddings in text encoder
598 params_to_freeze = itertools.chain(
599 text_encoder.text_model.encoder.parameters(),
600 text_encoder.text_model.final_layer_norm.parameters(),
601 text_encoder.text_model.embeddings.position_embedding.parameters(),
602 )
603 freeze_params(params_to_freeze)
604
605 if args.scale_lr:
606 args.learning_rate_unet = (
607 args.learning_rate_unet * args.gradient_accumulation_steps *
608 args.train_batch_size * accelerator.num_processes
609 )
610 args.learning_rate_text = (
611 args.learning_rate_text * args.gradient_accumulation_steps *
612 args.train_batch_size * accelerator.num_processes
613 )
614
615 # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
616 if args.use_8bit_adam:
617 try:
618 import bitsandbytes as bnb
619 except ImportError:
620 raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.")
621
622 optimizer_class = bnb.optim.AdamW8bit
623 else:
624 optimizer_class = torch.optim.AdamW
625
626 # Initialize the optimizer
627 optimizer = optimizer_class(
628 [
629 {
630 'params': unet.parameters(),
631 'lr': args.learning_rate_unet,
632 },
633 {
634 'params': text_encoder.get_input_embeddings().parameters(),
635 'lr': args.learning_rate_text,
636 }
637 ],
638 betas=(args.adam_beta1, args.adam_beta2),
639 weight_decay=args.adam_weight_decay,
640 eps=args.adam_epsilon,
641 )
642
643 noise_scheduler = DDPMScheduler(
644 beta_start=0.00085,
645 beta_end=0.012,
646 beta_schedule="scaled_linear",
647 num_train_timesteps=args.noise_timesteps
648 )
649
650 weight_dtype = torch.float32
651 if args.mixed_precision == "fp16":
652 weight_dtype = torch.float16
653 elif args.mixed_precision == "bf16":
654 weight_dtype = torch.bfloat16
655
656 def collate_fn(examples):
657 prompts = [example["prompts"] for example in examples]
658 nprompts = [example["nprompts"] for example in examples]
659 input_ids = [example["instance_prompt_ids"] for example in examples]
660 pixel_values = [example["instance_images"] for example in examples]
661
662 # concat class and instance examples for prior preservation
663 if args.num_class_images != 0 and "class_prompt_ids" in examples[0]:
664 input_ids += [example["class_prompt_ids"] for example in examples]
665 pixel_values += [example["class_images"] for example in examples]
666
667 pixel_values = torch.stack(pixel_values)
668 pixel_values = pixel_values.to(dtype=weight_dtype, memory_format=torch.contiguous_format)
669
670 input_ids = prompt_processor.unify_input_ids(input_ids)
671
672 batch = {
673 "prompts": prompts,
674 "nprompts": nprompts,
675 "input_ids": input_ids,
676 "pixel_values": pixel_values,
677 }
678 return batch
679
680 datamodule = CSVDataModule(
681 data_file=args.train_data_file,
682 batch_size=args.train_batch_size,
683 prompt_processor=prompt_processor,
684 instance_identifier=args.instance_identifier,
685 class_identifier=args.class_identifier,
686 class_subdir="cls",
687 num_class_images=args.num_class_images,
688 size=args.resolution,
689 repeats=args.repeats,
690 center_crop=args.center_crop,
691 valid_set_size=args.sample_batch_size*args.sample_batches,
692 collate_fn=collate_fn
693 )
694
695 datamodule.prepare_data()
696 datamodule.setup()
697
698 if args.num_class_images != 0:
699 missing_data = [item for item in datamodule.data_train if not item.class_image_path.exists()]
700
701 if len(missing_data) != 0:
702 batched_data = [
703 missing_data[i:i+args.sample_batch_size]
704 for i in range(0, len(missing_data), args.sample_batch_size)
705 ]
706
707 scheduler = EulerAScheduler(
708 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
709 )
710
711 pipeline = VlpnStableDiffusion(
712 text_encoder=text_encoder,
713 vae=vae,
714 unet=unet,
715 tokenizer=tokenizer,
716 scheduler=scheduler,
717 ).to(accelerator.device)
718 pipeline.set_progress_bar_config(dynamic_ncols=True)
719
720 with torch.inference_mode():
721 for batch in batched_data:
722 image_name = [item.class_image_path for item in batch]
723 prompt = [item.prompt.format(args.class_identifier) for item in batch]
724 nprompt = [item.nprompt for item in batch]
725
726 images = pipeline(
727 prompt=prompt,
728 negative_prompt=nprompt,
729 num_inference_steps=args.sample_steps
730 ).images
731
732 for i, image in enumerate(images):
733 image.save(image_name[i])
734
735 del pipeline
736
737 if torch.cuda.is_available():
738 torch.cuda.empty_cache()
739
740 train_dataloader = datamodule.train_dataloader()
741 val_dataloader = datamodule.val_dataloader()
742
743 # Scheduler and math around the number of training steps.
744 overrode_max_train_steps = False
745 num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
746 if args.max_train_steps is None:
747 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
748 overrode_max_train_steps = True
749 num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
750
751 if args.lr_scheduler == "cosine_with_restarts":
752 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
753 optimizer=optimizer,
754 num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
755 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
756 num_cycles=args.lr_cycles or math.ceil(math.sqrt(
757 ((args.max_train_steps - args.lr_warmup_steps) / num_update_steps_per_epoch))),
758 )
759 else:
760 lr_scheduler = get_scheduler(
761 args.lr_scheduler,
762 optimizer=optimizer,
763 num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
764 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
765 )
766
767 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
768 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler
769 )
770
771 # Move vae and unet to device
772 vae.to(accelerator.device, dtype=weight_dtype)
773
774 # Keep vae and unet in eval mode as we don't train these
775 vae.eval()
776
777 # We need to recalculate our total training steps as the size of the training dataloader may have changed.
778 num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
779 if overrode_max_train_steps:
780 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
781
782 num_val_steps_per_epoch = len(val_dataloader)
783 num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
784 val_steps = num_val_steps_per_epoch * num_epochs
785
786 # We need to initialize the trackers we use, and also store our configuration.
787 # The trackers initializes automatically on the main process.
788 if accelerator.is_main_process:
789 accelerator.init_trackers("dreambooth_plus", config=vars(args))
790
791 # Train!
792 total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
793
794 logger.info("***** Running training *****")
795 logger.info(f" Num Epochs = {num_epochs}")
796 logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
797 logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
798 logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
799 logger.info(f" Total optimization steps = {args.max_train_steps}")
800 # Only show the progress bar once on each machine.
801
802 global_step = 0
803 min_val_loss = np.inf
804
805 checkpointer = Checkpointer(
806 datamodule=datamodule,
807 accelerator=accelerator,
808 vae=vae,
809 unet=unet,
810 ema_unet=ema_unet,
811 tokenizer=tokenizer,
812 text_encoder=text_encoder,
813 instance_identifier=args.instance_identifier,
814 placeholder_token=args.placeholder_token,
815 placeholder_token_id=placeholder_token_id,
816 output_dir=basepath,
817 sample_image_size=args.sample_image_size,
818 sample_batch_size=args.sample_batch_size,
819 sample_batches=args.sample_batches,
820 seed=args.seed
821 )
822
823 if accelerator.is_main_process:
824 checkpointer.save_samples(
825 0,
826 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps)
827
828 local_progress_bar = tqdm(
829 range(num_update_steps_per_epoch + num_val_steps_per_epoch),
830 disable=not accelerator.is_local_main_process,
831 dynamic_ncols=True
832 )
833 local_progress_bar.set_description("Epoch X / Y")
834
835 global_progress_bar = tqdm(
836 range(args.max_train_steps + val_steps),
837 disable=not accelerator.is_local_main_process,
838 dynamic_ncols=True
839 )
840 global_progress_bar.set_description("Total progress")
841
842 try:
843 for epoch in range(num_epochs):
844 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}")
845 local_progress_bar.reset()
846
847 text_encoder.train()
848 train_loss = 0.0
849
850 sample_checkpoint = False
851
852 for step, batch in enumerate(train_dataloader):
853 with accelerator.accumulate(itertools.chain(unet, text_encoder)):
854 # Convert images to latent space
855 latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
856 latents = latents * 0.18215
857
858 # Sample noise that we'll add to the latents
859 noise = torch.randn_like(latents)
860 bsz = latents.shape[0]
861 # Sample a random timestep for each image
862 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps,
863 (bsz,), device=latents.device)
864 timesteps = timesteps.long()
865
866 # Add noise to the latents according to the noise magnitude at each timestep
867 # (this is the forward diffusion process)
868 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
869
870 # Get the text embedding for conditioning
871 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"])
872
873 # Predict the noise residual
874 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
875
876 if args.num_class_images != 0:
877 # Chunk the noise and noise_pred into two parts and compute the loss on each part separately.
878 noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0)
879 noise, noise_prior = torch.chunk(noise, 2, dim=0)
880
881 # Compute instance loss
882 loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="none").mean([1, 2, 3]).mean()
883
884 # Compute prior loss
885 prior_loss = F.mse_loss(noise_pred_prior.float(), noise_prior.float(), reduction="mean")
886
887 # Add the prior loss to the instance loss.
888 loss = loss + args.prior_loss_weight * prior_loss
889 else:
890 loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
891
892 accelerator.backward(loss)
893
894 # Keep the token embeddings fixed except the newly added
895 # embeddings for the concept, as we only want to optimize the concept embeddings
896 if accelerator.num_processes > 1:
897 token_embeds = text_encoder.module.get_input_embeddings().weight
898 else:
899 token_embeds = text_encoder.get_input_embeddings().weight
900
901 # Get the index for tokens that we want to freeze
902 index_fixed_tokens = torch.arange(len(tokenizer)) != placeholder_token_id
903 token_embeds.data[index_fixed_tokens, :] = original_token_embeds[index_fixed_tokens, :]
904
905 if accelerator.sync_gradients:
906 accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
907
908 optimizer.step()
909 if not accelerator.optimizer_step_was_skipped:
910 lr_scheduler.step()
911 optimizer.zero_grad(set_to_none=True)
912
913 loss = loss.detach().item()
914 train_loss += loss
915
916 # Checks if the accelerator has performed an optimization step behind the scenes
917 if accelerator.sync_gradients:
918 if args.use_ema:
919 ema_unet.step(unet)
920
921 local_progress_bar.update(1)
922 global_progress_bar.update(1)
923
924 global_step += 1
925
926 if global_step % args.sample_frequency == 0:
927 sample_checkpoint = True
928
929 if global_step % args.checkpoint_frequency == 0 and global_step > 0 and accelerator.is_main_process:
930 local_progress_bar.clear()
931 global_progress_bar.clear()
932
933 checkpointer.checkpoint(global_step + global_step_offset, "training")
934
935 logs = {
936 "train/loss": loss,
937 "lr/unet": lr_scheduler.get_last_lr()[0],
938 "lr/text": lr_scheduler.get_last_lr()[1]
939 }
940 if args.use_ema:
941 logs["ema_decay"] = ema_unet.decay
942
943 accelerator.log(logs, step=global_step)
944
945 local_progress_bar.set_postfix(**logs)
946
947 if global_step >= args.max_train_steps:
948 break
949
950 train_loss /= len(train_dataloader)
951
952 accelerator.wait_for_everyone()
953
954 text_encoder.eval()
955 val_loss = 0.0
956
957 for step, batch in enumerate(val_dataloader):
958 with torch.no_grad():
959 latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
960 latents = latents * 0.18215
961
962 noise = torch.randn_like(latents)
963 bsz = latents.shape[0]
964 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps,
965 (bsz,), device=latents.device)
966 timesteps = timesteps.long()
967
968 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
969
970 encoder_hidden_states = prompt_processor.get_embeddings(batch["input_ids"])
971
972 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
973
974 noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise))
975
976 loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
977
978 loss = loss.detach().item()
979 val_loss += loss
980
981 if accelerator.sync_gradients:
982 local_progress_bar.update(1)
983 global_progress_bar.update(1)
984
985 logs = {"val/loss": loss}
986 local_progress_bar.set_postfix(**logs)
987
988 val_loss /= len(val_dataloader)
989
990 accelerator.log({"val/loss": val_loss}, step=global_step)
991
992 local_progress_bar.clear()
993 global_progress_bar.clear()
994
995 if min_val_loss > val_loss:
996 accelerator.print(
997 f"Global step {global_step}: Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}")
998 checkpointer.checkpoint(global_step + global_step_offset, "milestone")
999 min_val_loss = val_loss
1000
1001 if sample_checkpoint and accelerator.is_main_process:
1002 checkpointer.save_samples(
1003 global_step + global_step_offset,
1004 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps)
1005
1006 # Create the pipeline using using the trained modules and save it.
1007 if accelerator.is_main_process:
1008 print("Finished! Saving final checkpoint and resume state.")
1009 checkpointer.checkpoint(global_step + global_step_offset, "end")
1010 checkpointer.save_model()
1011 accelerator.end_training()
1012
1013 except KeyboardInterrupt:
1014 if accelerator.is_main_process:
1015 print("Interrupted, saving checkpoint and resume state...")
1016 checkpointer.checkpoint(global_step + global_step_offset, "end")
1017 checkpointer.save_model()
1018 accelerator.end_training()
1019 quit()
1020
1021
1022if __name__ == "__main__":
1023 main()