summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-16 10:31:55 +0100
committerVolpeon <git@volpeon.ink>2023-01-16 10:31:55 +0100
commit89afcfda3f824cc44221e877182348f9b09687d2 (patch)
tree804b84322e5caa8fb861322ce6970bef4b532c61
parentExtended Dreambooth: Train TI tokens separately (diff)
downloadtextual-inversion-diff-89afcfda3f824cc44221e877182348f9b09687d2.tar.gz
textual-inversion-diff-89afcfda3f824cc44221e877182348f9b09687d2.tar.bz2
textual-inversion-diff-89afcfda3f824cc44221e877182348f9b09687d2.zip
Handle empty validation dataset
-rw-r--r--data/csv.py47
-rw-r--r--train_dreambooth.py6
-rw-r--r--train_ti.py9
-rw-r--r--training/functional.py101
-rw-r--r--training/strategy/dreambooth.py2
-rw-r--r--training/strategy/ti.py2
6 files changed, 91 insertions, 76 deletions
diff --git a/data/csv.py b/data/csv.py
index 002fdd2..968af8d 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -269,18 +269,22 @@ class VlpnDataModule():
269 269
270 num_images = len(items) 270 num_images = len(items)
271 271
272 valid_set_size = self.valid_set_size if self.valid_set_size is not None else num_images // 10 272 valid_set_size = min(self.valid_set_size, num_images) if self.valid_set_size is not None else num_images // 10
273 valid_set_size = max(valid_set_size, 1) 273 train_set_size = max(num_images - valid_set_size, 1)
274 train_set_size = num_images - valid_set_size 274 valid_set_size = num_images - train_set_size
275 275
276 generator = torch.Generator(device="cpu") 276 generator = torch.Generator(device="cpu")
277 if self.seed is not None: 277 if self.seed is not None:
278 generator = generator.manual_seed(self.seed) 278 generator = generator.manual_seed(self.seed)
279 279
280 data_train, data_val = random_split(items, [train_set_size, valid_set_size], generator=generator) 280 collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.num_class_images != 0)
281
282 if valid_set_size == 0:
283 data_train, data_val = items, []
284 else:
285 data_train, data_val = random_split(items, [train_set_size, valid_set_size], generator=generator)
281 286
282 self.data_train = self.pad_items(data_train, self.num_class_images) 287 self.data_train = self.pad_items(data_train, self.num_class_images)
283 self.data_val = self.pad_items(data_val)
284 288
285 train_dataset = VlpnDataset( 289 train_dataset = VlpnDataset(
286 self.data_train, self.tokenizer, 290 self.data_train, self.tokenizer,
@@ -291,26 +295,29 @@ class VlpnDataModule():
291 num_class_images=self.num_class_images, dropout=self.dropout, shuffle=self.shuffle, 295 num_class_images=self.num_class_images, dropout=self.dropout, shuffle=self.shuffle,
292 ) 296 )
293 297
294 val_dataset = VlpnDataset(
295 self.data_val, self.tokenizer,
296 num_buckets=self.num_buckets, progressive_buckets=True,
297 bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels,
298 repeat=self.valid_set_repeat,
299 batch_size=self.batch_size, generator=generator,
300 size=self.size, interpolation=self.interpolation,
301 )
302
303 collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.num_class_images != 0)
304
305 self.train_dataloader = DataLoader( 298 self.train_dataloader = DataLoader(
306 train_dataset, 299 train_dataset,
307 batch_size=None, pin_memory=True, collate_fn=collate_fn_ 300 batch_size=None, pin_memory=True, collate_fn=collate_fn_
308 ) 301 )
309 302
310 self.val_dataloader = DataLoader( 303 if valid_set_size != 0:
311 val_dataset, 304 self.data_val = self.pad_items(data_val)
312 batch_size=None, pin_memory=True, collate_fn=collate_fn_ 305
313 ) 306 val_dataset = VlpnDataset(
307 self.data_val, self.tokenizer,
308 num_buckets=self.num_buckets, progressive_buckets=True,
309 bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels,
310 repeat=self.valid_set_repeat,
311 batch_size=self.batch_size, generator=generator,
312 size=self.size, interpolation=self.interpolation,
313 )
314
315 self.val_dataloader = DataLoader(
316 val_dataset,
317 batch_size=None, pin_memory=True, collate_fn=collate_fn_
318 )
319 else:
320 self.val_dataloader = None
314 321
315 322
316class VlpnDataset(IterableDataset): 323class VlpnDataset(IterableDataset):
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 05777d0..4e41f77 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -564,7 +564,7 @@ def main():
564 embeddings=embeddings, 564 embeddings=embeddings,
565 placeholder_tokens=[placeholder_token], 565 placeholder_tokens=[placeholder_token],
566 initializer_tokens=[initializer_token], 566 initializer_tokens=[initializer_token],
567 num_vectors=num_vectors 567 num_vectors=[num_vectors]
568 ) 568 )
569 569
570 datamodule = VlpnDataModule( 570 datamodule = VlpnDataModule(
@@ -579,7 +579,7 @@ def main():
579 valid_set_size=args.valid_set_size, 579 valid_set_size=args.valid_set_size,
580 valid_set_repeat=args.valid_set_repeat, 580 valid_set_repeat=args.valid_set_repeat,
581 seed=args.seed, 581 seed=args.seed,
582 filter=partial(keyword_filter, placeholder_token, args.collection, args.exclude_collections), 582 filter=partial(keyword_filter, [placeholder_token], args.collection, args.exclude_collections),
583 dtype=weight_dtype 583 dtype=weight_dtype
584 ) 584 )
585 datamodule.setup() 585 datamodule.setup()
@@ -654,7 +654,7 @@ def main():
654 valid_set_size=args.valid_set_size, 654 valid_set_size=args.valid_set_size,
655 valid_set_repeat=args.valid_set_repeat, 655 valid_set_repeat=args.valid_set_repeat,
656 seed=args.seed, 656 seed=args.seed,
657 filter=partial(keyword_filter, args.placeholder_tokens, args.collection, args.exclude_collections), 657 filter=partial(keyword_filter, None, args.collection, args.exclude_collections),
658 dtype=weight_dtype 658 dtype=weight_dtype
659 ) 659 )
660 datamodule.setup() 660 datamodule.setup()
diff --git a/train_ti.py b/train_ti.py
index 48a2333..a894ee7 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -582,9 +582,6 @@ def main():
582 ) 582 )
583 datamodule.setup() 583 datamodule.setup()
584 584
585 train_dataloader = datamodule.train_dataloader
586 val_dataloader = datamodule.val_dataloader
587
588 if args.num_class_images != 0: 585 if args.num_class_images != 0:
589 generate_class_images( 586 generate_class_images(
590 accelerator, 587 accelerator,
@@ -623,7 +620,7 @@ def main():
623 lr_scheduler = get_scheduler( 620 lr_scheduler = get_scheduler(
624 args.lr_scheduler, 621 args.lr_scheduler,
625 optimizer=optimizer, 622 optimizer=optimizer,
626 num_training_steps_per_epoch=len(train_dataloader), 623 num_training_steps_per_epoch=len(datamodule.train_dataloader),
627 gradient_accumulation_steps=args.gradient_accumulation_steps, 624 gradient_accumulation_steps=args.gradient_accumulation_steps,
628 min_lr=args.lr_min_lr, 625 min_lr=args.lr_min_lr,
629 warmup_func=args.lr_warmup_func, 626 warmup_func=args.lr_warmup_func,
@@ -637,8 +634,8 @@ def main():
637 634
638 trainer( 635 trainer(
639 project="textual_inversion", 636 project="textual_inversion",
640 train_dataloader=train_dataloader, 637 train_dataloader=datamodule.train_dataloader,
641 val_dataloader=val_dataloader, 638 val_dataloader=datamodule.val_dataloader,
642 optimizer=optimizer, 639 optimizer=optimizer,
643 lr_scheduler=lr_scheduler, 640 lr_scheduler=lr_scheduler,
644 num_train_epochs=args.num_train_epochs, 641 num_train_epochs=args.num_train_epochs,
diff --git a/training/functional.py b/training/functional.py
index 1b6162b..c6b4dc3 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -73,7 +73,7 @@ def save_samples(
73 vae: AutoencoderKL, 73 vae: AutoencoderKL,
74 sample_scheduler: DPMSolverMultistepScheduler, 74 sample_scheduler: DPMSolverMultistepScheduler,
75 train_dataloader: DataLoader, 75 train_dataloader: DataLoader,
76 val_dataloader: DataLoader, 76 val_dataloader: Optional[DataLoader],
77 dtype: torch.dtype, 77 dtype: torch.dtype,
78 output_dir: Path, 78 output_dir: Path,
79 seed: int, 79 seed: int,
@@ -111,11 +111,13 @@ def save_samples(
111 111
112 generator = torch.Generator(device=accelerator.device).manual_seed(seed) 112 generator = torch.Generator(device=accelerator.device).manual_seed(seed)
113 113
114 for pool, data, gen in [ 114 datasets: list[tuple[str, DataLoader, Optional[torch.Generator]]] = [("train", train_dataloader, None)]
115 ("stable", val_dataloader, generator), 115
116 ("val", val_dataloader, None), 116 if val_dataloader is not None:
117 ("train", train_dataloader, None) 117 datasets.append(("stable", val_dataloader, generator))
118 ]: 118 datasets.append(("val", val_dataloader, None))
119
120 for pool, data, gen in datasets:
119 all_samples = [] 121 all_samples = []
120 file_path = samples_path.joinpath(pool, f"step_{step}.jpg") 122 file_path = samples_path.joinpath(pool, f"step_{step}.jpg")
121 file_path.parent.mkdir(parents=True, exist_ok=True) 123 file_path.parent.mkdir(parents=True, exist_ok=True)
@@ -328,7 +330,7 @@ def train_loop(
328 optimizer: torch.optim.Optimizer, 330 optimizer: torch.optim.Optimizer,
329 lr_scheduler: torch.optim.lr_scheduler._LRScheduler, 331 lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
330 train_dataloader: DataLoader, 332 train_dataloader: DataLoader,
331 val_dataloader: DataLoader, 333 val_dataloader: Optional[DataLoader],
332 loss_step: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], 334 loss_step: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]],
333 sample_frequency: int = 10, 335 sample_frequency: int = 10,
334 checkpoint_frequency: int = 50, 336 checkpoint_frequency: int = 50,
@@ -337,7 +339,7 @@ def train_loop(
337 callbacks: TrainingCallbacks = TrainingCallbacks(), 339 callbacks: TrainingCallbacks = TrainingCallbacks(),
338): 340):
339 num_training_steps_per_epoch = math.ceil(len(train_dataloader) / accelerator.gradient_accumulation_steps) 341 num_training_steps_per_epoch = math.ceil(len(train_dataloader) / accelerator.gradient_accumulation_steps)
340 num_val_steps_per_epoch = len(val_dataloader) 342 num_val_steps_per_epoch = len(val_dataloader) if val_dataloader is not None else 0
341 343
342 num_training_steps = num_training_steps_per_epoch * num_epochs 344 num_training_steps = num_training_steps_per_epoch * num_epochs
343 num_val_steps = num_val_steps_per_epoch * num_epochs 345 num_val_steps = num_val_steps_per_epoch * num_epochs
@@ -350,6 +352,7 @@ def train_loop(
350 avg_loss_val = AverageMeter() 352 avg_loss_val = AverageMeter()
351 avg_acc_val = AverageMeter() 353 avg_acc_val = AverageMeter()
352 354
355 max_acc = 0.0
353 max_acc_val = 0.0 356 max_acc_val = 0.0
354 357
355 local_progress_bar = tqdm( 358 local_progress_bar = tqdm(
@@ -432,49 +435,57 @@ def train_loop(
432 435
433 accelerator.wait_for_everyone() 436 accelerator.wait_for_everyone()
434 437
435 model.eval() 438 if val_dataloader is not None:
436 439 model.eval()
437 cur_loss_val = AverageMeter()
438 cur_acc_val = AverageMeter()
439
440 with torch.inference_mode(), on_eval():
441 for step, batch in enumerate(val_dataloader):
442 loss, acc, bsz = loss_step(step, batch, True)
443
444 loss = loss.detach_()
445 acc = acc.detach_()
446
447 cur_loss_val.update(loss, bsz)
448 cur_acc_val.update(acc, bsz)
449 440
450 avg_loss_val.update(loss, bsz) 441 cur_loss_val = AverageMeter()
451 avg_acc_val.update(acc, bsz) 442 cur_acc_val = AverageMeter()
452 443
453 local_progress_bar.update(1) 444 with torch.inference_mode(), on_eval():
454 global_progress_bar.update(1) 445 for step, batch in enumerate(val_dataloader):
446 loss, acc, bsz = loss_step(step, batch, True)
455 447
456 logs = { 448 loss = loss.detach_()
457 "val/loss": avg_loss_val.avg.item(), 449 acc = acc.detach_()
458 "val/acc": avg_acc_val.avg.item(),
459 "val/cur_loss": loss.item(),
460 "val/cur_acc": acc.item(),
461 }
462 local_progress_bar.set_postfix(**logs)
463 450
464 logs["val/cur_loss"] = cur_loss_val.avg.item() 451 cur_loss_val.update(loss, bsz)
465 logs["val/cur_acc"] = cur_acc_val.avg.item() 452 cur_acc_val.update(acc, bsz)
466 453
467 accelerator.log(logs, step=global_step) 454 avg_loss_val.update(loss, bsz)
455 avg_acc_val.update(acc, bsz)
468 456
469 local_progress_bar.clear() 457 local_progress_bar.update(1)
470 global_progress_bar.clear() 458 global_progress_bar.update(1)
471 459
472 if accelerator.is_main_process: 460 logs = {
473 if avg_acc_val.avg.item() > max_acc_val: 461 "val/loss": avg_loss_val.avg.item(),
474 accelerator.print( 462 "val/acc": avg_acc_val.avg.item(),
475 f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") 463 "val/cur_loss": loss.item(),
476 on_checkpoint(global_step + global_step_offset, "milestone") 464 "val/cur_acc": acc.item(),
477 max_acc_val = avg_acc_val.avg.item() 465 }
466 local_progress_bar.set_postfix(**logs)
467
468 logs["val/cur_loss"] = cur_loss_val.avg.item()
469 logs["val/cur_acc"] = cur_acc_val.avg.item()
470
471 accelerator.log(logs, step=global_step)
472
473 local_progress_bar.clear()
474 global_progress_bar.clear()
475
476 if accelerator.is_main_process:
477 if avg_acc_val.avg.item() > max_acc_val:
478 accelerator.print(
479 f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}")
480 on_checkpoint(global_step + global_step_offset, "milestone")
481 max_acc_val = avg_acc_val.avg.item()
482 else:
483 if accelerator.is_main_process:
484 if avg_acc.avg.item() > max_acc:
485 accelerator.print(
486 f"Global step {global_step}: Training accuracy reached new maximum: {max_acc:.2e} -> {avg_acc.avg.item():.2e}")
487 on_checkpoint(global_step + global_step_offset, "milestone")
488 max_acc = avg_acc.avg.item()
478 489
479 # Create the pipeline using using the trained modules and save it. 490 # Create the pipeline using using the trained modules and save it.
480 if accelerator.is_main_process: 491 if accelerator.is_main_process:
@@ -499,7 +510,7 @@ def train(
499 seed: int, 510 seed: int,
500 project: str, 511 project: str,
501 train_dataloader: DataLoader, 512 train_dataloader: DataLoader,
502 val_dataloader: DataLoader, 513 val_dataloader: Optional[DataLoader],
503 optimizer: torch.optim.Optimizer, 514 optimizer: torch.optim.Optimizer,
504 lr_scheduler: torch.optim.lr_scheduler._LRScheduler, 515 lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
505 callbacks_fn: Callable[..., TrainingCallbacks], 516 callbacks_fn: Callable[..., TrainingCallbacks],
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py
index 6e7ebe2..aeaa828 100644
--- a/training/strategy/dreambooth.py
+++ b/training/strategy/dreambooth.py
@@ -26,7 +26,7 @@ def dreambooth_strategy(
26 vae: AutoencoderKL, 26 vae: AutoencoderKL,
27 sample_scheduler: DPMSolverMultistepScheduler, 27 sample_scheduler: DPMSolverMultistepScheduler,
28 train_dataloader: DataLoader, 28 train_dataloader: DataLoader,
29 val_dataloader: DataLoader, 29 val_dataloader: Optional[DataLoader],
30 output_dir: Path, 30 output_dir: Path,
31 seed: int, 31 seed: int,
32 train_text_encoder_epochs: int, 32 train_text_encoder_epochs: int,
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index 753dce0..568f9eb 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -26,7 +26,7 @@ def textual_inversion_strategy(
26 vae: AutoencoderKL, 26 vae: AutoencoderKL,
27 sample_scheduler: DPMSolverMultistepScheduler, 27 sample_scheduler: DPMSolverMultistepScheduler,
28 train_dataloader: DataLoader, 28 train_dataloader: DataLoader,
29 val_dataloader: DataLoader, 29 val_dataloader: Optional[DataLoader],
30 output_dir: Path, 30 output_dir: Path,
31 seed: int, 31 seed: int,
32 placeholder_tokens: list[str], 32 placeholder_tokens: list[str],