summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-10-13 09:45:27 +0200
committerVolpeon <git@volpeon.ink>2022-10-13 09:45:27 +0200
commitdb0996c299fdd559ebf9cd48f9dbe47474ed7b07 (patch)
tree0d306c661ed5629e7d69566a82d588aca5ed86a9
parentVarious updates (diff)
downloadtextual-inversion-diff-db0996c299fdd559ebf9cd48f9dbe47474ed7b07.tar.gz
textual-inversion-diff-db0996c299fdd559ebf9cd48f9dbe47474ed7b07.tar.bz2
textual-inversion-diff-db0996c299fdd559ebf9cd48f9dbe47474ed7b07.zip
Added TI+Dreambooth training
-rw-r--r--dreambooth.py15
-rw-r--r--dreambooth_plus.py939
-rw-r--r--infer.py8
-rw-r--r--textual_inversion.py11
4 files changed, 961 insertions, 12 deletions
diff --git a/dreambooth.py b/dreambooth.py
index 775aea2..699313e 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -129,7 +129,7 @@ def parse_args():
129 parser.add_argument( 129 parser.add_argument(
130 "--learning_rate", 130 "--learning_rate",
131 type=float, 131 type=float,
132 default=1e-6, 132 default=5e-6,
133 help="Initial learning rate (after the potential warmup period) to use.", 133 help="Initial learning rate (after the potential warmup period) to use.",
134 ) 134 )
135 parser.add_argument( 135 parser.add_argument(
@@ -167,7 +167,7 @@ def parse_args():
167 parser.add_argument( 167 parser.add_argument(
168 "--ema_power", 168 "--ema_power",
169 type=float, 169 type=float,
170 default=7 / 8 170 default=6 / 7
171 ) 171 )
172 parser.add_argument( 172 parser.add_argument(
173 "--ema_max_decay", 173 "--ema_max_decay",
@@ -270,6 +270,11 @@ def parse_args():
270 help="Max gradient norm." 270 help="Max gradient norm."
271 ) 271 )
272 parser.add_argument( 272 parser.add_argument(
273 "--noise_timesteps",
274 type=int,
275 default=1000,
276 )
277 parser.add_argument(
273 "--config", 278 "--config",
274 type=str, 279 type=str,
275 default=None, 280 default=None,
@@ -480,7 +485,8 @@ def main():
480 unet, 485 unet,
481 inv_gamma=args.ema_inv_gamma, 486 inv_gamma=args.ema_inv_gamma,
482 power=args.ema_power, 487 power=args.ema_power,
483 max_value=args.ema_max_decay 488 max_value=args.ema_max_decay,
489 device=accelerator.device
484 ) if args.use_ema else None 490 ) if args.use_ema else None
485 491
486 if args.gradient_checkpointing: 492 if args.gradient_checkpointing:
@@ -523,7 +529,7 @@ def main():
523 beta_start=0.00085, 529 beta_start=0.00085,
524 beta_end=0.012, 530 beta_end=0.012,
525 beta_schedule="scaled_linear", 531 beta_schedule="scaled_linear",
526 num_train_timesteps=1000 532 num_train_timesteps=args.noise_timesteps
527 ) 533 )
528 534
529 def collate_fn(examples): 535 def collate_fn(examples):
@@ -632,7 +638,6 @@ def main():
632 # Move text_encoder and vae to device 638 # Move text_encoder and vae to device
633 text_encoder.to(accelerator.device) 639 text_encoder.to(accelerator.device)
634 vae.to(accelerator.device) 640 vae.to(accelerator.device)
635 ema_unet.averaged_model.to(accelerator.device)
636 641
637 # Keep text_encoder and vae in eval mode as we don't train these 642 # Keep text_encoder and vae in eval mode as we don't train these
638 text_encoder.eval() 643 text_encoder.eval()
diff --git a/dreambooth_plus.py b/dreambooth_plus.py
new file mode 100644
index 0000000..9e482b3
--- /dev/null
+++ b/dreambooth_plus.py
@@ -0,0 +1,939 @@
1import argparse
2import itertools
3import math
4import os
5import datetime
6import logging
7import json
8from pathlib import Path
9
10import numpy as np
11import torch
12import torch.nn.functional as F
13import torch.utils.checkpoint
14
15from accelerate import Accelerator
16from accelerate.logging import get_logger
17from accelerate.utils import LoggerType, set_seed
18from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, UNet2DConditionModel
19from diffusers.optimization import get_scheduler
20from diffusers.training_utils import EMAModel
21from PIL import Image
22from tqdm.auto import tqdm
23from transformers import CLIPTextModel, CLIPTokenizer
24from slugify import slugify
25
26from schedulers.scheduling_euler_a import EulerAScheduler
27from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
28from data.csv import CSVDataModule
29
30logger = get_logger(__name__)
31
32
33torch.backends.cuda.matmul.allow_tf32 = True
34
35
36def parse_args():
37 parser = argparse.ArgumentParser(
38 description="Simple example of a training script."
39 )
40 parser.add_argument(
41 "--pretrained_model_name_or_path",
42 type=str,
43 default=None,
44 help="Path to pretrained model or model identifier from huggingface.co/models.",
45 )
46 parser.add_argument(
47 "--tokenizer_name",
48 type=str,
49 default=None,
50 help="Pretrained tokenizer name or path if not the same as model_name",
51 )
52 parser.add_argument(
53 "--train_data_file",
54 type=str,
55 default=None,
56 help="A CSV file containing the training data."
57 )
58 parser.add_argument(
59 "--placeholder_token",
60 type=str,
61 default=None,
62 help="A token to use as a placeholder for the concept.",
63 )
64 parser.add_argument(
65 "--initializer_token",
66 type=str,
67 default=None,
68 help="A token to use as initializer word."
69 )
70 parser.add_argument(
71 "--num_class_images",
72 type=int,
73 default=400,
74 help="How many class images to generate per training image."
75 )
76 parser.add_argument(
77 "--repeats",
78 type=int,
79 default=1,
80 help="How many times to repeat the training data."
81 )
82 parser.add_argument(
83 "--output_dir",
84 type=str,
85 default="output/dreambooth-plus",
86 help="The output directory where the model predictions and checkpoints will be written.",
87 )
88 parser.add_argument(
89 "--seed",
90 type=int,
91 default=None,
92 help="A seed for reproducible training.")
93 parser.add_argument(
94 "--resolution",
95 type=int,
96 default=512,
97 help=(
98 "The resolution for input images, all the images in the train/validation dataset will be resized to this"
99 " resolution"
100 ),
101 )
102 parser.add_argument(
103 "--center_crop",
104 action="store_true",
105 help="Whether to center crop images before resizing to resolution"
106 )
107 parser.add_argument(
108 "--num_train_epochs",
109 type=int,
110 default=100
111 )
112 parser.add_argument(
113 "--max_train_steps",
114 type=int,
115 default=3000,
116 help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
117 )
118 parser.add_argument(
119 "--gradient_accumulation_steps",
120 type=int,
121 default=1,
122 help="Number of updates steps to accumulate before performing a backward/update pass.",
123 )
124 parser.add_argument(
125 "--gradient_checkpointing",
126 action="store_true",
127 help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
128 )
129 parser.add_argument(
130 "--learning_rate_unet",
131 type=float,
132 default=1e-5,
133 help="Initial learning rate (after the potential warmup period) to use.",
134 )
135 parser.add_argument(
136 "--learning_rate_text",
137 type=float,
138 default=1e-4,
139 help="Initial learning rate (after the potential warmup period) to use.",
140 )
141 parser.add_argument(
142 "--scale_lr",
143 action="store_true",
144 default=True,
145 help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
146 )
147 parser.add_argument(
148 "--lr_scheduler",
149 type=str,
150 default="cosine",
151 help=(
152 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
153 ' "constant", "constant_with_warmup"]'
154 ),
155 )
156 parser.add_argument(
157 "--lr_warmup_steps",
158 type=int,
159 default=500,
160 help="Number of steps for the warmup in the lr scheduler."
161 )
162 parser.add_argument(
163 "--use_ema",
164 action="store_true",
165 default=True,
166 help="Whether to use EMA model."
167 )
168 parser.add_argument(
169 "--ema_inv_gamma",
170 type=float,
171 default=1.0
172 )
173 parser.add_argument(
174 "--ema_power",
175 type=float,
176 default=6 / 7
177 )
178 parser.add_argument(
179 "--ema_max_decay",
180 type=float,
181 default=0.9999
182 )
183 parser.add_argument(
184 "--use_8bit_adam",
185 action="store_true",
186 default=True,
187 help="Whether or not to use 8-bit Adam from bitsandbytes."
188 )
189 parser.add_argument(
190 "--adam_beta1",
191 type=float,
192 default=0.9,
193 help="The beta1 parameter for the Adam optimizer."
194 )
195 parser.add_argument(
196 "--adam_beta2",
197 type=float,
198 default=0.999,
199 help="The beta2 parameter for the Adam optimizer."
200 )
201 parser.add_argument(
202 "--adam_weight_decay",
203 type=float,
204 default=1e-2,
205 help="Weight decay to use."
206 )
207 parser.add_argument(
208 "--adam_epsilon",
209 type=float,
210 default=1e-08,
211 help="Epsilon value for the Adam optimizer"
212 )
213 parser.add_argument(
214 "--mixed_precision",
215 type=str,
216 default="no",
217 choices=["no", "fp16", "bf16"],
218 help=(
219 "Whether to use mixed precision. Choose"
220 "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
221 "and an Nvidia Ampere GPU."
222 ),
223 )
224 parser.add_argument(
225 "--local_rank",
226 type=int,
227 default=-1,
228 help="For distributed training: local_rank"
229 )
230 parser.add_argument(
231 "--sample_frequency",
232 type=int,
233 default=100,
234 help="How often to save a checkpoint and sample image",
235 )
236 parser.add_argument(
237 "--sample_image_size",
238 type=int,
239 default=512,
240 help="Size of sample images",
241 )
242 parser.add_argument(
243 "--sample_batches",
244 type=int,
245 default=1,
246 help="Number of sample batches to generate per checkpoint",
247 )
248 parser.add_argument(
249 "--sample_batch_size",
250 type=int,
251 default=1,
252 help="Number of samples to generate per batch",
253 )
254 parser.add_argument(
255 "--train_batch_size",
256 type=int,
257 default=1,
258 help="Batch size (per device) for the training dataloader."
259 )
260 parser.add_argument(
261 "--sample_steps",
262 type=int,
263 default=30,
264 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.",
265 )
266 parser.add_argument(
267 "--prior_loss_weight",
268 type=float,
269 default=1.0,
270 help="The weight of prior preservation loss."
271 )
272 parser.add_argument(
273 "--max_grad_norm",
274 default=1.0,
275 type=float,
276 help="Max gradient norm."
277 )
278 parser.add_argument(
279 "--noise_timesteps",
280 type=int,
281 default=1000,
282 )
283 parser.add_argument(
284 "--config",
285 type=str,
286 default=None,
287 help="Path to a JSON configuration file containing arguments for invoking this script."
288 )
289
290 args = parser.parse_args()
291 if args.config is not None:
292 with open(args.config, 'rt') as f:
293 args = parser.parse_args(
294 namespace=argparse.Namespace(**json.load(f)["args"]))
295
296 env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
297 if env_local_rank != -1 and env_local_rank != args.local_rank:
298 args.local_rank = env_local_rank
299
300 if args.train_data_file is None:
301 raise ValueError("You must specify --train_data_file")
302
303 if args.pretrained_model_name_or_path is None:
304 raise ValueError("You must specify --pretrained_model_name_or_path")
305
306 if args.placeholder_token is None:
307 raise ValueError("You must specify --placeholder_token")
308
309 if args.initializer_token is None:
310 raise ValueError("You must specify --initializer_token")
311
312 if args.output_dir is None:
313 raise ValueError("You must specify --output_dir")
314
315 return args
316
317
318def freeze_params(params):
319 for param in params:
320 param.requires_grad = False
321
322
323def make_grid(images, rows, cols):
324 w, h = images[0].size
325 grid = Image.new('RGB', size=(cols*w, rows*h))
326 for i, image in enumerate(images):
327 grid.paste(image, box=(i % cols*w, i//cols*h))
328 return grid
329
330
331class Checkpointer:
332 def __init__(
333 self,
334 datamodule,
335 accelerator,
336 vae,
337 unet,
338 ema_unet,
339 tokenizer,
340 text_encoder,
341 placeholder_token,
342 placeholder_token_id,
343 output_dir: Path,
344 sample_image_size,
345 sample_batches,
346 sample_batch_size,
347 seed
348 ):
349 self.datamodule = datamodule
350 self.accelerator = accelerator
351 self.vae = vae
352 self.unet = unet
353 self.ema_unet = ema_unet
354 self.tokenizer = tokenizer
355 self.text_encoder = text_encoder
356 self.placeholder_token = placeholder_token
357 self.placeholder_token_id = placeholder_token_id
358 self.output_dir = output_dir
359 self.sample_image_size = sample_image_size
360 self.seed = seed or torch.random.seed()
361 self.sample_batches = sample_batches
362 self.sample_batch_size = sample_batch_size
363
364 @torch.no_grad()
365 def checkpoint(self):
366 print("Saving model...")
367
368 unwrapped_unet = self.accelerator.unwrap_model(
369 self.ema_unet.averaged_model if self.ema_unet is not None else self.unet)
370 unwrapped_text_encoder = self.accelerator.unwrap_model(self.text_encoder)
371 pipeline = VlpnStableDiffusion(
372 text_encoder=unwrapped_text_encoder,
373 vae=self.vae,
374 unet=unwrapped_unet,
375 tokenizer=self.tokenizer,
376 scheduler=PNDMScheduler(
377 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
378 ),
379 )
380 pipeline.save_pretrained(self.output_dir.joinpath("model"))
381
382 del unwrapped_unet
383 del unwrapped_text_encoder
384 del pipeline
385
386 if torch.cuda.is_available():
387 torch.cuda.empty_cache()
388
389 @torch.no_grad()
390 def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps):
391 samples_path = Path(self.output_dir).joinpath("samples")
392
393 unwrapped_unet = self.accelerator.unwrap_model(
394 self.ema_unet.averaged_model if self.ema_unet is not None else self.unet)
395 unwrapped_text_encoder = self.accelerator.unwrap_model(self.text_encoder)
396 scheduler = EulerAScheduler(
397 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
398 )
399
400 # Save a sample image
401 pipeline = VlpnStableDiffusion(
402 text_encoder=unwrapped_text_encoder,
403 vae=self.vae,
404 unet=unwrapped_unet,
405 tokenizer=self.tokenizer,
406 scheduler=scheduler,
407 ).to(self.accelerator.device)
408 pipeline.set_progress_bar_config(dynamic_ncols=True)
409
410 train_data = self.datamodule.train_dataloader()
411 val_data = self.datamodule.val_dataloader()
412
413 generator = torch.Generator(device=pipeline.device).manual_seed(self.seed)
414 stable_latents = torch.randn(
415 (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8),
416 device=pipeline.device,
417 generator=generator,
418 )
419
420 with torch.inference_mode():
421 for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]:
422 all_samples = []
423 file_path = samples_path.joinpath(pool, f"step_{step}.png")
424 file_path.parent.mkdir(parents=True, exist_ok=True)
425
426 data_enum = enumerate(data)
427
428 for i in range(self.sample_batches):
429 batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size]
430 prompt = [prompt.format(self.placeholder_token)
431 for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size]
432 nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size]
433
434 samples = pipeline(
435 prompt=prompt,
436 negative_prompt=nprompt,
437 height=self.sample_image_size,
438 width=self.sample_image_size,
439 latents=latents[:len(prompt)] if latents is not None else None,
440 generator=generator if latents is not None else None,
441 guidance_scale=guidance_scale,
442 eta=eta,
443 num_inference_steps=num_inference_steps,
444 output_type='pil'
445 )["sample"]
446
447 all_samples += samples
448
449 del samples
450
451 image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size)
452 image_grid.save(file_path)
453
454 del all_samples
455 del image_grid
456
457 del unwrapped_unet
458 del unwrapped_text_encoder
459 del scheduler
460 del pipeline
461 del generator
462 del stable_latents
463
464 if torch.cuda.is_available():
465 torch.cuda.empty_cache()
466
467
468def main():
469 args = parse_args()
470
471 global_step_offset = 0
472 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
473 basepath = Path(args.output_dir).joinpath(slugify(args.placeholder_token), now)
474 basepath.mkdir(parents=True, exist_ok=True)
475
476 accelerator = Accelerator(
477 log_with=LoggerType.TENSORBOARD,
478 logging_dir=f"{basepath}",
479 gradient_accumulation_steps=args.gradient_accumulation_steps,
480 mixed_precision=args.mixed_precision
481 )
482
483 logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG)
484
485 # If passed along, set the training seed now.
486 if args.seed is not None:
487 set_seed(args.seed)
488
489 # Load the tokenizer and add the placeholder token as a additional special token
490 if args.tokenizer_name:
491 tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
492 elif args.pretrained_model_name_or_path:
493 tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder='tokenizer')
494
495 # Add the placeholder token in tokenizer
496 num_added_tokens = tokenizer.add_tokens(args.placeholder_token)
497 if num_added_tokens == 0:
498 raise ValueError(
499 f"The tokenizer already contains the token {args.placeholder_token}. Please pass a different"
500 " `placeholder_token` that is not already in the tokenizer."
501 )
502
503 # Convert the initializer_token, placeholder_token to ids
504 initializer_token_ids = tokenizer.encode(args.initializer_token, add_special_tokens=False)
505 # Check if initializer_token is a single token or a sequence of tokens
506 if len(initializer_token_ids) > 1:
507 raise ValueError(
508 f"initializer_token_ids must not have more than 1 vector, but it's {len(initializer_token_ids)}.")
509
510 initializer_token_ids = torch.tensor(initializer_token_ids)
511 placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token)
512
513 # Load models and create wrapper for stable diffusion
514 text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder='text_encoder')
515 vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder='vae')
516 unet = UNet2DConditionModel.from_pretrained(
517 args.pretrained_model_name_or_path, subfolder='unet')
518
519 ema_unet = EMAModel(
520 unet,
521 inv_gamma=args.ema_inv_gamma,
522 power=args.ema_power,
523 max_value=args.ema_max_decay,
524 device=accelerator.device
525 ) if args.use_ema else None
526
527 if args.gradient_checkpointing:
528 unet.enable_gradient_checkpointing()
529
530 # slice_size = unet.config.attention_head_dim // 2
531 # unet.set_attention_slice(slice_size)
532
533 # Resize the token embeddings as we are adding new special tokens to the tokenizer
534 text_encoder.resize_token_embeddings(len(tokenizer))
535
536 # Initialise the newly added placeholder token with the embeddings of the initializer token
537 token_embeds = text_encoder.get_input_embeddings().weight.data
538 initializer_token_embeddings = text_encoder.get_input_embeddings()(initializer_token_ids)
539 token_embeds[placeholder_token_id] = initializer_token_embeddings
540
541 # Freeze vae and unet
542 freeze_params(vae.parameters())
543 # Freeze all parameters except for the token embeddings in text encoder
544 params_to_freeze = itertools.chain(
545 text_encoder.text_model.encoder.parameters(),
546 text_encoder.text_model.final_layer_norm.parameters(),
547 text_encoder.text_model.embeddings.position_embedding.parameters(),
548 )
549 freeze_params(params_to_freeze)
550
551 if args.scale_lr:
552 args.learning_rate_unet = (
553 args.learning_rate_unet * args.gradient_accumulation_steps *
554 args.train_batch_size * accelerator.num_processes
555 )
556 args.learning_rate_text = (
557 args.learning_rate_text * args.gradient_accumulation_steps *
558 args.train_batch_size * accelerator.num_processes
559 )
560
561 # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
562 if args.use_8bit_adam:
563 try:
564 import bitsandbytes as bnb
565 except ImportError:
566 raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.")
567
568 optimizer_class = bnb.optim.AdamW8bit
569 else:
570 optimizer_class = torch.optim.AdamW
571
572 # Initialize the optimizer
573 optimizer = optimizer_class(
574 [
575 {
576 'params': unet.parameters(),
577 'lr': args.learning_rate_unet,
578 },
579 {
580 'params': text_encoder.get_input_embeddings().parameters(),
581 'lr': args.learning_rate_text,
582 }
583 ],
584 betas=(args.adam_beta1, args.adam_beta2),
585 weight_decay=args.adam_weight_decay,
586 eps=args.adam_epsilon,
587 )
588
589 noise_scheduler = DDPMScheduler(
590 beta_start=0.00085,
591 beta_end=0.012,
592 beta_schedule="scaled_linear",
593 num_train_timesteps=args.noise_timesteps
594 )
595
596 def collate_fn(examples):
597 prompts = [example["prompts"] for example in examples]
598 nprompts = [example["nprompts"] for example in examples]
599 input_ids = [example["instance_prompt_ids"] for example in examples]
600 pixel_values = [example["instance_images"] for example in examples]
601
602 # concat class and instance examples for prior preservation
603 if args.num_class_images != 0 and "class_prompt_ids" in examples[0]:
604 input_ids += [example["class_prompt_ids"] for example in examples]
605 pixel_values += [example["class_images"] for example in examples]
606
607 pixel_values = torch.stack(pixel_values)
608 pixel_values = pixel_values.to(dtype=torch.float32, memory_format=torch.contiguous_format)
609
610 input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids
611
612 batch = {
613 "prompts": prompts,
614 "nprompts": nprompts,
615 "input_ids": input_ids,
616 "pixel_values": pixel_values,
617 }
618 return batch
619
620 datamodule = CSVDataModule(
621 data_file=args.train_data_file,
622 batch_size=args.train_batch_size,
623 tokenizer=tokenizer,
624 instance_identifier=args.placeholder_token,
625 class_identifier=args.initializer_token,
626 class_subdir="cls",
627 num_class_images=args.num_class_images,
628 size=args.resolution,
629 repeats=args.repeats,
630 center_crop=args.center_crop,
631 valid_set_size=args.sample_batch_size*args.sample_batches,
632 collate_fn=collate_fn
633 )
634
635 datamodule.prepare_data()
636 datamodule.setup()
637
638 if args.num_class_images != 0:
639 missing_data = [item for item in datamodule.data_train if not item.class_image_path.exists()]
640
641 if len(missing_data) != 0:
642 batched_data = [missing_data[i:i+args.sample_batch_size]
643 for i in range(0, len(missing_data), args.sample_batch_size)]
644
645 scheduler = EulerAScheduler(
646 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
647 )
648
649 pipeline = VlpnStableDiffusion(
650 text_encoder=text_encoder,
651 vae=vae,
652 unet=unet,
653 tokenizer=tokenizer,
654 scheduler=scheduler,
655 ).to(accelerator.device)
656 pipeline.set_progress_bar_config(dynamic_ncols=True)
657
658 with torch.inference_mode():
659 for batch in batched_data:
660 image_name = [p.class_image_path for p in batch]
661 prompt = [p.prompt.format(args.initializer_token) for p in batch]
662 nprompt = [p.nprompt for p in batch]
663
664 images = pipeline(
665 prompt=prompt,
666 negative_prompt=nprompt,
667 num_inference_steps=args.sample_steps
668 ).images
669
670 for i, image in enumerate(images):
671 image.save(image_name[i])
672
673 del pipeline
674
675 if torch.cuda.is_available():
676 torch.cuda.empty_cache()
677
678 train_dataloader = datamodule.train_dataloader()
679 val_dataloader = datamodule.val_dataloader()
680
681 # Scheduler and math around the number of training steps.
682 overrode_max_train_steps = False
683 num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
684 if args.max_train_steps is None:
685 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
686 overrode_max_train_steps = True
687
688 lr_scheduler = get_scheduler(
689 args.lr_scheduler,
690 optimizer=optimizer,
691 num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
692 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
693 )
694
695 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
696 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler
697 )
698
699 # Move vae and unet to device
700 vae.to(accelerator.device)
701
702 # Keep vae and unet in eval mode as we don't train these
703 vae.eval()
704
705 # We need to recalculate our total training steps as the size of the training dataloader may have changed.
706 num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
707 if overrode_max_train_steps:
708 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
709
710 num_val_steps_per_epoch = len(val_dataloader)
711 num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
712 val_steps = num_val_steps_per_epoch * num_epochs
713
714 # We need to initialize the trackers we use, and also store our configuration.
715 # The trackers initializes automatically on the main process.
716 if accelerator.is_main_process:
717 accelerator.init_trackers("dreambooth_plus", config=vars(args))
718
719 # Train!
720 total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
721
722 logger.info("***** Running training *****")
723 logger.info(f" Num Epochs = {num_epochs}")
724 logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
725 logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
726 logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
727 logger.info(f" Total optimization steps = {args.max_train_steps}")
728 # Only show the progress bar once on each machine.
729
730 global_step = 0
731 min_val_loss = np.inf
732
733 checkpointer = Checkpointer(
734 datamodule=datamodule,
735 accelerator=accelerator,
736 vae=vae,
737 unet=unet,
738 ema_unet=ema_unet,
739 tokenizer=tokenizer,
740 text_encoder=text_encoder,
741 placeholder_token=args.placeholder_token,
742 placeholder_token_id=placeholder_token_id,
743 output_dir=basepath,
744 sample_image_size=args.sample_image_size,
745 sample_batch_size=args.sample_batch_size,
746 sample_batches=args.sample_batches,
747 seed=args.seed
748 )
749
750 if accelerator.is_main_process:
751 checkpointer.save_samples(
752 0,
753 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps)
754
755 local_progress_bar = tqdm(
756 range(num_update_steps_per_epoch + num_val_steps_per_epoch),
757 disable=not accelerator.is_local_main_process,
758 dynamic_ncols=True
759 )
760 local_progress_bar.set_description("Epoch X / Y")
761
762 global_progress_bar = tqdm(
763 range(args.max_train_steps + val_steps),
764 disable=not accelerator.is_local_main_process,
765 dynamic_ncols=True
766 )
767 global_progress_bar.set_description("Total progress")
768
769 try:
770 for epoch in range(num_epochs):
771 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}")
772 local_progress_bar.reset()
773
774 text_encoder.train()
775 train_loss = 0.0
776
777 sample_checkpoint = False
778
779 for step, batch in enumerate(train_dataloader):
780 with accelerator.accumulate(itertools.chain(unet, text_encoder)):
781 # Convert images to latent space
782 latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
783 latents = latents * 0.18215
784
785 # Sample noise that we'll add to the latents
786 noise = torch.randn(latents.shape).to(latents.device)
787 bsz = latents.shape[0]
788 # Sample a random timestep for each image
789 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps,
790 (bsz,), device=latents.device)
791 timesteps = timesteps.long()
792
793 # Add noise to the latents according to the noise magnitude at each timestep
794 # (this is the forward diffusion process)
795 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
796
797 # Get the text embedding for conditioning
798 encoder_hidden_states = text_encoder(batch["input_ids"])[0]
799
800 # Predict the noise residual
801 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
802
803 if args.num_class_images != 0:
804 # Chunk the noise and noise_pred into two parts and compute the loss on each part separately.
805 noise_pred, noise_pred_prior = torch.chunk(noise_pred, 2, dim=0)
806 noise, noise_prior = torch.chunk(noise, 2, dim=0)
807
808 # Compute instance loss
809 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
810
811 # Compute prior loss
812 prior_loss = F.mse_loss(noise_pred_prior, noise_prior, reduction="none").mean([1, 2, 3]).mean()
813
814 # Add the prior loss to the instance loss.
815 loss = loss + args.prior_loss_weight * prior_loss
816 else:
817 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
818
819 accelerator.backward(loss)
820
821 # Zero out the gradients for all token embeddings except the newly added
822 # embeddings for the concept, as we only want to optimize the concept embeddings
823 if accelerator.num_processes > 1:
824 grads = text_encoder.module.get_input_embeddings().weight.grad
825 else:
826 grads = text_encoder.get_input_embeddings().weight.grad
827 # Get the index for tokens that we want to zero the grads for
828 index_grads_to_zero = torch.arange(len(tokenizer)) != placeholder_token_id
829 grads.data[index_grads_to_zero, :] = grads.data[index_grads_to_zero, :].fill_(0)
830
831 if accelerator.sync_gradients:
832 accelerator.clip_grad_norm_(unet.parameters(), args.max_grad_norm)
833
834 optimizer.step()
835 if not accelerator.optimizer_step_was_skipped:
836 lr_scheduler.step()
837 optimizer.zero_grad(set_to_none=True)
838
839 loss = loss.detach().item()
840 train_loss += loss
841
842 # Checks if the accelerator has performed an optimization step behind the scenes
843 if accelerator.sync_gradients:
844 if args.use_ema:
845 ema_unet.step(unet)
846
847 local_progress_bar.update(1)
848 global_progress_bar.update(1)
849
850 global_step += 1
851
852 if global_step % args.sample_frequency == 0:
853 sample_checkpoint = True
854
855 logs = {
856 "train/loss": loss,
857 "lr/unet": lr_scheduler.get_last_lr()[0],
858 "lr/text": lr_scheduler.get_last_lr()[1]
859 }
860 if args.use_ema:
861 logs["ema_decay"] = ema_unet.decay
862
863 accelerator.log(logs, step=global_step)
864
865 local_progress_bar.set_postfix(**logs)
866
867 if global_step >= args.max_train_steps:
868 break
869
870 train_loss /= len(train_dataloader)
871
872 accelerator.wait_for_everyone()
873
874 text_encoder.eval()
875 val_loss = 0.0
876
877 for step, batch in enumerate(val_dataloader):
878 with torch.no_grad():
879 latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
880 latents = latents * 0.18215
881
882 noise = torch.randn(latents.shape).to(latents.device)
883 bsz = latents.shape[0]
884 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps,
885 (bsz,), device=latents.device)
886 timesteps = timesteps.long()
887
888 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
889
890 encoder_hidden_states = text_encoder(batch["input_ids"])[0]
891
892 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
893
894 noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise))
895
896 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
897
898 loss = loss.detach().item()
899 val_loss += loss
900
901 if accelerator.sync_gradients:
902 local_progress_bar.update(1)
903 global_progress_bar.update(1)
904
905 logs = {"val/loss": loss}
906 local_progress_bar.set_postfix(**logs)
907
908 val_loss /= len(val_dataloader)
909
910 accelerator.log({"val/loss": val_loss}, step=global_step)
911
912 local_progress_bar.clear()
913 global_progress_bar.clear()
914
915 if min_val_loss > val_loss:
916 min_val_loss = val_loss
917
918 if sample_checkpoint and accelerator.is_main_process:
919 checkpointer.save_samples(
920 global_step + global_step_offset,
921 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps)
922
923 # Create the pipeline using using the trained modules and save it.
924 if accelerator.is_main_process:
925 print("Finished! Saving final checkpoint and resume state.")
926 checkpointer.checkpoint()
927
928 accelerator.end_training()
929
930 except KeyboardInterrupt:
931 if accelerator.is_main_process:
932 print("Interrupted, saving checkpoint and resume state...")
933 checkpointer.checkpoint()
934 accelerator.end_training()
935 quit()
936
937
938if __name__ == "__main__":
939 main()
diff --git a/infer.py b/infer.py
index 5bd4abc..63b16d8 100644
--- a/infer.py
+++ b/infer.py
@@ -205,10 +205,10 @@ def load_embeddings(tokenizer, text_encoder, embeddings_dir):
205def create_pipeline(model, scheduler, embeddings_dir, dtype): 205def create_pipeline(model, scheduler, embeddings_dir, dtype):
206 print("Loading Stable Diffusion pipeline...") 206 print("Loading Stable Diffusion pipeline...")
207 207
208 tokenizer = CLIPTokenizer.from_pretrained(model, subfolder='/tokenizer', torch_dtype=dtype) 208 tokenizer = CLIPTokenizer.from_pretrained(model, subfolder='tokenizer', torch_dtype=dtype)
209 text_encoder = CLIPTextModel.from_pretrained(model, subfolder='/text_encoder', torch_dtype=dtype) 209 text_encoder = CLIPTextModel.from_pretrained(model, subfolder='text_encoder', torch_dtype=dtype)
210 vae = AutoencoderKL.from_pretrained(model, subfolder='/vae', torch_dtype=dtype) 210 vae = AutoencoderKL.from_pretrained(model, subfolder='vae', torch_dtype=dtype)
211 unet = UNet2DConditionModel.from_pretrained(model, subfolder='/unet', torch_dtype=dtype) 211 unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype)
212 212
213 load_embeddings(tokenizer, text_encoder, embeddings_dir) 213 load_embeddings(tokenizer, text_encoder, embeddings_dir)
214 214
diff --git a/textual_inversion.py b/textual_inversion.py
index 3a3741d..181a318 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -111,7 +111,7 @@ def parse_args():
111 parser.add_argument( 111 parser.add_argument(
112 "--max_train_steps", 112 "--max_train_steps",
113 type=int, 113 type=int,
114 default=3000, 114 default=10000,
115 help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 115 help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
116 ) 116 )
117 parser.add_argument( 117 parser.add_argument(
@@ -128,7 +128,7 @@ def parse_args():
128 parser.add_argument( 128 parser.add_argument(
129 "--learning_rate", 129 "--learning_rate",
130 type=float, 130 type=float,
131 default=5e-5, 131 default=1e-4,
132 help="Initial learning rate (after the potential warmup period) to use.", 132 help="Initial learning rate (after the potential warmup period) to use.",
133 ) 133 )
134 parser.add_argument( 134 parser.add_argument(
@@ -247,6 +247,11 @@ def parse_args():
247 help="The weight of prior preservation loss." 247 help="The weight of prior preservation loss."
248 ) 248 )
249 parser.add_argument( 249 parser.add_argument(
250 "--noise_timesteps",
251 type=int,
252 default=1000,
253 )
254 parser.add_argument(
250 "--resume_from", 255 "--resume_from",
251 type=str, 256 type=str,
252 default=None, 257 default=None,
@@ -568,7 +573,7 @@ def main():
568 beta_start=0.00085, 573 beta_start=0.00085,
569 beta_end=0.012, 574 beta_end=0.012,
570 beta_schedule="scaled_linear", 575 beta_schedule="scaled_linear",
571 num_train_timesteps=1000 576 num_train_timesteps=args.noise_timesteps
572 ) 577 )
573 578
574 def collate_fn(examples): 579 def collate_fn(examples):