diff options
author | Volpeon <git@volpeon.ink> | 2022-10-03 14:47:01 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-10-03 14:47:01 +0200 |
commit | c90099f06e0b461660b326fb6d86b69d86e78289 (patch) | |
tree | df4ce274eed8f2a89bbd12f1a19c685ceac58ff2 | |
parent | Fixed euler_a generator argument (diff) | |
download | textual-inversion-diff-c90099f06e0b461660b326fb6d86b69d86e78289.tar.gz textual-inversion-diff-c90099f06e0b461660b326fb6d86b69d86e78289.tar.bz2 textual-inversion-diff-c90099f06e0b461660b326fb6d86b69d86e78289.zip |
Added negative prompt support for training scripts
-rw-r--r-- | data/dreambooth/csv.py | 15 | ||||
-rw-r--r-- | data/textual_inversion/csv.py | 17 | ||||
-rw-r--r-- | dreambooth.py | 81 | ||||
-rw-r--r-- | textual_inversion.py | 83 |
4 files changed, 63 insertions, 133 deletions
diff --git a/data/dreambooth/csv.py b/data/dreambooth/csv.py index 08ed49c..71aa1eb 100644 --- a/data/dreambooth/csv.py +++ b/data/dreambooth/csv.py | |||
@@ -49,9 +49,10 @@ class CSVDataModule(pl.LightningDataModule): | |||
49 | def prepare_data(self): | 49 | def prepare_data(self): |
50 | metadata = pd.read_csv(self.data_file) | 50 | metadata = pd.read_csv(self.data_file) |
51 | image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values] | 51 | image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values] |
52 | captions = [caption for caption in metadata['caption'].values] | 52 | prompts = metadata['prompt'].values |
53 | skips = [skip for skip in metadata['skip'].values] | 53 | nprompts = metadata['nprompt'].values if 'nprompt' in metadata else [""] * len(image_paths) |
54 | self.data_full = [(img, cap) for img, cap, skip in zip(image_paths, captions, skips) if skip != "x"] | 54 | skips = metadata['skip'].values if 'skip' in metadata else [""] * len(image_paths) |
55 | self.data_full = [(i, p, n) for i, p, n, s in zip(image_paths, prompts, nprompts, skips) if s != "x"] | ||
55 | 56 | ||
56 | def setup(self, stage=None): | 57 | def setup(self, stage=None): |
57 | valid_set_size = int(len(self.data_full) * 0.2) | 58 | valid_set_size = int(len(self.data_full) * 0.2) |
@@ -135,7 +136,7 @@ class CSVDataset(Dataset): | |||
135 | return math.ceil(self._length / self.batch_size) * self.batch_size | 136 | return math.ceil(self._length / self.batch_size) * self.batch_size |
136 | 137 | ||
137 | def get_example(self, i): | 138 | def get_example(self, i): |
138 | image_path, text = self.data[i % self.num_instance_images] | 139 | image_path, prompt, nprompt = self.data[i % self.num_instance_images] |
139 | 140 | ||
140 | if image_path in self.cache: | 141 | if image_path in self.cache: |
141 | return self.cache[image_path] | 142 | return self.cache[image_path] |
@@ -146,9 +147,10 @@ class CSVDataset(Dataset): | |||
146 | if not instance_image.mode == "RGB": | 147 | if not instance_image.mode == "RGB": |
147 | instance_image = instance_image.convert("RGB") | 148 | instance_image = instance_image.convert("RGB") |
148 | 149 | ||
149 | text = text.format(self.identifier) | 150 | prompt = prompt.format(self.identifier) |
150 | 151 | ||
151 | example["prompts"] = text | 152 | example["prompts"] = prompt |
153 | example["nprompts"] = nprompt | ||
152 | example["instance_images"] = instance_image | 154 | example["instance_images"] = instance_image |
153 | example["instance_prompt_ids"] = self.tokenizer( | 155 | example["instance_prompt_ids"] = self.tokenizer( |
154 | self.instance_prompt, | 156 | self.instance_prompt, |
@@ -178,6 +180,7 @@ class CSVDataset(Dataset): | |||
178 | unprocessed_example = self.get_example(i) | 180 | unprocessed_example = self.get_example(i) |
179 | 181 | ||
180 | example["prompts"] = unprocessed_example["prompts"] | 182 | example["prompts"] = unprocessed_example["prompts"] |
183 | example["nprompts"] = unprocessed_example["nprompts"] | ||
181 | example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) | 184 | example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) |
182 | example["instance_prompt_ids"] = unprocessed_example["instance_prompt_ids"] | 185 | example["instance_prompt_ids"] = unprocessed_example["instance_prompt_ids"] |
183 | 186 | ||
diff --git a/data/textual_inversion/csv.py b/data/textual_inversion/csv.py index 3ac57df..64f0c28 100644 --- a/data/textual_inversion/csv.py +++ b/data/textual_inversion/csv.py | |||
@@ -43,9 +43,10 @@ class CSVDataModule(pl.LightningDataModule): | |||
43 | def prepare_data(self): | 43 | def prepare_data(self): |
44 | metadata = pd.read_csv(self.data_file) | 44 | metadata = pd.read_csv(self.data_file) |
45 | image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values] | 45 | image_paths = [os.path.join(self.data_root, f_path) for f_path in metadata['image'].values] |
46 | captions = [caption for caption in metadata['caption'].values] | 46 | prompts = metadata['prompt'].values |
47 | skips = [skip for skip in metadata['skip'].values] | 47 | nprompts = metadata['nprompt'].values if 'nprompt' in metadata else [""] * len(image_paths) |
48 | self.data_full = [(img, cap) for img, cap, skip in zip(image_paths, captions, skips) if skip != "x"] | 48 | skips = metadata['skip'].values if 'skip' in metadata else [""] * len(image_paths) |
49 | self.data_full = [(i, p, n) for i, p, n, s in zip(image_paths, prompts, nprompts, skips) if s != "x"] | ||
49 | 50 | ||
50 | def setup(self, stage=None): | 51 | def setup(self, stage=None): |
51 | valid_set_size = int(len(self.data_full) * 0.2) | 52 | valid_set_size = int(len(self.data_full) * 0.2) |
@@ -109,7 +110,7 @@ class CSVDataset(Dataset): | |||
109 | return math.ceil(self._length / self.batch_size) * self.batch_size | 110 | return math.ceil(self._length / self.batch_size) * self.batch_size |
110 | 111 | ||
111 | def get_example(self, i): | 112 | def get_example(self, i): |
112 | image_path, text = self.data[i % self.num_instance_images] | 113 | image_path, prompt, nprompt = self.data[i % self.num_instance_images] |
113 | 114 | ||
114 | if image_path in self.cache: | 115 | if image_path in self.cache: |
115 | return self.cache[image_path] | 116 | return self.cache[image_path] |
@@ -120,12 +121,13 @@ class CSVDataset(Dataset): | |||
120 | if not instance_image.mode == "RGB": | 121 | if not instance_image.mode == "RGB": |
121 | instance_image = instance_image.convert("RGB") | 122 | instance_image = instance_image.convert("RGB") |
122 | 123 | ||
123 | text = text.format(self.placeholder_token) | 124 | prompt = prompt.format(self.placeholder_token) |
124 | 125 | ||
125 | example["prompts"] = text | 126 | example["prompts"] = prompt |
127 | example["nprompts"] = nprompt | ||
126 | example["pixel_values"] = instance_image | 128 | example["pixel_values"] = instance_image |
127 | example["input_ids"] = self.tokenizer( | 129 | example["input_ids"] = self.tokenizer( |
128 | text, | 130 | prompt, |
129 | padding="max_length", | 131 | padding="max_length", |
130 | truncation=True, | 132 | truncation=True, |
131 | max_length=self.tokenizer.model_max_length, | 133 | max_length=self.tokenizer.model_max_length, |
@@ -140,6 +142,7 @@ class CSVDataset(Dataset): | |||
140 | unprocessed_example = self.get_example(i) | 142 | unprocessed_example = self.get_example(i) |
141 | 143 | ||
142 | example["prompts"] = unprocessed_example["prompts"] | 144 | example["prompts"] = unprocessed_example["prompts"] |
145 | example["nprompts"] = unprocessed_example["nprompts"] | ||
143 | example["input_ids"] = unprocessed_example["input_ids"] | 146 | example["input_ids"] = unprocessed_example["input_ids"] |
144 | example["pixel_values"] = self.image_transforms(unprocessed_example["pixel_values"]) | 147 | example["pixel_values"] = self.image_transforms(unprocessed_example["pixel_values"]) |
145 | 148 | ||
diff --git a/dreambooth.py b/dreambooth.py index 75602dc..5fbf172 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -191,16 +191,10 @@ def parse_args(): | |||
191 | help="Size of sample images", | 191 | help="Size of sample images", |
192 | ) | 192 | ) |
193 | parser.add_argument( | 193 | parser.add_argument( |
194 | "--stable_sample_batches", | 194 | "--sample_batches", |
195 | type=int, | 195 | type=int, |
196 | default=1, | 196 | default=1, |
197 | help="Number of fixed seed sample batches to generate per checkpoint", | 197 | help="Number of sample batches to generate per checkpoint", |
198 | ) | ||
199 | parser.add_argument( | ||
200 | "--random_sample_batches", | ||
201 | type=int, | ||
202 | default=1, | ||
203 | help="Number of random seed sample batches to generate per checkpoint", | ||
204 | ) | 198 | ) |
205 | parser.add_argument( | 199 | parser.add_argument( |
206 | "--sample_batch_size", | 200 | "--sample_batch_size", |
@@ -331,9 +325,8 @@ class Checkpointer: | |||
331 | text_encoder, | 325 | text_encoder, |
332 | output_dir, | 326 | output_dir, |
333 | sample_image_size, | 327 | sample_image_size, |
334 | random_sample_batches, | 328 | sample_batches, |
335 | sample_batch_size, | 329 | sample_batch_size, |
336 | stable_sample_batches, | ||
337 | seed | 330 | seed |
338 | ): | 331 | ): |
339 | self.datamodule = datamodule | 332 | self.datamodule = datamodule |
@@ -345,9 +338,8 @@ class Checkpointer: | |||
345 | self.output_dir = output_dir | 338 | self.output_dir = output_dir |
346 | self.sample_image_size = sample_image_size | 339 | self.sample_image_size = sample_image_size |
347 | self.seed = seed | 340 | self.seed = seed |
348 | self.random_sample_batches = random_sample_batches | 341 | self.sample_batches = sample_batches |
349 | self.sample_batch_size = sample_batch_size | 342 | self.sample_batch_size = sample_batch_size |
350 | self.stable_sample_batches = stable_sample_batches | ||
351 | 343 | ||
352 | @torch.no_grad() | 344 | @torch.no_grad() |
353 | def checkpoint(self): | 345 | def checkpoint(self): |
@@ -396,63 +388,33 @@ class Checkpointer: | |||
396 | train_data = self.datamodule.train_dataloader() | 388 | train_data = self.datamodule.train_dataloader() |
397 | val_data = self.datamodule.val_dataloader() | 389 | val_data = self.datamodule.val_dataloader() |
398 | 390 | ||
399 | if self.stable_sample_batches > 0: | 391 | generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) |
400 | stable_latents = torch.randn( | 392 | stable_latents = torch.randn( |
401 | (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8), | 393 | (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8), |
402 | device=pipeline.device, | 394 | device=pipeline.device, |
403 | generator=torch.Generator(device=pipeline.device).manual_seed(self.seed), | 395 | generator=generator, |
404 | ) | 396 | ) |
405 | |||
406 | all_samples = [] | ||
407 | file_path = samples_path.joinpath("stable", f"step_{step}.png") | ||
408 | file_path.parent.mkdir(parents=True, exist_ok=True) | ||
409 | |||
410 | data_enum = enumerate(val_data) | ||
411 | |||
412 | # Generate and save stable samples | ||
413 | for i in range(0, self.stable_sample_batches): | ||
414 | prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( | ||
415 | batch["prompts"]) if i * val_data.batch_size + j < self.sample_batch_size] | ||
416 | |||
417 | with self.accelerator.autocast(): | ||
418 | samples = pipeline( | ||
419 | prompt=prompt, | ||
420 | height=self.sample_image_size, | ||
421 | latents=stable_latents[:len(prompt)], | ||
422 | width=self.sample_image_size, | ||
423 | guidance_scale=guidance_scale, | ||
424 | eta=eta, | ||
425 | num_inference_steps=num_inference_steps, | ||
426 | output_type='pil' | ||
427 | )["sample"] | ||
428 | |||
429 | all_samples += samples | ||
430 | |||
431 | del samples | ||
432 | |||
433 | image_grid = make_grid(all_samples, self.stable_sample_batches, self.sample_batch_size) | ||
434 | image_grid.save(file_path) | ||
435 | |||
436 | del all_samples | ||
437 | del image_grid | ||
438 | del stable_latents | ||
439 | 397 | ||
440 | for data, pool in [(val_data, "val"), (train_data, "train")]: | 398 | for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: |
441 | all_samples = [] | 399 | all_samples = [] |
442 | file_path = samples_path.joinpath(pool, f"step_{step}.png") | 400 | file_path = samples_path.joinpath(pool, f"step_{step}.png") |
443 | file_path.parent.mkdir(parents=True, exist_ok=True) | 401 | file_path.parent.mkdir(parents=True, exist_ok=True) |
444 | 402 | ||
445 | data_enum = enumerate(data) | 403 | data_enum = enumerate(data) |
446 | 404 | ||
447 | for i in range(0, self.random_sample_batches): | 405 | for i in range(self.sample_batches): |
448 | prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( | 406 | batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] |
449 | batch["prompts"]) if i * data.batch_size + j < self.sample_batch_size] | 407 | prompt = [prompt for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] |
408 | nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] | ||
450 | 409 | ||
451 | with self.accelerator.autocast(): | 410 | with self.accelerator.autocast(): |
452 | samples = pipeline( | 411 | samples = pipeline( |
453 | prompt=prompt, | 412 | prompt=prompt, |
413 | negative_prompt=nprompt, | ||
454 | height=self.sample_image_size, | 414 | height=self.sample_image_size, |
455 | width=self.sample_image_size, | 415 | width=self.sample_image_size, |
416 | latents=latents[:len(prompt)] if latents is not None else None, | ||
417 | generator=generator if latents is not None else None, | ||
456 | guidance_scale=guidance_scale, | 418 | guidance_scale=guidance_scale, |
457 | eta=eta, | 419 | eta=eta, |
458 | num_inference_steps=num_inference_steps, | 420 | num_inference_steps=num_inference_steps, |
@@ -463,7 +425,7 @@ class Checkpointer: | |||
463 | 425 | ||
464 | del samples | 426 | del samples |
465 | 427 | ||
466 | image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size) | 428 | image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size) |
467 | image_grid.save(file_path) | 429 | image_grid.save(file_path) |
468 | 430 | ||
469 | del all_samples | 431 | del all_samples |
@@ -630,7 +592,7 @@ def main(): | |||
630 | identifier=args.identifier, | 592 | identifier=args.identifier, |
631 | repeats=args.repeats, | 593 | repeats=args.repeats, |
632 | center_crop=args.center_crop, | 594 | center_crop=args.center_crop, |
633 | valid_set_size=args.sample_batch_size*args.stable_sample_batches, | 595 | valid_set_size=args.sample_batch_size*args.sample_batches, |
634 | collate_fn=collate_fn) | 596 | collate_fn=collate_fn) |
635 | 597 | ||
636 | datamodule.prepare_data() | 598 | datamodule.prepare_data() |
@@ -649,8 +611,7 @@ def main(): | |||
649 | output_dir=basepath, | 611 | output_dir=basepath, |
650 | sample_image_size=args.sample_image_size, | 612 | sample_image_size=args.sample_image_size, |
651 | sample_batch_size=args.sample_batch_size, | 613 | sample_batch_size=args.sample_batch_size, |
652 | random_sample_batches=args.random_sample_batches, | 614 | sample_batches=args.sample_batches, |
653 | stable_sample_batches=args.stable_sample_batches, | ||
654 | seed=args.seed or torch.random.seed() | 615 | seed=args.seed or torch.random.seed() |
655 | ) | 616 | ) |
656 | 617 | ||
diff --git a/textual_inversion.py b/textual_inversion.py index 285aa0a..00d460f 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
@@ -207,16 +207,10 @@ def parse_args(): | |||
207 | help="Size of sample images", | 207 | help="Size of sample images", |
208 | ) | 208 | ) |
209 | parser.add_argument( | 209 | parser.add_argument( |
210 | "--stable_sample_batches", | 210 | "--sample_batches", |
211 | type=int, | 211 | type=int, |
212 | default=1, | 212 | default=1, |
213 | help="Number of fixed seed sample batches to generate per checkpoint", | 213 | help="Number of sample batches to generate per checkpoint", |
214 | ) | ||
215 | parser.add_argument( | ||
216 | "--random_sample_batches", | ||
217 | type=int, | ||
218 | default=1, | ||
219 | help="Number of random seed sample batches to generate per checkpoint", | ||
220 | ) | 214 | ) |
221 | parser.add_argument( | 215 | parser.add_argument( |
222 | "--sample_batch_size", | 216 | "--sample_batch_size", |
@@ -319,9 +313,8 @@ class Checkpointer: | |||
319 | placeholder_token_id, | 313 | placeholder_token_id, |
320 | output_dir, | 314 | output_dir, |
321 | sample_image_size, | 315 | sample_image_size, |
322 | random_sample_batches, | 316 | sample_batches, |
323 | sample_batch_size, | 317 | sample_batch_size, |
324 | stable_sample_batches, | ||
325 | seed | 318 | seed |
326 | ): | 319 | ): |
327 | self.datamodule = datamodule | 320 | self.datamodule = datamodule |
@@ -334,9 +327,8 @@ class Checkpointer: | |||
334 | self.output_dir = output_dir | 327 | self.output_dir = output_dir |
335 | self.sample_image_size = sample_image_size | 328 | self.sample_image_size = sample_image_size |
336 | self.seed = seed | 329 | self.seed = seed |
337 | self.random_sample_batches = random_sample_batches | 330 | self.sample_batches = sample_batches |
338 | self.sample_batch_size = sample_batch_size | 331 | self.sample_batch_size = sample_batch_size |
339 | self.stable_sample_batches = stable_sample_batches | ||
340 | 332 | ||
341 | @torch.no_grad() | 333 | @torch.no_grad() |
342 | def checkpoint(self, step, postfix, text_encoder, save_samples=True, path=None): | 334 | def checkpoint(self, step, postfix, text_encoder, save_samples=True, path=None): |
@@ -385,63 +377,33 @@ class Checkpointer: | |||
385 | train_data = self.datamodule.train_dataloader() | 377 | train_data = self.datamodule.train_dataloader() |
386 | val_data = self.datamodule.val_dataloader() | 378 | val_data = self.datamodule.val_dataloader() |
387 | 379 | ||
388 | if self.stable_sample_batches > 0: | 380 | generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) |
389 | stable_latents = torch.randn( | 381 | stable_latents = torch.randn( |
390 | (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8), | 382 | (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8), |
391 | device=pipeline.device, | 383 | device=pipeline.device, |
392 | generator=torch.Generator(device=pipeline.device).manual_seed(self.seed), | 384 | generator=generator, |
393 | ) | 385 | ) |
394 | |||
395 | all_samples = [] | ||
396 | file_path = samples_path.joinpath("stable", f"step_{step}.png") | ||
397 | file_path.parent.mkdir(parents=True, exist_ok=True) | ||
398 | |||
399 | data_enum = enumerate(val_data) | ||
400 | |||
401 | # Generate and save stable samples | ||
402 | for i in range(0, self.stable_sample_batches): | ||
403 | prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( | ||
404 | batch["prompts"]) if i * val_data.batch_size + j < self.sample_batch_size] | ||
405 | |||
406 | with self.accelerator.autocast(): | ||
407 | samples = pipeline( | ||
408 | prompt=prompt, | ||
409 | height=self.sample_image_size, | ||
410 | latents=stable_latents[:len(prompt)], | ||
411 | width=self.sample_image_size, | ||
412 | guidance_scale=guidance_scale, | ||
413 | eta=eta, | ||
414 | num_inference_steps=num_inference_steps, | ||
415 | output_type='pil' | ||
416 | )["sample"] | ||
417 | |||
418 | all_samples += samples | ||
419 | |||
420 | del samples | ||
421 | |||
422 | image_grid = make_grid(all_samples, self.stable_sample_batches, self.sample_batch_size) | ||
423 | image_grid.save(file_path) | ||
424 | |||
425 | del all_samples | ||
426 | del image_grid | ||
427 | del stable_latents | ||
428 | 386 | ||
429 | for data, pool in [(val_data, "val"), (train_data, "train")]: | 387 | for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: |
430 | all_samples = [] | 388 | all_samples = [] |
431 | file_path = samples_path.joinpath(pool, f"step_{step}.png") | 389 | file_path = samples_path.joinpath(pool, f"step_{step}.png") |
432 | file_path.parent.mkdir(parents=True, exist_ok=True) | 390 | file_path.parent.mkdir(parents=True, exist_ok=True) |
433 | 391 | ||
434 | data_enum = enumerate(data) | 392 | data_enum = enumerate(data) |
435 | 393 | ||
436 | for i in range(0, self.random_sample_batches): | 394 | for i in range(self.sample_batches): |
437 | prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( | 395 | batches = [batch for j, batch in data_enum if j * data.batch_size < self.sample_batch_size] |
438 | batch["prompts"]) if i * data.batch_size + j < self.sample_batch_size] | 396 | prompt = [prompt for batch in batches for prompt in batch["prompts"]][:self.sample_batch_size] |
397 | nprompt = [prompt for batch in batches for prompt in batch["nprompts"]][:self.sample_batch_size] | ||
439 | 398 | ||
440 | with self.accelerator.autocast(): | 399 | with self.accelerator.autocast(): |
441 | samples = pipeline( | 400 | samples = pipeline( |
442 | prompt=prompt, | 401 | prompt=prompt, |
402 | negative_prompt=nprompt, | ||
443 | height=self.sample_image_size, | 403 | height=self.sample_image_size, |
444 | width=self.sample_image_size, | 404 | width=self.sample_image_size, |
405 | latents=latents[:len(prompt)] if latents is not None else None, | ||
406 | generator=generator if latents is not None else None, | ||
445 | guidance_scale=guidance_scale, | 407 | guidance_scale=guidance_scale, |
446 | eta=eta, | 408 | eta=eta, |
447 | num_inference_steps=num_inference_steps, | 409 | num_inference_steps=num_inference_steps, |
@@ -452,7 +414,7 @@ class Checkpointer: | |||
452 | 414 | ||
453 | del samples | 415 | del samples |
454 | 416 | ||
455 | image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size) | 417 | image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size) |
456 | image_grid.save(file_path) | 418 | image_grid.save(file_path) |
457 | 419 | ||
458 | del all_samples | 420 | del all_samples |
@@ -461,6 +423,8 @@ class Checkpointer: | |||
461 | del unwrapped | 423 | del unwrapped |
462 | del scheduler | 424 | del scheduler |
463 | del pipeline | 425 | del pipeline |
426 | del generator | ||
427 | del stable_latents | ||
464 | 428 | ||
465 | if torch.cuda.is_available(): | 429 | if torch.cuda.is_available(): |
466 | torch.cuda.empty_cache() | 430 | torch.cuda.empty_cache() |
@@ -603,7 +567,7 @@ def main(): | |||
603 | placeholder_token=args.placeholder_token, | 567 | placeholder_token=args.placeholder_token, |
604 | repeats=args.repeats, | 568 | repeats=args.repeats, |
605 | center_crop=args.center_crop, | 569 | center_crop=args.center_crop, |
606 | valid_set_size=args.sample_batch_size*args.stable_sample_batches | 570 | valid_set_size=args.sample_batch_size*args.sample_batches |
607 | ) | 571 | ) |
608 | 572 | ||
609 | datamodule.prepare_data() | 573 | datamodule.prepare_data() |
@@ -623,8 +587,7 @@ def main(): | |||
623 | output_dir=basepath, | 587 | output_dir=basepath, |
624 | sample_image_size=args.sample_image_size, | 588 | sample_image_size=args.sample_image_size, |
625 | sample_batch_size=args.sample_batch_size, | 589 | sample_batch_size=args.sample_batch_size, |
626 | random_sample_batches=args.random_sample_batches, | 590 | sample_batches=args.sample_batches, |
627 | stable_sample_batches=args.stable_sample_batches, | ||
628 | seed=args.seed or torch.random.seed() | 591 | seed=args.seed or torch.random.seed() |
629 | ) | 592 | ) |
630 | 593 | ||