From 68a3735352d3e62d46556b677407fc71b78c47c4 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 27 Sep 2022 08:21:28 +0200 Subject: Better sample filenames, optimizations --- .gitignore | 4 ++-- environment.yaml | 10 ++++------ main.py | 37 ++++++++++++++++++++++++++----------- 3 files changed, 32 insertions(+), 19 deletions(-) diff --git a/.gitignore b/.gitignore index 00e7681..a8893c3 100644 --- a/.gitignore +++ b/.gitignore @@ -160,5 +160,5 @@ cython_debug/ #.idea/ text-inversion-model/ -conf.json -v1-inference.yaml +conf*.json +v1-inference.yaml* diff --git a/environment.yaml b/environment.yaml index a460158..8dc930c 100644 --- a/environment.yaml +++ b/environment.yaml @@ -16,21 +16,19 @@ dependencies: - -e git+https://github.com/openai/CLIP.git@main#egg=clip - -e git+https://github.com/hlky/k-diffusion-sd#egg=k_diffusion - -e git+https://github.com/devilismyfriend/latent-diffusion#egg=latent-diffusion + - -e git+https://github.com/huggingface/diffusers#egg=diffusers - accelerate==0.12.0 - albumentations==1.1.0 - - diffusers==0.3.0 - einops==0.4.1 - - imageio-ffmpeg==0.4.7 - - imageio==2.14.1 + - imageio==2.22.0 - kornia==0.6 - pudb==2019.2 - - omegaconf==2.1.1 + - omegaconf==2.2.3 - opencv-python-headless==4.6.0.66 - python-slugify>=6.1.2 - pytorch-lightning==1.7.7 - setuptools==59.5.0 - - streamlit>=0.73.1 - test-tube>=0.7.5 - torch-fidelity==0.3.0 - torchmetrics==0.9.3 - - transformers==4.19.2 + - transformers==4.22.1 diff --git a/main.py b/main.py index aa5af72..8be79e5 100644 --- a/main.py +++ b/main.py @@ -105,6 +105,11 @@ def parse_args(): default=1, help="Number of updates steps to accumulate before performing a backward/update pass.", ) + parser.add_argument( + "--gradient_checkpointing", + action="store_true", + help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.", + ) parser.add_argument( "--learning_rate", type=float, @@ -213,7 +218,7 @@ def parse_args(): elif args.config is not None: with open(args.config, 'rt') as f: args = parser.parse_args( - namespace=argparse.Namespace(**json.load(f))) + namespace=argparse.Namespace(**json.load(f)["args"])) env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) if env_local_rank != -1 and env_local_rank != args.local_rank: @@ -289,7 +294,7 @@ class Checkpointer: self.stable_sample_batches = stable_sample_batches @torch.no_grad() - def checkpoint(self, step, text_encoder, save_samples=True, path=None): + def checkpoint(self, step, postfix, text_encoder, save_samples=True, path=None): print("Saving checkpoint for step %d..." % step) with self.accelerator.autocast(): if path is None: @@ -302,12 +307,11 @@ class Checkpointer: learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id] learned_embeds_dict = {self.placeholder_token: learned_embeds.detach().cpu()} - filename = f"%s_%d.bin" % (slugify(self.placeholder_token), step) + filename = f"%s_%d_%s.bin" % (slugify(self.placeholder_token), step, postfix) if path is not None: torch.save(learned_embeds_dict, path) else: - torch.save(learned_embeds_dict, - f"{checkpoints_path}/{filename}") + torch.save(learned_embeds_dict, f"{checkpoints_path}/{filename}") torch.save(learned_embeds_dict, f"{checkpoints_path}/last.bin") del unwrapped del learned_embeds @@ -338,7 +342,7 @@ class Checkpointer: "validation": self.datamodule.val_dataloader(), }[mode] - if mode == "validation" and self.stable_sample_batches > 0: + if mode == "validation" and self.stable_sample_batches > 0 and step > 0: stable_latents = torch.randn( (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8), device=pipeline.device, @@ -348,15 +352,18 @@ class Checkpointer: all_samples = [] filename = f"stable_step_%d.png" % (step) + data_enum = enumerate(data) + # Generate and save stable samples for i in range(0, self.stable_sample_batches): - prompt = [batch["prompt"] for i, batch in enumerate(data) if i < self.sample_batch_size] + prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( + batch["prompt"]) if i * data.batch_size + j < self.sample_batch_size] with self.accelerator.autocast(): samples = pipeline( prompt=prompt, height=self.sample_image_size, - latents=stable_latents, + latents=stable_latents[:len(prompt)], width=self.sample_image_size, guidance_scale=guidance_scale, eta=eta, @@ -377,9 +384,12 @@ class Checkpointer: all_samples = [] filename = f"step_%d.png" % (step) + data_enum = enumerate(data) + # Generate and save random samples for i in range(0, self.random_sample_batches): - prompt = [batch["prompt"] for i, batch in enumerate(data) if i < self.sample_batch_size] + prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( + batch["prompt"]) if i * data.batch_size + j < self.sample_batch_size] with self.accelerator.autocast(): samples = pipeline( @@ -486,6 +496,9 @@ def main(): args.pretrained_model_name_or_path + '/unet', ) + if args.gradient_checkpointing: + unet.enable_gradient_checkpointing() + slice_size = unet.config.attention_head_dim // 2 unet.set_attention_slice(slice_size) @@ -693,7 +706,7 @@ def main(): progress_bar.clear() local_progress_bar.clear() - checkpointer.checkpoint(global_step + global_step_offset, text_encoder) + checkpointer.checkpoint(global_step + global_step_offset, "training", text_encoder) save_resume_file(basepath, args, { "global_step": global_step + global_step_offset, "resume_checkpoint": f"{basepath}/checkpoints/last.bin" @@ -753,6 +766,7 @@ def main(): if min_val_loss > val_loss: accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") + checkpointer.checkpoint(global_step + global_step_offset, "milestone", text_encoder) min_val_loss = val_loss checkpointer.save_samples( @@ -768,6 +782,7 @@ def main(): print("Finished! Saving final checkpoint and resume state.") checkpointer.checkpoint( global_step + global_step_offset, + "end", text_encoder, path=f"{basepath}/learned_embeds.bin" ) @@ -782,7 +797,7 @@ def main(): except KeyboardInterrupt: if accelerator.is_main_process: print("Interrupted, saving checkpoint and resume state...") - checkpointer.checkpoint(global_step + global_step_offset, text_encoder) + checkpointer.checkpoint(global_step + global_step_offset, "end", text_encoder) save_resume_file(basepath, args, { "global_step": global_step + global_step_offset, "resume_checkpoint": f"{basepath}/checkpoints/last.bin" -- cgit v1.2.3-54-g00ecf