From 89afcfda3f824cc44221e877182348f9b09687d2 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 16 Jan 2023 10:31:55 +0100 Subject: Handle empty validation dataset --- data/csv.py | 47 +++++++++++-------- train_dreambooth.py | 6 +-- train_ti.py | 9 ++-- training/functional.py | 101 ++++++++++++++++++++++------------------ training/strategy/dreambooth.py | 2 +- training/strategy/ti.py | 2 +- 6 files changed, 91 insertions(+), 76 deletions(-) diff --git a/data/csv.py b/data/csv.py index 002fdd2..968af8d 100644 --- a/data/csv.py +++ b/data/csv.py @@ -269,18 +269,22 @@ class VlpnDataModule(): num_images = len(items) - valid_set_size = self.valid_set_size if self.valid_set_size is not None else num_images // 10 - valid_set_size = max(valid_set_size, 1) - train_set_size = num_images - valid_set_size + valid_set_size = min(self.valid_set_size, num_images) if self.valid_set_size is not None else num_images // 10 + train_set_size = max(num_images - valid_set_size, 1) + valid_set_size = num_images - train_set_size generator = torch.Generator(device="cpu") if self.seed is not None: generator = generator.manual_seed(self.seed) - data_train, data_val = random_split(items, [train_set_size, valid_set_size], generator=generator) + collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.num_class_images != 0) + + if valid_set_size == 0: + data_train, data_val = items, [] + else: + data_train, data_val = random_split(items, [train_set_size, valid_set_size], generator=generator) self.data_train = self.pad_items(data_train, self.num_class_images) - self.data_val = self.pad_items(data_val) train_dataset = VlpnDataset( self.data_train, self.tokenizer, @@ -291,26 +295,29 @@ class VlpnDataModule(): num_class_images=self.num_class_images, dropout=self.dropout, shuffle=self.shuffle, ) - val_dataset = VlpnDataset( - self.data_val, self.tokenizer, - num_buckets=self.num_buckets, progressive_buckets=True, - bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, - repeat=self.valid_set_repeat, - batch_size=self.batch_size, generator=generator, - size=self.size, interpolation=self.interpolation, - ) - - collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.num_class_images != 0) - self.train_dataloader = DataLoader( train_dataset, batch_size=None, pin_memory=True, collate_fn=collate_fn_ ) - self.val_dataloader = DataLoader( - val_dataset, - batch_size=None, pin_memory=True, collate_fn=collate_fn_ - ) + if valid_set_size != 0: + self.data_val = self.pad_items(data_val) + + val_dataset = VlpnDataset( + self.data_val, self.tokenizer, + num_buckets=self.num_buckets, progressive_buckets=True, + bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, + repeat=self.valid_set_repeat, + batch_size=self.batch_size, generator=generator, + size=self.size, interpolation=self.interpolation, + ) + + self.val_dataloader = DataLoader( + val_dataset, + batch_size=None, pin_memory=True, collate_fn=collate_fn_ + ) + else: + self.val_dataloader = None class VlpnDataset(IterableDataset): diff --git a/train_dreambooth.py b/train_dreambooth.py index 05777d0..4e41f77 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -564,7 +564,7 @@ def main(): embeddings=embeddings, placeholder_tokens=[placeholder_token], initializer_tokens=[initializer_token], - num_vectors=num_vectors + num_vectors=[num_vectors] ) datamodule = VlpnDataModule( @@ -579,7 +579,7 @@ def main(): valid_set_size=args.valid_set_size, valid_set_repeat=args.valid_set_repeat, seed=args.seed, - filter=partial(keyword_filter, placeholder_token, args.collection, args.exclude_collections), + filter=partial(keyword_filter, [placeholder_token], args.collection, args.exclude_collections), dtype=weight_dtype ) datamodule.setup() @@ -654,7 +654,7 @@ def main(): valid_set_size=args.valid_set_size, valid_set_repeat=args.valid_set_repeat, seed=args.seed, - filter=partial(keyword_filter, args.placeholder_tokens, args.collection, args.exclude_collections), + filter=partial(keyword_filter, None, args.collection, args.exclude_collections), dtype=weight_dtype ) datamodule.setup() diff --git a/train_ti.py b/train_ti.py index 48a2333..a894ee7 100644 --- a/train_ti.py +++ b/train_ti.py @@ -582,9 +582,6 @@ def main(): ) datamodule.setup() - train_dataloader = datamodule.train_dataloader - val_dataloader = datamodule.val_dataloader - if args.num_class_images != 0: generate_class_images( accelerator, @@ -623,7 +620,7 @@ def main(): lr_scheduler = get_scheduler( args.lr_scheduler, optimizer=optimizer, - num_training_steps_per_epoch=len(train_dataloader), + num_training_steps_per_epoch=len(datamodule.train_dataloader), gradient_accumulation_steps=args.gradient_accumulation_steps, min_lr=args.lr_min_lr, warmup_func=args.lr_warmup_func, @@ -637,8 +634,8 @@ def main(): trainer( project="textual_inversion", - train_dataloader=train_dataloader, - val_dataloader=val_dataloader, + train_dataloader=datamodule.train_dataloader, + val_dataloader=datamodule.val_dataloader, optimizer=optimizer, lr_scheduler=lr_scheduler, num_train_epochs=args.num_train_epochs, diff --git a/training/functional.py b/training/functional.py index 1b6162b..c6b4dc3 100644 --- a/training/functional.py +++ b/training/functional.py @@ -73,7 +73,7 @@ def save_samples( vae: AutoencoderKL, sample_scheduler: DPMSolverMultistepScheduler, train_dataloader: DataLoader, - val_dataloader: DataLoader, + val_dataloader: Optional[DataLoader], dtype: torch.dtype, output_dir: Path, seed: int, @@ -111,11 +111,13 @@ def save_samples( generator = torch.Generator(device=accelerator.device).manual_seed(seed) - for pool, data, gen in [ - ("stable", val_dataloader, generator), - ("val", val_dataloader, None), - ("train", train_dataloader, None) - ]: + datasets: list[tuple[str, DataLoader, Optional[torch.Generator]]] = [("train", train_dataloader, None)] + + if val_dataloader is not None: + datasets.append(("stable", val_dataloader, generator)) + datasets.append(("val", val_dataloader, None)) + + for pool, data, gen in datasets: all_samples = [] file_path = samples_path.joinpath(pool, f"step_{step}.jpg") file_path.parent.mkdir(parents=True, exist_ok=True) @@ -328,7 +330,7 @@ def train_loop( optimizer: torch.optim.Optimizer, lr_scheduler: torch.optim.lr_scheduler._LRScheduler, train_dataloader: DataLoader, - val_dataloader: DataLoader, + val_dataloader: Optional[DataLoader], loss_step: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], sample_frequency: int = 10, checkpoint_frequency: int = 50, @@ -337,7 +339,7 @@ def train_loop( callbacks: TrainingCallbacks = TrainingCallbacks(), ): num_training_steps_per_epoch = math.ceil(len(train_dataloader) / accelerator.gradient_accumulation_steps) - num_val_steps_per_epoch = len(val_dataloader) + num_val_steps_per_epoch = len(val_dataloader) if val_dataloader is not None else 0 num_training_steps = num_training_steps_per_epoch * num_epochs num_val_steps = num_val_steps_per_epoch * num_epochs @@ -350,6 +352,7 @@ def train_loop( avg_loss_val = AverageMeter() avg_acc_val = AverageMeter() + max_acc = 0.0 max_acc_val = 0.0 local_progress_bar = tqdm( @@ -432,49 +435,57 @@ def train_loop( accelerator.wait_for_everyone() - model.eval() - - cur_loss_val = AverageMeter() - cur_acc_val = AverageMeter() - - with torch.inference_mode(), on_eval(): - for step, batch in enumerate(val_dataloader): - loss, acc, bsz = loss_step(step, batch, True) - - loss = loss.detach_() - acc = acc.detach_() - - cur_loss_val.update(loss, bsz) - cur_acc_val.update(acc, bsz) + if val_dataloader is not None: + model.eval() - avg_loss_val.update(loss, bsz) - avg_acc_val.update(acc, bsz) + cur_loss_val = AverageMeter() + cur_acc_val = AverageMeter() - local_progress_bar.update(1) - global_progress_bar.update(1) + with torch.inference_mode(), on_eval(): + for step, batch in enumerate(val_dataloader): + loss, acc, bsz = loss_step(step, batch, True) - logs = { - "val/loss": avg_loss_val.avg.item(), - "val/acc": avg_acc_val.avg.item(), - "val/cur_loss": loss.item(), - "val/cur_acc": acc.item(), - } - local_progress_bar.set_postfix(**logs) + loss = loss.detach_() + acc = acc.detach_() - logs["val/cur_loss"] = cur_loss_val.avg.item() - logs["val/cur_acc"] = cur_acc_val.avg.item() + cur_loss_val.update(loss, bsz) + cur_acc_val.update(acc, bsz) - accelerator.log(logs, step=global_step) + avg_loss_val.update(loss, bsz) + avg_acc_val.update(acc, bsz) - local_progress_bar.clear() - global_progress_bar.clear() + local_progress_bar.update(1) + global_progress_bar.update(1) - if accelerator.is_main_process: - if avg_acc_val.avg.item() > max_acc_val: - accelerator.print( - f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") - on_checkpoint(global_step + global_step_offset, "milestone") - max_acc_val = avg_acc_val.avg.item() + logs = { + "val/loss": avg_loss_val.avg.item(), + "val/acc": avg_acc_val.avg.item(), + "val/cur_loss": loss.item(), + "val/cur_acc": acc.item(), + } + local_progress_bar.set_postfix(**logs) + + logs["val/cur_loss"] = cur_loss_val.avg.item() + logs["val/cur_acc"] = cur_acc_val.avg.item() + + accelerator.log(logs, step=global_step) + + local_progress_bar.clear() + global_progress_bar.clear() + + if accelerator.is_main_process: + if avg_acc_val.avg.item() > max_acc_val: + accelerator.print( + f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") + on_checkpoint(global_step + global_step_offset, "milestone") + max_acc_val = avg_acc_val.avg.item() + else: + if accelerator.is_main_process: + if avg_acc.avg.item() > max_acc: + accelerator.print( + f"Global step {global_step}: Training accuracy reached new maximum: {max_acc:.2e} -> {avg_acc.avg.item():.2e}") + on_checkpoint(global_step + global_step_offset, "milestone") + max_acc = avg_acc.avg.item() # Create the pipeline using using the trained modules and save it. if accelerator.is_main_process: @@ -499,7 +510,7 @@ def train( seed: int, project: str, train_dataloader: DataLoader, - val_dataloader: DataLoader, + val_dataloader: Optional[DataLoader], optimizer: torch.optim.Optimizer, lr_scheduler: torch.optim.lr_scheduler._LRScheduler, callbacks_fn: Callable[..., TrainingCallbacks], diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index 6e7ebe2..aeaa828 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py @@ -26,7 +26,7 @@ def dreambooth_strategy( vae: AutoencoderKL, sample_scheduler: DPMSolverMultistepScheduler, train_dataloader: DataLoader, - val_dataloader: DataLoader, + val_dataloader: Optional[DataLoader], output_dir: Path, seed: int, train_text_encoder_epochs: int, diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 753dce0..568f9eb 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py @@ -26,7 +26,7 @@ def textual_inversion_strategy( vae: AutoencoderKL, sample_scheduler: DPMSolverMultistepScheduler, train_dataloader: DataLoader, - val_dataloader: DataLoader, + val_dataloader: Optional[DataLoader], output_dir: Path, seed: int, placeholder_tokens: list[str], -- cgit v1.2.3-70-g09d2