diff options
| author | Volpeon <git@volpeon.ink> | 2022-09-27 08:21:28 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-09-27 08:21:28 +0200 |
| commit | 68a3735352d3e62d46556b677407fc71b78c47c4 (patch) | |
| tree | 5732c82dc668da7fe702e5fe1fdfaa09e5cec2d0 | |
| parent | Autocast on sample generation, progress bar cleanup (diff) | |
| download | textual-inversion-diff-68a3735352d3e62d46556b677407fc71b78c47c4.tar.gz textual-inversion-diff-68a3735352d3e62d46556b677407fc71b78c47c4.tar.bz2 textual-inversion-diff-68a3735352d3e62d46556b677407fc71b78c47c4.zip | |
Better sample filenames, optimizations
| -rw-r--r-- | .gitignore | 4 | ||||
| -rw-r--r-- | environment.yaml | 10 | ||||
| -rw-r--r-- | main.py | 37 |
3 files changed, 32 insertions, 19 deletions
| @@ -160,5 +160,5 @@ cython_debug/ | |||
| 160 | #.idea/ | 160 | #.idea/ |
| 161 | 161 | ||
| 162 | text-inversion-model/ | 162 | text-inversion-model/ |
| 163 | conf.json | 163 | conf*.json |
| 164 | v1-inference.yaml | 164 | v1-inference.yaml* |
diff --git a/environment.yaml b/environment.yaml index a460158..8dc930c 100644 --- a/environment.yaml +++ b/environment.yaml | |||
| @@ -16,21 +16,19 @@ dependencies: | |||
| 16 | - -e git+https://github.com/openai/CLIP.git@main#egg=clip | 16 | - -e git+https://github.com/openai/CLIP.git@main#egg=clip |
| 17 | - -e git+https://github.com/hlky/k-diffusion-sd#egg=k_diffusion | 17 | - -e git+https://github.com/hlky/k-diffusion-sd#egg=k_diffusion |
| 18 | - -e git+https://github.com/devilismyfriend/latent-diffusion#egg=latent-diffusion | 18 | - -e git+https://github.com/devilismyfriend/latent-diffusion#egg=latent-diffusion |
| 19 | - -e git+https://github.com/huggingface/diffusers#egg=diffusers | ||
| 19 | - accelerate==0.12.0 | 20 | - accelerate==0.12.0 |
| 20 | - albumentations==1.1.0 | 21 | - albumentations==1.1.0 |
| 21 | - diffusers==0.3.0 | ||
| 22 | - einops==0.4.1 | 22 | - einops==0.4.1 |
| 23 | - imageio-ffmpeg==0.4.7 | 23 | - imageio==2.22.0 |
| 24 | - imageio==2.14.1 | ||
| 25 | - kornia==0.6 | 24 | - kornia==0.6 |
| 26 | - pudb==2019.2 | 25 | - pudb==2019.2 |
| 27 | - omegaconf==2.1.1 | 26 | - omegaconf==2.2.3 |
| 28 | - opencv-python-headless==4.6.0.66 | 27 | - opencv-python-headless==4.6.0.66 |
| 29 | - python-slugify>=6.1.2 | 28 | - python-slugify>=6.1.2 |
| 30 | - pytorch-lightning==1.7.7 | 29 | - pytorch-lightning==1.7.7 |
| 31 | - setuptools==59.5.0 | 30 | - setuptools==59.5.0 |
| 32 | - streamlit>=0.73.1 | ||
| 33 | - test-tube>=0.7.5 | 31 | - test-tube>=0.7.5 |
| 34 | - torch-fidelity==0.3.0 | 32 | - torch-fidelity==0.3.0 |
| 35 | - torchmetrics==0.9.3 | 33 | - torchmetrics==0.9.3 |
| 36 | - transformers==4.19.2 | 34 | - transformers==4.22.1 |
| @@ -106,6 +106,11 @@ def parse_args(): | |||
| 106 | help="Number of updates steps to accumulate before performing a backward/update pass.", | 106 | help="Number of updates steps to accumulate before performing a backward/update pass.", |
| 107 | ) | 107 | ) |
| 108 | parser.add_argument( | 108 | parser.add_argument( |
| 109 | "--gradient_checkpointing", | ||
| 110 | action="store_true", | ||
| 111 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", | ||
| 112 | ) | ||
| 113 | parser.add_argument( | ||
| 109 | "--learning_rate", | 114 | "--learning_rate", |
| 110 | type=float, | 115 | type=float, |
| 111 | default=1e-4, | 116 | default=1e-4, |
| @@ -213,7 +218,7 @@ def parse_args(): | |||
| 213 | elif args.config is not None: | 218 | elif args.config is not None: |
| 214 | with open(args.config, 'rt') as f: | 219 | with open(args.config, 'rt') as f: |
| 215 | args = parser.parse_args( | 220 | args = parser.parse_args( |
| 216 | namespace=argparse.Namespace(**json.load(f))) | 221 | namespace=argparse.Namespace(**json.load(f)["args"])) |
| 217 | 222 | ||
| 218 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) | 223 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) |
| 219 | if env_local_rank != -1 and env_local_rank != args.local_rank: | 224 | if env_local_rank != -1 and env_local_rank != args.local_rank: |
| @@ -289,7 +294,7 @@ class Checkpointer: | |||
| 289 | self.stable_sample_batches = stable_sample_batches | 294 | self.stable_sample_batches = stable_sample_batches |
| 290 | 295 | ||
| 291 | @torch.no_grad() | 296 | @torch.no_grad() |
| 292 | def checkpoint(self, step, text_encoder, save_samples=True, path=None): | 297 | def checkpoint(self, step, postfix, text_encoder, save_samples=True, path=None): |
| 293 | print("Saving checkpoint for step %d..." % step) | 298 | print("Saving checkpoint for step %d..." % step) |
| 294 | with self.accelerator.autocast(): | 299 | with self.accelerator.autocast(): |
| 295 | if path is None: | 300 | if path is None: |
| @@ -302,12 +307,11 @@ class Checkpointer: | |||
| 302 | learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id] | 307 | learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id] |
| 303 | learned_embeds_dict = {self.placeholder_token: learned_embeds.detach().cpu()} | 308 | learned_embeds_dict = {self.placeholder_token: learned_embeds.detach().cpu()} |
| 304 | 309 | ||
| 305 | filename = f"%s_%d.bin" % (slugify(self.placeholder_token), step) | 310 | filename = f"%s_%d_%s.bin" % (slugify(self.placeholder_token), step, postfix) |
| 306 | if path is not None: | 311 | if path is not None: |
| 307 | torch.save(learned_embeds_dict, path) | 312 | torch.save(learned_embeds_dict, path) |
| 308 | else: | 313 | else: |
| 309 | torch.save(learned_embeds_dict, | 314 | torch.save(learned_embeds_dict, f"{checkpoints_path}/{filename}") |
| 310 | f"{checkpoints_path}/{filename}") | ||
| 311 | torch.save(learned_embeds_dict, f"{checkpoints_path}/last.bin") | 315 | torch.save(learned_embeds_dict, f"{checkpoints_path}/last.bin") |
| 312 | del unwrapped | 316 | del unwrapped |
| 313 | del learned_embeds | 317 | del learned_embeds |
| @@ -338,7 +342,7 @@ class Checkpointer: | |||
| 338 | "validation": self.datamodule.val_dataloader(), | 342 | "validation": self.datamodule.val_dataloader(), |
| 339 | }[mode] | 343 | }[mode] |
| 340 | 344 | ||
| 341 | if mode == "validation" and self.stable_sample_batches > 0: | 345 | if mode == "validation" and self.stable_sample_batches > 0 and step > 0: |
| 342 | stable_latents = torch.randn( | 346 | stable_latents = torch.randn( |
| 343 | (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8), | 347 | (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8), |
| 344 | device=pipeline.device, | 348 | device=pipeline.device, |
| @@ -348,15 +352,18 @@ class Checkpointer: | |||
| 348 | all_samples = [] | 352 | all_samples = [] |
| 349 | filename = f"stable_step_%d.png" % (step) | 353 | filename = f"stable_step_%d.png" % (step) |
| 350 | 354 | ||
| 355 | data_enum = enumerate(data) | ||
| 356 | |||
| 351 | # Generate and save stable samples | 357 | # Generate and save stable samples |
| 352 | for i in range(0, self.stable_sample_batches): | 358 | for i in range(0, self.stable_sample_batches): |
| 353 | prompt = [batch["prompt"] for i, batch in enumerate(data) if i < self.sample_batch_size] | 359 | prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( |
| 360 | batch["prompt"]) if i * data.batch_size + j < self.sample_batch_size] | ||
| 354 | 361 | ||
| 355 | with self.accelerator.autocast(): | 362 | with self.accelerator.autocast(): |
| 356 | samples = pipeline( | 363 | samples = pipeline( |
| 357 | prompt=prompt, | 364 | prompt=prompt, |
| 358 | height=self.sample_image_size, | 365 | height=self.sample_image_size, |
| 359 | latents=stable_latents, | 366 | latents=stable_latents[:len(prompt)], |
| 360 | width=self.sample_image_size, | 367 | width=self.sample_image_size, |
| 361 | guidance_scale=guidance_scale, | 368 | guidance_scale=guidance_scale, |
| 362 | eta=eta, | 369 | eta=eta, |
| @@ -377,9 +384,12 @@ class Checkpointer: | |||
| 377 | all_samples = [] | 384 | all_samples = [] |
| 378 | filename = f"step_%d.png" % (step) | 385 | filename = f"step_%d.png" % (step) |
| 379 | 386 | ||
| 387 | data_enum = enumerate(data) | ||
| 388 | |||
| 380 | # Generate and save random samples | 389 | # Generate and save random samples |
| 381 | for i in range(0, self.random_sample_batches): | 390 | for i in range(0, self.random_sample_batches): |
| 382 | prompt = [batch["prompt"] for i, batch in enumerate(data) if i < self.sample_batch_size] | 391 | prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( |
| 392 | batch["prompt"]) if i * data.batch_size + j < self.sample_batch_size] | ||
| 383 | 393 | ||
| 384 | with self.accelerator.autocast(): | 394 | with self.accelerator.autocast(): |
| 385 | samples = pipeline( | 395 | samples = pipeline( |
| @@ -486,6 +496,9 @@ def main(): | |||
| 486 | args.pretrained_model_name_or_path + '/unet', | 496 | args.pretrained_model_name_or_path + '/unet', |
| 487 | ) | 497 | ) |
| 488 | 498 | ||
| 499 | if args.gradient_checkpointing: | ||
| 500 | unet.enable_gradient_checkpointing() | ||
| 501 | |||
| 489 | slice_size = unet.config.attention_head_dim // 2 | 502 | slice_size = unet.config.attention_head_dim // 2 |
| 490 | unet.set_attention_slice(slice_size) | 503 | unet.set_attention_slice(slice_size) |
| 491 | 504 | ||
| @@ -693,7 +706,7 @@ def main(): | |||
| 693 | progress_bar.clear() | 706 | progress_bar.clear() |
| 694 | local_progress_bar.clear() | 707 | local_progress_bar.clear() |
| 695 | 708 | ||
| 696 | checkpointer.checkpoint(global_step + global_step_offset, text_encoder) | 709 | checkpointer.checkpoint(global_step + global_step_offset, "training", text_encoder) |
| 697 | save_resume_file(basepath, args, { | 710 | save_resume_file(basepath, args, { |
| 698 | "global_step": global_step + global_step_offset, | 711 | "global_step": global_step + global_step_offset, |
| 699 | "resume_checkpoint": f"{basepath}/checkpoints/last.bin" | 712 | "resume_checkpoint": f"{basepath}/checkpoints/last.bin" |
| @@ -753,6 +766,7 @@ def main(): | |||
| 753 | 766 | ||
| 754 | if min_val_loss > val_loss: | 767 | if min_val_loss > val_loss: |
| 755 | accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") | 768 | accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") |
| 769 | checkpointer.checkpoint(global_step + global_step_offset, "milestone", text_encoder) | ||
| 756 | min_val_loss = val_loss | 770 | min_val_loss = val_loss |
| 757 | 771 | ||
| 758 | checkpointer.save_samples( | 772 | checkpointer.save_samples( |
| @@ -768,6 +782,7 @@ def main(): | |||
| 768 | print("Finished! Saving final checkpoint and resume state.") | 782 | print("Finished! Saving final checkpoint and resume state.") |
| 769 | checkpointer.checkpoint( | 783 | checkpointer.checkpoint( |
| 770 | global_step + global_step_offset, | 784 | global_step + global_step_offset, |
| 785 | "end", | ||
| 771 | text_encoder, | 786 | text_encoder, |
| 772 | path=f"{basepath}/learned_embeds.bin" | 787 | path=f"{basepath}/learned_embeds.bin" |
| 773 | ) | 788 | ) |
| @@ -782,7 +797,7 @@ def main(): | |||
| 782 | except KeyboardInterrupt: | 797 | except KeyboardInterrupt: |
| 783 | if accelerator.is_main_process: | 798 | if accelerator.is_main_process: |
| 784 | print("Interrupted, saving checkpoint and resume state...") | 799 | print("Interrupted, saving checkpoint and resume state...") |
| 785 | checkpointer.checkpoint(global_step + global_step_offset, text_encoder) | 800 | checkpointer.checkpoint(global_step + global_step_offset, "end", text_encoder) |
| 786 | save_resume_file(basepath, args, { | 801 | save_resume_file(basepath, args, { |
| 787 | "global_step": global_step + global_step_offset, | 802 | "global_step": global_step + global_step_offset, |
| 788 | "resume_checkpoint": f"{basepath}/checkpoints/last.bin" | 803 | "resume_checkpoint": f"{basepath}/checkpoints/last.bin" |
