diff options
-rw-r--r-- | .gitignore | 4 | ||||
-rw-r--r-- | environment.yaml | 10 | ||||
-rw-r--r-- | main.py | 37 |
3 files changed, 32 insertions, 19 deletions
@@ -160,5 +160,5 @@ cython_debug/ | |||
160 | #.idea/ | 160 | #.idea/ |
161 | 161 | ||
162 | text-inversion-model/ | 162 | text-inversion-model/ |
163 | conf.json | 163 | conf*.json |
164 | v1-inference.yaml | 164 | v1-inference.yaml* |
diff --git a/environment.yaml b/environment.yaml index a460158..8dc930c 100644 --- a/environment.yaml +++ b/environment.yaml | |||
@@ -16,21 +16,19 @@ dependencies: | |||
16 | - -e git+https://github.com/openai/CLIP.git@main#egg=clip | 16 | - -e git+https://github.com/openai/CLIP.git@main#egg=clip |
17 | - -e git+https://github.com/hlky/k-diffusion-sd#egg=k_diffusion | 17 | - -e git+https://github.com/hlky/k-diffusion-sd#egg=k_diffusion |
18 | - -e git+https://github.com/devilismyfriend/latent-diffusion#egg=latent-diffusion | 18 | - -e git+https://github.com/devilismyfriend/latent-diffusion#egg=latent-diffusion |
19 | - -e git+https://github.com/huggingface/diffusers#egg=diffusers | ||
19 | - accelerate==0.12.0 | 20 | - accelerate==0.12.0 |
20 | - albumentations==1.1.0 | 21 | - albumentations==1.1.0 |
21 | - diffusers==0.3.0 | ||
22 | - einops==0.4.1 | 22 | - einops==0.4.1 |
23 | - imageio-ffmpeg==0.4.7 | 23 | - imageio==2.22.0 |
24 | - imageio==2.14.1 | ||
25 | - kornia==0.6 | 24 | - kornia==0.6 |
26 | - pudb==2019.2 | 25 | - pudb==2019.2 |
27 | - omegaconf==2.1.1 | 26 | - omegaconf==2.2.3 |
28 | - opencv-python-headless==4.6.0.66 | 27 | - opencv-python-headless==4.6.0.66 |
29 | - python-slugify>=6.1.2 | 28 | - python-slugify>=6.1.2 |
30 | - pytorch-lightning==1.7.7 | 29 | - pytorch-lightning==1.7.7 |
31 | - setuptools==59.5.0 | 30 | - setuptools==59.5.0 |
32 | - streamlit>=0.73.1 | ||
33 | - test-tube>=0.7.5 | 31 | - test-tube>=0.7.5 |
34 | - torch-fidelity==0.3.0 | 32 | - torch-fidelity==0.3.0 |
35 | - torchmetrics==0.9.3 | 33 | - torchmetrics==0.9.3 |
36 | - transformers==4.19.2 | 34 | - transformers==4.22.1 |
@@ -106,6 +106,11 @@ def parse_args(): | |||
106 | help="Number of updates steps to accumulate before performing a backward/update pass.", | 106 | help="Number of updates steps to accumulate before performing a backward/update pass.", |
107 | ) | 107 | ) |
108 | parser.add_argument( | 108 | parser.add_argument( |
109 | "--gradient_checkpointing", | ||
110 | action="store_true", | ||
111 | help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", | ||
112 | ) | ||
113 | parser.add_argument( | ||
109 | "--learning_rate", | 114 | "--learning_rate", |
110 | type=float, | 115 | type=float, |
111 | default=1e-4, | 116 | default=1e-4, |
@@ -213,7 +218,7 @@ def parse_args(): | |||
213 | elif args.config is not None: | 218 | elif args.config is not None: |
214 | with open(args.config, 'rt') as f: | 219 | with open(args.config, 'rt') as f: |
215 | args = parser.parse_args( | 220 | args = parser.parse_args( |
216 | namespace=argparse.Namespace(**json.load(f))) | 221 | namespace=argparse.Namespace(**json.load(f)["args"])) |
217 | 222 | ||
218 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) | 223 | env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) |
219 | if env_local_rank != -1 and env_local_rank != args.local_rank: | 224 | if env_local_rank != -1 and env_local_rank != args.local_rank: |
@@ -289,7 +294,7 @@ class Checkpointer: | |||
289 | self.stable_sample_batches = stable_sample_batches | 294 | self.stable_sample_batches = stable_sample_batches |
290 | 295 | ||
291 | @torch.no_grad() | 296 | @torch.no_grad() |
292 | def checkpoint(self, step, text_encoder, save_samples=True, path=None): | 297 | def checkpoint(self, step, postfix, text_encoder, save_samples=True, path=None): |
293 | print("Saving checkpoint for step %d..." % step) | 298 | print("Saving checkpoint for step %d..." % step) |
294 | with self.accelerator.autocast(): | 299 | with self.accelerator.autocast(): |
295 | if path is None: | 300 | if path is None: |
@@ -302,12 +307,11 @@ class Checkpointer: | |||
302 | learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id] | 307 | learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id] |
303 | learned_embeds_dict = {self.placeholder_token: learned_embeds.detach().cpu()} | 308 | learned_embeds_dict = {self.placeholder_token: learned_embeds.detach().cpu()} |
304 | 309 | ||
305 | filename = f"%s_%d.bin" % (slugify(self.placeholder_token), step) | 310 | filename = f"%s_%d_%s.bin" % (slugify(self.placeholder_token), step, postfix) |
306 | if path is not None: | 311 | if path is not None: |
307 | torch.save(learned_embeds_dict, path) | 312 | torch.save(learned_embeds_dict, path) |
308 | else: | 313 | else: |
309 | torch.save(learned_embeds_dict, | 314 | torch.save(learned_embeds_dict, f"{checkpoints_path}/{filename}") |
310 | f"{checkpoints_path}/{filename}") | ||
311 | torch.save(learned_embeds_dict, f"{checkpoints_path}/last.bin") | 315 | torch.save(learned_embeds_dict, f"{checkpoints_path}/last.bin") |
312 | del unwrapped | 316 | del unwrapped |
313 | del learned_embeds | 317 | del learned_embeds |
@@ -338,7 +342,7 @@ class Checkpointer: | |||
338 | "validation": self.datamodule.val_dataloader(), | 342 | "validation": self.datamodule.val_dataloader(), |
339 | }[mode] | 343 | }[mode] |
340 | 344 | ||
341 | if mode == "validation" and self.stable_sample_batches > 0: | 345 | if mode == "validation" and self.stable_sample_batches > 0 and step > 0: |
342 | stable_latents = torch.randn( | 346 | stable_latents = torch.randn( |
343 | (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8), | 347 | (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8), |
344 | device=pipeline.device, | 348 | device=pipeline.device, |
@@ -348,15 +352,18 @@ class Checkpointer: | |||
348 | all_samples = [] | 352 | all_samples = [] |
349 | filename = f"stable_step_%d.png" % (step) | 353 | filename = f"stable_step_%d.png" % (step) |
350 | 354 | ||
355 | data_enum = enumerate(data) | ||
356 | |||
351 | # Generate and save stable samples | 357 | # Generate and save stable samples |
352 | for i in range(0, self.stable_sample_batches): | 358 | for i in range(0, self.stable_sample_batches): |
353 | prompt = [batch["prompt"] for i, batch in enumerate(data) if i < self.sample_batch_size] | 359 | prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( |
360 | batch["prompt"]) if i * data.batch_size + j < self.sample_batch_size] | ||
354 | 361 | ||
355 | with self.accelerator.autocast(): | 362 | with self.accelerator.autocast(): |
356 | samples = pipeline( | 363 | samples = pipeline( |
357 | prompt=prompt, | 364 | prompt=prompt, |
358 | height=self.sample_image_size, | 365 | height=self.sample_image_size, |
359 | latents=stable_latents, | 366 | latents=stable_latents[:len(prompt)], |
360 | width=self.sample_image_size, | 367 | width=self.sample_image_size, |
361 | guidance_scale=guidance_scale, | 368 | guidance_scale=guidance_scale, |
362 | eta=eta, | 369 | eta=eta, |
@@ -377,9 +384,12 @@ class Checkpointer: | |||
377 | all_samples = [] | 384 | all_samples = [] |
378 | filename = f"step_%d.png" % (step) | 385 | filename = f"step_%d.png" % (step) |
379 | 386 | ||
387 | data_enum = enumerate(data) | ||
388 | |||
380 | # Generate and save random samples | 389 | # Generate and save random samples |
381 | for i in range(0, self.random_sample_batches): | 390 | for i in range(0, self.random_sample_batches): |
382 | prompt = [batch["prompt"] for i, batch in enumerate(data) if i < self.sample_batch_size] | 391 | prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( |
392 | batch["prompt"]) if i * data.batch_size + j < self.sample_batch_size] | ||
383 | 393 | ||
384 | with self.accelerator.autocast(): | 394 | with self.accelerator.autocast(): |
385 | samples = pipeline( | 395 | samples = pipeline( |
@@ -486,6 +496,9 @@ def main(): | |||
486 | args.pretrained_model_name_or_path + '/unet', | 496 | args.pretrained_model_name_or_path + '/unet', |
487 | ) | 497 | ) |
488 | 498 | ||
499 | if args.gradient_checkpointing: | ||
500 | unet.enable_gradient_checkpointing() | ||
501 | |||
489 | slice_size = unet.config.attention_head_dim // 2 | 502 | slice_size = unet.config.attention_head_dim // 2 |
490 | unet.set_attention_slice(slice_size) | 503 | unet.set_attention_slice(slice_size) |
491 | 504 | ||
@@ -693,7 +706,7 @@ def main(): | |||
693 | progress_bar.clear() | 706 | progress_bar.clear() |
694 | local_progress_bar.clear() | 707 | local_progress_bar.clear() |
695 | 708 | ||
696 | checkpointer.checkpoint(global_step + global_step_offset, text_encoder) | 709 | checkpointer.checkpoint(global_step + global_step_offset, "training", text_encoder) |
697 | save_resume_file(basepath, args, { | 710 | save_resume_file(basepath, args, { |
698 | "global_step": global_step + global_step_offset, | 711 | "global_step": global_step + global_step_offset, |
699 | "resume_checkpoint": f"{basepath}/checkpoints/last.bin" | 712 | "resume_checkpoint": f"{basepath}/checkpoints/last.bin" |
@@ -753,6 +766,7 @@ def main(): | |||
753 | 766 | ||
754 | if min_val_loss > val_loss: | 767 | if min_val_loss > val_loss: |
755 | accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") | 768 | accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") |
769 | checkpointer.checkpoint(global_step + global_step_offset, "milestone", text_encoder) | ||
756 | min_val_loss = val_loss | 770 | min_val_loss = val_loss |
757 | 771 | ||
758 | checkpointer.save_samples( | 772 | checkpointer.save_samples( |
@@ -768,6 +782,7 @@ def main(): | |||
768 | print("Finished! Saving final checkpoint and resume state.") | 782 | print("Finished! Saving final checkpoint and resume state.") |
769 | checkpointer.checkpoint( | 783 | checkpointer.checkpoint( |
770 | global_step + global_step_offset, | 784 | global_step + global_step_offset, |
785 | "end", | ||
771 | text_encoder, | 786 | text_encoder, |
772 | path=f"{basepath}/learned_embeds.bin" | 787 | path=f"{basepath}/learned_embeds.bin" |
773 | ) | 788 | ) |
@@ -782,7 +797,7 @@ def main(): | |||
782 | except KeyboardInterrupt: | 797 | except KeyboardInterrupt: |
783 | if accelerator.is_main_process: | 798 | if accelerator.is_main_process: |
784 | print("Interrupted, saving checkpoint and resume state...") | 799 | print("Interrupted, saving checkpoint and resume state...") |
785 | checkpointer.checkpoint(global_step + global_step_offset, text_encoder) | 800 | checkpointer.checkpoint(global_step + global_step_offset, "end", text_encoder) |
786 | save_resume_file(basepath, args, { | 801 | save_resume_file(basepath, args, { |
787 | "global_step": global_step + global_step_offset, | 802 | "global_step": global_step + global_step_offset, |
788 | "resume_checkpoint": f"{basepath}/checkpoints/last.bin" | 803 | "resume_checkpoint": f"{basepath}/checkpoints/last.bin" |