From 2af0d47b44fe02269b1378f7691d258d35544bb3 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 7 Oct 2022 14:54:44 +0200 Subject: Fix small details --- textual_inversion.py | 70 ++++++++++++++++++++++++++-------------------------- 1 file changed, 35 insertions(+), 35 deletions(-) (limited to 'textual_inversion.py') diff --git a/textual_inversion.py b/textual_inversion.py index 86fcdfe..4f2de9e 100644 --- a/textual_inversion.py +++ b/textual_inversion.py @@ -19,7 +19,7 @@ from schedulers.scheduling_euler_a import EulerAScheduler from diffusers.optimization import get_scheduler from PIL import Image from tqdm.auto import tqdm -from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer +from transformers import CLIPTextModel, CLIPTokenizer from slugify import slugify from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion import json @@ -70,7 +70,7 @@ def parse_args(): parser.add_argument( "--num_class_images", type=int, - default=2, + default=4, help="How many class images to generate per training image." ) parser.add_argument( @@ -107,7 +107,8 @@ def parse_args(): parser.add_argument( "--num_train_epochs", type=int, - default=100) + default=100 + ) parser.add_argument( "--max_train_steps", type=int, @@ -128,7 +129,7 @@ def parse_args(): parser.add_argument( "--learning_rate", type=float, - default=1e-4, + default=5e-5, help="Initial learning rate (after the potential warmup period) to use.", ) parser.add_argument( @@ -325,9 +326,10 @@ class Checkpointer: vae, unet, tokenizer, + text_encoder, placeholder_token, placeholder_token_id, - output_dir, + output_dir: Path, sample_image_size, sample_batches, sample_batch_size, @@ -338,6 +340,7 @@ class Checkpointer: self.vae = vae self.unet = unet self.tokenizer = tokenizer + self.text_encoder = text_encoder self.placeholder_token = placeholder_token self.placeholder_token_id = placeholder_token_id self.output_dir = output_dir @@ -347,14 +350,14 @@ class Checkpointer: self.sample_batch_size = sample_batch_size @torch.no_grad() - def checkpoint(self, step, postfix, text_encoder, save_samples=True, path=None): + def checkpoint(self, step, postfix, path=None): print("Saving checkpoint for step %d..." % step) if path is None: - checkpoints_path = f"{self.output_dir}/checkpoints" - os.makedirs(checkpoints_path, exist_ok=True) + checkpoints_path = self.output_dir.joinpath("checkpoints") + checkpoints_path.mkdir(parents=True, exist_ok=True) - unwrapped = self.accelerator.unwrap_model(text_encoder) + unwrapped = self.accelerator.unwrap_model(self.text_encoder) # Save a checkpoint learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id] @@ -364,17 +367,16 @@ class Checkpointer: if path is not None: torch.save(learned_embeds_dict, path) else: - torch.save(learned_embeds_dict, f"{checkpoints_path}/{filename}") - torch.save(learned_embeds_dict, f"{checkpoints_path}/last.bin") + torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename)) del unwrapped del learned_embeds @torch.no_grad() - def save_samples(self, step, text_encoder, height, width, guidance_scale, eta, num_inference_steps): + def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps): samples_path = Path(self.output_dir).joinpath("samples") - unwrapped = self.accelerator.unwrap_model(text_encoder) + unwrapped = self.accelerator.unwrap_model(self.text_encoder) scheduler = EulerAScheduler( beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" ) @@ -608,7 +610,7 @@ def main(): tokenizer=tokenizer, instance_identifier=args.placeholder_token, class_identifier=args.initializer_token, - class_subdir="ti_cls", + class_subdir="cls", num_class_images=args.num_class_images, size=args.resolution, repeats=args.repeats, @@ -664,21 +666,6 @@ def main(): train_dataloader = datamodule.train_dataloader() val_dataloader = datamodule.val_dataloader() - checkpointer = Checkpointer( - datamodule=datamodule, - accelerator=accelerator, - vae=vae, - unet=unet, - tokenizer=tokenizer, - placeholder_token=args.placeholder_token, - placeholder_token_id=placeholder_token_id, - output_dir=basepath, - sample_image_size=args.sample_image_size, - sample_batch_size=args.sample_batch_size, - sample_batches=args.sample_batches, - seed=args.seed - ) - # Scheduler and math around the number of training steps. overrode_max_train_steps = False num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) @@ -733,10 +720,25 @@ def main(): global_step = 0 min_val_loss = np.inf + checkpointer = Checkpointer( + datamodule=datamodule, + accelerator=accelerator, + vae=vae, + unet=unet, + tokenizer=tokenizer, + text_encoder=text_encoder, + placeholder_token=args.placeholder_token, + placeholder_token_id=placeholder_token_id, + output_dir=basepath, + sample_image_size=args.sample_image_size, + sample_batch_size=args.sample_batch_size, + sample_batches=args.sample_batches, + seed=args.seed + ) + if accelerator.is_main_process: checkpointer.save_samples( 0, - text_encoder, args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) local_progress_bar = tqdm( @@ -838,7 +840,7 @@ def main(): local_progress_bar.clear() global_progress_bar.clear() - checkpointer.checkpoint(global_step + global_step_offset, "training", text_encoder) + checkpointer.checkpoint(global_step + global_step_offset, "training") save_resume_file(basepath, args, { "global_step": global_step + global_step_offset, "resume_checkpoint": f"{basepath}/checkpoints/last.bin" @@ -897,13 +899,12 @@ def main(): if min_val_loss > val_loss: accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") - checkpointer.checkpoint(global_step + global_step_offset, "milestone", text_encoder) + checkpointer.checkpoint(global_step + global_step_offset, "milestone") min_val_loss = val_loss if sample_checkpoint and accelerator.is_main_process: checkpointer.save_samples( global_step + global_step_offset, - text_encoder, args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) # Create the pipeline using using the trained modules and save it. @@ -912,7 +913,6 @@ def main(): checkpointer.checkpoint( global_step + global_step_offset, "end", - text_encoder, path=f"{basepath}/learned_embeds.bin" ) @@ -926,7 +926,7 @@ def main(): except KeyboardInterrupt: if accelerator.is_main_process: print("Interrupted, saving checkpoint and resume state...") - checkpointer.checkpoint(global_step + global_step_offset, "end", text_encoder) + checkpointer.checkpoint(global_step + global_step_offset, "end") save_resume_file(basepath, args, { "global_step": global_step + global_step_offset, "resume_checkpoint": f"{basepath}/checkpoints/last.bin" -- cgit v1.2.3-54-g00ecf