diff options
| -rw-r--r-- | data/dreambooth/csv.py | 1 | ||||
| -rw-r--r-- | data/textual_inversion/csv.py | 98 | ||||
| -rw-r--r-- | dreambooth.py | 26 | ||||
| -rw-r--r-- | textual_inversion.py | 302 |
4 files changed, 227 insertions, 200 deletions
diff --git a/data/dreambooth/csv.py b/data/dreambooth/csv.py index 99bcf12..1676d35 100644 --- a/data/dreambooth/csv.py +++ b/data/dreambooth/csv.py | |||
| @@ -2,7 +2,6 @@ import math | |||
| 2 | import os | 2 | import os |
| 3 | import pandas as pd | 3 | import pandas as pd |
| 4 | from pathlib import Path | 4 | from pathlib import Path |
| 5 | import PIL | ||
| 6 | import pytorch_lightning as pl | 5 | import pytorch_lightning as pl |
| 7 | from PIL import Image | 6 | from PIL import Image |
| 8 | from torch.utils.data import Dataset, DataLoader, random_split | 7 | from torch.utils.data import Dataset, DataLoader, random_split |
diff --git a/data/textual_inversion/csv.py b/data/textual_inversion/csv.py index 0d1e96e..f306c7a 100644 --- a/data/textual_inversion/csv.py +++ b/data/textual_inversion/csv.py | |||
| @@ -1,11 +1,10 @@ | |||
| 1 | import os | 1 | import os |
| 2 | import numpy as np | 2 | import numpy as np |
| 3 | import pandas as pd | 3 | import pandas as pd |
| 4 | import random | 4 | from pathlib import Path |
| 5 | import PIL | 5 | import math |
| 6 | import pytorch_lightning as pl | 6 | import pytorch_lightning as pl |
| 7 | from PIL import Image | 7 | from PIL import Image |
| 8 | import torch | ||
| 9 | from torch.utils.data import Dataset, DataLoader, random_split | 8 | from torch.utils.data import Dataset, DataLoader, random_split |
| 10 | from torchvision import transforms | 9 | from torchvision import transforms |
| 11 | 10 | ||
| @@ -13,29 +12,32 @@ from torchvision import transforms | |||
| 13 | class CSVDataModule(pl.LightningDataModule): | 12 | class CSVDataModule(pl.LightningDataModule): |
| 14 | def __init__(self, | 13 | def __init__(self, |
| 15 | batch_size, | 14 | batch_size, |
| 16 | data_root, | 15 | data_file, |
| 17 | tokenizer, | 16 | tokenizer, |
| 18 | size=512, | 17 | size=512, |
| 19 | repeats=100, | 18 | repeats=100, |
| 20 | interpolation="bicubic", | 19 | interpolation="bicubic", |
| 21 | placeholder_token="*", | 20 | placeholder_token="*", |
| 22 | flip_p=0.5, | ||
| 23 | center_crop=False): | 21 | center_crop=False): |
| 24 | super().__init__() | 22 | super().__init__() |
| 25 | 23 | ||
| 26 | self.data_root = data_root | 24 | self.data_file = Path(data_file) |
| 25 | |||
| 26 | if not self.data_file.is_file(): | ||
| 27 | raise ValueError("data_file must be a file") | ||
| 28 | |||
| 29 | self.data_root = self.data_file.parent | ||
| 27 | self.tokenizer = tokenizer | 30 | self.tokenizer = tokenizer |
| 28 | self.size = size | 31 | self.size = size |
| 29 | self.repeats = repeats | 32 | self.repeats = repeats |
| 30 | self.placeholder_token = placeholder_token | 33 | self.placeholder_token = placeholder_token |
| 31 | self.center_crop = center_crop | 34 | self.center_crop = center_crop |
| 32 | self.flip_p = flip_p | ||
| 33 | self.interpolation = interpolation | 35 | self.interpolation = interpolation |
| 34 | 36 | ||
| 35 | self.batch_size = batch_size | 37 | self.batch_size = batch_size |
| 36 | 38 | ||
| 37 | def prepare_data(self): | 39 | def prepare_data(self): |
| 38 | metadata = pd.read_csv(f'{self.data_root}/list.csv') | 40 | metadata = pd.read_csv(self.data_file) |
| 39 | image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values] | 41 | image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values] |
| 40 | captions = [caption for caption in metadata['caption'].values] | 42 | captions = [caption for caption in metadata['caption'].values] |
| 41 | skips = [skip for skip in metadata['skip'].values] | 43 | skips = [skip for skip in metadata['skip'].values] |
| @@ -47,9 +49,9 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 47 | self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size]) | 49 | self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size]) |
| 48 | 50 | ||
| 49 | train_dataset = CSVDataset(self.data_train, self.tokenizer, size=self.size, repeats=self.repeats, interpolation=self.interpolation, | 51 | train_dataset = CSVDataset(self.data_train, self.tokenizer, size=self.size, repeats=self.repeats, interpolation=self.interpolation, |
| 50 | flip_p=self.flip_p, placeholder_token=self.placeholder_token, center_crop=self.center_crop) | 52 | placeholder_token=self.placeholder_token, center_crop=self.center_crop) |
| 51 | val_dataset = CSVDataset(self.data_val, self.tokenizer, size=self.size, interpolation=self.interpolation, | 53 | val_dataset = CSVDataset(self.data_val, self.tokenizer, size=self.size, interpolation=self.interpolation, |
| 52 | flip_p=self.flip_p, placeholder_token=self.placeholder_token, center_crop=self.center_crop) | 54 | placeholder_token=self.placeholder_token, center_crop=self.center_crop) |
| 53 | self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True) | 55 | self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True) |
| 54 | self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size) | 56 | self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size) |
| 55 | 57 | ||
| @@ -67,48 +69,54 @@ class CSVDataset(Dataset): | |||
| 67 | size=512, | 69 | size=512, |
| 68 | repeats=1, | 70 | repeats=1, |
| 69 | interpolation="bicubic", | 71 | interpolation="bicubic", |
| 70 | flip_p=0.5, | ||
| 71 | placeholder_token="*", | 72 | placeholder_token="*", |
| 72 | center_crop=False, | 73 | center_crop=False, |
| 74 | batch_size=1, | ||
| 73 | ): | 75 | ): |
| 74 | 76 | ||
| 75 | self.data = data | 77 | self.data = data |
| 76 | self.tokenizer = tokenizer | 78 | self.tokenizer = tokenizer |
| 77 | |||
| 78 | self.num_images = len(self.data) | ||
| 79 | self._length = self.num_images * repeats | ||
| 80 | |||
| 81 | self.placeholder_token = placeholder_token | 79 | self.placeholder_token = placeholder_token |
| 80 | self.batch_size = batch_size | ||
| 81 | self.cache = {} | ||
| 82 | 82 | ||
| 83 | self.size = size | 83 | self.num_instance_images = len(self.data) |
| 84 | self.center_crop = center_crop | 84 | self._length = self.num_instance_images * repeats |
| 85 | self.interpolation = {"linear": PIL.Image.LINEAR, | ||
| 86 | "bilinear": PIL.Image.BILINEAR, | ||
| 87 | "bicubic": PIL.Image.BICUBIC, | ||
| 88 | "lanczos": PIL.Image.LANCZOS, | ||
| 89 | }[interpolation] | ||
| 90 | self.flip = transforms.RandomHorizontalFlip(p=flip_p) | ||
| 91 | 85 | ||
| 92 | self.cache = {} | 86 | self.interpolation = {"linear": transforms.InterpolationMode.NEAREST, |
| 87 | "bilinear": transforms.InterpolationMode.BILINEAR, | ||
| 88 | "bicubic": transforms.InterpolationMode.BICUBIC, | ||
| 89 | "lanczos": transforms.InterpolationMode.LANCZOS, | ||
| 90 | }[interpolation] | ||
| 91 | self.image_transforms = transforms.Compose( | ||
| 92 | [ | ||
| 93 | transforms.Resize(size, interpolation=self.interpolation), | ||
| 94 | transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), | ||
| 95 | transforms.RandomHorizontalFlip(), | ||
| 96 | transforms.ToTensor(), | ||
| 97 | transforms.Normalize([0.5], [0.5]), | ||
| 98 | ] | ||
| 99 | ) | ||
| 93 | 100 | ||
| 94 | def __len__(self): | 101 | def __len__(self): |
| 95 | return self._length | 102 | return math.ceil(self._length / self.batch_size) * self.batch_size |
| 96 | 103 | ||
| 97 | def get_example(self, i, flipped): | 104 | def get_example(self, i): |
| 98 | image_path, text = self.data[i % self.num_images] | 105 | image_path, text = self.data[i % self.num_instance_images] |
| 99 | 106 | ||
| 100 | if image_path in self.cache: | 107 | if image_path in self.cache: |
| 101 | return self.cache[image_path] | 108 | return self.cache[image_path] |
| 102 | 109 | ||
| 103 | example = {} | 110 | example = {} |
| 104 | image = Image.open(image_path) | ||
| 105 | 111 | ||
| 106 | if not image.mode == "RGB": | 112 | instance_image = Image.open(image_path) |
| 107 | image = image.convert("RGB") | 113 | if not instance_image.mode == "RGB": |
| 114 | instance_image = instance_image.convert("RGB") | ||
| 108 | 115 | ||
| 109 | text = text.format(self.placeholder_token) | 116 | text = text.format(self.placeholder_token) |
| 110 | 117 | ||
| 111 | example["prompt"] = text | 118 | example["prompts"] = text |
| 119 | example["pixel_values"] = instance_image | ||
| 112 | example["input_ids"] = self.tokenizer( | 120 | example["input_ids"] = self.tokenizer( |
| 113 | text, | 121 | text, |
| 114 | padding="max_length", | 122 | padding="max_length", |
| @@ -117,29 +125,15 @@ class CSVDataset(Dataset): | |||
| 117 | return_tensors="pt", | 125 | return_tensors="pt", |
| 118 | ).input_ids[0] | 126 | ).input_ids[0] |
| 119 | 127 | ||
| 120 | # default to score-sde preprocessing | ||
| 121 | img = np.array(image).astype(np.uint8) | ||
| 122 | |||
| 123 | if self.center_crop: | ||
| 124 | crop = min(img.shape[0], img.shape[1]) | ||
| 125 | h, w, = img.shape[0], img.shape[1] | ||
| 126 | img = img[(h - crop) // 2:(h + crop) // 2, | ||
| 127 | (w - crop) // 2:(w + crop) // 2] | ||
| 128 | |||
| 129 | image = Image.fromarray(img) | ||
| 130 | image = image.resize((self.size, self.size), | ||
| 131 | resample=self.interpolation) | ||
| 132 | image = self.flip(image) | ||
| 133 | image = np.array(image).astype(np.uint8) | ||
| 134 | image = (image / 127.5 - 1.0).astype(np.float32) | ||
| 135 | |||
| 136 | example["key"] = "-".join([image_path, "-", str(flipped)]) | ||
| 137 | example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1) | ||
| 138 | |||
| 139 | self.cache[image_path] = example | 128 | self.cache[image_path] = example |
| 140 | return example | 129 | return example |
| 141 | 130 | ||
| 142 | def __getitem__(self, i): | 131 | def __getitem__(self, i): |
| 143 | flipped = random.choice([False, True]) | 132 | example = {} |
| 144 | example = self.get_example(i, flipped) | 133 | unprocessed_example = self.get_example(i) |
| 134 | |||
| 135 | example["prompts"] = unprocessed_example["prompts"] | ||
| 136 | example["input_ids"] = unprocessed_example["input_ids"] | ||
| 137 | example["pixel_values"] = self.image_transforms(unprocessed_example["pixel_values"]) | ||
| 138 | |||
| 145 | return example | 139 | return example |
diff --git a/dreambooth.py b/dreambooth.py index 4d7366c..744d1bc 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
| @@ -14,12 +14,14 @@ from accelerate import Accelerator | |||
| 14 | from accelerate.logging import get_logger | 14 | from accelerate.logging import get_logger |
| 15 | from accelerate.utils import LoggerType, set_seed | 15 | from accelerate.utils import LoggerType, set_seed |
| 16 | from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, LMSDiscreteScheduler, StableDiffusionPipeline, UNet2DConditionModel | 16 | from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, LMSDiscreteScheduler, StableDiffusionPipeline, UNet2DConditionModel |
| 17 | from schedulers.scheduling_euler_a import EulerAScheduler | ||
| 17 | from diffusers.optimization import get_scheduler | 18 | from diffusers.optimization import get_scheduler |
| 18 | from pipelines.stable_diffusion.no_check import NoCheck | 19 | from pipelines.stable_diffusion.no_check import NoCheck |
| 19 | from PIL import Image | 20 | from PIL import Image |
| 20 | from tqdm.auto import tqdm | 21 | from tqdm.auto import tqdm |
| 21 | from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer | 22 | from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer |
| 22 | from slugify import slugify | 23 | from slugify import slugify |
| 24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | ||
| 23 | import json | 25 | import json |
| 24 | 26 | ||
| 25 | from data.dreambooth.csv import CSVDataModule | 27 | from data.dreambooth.csv import CSVDataModule |
| @@ -215,7 +217,7 @@ def parse_args(): | |||
| 215 | parser.add_argument( | 217 | parser.add_argument( |
| 216 | "--sample_steps", | 218 | "--sample_steps", |
| 217 | type=int, | 219 | type=int, |
| 218 | default=80, | 220 | default=30, |
| 219 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", | 221 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", |
| 220 | ) | 222 | ) |
| 221 | parser.add_argument( | 223 | parser.add_argument( |
| @@ -377,15 +379,16 @@ class Checkpointer: | |||
| 377 | samples_path = Path(self.output_dir).joinpath("samples") | 379 | samples_path = Path(self.output_dir).joinpath("samples") |
| 378 | 380 | ||
| 379 | unwrapped = self.accelerator.unwrap_model(self.unet) | 381 | unwrapped = self.accelerator.unwrap_model(self.unet) |
| 380 | pipeline = StableDiffusionPipeline( | 382 | scheduler = EulerAScheduler( |
| 383 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | ||
| 384 | ) | ||
| 385 | |||
| 386 | pipeline = VlpnStableDiffusion( | ||
| 381 | text_encoder=self.text_encoder, | 387 | text_encoder=self.text_encoder, |
| 382 | vae=self.vae, | 388 | vae=self.vae, |
| 383 | unet=unwrapped, | 389 | unet=unwrapped, |
| 384 | tokenizer=self.tokenizer, | 390 | tokenizer=self.tokenizer, |
| 385 | scheduler=LMSDiscreteScheduler( | 391 | scheduler=scheduler, |
| 386 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | ||
| 387 | ), | ||
| 388 | safety_checker=NoCheck(), | ||
| 389 | feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), | 392 | feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), |
| 390 | ).to(self.accelerator.device) | 393 | ).to(self.accelerator.device) |
| 391 | pipeline.enable_attention_slicing() | 394 | pipeline.enable_attention_slicing() |
| @@ -411,6 +414,8 @@ class Checkpointer: | |||
| 411 | prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( | 414 | prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( |
| 412 | batch["prompts"]) if i * val_data.batch_size + j < self.sample_batch_size] | 415 | batch["prompts"]) if i * val_data.batch_size + j < self.sample_batch_size] |
| 413 | 416 | ||
| 417 | generator = torch.Generator(device="cuda").manual_seed(self.seed + i) | ||
| 418 | |||
| 414 | with self.accelerator.autocast(): | 419 | with self.accelerator.autocast(): |
| 415 | samples = pipeline( | 420 | samples = pipeline( |
| 416 | prompt=prompt, | 421 | prompt=prompt, |
| @@ -420,10 +425,13 @@ class Checkpointer: | |||
| 420 | guidance_scale=guidance_scale, | 425 | guidance_scale=guidance_scale, |
| 421 | eta=eta, | 426 | eta=eta, |
| 422 | num_inference_steps=num_inference_steps, | 427 | num_inference_steps=num_inference_steps, |
| 428 | generator=generator, | ||
| 423 | output_type='pil' | 429 | output_type='pil' |
| 424 | )["sample"] | 430 | )["sample"] |
| 425 | 431 | ||
| 426 | all_samples += samples | 432 | all_samples += samples |
| 433 | |||
| 434 | del generator | ||
| 427 | del samples | 435 | del samples |
| 428 | 436 | ||
| 429 | image_grid = make_grid(all_samples, self.stable_sample_batches, self.sample_batch_size) | 437 | image_grid = make_grid(all_samples, self.stable_sample_batches, self.sample_batch_size) |
| @@ -444,6 +452,8 @@ class Checkpointer: | |||
| 444 | prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( | 452 | prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( |
| 445 | batch["prompts"]) if i * data.batch_size + j < self.sample_batch_size] | 453 | batch["prompts"]) if i * data.batch_size + j < self.sample_batch_size] |
| 446 | 454 | ||
| 455 | generator = torch.Generator(device="cuda").manual_seed(self.seed + i) | ||
| 456 | |||
| 447 | with self.accelerator.autocast(): | 457 | with self.accelerator.autocast(): |
| 448 | samples = pipeline( | 458 | samples = pipeline( |
| 449 | prompt=prompt, | 459 | prompt=prompt, |
| @@ -452,10 +462,13 @@ class Checkpointer: | |||
| 452 | guidance_scale=guidance_scale, | 462 | guidance_scale=guidance_scale, |
| 453 | eta=eta, | 463 | eta=eta, |
| 454 | num_inference_steps=num_inference_steps, | 464 | num_inference_steps=num_inference_steps, |
| 465 | generator=generator, | ||
| 455 | output_type='pil' | 466 | output_type='pil' |
| 456 | )["sample"] | 467 | )["sample"] |
| 457 | 468 | ||
| 458 | all_samples += samples | 469 | all_samples += samples |
| 470 | |||
| 471 | del generator | ||
| 459 | del samples | 472 | del samples |
| 460 | 473 | ||
| 461 | image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size) | 474 | image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size) |
| @@ -465,6 +478,7 @@ class Checkpointer: | |||
| 465 | del image_grid | 478 | del image_grid |
| 466 | 479 | ||
| 467 | del unwrapped | 480 | del unwrapped |
| 481 | del scheduler | ||
| 468 | del pipeline | 482 | del pipeline |
| 469 | 483 | ||
| 470 | if torch.cuda.is_available(): | 484 | if torch.cuda.is_available(): |
diff --git a/textual_inversion.py b/textual_inversion.py index 399d876..7a7d7fc 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
| @@ -3,6 +3,8 @@ import itertools | |||
| 3 | import math | 3 | import math |
| 4 | import os | 4 | import os |
| 5 | import datetime | 5 | import datetime |
| 6 | import logging | ||
| 7 | from pathlib import Path | ||
| 6 | 8 | ||
| 7 | import numpy as np | 9 | import numpy as np |
| 8 | import torch | 10 | import torch |
| @@ -13,12 +15,13 @@ from accelerate import Accelerator | |||
| 13 | from accelerate.logging import get_logger | 15 | from accelerate.logging import get_logger |
| 14 | from accelerate.utils import LoggerType, set_seed | 16 | from accelerate.utils import LoggerType, set_seed |
| 15 | from diffusers import AutoencoderKL, DDPMScheduler, LMSDiscreteScheduler, StableDiffusionPipeline, UNet2DConditionModel | 17 | from diffusers import AutoencoderKL, DDPMScheduler, LMSDiscreteScheduler, StableDiffusionPipeline, UNet2DConditionModel |
| 18 | from schedulers.scheduling_euler_a import EulerAScheduler | ||
| 16 | from diffusers.optimization import get_scheduler | 19 | from diffusers.optimization import get_scheduler |
| 17 | from pipelines.stable_diffusion.no_check import NoCheck | ||
| 18 | from PIL import Image | 20 | from PIL import Image |
| 19 | from tqdm.auto import tqdm | 21 | from tqdm.auto import tqdm |
| 20 | from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer | 22 | from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer |
| 21 | from slugify import slugify | 23 | from slugify import slugify |
| 24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | ||
| 22 | import json | 25 | import json |
| 23 | import os | 26 | import os |
| 24 | 27 | ||
| @@ -44,10 +47,10 @@ def parse_args(): | |||
| 44 | help="Pretrained tokenizer name or path if not the same as model_name", | 47 | help="Pretrained tokenizer name or path if not the same as model_name", |
| 45 | ) | 48 | ) |
| 46 | parser.add_argument( | 49 | parser.add_argument( |
| 47 | "--train_data_dir", | 50 | "--train_data_file", |
| 48 | type=str, | 51 | type=str, |
| 49 | default=None, | 52 | default=None, |
| 50 | help="A folder containing the training data." | 53 | help="A CSV file containing the training data." |
| 51 | ) | 54 | ) |
| 52 | parser.add_argument( | 55 | parser.add_argument( |
| 53 | "--placeholder_token", | 56 | "--placeholder_token", |
| @@ -146,6 +149,11 @@ def parse_args(): | |||
| 146 | help="Number of steps for the warmup in the lr scheduler." | 149 | help="Number of steps for the warmup in the lr scheduler." |
| 147 | ) | 150 | ) |
| 148 | parser.add_argument( | 151 | parser.add_argument( |
| 152 | "--use_8bit_adam", | ||
| 153 | action="store_true", | ||
| 154 | help="Whether or not to use 8-bit Adam from bitsandbytes." | ||
| 155 | ) | ||
| 156 | parser.add_argument( | ||
| 149 | "--adam_beta1", | 157 | "--adam_beta1", |
| 150 | type=float, | 158 | type=float, |
| 151 | default=0.9, | 159 | default=0.9, |
| @@ -225,7 +233,7 @@ def parse_args(): | |||
| 225 | parser.add_argument( | 233 | parser.add_argument( |
| 226 | "--sample_steps", | 234 | "--sample_steps", |
| 227 | type=int, | 235 | type=int, |
| 228 | default=50, | 236 | default=30, |
| 229 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", | 237 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", |
| 230 | ) | 238 | ) |
| 231 | parser.add_argument( | 239 | parser.add_argument( |
| @@ -261,8 +269,8 @@ def parse_args(): | |||
| 261 | if env_local_rank != -1 and env_local_rank != args.local_rank: | 269 | if env_local_rank != -1 and env_local_rank != args.local_rank: |
| 262 | args.local_rank = env_local_rank | 270 | args.local_rank = env_local_rank |
| 263 | 271 | ||
| 264 | if args.train_data_dir is None: | 272 | if args.train_data_file is None: |
| 265 | raise ValueError("You must specify --train_data_dir") | 273 | raise ValueError("You must specify --train_data_file") |
| 266 | 274 | ||
| 267 | if args.pretrained_model_name_or_path is None: | 275 | if args.pretrained_model_name_or_path is None: |
| 268 | raise ValueError("You must specify --pretrained_model_name_or_path") | 276 | raise ValueError("You must specify --pretrained_model_name_or_path") |
| @@ -333,53 +341,51 @@ class Checkpointer: | |||
| 333 | @torch.no_grad() | 341 | @torch.no_grad() |
| 334 | def checkpoint(self, step, postfix, text_encoder, save_samples=True, path=None): | 342 | def checkpoint(self, step, postfix, text_encoder, save_samples=True, path=None): |
| 335 | print("Saving checkpoint for step %d..." % step) | 343 | print("Saving checkpoint for step %d..." % step) |
| 336 | with self.accelerator.autocast(): | ||
| 337 | if path is None: | ||
| 338 | checkpoints_path = f"{self.output_dir}/checkpoints" | ||
| 339 | os.makedirs(checkpoints_path, exist_ok=True) | ||
| 340 | 344 | ||
| 341 | unwrapped = self.accelerator.unwrap_model(text_encoder) | 345 | if path is None: |
| 346 | checkpoints_path = f"{self.output_dir}/checkpoints" | ||
| 347 | os.makedirs(checkpoints_path, exist_ok=True) | ||
| 348 | |||
| 349 | unwrapped = self.accelerator.unwrap_model(text_encoder) | ||
| 350 | |||
| 351 | # Save a checkpoint | ||
| 352 | learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id] | ||
| 353 | learned_embeds_dict = {self.placeholder_token: learned_embeds.detach().cpu()} | ||
| 342 | 354 | ||
| 343 | # Save a checkpoint | 355 | filename = f"%s_%d_%s.bin" % (slugify(self.placeholder_token), step, postfix) |
| 344 | learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id] | 356 | if path is not None: |
| 345 | learned_embeds_dict = {self.placeholder_token: learned_embeds.detach().cpu()} | 357 | torch.save(learned_embeds_dict, path) |
| 358 | else: | ||
| 359 | torch.save(learned_embeds_dict, f"{checkpoints_path}/{filename}") | ||
| 360 | torch.save(learned_embeds_dict, f"{checkpoints_path}/last.bin") | ||
| 346 | 361 | ||
| 347 | filename = f"%s_%d_%s.bin" % (slugify(self.placeholder_token), step, postfix) | 362 | del unwrapped |
| 348 | if path is not None: | 363 | del learned_embeds |
| 349 | torch.save(learned_embeds_dict, path) | ||
| 350 | else: | ||
| 351 | torch.save(learned_embeds_dict, f"{checkpoints_path}/{filename}") | ||
| 352 | torch.save(learned_embeds_dict, f"{checkpoints_path}/last.bin") | ||
| 353 | del unwrapped | ||
| 354 | del learned_embeds | ||
| 355 | 364 | ||
| 356 | @torch.no_grad() | 365 | @torch.no_grad() |
| 357 | def save_samples(self, mode, step, text_encoder, height, width, guidance_scale, eta, num_inference_steps): | 366 | def save_samples(self, step, text_encoder, height, width, guidance_scale, eta, num_inference_steps): |
| 358 | samples_path = f"{self.output_dir}/samples/{mode}" | 367 | samples_path = Path(self.output_dir).joinpath("samples") |
| 359 | os.makedirs(samples_path, exist_ok=True) | ||
| 360 | checker = NoCheck() | ||
| 361 | 368 | ||
| 362 | unwrapped = self.accelerator.unwrap_model(text_encoder) | 369 | unwrapped = self.accelerator.unwrap_model(text_encoder) |
| 370 | scheduler = EulerAScheduler( | ||
| 371 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | ||
| 372 | ) | ||
| 373 | |||
| 363 | # Save a sample image | 374 | # Save a sample image |
| 364 | pipeline = StableDiffusionPipeline( | 375 | pipeline = VlpnStableDiffusion( |
| 365 | text_encoder=unwrapped, | 376 | text_encoder=unwrapped, |
| 366 | vae=self.vae, | 377 | vae=self.vae, |
| 367 | unet=self.unet, | 378 | unet=self.unet, |
| 368 | tokenizer=self.tokenizer, | 379 | tokenizer=self.tokenizer, |
| 369 | scheduler=LMSDiscreteScheduler( | 380 | scheduler=scheduler, |
| 370 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | ||
| 371 | ), | ||
| 372 | safety_checker=NoCheck(), | ||
| 373 | feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), | 381 | feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), |
| 374 | ).to(self.accelerator.device) | 382 | ).to(self.accelerator.device) |
| 375 | pipeline.enable_attention_slicing() | 383 | pipeline.enable_attention_slicing() |
| 376 | 384 | ||
| 377 | data = { | 385 | train_data = self.datamodule.train_dataloader() |
| 378 | "training": self.datamodule.train_dataloader(), | 386 | val_data = self.datamodule.val_dataloader() |
| 379 | "validation": self.datamodule.val_dataloader(), | ||
| 380 | }[mode] | ||
| 381 | 387 | ||
| 382 | if mode == "validation" and self.stable_sample_batches > 0 and step > 0: | 388 | if self.stable_sample_batches > 0: |
| 383 | stable_latents = torch.randn( | 389 | stable_latents = torch.randn( |
| 384 | (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8), | 390 | (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8), |
| 385 | device=pipeline.device, | 391 | device=pipeline.device, |
| @@ -387,14 +393,17 @@ class Checkpointer: | |||
| 387 | ) | 393 | ) |
| 388 | 394 | ||
| 389 | all_samples = [] | 395 | all_samples = [] |
| 390 | filename = f"stable_step_%d.png" % (step) | 396 | file_path = samples_path.joinpath("stable", f"step_{step}.png") |
| 397 | file_path.parent.mkdir(parents=True, exist_ok=True) | ||
| 391 | 398 | ||
| 392 | data_enum = enumerate(data) | 399 | data_enum = enumerate(val_data) |
| 393 | 400 | ||
| 394 | # Generate and save stable samples | 401 | # Generate and save stable samples |
| 395 | for i in range(0, self.stable_sample_batches): | 402 | for i in range(0, self.stable_sample_batches): |
| 396 | prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( | 403 | prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( |
| 397 | batch["prompt"]) if i * data.batch_size + j < self.sample_batch_size] | 404 | batch["prompts"]) if i * val_data.batch_size + j < self.sample_batch_size] |
| 405 | |||
| 406 | generator = torch.Generator(device="cuda").manual_seed(self.seed + i) | ||
| 398 | 407 | ||
| 399 | with self.accelerator.autocast(): | 408 | with self.accelerator.autocast(): |
| 400 | samples = pipeline( | 409 | samples = pipeline( |
| @@ -405,67 +414,64 @@ class Checkpointer: | |||
| 405 | guidance_scale=guidance_scale, | 414 | guidance_scale=guidance_scale, |
| 406 | eta=eta, | 415 | eta=eta, |
| 407 | num_inference_steps=num_inference_steps, | 416 | num_inference_steps=num_inference_steps, |
| 417 | generator=generator, | ||
| 408 | output_type='pil' | 418 | output_type='pil' |
| 409 | )["sample"] | 419 | )["sample"] |
| 410 | 420 | ||
| 411 | all_samples += samples | 421 | all_samples += samples |
| 422 | |||
| 423 | del generator | ||
| 412 | del samples | 424 | del samples |
| 413 | 425 | ||
| 414 | image_grid = make_grid(all_samples, self.stable_sample_batches, self.sample_batch_size) | 426 | image_grid = make_grid(all_samples, self.stable_sample_batches, self.sample_batch_size) |
| 415 | image_grid.save(f"{samples_path}/{filename}") | 427 | image_grid.save(file_path) |
| 416 | 428 | ||
| 417 | del all_samples | 429 | del all_samples |
| 418 | del image_grid | 430 | del image_grid |
| 419 | del stable_latents | 431 | del stable_latents |
| 420 | 432 | ||
| 421 | all_samples = [] | 433 | for data, pool in [(val_data, "val"), (train_data, "train")]: |
| 422 | filename = f"step_%d.png" % (step) | 434 | all_samples = [] |
| 435 | file_path = samples_path.joinpath(pool, f"step_{step}.png") | ||
| 436 | file_path.parent.mkdir(parents=True, exist_ok=True) | ||
| 423 | 437 | ||
| 424 | data_enum = enumerate(data) | 438 | data_enum = enumerate(data) |
| 425 | 439 | ||
| 426 | # Generate and save random samples | 440 | for i in range(0, self.random_sample_batches): |
| 427 | for i in range(0, self.random_sample_batches): | 441 | prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( |
| 428 | prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( | 442 | batch["prompts"]) if i * data.batch_size + j < self.sample_batch_size] |
| 429 | batch["prompt"]) if i * data.batch_size + j < self.sample_batch_size] | ||
| 430 | 443 | ||
| 431 | with self.accelerator.autocast(): | 444 | generator = torch.Generator(device="cuda").manual_seed(self.seed + i) |
| 432 | samples = pipeline( | ||
| 433 | prompt=prompt, | ||
| 434 | height=self.sample_image_size, | ||
| 435 | width=self.sample_image_size, | ||
| 436 | guidance_scale=guidance_scale, | ||
| 437 | eta=eta, | ||
| 438 | num_inference_steps=num_inference_steps, | ||
| 439 | output_type='pil' | ||
| 440 | )["sample"] | ||
| 441 | 445 | ||
| 442 | all_samples += samples | 446 | with self.accelerator.autocast(): |
| 443 | del samples | 447 | samples = pipeline( |
| 448 | prompt=prompt, | ||
| 449 | height=self.sample_image_size, | ||
| 450 | width=self.sample_image_size, | ||
| 451 | guidance_scale=guidance_scale, | ||
| 452 | eta=eta, | ||
| 453 | num_inference_steps=num_inference_steps, | ||
| 454 | generator=generator, | ||
| 455 | output_type='pil' | ||
| 456 | )["sample"] | ||
| 444 | 457 | ||
| 445 | image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size) | 458 | all_samples += samples |
| 446 | image_grid.save(f"{samples_path}/{filename}") | ||
| 447 | 459 | ||
| 448 | del all_samples | 460 | del generator |
| 449 | del image_grid | 461 | del samples |
| 450 | 462 | ||
| 451 | del checker | 463 | image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size) |
| 452 | del unwrapped | 464 | image_grid.save(file_path) |
| 453 | del pipeline | ||
| 454 | torch.cuda.empty_cache() | ||
| 455 | 465 | ||
| 466 | del all_samples | ||
| 467 | del image_grid | ||
| 456 | 468 | ||
| 457 | class ImageToLatents(): | 469 | del unwrapped |
| 458 | def __init__(self, vae): | 470 | del scheduler |
| 459 | self.vae = vae | 471 | del pipeline |
| 460 | self.encoded_pixel_values_cache = {} | ||
| 461 | 472 | ||
| 462 | @torch.no_grad() | 473 | if torch.cuda.is_available(): |
| 463 | def __call__(self, batch): | 474 | torch.cuda.empty_cache() |
| 464 | key = "|".join(batch["key"]) | ||
| 465 | if self.encoded_pixel_values_cache.get(key, None) is None: | ||
| 466 | self.encoded_pixel_values_cache[key] = self.vae.encode(batch["pixel_values"]).latent_dist | ||
| 467 | latents = self.encoded_pixel_values_cache[key].sample().detach().half() * 0.18215 | ||
| 468 | return latents | ||
| 469 | 475 | ||
| 470 | 476 | ||
| 471 | def main(): | 477 | def main(): |
| @@ -473,17 +479,17 @@ def main(): | |||
| 473 | 479 | ||
| 474 | global_step_offset = 0 | 480 | global_step_offset = 0 |
| 475 | if args.resume_from is not None: | 481 | if args.resume_from is not None: |
| 476 | basepath = f"{args.resume_from}" | 482 | basepath = Path(args.resume_from) |
| 477 | print("Resuming state from %s" % args.resume_from) | 483 | print("Resuming state from %s" % args.resume_from) |
| 478 | with open(f"{basepath}/resume.json", 'r') as f: | 484 | with open(basepath.joinpath("resume.json"), 'r') as f: |
| 479 | state = json.load(f) | 485 | state = json.load(f) |
| 480 | global_step_offset = state["args"].get("global_step", 0) | 486 | global_step_offset = state["args"].get("global_step", 0) |
| 481 | 487 | ||
| 482 | print("We've trained %d steps so far" % global_step_offset) | 488 | print("We've trained %d steps so far" % global_step_offset) |
| 483 | else: | 489 | else: |
| 484 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | 490 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") |
| 485 | basepath = f"{args.output_dir}/{slugify(args.placeholder_token)}/{now}" | 491 | basepath = Path(args.output_dir).joinpath(slugify(args.placeholder_token), now) |
| 486 | os.makedirs(basepath, exist_ok=True) | 492 | basepath.mkdir(parents=True, exist_ok=True) |
| 487 | 493 | ||
| 488 | accelerator = Accelerator( | 494 | accelerator = Accelerator( |
| 489 | log_with=LoggerType.TENSORBOARD, | 495 | log_with=LoggerType.TENSORBOARD, |
| @@ -492,6 +498,8 @@ def main(): | |||
| 492 | mixed_precision=args.mixed_precision | 498 | mixed_precision=args.mixed_precision |
| 493 | ) | 499 | ) |
| 494 | 500 | ||
| 501 | logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) | ||
| 502 | |||
| 495 | # If passed along, set the training seed now. | 503 | # If passed along, set the training seed now. |
| 496 | if args.seed is not None: | 504 | if args.seed is not None: |
| 497 | set_seed(args.seed) | 505 | set_seed(args.seed) |
| @@ -570,8 +578,19 @@ def main(): | |||
| 570 | args.train_batch_size * accelerator.num_processes | 578 | args.train_batch_size * accelerator.num_processes |
| 571 | ) | 579 | ) |
| 572 | 580 | ||
| 581 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs | ||
| 582 | if args.use_8bit_adam: | ||
| 583 | try: | ||
| 584 | import bitsandbytes as bnb | ||
| 585 | except ImportError: | ||
| 586 | raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.") | ||
| 587 | |||
| 588 | optimizer_class = bnb.optim.AdamW8bit | ||
| 589 | else: | ||
| 590 | optimizer_class = torch.optim.AdamW | ||
| 591 | |||
| 573 | # Initialize the optimizer | 592 | # Initialize the optimizer |
| 574 | optimizer = torch.optim.AdamW( | 593 | optimizer = optimizer_class( |
| 575 | text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings | 594 | text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings |
| 576 | lr=args.learning_rate, | 595 | lr=args.learning_rate, |
| 577 | betas=(args.adam_beta1, args.adam_beta2), | 596 | betas=(args.adam_beta1, args.adam_beta2), |
| @@ -585,7 +604,7 @@ def main(): | |||
| 585 | ) | 604 | ) |
| 586 | 605 | ||
| 587 | datamodule = CSVDataModule( | 606 | datamodule = CSVDataModule( |
| 588 | data_root=args.train_data_dir, batch_size=args.train_batch_size, tokenizer=tokenizer, | 607 | data_file=args.train_data_file, batch_size=args.train_batch_size, tokenizer=tokenizer, |
| 589 | size=args.resolution, placeholder_token=args.placeholder_token, repeats=args.repeats, | 608 | size=args.resolution, placeholder_token=args.placeholder_token, repeats=args.repeats, |
| 590 | center_crop=args.center_crop) | 609 | center_crop=args.center_crop) |
| 591 | 610 | ||
| @@ -608,13 +627,12 @@ def main(): | |||
| 608 | sample_batch_size=args.sample_batch_size, | 627 | sample_batch_size=args.sample_batch_size, |
| 609 | random_sample_batches=args.random_sample_batches, | 628 | random_sample_batches=args.random_sample_batches, |
| 610 | stable_sample_batches=args.stable_sample_batches, | 629 | stable_sample_batches=args.stable_sample_batches, |
| 611 | seed=args.seed | 630 | seed=args.seed or torch.random.seed() |
| 612 | ) | 631 | ) |
| 613 | 632 | ||
| 614 | # Scheduler and math around the number of training steps. | 633 | # Scheduler and math around the number of training steps. |
| 615 | overrode_max_train_steps = False | 634 | overrode_max_train_steps = False |
| 616 | num_update_steps_per_epoch = math.ceil( | 635 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) |
| 617 | (len(train_dataloader) + len(val_dataloader)) / args.gradient_accumulation_steps) | ||
| 618 | if args.max_train_steps is None: | 636 | if args.max_train_steps is None: |
| 619 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch | 637 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch |
| 620 | overrode_max_train_steps = True | 638 | overrode_max_train_steps = True |
| @@ -643,9 +661,10 @@ def main(): | |||
| 643 | (len(train_dataloader) + len(val_dataloader)) / args.gradient_accumulation_steps) | 661 | (len(train_dataloader) + len(val_dataloader)) / args.gradient_accumulation_steps) |
| 644 | if overrode_max_train_steps: | 662 | if overrode_max_train_steps: |
| 645 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch | 663 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch |
| 646 | # Afterwards we recalculate our number of training epochs | 664 | |
| 647 | args.num_train_epochs = math.ceil( | 665 | num_val_steps_per_epoch = len(val_dataloader) |
| 648 | args.max_train_steps / num_update_steps_per_epoch) | 666 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) |
| 667 | val_steps = num_val_steps_per_epoch * num_epochs | ||
| 649 | 668 | ||
| 650 | # We need to initialize the trackers we use, and also store our configuration. | 669 | # We need to initialize the trackers we use, and also store our configuration. |
| 651 | # The trackers initializes automatically on the main process. | 670 | # The trackers initializes automatically on the main process. |
| @@ -656,7 +675,7 @@ def main(): | |||
| 656 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps | 675 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps |
| 657 | 676 | ||
| 658 | logger.info("***** Running training *****") | 677 | logger.info("***** Running training *****") |
| 659 | logger.info(f" Num Epochs = {args.num_train_epochs}") | 678 | logger.info(f" Num Epochs = {num_epochs}") |
| 660 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") | 679 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") |
| 661 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") | 680 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") |
| 662 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") | 681 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") |
| @@ -666,22 +685,22 @@ def main(): | |||
| 666 | global_step = 0 | 685 | global_step = 0 |
| 667 | min_val_loss = np.inf | 686 | min_val_loss = np.inf |
| 668 | 687 | ||
| 669 | imageToLatents = ImageToLatents(vae) | 688 | if accelerator.is_main_process: |
| 670 | 689 | checkpointer.save_samples( | |
| 671 | checkpointer.save_samples( | 690 | 0, |
| 672 | "validation", | 691 | text_encoder, |
| 673 | 0, | 692 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) |
| 674 | text_encoder, | ||
| 675 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) | ||
| 676 | 693 | ||
| 677 | progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) | 694 | local_progress_bar = tqdm(range(num_update_steps_per_epoch + num_val_steps_per_epoch), |
| 678 | progress_bar.set_description("Global steps") | 695 | disable=not accelerator.is_local_main_process) |
| 696 | local_progress_bar.set_description("Batch X out of Y") | ||
| 679 | 697 | ||
| 680 | local_progress_bar = tqdm(range(num_update_steps_per_epoch), disable=not accelerator.is_local_main_process) | 698 | global_progress_bar = tqdm(range(args.max_train_steps + val_steps), disable=not accelerator.is_local_main_process) |
| 681 | local_progress_bar.set_description("Steps") | 699 | global_progress_bar.set_description("Total progress") |
| 682 | 700 | ||
| 683 | try: | 701 | try: |
| 684 | for epoch in range(args.num_train_epochs): | 702 | for epoch in range(num_epochs): |
| 703 | local_progress_bar.set_description(f"Batch {epoch + 1} out of {num_epochs}") | ||
| 685 | local_progress_bar.reset() | 704 | local_progress_bar.reset() |
| 686 | 705 | ||
| 687 | text_encoder.train() | 706 | text_encoder.train() |
| @@ -689,27 +708,30 @@ def main(): | |||
| 689 | 708 | ||
| 690 | for step, batch in enumerate(train_dataloader): | 709 | for step, batch in enumerate(train_dataloader): |
| 691 | with accelerator.accumulate(text_encoder): | 710 | with accelerator.accumulate(text_encoder): |
| 692 | with accelerator.autocast(): | 711 | # Convert images to latent space |
| 693 | # Convert images to latent space | 712 | with torch.no_grad(): |
| 694 | latents = imageToLatents(batch) | 713 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() |
| 714 | latents = latents * 0.18215 | ||
| 695 | 715 | ||
| 696 | # Sample noise that we'll add to the latents | 716 | # Sample noise that we'll add to the latents |
| 697 | noise = torch.randn(latents.shape).to(latents.device) | 717 | noise = torch.randn(latents.shape).to(latents.device) |
| 698 | bsz = latents.shape[0] | 718 | bsz = latents.shape[0] |
| 699 | # Sample a random timestep for each image | 719 | # Sample a random timestep for each image |
| 700 | timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, | 720 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, |
| 701 | (bsz,), device=latents.device).long() | 721 | (bsz,), device=latents.device) |
| 722 | timesteps = timesteps.long() | ||
| 702 | 723 | ||
| 703 | # Add noise to the latents according to the noise magnitude at each timestep | 724 | # Add noise to the latents according to the noise magnitude at each timestep |
| 704 | # (this is the forward diffusion process) | 725 | # (this is the forward diffusion process) |
| 705 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | 726 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) |
| 706 | 727 | ||
| 707 | # Get the text embedding for conditioning | 728 | # Get the text embedding for conditioning |
| 708 | encoder_hidden_states = text_encoder(batch["input_ids"])[0] | 729 | encoder_hidden_states = text_encoder(batch["input_ids"])[0] |
| 709 | 730 | ||
| 710 | # Predict the noise residual | 731 | # Predict the noise residual |
| 711 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 732 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
| 712 | 733 | ||
| 734 | with accelerator.autocast(): | ||
| 713 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | 735 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() |
| 714 | 736 | ||
| 715 | accelerator.backward(loss) | 737 | accelerator.backward(loss) |
| @@ -727,32 +749,27 @@ def main(): | |||
| 727 | optimizer.step() | 749 | optimizer.step() |
| 728 | if not accelerator.optimizer_step_was_skipped: | 750 | if not accelerator.optimizer_step_was_skipped: |
| 729 | lr_scheduler.step() | 751 | lr_scheduler.step() |
| 730 | optimizer.zero_grad() | 752 | optimizer.zero_grad(set_to_none=True) |
| 731 | 753 | ||
| 732 | loss = loss.detach().item() | 754 | loss = loss.detach().item() |
| 733 | train_loss += loss | 755 | train_loss += loss |
| 734 | 756 | ||
| 735 | # Checks if the accelerator has performed an optimization step behind the scenes | 757 | # Checks if the accelerator has performed an optimization step behind the scenes |
| 736 | if accelerator.sync_gradients: | 758 | if accelerator.sync_gradients: |
| 737 | progress_bar.update(1) | ||
| 738 | local_progress_bar.update(1) | 759 | local_progress_bar.update(1) |
| 760 | global_progress_bar.update(1) | ||
| 739 | 761 | ||
| 740 | global_step += 1 | 762 | global_step += 1 |
| 741 | 763 | ||
| 742 | if global_step % args.checkpoint_frequency == 0 and global_step > 0 and accelerator.is_main_process: | 764 | if global_step % args.checkpoint_frequency == 0 and global_step > 0 and accelerator.is_main_process: |
| 743 | progress_bar.clear() | ||
| 744 | local_progress_bar.clear() | 765 | local_progress_bar.clear() |
| 766 | global_progress_bar.clear() | ||
| 745 | 767 | ||
| 746 | checkpointer.checkpoint(global_step + global_step_offset, "training", text_encoder) | 768 | checkpointer.checkpoint(global_step + global_step_offset, "training", text_encoder) |
| 747 | save_resume_file(basepath, args, { | 769 | save_resume_file(basepath, args, { |
| 748 | "global_step": global_step + global_step_offset, | 770 | "global_step": global_step + global_step_offset, |
| 749 | "resume_checkpoint": f"{basepath}/checkpoints/last.bin" | 771 | "resume_checkpoint": f"{basepath}/checkpoints/last.bin" |
| 750 | }) | 772 | }) |
| 751 | checkpointer.save_samples( | ||
| 752 | "training", | ||
| 753 | global_step + global_step_offset, | ||
| 754 | text_encoder, | ||
| 755 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) | ||
| 756 | 773 | ||
| 757 | logs = {"mode": "training", "loss": loss, "lr": lr_scheduler.get_last_lr()[0]} | 774 | logs = {"mode": "training", "loss": loss, "lr": lr_scheduler.get_last_lr()[0]} |
| 758 | local_progress_bar.set_postfix(**logs) | 775 | local_progress_bar.set_postfix(**logs) |
| @@ -762,17 +779,21 @@ def main(): | |||
| 762 | 779 | ||
| 763 | train_loss /= len(train_dataloader) | 780 | train_loss /= len(train_dataloader) |
| 764 | 781 | ||
| 782 | accelerator.wait_for_everyone() | ||
| 783 | |||
| 765 | text_encoder.eval() | 784 | text_encoder.eval() |
| 766 | val_loss = 0.0 | 785 | val_loss = 0.0 |
| 767 | 786 | ||
| 768 | for step, batch in enumerate(val_dataloader): | 787 | for step, batch in enumerate(val_dataloader): |
| 769 | with torch.no_grad(), accelerator.autocast(): | 788 | with torch.no_grad(): |
| 770 | latents = imageToLatents(batch) | 789 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() |
| 790 | latents = latents * 0.18215 | ||
| 771 | 791 | ||
| 772 | noise = torch.randn(latents.shape).to(latents.device) | 792 | noise = torch.randn(latents.shape).to(latents.device) |
| 773 | bsz = latents.shape[0] | 793 | bsz = latents.shape[0] |
| 774 | timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, | 794 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, |
| 775 | (bsz,), device=latents.device).long() | 795 | (bsz,), device=latents.device) |
| 796 | timesteps = timesteps.long() | ||
| 776 | 797 | ||
| 777 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | 798 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) |
| 778 | 799 | ||
| @@ -782,14 +803,15 @@ def main(): | |||
| 782 | 803 | ||
| 783 | noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) | 804 | noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) |
| 784 | 805 | ||
| 785 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | 806 | with accelerator.autocast(): |
| 807 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | ||
| 786 | 808 | ||
| 787 | loss = loss.detach().item() | 809 | loss = loss.detach().item() |
| 788 | val_loss += loss | 810 | val_loss += loss |
| 789 | 811 | ||
| 790 | if accelerator.sync_gradients: | 812 | if accelerator.sync_gradients: |
| 791 | progress_bar.update(1) | ||
| 792 | local_progress_bar.update(1) | 813 | local_progress_bar.update(1) |
| 814 | global_progress_bar.update(1) | ||
| 793 | 815 | ||
| 794 | logs = {"mode": "validation", "loss": loss} | 816 | logs = {"mode": "validation", "loss": loss} |
| 795 | local_progress_bar.set_postfix(**logs) | 817 | local_progress_bar.set_postfix(**logs) |
| @@ -798,21 +820,19 @@ def main(): | |||
| 798 | 820 | ||
| 799 | accelerator.log({"train/loss": train_loss, "val/loss": val_loss}, step=global_step) | 821 | accelerator.log({"train/loss": train_loss, "val/loss": val_loss}, step=global_step) |
| 800 | 822 | ||
| 801 | progress_bar.clear() | ||
| 802 | local_progress_bar.clear() | 823 | local_progress_bar.clear() |
| 824 | global_progress_bar.clear() | ||
| 803 | 825 | ||
| 804 | if min_val_loss > val_loss: | 826 | if min_val_loss > val_loss: |
| 805 | accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") | 827 | accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") |
| 806 | checkpointer.checkpoint(global_step + global_step_offset, "milestone", text_encoder) | 828 | checkpointer.checkpoint(global_step + global_step_offset, "milestone", text_encoder) |
| 807 | min_val_loss = val_loss | 829 | min_val_loss = val_loss |
| 808 | 830 | ||
| 809 | checkpointer.save_samples( | 831 | if accelerator.is_main_process: |
| 810 | "validation", | 832 | checkpointer.save_samples( |
| 811 | global_step + global_step_offset, | 833 | global_step + global_step_offset, |
| 812 | text_encoder, | 834 | text_encoder, |
| 813 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) | 835 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) |
| 814 | |||
| 815 | accelerator.wait_for_everyone() | ||
| 816 | 836 | ||
| 817 | # Create the pipeline using using the trained modules and save it. | 837 | # Create the pipeline using using the trained modules and save it. |
| 818 | if accelerator.is_main_process: | 838 | if accelerator.is_main_process: |
