diff options
| -rw-r--r-- | data/dreambooth/csv.py | 15 | ||||
| -rw-r--r-- | data/textual_inversion/csv.py | 17 | ||||
| -rw-r--r-- | dreambooth.py | 81 | ||||
| -rw-r--r-- | textual_inversion.py | 83 |
4 files changed, 63 insertions, 133 deletions
diff --git a/data/dreambooth/csv.py b/data/dreambooth/csv.py index 08ed49c..71aa1eb 100644 --- a/data/dreambooth/csv.py +++ b/data/dreambooth/csv.py | |||
| @@ -49,9 +49,10 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 49 | def prepare_data(self): | 49 | def prepare_data(self): |
| 50 | metadata = pd.read_csv(self.data_file) | 50 | metadata = pd.read_csv(self.data_file) |
| 51 | image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values] | 51 | image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values] |
| 52 | captions = [caption for caption in metadata['caption'].values] | 52 | prompts = metadata['prompt'].values |
| 53 | skips = [skip for skip in metadata['skip'].values] | 53 | nprompts = metadata['nprompt'].values if 'nprompt' in metadata else [""] * len(image_paths) |
| 54 | self.data_full = [(img, cap) for img, cap, skip in zip(image_paths, captions, skips) if skip != "x"] | 54 | skips = metadata['skip'].values if 'skip' in metadata else [""] * len(image_paths) |
| 55 | self.data_full = [(i, p, n) for i, p, n, s in zip(image_paths, prompts, nprompts, skips) if s != "x"] | ||
| 55 | 56 | ||
| 56 | def setup(self, stage=None): | 57 | def setup(self, stage=None): |
| 57 | valid_set_size = int(len(self.data_full) * 0.2) | 58 | valid_set_size = int(len(self.data_full) * 0.2) |
| @@ -135,7 +136,7 @@ class CSVDataset(Dataset): | |||
| 135 | return math.ceil(self._length / self.batch_size) * self.batch_size | 136 | return math.ceil(self._length / self.batch_size) * self.batch_size |
| 136 | 137 | ||
| 137 | def get_example(self, i): | 138 | def get_example(self, i): |
| 138 | image_path, text = self.data[i % self.num_instance_images] | 139 | image_path, prompt, nprompt = self.data[i % self.num_instance_images] |
| 139 | 140 | ||
| 140 | if image_path in self.cache: | 141 | if image_path in self.cache: |
| 141 | return self.cache[image_path] | 142 | return self.cache[image_path] |
| @@ -146,9 +147,10 @@ class CSVDataset(Dataset): | |||
| 146 | if not instance_image.mode == "RGB": | 147 | if not instance_image.mode == "RGB": |
| 147 | instance_image = instance_image.convert("RGB") | 148 | instance_image = instance_image.convert("RGB") |
| 148 | 149 | ||
| 149 | text = text.format(self.identifier) | 150 | prompt = prompt.format(self.identifier) |
| 150 | 151 | ||
| 151 | example["prompts"] = text | 152 | example["prompts"] = prompt |
| 153 | example["nprompts"] = nprompt | ||
| 152 | example["instance_images"] = instance_image | 154 | example["instance_images"] = instance_image |
| 153 | example["instance_prompt_ids"] = self.tokenizer( | 155 | example["instance_prompt_ids"] = self.tokenizer( |
| 154 | self.instance_prompt, | 156 | self.instance_prompt, |
| @@ -178,6 +180,7 @@ class CSVDataset(Dataset): | |||
| 178 | unprocessed_example = self.get_example(i) | 180 | unprocessed_example = self.get_example(i) |
| 179 | 181 | ||
| 180 | example["prompts"] = unprocessed_example["prompts"] | 182 | example["prompts"] = unprocessed_example["prompts"] |
| 183 | example["nprompts"] = unprocessed_example["nprompts"] | ||
| 181 | example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) | 184 | example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) |
| 182 | example["instance_prompt_ids"] = unprocessed_example["instance_prompt_ids"] | 185 | example["instance_prompt_ids"] = unprocessed_example["instance_prompt_ids"] |
| 183 | 186 | ||
diff --git a/data/textual_inversion/csv.py b/data/textual_inversion/csv.py index 3ac57df..64f0c28 100644 --- a/data/textual_inversion/csv.py +++ b/data/textual_inversion/csv.py | |||
| @@ -43,9 +43,10 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 43 | def prepare_data(self): | 43 | def prepare_data(self): |
| 44 | metadata = pd.read_csv(self.data_file) | 44 | metadata = pd.read_csv(self.data_file) |
| 45 | image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values] | 45 | image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values] |
| 46 | captions = [caption for caption in metadata['caption'].values] | 46 | prompts = metadata['prompt'].values |
| 47 | skips = [skip for skip in metadata['skip'].values] | 47 | nprompts = metadata['nprompt'].values if 'nprompt' in metadata else [""] * len(image_paths) |
| 48 | self.data_full = [(img, cap) for img, cap, skip in zip(image_paths, captions, skips) if skip != "x"] | 48 | skips = metadata['skip'].values if 'skip' in metadata else [""] * len(image_paths) |
| 49 | self.data_full = [(i, p, n) for i, p, n, s in zip(image_paths, prompts, nprompts, skips) if s != "x"] | ||
| 49 | 50 | ||
| 50 | def setup(self, stage=None): | 51 | def setup(self, stage=None): |
| 51 | valid_set_size = int(len(self.data_full) * 0.2) | 52 | valid_set_size = int(len(self.data_full) * 0.2) |
| @@ -109,7 +110,7 @@ class CSVDataset(Dataset): | |||
| 109 | return math.ceil(self._length / self.batch_size) * self.batch_size | 110 | return math.ceil(self._length / self.batch_size) * self.batch_size |
| 110 | 111 | ||
| 111 | def get_example(self, i): | 112 | def get_example(self, i): |
| 112 | image_path, text = self.data[i % self.num_instance_images] | 113 | image_path, prompt, nprompt = self.data[i % self.num_instance_images] |
| 113 | 114 | ||
| 114 | if image_path in self.cache: | 115 | if image_path in self.cache: |
| 115 | return self.cache[image_path] | 116 | return self.cache[image_path] |
| @@ -120,12 +121,13 @@ class CSVDataset(Dataset): | |||
| 120 | if not instance_image.mode == "RGB": | 121 | if not instance_image.mode == "RGB": |
| 121 | instance_image = instance_image.convert("RGB") | 122 | instance_image = instance_image.convert("RGB") |
| 122 | 123 | ||
| 123 | text = text.format(self.placeholder_token) | 124 | prompt = prompt.format(self.placeholder_token) |
| 124 | 125 | ||
| 125 | example["prompts"] = text | 126 | example["prompts"] = prompt |
| 127 | example["nprompts"] = nprompt | ||
| 126 | example["pixel_values"] = instance_image | 128 | example["pixel_values"] = instance_image |
| 127 | example["input_ids"] = self.tokenizer( | 129 | example["input_ids"] = self.tokenizer( |
| 128 | text, | 130 | prompt, |
| 129 | padding="max_length", | 131 | padding="max_length", |
| 130 | truncation=True, | 132 | truncation=True, |
| 131 | max_length=self.tokenizer.model_max_length, | 133 | max_length=self.tokenizer.model_max_length, |
| @@ -140,6 +142,7 @@ class CSVDataset(Dataset): | |||
| 140 | unprocessed_example = self.get_example(i) | 142 | unprocessed_example = self.get_example(i) |
| 141 | 143 | ||
| 142 | example["prompts"] = unprocessed_example["prompts"] | 144 | example["prompts"] = unprocessed_example["prompts"] |
| 145 | example["nprompts"] = unprocessed_example["nprompts"] | ||
| 143 | example["input_ids"] = unprocessed_example["input_ids"] | 146 | example["input_ids"] = unprocessed_example["input_ids"] |
| 144 | example["pixel_values"] = self.image_transforms(unprocessed_example["pixel_values"]) | 147 | example["pixel_values"] = self.image_transforms(unprocessed_example["pixel_values"]) |
| 145 | 148 | ||
diff --git a/dreambooth.py b/dreambooth.py index 75602dc..5fbf172 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
| @@ -191,16 +191,10 @@ def parse_args(): | |||
| 191 | help="Size of sample images", | 191 | help="Size of sample images", |
| 192 | ) | 192 | ) |
| 193 | parser.add_argument( | 193 | parser.add_argument( |
| 194 | "--stable_sample_batches", | 194 | "--sample_batches", |
| 195 | type=int, | 195 | type=int, |
| 196 | default=1, | 196 | default=1, |
| 197 | help="Number of fixed seed sample batches to generate per checkpoint", | 197 | help="Number of sample batches to generate per checkpoint", |
| 198 | ) | ||
| 199 | parser.add_argument( | ||
| 200 | "--random_sample_batches", | ||
| 201 | type=int, | ||
| 202 | default=1, | ||
| 203 | help="Number of random seed sample batches to generate per checkpoint", | ||
| 204 | ) | 198 | ) |
| 205 | parser.add_argument( | 199 | parser.add_argument( |
| 206 | "--sample_batch_size", | 200 | "--sample_batch_size", |
| @@ -331,9 +325,8 @@ class Checkpointer: | |||
| 331 | text_encoder, | 325 | text_encoder, |
| 332 | output_dir, | 326 | output_dir, |
| 333 | sample_image_size, | 327 | sample_image_size, |
| 334 | random_sample_batches, | 328 | sample_batches, |
| 335 | sample_batch_size, | 329 | sample_batch_size, |
| 336 | stable_sample_batches, | ||
| 337 | seed | 330 | seed |
| 338 | ): | 331 | ): |
| 339 | self.datamodule = datamodule | 332 | self.datamodule = datamodule |
| @@ -345,9 +338,8 @@ class Checkpointer: | |||
| 345 | self.output_dir = output_dir | 338 | self.output_dir = output_dir |
| 346 | self.sample_image_size = sample_image_size | 339 | self.sample_image_size = sample_image_size |
| 347 | self.seed = seed | 340 | self.seed = seed |
| 348 | self.random_sample_batches = random_sample_batches | 341 | self.sample_batches = sample_batches |
| 349 | self.sample_batch_size = sample_batch_size | 342 | self.sample_batch_size = sample_batch_size |
| 350 | self.stable_sample_batches = stable_sample_batches | ||
| 351 | 343 | ||
| 352 | @torch.no_grad() | 344 | @torch.no_grad() |
| 353 | def checkpoint(self): | 345 | def checkpoint(self): |
| @@ -396,63 +388,33 @@ class Checkpointer: | |||
| 396 | train_data = self.datamodule.train_dataloader() | 388 | train_data = self.datamodule.train_dataloader() |
| 397 | val_data = self.datamodule.val_dataloader() | 389 | val_data = self.datamodule.val_dataloader() |
| 398 | 390 | ||
| 399 | if self.stable_sample_batches > 0: | 391 | generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) |
| 400 | stable_latents = torch.randn( | 392 | stable_latents = torch.randn( |
| 401 | (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8), | 393 | (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8), |
| 402 | device=pipeline.device, | 394 | device=pipeline.device, |
| 403 | generator=torch.Generator(device=pipeline.device).manual_seed(self.seed), | 395 | generator=generator, |
| 404 | ) | 396 | ) |
| 405 | |||
| 406 | all_samples = [] | ||
| 407 | file_path = samples_path.joinpath("stable", f"step_{step}.png") | ||
| 408 | file_path.parent.mkdir(parents=True, exist_ok=True) | ||
| 409 | |||
| 410 | data_enum = enumerate(val_data) | ||
| 411 | |||
| 412 | # Generate and save stable samples | ||
| 413 | for i in range(0, self.stable_sample_batches): | ||
| 414 | prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( | ||
| 415 | batch["prompts"]) if i * val_data.batch_size + j < self.sample_batch_size] | ||
| 416 | |||
| 417 | with self.accelerator.autocast(): | ||
| 418 | samples = pipeline( | ||
| 419 | prompt=prompt, | ||
| 420 | height=self.sample_image_size, | ||
| 421 | latents=stable_latents[:len(prompt)], | ||
| 422 | width=self.sample_image_size, | ||
| 423 | guidance_scale=guidance_scale, | ||
| 424 | eta=eta, | ||
| 425 | num_inference_steps=num_inference_steps, | ||
| 426 | output_type='pil' | ||
| 427 | )["sample"] | ||
| 428 | |||
| 429 | all_samples += samples | ||
| 430 | |||
| 431 | del samples | ||
| 432 | |||
| 433 | image_grid = make_grid(all_samples, self.stable_sample_batches, self.sample_batch_size) | ||
| 434 | image_grid.save(file_path) | ||
| 435 | |||
| 436 | del all_samples | ||
| 437 | del image_grid | ||
| 438 | del stable_latents | ||
| 439 | 397 | ||
| 440 | for data, pool in [(val_data, "val"), (train_data, "train")]: | 398 | for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: |
| 441 | all_samples = [] | 399 | all_samples = [] |
| 442 | file_path = samples_path.joinpath(pool, f"step_{step}.png") | 400 | file_path = samples_path.joinpath(pool, f"step_{step}.png") |
| 443 | file_path.parent.mkdir(parents=True, exist_ok=True) | 401 | file_path.parent.mkdir(parents=True, exist_ok=True) |
| 444 | 402 | ||
| 445 | data_enum = enumerate(data) | 403 | data_enum = enumerate(data) |
| 446 | 404 | ||
| 447 | for i in range(0, self.random_sample_batches): | 405 | for i in range(self.sample_batches): |
| 448 | prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( | 406 | batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] |
| 449 | batch["prompts"]) if i * data.batch_size + j < self.sample_batch_size] | 407 | prompt = [prompt for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] |
| 408 | nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] | ||
| 450 | 409 | ||
| 451 | with self.accelerator.autocast(): | 410 | with self.accelerator.autocast(): |
| 452 | samples = pipeline( | 411 | samples = pipeline( |
| 453 | prompt=prompt, | 412 | prompt=prompt, |
| 413 | negative_prompt=nprompt, | ||
| 454 | height=self.sample_image_size, | 414 | height=self.sample_image_size, |
| 455 | width=self.sample_image_size, | 415 | width=self.sample_image_size, |
| 416 | latents=latents[:len(prompt)] if latents is not None else None, | ||
| 417 | generator=generator if latents is not None else None, | ||
| 456 | guidance_scale=guidance_scale, | 418 | guidance_scale=guidance_scale, |
| 457 | eta=eta, | 419 | eta=eta, |
| 458 | num_inference_steps=num_inference_steps, | 420 | num_inference_steps=num_inference_steps, |
| @@ -463,7 +425,7 @@ class Checkpointer: | |||
| 463 | 425 | ||
| 464 | del samples | 426 | del samples |
| 465 | 427 | ||
| 466 | image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size) | 428 | image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size) |
| 467 | image_grid.save(file_path) | 429 | image_grid.save(file_path) |
| 468 | 430 | ||
| 469 | del all_samples | 431 | del all_samples |
| @@ -630,7 +592,7 @@ def main(): | |||
| 630 | identifier=args.identifier, | 592 | identifier=args.identifier, |
| 631 | repeats=args.repeats, | 593 | repeats=args.repeats, |
| 632 | center_crop=args.center_crop, | 594 | center_crop=args.center_crop, |
| 633 | valid_set_size=args.sample_batch_size*args.stable_sample_batches, | 595 | valid_set_size=args.sample_batch_size*args.sample_batches, |
| 634 | collate_fn=collate_fn) | 596 | collate_fn=collate_fn) |
| 635 | 597 | ||
| 636 | datamodule.prepare_data() | 598 | datamodule.prepare_data() |
| @@ -649,8 +611,7 @@ def main(): | |||
| 649 | output_dir=basepath, | 611 | output_dir=basepath, |
| 650 | sample_image_size=args.sample_image_size, | 612 | sample_image_size=args.sample_image_size, |
| 651 | sample_batch_size=args.sample_batch_size, | 613 | sample_batch_size=args.sample_batch_size, |
| 652 | random_sample_batches=args.random_sample_batches, | 614 | sample_batches=args.sample_batches, |
| 653 | stable_sample_batches=args.stable_sample_batches, | ||
| 654 | seed=args.seed or torch.random.seed() | 615 | seed=args.seed or torch.random.seed() |
| 655 | ) | 616 | ) |
| 656 | 617 | ||
diff --git a/textual_inversion.py b/textual_inversion.py index 285aa0a..00d460f 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
| @@ -207,16 +207,10 @@ def parse_args(): | |||
| 207 | help="Size of sample images", | 207 | help="Size of sample images", |
| 208 | ) | 208 | ) |
| 209 | parser.add_argument( | 209 | parser.add_argument( |
| 210 | "--stable_sample_batches", | 210 | "--sample_batches", |
| 211 | type=int, | 211 | type=int, |
| 212 | default=1, | 212 | default=1, |
| 213 | help="Number of fixed seed sample batches to generate per checkpoint", | 213 | help="Number of sample batches to generate per checkpoint", |
| 214 | ) | ||
| 215 | parser.add_argument( | ||
| 216 | "--random_sample_batches", | ||
| 217 | type=int, | ||
| 218 | default=1, | ||
| 219 | help="Number of random seed sample batches to generate per checkpoint", | ||
| 220 | ) | 214 | ) |
| 221 | parser.add_argument( | 215 | parser.add_argument( |
| 222 | "--sample_batch_size", | 216 | "--sample_batch_size", |
| @@ -319,9 +313,8 @@ class Checkpointer: | |||
| 319 | placeholder_token_id, | 313 | placeholder_token_id, |
| 320 | output_dir, | 314 | output_dir, |
| 321 | sample_image_size, | 315 | sample_image_size, |
| 322 | random_sample_batches, | 316 | sample_batches, |
| 323 | sample_batch_size, | 317 | sample_batch_size, |
| 324 | stable_sample_batches, | ||
| 325 | seed | 318 | seed |
| 326 | ): | 319 | ): |
| 327 | self.datamodule = datamodule | 320 | self.datamodule = datamodule |
| @@ -334,9 +327,8 @@ class Checkpointer: | |||
| 334 | self.output_dir = output_dir | 327 | self.output_dir = output_dir |
| 335 | self.sample_image_size = sample_image_size | 328 | self.sample_image_size = sample_image_size |
| 336 | self.seed = seed | 329 | self.seed = seed |
| 337 | self.random_sample_batches = random_sample_batches | 330 | self.sample_batches = sample_batches |
| 338 | self.sample_batch_size = sample_batch_size | 331 | self.sample_batch_size = sample_batch_size |
| 339 | self.stable_sample_batches = stable_sample_batches | ||
| 340 | 332 | ||
| 341 | @torch.no_grad() | 333 | @torch.no_grad() |
| 342 | def checkpoint(self, step, postfix, text_encoder, save_samples=True, path=None): | 334 | def checkpoint(self, step, postfix, text_encoder, save_samples=True, path=None): |
| @@ -385,63 +377,33 @@ class Checkpointer: | |||
| 385 | train_data = self.datamodule.train_dataloader() | 377 | train_data = self.datamodule.train_dataloader() |
| 386 | val_data = self.datamodule.val_dataloader() | 378 | val_data = self.datamodule.val_dataloader() |
| 387 | 379 | ||
| 388 | if self.stable_sample_batches > 0: | 380 | generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) |
| 389 | stable_latents = torch.randn( | 381 | stable_latents = torch.randn( |
| 390 | (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8), | 382 | (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8), |
| 391 | device=pipeline.device, | 383 | device=pipeline.device, |
| 392 | generator=torch.Generator(device=pipeline.device).manual_seed(self.seed), | 384 | generator=generator, |
| 393 | ) | 385 | ) |
| 394 | |||
| 395 | all_samples = [] | ||
| 396 | file_path = samples_path.joinpath("stable", f"step_{step}.png") | ||
| 397 | file_path.parent.mkdir(parents=True, exist_ok=True) | ||
| 398 | |||
| 399 | data_enum = enumerate(val_data) | ||
| 400 | |||
| 401 | # Generate and save stable samples | ||
| 402 | for i in range(0, self.stable_sample_batches): | ||
| 403 | prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( | ||
| 404 | batch["prompts"]) if i * val_data.batch_size + j < self.sample_batch_size] | ||
| 405 | |||
| 406 | with self.accelerator.autocast(): | ||
| 407 | samples = pipeline( | ||
| 408 | prompt=prompt, | ||
| 409 | height=self.sample_image_size, | ||
| 410 | latents=stable_latents[:len(prompt)], | ||
| 411 | width=self.sample_image_size, | ||
| 412 | guidance_scale=guidance_scale, | ||
| 413 | eta=eta, | ||
| 414 | num_inference_steps=num_inference_steps, | ||
| 415 | output_type='pil' | ||
| 416 | )["sample"] | ||
| 417 | |||
| 418 | all_samples += samples | ||
| 419 | |||
| 420 | del samples | ||
| 421 | |||
| 422 | image_grid = make_grid(all_samples, self.stable_sample_batches, self.sample_batch_size) | ||
| 423 | image_grid.save(file_path) | ||
| 424 | |||
| 425 | del all_samples | ||
| 426 | del image_grid | ||
| 427 | del stable_latents | ||
| 428 | 386 | ||
| 429 | for data, pool in [(val_data, "val"), (train_data, "train")]: | 387 | for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: |
| 430 | all_samples = [] | 388 | all_samples = [] |
| 431 | file_path = samples_path.joinpath(pool, f"step_{step}.png") | 389 | file_path = samples_path.joinpath(pool, f"step_{step}.png") |
| 432 | file_path.parent.mkdir(parents=True, exist_ok=True) | 390 | file_path.parent.mkdir(parents=True, exist_ok=True) |
| 433 | 391 | ||
| 434 | data_enum = enumerate(data) | 392 | data_enum = enumerate(data) |
| 435 | 393 | ||
| 436 | for i in range(0, self.random_sample_batches): | 394 | for i in range(self.sample_batches): |
| 437 | prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( | 395 | batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] |
| 438 | batch["prompts"]) if i * data.batch_size + j < self.sample_batch_size] | 396 | prompt = [prompt for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] |
| 397 | nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] | ||
| 439 | 398 | ||
| 440 | with self.accelerator.autocast(): | 399 | with self.accelerator.autocast(): |
| 441 | samples = pipeline( | 400 | samples = pipeline( |
| 442 | prompt=prompt, | 401 | prompt=prompt, |
| 402 | negative_prompt=nprompt, | ||
| 443 | height=self.sample_image_size, | 403 | height=self.sample_image_size, |
| 444 | width=self.sample_image_size, | 404 | width=self.sample_image_size, |
| 405 | latents=latents[:len(prompt)] if latents is not None else None, | ||
| 406 | generator=generator if latents is not None else None, | ||
| 445 | guidance_scale=guidance_scale, | 407 | guidance_scale=guidance_scale, |
| 446 | eta=eta, | 408 | eta=eta, |
| 447 | num_inference_steps=num_inference_steps, | 409 | num_inference_steps=num_inference_steps, |
| @@ -452,7 +414,7 @@ class Checkpointer: | |||
| 452 | 414 | ||
| 453 | del samples | 415 | del samples |
| 454 | 416 | ||
| 455 | image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size) | 417 | image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size) |
| 456 | image_grid.save(file_path) | 418 | image_grid.save(file_path) |
| 457 | 419 | ||
| 458 | del all_samples | 420 | del all_samples |
| @@ -461,6 +423,8 @@ class Checkpointer: | |||
| 461 | del unwrapped | 423 | del unwrapped |
| 462 | del scheduler | 424 | del scheduler |
| 463 | del pipeline | 425 | del pipeline |
| 426 | del generator | ||
| 427 | del stable_latents | ||
| 464 | 428 | ||
| 465 | if torch.cuda.is_available(): | 429 | if torch.cuda.is_available(): |
| 466 | torch.cuda.empty_cache() | 430 | torch.cuda.empty_cache() |
| @@ -603,7 +567,7 @@ def main(): | |||
| 603 | placeholder_token=args.placeholder_token, | 567 | placeholder_token=args.placeholder_token, |
| 604 | repeats=args.repeats, | 568 | repeats=args.repeats, |
| 605 | center_crop=args.center_crop, | 569 | center_crop=args.center_crop, |
| 606 | valid_set_size=args.sample_batch_size*args.stable_sample_batches | 570 | valid_set_size=args.sample_batch_size*args.sample_batches |
| 607 | ) | 571 | ) |
| 608 | 572 | ||
| 609 | datamodule.prepare_data() | 573 | datamodule.prepare_data() |
| @@ -623,8 +587,7 @@ def main(): | |||
| 623 | output_dir=basepath, | 587 | output_dir=basepath, |
| 624 | sample_image_size=args.sample_image_size, | 588 | sample_image_size=args.sample_image_size, |
| 625 | sample_batch_size=args.sample_batch_size, | 589 | sample_batch_size=args.sample_batch_size, |
| 626 | random_sample_batches=args.random_sample_batches, | 590 | sample_batches=args.sample_batches, |
| 627 | stable_sample_batches=args.stable_sample_batches, | ||
| 628 | seed=args.seed or torch.random.seed() | 591 | seed=args.seed or torch.random.seed() |
| 629 | ) | 592 | ) |
| 630 | 593 | ||
