summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--data/dreambooth/csv.py1
-rw-r--r--data/textual_inversion/csv.py98
-rw-r--r--dreambooth.py26
-rw-r--r--textual_inversion.py308
4 files changed, 230 insertions, 203 deletions
diff --git a/data/dreambooth/csv.py b/data/dreambooth/csv.py
index 99bcf12..1676d35 100644
--- a/data/dreambooth/csv.py
+++ b/data/dreambooth/csv.py
@@ -2,7 +2,6 @@ import math
2import os 2import os
3import pandas as pd 3import pandas as pd
4from pathlib import Path 4from pathlib import Path
5import PIL
6import pytorch_lightning as pl 5import pytorch_lightning as pl
7from PIL import Image 6from PIL import Image
8from torch.utils.data import Dataset, DataLoader, random_split 7from torch.utils.data import Dataset, DataLoader, random_split
diff --git a/data/textual_inversion/csv.py b/data/textual_inversion/csv.py
index 0d1e96e..f306c7a 100644
--- a/data/textual_inversion/csv.py
+++ b/data/textual_inversion/csv.py
@@ -1,11 +1,10 @@
1import os 1import os
2import numpy as np 2import numpy as np
3import pandas as pd 3import pandas as pd
4import random 4from pathlib import Path
5import PIL 5import math
6import pytorch_lightning as pl 6import pytorch_lightning as pl
7from PIL import Image 7from PIL import Image
8import torch
9from torch.utils.data import Dataset, DataLoader, random_split 8from torch.utils.data import Dataset, DataLoader, random_split
10from torchvision import transforms 9from torchvision import transforms
11 10
@@ -13,29 +12,32 @@ from torchvision import transforms
13class CSVDataModule(pl.LightningDataModule): 12class CSVDataModule(pl.LightningDataModule):
14 def __init__(self, 13 def __init__(self,
15 batch_size, 14 batch_size,
16 data_root, 15 data_file,
17 tokenizer, 16 tokenizer,
18 size=512, 17 size=512,
19 repeats=100, 18 repeats=100,
20 interpolation="bicubic", 19 interpolation="bicubic",
21 placeholder_token="*", 20 placeholder_token="*",
22 flip_p=0.5,
23 center_crop=False): 21 center_crop=False):
24 super().__init__() 22 super().__init__()
25 23
26 self.data_root = data_root 24 self.data_file = Path(data_file)
25
26 if not self.data_file.is_file():
27 raise ValueError("data_file must be a file")
28
29 self.data_root = self.data_file.parent
27 self.tokenizer = tokenizer 30 self.tokenizer = tokenizer
28 self.size = size 31 self.size = size
29 self.repeats = repeats 32 self.repeats = repeats
30 self.placeholder_token = placeholder_token 33 self.placeholder_token = placeholder_token
31 self.center_crop = center_crop 34 self.center_crop = center_crop
32 self.flip_p = flip_p
33 self.interpolation = interpolation 35 self.interpolation = interpolation
34 36
35 self.batch_size = batch_size 37 self.batch_size = batch_size
36 38
37 def prepare_data(self): 39 def prepare_data(self):
38 metadata = pd.read_csv(f'{self.data_root}/list.csv') 40 metadata = pd.read_csv(self.data_file)
39 image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values] 41 image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values]
40 captions = [caption for caption in metadata['caption'].values] 42 captions = [caption for caption in metadata['caption'].values]
41 skips = [skip for skip in metadata['skip'].values] 43 skips = [skip for skip in metadata['skip'].values]
@@ -47,9 +49,9 @@ class CSVDataModule(pl.LightningDataModule):
47 self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size]) 49 self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size])
48 50
49 train_dataset = CSVDataset(self.data_train, self.tokenizer, size=self.size, repeats=self.repeats, interpolation=self.interpolation, 51 train_dataset = CSVDataset(self.data_train, self.tokenizer, size=self.size, repeats=self.repeats, interpolation=self.interpolation,
50 flip_p=self.flip_p, placeholder_token=self.placeholder_token, center_crop=self.center_crop) 52 placeholder_token=self.placeholder_token, center_crop=self.center_crop)
51 val_dataset = CSVDataset(self.data_val, self.tokenizer, size=self.size, interpolation=self.interpolation, 53 val_dataset = CSVDataset(self.data_val, self.tokenizer, size=self.size, interpolation=self.interpolation,
52 flip_p=self.flip_p, placeholder_token=self.placeholder_token, center_crop=self.center_crop) 54 placeholder_token=self.placeholder_token, center_crop=self.center_crop)
53 self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True) 55 self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True)
54 self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size) 56 self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size)
55 57
@@ -67,48 +69,54 @@ class CSVDataset(Dataset):
67 size=512, 69 size=512,
68 repeats=1, 70 repeats=1,
69 interpolation="bicubic", 71 interpolation="bicubic",
70 flip_p=0.5,
71 placeholder_token="*", 72 placeholder_token="*",
72 center_crop=False, 73 center_crop=False,
74 batch_size=1,
73 ): 75 ):
74 76
75 self.data = data 77 self.data = data
76 self.tokenizer = tokenizer 78 self.tokenizer = tokenizer
77
78 self.num_images = len(self.data)
79 self._length = self.num_images * repeats
80
81 self.placeholder_token = placeholder_token 79 self.placeholder_token = placeholder_token
80 self.batch_size = batch_size
81 self.cache = {}
82 82
83 self.size = size 83 self.num_instance_images = len(self.data)
84 self.center_crop = center_crop 84 self._length = self.num_instance_images * repeats
85 self.interpolation = {"linear": PIL.Image.LINEAR,
86 "bilinear": PIL.Image.BILINEAR,
87 "bicubic": PIL.Image.BICUBIC,
88 "lanczos": PIL.Image.LANCZOS,
89 }[interpolation]
90 self.flip = transforms.RandomHorizontalFlip(p=flip_p)
91 85
92 self.cache = {} 86 self.interpolation = {"linear": transforms.InterpolationMode.NEAREST,
87 "bilinear": transforms.InterpolationMode.BILINEAR,
88 "bicubic": transforms.InterpolationMode.BICUBIC,
89 "lanczos": transforms.InterpolationMode.LANCZOS,
90 }[interpolation]
91 self.image_transforms = transforms.Compose(
92 [
93 transforms.Resize(size, interpolation=self.interpolation),
94 transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size),
95 transforms.RandomHorizontalFlip(),
96 transforms.ToTensor(),
97 transforms.Normalize([0.5], [0.5]),
98 ]
99 )
93 100
94 def __len__(self): 101 def __len__(self):
95 return self._length 102 return math.ceil(self._length / self.batch_size) * self.batch_size
96 103
97 def get_example(self, i, flipped): 104 def get_example(self, i):
98 image_path, text = self.data[i % self.num_images] 105 image_path, text = self.data[i % self.num_instance_images]
99 106
100 if image_path in self.cache: 107 if image_path in self.cache:
101 return self.cache[image_path] 108 return self.cache[image_path]
102 109
103 example = {} 110 example = {}
104 image = Image.open(image_path)
105 111
106 if not image.mode == "RGB": 112 instance_image = Image.open(image_path)
107 image = image.convert("RGB") 113 if not instance_image.mode == "RGB":
114 instance_image = instance_image.convert("RGB")
108 115
109 text = text.format(self.placeholder_token) 116 text = text.format(self.placeholder_token)
110 117
111 example["prompt"] = text 118 example["prompts"] = text
119 example["pixel_values"] = instance_image
112 example["input_ids"] = self.tokenizer( 120 example["input_ids"] = self.tokenizer(
113 text, 121 text,
114 padding="max_length", 122 padding="max_length",
@@ -117,29 +125,15 @@ class CSVDataset(Dataset):
117 return_tensors="pt", 125 return_tensors="pt",
118 ).input_ids[0] 126 ).input_ids[0]
119 127
120 # default to score-sde preprocessing
121 img = np.array(image).astype(np.uint8)
122
123 if self.center_crop:
124 crop = min(img.shape[0], img.shape[1])
125 h, w, = img.shape[0], img.shape[1]
126 img = img[(h - crop) // 2:(h + crop) // 2,
127 (w - crop) // 2:(w + crop) // 2]
128
129 image = Image.fromarray(img)
130 image = image.resize((self.size, self.size),
131 resample=self.interpolation)
132 image = self.flip(image)
133 image = np.array(image).astype(np.uint8)
134 image = (image / 127.5 - 1.0).astype(np.float32)
135
136 example["key"] = "-".join([image_path, "-", str(flipped)])
137 example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1)
138
139 self.cache[image_path] = example 128 self.cache[image_path] = example
140 return example 129 return example
141 130
142 def __getitem__(self, i): 131 def __getitem__(self, i):
143 flipped = random.choice([False, True]) 132 example = {}
144 example = self.get_example(i, flipped) 133 unprocessed_example = self.get_example(i)
134
135 example["prompts"] = unprocessed_example["prompts"]
136 example["input_ids"] = unprocessed_example["input_ids"]
137 example["pixel_values"] = self.image_transforms(unprocessed_example["pixel_values"])
138
145 return example 139 return example
diff --git a/dreambooth.py b/dreambooth.py
index 4d7366c..744d1bc 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -14,12 +14,14 @@ from accelerate import Accelerator
14from accelerate.logging import get_logger 14from accelerate.logging import get_logger
15from accelerate.utils import LoggerType, set_seed 15from accelerate.utils import LoggerType, set_seed
16from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, LMSDiscreteScheduler, StableDiffusionPipeline, UNet2DConditionModel 16from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, LMSDiscreteScheduler, StableDiffusionPipeline, UNet2DConditionModel
17from schedulers.scheduling_euler_a import EulerAScheduler
17from diffusers.optimization import get_scheduler 18from diffusers.optimization import get_scheduler
18from pipelines.stable_diffusion.no_check import NoCheck 19from pipelines.stable_diffusion.no_check import NoCheck
19from PIL import Image 20from PIL import Image
20from tqdm.auto import tqdm 21from tqdm.auto import tqdm
21from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer 22from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
22from slugify import slugify 23from slugify import slugify
24from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
23import json 25import json
24 26
25from data.dreambooth.csv import CSVDataModule 27from data.dreambooth.csv import CSVDataModule
@@ -215,7 +217,7 @@ def parse_args():
215 parser.add_argument( 217 parser.add_argument(
216 "--sample_steps", 218 "--sample_steps",
217 type=int, 219 type=int,
218 default=80, 220 default=30,
219 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", 221 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.",
220 ) 222 )
221 parser.add_argument( 223 parser.add_argument(
@@ -377,15 +379,16 @@ class Checkpointer:
377 samples_path = Path(self.output_dir).joinpath("samples") 379 samples_path = Path(self.output_dir).joinpath("samples")
378 380
379 unwrapped = self.accelerator.unwrap_model(self.unet) 381 unwrapped = self.accelerator.unwrap_model(self.unet)
380 pipeline = StableDiffusionPipeline( 382 scheduler = EulerAScheduler(
383 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
384 )
385
386 pipeline = VlpnStableDiffusion(
381 text_encoder=self.text_encoder, 387 text_encoder=self.text_encoder,
382 vae=self.vae, 388 vae=self.vae,
383 unet=unwrapped, 389 unet=unwrapped,
384 tokenizer=self.tokenizer, 390 tokenizer=self.tokenizer,
385 scheduler=LMSDiscreteScheduler( 391 scheduler=scheduler,
386 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
387 ),
388 safety_checker=NoCheck(),
389 feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), 392 feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
390 ).to(self.accelerator.device) 393 ).to(self.accelerator.device)
391 pipeline.enable_attention_slicing() 394 pipeline.enable_attention_slicing()
@@ -411,6 +414,8 @@ class Checkpointer:
411 prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( 414 prompt = [prompt for i, batch in data_enum for j, prompt in enumerate(
412 batch["prompts"]) if i * val_data.batch_size + j < self.sample_batch_size] 415 batch["prompts"]) if i * val_data.batch_size + j < self.sample_batch_size]
413 416
417 generator = torch.Generator(device="cuda").manual_seed(self.seed + i)
418
414 with self.accelerator.autocast(): 419 with self.accelerator.autocast():
415 samples = pipeline( 420 samples = pipeline(
416 prompt=prompt, 421 prompt=prompt,
@@ -420,10 +425,13 @@ class Checkpointer:
420 guidance_scale=guidance_scale, 425 guidance_scale=guidance_scale,
421 eta=eta, 426 eta=eta,
422 num_inference_steps=num_inference_steps, 427 num_inference_steps=num_inference_steps,
428 generator=generator,
423 output_type='pil' 429 output_type='pil'
424 )["sample"] 430 )["sample"]
425 431
426 all_samples += samples 432 all_samples += samples
433
434 del generator
427 del samples 435 del samples
428 436
429 image_grid = make_grid(all_samples, self.stable_sample_batches, self.sample_batch_size) 437 image_grid = make_grid(all_samples, self.stable_sample_batches, self.sample_batch_size)
@@ -444,6 +452,8 @@ class Checkpointer:
444 prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( 452 prompt = [prompt for i, batch in data_enum for j, prompt in enumerate(
445 batch["prompts"]) if i * data.batch_size + j < self.sample_batch_size] 453 batch["prompts"]) if i * data.batch_size + j < self.sample_batch_size]
446 454
455 generator = torch.Generator(device="cuda").manual_seed(self.seed + i)
456
447 with self.accelerator.autocast(): 457 with self.accelerator.autocast():
448 samples = pipeline( 458 samples = pipeline(
449 prompt=prompt, 459 prompt=prompt,
@@ -452,10 +462,13 @@ class Checkpointer:
452 guidance_scale=guidance_scale, 462 guidance_scale=guidance_scale,
453 eta=eta, 463 eta=eta,
454 num_inference_steps=num_inference_steps, 464 num_inference_steps=num_inference_steps,
465 generator=generator,
455 output_type='pil' 466 output_type='pil'
456 )["sample"] 467 )["sample"]
457 468
458 all_samples += samples 469 all_samples += samples
470
471 del generator
459 del samples 472 del samples
460 473
461 image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size) 474 image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size)
@@ -465,6 +478,7 @@ class Checkpointer:
465 del image_grid 478 del image_grid
466 479
467 del unwrapped 480 del unwrapped
481 del scheduler
468 del pipeline 482 del pipeline
469 483
470 if torch.cuda.is_available(): 484 if torch.cuda.is_available():
diff --git a/textual_inversion.py b/textual_inversion.py
index 399d876..7a7d7fc 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -3,6 +3,8 @@ import itertools
3import math 3import math
4import os 4import os
5import datetime 5import datetime
6import logging
7from pathlib import Path
6 8
7import numpy as np 9import numpy as np
8import torch 10import torch
@@ -13,12 +15,13 @@ from accelerate import Accelerator
13from accelerate.logging import get_logger 15from accelerate.logging import get_logger
14from accelerate.utils import LoggerType, set_seed 16from accelerate.utils import LoggerType, set_seed
15from diffusers import AutoencoderKL, DDPMScheduler, LMSDiscreteScheduler, StableDiffusionPipeline, UNet2DConditionModel 17from diffusers import AutoencoderKL, DDPMScheduler, LMSDiscreteScheduler, StableDiffusionPipeline, UNet2DConditionModel
18from schedulers.scheduling_euler_a import EulerAScheduler
16from diffusers.optimization import get_scheduler 19from diffusers.optimization import get_scheduler
17from pipelines.stable_diffusion.no_check import NoCheck
18from PIL import Image 20from PIL import Image
19from tqdm.auto import tqdm 21from tqdm.auto import tqdm
20from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer 22from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
21from slugify import slugify 23from slugify import slugify
24from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
22import json 25import json
23import os 26import os
24 27
@@ -44,10 +47,10 @@ def parse_args():
44 help="Pretrained tokenizer name or path if not the same as model_name", 47 help="Pretrained tokenizer name or path if not the same as model_name",
45 ) 48 )
46 parser.add_argument( 49 parser.add_argument(
47 "--train_data_dir", 50 "--train_data_file",
48 type=str, 51 type=str,
49 default=None, 52 default=None,
50 help="A folder containing the training data." 53 help="A CSV file containing the training data."
51 ) 54 )
52 parser.add_argument( 55 parser.add_argument(
53 "--placeholder_token", 56 "--placeholder_token",
@@ -146,6 +149,11 @@ def parse_args():
146 help="Number of steps for the warmup in the lr scheduler." 149 help="Number of steps for the warmup in the lr scheduler."
147 ) 150 )
148 parser.add_argument( 151 parser.add_argument(
152 "--use_8bit_adam",
153 action="store_true",
154 help="Whether or not to use 8-bit Adam from bitsandbytes."
155 )
156 parser.add_argument(
149 "--adam_beta1", 157 "--adam_beta1",
150 type=float, 158 type=float,
151 default=0.9, 159 default=0.9,
@@ -225,7 +233,7 @@ def parse_args():
225 parser.add_argument( 233 parser.add_argument(
226 "--sample_steps", 234 "--sample_steps",
227 type=int, 235 type=int,
228 default=50, 236 default=30,
229 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", 237 help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.",
230 ) 238 )
231 parser.add_argument( 239 parser.add_argument(
@@ -261,8 +269,8 @@ def parse_args():
261 if env_local_rank != -1 and env_local_rank != args.local_rank: 269 if env_local_rank != -1 and env_local_rank != args.local_rank:
262 args.local_rank = env_local_rank 270 args.local_rank = env_local_rank
263 271
264 if args.train_data_dir is None: 272 if args.train_data_file is None:
265 raise ValueError("You must specify --train_data_dir") 273 raise ValueError("You must specify --train_data_file")
266 274
267 if args.pretrained_model_name_or_path is None: 275 if args.pretrained_model_name_or_path is None:
268 raise ValueError("You must specify --pretrained_model_name_or_path") 276 raise ValueError("You must specify --pretrained_model_name_or_path")
@@ -333,53 +341,51 @@ class Checkpointer:
333 @torch.no_grad() 341 @torch.no_grad()
334 def checkpoint(self, step, postfix, text_encoder, save_samples=True, path=None): 342 def checkpoint(self, step, postfix, text_encoder, save_samples=True, path=None):
335 print("Saving checkpoint for step %d..." % step) 343 print("Saving checkpoint for step %d..." % step)
336 with self.accelerator.autocast(): 344
337 if path is None: 345 if path is None:
338 checkpoints_path = f"{self.output_dir}/checkpoints" 346 checkpoints_path = f"{self.output_dir}/checkpoints"
339 os.makedirs(checkpoints_path, exist_ok=True) 347 os.makedirs(checkpoints_path, exist_ok=True)
340 348
341 unwrapped = self.accelerator.unwrap_model(text_encoder) 349 unwrapped = self.accelerator.unwrap_model(text_encoder)
342 350
343 # Save a checkpoint 351 # Save a checkpoint
344 learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id] 352 learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id]
345 learned_embeds_dict = {self.placeholder_token: learned_embeds.detach().cpu()} 353 learned_embeds_dict = {self.placeholder_token: learned_embeds.detach().cpu()}
346 354
347 filename = f"%s_%d_%s.bin" % (slugify(self.placeholder_token), step, postfix) 355 filename = f"%s_%d_%s.bin" % (slugify(self.placeholder_token), step, postfix)
348 if path is not None: 356 if path is not None:
349 torch.save(learned_embeds_dict, path) 357 torch.save(learned_embeds_dict, path)
350 else: 358 else:
351 torch.save(learned_embeds_dict, f"{checkpoints_path}/{filename}") 359 torch.save(learned_embeds_dict, f"{checkpoints_path}/{filename}")
352 torch.save(learned_embeds_dict, f"{checkpoints_path}/last.bin") 360 torch.save(learned_embeds_dict, f"{checkpoints_path}/last.bin")
353 del unwrapped 361
354 del learned_embeds 362 del unwrapped
363 del learned_embeds
355 364
356 @torch.no_grad() 365 @torch.no_grad()
357 def save_samples(self, mode, step, text_encoder, height, width, guidance_scale, eta, num_inference_steps): 366 def save_samples(self, step, text_encoder, height, width, guidance_scale, eta, num_inference_steps):
358 samples_path = f"{self.output_dir}/samples/{mode}" 367 samples_path = Path(self.output_dir).joinpath("samples")
359 os.makedirs(samples_path, exist_ok=True)
360 checker = NoCheck()
361 368
362 unwrapped = self.accelerator.unwrap_model(text_encoder) 369 unwrapped = self.accelerator.unwrap_model(text_encoder)
370 scheduler = EulerAScheduler(
371 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
372 )
373
363 # Save a sample image 374 # Save a sample image
364 pipeline = StableDiffusionPipeline( 375 pipeline = VlpnStableDiffusion(
365 text_encoder=unwrapped, 376 text_encoder=unwrapped,
366 vae=self.vae, 377 vae=self.vae,
367 unet=self.unet, 378 unet=self.unet,
368 tokenizer=self.tokenizer, 379 tokenizer=self.tokenizer,
369 scheduler=LMSDiscreteScheduler( 380 scheduler=scheduler,
370 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
371 ),
372 safety_checker=NoCheck(),
373 feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), 381 feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
374 ).to(self.accelerator.device) 382 ).to(self.accelerator.device)
375 pipeline.enable_attention_slicing() 383 pipeline.enable_attention_slicing()
376 384
377 data = { 385 train_data = self.datamodule.train_dataloader()
378 "training": self.datamodule.train_dataloader(), 386 val_data = self.datamodule.val_dataloader()
379 "validation": self.datamodule.val_dataloader(),
380 }[mode]
381 387
382 if mode == "validation" and self.stable_sample_batches > 0 and step > 0: 388 if self.stable_sample_batches > 0:
383 stable_latents = torch.randn( 389 stable_latents = torch.randn(
384 (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8), 390 (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8),
385 device=pipeline.device, 391 device=pipeline.device,
@@ -387,14 +393,17 @@ class Checkpointer:
387 ) 393 )
388 394
389 all_samples = [] 395 all_samples = []
390 filename = f"stable_step_%d.png" % (step) 396 file_path = samples_path.joinpath("stable", f"step_{step}.png")
397 file_path.parent.mkdir(parents=True, exist_ok=True)
391 398
392 data_enum = enumerate(data) 399 data_enum = enumerate(val_data)
393 400
394 # Generate and save stable samples 401 # Generate and save stable samples
395 for i in range(0, self.stable_sample_batches): 402 for i in range(0, self.stable_sample_batches):
396 prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( 403 prompt = [prompt for i, batch in data_enum for j, prompt in enumerate(
397 batch["prompt"]) if i * data.batch_size + j < self.sample_batch_size] 404 batch["prompts"]) if i * val_data.batch_size + j < self.sample_batch_size]
405
406 generator = torch.Generator(device="cuda").manual_seed(self.seed + i)
398 407
399 with self.accelerator.autocast(): 408 with self.accelerator.autocast():
400 samples = pipeline( 409 samples = pipeline(
@@ -405,67 +414,64 @@ class Checkpointer:
405 guidance_scale=guidance_scale, 414 guidance_scale=guidance_scale,
406 eta=eta, 415 eta=eta,
407 num_inference_steps=num_inference_steps, 416 num_inference_steps=num_inference_steps,
417 generator=generator,
408 output_type='pil' 418 output_type='pil'
409 )["sample"] 419 )["sample"]
410 420
411 all_samples += samples 421 all_samples += samples
422
423 del generator
412 del samples 424 del samples
413 425
414 image_grid = make_grid(all_samples, self.stable_sample_batches, self.sample_batch_size) 426 image_grid = make_grid(all_samples, self.stable_sample_batches, self.sample_batch_size)
415 image_grid.save(f"{samples_path}/{filename}") 427 image_grid.save(file_path)
416 428
417 del all_samples 429 del all_samples
418 del image_grid 430 del image_grid
419 del stable_latents 431 del stable_latents
420 432
421 all_samples = [] 433 for data, pool in [(val_data, "val"), (train_data, "train")]:
422 filename = f"step_%d.png" % (step) 434 all_samples = []
435 file_path = samples_path.joinpath(pool, f"step_{step}.png")
436 file_path.parent.mkdir(parents=True, exist_ok=True)
423 437
424 data_enum = enumerate(data) 438 data_enum = enumerate(data)
425 439
426 # Generate and save random samples 440 for i in range(0, self.random_sample_batches):
427 for i in range(0, self.random_sample_batches): 441 prompt = [prompt for i, batch in data_enum for j, prompt in enumerate(
428 prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( 442 batch["prompts"]) if i * data.batch_size + j < self.sample_batch_size]
429 batch["prompt"]) if i * data.batch_size + j < self.sample_batch_size]
430 443
431 with self.accelerator.autocast(): 444 generator = torch.Generator(device="cuda").manual_seed(self.seed + i)
432 samples = pipeline(
433 prompt=prompt,
434 height=self.sample_image_size,
435 width=self.sample_image_size,
436 guidance_scale=guidance_scale,
437 eta=eta,
438 num_inference_steps=num_inference_steps,
439 output_type='pil'
440 )["sample"]
441 445
442 all_samples += samples 446 with self.accelerator.autocast():
443 del samples 447 samples = pipeline(
448 prompt=prompt,
449 height=self.sample_image_size,
450 width=self.sample_image_size,
451 guidance_scale=guidance_scale,
452 eta=eta,
453 num_inference_steps=num_inference_steps,
454 generator=generator,
455 output_type='pil'
456 )["sample"]
444 457
445 image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size) 458 all_samples += samples
446 image_grid.save(f"{samples_path}/{filename}")
447 459
448 del all_samples 460 del generator
449 del image_grid 461 del samples
450 462
451 del checker 463 image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size)
452 del unwrapped 464 image_grid.save(file_path)
453 del pipeline
454 torch.cuda.empty_cache()
455 465
466 del all_samples
467 del image_grid
456 468
457class ImageToLatents(): 469 del unwrapped
458 def __init__(self, vae): 470 del scheduler
459 self.vae = vae 471 del pipeline
460 self.encoded_pixel_values_cache = {}
461 472
462 @torch.no_grad() 473 if torch.cuda.is_available():
463 def __call__(self, batch): 474 torch.cuda.empty_cache()
464 key = "|".join(batch["key"])
465 if self.encoded_pixel_values_cache.get(key, None) is None:
466 self.encoded_pixel_values_cache[key] = self.vae.encode(batch["pixel_values"]).latent_dist
467 latents = self.encoded_pixel_values_cache[key].sample().detach().half() * 0.18215
468 return latents
469 475
470 476
471def main(): 477def main():
@@ -473,17 +479,17 @@ def main():
473 479
474 global_step_offset = 0 480 global_step_offset = 0
475 if args.resume_from is not None: 481 if args.resume_from is not None:
476 basepath = f"{args.resume_from}" 482 basepath = Path(args.resume_from)
477 print("Resuming state from %s" % args.resume_from) 483 print("Resuming state from %s" % args.resume_from)
478 with open(f"{basepath}/resume.json", 'r') as f: 484 with open(basepath.joinpath("resume.json"), 'r') as f:
479 state = json.load(f) 485 state = json.load(f)
480 global_step_offset = state["args"].get("global_step", 0) 486 global_step_offset = state["args"].get("global_step", 0)
481 487
482 print("We've trained %d steps so far" % global_step_offset) 488 print("We've trained %d steps so far" % global_step_offset)
483 else: 489 else:
484 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") 490 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
485 basepath = f"{args.output_dir}/{slugify(args.placeholder_token)}/{now}" 491 basepath = Path(args.output_dir).joinpath(slugify(args.placeholder_token), now)
486 os.makedirs(basepath, exist_ok=True) 492 basepath.mkdir(parents=True, exist_ok=True)
487 493
488 accelerator = Accelerator( 494 accelerator = Accelerator(
489 log_with=LoggerType.TENSORBOARD, 495 log_with=LoggerType.TENSORBOARD,
@@ -492,6 +498,8 @@ def main():
492 mixed_precision=args.mixed_precision 498 mixed_precision=args.mixed_precision
493 ) 499 )
494 500
501 logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG)
502
495 # If passed along, set the training seed now. 503 # If passed along, set the training seed now.
496 if args.seed is not None: 504 if args.seed is not None:
497 set_seed(args.seed) 505 set_seed(args.seed)
@@ -570,8 +578,19 @@ def main():
570 args.train_batch_size * accelerator.num_processes 578 args.train_batch_size * accelerator.num_processes
571 ) 579 )
572 580
581 # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
582 if args.use_8bit_adam:
583 try:
584 import bitsandbytes as bnb
585 except ImportError:
586 raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.")
587
588 optimizer_class = bnb.optim.AdamW8bit
589 else:
590 optimizer_class = torch.optim.AdamW
591
573 # Initialize the optimizer 592 # Initialize the optimizer
574 optimizer = torch.optim.AdamW( 593 optimizer = optimizer_class(
575 text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings 594 text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings
576 lr=args.learning_rate, 595 lr=args.learning_rate,
577 betas=(args.adam_beta1, args.adam_beta2), 596 betas=(args.adam_beta1, args.adam_beta2),
@@ -585,7 +604,7 @@ def main():
585 ) 604 )
586 605
587 datamodule = CSVDataModule( 606 datamodule = CSVDataModule(
588 data_root=args.train_data_dir, batch_size=args.train_batch_size, tokenizer=tokenizer, 607 data_file=args.train_data_file, batch_size=args.train_batch_size, tokenizer=tokenizer,
589 size=args.resolution, placeholder_token=args.placeholder_token, repeats=args.repeats, 608 size=args.resolution, placeholder_token=args.placeholder_token, repeats=args.repeats,
590 center_crop=args.center_crop) 609 center_crop=args.center_crop)
591 610
@@ -608,13 +627,12 @@ def main():
608 sample_batch_size=args.sample_batch_size, 627 sample_batch_size=args.sample_batch_size,
609 random_sample_batches=args.random_sample_batches, 628 random_sample_batches=args.random_sample_batches,
610 stable_sample_batches=args.stable_sample_batches, 629 stable_sample_batches=args.stable_sample_batches,
611 seed=args.seed 630 seed=args.seed or torch.random.seed()
612 ) 631 )
613 632
614 # Scheduler and math around the number of training steps. 633 # Scheduler and math around the number of training steps.
615 overrode_max_train_steps = False 634 overrode_max_train_steps = False
616 num_update_steps_per_epoch = math.ceil( 635 num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
617 (len(train_dataloader) + len(val_dataloader)) / args.gradient_accumulation_steps)
618 if args.max_train_steps is None: 636 if args.max_train_steps is None:
619 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 637 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
620 overrode_max_train_steps = True 638 overrode_max_train_steps = True
@@ -643,9 +661,10 @@ def main():
643 (len(train_dataloader) + len(val_dataloader)) / args.gradient_accumulation_steps) 661 (len(train_dataloader) + len(val_dataloader)) / args.gradient_accumulation_steps)
644 if overrode_max_train_steps: 662 if overrode_max_train_steps:
645 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 663 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
646 # Afterwards we recalculate our number of training epochs 664
647 args.num_train_epochs = math.ceil( 665 num_val_steps_per_epoch = len(val_dataloader)
648 args.max_train_steps / num_update_steps_per_epoch) 666 num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
667 val_steps = num_val_steps_per_epoch * num_epochs
649 668
650 # We need to initialize the trackers we use, and also store our configuration. 669 # We need to initialize the trackers we use, and also store our configuration.
651 # The trackers initializes automatically on the main process. 670 # The trackers initializes automatically on the main process.
@@ -656,7 +675,7 @@ def main():
656 total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps 675 total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
657 676
658 logger.info("***** Running training *****") 677 logger.info("***** Running training *****")
659 logger.info(f" Num Epochs = {args.num_train_epochs}") 678 logger.info(f" Num Epochs = {num_epochs}")
660 logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") 679 logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
661 logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") 680 logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
662 logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") 681 logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
@@ -666,22 +685,22 @@ def main():
666 global_step = 0 685 global_step = 0
667 min_val_loss = np.inf 686 min_val_loss = np.inf
668 687
669 imageToLatents = ImageToLatents(vae) 688 if accelerator.is_main_process:
670 689 checkpointer.save_samples(
671 checkpointer.save_samples( 690 0,
672 "validation", 691 text_encoder,
673 0, 692 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps)
674 text_encoder,
675 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps)
676 693
677 progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) 694 local_progress_bar = tqdm(range(num_update_steps_per_epoch + num_val_steps_per_epoch),
678 progress_bar.set_description("Global steps") 695 disable=not accelerator.is_local_main_process)
696 local_progress_bar.set_description("Batch X out of Y")
679 697
680 local_progress_bar = tqdm(range(num_update_steps_per_epoch), disable=not accelerator.is_local_main_process) 698 global_progress_bar = tqdm(range(args.max_train_steps + val_steps), disable=not accelerator.is_local_main_process)
681 local_progress_bar.set_description("Steps") 699 global_progress_bar.set_description("Total progress")
682 700
683 try: 701 try:
684 for epoch in range(args.num_train_epochs): 702 for epoch in range(num_epochs):
703 local_progress_bar.set_description(f"Batch {epoch + 1} out of {num_epochs}")
685 local_progress_bar.reset() 704 local_progress_bar.reset()
686 705
687 text_encoder.train() 706 text_encoder.train()
@@ -689,27 +708,30 @@ def main():
689 708
690 for step, batch in enumerate(train_dataloader): 709 for step, batch in enumerate(train_dataloader):
691 with accelerator.accumulate(text_encoder): 710 with accelerator.accumulate(text_encoder):
692 with accelerator.autocast(): 711 # Convert images to latent space
693 # Convert images to latent space 712 with torch.no_grad():
694 latents = imageToLatents(batch) 713 latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
714 latents = latents * 0.18215
695 715
696 # Sample noise that we'll add to the latents 716 # Sample noise that we'll add to the latents
697 noise = torch.randn(latents.shape).to(latents.device) 717 noise = torch.randn(latents.shape).to(latents.device)
698 bsz = latents.shape[0] 718 bsz = latents.shape[0]
699 # Sample a random timestep for each image 719 # Sample a random timestep for each image
700 timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, 720 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps,
701 (bsz,), device=latents.device).long() 721 (bsz,), device=latents.device)
722 timesteps = timesteps.long()
702 723
703 # Add noise to the latents according to the noise magnitude at each timestep 724 # Add noise to the latents according to the noise magnitude at each timestep
704 # (this is the forward diffusion process) 725 # (this is the forward diffusion process)
705 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 726 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
706 727
707 # Get the text embedding for conditioning 728 # Get the text embedding for conditioning
708 encoder_hidden_states = text_encoder(batch["input_ids"])[0] 729 encoder_hidden_states = text_encoder(batch["input_ids"])[0]
709 730
710 # Predict the noise residual 731 # Predict the noise residual
711 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample 732 noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
712 733
734 with accelerator.autocast():
713 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() 735 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
714 736
715 accelerator.backward(loss) 737 accelerator.backward(loss)
@@ -727,32 +749,27 @@ def main():
727 optimizer.step() 749 optimizer.step()
728 if not accelerator.optimizer_step_was_skipped: 750 if not accelerator.optimizer_step_was_skipped:
729 lr_scheduler.step() 751 lr_scheduler.step()
730 optimizer.zero_grad() 752 optimizer.zero_grad(set_to_none=True)
731 753
732 loss = loss.detach().item() 754 loss = loss.detach().item()
733 train_loss += loss 755 train_loss += loss
734 756
735 # Checks if the accelerator has performed an optimization step behind the scenes 757 # Checks if the accelerator has performed an optimization step behind the scenes
736 if accelerator.sync_gradients: 758 if accelerator.sync_gradients:
737 progress_bar.update(1)
738 local_progress_bar.update(1) 759 local_progress_bar.update(1)
760 global_progress_bar.update(1)
739 761
740 global_step += 1 762 global_step += 1
741 763
742 if global_step % args.checkpoint_frequency == 0 and global_step > 0 and accelerator.is_main_process: 764 if global_step % args.checkpoint_frequency == 0 and global_step > 0 and accelerator.is_main_process:
743 progress_bar.clear()
744 local_progress_bar.clear() 765 local_progress_bar.clear()
766 global_progress_bar.clear()
745 767
746 checkpointer.checkpoint(global_step + global_step_offset, "training", text_encoder) 768 checkpointer.checkpoint(global_step + global_step_offset, "training", text_encoder)
747 save_resume_file(basepath, args, { 769 save_resume_file(basepath, args, {
748 "global_step": global_step + global_step_offset, 770 "global_step": global_step + global_step_offset,
749 "resume_checkpoint": f"{basepath}/checkpoints/last.bin" 771 "resume_checkpoint": f"{basepath}/checkpoints/last.bin"
750 }) 772 })
751 checkpointer.save_samples(
752 "training",
753 global_step + global_step_offset,
754 text_encoder,
755 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps)
756 773
757 logs = {"mode": "training", "loss": loss, "lr": lr_scheduler.get_last_lr()[0]} 774 logs = {"mode": "training", "loss": loss, "lr": lr_scheduler.get_last_lr()[0]}
758 local_progress_bar.set_postfix(**logs) 775 local_progress_bar.set_postfix(**logs)
@@ -762,17 +779,21 @@ def main():
762 779
763 train_loss /= len(train_dataloader) 780 train_loss /= len(train_dataloader)
764 781
782 accelerator.wait_for_everyone()
783
765 text_encoder.eval() 784 text_encoder.eval()
766 val_loss = 0.0 785 val_loss = 0.0
767 786
768 for step, batch in enumerate(val_dataloader): 787 for step, batch in enumerate(val_dataloader):
769 with torch.no_grad(), accelerator.autocast(): 788 with torch.no_grad():
770 latents = imageToLatents(batch) 789 latents = vae.encode(batch["pixel_values"]).latent_dist.sample()
790 latents = latents * 0.18215
771 791
772 noise = torch.randn(latents.shape).to(latents.device) 792 noise = torch.randn(latents.shape).to(latents.device)
773 bsz = latents.shape[0] 793 bsz = latents.shape[0]
774 timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, 794 timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps,
775 (bsz,), device=latents.device).long() 795 (bsz,), device=latents.device)
796 timesteps = timesteps.long()
776 797
777 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) 798 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
778 799
@@ -782,14 +803,15 @@ def main():
782 803
783 noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) 804 noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise))
784 805
785 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() 806 with accelerator.autocast():
807 loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean()
786 808
787 loss = loss.detach().item() 809 loss = loss.detach().item()
788 val_loss += loss 810 val_loss += loss
789 811
790 if accelerator.sync_gradients: 812 if accelerator.sync_gradients:
791 progress_bar.update(1)
792 local_progress_bar.update(1) 813 local_progress_bar.update(1)
814 global_progress_bar.update(1)
793 815
794 logs = {"mode": "validation", "loss": loss} 816 logs = {"mode": "validation", "loss": loss}
795 local_progress_bar.set_postfix(**logs) 817 local_progress_bar.set_postfix(**logs)
@@ -798,21 +820,19 @@ def main():
798 820
799 accelerator.log({"train/loss": train_loss, "val/loss": val_loss}, step=global_step) 821 accelerator.log({"train/loss": train_loss, "val/loss": val_loss}, step=global_step)
800 822
801 progress_bar.clear()
802 local_progress_bar.clear() 823 local_progress_bar.clear()
824 global_progress_bar.clear()
803 825
804 if min_val_loss > val_loss: 826 if min_val_loss > val_loss:
805 accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") 827 accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}")
806 checkpointer.checkpoint(global_step + global_step_offset, "milestone", text_encoder) 828 checkpointer.checkpoint(global_step + global_step_offset, "milestone", text_encoder)
807 min_val_loss = val_loss 829 min_val_loss = val_loss
808 830
809 checkpointer.save_samples( 831 if accelerator.is_main_process:
810 "validation", 832 checkpointer.save_samples(
811 global_step + global_step_offset, 833 global_step + global_step_offset,
812 text_encoder, 834 text_encoder,
813 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) 835 args.resolution, args.resolution, 7.5, 0.0, args.sample_steps)
814
815 accelerator.wait_for_everyone()
816 836
817 # Create the pipeline using using the trained modules and save it. 837 # Create the pipeline using using the trained modules and save it.
818 if accelerator.is_main_process: 838 if accelerator.is_main_process: