diff options
-rw-r--r-- | data/dreambooth/csv.py | 1 | ||||
-rw-r--r-- | data/textual_inversion/csv.py | 98 | ||||
-rw-r--r-- | dreambooth.py | 26 | ||||
-rw-r--r-- | textual_inversion.py | 308 |
4 files changed, 230 insertions, 203 deletions
diff --git a/data/dreambooth/csv.py b/data/dreambooth/csv.py index 99bcf12..1676d35 100644 --- a/data/dreambooth/csv.py +++ b/data/dreambooth/csv.py | |||
@@ -2,7 +2,6 @@ import math | |||
2 | import os | 2 | import os |
3 | import pandas as pd | 3 | import pandas as pd |
4 | from pathlib import Path | 4 | from pathlib import Path |
5 | import PIL | ||
6 | import pytorch_lightning as pl | 5 | import pytorch_lightning as pl |
7 | from PIL import Image | 6 | from PIL import Image |
8 | from torch.utils.data import Dataset, DataLoader, random_split | 7 | from torch.utils.data import Dataset, DataLoader, random_split |
diff --git a/data/textual_inversion/csv.py b/data/textual_inversion/csv.py index 0d1e96e..f306c7a 100644 --- a/data/textual_inversion/csv.py +++ b/data/textual_inversion/csv.py | |||
@@ -1,11 +1,10 @@ | |||
1 | import os | 1 | import os |
2 | import numpy as np | 2 | import numpy as np |
3 | import pandas as pd | 3 | import pandas as pd |
4 | import random | 4 | from pathlib import Path |
5 | import PIL | 5 | import math |
6 | import pytorch_lightning as pl | 6 | import pytorch_lightning as pl |
7 | from PIL import Image | 7 | from PIL import Image |
8 | import torch | ||
9 | from torch.utils.data import Dataset, DataLoader, random_split | 8 | from torch.utils.data import Dataset, DataLoader, random_split |
10 | from torchvision import transforms | 9 | from torchvision import transforms |
11 | 10 | ||
@@ -13,29 +12,32 @@ from torchvision import transforms | |||
13 | class CSVDataModule(pl.LightningDataModule): | 12 | class CSVDataModule(pl.LightningDataModule): |
14 | def __init__(self, | 13 | def __init__(self, |
15 | batch_size, | 14 | batch_size, |
16 | data_root, | 15 | data_file, |
17 | tokenizer, | 16 | tokenizer, |
18 | size=512, | 17 | size=512, |
19 | repeats=100, | 18 | repeats=100, |
20 | interpolation="bicubic", | 19 | interpolation="bicubic", |
21 | placeholder_token="*", | 20 | placeholder_token="*", |
22 | flip_p=0.5, | ||
23 | center_crop=False): | 21 | center_crop=False): |
24 | super().__init__() | 22 | super().__init__() |
25 | 23 | ||
26 | self.data_root = data_root | 24 | self.data_file = Path(data_file) |
25 | |||
26 | if not self.data_file.is_file(): | ||
27 | raise ValueError("data_file must be a file") | ||
28 | |||
29 | self.data_root = self.data_file.parent | ||
27 | self.tokenizer = tokenizer | 30 | self.tokenizer = tokenizer |
28 | self.size = size | 31 | self.size = size |
29 | self.repeats = repeats | 32 | self.repeats = repeats |
30 | self.placeholder_token = placeholder_token | 33 | self.placeholder_token = placeholder_token |
31 | self.center_crop = center_crop | 34 | self.center_crop = center_crop |
32 | self.flip_p = flip_p | ||
33 | self.interpolation = interpolation | 35 | self.interpolation = interpolation |
34 | 36 | ||
35 | self.batch_size = batch_size | 37 | self.batch_size = batch_size |
36 | 38 | ||
37 | def prepare_data(self): | 39 | def prepare_data(self): |
38 | metadata = pd.read_csv(f'{self.data_root}/list.csv') | 40 | metadata = pd.read_csv(self.data_file) |
39 | image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values] | 41 | image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values] |
40 | captions = [caption for caption in metadata['caption'].values] | 42 | captions = [caption for caption in metadata['caption'].values] |
41 | skips = [skip for skip in metadata['skip'].values] | 43 | skips = [skip for skip in metadata['skip'].values] |
@@ -47,9 +49,9 @@ class CSVDataModule(pl.LightningDataModule): | |||
47 | self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size]) | 49 | self.data_train, self.data_val = random_split(self.data_full, [train_set_size, valid_set_size]) |
48 | 50 | ||
49 | train_dataset = CSVDataset(self.data_train, self.tokenizer, size=self.size, repeats=self.repeats, interpolation=self.interpolation, | 51 | train_dataset = CSVDataset(self.data_train, self.tokenizer, size=self.size, repeats=self.repeats, interpolation=self.interpolation, |
50 | flip_p=self.flip_p, placeholder_token=self.placeholder_token, center_crop=self.center_crop) | 52 | placeholder_token=self.placeholder_token, center_crop=self.center_crop) |
51 | val_dataset = CSVDataset(self.data_val, self.tokenizer, size=self.size, interpolation=self.interpolation, | 53 | val_dataset = CSVDataset(self.data_val, self.tokenizer, size=self.size, interpolation=self.interpolation, |
52 | flip_p=self.flip_p, placeholder_token=self.placeholder_token, center_crop=self.center_crop) | 54 | placeholder_token=self.placeholder_token, center_crop=self.center_crop) |
53 | self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True) | 55 | self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True) |
54 | self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size) | 56 | self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size) |
55 | 57 | ||
@@ -67,48 +69,54 @@ class CSVDataset(Dataset): | |||
67 | size=512, | 69 | size=512, |
68 | repeats=1, | 70 | repeats=1, |
69 | interpolation="bicubic", | 71 | interpolation="bicubic", |
70 | flip_p=0.5, | ||
71 | placeholder_token="*", | 72 | placeholder_token="*", |
72 | center_crop=False, | 73 | center_crop=False, |
74 | batch_size=1, | ||
73 | ): | 75 | ): |
74 | 76 | ||
75 | self.data = data | 77 | self.data = data |
76 | self.tokenizer = tokenizer | 78 | self.tokenizer = tokenizer |
77 | |||
78 | self.num_images = len(self.data) | ||
79 | self._length = self.num_images * repeats | ||
80 | |||
81 | self.placeholder_token = placeholder_token | 79 | self.placeholder_token = placeholder_token |
80 | self.batch_size = batch_size | ||
81 | self.cache = {} | ||
82 | 82 | ||
83 | self.size = size | 83 | self.num_instance_images = len(self.data) |
84 | self.center_crop = center_crop | 84 | self._length = self.num_instance_images * repeats |
85 | self.interpolation = {"linear": PIL.Image.LINEAR, | ||
86 | "bilinear": PIL.Image.BILINEAR, | ||
87 | "bicubic": PIL.Image.BICUBIC, | ||
88 | "lanczos": PIL.Image.LANCZOS, | ||
89 | }[interpolation] | ||
90 | self.flip = transforms.RandomHorizontalFlip(p=flip_p) | ||
91 | 85 | ||
92 | self.cache = {} | 86 | self.interpolation = {"linear": transforms.InterpolationMode.NEAREST, |
87 | "bilinear": transforms.InterpolationMode.BILINEAR, | ||
88 | "bicubic": transforms.InterpolationMode.BICUBIC, | ||
89 | "lanczos": transforms.InterpolationMode.LANCZOS, | ||
90 | }[interpolation] | ||
91 | self.image_transforms = transforms.Compose( | ||
92 | [ | ||
93 | transforms.Resize(size, interpolation=self.interpolation), | ||
94 | transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), | ||
95 | transforms.RandomHorizontalFlip(), | ||
96 | transforms.ToTensor(), | ||
97 | transforms.Normalize([0.5], [0.5]), | ||
98 | ] | ||
99 | ) | ||
93 | 100 | ||
94 | def __len__(self): | 101 | def __len__(self): |
95 | return self._length | 102 | return math.ceil(self._length / self.batch_size) * self.batch_size |
96 | 103 | ||
97 | def get_example(self, i, flipped): | 104 | def get_example(self, i): |
98 | image_path, text = self.data[i % self.num_images] | 105 | image_path, text = self.data[i % self.num_instance_images] |
99 | 106 | ||
100 | if image_path in self.cache: | 107 | if image_path in self.cache: |
101 | return self.cache[image_path] | 108 | return self.cache[image_path] |
102 | 109 | ||
103 | example = {} | 110 | example = {} |
104 | image = Image.open(image_path) | ||
105 | 111 | ||
106 | if not image.mode == "RGB": | 112 | instance_image = Image.open(image_path) |
107 | image = image.convert("RGB") | 113 | if not instance_image.mode == "RGB": |
114 | instance_image = instance_image.convert("RGB") | ||
108 | 115 | ||
109 | text = text.format(self.placeholder_token) | 116 | text = text.format(self.placeholder_token) |
110 | 117 | ||
111 | example["prompt"] = text | 118 | example["prompts"] = text |
119 | example["pixel_values"] = instance_image | ||
112 | example["input_ids"] = self.tokenizer( | 120 | example["input_ids"] = self.tokenizer( |
113 | text, | 121 | text, |
114 | padding="max_length", | 122 | padding="max_length", |
@@ -117,29 +125,15 @@ class CSVDataset(Dataset): | |||
117 | return_tensors="pt", | 125 | return_tensors="pt", |
118 | ).input_ids[0] | 126 | ).input_ids[0] |
119 | 127 | ||
120 | # default to score-sde preprocessing | ||
121 | img = np.array(image).astype(np.uint8) | ||
122 | |||
123 | if self.center_crop: | ||
124 | crop = min(img.shape[0], img.shape[1]) | ||
125 | h, w, = img.shape[0], img.shape[1] | ||
126 | img = img[(h - crop) // 2:(h + crop) // 2, | ||
127 | (w - crop) // 2:(w + crop) // 2] | ||
128 | |||
129 | image = Image.fromarray(img) | ||
130 | image = image.resize((self.size, self.size), | ||
131 | resample=self.interpolation) | ||
132 | image = self.flip(image) | ||
133 | image = np.array(image).astype(np.uint8) | ||
134 | image = (image / 127.5 - 1.0).astype(np.float32) | ||
135 | |||
136 | example["key"] = "-".join([image_path, "-", str(flipped)]) | ||
137 | example["pixel_values"] = torch.from_numpy(image).permute(2, 0, 1) | ||
138 | |||
139 | self.cache[image_path] = example | 128 | self.cache[image_path] = example |
140 | return example | 129 | return example |
141 | 130 | ||
142 | def __getitem__(self, i): | 131 | def __getitem__(self, i): |
143 | flipped = random.choice([False, True]) | 132 | example = {} |
144 | example = self.get_example(i, flipped) | 133 | unprocessed_example = self.get_example(i) |
134 | |||
135 | example["prompts"] = unprocessed_example["prompts"] | ||
136 | example["input_ids"] = unprocessed_example["input_ids"] | ||
137 | example["pixel_values"] = self.image_transforms(unprocessed_example["pixel_values"]) | ||
138 | |||
145 | return example | 139 | return example |
diff --git a/dreambooth.py b/dreambooth.py index 4d7366c..744d1bc 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -14,12 +14,14 @@ from accelerate import Accelerator | |||
14 | from accelerate.logging import get_logger | 14 | from accelerate.logging import get_logger |
15 | from accelerate.utils import LoggerType, set_seed | 15 | from accelerate.utils import LoggerType, set_seed |
16 | from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, LMSDiscreteScheduler, StableDiffusionPipeline, UNet2DConditionModel | 16 | from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, LMSDiscreteScheduler, StableDiffusionPipeline, UNet2DConditionModel |
17 | from schedulers.scheduling_euler_a import EulerAScheduler | ||
17 | from diffusers.optimization import get_scheduler | 18 | from diffusers.optimization import get_scheduler |
18 | from pipelines.stable_diffusion.no_check import NoCheck | 19 | from pipelines.stable_diffusion.no_check import NoCheck |
19 | from PIL import Image | 20 | from PIL import Image |
20 | from tqdm.auto import tqdm | 21 | from tqdm.auto import tqdm |
21 | from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer | 22 | from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer |
22 | from slugify import slugify | 23 | from slugify import slugify |
24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | ||
23 | import json | 25 | import json |
24 | 26 | ||
25 | from data.dreambooth.csv import CSVDataModule | 27 | from data.dreambooth.csv import CSVDataModule |
@@ -215,7 +217,7 @@ def parse_args(): | |||
215 | parser.add_argument( | 217 | parser.add_argument( |
216 | "--sample_steps", | 218 | "--sample_steps", |
217 | type=int, | 219 | type=int, |
218 | default=80, | 220 | default=30, |
219 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", | 221 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", |
220 | ) | 222 | ) |
221 | parser.add_argument( | 223 | parser.add_argument( |
@@ -377,15 +379,16 @@ class Checkpointer: | |||
377 | samples_path = Path(self.output_dir).joinpath("samples") | 379 | samples_path = Path(self.output_dir).joinpath("samples") |
378 | 380 | ||
379 | unwrapped = self.accelerator.unwrap_model(self.unet) | 381 | unwrapped = self.accelerator.unwrap_model(self.unet) |
380 | pipeline = StableDiffusionPipeline( | 382 | scheduler = EulerAScheduler( |
383 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | ||
384 | ) | ||
385 | |||
386 | pipeline = VlpnStableDiffusion( | ||
381 | text_encoder=self.text_encoder, | 387 | text_encoder=self.text_encoder, |
382 | vae=self.vae, | 388 | vae=self.vae, |
383 | unet=unwrapped, | 389 | unet=unwrapped, |
384 | tokenizer=self.tokenizer, | 390 | tokenizer=self.tokenizer, |
385 | scheduler=LMSDiscreteScheduler( | 391 | scheduler=scheduler, |
386 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | ||
387 | ), | ||
388 | safety_checker=NoCheck(), | ||
389 | feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), | 392 | feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), |
390 | ).to(self.accelerator.device) | 393 | ).to(self.accelerator.device) |
391 | pipeline.enable_attention_slicing() | 394 | pipeline.enable_attention_slicing() |
@@ -411,6 +414,8 @@ class Checkpointer: | |||
411 | prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( | 414 | prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( |
412 | batch["prompts"]) if i * val_data.batch_size + j < self.sample_batch_size] | 415 | batch["prompts"]) if i * val_data.batch_size + j < self.sample_batch_size] |
413 | 416 | ||
417 | generator = torch.Generator(device="cuda").manual_seed(self.seed + i) | ||
418 | |||
414 | with self.accelerator.autocast(): | 419 | with self.accelerator.autocast(): |
415 | samples = pipeline( | 420 | samples = pipeline( |
416 | prompt=prompt, | 421 | prompt=prompt, |
@@ -420,10 +425,13 @@ class Checkpointer: | |||
420 | guidance_scale=guidance_scale, | 425 | guidance_scale=guidance_scale, |
421 | eta=eta, | 426 | eta=eta, |
422 | num_inference_steps=num_inference_steps, | 427 | num_inference_steps=num_inference_steps, |
428 | generator=generator, | ||
423 | output_type='pil' | 429 | output_type='pil' |
424 | )["sample"] | 430 | )["sample"] |
425 | 431 | ||
426 | all_samples += samples | 432 | all_samples += samples |
433 | |||
434 | del generator | ||
427 | del samples | 435 | del samples |
428 | 436 | ||
429 | image_grid = make_grid(all_samples, self.stable_sample_batches, self.sample_batch_size) | 437 | image_grid = make_grid(all_samples, self.stable_sample_batches, self.sample_batch_size) |
@@ -444,6 +452,8 @@ class Checkpointer: | |||
444 | prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( | 452 | prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( |
445 | batch["prompts"]) if i * data.batch_size + j < self.sample_batch_size] | 453 | batch["prompts"]) if i * data.batch_size + j < self.sample_batch_size] |
446 | 454 | ||
455 | generator = torch.Generator(device="cuda").manual_seed(self.seed + i) | ||
456 | |||
447 | with self.accelerator.autocast(): | 457 | with self.accelerator.autocast(): |
448 | samples = pipeline( | 458 | samples = pipeline( |
449 | prompt=prompt, | 459 | prompt=prompt, |
@@ -452,10 +462,13 @@ class Checkpointer: | |||
452 | guidance_scale=guidance_scale, | 462 | guidance_scale=guidance_scale, |
453 | eta=eta, | 463 | eta=eta, |
454 | num_inference_steps=num_inference_steps, | 464 | num_inference_steps=num_inference_steps, |
465 | generator=generator, | ||
455 | output_type='pil' | 466 | output_type='pil' |
456 | )["sample"] | 467 | )["sample"] |
457 | 468 | ||
458 | all_samples += samples | 469 | all_samples += samples |
470 | |||
471 | del generator | ||
459 | del samples | 472 | del samples |
460 | 473 | ||
461 | image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size) | 474 | image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size) |
@@ -465,6 +478,7 @@ class Checkpointer: | |||
465 | del image_grid | 478 | del image_grid |
466 | 479 | ||
467 | del unwrapped | 480 | del unwrapped |
481 | del scheduler | ||
468 | del pipeline | 482 | del pipeline |
469 | 483 | ||
470 | if torch.cuda.is_available(): | 484 | if torch.cuda.is_available(): |
diff --git a/textual_inversion.py b/textual_inversion.py index 399d876..7a7d7fc 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
@@ -3,6 +3,8 @@ import itertools | |||
3 | import math | 3 | import math |
4 | import os | 4 | import os |
5 | import datetime | 5 | import datetime |
6 | import logging | ||
7 | from pathlib import Path | ||
6 | 8 | ||
7 | import numpy as np | 9 | import numpy as np |
8 | import torch | 10 | import torch |
@@ -13,12 +15,13 @@ from accelerate import Accelerator | |||
13 | from accelerate.logging import get_logger | 15 | from accelerate.logging import get_logger |
14 | from accelerate.utils import LoggerType, set_seed | 16 | from accelerate.utils import LoggerType, set_seed |
15 | from diffusers import AutoencoderKL, DDPMScheduler, LMSDiscreteScheduler, StableDiffusionPipeline, UNet2DConditionModel | 17 | from diffusers import AutoencoderKL, DDPMScheduler, LMSDiscreteScheduler, StableDiffusionPipeline, UNet2DConditionModel |
18 | from schedulers.scheduling_euler_a import EulerAScheduler | ||
16 | from diffusers.optimization import get_scheduler | 19 | from diffusers.optimization import get_scheduler |
17 | from pipelines.stable_diffusion.no_check import NoCheck | ||
18 | from PIL import Image | 20 | from PIL import Image |
19 | from tqdm.auto import tqdm | 21 | from tqdm.auto import tqdm |
20 | from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer | 22 | from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer |
21 | from slugify import slugify | 23 | from slugify import slugify |
24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | ||
22 | import json | 25 | import json |
23 | import os | 26 | import os |
24 | 27 | ||
@@ -44,10 +47,10 @@ def parse_args(): | |||
44 | help="Pretrained tokenizer name or path if not the same as model_name", | 47 | help="Pretrained tokenizer name or path if not the same as model_name", |
45 | ) | 48 | ) |
46 | parser.add_argument( | 49 | parser.add_argument( |
47 | "--train_data_dir", | 50 | "--train_data_file", |
48 | type=str, | 51 | type=str, |
49 | default=None, | 52 | default=None, |
50 | help="A folder containing the training data." | 53 | help="A CSV file containing the training data." |
51 | ) | 54 | ) |
52 | parser.add_argument( | 55 | parser.add_argument( |
53 | "--placeholder_token", | 56 | "--placeholder_token", |
@@ -146,6 +149,11 @@ def parse_args(): | |||
146 | help="Number of steps for the warmup in the lr scheduler." | 149 | help="Number of steps for the warmup in the lr scheduler." |
147 | ) | 150 | ) |
148 | parser.add_argument( | 151 | parser.add_argument( |
152 | "--use_8bit_adam", | ||
153 | action="store_true", | ||
154 | help="Whether or not to use 8-bit Adam from bitsandbytes." | ||
155 | ) | ||
156 | parser.add_argument( | ||
149 | "--adam_beta1", | 157 | "--adam_beta1", |
150 | type=float, | 158 | type=float, |
151 | default=0.9, | 159 | default=0.9, |
@@ -225,7 +233,7 @@ def parse_args(): | |||
225 | parser.add_argument( | 233 | parser.add_argument( |
226 | "--sample_steps", | 234 | "--sample_steps", |
227 | type=int, | 235 | type=int, |
228 | default=50, | 236 | default=30, |
229 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", | 237 | help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", |
230 | ) | 238 | ) |
231 | parser.add_argument( | 239 | parser.add_argument( |
@@ -261,8 +269,8 @@ def parse_args(): | |||
261 | if env_local_rank != -1 and env_local_rank != args.local_rank: | 269 | if env_local_rank != -1 and env_local_rank != args.local_rank: |
262 | args.local_rank = env_local_rank | 270 | args.local_rank = env_local_rank |
263 | 271 | ||
264 | if args.train_data_dir is None: | 272 | if args.train_data_file is None: |
265 | raise ValueError("You must specify --train_data_dir") | 273 | raise ValueError("You must specify --train_data_file") |
266 | 274 | ||
267 | if args.pretrained_model_name_or_path is None: | 275 | if args.pretrained_model_name_or_path is None: |
268 | raise ValueError("You must specify --pretrained_model_name_or_path") | 276 | raise ValueError("You must specify --pretrained_model_name_or_path") |
@@ -333,53 +341,51 @@ class Checkpointer: | |||
333 | @torch.no_grad() | 341 | @torch.no_grad() |
334 | def checkpoint(self, step, postfix, text_encoder, save_samples=True, path=None): | 342 | def checkpoint(self, step, postfix, text_encoder, save_samples=True, path=None): |
335 | print("Saving checkpoint for step %d..." % step) | 343 | print("Saving checkpoint for step %d..." % step) |
336 | with self.accelerator.autocast(): | 344 | |
337 | if path is None: | 345 | if path is None: |
338 | checkpoints_path = f"{self.output_dir}/checkpoints" | 346 | checkpoints_path = f"{self.output_dir}/checkpoints" |
339 | os.makedirs(checkpoints_path, exist_ok=True) | 347 | os.makedirs(checkpoints_path, exist_ok=True) |
340 | 348 | ||
341 | unwrapped = self.accelerator.unwrap_model(text_encoder) | 349 | unwrapped = self.accelerator.unwrap_model(text_encoder) |
342 | 350 | ||
343 | # Save a checkpoint | 351 | # Save a checkpoint |
344 | learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id] | 352 | learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id] |
345 | learned_embeds_dict = {self.placeholder_token: learned_embeds.detach().cpu()} | 353 | learned_embeds_dict = {self.placeholder_token: learned_embeds.detach().cpu()} |
346 | 354 | ||
347 | filename = f"%s_%d_%s.bin" % (slugify(self.placeholder_token), step, postfix) | 355 | filename = f"%s_%d_%s.bin" % (slugify(self.placeholder_token), step, postfix) |
348 | if path is not None: | 356 | if path is not None: |
349 | torch.save(learned_embeds_dict, path) | 357 | torch.save(learned_embeds_dict, path) |
350 | else: | 358 | else: |
351 | torch.save(learned_embeds_dict, f"{checkpoints_path}/{filename}") | 359 | torch.save(learned_embeds_dict, f"{checkpoints_path}/{filename}") |
352 | torch.save(learned_embeds_dict, f"{checkpoints_path}/last.bin") | 360 | torch.save(learned_embeds_dict, f"{checkpoints_path}/last.bin") |
353 | del unwrapped | 361 | |
354 | del learned_embeds | 362 | del unwrapped |
363 | del learned_embeds | ||
355 | 364 | ||
356 | @torch.no_grad() | 365 | @torch.no_grad() |
357 | def save_samples(self, mode, step, text_encoder, height, width, guidance_scale, eta, num_inference_steps): | 366 | def save_samples(self, step, text_encoder, height, width, guidance_scale, eta, num_inference_steps): |
358 | samples_path = f"{self.output_dir}/samples/{mode}" | 367 | samples_path = Path(self.output_dir).joinpath("samples") |
359 | os.makedirs(samples_path, exist_ok=True) | ||
360 | checker = NoCheck() | ||
361 | 368 | ||
362 | unwrapped = self.accelerator.unwrap_model(text_encoder) | 369 | unwrapped = self.accelerator.unwrap_model(text_encoder) |
370 | scheduler = EulerAScheduler( | ||
371 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | ||
372 | ) | ||
373 | |||
363 | # Save a sample image | 374 | # Save a sample image |
364 | pipeline = StableDiffusionPipeline( | 375 | pipeline = VlpnStableDiffusion( |
365 | text_encoder=unwrapped, | 376 | text_encoder=unwrapped, |
366 | vae=self.vae, | 377 | vae=self.vae, |
367 | unet=self.unet, | 378 | unet=self.unet, |
368 | tokenizer=self.tokenizer, | 379 | tokenizer=self.tokenizer, |
369 | scheduler=LMSDiscreteScheduler( | 380 | scheduler=scheduler, |
370 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | ||
371 | ), | ||
372 | safety_checker=NoCheck(), | ||
373 | feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), | 381 | feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), |
374 | ).to(self.accelerator.device) | 382 | ).to(self.accelerator.device) |
375 | pipeline.enable_attention_slicing() | 383 | pipeline.enable_attention_slicing() |
376 | 384 | ||
377 | data = { | 385 | train_data = self.datamodule.train_dataloader() |
378 | "training": self.datamodule.train_dataloader(), | 386 | val_data = self.datamodule.val_dataloader() |
379 | "validation": self.datamodule.val_dataloader(), | ||
380 | }[mode] | ||
381 | 387 | ||
382 | if mode == "validation" and self.stable_sample_batches > 0 and step > 0: | 388 | if self.stable_sample_batches > 0: |
383 | stable_latents = torch.randn( | 389 | stable_latents = torch.randn( |
384 | (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8), | 390 | (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8), |
385 | device=pipeline.device, | 391 | device=pipeline.device, |
@@ -387,14 +393,17 @@ class Checkpointer: | |||
387 | ) | 393 | ) |
388 | 394 | ||
389 | all_samples = [] | 395 | all_samples = [] |
390 | filename = f"stable_step_%d.png" % (step) | 396 | file_path = samples_path.joinpath("stable", f"step_{step}.png") |
397 | file_path.parent.mkdir(parents=True, exist_ok=True) | ||
391 | 398 | ||
392 | data_enum = enumerate(data) | 399 | data_enum = enumerate(val_data) |
393 | 400 | ||
394 | # Generate and save stable samples | 401 | # Generate and save stable samples |
395 | for i in range(0, self.stable_sample_batches): | 402 | for i in range(0, self.stable_sample_batches): |
396 | prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( | 403 | prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( |
397 | batch["prompt"]) if i * data.batch_size + j < self.sample_batch_size] | 404 | batch["prompts"]) if i * val_data.batch_size + j < self.sample_batch_size] |
405 | |||
406 | generator = torch.Generator(device="cuda").manual_seed(self.seed + i) | ||
398 | 407 | ||
399 | with self.accelerator.autocast(): | 408 | with self.accelerator.autocast(): |
400 | samples = pipeline( | 409 | samples = pipeline( |
@@ -405,67 +414,64 @@ class Checkpointer: | |||
405 | guidance_scale=guidance_scale, | 414 | guidance_scale=guidance_scale, |
406 | eta=eta, | 415 | eta=eta, |
407 | num_inference_steps=num_inference_steps, | 416 | num_inference_steps=num_inference_steps, |
417 | generator=generator, | ||
408 | output_type='pil' | 418 | output_type='pil' |
409 | )["sample"] | 419 | )["sample"] |
410 | 420 | ||
411 | all_samples += samples | 421 | all_samples += samples |
422 | |||
423 | del generator | ||
412 | del samples | 424 | del samples |
413 | 425 | ||
414 | image_grid = make_grid(all_samples, self.stable_sample_batches, self.sample_batch_size) | 426 | image_grid = make_grid(all_samples, self.stable_sample_batches, self.sample_batch_size) |
415 | image_grid.save(f"{samples_path}/{filename}") | 427 | image_grid.save(file_path) |
416 | 428 | ||
417 | del all_samples | 429 | del all_samples |
418 | del image_grid | 430 | del image_grid |
419 | del stable_latents | 431 | del stable_latents |
420 | 432 | ||
421 | all_samples = [] | 433 | for data, pool in [(val_data, "val"), (train_data, "train")]: |
422 | filename = f"step_%d.png" % (step) | 434 | all_samples = [] |
435 | file_path = samples_path.joinpath(pool, f"step_{step}.png") | ||
436 | file_path.parent.mkdir(parents=True, exist_ok=True) | ||
423 | 437 | ||
424 | data_enum = enumerate(data) | 438 | data_enum = enumerate(data) |
425 | 439 | ||
426 | # Generate and save random samples | 440 | for i in range(0, self.random_sample_batches): |
427 | for i in range(0, self.random_sample_batches): | 441 | prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( |
428 | prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( | 442 | batch["prompts"]) if i * data.batch_size + j < self.sample_batch_size] |
429 | batch["prompt"]) if i * data.batch_size + j < self.sample_batch_size] | ||
430 | 443 | ||
431 | with self.accelerator.autocast(): | 444 | generator = torch.Generator(device="cuda").manual_seed(self.seed + i) |
432 | samples = pipeline( | ||
433 | prompt=prompt, | ||
434 | height=self.sample_image_size, | ||
435 | width=self.sample_image_size, | ||
436 | guidance_scale=guidance_scale, | ||
437 | eta=eta, | ||
438 | num_inference_steps=num_inference_steps, | ||
439 | output_type='pil' | ||
440 | )["sample"] | ||
441 | 445 | ||
442 | all_samples += samples | 446 | with self.accelerator.autocast(): |
443 | del samples | 447 | samples = pipeline( |
448 | prompt=prompt, | ||
449 | height=self.sample_image_size, | ||
450 | width=self.sample_image_size, | ||
451 | guidance_scale=guidance_scale, | ||
452 | eta=eta, | ||
453 | num_inference_steps=num_inference_steps, | ||
454 | generator=generator, | ||
455 | output_type='pil' | ||
456 | )["sample"] | ||
444 | 457 | ||
445 | image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size) | 458 | all_samples += samples |
446 | image_grid.save(f"{samples_path}/{filename}") | ||
447 | 459 | ||
448 | del all_samples | 460 | del generator |
449 | del image_grid | 461 | del samples |
450 | 462 | ||
451 | del checker | 463 | image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size) |
452 | del unwrapped | 464 | image_grid.save(file_path) |
453 | del pipeline | ||
454 | torch.cuda.empty_cache() | ||
455 | 465 | ||
466 | del all_samples | ||
467 | del image_grid | ||
456 | 468 | ||
457 | class ImageToLatents(): | 469 | del unwrapped |
458 | def __init__(self, vae): | 470 | del scheduler |
459 | self.vae = vae | 471 | del pipeline |
460 | self.encoded_pixel_values_cache = {} | ||
461 | 472 | ||
462 | @torch.no_grad() | 473 | if torch.cuda.is_available(): |
463 | def __call__(self, batch): | 474 | torch.cuda.empty_cache() |
464 | key = "|".join(batch["key"]) | ||
465 | if self.encoded_pixel_values_cache.get(key, None) is None: | ||
466 | self.encoded_pixel_values_cache[key] = self.vae.encode(batch["pixel_values"]).latent_dist | ||
467 | latents = self.encoded_pixel_values_cache[key].sample().detach().half() * 0.18215 | ||
468 | return latents | ||
469 | 475 | ||
470 | 476 | ||
471 | def main(): | 477 | def main(): |
@@ -473,17 +479,17 @@ def main(): | |||
473 | 479 | ||
474 | global_step_offset = 0 | 480 | global_step_offset = 0 |
475 | if args.resume_from is not None: | 481 | if args.resume_from is not None: |
476 | basepath = f"{args.resume_from}" | 482 | basepath = Path(args.resume_from) |
477 | print("Resuming state from %s" % args.resume_from) | 483 | print("Resuming state from %s" % args.resume_from) |
478 | with open(f"{basepath}/resume.json", 'r') as f: | 484 | with open(basepath.joinpath("resume.json"), 'r') as f: |
479 | state = json.load(f) | 485 | state = json.load(f) |
480 | global_step_offset = state["args"].get("global_step", 0) | 486 | global_step_offset = state["args"].get("global_step", 0) |
481 | 487 | ||
482 | print("We've trained %d steps so far" % global_step_offset) | 488 | print("We've trained %d steps so far" % global_step_offset) |
483 | else: | 489 | else: |
484 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | 490 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") |
485 | basepath = f"{args.output_dir}/{slugify(args.placeholder_token)}/{now}" | 491 | basepath = Path(args.output_dir).joinpath(slugify(args.placeholder_token), now) |
486 | os.makedirs(basepath, exist_ok=True) | 492 | basepath.mkdir(parents=True, exist_ok=True) |
487 | 493 | ||
488 | accelerator = Accelerator( | 494 | accelerator = Accelerator( |
489 | log_with=LoggerType.TENSORBOARD, | 495 | log_with=LoggerType.TENSORBOARD, |
@@ -492,6 +498,8 @@ def main(): | |||
492 | mixed_precision=args.mixed_precision | 498 | mixed_precision=args.mixed_precision |
493 | ) | 499 | ) |
494 | 500 | ||
501 | logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG) | ||
502 | |||
495 | # If passed along, set the training seed now. | 503 | # If passed along, set the training seed now. |
496 | if args.seed is not None: | 504 | if args.seed is not None: |
497 | set_seed(args.seed) | 505 | set_seed(args.seed) |
@@ -570,8 +578,19 @@ def main(): | |||
570 | args.train_batch_size * accelerator.num_processes | 578 | args.train_batch_size * accelerator.num_processes |
571 | ) | 579 | ) |
572 | 580 | ||
581 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs | ||
582 | if args.use_8bit_adam: | ||
583 | try: | ||
584 | import bitsandbytes as bnb | ||
585 | except ImportError: | ||
586 | raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.") | ||
587 | |||
588 | optimizer_class = bnb.optim.AdamW8bit | ||
589 | else: | ||
590 | optimizer_class = torch.optim.AdamW | ||
591 | |||
573 | # Initialize the optimizer | 592 | # Initialize the optimizer |
574 | optimizer = torch.optim.AdamW( | 593 | optimizer = optimizer_class( |
575 | text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings | 594 | text_encoder.get_input_embeddings().parameters(), # only optimize the embeddings |
576 | lr=args.learning_rate, | 595 | lr=args.learning_rate, |
577 | betas=(args.adam_beta1, args.adam_beta2), | 596 | betas=(args.adam_beta1, args.adam_beta2), |
@@ -585,7 +604,7 @@ def main(): | |||
585 | ) | 604 | ) |
586 | 605 | ||
587 | datamodule = CSVDataModule( | 606 | datamodule = CSVDataModule( |
588 | data_root=args.train_data_dir, batch_size=args.train_batch_size, tokenizer=tokenizer, | 607 | data_file=args.train_data_file, batch_size=args.train_batch_size, tokenizer=tokenizer, |
589 | size=args.resolution, placeholder_token=args.placeholder_token, repeats=args.repeats, | 608 | size=args.resolution, placeholder_token=args.placeholder_token, repeats=args.repeats, |
590 | center_crop=args.center_crop) | 609 | center_crop=args.center_crop) |
591 | 610 | ||
@@ -608,13 +627,12 @@ def main(): | |||
608 | sample_batch_size=args.sample_batch_size, | 627 | sample_batch_size=args.sample_batch_size, |
609 | random_sample_batches=args.random_sample_batches, | 628 | random_sample_batches=args.random_sample_batches, |
610 | stable_sample_batches=args.stable_sample_batches, | 629 | stable_sample_batches=args.stable_sample_batches, |
611 | seed=args.seed | 630 | seed=args.seed or torch.random.seed() |
612 | ) | 631 | ) |
613 | 632 | ||
614 | # Scheduler and math around the number of training steps. | 633 | # Scheduler and math around the number of training steps. |
615 | overrode_max_train_steps = False | 634 | overrode_max_train_steps = False |
616 | num_update_steps_per_epoch = math.ceil( | 635 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) |
617 | (len(train_dataloader) + len(val_dataloader)) / args.gradient_accumulation_steps) | ||
618 | if args.max_train_steps is None: | 636 | if args.max_train_steps is None: |
619 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch | 637 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch |
620 | overrode_max_train_steps = True | 638 | overrode_max_train_steps = True |
@@ -643,9 +661,10 @@ def main(): | |||
643 | (len(train_dataloader) + len(val_dataloader)) / args.gradient_accumulation_steps) | 661 | (len(train_dataloader) + len(val_dataloader)) / args.gradient_accumulation_steps) |
644 | if overrode_max_train_steps: | 662 | if overrode_max_train_steps: |
645 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch | 663 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch |
646 | # Afterwards we recalculate our number of training epochs | 664 | |
647 | args.num_train_epochs = math.ceil( | 665 | num_val_steps_per_epoch = len(val_dataloader) |
648 | args.max_train_steps / num_update_steps_per_epoch) | 666 | num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) |
667 | val_steps = num_val_steps_per_epoch * num_epochs | ||
649 | 668 | ||
650 | # We need to initialize the trackers we use, and also store our configuration. | 669 | # We need to initialize the trackers we use, and also store our configuration. |
651 | # The trackers initializes automatically on the main process. | 670 | # The trackers initializes automatically on the main process. |
@@ -656,7 +675,7 @@ def main(): | |||
656 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps | 675 | total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps |
657 | 676 | ||
658 | logger.info("***** Running training *****") | 677 | logger.info("***** Running training *****") |
659 | logger.info(f" Num Epochs = {args.num_train_epochs}") | 678 | logger.info(f" Num Epochs = {num_epochs}") |
660 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") | 679 | logger.info(f" Instantaneous batch size per device = {args.train_batch_size}") |
661 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") | 680 | logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}") |
662 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") | 681 | logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}") |
@@ -666,22 +685,22 @@ def main(): | |||
666 | global_step = 0 | 685 | global_step = 0 |
667 | min_val_loss = np.inf | 686 | min_val_loss = np.inf |
668 | 687 | ||
669 | imageToLatents = ImageToLatents(vae) | 688 | if accelerator.is_main_process: |
670 | 689 | checkpointer.save_samples( | |
671 | checkpointer.save_samples( | 690 | 0, |
672 | "validation", | 691 | text_encoder, |
673 | 0, | 692 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) |
674 | text_encoder, | ||
675 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) | ||
676 | 693 | ||
677 | progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) | 694 | local_progress_bar = tqdm(range(num_update_steps_per_epoch + num_val_steps_per_epoch), |
678 | progress_bar.set_description("Global steps") | 695 | disable=not accelerator.is_local_main_process) |
696 | local_progress_bar.set_description("Batch X out of Y") | ||
679 | 697 | ||
680 | local_progress_bar = tqdm(range(num_update_steps_per_epoch), disable=not accelerator.is_local_main_process) | 698 | global_progress_bar = tqdm(range(args.max_train_steps + val_steps), disable=not accelerator.is_local_main_process) |
681 | local_progress_bar.set_description("Steps") | 699 | global_progress_bar.set_description("Total progress") |
682 | 700 | ||
683 | try: | 701 | try: |
684 | for epoch in range(args.num_train_epochs): | 702 | for epoch in range(num_epochs): |
703 | local_progress_bar.set_description(f"Batch {epoch + 1} out of {num_epochs}") | ||
685 | local_progress_bar.reset() | 704 | local_progress_bar.reset() |
686 | 705 | ||
687 | text_encoder.train() | 706 | text_encoder.train() |
@@ -689,27 +708,30 @@ def main(): | |||
689 | 708 | ||
690 | for step, batch in enumerate(train_dataloader): | 709 | for step, batch in enumerate(train_dataloader): |
691 | with accelerator.accumulate(text_encoder): | 710 | with accelerator.accumulate(text_encoder): |
692 | with accelerator.autocast(): | 711 | # Convert images to latent space |
693 | # Convert images to latent space | 712 | with torch.no_grad(): |
694 | latents = imageToLatents(batch) | 713 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() |
714 | latents = latents * 0.18215 | ||
695 | 715 | ||
696 | # Sample noise that we'll add to the latents | 716 | # Sample noise that we'll add to the latents |
697 | noise = torch.randn(latents.shape).to(latents.device) | 717 | noise = torch.randn(latents.shape).to(latents.device) |
698 | bsz = latents.shape[0] | 718 | bsz = latents.shape[0] |
699 | # Sample a random timestep for each image | 719 | # Sample a random timestep for each image |
700 | timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, | 720 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, |
701 | (bsz,), device=latents.device).long() | 721 | (bsz,), device=latents.device) |
722 | timesteps = timesteps.long() | ||
702 | 723 | ||
703 | # Add noise to the latents according to the noise magnitude at each timestep | 724 | # Add noise to the latents according to the noise magnitude at each timestep |
704 | # (this is the forward diffusion process) | 725 | # (this is the forward diffusion process) |
705 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | 726 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) |
706 | 727 | ||
707 | # Get the text embedding for conditioning | 728 | # Get the text embedding for conditioning |
708 | encoder_hidden_states = text_encoder(batch["input_ids"])[0] | 729 | encoder_hidden_states = text_encoder(batch["input_ids"])[0] |
709 | 730 | ||
710 | # Predict the noise residual | 731 | # Predict the noise residual |
711 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample | 732 | noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample |
712 | 733 | ||
734 | with accelerator.autocast(): | ||
713 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | 735 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() |
714 | 736 | ||
715 | accelerator.backward(loss) | 737 | accelerator.backward(loss) |
@@ -727,32 +749,27 @@ def main(): | |||
727 | optimizer.step() | 749 | optimizer.step() |
728 | if not accelerator.optimizer_step_was_skipped: | 750 | if not accelerator.optimizer_step_was_skipped: |
729 | lr_scheduler.step() | 751 | lr_scheduler.step() |
730 | optimizer.zero_grad() | 752 | optimizer.zero_grad(set_to_none=True) |
731 | 753 | ||
732 | loss = loss.detach().item() | 754 | loss = loss.detach().item() |
733 | train_loss += loss | 755 | train_loss += loss |
734 | 756 | ||
735 | # Checks if the accelerator has performed an optimization step behind the scenes | 757 | # Checks if the accelerator has performed an optimization step behind the scenes |
736 | if accelerator.sync_gradients: | 758 | if accelerator.sync_gradients: |
737 | progress_bar.update(1) | ||
738 | local_progress_bar.update(1) | 759 | local_progress_bar.update(1) |
760 | global_progress_bar.update(1) | ||
739 | 761 | ||
740 | global_step += 1 | 762 | global_step += 1 |
741 | 763 | ||
742 | if global_step % args.checkpoint_frequency == 0 and global_step > 0 and accelerator.is_main_process: | 764 | if global_step % args.checkpoint_frequency == 0 and global_step > 0 and accelerator.is_main_process: |
743 | progress_bar.clear() | ||
744 | local_progress_bar.clear() | 765 | local_progress_bar.clear() |
766 | global_progress_bar.clear() | ||
745 | 767 | ||
746 | checkpointer.checkpoint(global_step + global_step_offset, "training", text_encoder) | 768 | checkpointer.checkpoint(global_step + global_step_offset, "training", text_encoder) |
747 | save_resume_file(basepath, args, { | 769 | save_resume_file(basepath, args, { |
748 | "global_step": global_step + global_step_offset, | 770 | "global_step": global_step + global_step_offset, |
749 | "resume_checkpoint": f"{basepath}/checkpoints/last.bin" | 771 | "resume_checkpoint": f"{basepath}/checkpoints/last.bin" |
750 | }) | 772 | }) |
751 | checkpointer.save_samples( | ||
752 | "training", | ||
753 | global_step + global_step_offset, | ||
754 | text_encoder, | ||
755 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) | ||
756 | 773 | ||
757 | logs = {"mode": "training", "loss": loss, "lr": lr_scheduler.get_last_lr()[0]} | 774 | logs = {"mode": "training", "loss": loss, "lr": lr_scheduler.get_last_lr()[0]} |
758 | local_progress_bar.set_postfix(**logs) | 775 | local_progress_bar.set_postfix(**logs) |
@@ -762,17 +779,21 @@ def main(): | |||
762 | 779 | ||
763 | train_loss /= len(train_dataloader) | 780 | train_loss /= len(train_dataloader) |
764 | 781 | ||
782 | accelerator.wait_for_everyone() | ||
783 | |||
765 | text_encoder.eval() | 784 | text_encoder.eval() |
766 | val_loss = 0.0 | 785 | val_loss = 0.0 |
767 | 786 | ||
768 | for step, batch in enumerate(val_dataloader): | 787 | for step, batch in enumerate(val_dataloader): |
769 | with torch.no_grad(), accelerator.autocast(): | 788 | with torch.no_grad(): |
770 | latents = imageToLatents(batch) | 789 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() |
790 | latents = latents * 0.18215 | ||
771 | 791 | ||
772 | noise = torch.randn(latents.shape).to(latents.device) | 792 | noise = torch.randn(latents.shape).to(latents.device) |
773 | bsz = latents.shape[0] | 793 | bsz = latents.shape[0] |
774 | timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, | 794 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, |
775 | (bsz,), device=latents.device).long() | 795 | (bsz,), device=latents.device) |
796 | timesteps = timesteps.long() | ||
776 | 797 | ||
777 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | 798 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) |
778 | 799 | ||
@@ -782,14 +803,15 @@ def main(): | |||
782 | 803 | ||
783 | noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) | 804 | noise_pred, noise = accelerator.gather_for_metrics((noise_pred, noise)) |
784 | 805 | ||
785 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | 806 | with accelerator.autocast(): |
807 | loss = F.mse_loss(noise_pred, noise, reduction="none").mean([1, 2, 3]).mean() | ||
786 | 808 | ||
787 | loss = loss.detach().item() | 809 | loss = loss.detach().item() |
788 | val_loss += loss | 810 | val_loss += loss |
789 | 811 | ||
790 | if accelerator.sync_gradients: | 812 | if accelerator.sync_gradients: |
791 | progress_bar.update(1) | ||
792 | local_progress_bar.update(1) | 813 | local_progress_bar.update(1) |
814 | global_progress_bar.update(1) | ||
793 | 815 | ||
794 | logs = {"mode": "validation", "loss": loss} | 816 | logs = {"mode": "validation", "loss": loss} |
795 | local_progress_bar.set_postfix(**logs) | 817 | local_progress_bar.set_postfix(**logs) |
@@ -798,21 +820,19 @@ def main(): | |||
798 | 820 | ||
799 | accelerator.log({"train/loss": train_loss, "val/loss": val_loss}, step=global_step) | 821 | accelerator.log({"train/loss": train_loss, "val/loss": val_loss}, step=global_step) |
800 | 822 | ||
801 | progress_bar.clear() | ||
802 | local_progress_bar.clear() | 823 | local_progress_bar.clear() |
824 | global_progress_bar.clear() | ||
803 | 825 | ||
804 | if min_val_loss > val_loss: | 826 | if min_val_loss > val_loss: |
805 | accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") | 827 | accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") |
806 | checkpointer.checkpoint(global_step + global_step_offset, "milestone", text_encoder) | 828 | checkpointer.checkpoint(global_step + global_step_offset, "milestone", text_encoder) |
807 | min_val_loss = val_loss | 829 | min_val_loss = val_loss |
808 | 830 | ||
809 | checkpointer.save_samples( | 831 | if accelerator.is_main_process: |
810 | "validation", | 832 | checkpointer.save_samples( |
811 | global_step + global_step_offset, | 833 | global_step + global_step_offset, |
812 | text_encoder, | 834 | text_encoder, |
813 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) | 835 | args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) |
814 | |||
815 | accelerator.wait_for_everyone() | ||
816 | 836 | ||
817 | # Create the pipeline using using the trained modules and save it. | 837 | # Create the pipeline using using the trained modules and save it. |
818 | if accelerator.is_main_process: | 838 | if accelerator.is_main_process: |