summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-09-27 08:21:28 +0200
committerVolpeon <git@volpeon.ink>2022-09-27 08:21:28 +0200
commit68a3735352d3e62d46556b677407fc71b78c47c4 (patch)
tree5732c82dc668da7fe702e5fe1fdfaa09e5cec2d0
parentAutocast on sample generation, progress bar cleanup (diff)
downloadtextual-inversion-diff-68a3735352d3e62d46556b677407fc71b78c47c4.tar.gz
textual-inversion-diff-68a3735352d3e62d46556b677407fc71b78c47c4.tar.bz2
textual-inversion-diff-68a3735352d3e62d46556b677407fc71b78c47c4.zip
Better sample filenames, optimizations
-rw-r--r--.gitignore4
-rw-r--r--environment.yaml10
-rw-r--r--main.py37
3 files changed, 32 insertions, 19 deletions
diff --git a/.gitignore b/.gitignore
index 00e7681..a8893c3 100644
--- a/.gitignore
+++ b/.gitignore
@@ -160,5 +160,5 @@ cython_debug/
160#.idea/ 160#.idea/
161 161
162text-inversion-model/ 162text-inversion-model/
163conf.json 163conf*.json
164v1-inference.yaml 164v1-inference.yaml*
diff --git a/environment.yaml b/environment.yaml
index a460158..8dc930c 100644
--- a/environment.yaml
+++ b/environment.yaml
@@ -16,21 +16,19 @@ dependencies:
16 - -e git+https://github.com/openai/CLIP.git@main#egg=clip 16 - -e git+https://github.com/openai/CLIP.git@main#egg=clip
17 - -e git+https://github.com/hlky/k-diffusion-sd#egg=k_diffusion 17 - -e git+https://github.com/hlky/k-diffusion-sd#egg=k_diffusion
18 - -e git+https://github.com/devilismyfriend/latent-diffusion#egg=latent-diffusion 18 - -e git+https://github.com/devilismyfriend/latent-diffusion#egg=latent-diffusion
19 - -e git+https://github.com/huggingface/diffusers#egg=diffusers
19 - accelerate==0.12.0 20 - accelerate==0.12.0
20 - albumentations==1.1.0 21 - albumentations==1.1.0
21 - diffusers==0.3.0
22 - einops==0.4.1 22 - einops==0.4.1
23 - imageio-ffmpeg==0.4.7 23 - imageio==2.22.0
24 - imageio==2.14.1
25 - kornia==0.6 24 - kornia==0.6
26 - pudb==2019.2 25 - pudb==2019.2
27 - omegaconf==2.1.1 26 - omegaconf==2.2.3
28 - opencv-python-headless==4.6.0.66 27 - opencv-python-headless==4.6.0.66
29 - python-slugify>=6.1.2 28 - python-slugify>=6.1.2
30 - pytorch-lightning==1.7.7 29 - pytorch-lightning==1.7.7
31 - setuptools==59.5.0 30 - setuptools==59.5.0
32 - streamlit>=0.73.1
33 - test-tube>=0.7.5 31 - test-tube>=0.7.5
34 - torch-fidelity==0.3.0 32 - torch-fidelity==0.3.0
35 - torchmetrics==0.9.3 33 - torchmetrics==0.9.3
36 - transformers==4.19.2 34 - transformers==4.22.1
diff --git a/main.py b/main.py
index aa5af72..8be79e5 100644
--- a/main.py
+++ b/main.py
@@ -106,6 +106,11 @@ def parse_args():
106 help="Number of updates steps to accumulate before performing a backward/update pass.", 106 help="Number of updates steps to accumulate before performing a backward/update pass.",
107 ) 107 )
108 parser.add_argument( 108 parser.add_argument(
109 "--gradient_checkpointing",
110 action="store_true",
111 help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
112 )
113 parser.add_argument(
109 "--learning_rate", 114 "--learning_rate",
110 type=float, 115 type=float,
111 default=1e-4, 116 default=1e-4,
@@ -213,7 +218,7 @@ def parse_args():
213 elif args.config is not None: 218 elif args.config is not None:
214 with open(args.config, 'rt') as f: 219 with open(args.config, 'rt') as f:
215 args = parser.parse_args( 220 args = parser.parse_args(
216 namespace=argparse.Namespace(**json.load(f))) 221 namespace=argparse.Namespace(**json.load(f)["args"]))
217 222
218 env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) 223 env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
219 if env_local_rank != -1 and env_local_rank != args.local_rank: 224 if env_local_rank != -1 and env_local_rank != args.local_rank:
@@ -289,7 +294,7 @@ class Checkpointer:
289 self.stable_sample_batches = stable_sample_batches 294 self.stable_sample_batches = stable_sample_batches
290 295
291 @torch.no_grad() 296 @torch.no_grad()
292 def checkpoint(self, step, text_encoder, save_samples=True, path=None): 297 def checkpoint(self, step, postfix, text_encoder, save_samples=True, path=None):
293 print("Saving checkpoint for step %d..." % step) 298 print("Saving checkpoint for step %d..." % step)
294 with self.accelerator.autocast(): 299 with self.accelerator.autocast():
295 if path is None: 300 if path is None:
@@ -302,12 +307,11 @@ class Checkpointer:
302 learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id] 307 learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id]
303 learned_embeds_dict = {self.placeholder_token: learned_embeds.detach().cpu()} 308 learned_embeds_dict = {self.placeholder_token: learned_embeds.detach().cpu()}
304 309
305 filename = f"%s_%d.bin" % (slugify(self.placeholder_token), step) 310 filename = f"%s_%d_%s.bin" % (slugify(self.placeholder_token), step, postfix)
306 if path is not None: 311 if path is not None:
307 torch.save(learned_embeds_dict, path) 312 torch.save(learned_embeds_dict, path)
308 else: 313 else:
309 torch.save(learned_embeds_dict, 314 torch.save(learned_embeds_dict, f"{checkpoints_path}/{filename}")
310 f"{checkpoints_path}/{filename}")
311 torch.save(learned_embeds_dict, f"{checkpoints_path}/last.bin") 315 torch.save(learned_embeds_dict, f"{checkpoints_path}/last.bin")
312 del unwrapped 316 del unwrapped
313 del learned_embeds 317 del learned_embeds
@@ -338,7 +342,7 @@ class Checkpointer:
338 "validation": self.datamodule.val_dataloader(), 342 "validation": self.datamodule.val_dataloader(),
339 }[mode] 343 }[mode]
340 344
341 if mode == "validation" and self.stable_sample_batches > 0: 345 if mode == "validation" and self.stable_sample_batches > 0 and step > 0:
342 stable_latents = torch.randn( 346 stable_latents = torch.randn(
343 (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8), 347 (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8),
344 device=pipeline.device, 348 device=pipeline.device,
@@ -348,15 +352,18 @@ class Checkpointer:
348 all_samples = [] 352 all_samples = []
349 filename = f"stable_step_%d.png" % (step) 353 filename = f"stable_step_%d.png" % (step)
350 354
355 data_enum = enumerate(data)
356
351 # Generate and save stable samples 357 # Generate and save stable samples
352 for i in range(0, self.stable_sample_batches): 358 for i in range(0, self.stable_sample_batches):
353 prompt = [batch["prompt"] for i, batch in enumerate(data) if i < self.sample_batch_size] 359 prompt = [prompt for i, batch in data_enum for j, prompt in enumerate(
360 batch["prompt"]) if i * data.batch_size + j < self.sample_batch_size]
354 361
355 with self.accelerator.autocast(): 362 with self.accelerator.autocast():
356 samples = pipeline( 363 samples = pipeline(
357 prompt=prompt, 364 prompt=prompt,
358 height=self.sample_image_size, 365 height=self.sample_image_size,
359 latents=stable_latents, 366 latents=stable_latents[:len(prompt)],
360 width=self.sample_image_size, 367 width=self.sample_image_size,
361 guidance_scale=guidance_scale, 368 guidance_scale=guidance_scale,
362 eta=eta, 369 eta=eta,
@@ -377,9 +384,12 @@ class Checkpointer:
377 all_samples = [] 384 all_samples = []
378 filename = f"step_%d.png" % (step) 385 filename = f"step_%d.png" % (step)
379 386
387 data_enum = enumerate(data)
388
380 # Generate and save random samples 389 # Generate and save random samples
381 for i in range(0, self.random_sample_batches): 390 for i in range(0, self.random_sample_batches):
382 prompt = [batch["prompt"] for i, batch in enumerate(data) if i < self.sample_batch_size] 391 prompt = [prompt for i, batch in data_enum for j, prompt in enumerate(
392 batch["prompt"]) if i * data.batch_size + j < self.sample_batch_size]
383 393
384 with self.accelerator.autocast(): 394 with self.accelerator.autocast():
385 samples = pipeline( 395 samples = pipeline(
@@ -486,6 +496,9 @@ def main():
486 args.pretrained_model_name_or_path + '/unet', 496 args.pretrained_model_name_or_path + '/unet',
487 ) 497 )
488 498
499 if args.gradient_checkpointing:
500 unet.enable_gradient_checkpointing()
501
489 slice_size = unet.config.attention_head_dim // 2 502 slice_size = unet.config.attention_head_dim // 2
490 unet.set_attention_slice(slice_size) 503 unet.set_attention_slice(slice_size)
491 504
@@ -693,7 +706,7 @@ def main():
693 progress_bar.clear() 706 progress_bar.clear()
694 local_progress_bar.clear() 707 local_progress_bar.clear()
695 708
696 checkpointer.checkpoint(global_step + global_step_offset, text_encoder) 709 checkpointer.checkpoint(global_step + global_step_offset, "training", text_encoder)
697 save_resume_file(basepath, args, { 710 save_resume_file(basepath, args, {
698 "global_step": global_step + global_step_offset, 711 "global_step": global_step + global_step_offset,
699 "resume_checkpoint": f"{basepath}/checkpoints/last.bin" 712 "resume_checkpoint": f"{basepath}/checkpoints/last.bin"
@@ -753,6 +766,7 @@ def main():
753 766
754 if min_val_loss > val_loss: 767 if min_val_loss > val_loss:
755 accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") 768 accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}")
769 checkpointer.checkpoint(global_step + global_step_offset, "milestone", text_encoder)
756 min_val_loss = val_loss 770 min_val_loss = val_loss
757 771
758 checkpointer.save_samples( 772 checkpointer.save_samples(
@@ -768,6 +782,7 @@ def main():
768 print("Finished! Saving final checkpoint and resume state.") 782 print("Finished! Saving final checkpoint and resume state.")
769 checkpointer.checkpoint( 783 checkpointer.checkpoint(
770 global_step + global_step_offset, 784 global_step + global_step_offset,
785 "end",
771 text_encoder, 786 text_encoder,
772 path=f"{basepath}/learned_embeds.bin" 787 path=f"{basepath}/learned_embeds.bin"
773 ) 788 )
@@ -782,7 +797,7 @@ def main():
782 except KeyboardInterrupt: 797 except KeyboardInterrupt:
783 if accelerator.is_main_process: 798 if accelerator.is_main_process:
784 print("Interrupted, saving checkpoint and resume state...") 799 print("Interrupted, saving checkpoint and resume state...")
785 checkpointer.checkpoint(global_step + global_step_offset, text_encoder) 800 checkpointer.checkpoint(global_step + global_step_offset, "end", text_encoder)
786 save_resume_file(basepath, args, { 801 save_resume_file(basepath, args, {
787 "global_step": global_step + global_step_offset, 802 "global_step": global_step + global_step_offset,
788 "resume_checkpoint": f"{basepath}/checkpoints/last.bin" 803 "resume_checkpoint": f"{basepath}/checkpoints/last.bin"