From 5b54788842cdd7b342bd60d6944158009130b4d4 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Wed, 28 Sep 2022 15:05:30 +0200 Subject: Improved sample output and progress bars --- data/dreambooth/csv.py | 10 ++--- dreambooth.py | 104 +++++++++++++++++++++---------------------------- textual_inversion.py | 2 +- 3 files changed, 50 insertions(+), 66 deletions(-) diff --git a/data/dreambooth/csv.py b/data/dreambooth/csv.py index 14c13bb..85ed4a5 100644 --- a/data/dreambooth/csv.py +++ b/data/dreambooth/csv.py @@ -108,14 +108,14 @@ class CSVDataset(Dataset): else: self.class_data_root = None - self.interpolation = {"linear": PIL.Image.LINEAR, - "bilinear": PIL.Image.BILINEAR, - "bicubic": PIL.Image.BICUBIC, - "lanczos": PIL.Image.LANCZOS, + self.interpolation = {"linear": transforms.InterpolationMode.NEAREST, + "bilinear": transforms.InterpolationMode.BILINEAR, + "bicubic": transforms.InterpolationMode.BICUBIC, + "lanczos": transforms.InterpolationMode.LANCZOS, }[interpolation] self.image_transforms = transforms.Compose( [ - transforms.Resize(size, interpolation=transforms.InterpolationMode.BILINEAR), + transforms.Resize(size, interpolation=self.interpolation), transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), transforms.RandomHorizontalFlip(), transforms.ToTensor(), diff --git a/dreambooth.py b/dreambooth.py index 170b8e9..2df6858 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -112,7 +112,7 @@ def parse_args(): parser.add_argument( "--learning_rate", type=float, - default=5e-6, + default=3e-6, help="Initial learning rate (after the potential warmup period) to use.", ) parser.add_argument( @@ -182,12 +182,6 @@ def parse_args(): default=-1, help="For distributed training: local_rank" ) - parser.add_argument( - "--checkpoint_frequency", - type=int, - default=200, - help="How often to save a checkpoint and sample image", - ) parser.add_argument( "--sample_image_size", type=int, @@ -379,8 +373,8 @@ class Checkpointer: torch.cuda.empty_cache() @torch.no_grad() - def save_samples(self, mode, step, height, width, guidance_scale, eta, num_inference_steps): - samples_path = Path(self.output_dir).joinpath("samples").joinpath(mode) + def save_samples(self, step, height, width, guidance_scale, eta, num_inference_steps): + samples_path = Path(self.output_dir).joinpath("samples") samples_path.mkdir(parents=True, exist_ok=True) unwrapped = self.accelerator.unwrap_model(self.unet) @@ -397,12 +391,10 @@ class Checkpointer: ).to(self.accelerator.device) pipeline.enable_attention_slicing() - data = { - "training": self.datamodule.train_dataloader(), - "validation": self.datamodule.val_dataloader(), - }[mode] + train_data = self.datamodule.train_dataloader() + val_data = self.datamodule.val_dataloader() - if mode == "validation" and self.stable_sample_batches > 0 and step > 0: + if self.stable_sample_batches > 0: stable_latents = torch.randn( (self.sample_batch_size, pipeline.unet.in_channels, height // 8, width // 8), device=pipeline.device, @@ -410,14 +402,14 @@ class Checkpointer: ) all_samples = [] - filename = f"stable_step_%d.png" % (step) + filename = f"step_{step}_val_stable.png" - data_enum = enumerate(data) + data_enum = enumerate(val_data) # Generate and save stable samples for i in range(0, self.stable_sample_batches): prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( - batch["prompts"]) if i * data.batch_size + j < self.sample_batch_size] + batch["prompts"]) if i * val_data.batch_size + j < self.sample_batch_size] with self.accelerator.autocast(): samples = pipeline( @@ -441,35 +433,35 @@ class Checkpointer: del image_grid del stable_latents - all_samples = [] - filename = f"step_%d.png" % (step) + for data, pool in [(train_data, "train"), (val_data, "val")]: + all_samples = [] + filename = f"step_{step}_{pool}.png" - data_enum = enumerate(data) + data_enum = enumerate(data) - # Generate and save random samples - for i in range(0, self.random_sample_batches): - prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( - batch["prompts"]) if i * data.batch_size + j < self.sample_batch_size] + for i in range(0, self.random_sample_batches): + prompt = [prompt for i, batch in data_enum for j, prompt in enumerate( + batch["prompts"]) if i * data.batch_size + j < self.sample_batch_size] - with self.accelerator.autocast(): - samples = pipeline( - prompt=prompt, - height=self.sample_image_size, - width=self.sample_image_size, - guidance_scale=guidance_scale, - eta=eta, - num_inference_steps=num_inference_steps, - output_type='pil' - )["sample"] + with self.accelerator.autocast(): + samples = pipeline( + prompt=prompt, + height=self.sample_image_size, + width=self.sample_image_size, + guidance_scale=guidance_scale, + eta=eta, + num_inference_steps=num_inference_steps, + output_type='pil' + )["sample"] - all_samples += samples - del samples + all_samples += samples + del samples - image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size) - image_grid.save(f"{samples_path}/{filename}") + image_grid = make_grid(all_samples, self.random_sample_batches, self.sample_batch_size) + image_grid.save(f"{samples_path}/{filename}") - del all_samples - del image_grid + del all_samples + del image_grid del unwrapped del pipeline @@ -594,8 +586,7 @@ def main(): beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", - num_train_timesteps=1000, - tensor_format="pt" + num_train_timesteps=1000 ) def collate_fn(examples): @@ -687,6 +678,7 @@ def main(): num_val_steps_per_epoch = len(val_dataloader) num_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch) + val_steps = num_val_steps_per_epoch * num_epochs # We need to initialize the trackers we use, and also store our configuration. # The trackers initializes automatically on the main process. @@ -707,16 +699,16 @@ def main(): global_step = 0 min_val_loss = np.inf - checkpointer.save_samples( - "validation", - 0, - args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) + if accelerator.is_main_process: + checkpointer.save_samples( + 0, + args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) local_progress_bar = tqdm(range(num_update_steps_per_epoch + num_val_steps_per_epoch), disable=not accelerator.is_local_main_process) local_progress_bar.set_description("Batch X out of Y") - global_progress_bar = tqdm(range(args.max_train_steps), disable=not accelerator.is_local_main_process) + global_progress_bar = tqdm(range(args.max_train_steps + val_steps), disable=not accelerator.is_local_main_process) global_progress_bar.set_description("Total progress") try: @@ -789,15 +781,6 @@ def main(): global_step += 1 - if global_step % args.checkpoint_frequency == 0 and global_step > 0 and accelerator.is_main_process: - local_progress_bar.clear() - global_progress_bar.clear() - - checkpointer.save_samples( - "training", - global_step, - args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) - logs = {"mode": "training", "loss": loss, "lr": lr_scheduler.get_last_lr()[0]} local_progress_bar.set_postfix(**logs) @@ -847,6 +830,7 @@ def main(): if accelerator.sync_gradients: local_progress_bar.update(1) + global_progress_bar.update(1) logs = {"mode": "validation", "loss": loss} local_progress_bar.set_postfix(**logs) @@ -862,10 +846,10 @@ def main(): accelerator.print(f"Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}") min_val_loss = val_loss - checkpointer.save_samples( - "validation", - global_step, - args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) + if accelerator.is_main_process: + checkpointer.save_samples( + global_step, + args.resolution, args.resolution, 7.5, 0.0, args.sample_steps) accelerator.wait_for_everyone() diff --git a/textual_inversion.py b/textual_inversion.py index 81f1cf5..399d876 100644 --- a/textual_inversion.py +++ b/textual_inversion.py @@ -581,7 +581,7 @@ def main(): # TODO (patil-suraj): laod scheduler using args noise_scheduler = DDPMScheduler( - beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000, tensor_format="pt" + beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000 ) datamodule = CSVDataModule( -- cgit v1.2.3-70-g09d2