summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--.gitignore1
-rw-r--r--data/dreambooth/csv.py177
-rw-r--r--data/dreambooth/prompt.py16
-rw-r--r--dreambooth.py825
4 files changed, 1019 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore
index a8893c3..91a5e07 100644
--- a/.gitignore
+++ b/.gitignore
@@ -160,5 +160,6 @@ cython_debug/
160#.idea/ 160#.idea/
161 161
162text-inversion-model/ 162text-inversion-model/
163dreambooth-model/
163conf*.json 164conf*.json
164v1-inference.yaml* 165v1-inference.yaml*
diff --git a/data/dreambooth/csv.py b/data/dreambooth/csv.py
new file mode 100644
index 0000000..04df4c6
--- /dev/null
+++ b/data/dreambooth/csv.py
@@ -0,0 +1,177 @@
1import os
2import pandas as pd
3from pathlib import Path
4import PIL
5import pytorch_lightning as pl
6from PIL import Image
7from torch.utils.data import Dataset, DataLoader, random_split
8from torchvision import transforms
9
10
11class CSVDataModule(pl.LightningDataModule):
12 def __init__(self,
13 batch_size,
14 data_root,
15 tokenizer,
16 instance_prompt,
17 class_data_root=None,
18 class_prompt=None,
19 size=512,
20 repeats=100,
21 interpolation="bicubic",
22 identifier="*",
23 center_crop=False,
24 collate_fn=None):
25 super().__init__()
26
27 self.data_root = data_root
28 self.tokenizer = tokenizer
29 self.instance_prompt = instance_prompt
30 self.class_data_root = class_data_root
31 self.class_prompt = class_prompt
32 self.size = size
33 self.repeats = repeats
34 self.identifier = identifier
35 self.center_crop = center_crop
36 self.interpolation = interpolation
37 self.collate_fn = collate_fn
38 self.batch_size = batch_size
39
40 def prepare_data(self):
41 metadata = pd.read_csv(f'{self.data_root}/list.csv')
42 image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values]
43 captions = [caption for caption in metadata['caption'].values]
44 skips = [skip for skip in metadata['skip'].values]
45 self.data_full = [(img, cap) for img, cap, skip in zip(image_paths, captions, skips) if skip != "x"]
46
47 def setup(self, stage=None):
48 train_set_size = int(len(self.data_full) * 0.8)
49 valid_set_size = len(self.data_full) - train_set_size
50 self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size])
51
52 train_dataset = CSVDataset(self.data_train, self.tokenizer, instance_prompt=self.instance_prompt,
53 class_data_root=self.class_data_root,
54 class_prompt=self.class_prompt, size=self.size, repeats=self.repeats,
55 interpolation=self.interpolation, identifier=self.identifier,
56 center_crop=self.center_crop)
57 val_dataset = CSVDataset(self.data_val, self.tokenizer, instance_prompt=self.instance_prompt,
58 class_data_root=self.class_data_root,
59 class_prompt=self.class_prompt, size=self.size, interpolation=self.interpolation,
60 identifier=self.identifier, center_crop=self.center_crop)
61 self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size,
62 shuffle=True, collate_fn=self.collate_fn)
63 self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, collate_fn=self.collate_fn)
64
65 def train_dataloader(self):
66 return self.train_dataloader_
67
68 def val_dataloader(self):
69 return self.val_dataloader_
70
71
72class CSVDataset(Dataset):
73 def __init__(self,
74 data,
75 tokenizer,
76 instance_prompt,
77 class_data_root=None,
78 class_prompt=None,
79 size=512,
80 repeats=1,
81 interpolation="bicubic",
82 identifier="*",
83 center_crop=False,
84 ):
85
86 self.data = data
87 self.tokenizer = tokenizer
88 self.instance_prompt = instance_prompt
89
90 self.num_instance_images = len(self.data)
91 self._length = self.num_instance_images * repeats
92
93 self.identifier = identifier
94
95 if class_data_root is not None:
96 self.class_data_root = Path(class_data_root)
97 self.class_data_root.mkdir(parents=True, exist_ok=True)
98
99 self.class_images = list(Path(class_data_root).iterdir())
100 self.num_class_images = len(self.class_images)
101 self._length = max(self.num_class_images, self.num_instance_images)
102
103 self.class_prompt = class_prompt
104 else:
105 self.class_data_root = None
106
107 self.interpolation = {"linear": PIL.Image.LINEAR,
108 "bilinear": PIL.Image.BILINEAR,
109 "bicubic": PIL.Image.BICUBIC,
110 "lanczos": PIL.Image.LANCZOS,
111 }[interpolation]
112 self.image_transforms = transforms.Compose(
113 [
114 transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR),
115 transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
116 transforms.ToTensor(),
117 transforms.Normalize([0.5], [0.5]),
118 ]
119 )
120
121 self.cache = {}
122
123 def __len__(self):
124 return self._length
125
126 def get_example(self, i):
127 image_path, text = self.data[i % self.num_instance_images]
128
129 if image_path in self.cache:
130 return self.cache[image_path]
131
132 example = {}
133
134 instance_image = Image.open(image_path)
135 if not instance_image.mode == "RGB":
136 instance_image = instance_image.convert("RGB")
137
138 text = text.format(self.identifier)
139
140 example["prompts"] = text
141 example["instance_images"] = instance_image
142 example["instance_prompt_ids"] = self.tokenizer(
143 self.instance_prompt,
144 padding="do_not_pad",
145 truncation=True,
146 max_length=self.tokenizer.model_max_length,
147 ).input_ids
148
149 if self.class_data_root:
150 class_image = Image.open(self.class_images[i % self.num_class_images])
151 if not class_image.mode == "RGB":
152 class_image = class_image.convert("RGB")
153
154 example["class_images"] = class_image
155 example["class_prompt_ids"] = self.tokenizer(
156 self.class_prompt,
157 padding="do_not_pad",
158 truncation=True,
159 max_length=self.tokenizer.model_max_length,
160 ).input_ids
161
162 self.cache[image_path] = example
163 return example
164
165 def __getitem__(self, i):
166 example = {}
167 unprocessed_example = self.get_example(i)
168
169 example["prompts"] = unprocessed_example["prompts"]
170 example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"])
171 example["instance_prompt_ids"] = unprocessed_example["instance_prompt_ids"]
172
173 if self.class_data_root:
174 example["class_images"] = self.image_transforms(unprocessed_example["class_images"])
175 example["class_prompt_ids"] = unprocessed_example["class_prompt_ids"]
176
177 return example
diff --git a/data/dreambooth/prompt.py b/data/dreambooth/prompt.py
new file mode 100644
index 0000000..34f510d
--- /dev/null
+++ b/data/dreambooth/prompt.py
@@ -0,0 +1,16 @@
1from torch.utils.data import Dataset
2
3
4class PromptDataset(Dataset):
5 def __init__(self, prompt, num_samples):
6 self.prompt = prompt
7 self.num_samples = num_samples
8
9 def __len__(self):
10 return self.num_samples
11
12 def __getitem__(self, index):
13 example = {}
14 example["prompt"] = self.prompt
15 example["index"] = index
16 return example
diff --git a/dreambooth.py b/dreambooth.py
new file mode 100644
index 0000000..b6b3594
--- /dev/null
+++ b/dreambooth.py
@@ -0,0 +1,825 @@
1import argparse
2import itertools
3import math
4import os
5import datetime
6from pathlib import Path
7
8import numpy as np
9import torch
10import torch.nn.functional as F
11import torch.utils.checkpoint
12
13from accelerate import Accelerator
14from accelerate.logging import get_logger
15from accelerate.utils import LoggerType, set_seed
16from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, LMSDiscreteScheduler, StableDiffusionPipeline, UNet2DConditionModel
17from diffusers.optimization import get_scheduler
18from pipelines.stable_diffusion.no_check import NoCheck
19from PIL import Image
20from tqdm.auto import tqdm
21from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
22from slugify import slugify
23import json
24import os
25
26from data.dreambooth.csv import CSVDataModule
27from data.dreambooth.prompt import PromptDataset
28
29logger = get_logger(__name__)
30
31
32def parse_args():
33 parser = argparse.ArgumentParser(
34 description="Simple example of a training script."
35 )
36 parser.add_argument(
37 "--pretrained_model_name_or_path",
38 type=str,
39 default=None,
40 help="Path to pretrained model or model identifier from huggingface.co/models.",
41 )
42 parser.add_argument(
43 "--tokenizer_name",
44 type=str,
45 default=None,
46 help="Pretrained tokenizer name or path if not the same as model_name",
47 )
48 parser.add_argument(
49 "--train_data_dir",
50 type=str,
51 default=None,
52 help="A folder containing the training data."
53 )
54 parser.add_argument(
55 "--identifier",
56 type=str,
57 default=None,
58 help="A token to use as a placeholder for the concept.",
59 )
60 parser.add_argument(
61 "--repeats",
62 type=int,
63 default=100,
64 help="How many times to repeat the training data.")
65 parser.add_argument(
66 "--output_dir",
67 type=str,
68 default="dreambooth-model",
69 help="The output directory where the model predictions and checkpoints will be written.",
70 )
71 parser.add_argument(
72 "--seed",
73 type=int,
74 default=None,
75 help="A seed for reproducible training.")
76 parser.add_argument(
77 "--resolution",
78 type=int,
79 default=512,
80 help=(
81 "The resolution for input images, all the images in the train/validation dataset will be resized to this"
82 " resolution"
83 ),
84 )
85 parser.add_argument(
86 "--center_crop",
87 action="store_true",
88 help="Whether to center crop images before resizing to resolution"
89 )
90 parser.add_argument(
91 "--num_train_epochs",
92 type=int,
93 default=100)
94 parser.add_argument(
95 "--max_train_steps",
96 type=int,
97 default=5000,
98 help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
99 )
100 parser.add_argument(
101 "--gradient_accumulation_steps",
102 type=int,
103 default=1,
104 help="Number of updates steps to accumulate before performing a backward/update pass.",
105 )
106 parser.add_argument(
107 "--gradient_checkpointing",
108 action="store_true",
109 help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
110 )
111 parser.add_argument(
112 "--learning_rate",
113 type=float,
114 default=1e-4,
115 help="Initial learning rate (after the potential warmup period) to use.",
116 )
117 parser.add_argument(
118 "--scale_lr",
119 action="store_true",
120 default=True,
121 help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
122 )
123 parser.add_argument(
124 "--lr_scheduler",
125 type=str,
126 default="constant",
127 help=(
128 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
129 ' "constant", "constant_with_warmup"]'
130 ),
131 )
132 parser.add_argument(
133 "--lr_warmup_steps",
134 type=int,
135 default=500,
136 help="Number of steps for the warmup in the lr scheduler."
137 )
138 parser.add_argument(
139 "--adam_beta1",
140 type=float,
141 default=0.9,
142 help="The beta1 parameter for the Adam optimizer."
143 )
144 parser.add_argument(
145 "--adam_beta2",
146 type=float,
147 default=0.999,
148 help="The beta2 parameter for the Adam optimizer."
149 )
150 parser.add_argument(
151 "--adam_weight_decay",
152 type=float,
153 default=1e-2,
154 help="Weight decay to use."
155 )
156 parser.add_argument(
157 "--adam_epsilon",
158 type=float,
159 default=1e-08,
160 help="Epsilon value for the Adam optimizer"
161 )
162 parser.add_argument(
163 "--mixed_precision",
164 type=str,
165 default="no",
166 choices=["no", "fp16", "bf16"],
167 help=(
168 "Whether to use mixed precision. Choose"
169 "between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >= 1.10."
170 "and an Nvidia Ampere GPU."
171 ),
172 )
173 parser.add_argument(
174 "--local_rank",
175 type=int,
176 default=-1,
177 help="For distributed training: local_rank"
178 )
179 parser.add_argument(
180 "--checkpoint_frequency",
181 type=int,
182 default=500,
183 help="How often to save a checkpoint and sample image",
184 )
185 parser.add_argument(
186 "--sample_image_size",
187 type=int,
188 default=512,
189 help="Size of sample images",
190 )
191 parser.add_argument(
192 "--stable_sample_batches",
193 type=int,
194 default=1,
195 help="Number of fixed seed sample batches to generate per checkpoint",
196 )
197 parser.add_argument(
198 "--random_sample_batches",
199 type=int,
200 default=1,
201 help="Number of random seed sample batches to generate per checkpoint",
202 )
203 parser.add_argument(
204 "--sample_batch_size",
205 type=int,
206 default=1,
207 help="Number of samples to generate per batch",
208 )
209 parser.add_argument(
210 "--train_batch_size",
211 type=int,
212 default=1,
213 help="Batch size (per device) for the training dataloader."
214 )
215 parser.add_argument(
216 "--sample_steps",
217 type=int,
218 default=50,
219 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.",
220 )
221 parser.add_argument(
222 "--instance_prompt",
223 type=str,
224 default=None,
225 help="The prompt with identifier specifing the instance",
226 )
227 parser.add_argument(
228 "--class_data_dir",
229 type=str,
230 default=None,
231 required=False,
232 help="A folder containing the training data of class images.",
233 )
234 parser.add_argument(
235 "--class_prompt",
236 type=str,
237 default=None,
238 help="The prompt to specify images in the same class as provided intance images.",
239 )
240 parser.add_argument(
241 "--with_prior_preservation",
242 default=False,
243 action="store_true",
244 help="Flag to add prior perservation loss.",
245 )
246 parser.add_argument(
247 "--num_class_images",
248 type=int,
249 default=100,
250 help=(
251 "Minimal class images for prior perversation loss. If not have enough images, additional images will be"
252 " sampled with class_prompt."
253 ),
254 )
255 parser.add_argument(
256 "--config",
257 type=str,
258 default=None,
259 help="Path to a JSON configuration file containing arguments for invoking this script."
260 )
261
262 args = parser.parse_args()
263 if args.config is not None:
264 with open(args.config, 'rt') as f:
265 args = parser.parse_args(
266 namespace=argparse.Namespace(**json.load(f)["args"]))
267
268 env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
269 if env_local_rank != -1 and env_local_rank != args.local_rank:
270 args.local_rank = env_local_rank
271
272 if args.train_data_dir is None:
273 raise ValueError("You must specify --train_data_dir")
274
275 if args.pretrained_model_name_or_path is None:
276 raise ValueError("You must specify --pretrained_model_name_or_path")
277
278 if args.instance_prompt is None:
279 raise ValueError("You must specify --instance_prompt")
280
281 if args.identifier is None:
282 raise ValueError("You must specify --identifier")
283
284 if args.output_dir is None:
285 raise ValueError("You must specify --output_dir")
286
287 if args.with_prior_preservation:
288 if args.class_data_dir is None:
289 raise ValueError("You must specify --class_data_dir")
290 if args.class_prompt is None:
291 raise ValueError("You must specify --class_prompt")
292
293 return args
294
295
296def freeze_params(params):
297 for param in params:
298 param.requires_grad = False
299
300
301def make_grid(images, rows, cols):
302 w, h = images[0].size
303 grid = Image.new('RGB', size=(cols*w, rows*h))
304 for i, image in enumerate(images):
305 grid.paste(image, box=(i % cols*w, i//cols*h))
306 return grid
307
308
309class Checkpointer:
310 def __init__(
311 self,
312 datamodule,
313 accelerator,
314 vae,
315 unet,
316 tokenizer,
317 text_encoder,
318 output_dir,
319 sample_image_size,
320 random_sample_batches,
321 sample_batch_size,
322 stable_sample_batches,
323 seed
324 ):
325 self.datamodule = datamodule
326 self.accelerator = accelerator
327 self.vae = vae
328 self.unet = unet
329 self.tokenizer = tokenizer
330 self.text_encoder = text_encoder
331 self.output_dir = output_dir
332 self.sample_image_size = sample_image_size
333 self.seed = seed
334 self.random_sample_batches = random_sample_batches
335 self.sample_batch_size = sample_batch_size
336 self.stable_sample_batches = stable_sample_batches
337
338 @torch.no_grad()
339 def checkpoint(self):
340 print("Saving model...")
341
342 unwrapped = self.accelerator.unwrap_model(self.unet)
343 pipeline = StableDiffusionPipeline(
344 text_encoder=self.text_encoder,
345 vae=self.vae,
346 unet=self.accelerator.unwrap_model(self.unet),
347 tokenizer=self.tokenizer,
348 scheduler=PNDMScheduler(
349 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
350 ),
351 safety_checker=NoCheck(),
352 feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
353 )
354 pipeline.enable_attention_slicing()
355 pipeline.save_pretrained(f"{self.output_dir}/model.ckpt")
356
357 del unwrapped
358 del pipeline
359
360 if torch.cuda.is_available():
361 torch.cuda.empty_cache()
362
363 @torch.no_grad()
364 def save_samples(self, mode, step, height, width, guidance_scale, eta, num_inference_steps):
365 samples_path = f"{self.output_dir}/samples/{mode}"
366 os.makedirs(samples_path, exist_ok=True)
367 checker = NoCheck()
368
369 unwrapped = self.accelerator.unwrap_model(self.unet)
370 pipeline = StableDiffusionPipeline(
371 text_encoder=self.text_encoder,
372 vae=self.vae,
373 unet=unwrapped,
374 tokenizer=self.tokenizer,
375 scheduler=LMSDiscreteScheduler(
376 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
377 ),
378 safety_checker=NoCheck(),
379 feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
380 ).to(self.accelerator.device)
381 pipeline.enable_attention_slicing()
382
383 data = {
384 "training": self.datamodule.train_dataloader(),
385 "validation": self.datamodule.val_dataloader(),
386 }[mode]
387
388 if mode == "validation" and self.stable_sample_batches > 0 and step > 0:
389 stable_latents = torch.randn(
390 (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8),
391 device=pipeline.device,
392 generator=torch.Generator(device=pipeline.device).manual_seed(self.seed),
393 )
394
395 all_samples = []
396 filename = f"stable_step_%d.png" % (step)
397
398 data_enum = enumerate(data)
399
400 # Generate and save stable samples
401 for i in range(0, self.stable_sample_batches):
402 prompt = [prompt for i, batch in data_enum for j, prompt in enumerate(
403 batch["prompts"]) if i * data.batch_size + j < self.sample_batch_size]
404
405 with self.accelerator.autocast():
406 samples = pipeline(
407 prompt=prompt,
408 height=self.sample_image_size,
409 latents=stable_latents[:len(prompt)],
410 width=self.sample_image_size,
411 guidance_scale=guidance_scale,
412 eta=eta,
413 num_inference_steps=num_inference_steps,
414 output_type='pil'
415 )["sample"]
416
417 all_samples += samples
418 del samples
419
420 image_grid = make_grid(all_samples, self.stable_sample_batches, self.sample_batch_size)
421 image_grid.save(f"{samples_path}/{filename}")
422
423 del all_samples
424 del image_grid
425 del stable_latents
426
427 all_samples = []
428 filename = f"step_%d.png" % (step)
429
430 data_enum = enumerate(data)
431
432 # Generate and save random samples
433 for i in range(0, self.random_sample_batches):
434 prompt = [prompt for i, batch in data_enum for j, prompt in enumerate(
435 batch["prompts"]) if i * data.batch_size + j < self.sample_batch_size]
436
437 with self.accelerator.autocast():
438 samples = pipeline(
439 prompt=prompt,
440 height=self.sample_image_size,
441 width=self.sample_image_size,
442 guidance_scale=guidance_scale,
443 eta=eta,
444 num_inference_steps=num_inference_steps,
445 output_type='pil'
446 )["sample"]
447
448 all_samples += samples
449 del samples
450
451 image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size)
452 image_grid.save(f"{samples_path}/{filename}")
453
454 del all_samples
455 del image_grid
456
457 del checker
458 del unwrapped
459 del pipeline
460
461 if torch.cuda.is_available():
462 torch.cuda.empty_cache()
463
464
465def main():
466 args = parse_args()
467
468 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
469 basepath = f"{args.output_dir}/{slugify(args.identifier)}/{now}"
470 os.makedirs(basepath, exist_ok=True)
471
472 accelerator = Accelerator(
473 log_with=LoggerType.TENSORBOARD,
474 logging_dir=f"{basepath}",
475 gradient_accumulation_steps=args.gradient_accumulation_steps,
476 mixed_precision=args.mixed_precision
477 )
478
479 # If passed along, set the training seed now.
480 if args.seed is not None:
481 set_seed(args.seed)
482
483 if args.with_prior_preservation:
484 class_images_dir = Path(args.class_data_dir)
485 if not class_images_dir.exists():
486 class_images_dir.mkdir(parents=True)
487 cur_class_images = len(list(class_images_dir.iterdir()))
488
489 if cur_class_images < args.num_class_images:
490 torch_dtype = torch.float16 if accelerator.device.type == "cuda" else torch.float32
491 pipeline = StableDiffusionPipeline.from_pretrained(
492 args.pretrained_model_name_or_path, torch_dtype=torch_dtype)
493 pipeline.set_progress_bar_config(disable=True)
494
495 num_new_images = args.num_class_images - cur_class_images
496 logger.info(f"Number of class images to sample: {num_new_images}.")
497
498 sample_dataset = PromptDataset(args.class_prompt, num_new_images)
499 sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
500
501 sample_dataloader = accelerator.prepare(sample_dataloader)
502 pipeline.to(accelerator.device)
503
504 for example in tqdm(
505 sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
506 ):
507 with accelerator.autocast():
508 images = pipeline(example["prompt"]).images
509
510 for i, image in enumerate(images):
511 image.save(class_images_dir / f"{example['index'][i] + cur_class_images}.jpg")
512
513 del pipeline
514
515 if torch.cuda.is_available():
516 torch.cuda.empty_cache()
517
518 # Load the tokenizer and add the placeholder token as a additional special token
519 if args.tokenizer_name:
520 tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
521 elif args.pretrained_model_name_or_path:
522 tokenizer = CLIPTokenizer.from_pretrained(
523 args.pretrained_model_name_or_path + '/tokenizer'
524 )
525
526 # Load models and create wrapper for stable diffusion
527 text_encoder = CLIPTextModel.from_pretrained(
528 args.pretrained_model_name_or_path + '/text_encoder',
529 )
530 vae = AutoencoderKL.from_pretrained(
531 args.pretrained_model_name_or_path + '/vae',
532 )
533 unet = UNet2DConditionModel.from_pretrained(
534 args.pretrained_model_name_or_path + '/unet',
535 )
536
537 if args.gradient_checkpointing:
538 unet.enable_gradient_checkpointing()
539
540 # slice_size = unet.config.attention_head_dim // 2
541 # unet.set_attention_slice(slice_size)
542
543 # Freeze vae and unet
544 # freeze_params(vae.parameters())
545 # freeze_params(text_encoder.parameters())
546
547 if args.scale_lr:
548 args.learning_rate = (
549 args.learning_rate * args.gradient_accumulation_steps *
550 args.train_batch_size * accelerator.num_processes
551 )
552
553 # Initialize the optimizer
554 optimizer = torch.optim.AdamW(
555 unet.parameters(), # only optimize unet
556 lr=args.learning_rate,
557 betas=(args.adam_beta1, args.adam_beta2),
558 weight_decay=args.adam_weight_decay,
559 eps=args.adam_epsilon,
560 )
561
562 # TODO (patil-suraj): laod scheduler using args
563 noise_scheduler = DDPMScheduler(
564 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, tensor_format="pt"
565 )
566
567 def collate_fn(examples):
568 prompts = [example["prompts"] for example in examples]
569 input_ids = [example["instance_prompt_ids"] for example in examples]
570 pixel_values = [example["instance_images"] for example in examples]
571
572 # concat class and instance examples for prior preservation
573 if args.with_prior_preservation:
574 input_ids += [example["class_prompt_ids"] for example in examples]
575 pixel_values += [example["class_images"] for example in examples]
576
577 pixel_values = torch.stack(pixel_values)
578 pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
579
580 input_ids = tokenizer.pad({"input_ids": input_ids}, padding=True, return_tensors="pt").input_ids
581
582 batch = {
583 "prompts": prompts,
584 "input_ids": input_ids,
585 "pixel_values": pixel_values,
586 }
587 return batch
588
589 datamodule = CSVDataModule(
590 data_root=args.train_data_dir,
591 batch_size=args.train_batch_size,
592 tokenizer=tokenizer,
593 instance_prompt=args.instance_prompt,
594 class_data_root=args.class_data_dir if args.with_prior_preservation else None,
595 class_prompt=args.class_prompt,
596 size=args.resolution,
597 identifier=args.identifier,
598 repeats=args.repeats,
599 center_crop=args.center_crop,
600 collate_fn=collate_fn)
601
602 datamodule.prepare_data()
603 datamodule.setup()
604
605 train_dataloader = datamodule.train_dataloader()
606 val_dataloader = datamodule.val_dataloader()
607
608 checkpointer = Checkpointer(
609 datamodule=datamodule,
610 accelerator=accelerator,
611 vae=vae,
612 unet=unet,
613 tokenizer=tokenizer,
614 text_encoder=text_encoder,
615 output_dir=basepath,
616 sample_image_size=args.sample_image_size,
617 sample_batch_size=args.sample_batch_size,
618 random_sample_batches=args.random_sample_batches,
619 stable_sample_batches=args.stable_sample_batches,
620 seed=args.seed
621 )
622
623 # Scheduler and math around the number of training steps.
624 overrode_max_train_steps = False
625 num_update_steps_per_epoch = math.ceil(
626 (len(train_dataloader) + len(val_dataloader)) / args.gradient_accumulation_steps)
627 if args.max_train_steps is None:
628 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
629 overrode_max_train_steps = True
630
631 lr_scheduler = get_scheduler(
632 args.lr_scheduler,
633 optimizer=optimizer,
634 num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
635 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
636 )
637
638 unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
639 unet, optimizer, train_dataloader, val_dataloader, lr_scheduler
640 )
641
642 # Move vae and unet to device
643 text_encoder.to(accelerator.device)
644 vae.to(accelerator.device)
645
646 # Keep text_encoder and vae in eval mode as we don't train these
647 # text_encoder.eval()
648 # vae.eval()
649
650 # We need to recalculate our total training steps as the size of the training dataloader may have changed.
651 num_update_steps_per_epoch = math.ceil(
652 (len(train_dataloader) + len(val_dataloader)) / args.gradient_accumulation_steps)
653 if overrode_max_train_steps:
654 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
655 # Afterwards we recalculate our number of training epochs
656 args.num_train_epochs = math.ceil(
657 args.max_train_steps / num_update_steps_per_epoch)
658
659 # We need to initialize the trackers we use, and also store our configuration.
660 # The trackers initializes automatically on the main process.
661 if accelerator.is_main_process:
662 accelerator.init_trackers("dreambooth", config=vars(args))
663
664 # Train!
665 total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
666
667 logger.info("***** Running training *****")
668 logger.info(f" Num Epochs = {args.num_train_epochs}")
669 logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
670 logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
671 logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
672 logger.info(f" Total optimization steps = {args.max_train_steps}")
673 # Only show the progress bar once on each machine.
674
675 global_step = 0
676 min_val_loss = np.inf
677
678 checkpointer.save_samples(
679 "validation",
680 0,
681 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps)
682
683 progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process)
684 progress_bar.set_description("Global steps")
685
686 local_progress_bar = tqdm(range(num_update_steps_per_epoch), disable=not accelerator.is_local_main_process)
687 local_progress_bar.set_description("Steps")
688
689 try:
690 for epoch in range(args.num_train_epochs):
691 local_progress_bar.reset()
692
693 unet.train()
694 train_loss = 0.0
695
696 for step, batch in enumerate(train_dataloader):
697 with accelerator.accumulate(unet):
698 with accelerator.autocast():
699 # Convert images to latent space
700 with torch.no_grad():
701 latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
702 latents = latents * 0.18215
703
704 # Sample noise that we'll add to the latents
705 noise = torch.randn(latents.shape).to(latents.device)
706 bsz = latents.shape[0]
707 # Sample a random timestep for each image
708 timesteps = torch.randint(0, noise_scheduler.num_train_timesteps,
709 (bsz,), device=latents.device).long()
710
711 # Add noise to the latents according to the noise magnitude at each timestep
712 # (this is the forward diffusion process)
713 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
714
715 # Get the text embedding for conditioning
716 with torch.no_grad():
717 encoder_hidden_states = text_encoder(batch["input_ids"])[0]
718
719 # Predict the noise residual
720 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
721
722 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
723
724 accelerator.backward(loss)
725
726 optimizer.step()
727 if not accelerator.optimizer_step_was_skipped:
728 lr_scheduler.step()
729 optimizer.zero_grad(set_to_none=True)
730
731 loss = loss.detach().item()
732 train_loss += loss
733
734 # Checks if the accelerator has performed an optimization step behind the scenes
735 if accelerator.sync_gradients:
736 progress_bar.update(1)
737 local_progress_bar.update(1)
738
739 global_step += 1
740
741 if global_step % args.checkpoint_frequency == 0 and global_step > 0 and accelerator.is_main_process:
742 progress_bar.clear()
743 local_progress_bar.clear()
744
745 checkpointer.save_samples(
746 "training",
747 global_step,
748 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps)
749
750 logs = {"mode": "training", "loss": loss, "lr": lr_scheduler.get_last_lr()[0]}
751 local_progress_bar.set_postfix(**logs)
752
753 if global_step >= args.max_train_steps:
754 break
755
756 train_loss /= len(train_dataloader)
757
758 unet.eval()
759 val_loss = 0.0
760
761 for step, batch in enumerate(val_dataloader):
762 with torch.no_grad(), accelerator.autocast():
763 latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
764 latents = latents * 0.18215
765
766 noise = torch.randn(latents.shape).to(latents.device)
767 bsz = latents.shape[0]
768 timesteps = torch.randint(0, noise_scheduler.num_train_timesteps,
769 (bsz,), device=latents.device).long()
770
771 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
772
773 encoder_hidden_states = text_encoder(batch["input_ids"])[0]
774
775 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
776
777 noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise))
778
779 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
780
781 loss = loss.detach().item()
782 val_loss += loss
783
784 if accelerator.sync_gradients:
785 progress_bar.update(1)
786 local_progress_bar.update(1)
787
788 logs = {"mode": "validation", "loss": loss}
789 local_progress_bar.set_postfix(**logs)
790
791 val_loss /= len(val_dataloader)
792
793 accelerator.log({"train/loss": train_loss, "val/loss": val_loss}, step=global_step)
794
795 progress_bar.clear()
796 local_progress_bar.clear()
797
798 if min_val_loss > val_loss:
799 accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}")
800 min_val_loss = val_loss
801
802 checkpointer.save_samples(
803 "validation",
804 global_step,
805 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps)
806
807 accelerator.wait_for_everyone()
808
809 # Create the pipeline using using the trained modules and save it.
810 if accelerator.is_main_process:
811 print("Finished! Saving final checkpoint and resume state.")
812 checkpointer.checkpoint()
813
814 accelerator.end_training()
815
816 except KeyboardInterrupt:
817 if accelerator.is_main_process:
818 print("Interrupted, saving checkpoint and resume state...")
819 checkpointer.checkpoint()
820 accelerator.end_training()
821 quit()
822
823
824if __name__ == "__main__":
825 main()