diff options
| -rw-r--r-- | data/csv.py | 47 | ||||
| -rw-r--r-- | train_dreambooth.py | 6 | ||||
| -rw-r--r-- | train_ti.py | 9 | ||||
| -rw-r--r-- | training/functional.py | 93 | ||||
| -rw-r--r-- | training/strategy/dreambooth.py | 2 | ||||
| -rw-r--r-- | training/strategy/ti.py | 2 |
6 files changed, 87 insertions, 72 deletions
diff --git a/data/csv.py b/data/csv.py index 002fdd2..968af8d 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -269,18 +269,22 @@ class VlpnDataModule(): | |||
| 269 | 269 | ||
| 270 | num_images = len(items) | 270 | num_images = len(items) |
| 271 | 271 | ||
| 272 | valid_set_size = self.valid_set_size if self.valid_set_size is not None else num_images // 10 | 272 | valid_set_size = min(self.valid_set_size, num_images) if self.valid_set_size is not None else num_images // 10 |
| 273 | valid_set_size = max(valid_set_size, 1) | 273 | train_set_size = max(num_images - valid_set_size, 1) |
| 274 | train_set_size = num_images - valid_set_size | 274 | valid_set_size = num_images - train_set_size |
| 275 | 275 | ||
| 276 | generator = torch.Generator(device="cpu") | 276 | generator = torch.Generator(device="cpu") |
| 277 | if self.seed is not None: | 277 | if self.seed is not None: |
| 278 | generator = generator.manual_seed(self.seed) | 278 | generator = generator.manual_seed(self.seed) |
| 279 | 279 | ||
| 280 | data_train, data_val = random_split(items, [train_set_size, valid_set_size], generator=generator) | 280 | collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.num_class_images != 0) |
| 281 | |||
| 282 | if valid_set_size == 0: | ||
| 283 | data_train, data_val = items, [] | ||
| 284 | else: | ||
| 285 | data_train, data_val = random_split(items, [train_set_size, valid_set_size], generator=generator) | ||
| 281 | 286 | ||
| 282 | self.data_train = self.pad_items(data_train, self.num_class_images) | 287 | self.data_train = self.pad_items(data_train, self.num_class_images) |
| 283 | self.data_val = self.pad_items(data_val) | ||
| 284 | 288 | ||
| 285 | train_dataset = VlpnDataset( | 289 | train_dataset = VlpnDataset( |
| 286 | self.data_train, self.tokenizer, | 290 | self.data_train, self.tokenizer, |
| @@ -291,26 +295,29 @@ class VlpnDataModule(): | |||
| 291 | num_class_images=self.num_class_images, dropout=self.dropout, shuffle=self.shuffle, | 295 | num_class_images=self.num_class_images, dropout=self.dropout, shuffle=self.shuffle, |
| 292 | ) | 296 | ) |
| 293 | 297 | ||
| 294 | val_dataset = VlpnDataset( | ||
| 295 | self.data_val, self.tokenizer, | ||
| 296 | num_buckets=self.num_buckets, progressive_buckets=True, | ||
| 297 | bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, | ||
| 298 | repeat=self.valid_set_repeat, | ||
| 299 | batch_size=self.batch_size, generator=generator, | ||
| 300 | size=self.size, interpolation=self.interpolation, | ||
| 301 | ) | ||
| 302 | |||
| 303 | collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.num_class_images != 0) | ||
| 304 | |||
| 305 | self.train_dataloader = DataLoader( | 298 | self.train_dataloader = DataLoader( |
| 306 | train_dataset, | 299 | train_dataset, |
| 307 | batch_size=None, pin_memory=True, collate_fn=collate_fn_ | 300 | batch_size=None, pin_memory=True, collate_fn=collate_fn_ |
| 308 | ) | 301 | ) |
| 309 | 302 | ||
| 310 | self.val_dataloader = DataLoader( | 303 | if valid_set_size != 0: |
| 311 | val_dataset, | 304 | self.data_val = self.pad_items(data_val) |
| 312 | batch_size=None, pin_memory=True, collate_fn=collate_fn_ | 305 | |
| 313 | ) | 306 | val_dataset = VlpnDataset( |
| 307 | self.data_val, self.tokenizer, | ||
| 308 | num_buckets=self.num_buckets, progressive_buckets=True, | ||
| 309 | bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, | ||
| 310 | repeat=self.valid_set_repeat, | ||
| 311 | batch_size=self.batch_size, generator=generator, | ||
| 312 | size=self.size, interpolation=self.interpolation, | ||
| 313 | ) | ||
| 314 | |||
| 315 | self.val_dataloader = DataLoader( | ||
| 316 | val_dataset, | ||
| 317 | batch_size=None, pin_memory=True, collate_fn=collate_fn_ | ||
| 318 | ) | ||
| 319 | else: | ||
| 320 | self.val_dataloader = None | ||
| 314 | 321 | ||
| 315 | 322 | ||
| 316 | class VlpnDataset(IterableDataset): | 323 | class VlpnDataset(IterableDataset): |
diff --git a/train_dreambooth.py b/train_dreambooth.py index 05777d0..4e41f77 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -564,7 +564,7 @@ def main(): | |||
| 564 | embeddings=embeddings, | 564 | embeddings=embeddings, |
| 565 | placeholder_tokens=[placeholder_token], | 565 | placeholder_tokens=[placeholder_token], |
| 566 | initializer_tokens=[initializer_token], | 566 | initializer_tokens=[initializer_token], |
| 567 | num_vectors=num_vectors | 567 | num_vectors=[num_vectors] |
| 568 | ) | 568 | ) |
| 569 | 569 | ||
| 570 | datamodule = VlpnDataModule( | 570 | datamodule = VlpnDataModule( |
| @@ -579,7 +579,7 @@ def main(): | |||
| 579 | valid_set_size=args.valid_set_size, | 579 | valid_set_size=args.valid_set_size, |
| 580 | valid_set_repeat=args.valid_set_repeat, | 580 | valid_set_repeat=args.valid_set_repeat, |
| 581 | seed=args.seed, | 581 | seed=args.seed, |
| 582 | filter=partial(keyword_filter, placeholder_token, args.collection, args.exclude_collections), | 582 | filter=partial(keyword_filter, [placeholder_token], args.collection, args.exclude_collections), |
| 583 | dtype=weight_dtype | 583 | dtype=weight_dtype |
| 584 | ) | 584 | ) |
| 585 | datamodule.setup() | 585 | datamodule.setup() |
| @@ -654,7 +654,7 @@ def main(): | |||
| 654 | valid_set_size=args.valid_set_size, | 654 | valid_set_size=args.valid_set_size, |
| 655 | valid_set_repeat=args.valid_set_repeat, | 655 | valid_set_repeat=args.valid_set_repeat, |
| 656 | seed=args.seed, | 656 | seed=args.seed, |
| 657 | filter=partial(keyword_filter, args.placeholder_tokens, args.collection, args.exclude_collections), | 657 | filter=partial(keyword_filter, None, args.collection, args.exclude_collections), |
| 658 | dtype=weight_dtype | 658 | dtype=weight_dtype |
| 659 | ) | 659 | ) |
| 660 | datamodule.setup() | 660 | datamodule.setup() |
diff --git a/train_ti.py b/train_ti.py index 48a2333..a894ee7 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -582,9 +582,6 @@ def main(): | |||
| 582 | ) | 582 | ) |
| 583 | datamodule.setup() | 583 | datamodule.setup() |
| 584 | 584 | ||
| 585 | train_dataloader = datamodule.train_dataloader | ||
| 586 | val_dataloader = datamodule.val_dataloader | ||
| 587 | |||
| 588 | if args.num_class_images != 0: | 585 | if args.num_class_images != 0: |
| 589 | generate_class_images( | 586 | generate_class_images( |
| 590 | accelerator, | 587 | accelerator, |
| @@ -623,7 +620,7 @@ def main(): | |||
| 623 | lr_scheduler = get_scheduler( | 620 | lr_scheduler = get_scheduler( |
| 624 | args.lr_scheduler, | 621 | args.lr_scheduler, |
| 625 | optimizer=optimizer, | 622 | optimizer=optimizer, |
| 626 | num_training_steps_per_epoch=len(train_dataloader), | 623 | num_training_steps_per_epoch=len(datamodule.train_dataloader), |
| 627 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 624 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
| 628 | min_lr=args.lr_min_lr, | 625 | min_lr=args.lr_min_lr, |
| 629 | warmup_func=args.lr_warmup_func, | 626 | warmup_func=args.lr_warmup_func, |
| @@ -637,8 +634,8 @@ def main(): | |||
| 637 | 634 | ||
| 638 | trainer( | 635 | trainer( |
| 639 | project="textual_inversion", | 636 | project="textual_inversion", |
| 640 | train_dataloader=train_dataloader, | 637 | train_dataloader=datamodule.train_dataloader, |
| 641 | val_dataloader=val_dataloader, | 638 | val_dataloader=datamodule.val_dataloader, |
| 642 | optimizer=optimizer, | 639 | optimizer=optimizer, |
| 643 | lr_scheduler=lr_scheduler, | 640 | lr_scheduler=lr_scheduler, |
| 644 | num_train_epochs=args.num_train_epochs, | 641 | num_train_epochs=args.num_train_epochs, |
diff --git a/training/functional.py b/training/functional.py index 1b6162b..c6b4dc3 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -73,7 +73,7 @@ def save_samples( | |||
| 73 | vae: AutoencoderKL, | 73 | vae: AutoencoderKL, |
| 74 | sample_scheduler: DPMSolverMultistepScheduler, | 74 | sample_scheduler: DPMSolverMultistepScheduler, |
| 75 | train_dataloader: DataLoader, | 75 | train_dataloader: DataLoader, |
| 76 | val_dataloader: DataLoader, | 76 | val_dataloader: Optional[DataLoader], |
| 77 | dtype: torch.dtype, | 77 | dtype: torch.dtype, |
| 78 | output_dir: Path, | 78 | output_dir: Path, |
| 79 | seed: int, | 79 | seed: int, |
| @@ -111,11 +111,13 @@ def save_samples( | |||
| 111 | 111 | ||
| 112 | generator = torch.Generator(device=accelerator.device).manual_seed(seed) | 112 | generator = torch.Generator(device=accelerator.device).manual_seed(seed) |
| 113 | 113 | ||
| 114 | for pool, data, gen in [ | 114 | datasets: list[tuple[str, DataLoader, Optional[torch.Generator]]] = [("train", train_dataloader, None)] |
| 115 | ("stable", val_dataloader, generator), | 115 | |
| 116 | ("val", val_dataloader, None), | 116 | if val_dataloader is not None: |
| 117 | ("train", train_dataloader, None) | 117 | datasets.append(("stable", val_dataloader, generator)) |
| 118 | ]: | 118 | datasets.append(("val", val_dataloader, None)) |
| 119 | |||
| 120 | for pool, data, gen in datasets: | ||
| 119 | all_samples = [] | 121 | all_samples = [] |
| 120 | file_path = samples_path.joinpath(pool, f"step_{step}.jpg") | 122 | file_path = samples_path.joinpath(pool, f"step_{step}.jpg") |
| 121 | file_path.parent.mkdir(parents=True, exist_ok=True) | 123 | file_path.parent.mkdir(parents=True, exist_ok=True) |
| @@ -328,7 +330,7 @@ def train_loop( | |||
| 328 | optimizer: torch.optim.Optimizer, | 330 | optimizer: torch.optim.Optimizer, |
| 329 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | 331 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, |
| 330 | train_dataloader: DataLoader, | 332 | train_dataloader: DataLoader, |
| 331 | val_dataloader: DataLoader, | 333 | val_dataloader: Optional[DataLoader], |
| 332 | loss_step: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], | 334 | loss_step: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], |
| 333 | sample_frequency: int = 10, | 335 | sample_frequency: int = 10, |
| 334 | checkpoint_frequency: int = 50, | 336 | checkpoint_frequency: int = 50, |
| @@ -337,7 +339,7 @@ def train_loop( | |||
| 337 | callbacks: TrainingCallbacks = TrainingCallbacks(), | 339 | callbacks: TrainingCallbacks = TrainingCallbacks(), |
| 338 | ): | 340 | ): |
| 339 | num_training_steps_per_epoch = math.ceil(len(train_dataloader) / accelerator.gradient_accumulation_steps) | 341 | num_training_steps_per_epoch = math.ceil(len(train_dataloader) / accelerator.gradient_accumulation_steps) |
| 340 | num_val_steps_per_epoch = len(val_dataloader) | 342 | num_val_steps_per_epoch = len(val_dataloader) if val_dataloader is not None else 0 |
| 341 | 343 | ||
| 342 | num_training_steps = num_training_steps_per_epoch * num_epochs | 344 | num_training_steps = num_training_steps_per_epoch * num_epochs |
| 343 | num_val_steps = num_val_steps_per_epoch * num_epochs | 345 | num_val_steps = num_val_steps_per_epoch * num_epochs |
| @@ -350,6 +352,7 @@ def train_loop( | |||
| 350 | avg_loss_val = AverageMeter() | 352 | avg_loss_val = AverageMeter() |
| 351 | avg_acc_val = AverageMeter() | 353 | avg_acc_val = AverageMeter() |
| 352 | 354 | ||
| 355 | max_acc = 0.0 | ||
| 353 | max_acc_val = 0.0 | 356 | max_acc_val = 0.0 |
| 354 | 357 | ||
| 355 | local_progress_bar = tqdm( | 358 | local_progress_bar = tqdm( |
| @@ -432,49 +435,57 @@ def train_loop( | |||
| 432 | 435 | ||
| 433 | accelerator.wait_for_everyone() | 436 | accelerator.wait_for_everyone() |
| 434 | 437 | ||
| 435 | model.eval() | 438 | if val_dataloader is not None: |
| 439 | model.eval() | ||
| 436 | 440 | ||
| 437 | cur_loss_val = AverageMeter() | 441 | cur_loss_val = AverageMeter() |
| 438 | cur_acc_val = AverageMeter() | 442 | cur_acc_val = AverageMeter() |
| 439 | 443 | ||
| 440 | with torch.inference_mode(), on_eval(): | 444 | with torch.inference_mode(), on_eval(): |
| 441 | for step, batch in enumerate(val_dataloader): | 445 | for step, batch in enumerate(val_dataloader): |
| 442 | loss, acc, bsz = loss_step(step, batch, True) | 446 | loss, acc, bsz = loss_step(step, batch, True) |
| 443 | 447 | ||
| 444 | loss = loss.detach_() | 448 | loss = loss.detach_() |
| 445 | acc = acc.detach_() | 449 | acc = acc.detach_() |
| 446 | 450 | ||
| 447 | cur_loss_val.update(loss, bsz) | 451 | cur_loss_val.update(loss, bsz) |
| 448 | cur_acc_val.update(acc, bsz) | 452 | cur_acc_val.update(acc, bsz) |
| 449 | 453 | ||
| 450 | avg_loss_val.update(loss, bsz) | 454 | avg_loss_val.update(loss, bsz) |
| 451 | avg_acc_val.update(acc, bsz) | 455 | avg_acc_val.update(acc, bsz) |
| 452 | 456 | ||
| 453 | local_progress_bar.update(1) | 457 | local_progress_bar.update(1) |
| 454 | global_progress_bar.update(1) | 458 | global_progress_bar.update(1) |
| 455 | 459 | ||
| 456 | logs = { | 460 | logs = { |
| 457 | "val/loss": avg_loss_val.avg.item(), | 461 | "val/loss": avg_loss_val.avg.item(), |
| 458 | "val/acc": avg_acc_val.avg.item(), | 462 | "val/acc": avg_acc_val.avg.item(), |
| 459 | "val/cur_loss": loss.item(), | 463 | "val/cur_loss": loss.item(), |
| 460 | "val/cur_acc": acc.item(), | 464 | "val/cur_acc": acc.item(), |
| 461 | } | 465 | } |
| 462 | local_progress_bar.set_postfix(**logs) | 466 | local_progress_bar.set_postfix(**logs) |
| 463 | 467 | ||
| 464 | logs["val/cur_loss"] = cur_loss_val.avg.item() | 468 | logs["val/cur_loss"] = cur_loss_val.avg.item() |
| 465 | logs["val/cur_acc"] = cur_acc_val.avg.item() | 469 | logs["val/cur_acc"] = cur_acc_val.avg.item() |
| 466 | 470 | ||
| 467 | accelerator.log(logs, step=global_step) | 471 | accelerator.log(logs, step=global_step) |
| 468 | 472 | ||
| 469 | local_progress_bar.clear() | 473 | local_progress_bar.clear() |
| 470 | global_progress_bar.clear() | 474 | global_progress_bar.clear() |
| 471 | 475 | ||
| 472 | if accelerator.is_main_process: | 476 | if accelerator.is_main_process: |
| 473 | if avg_acc_val.avg.item() > max_acc_val: | 477 | if avg_acc_val.avg.item() > max_acc_val: |
| 474 | accelerator.print( | 478 | accelerator.print( |
| 475 | f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") | 479 | f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") |
| 476 | on_checkpoint(global_step + global_step_offset, "milestone") | 480 | on_checkpoint(global_step + global_step_offset, "milestone") |
| 477 | max_acc_val = avg_acc_val.avg.item() | 481 | max_acc_val = avg_acc_val.avg.item() |
| 482 | else: | ||
| 483 | if accelerator.is_main_process: | ||
| 484 | if avg_acc.avg.item() > max_acc: | ||
| 485 | accelerator.print( | ||
| 486 | f"Global step {global_step}: Training accuracy reached new maximum: {max_acc:.2e} -> {avg_acc.avg.item():.2e}") | ||
| 487 | on_checkpoint(global_step + global_step_offset, "milestone") | ||
| 488 | max_acc = avg_acc.avg.item() | ||
| 478 | 489 | ||
| 479 | # Create the pipeline using using the trained modules and save it. | 490 | # Create the pipeline using using the trained modules and save it. |
| 480 | if accelerator.is_main_process: | 491 | if accelerator.is_main_process: |
| @@ -499,7 +510,7 @@ def train( | |||
| 499 | seed: int, | 510 | seed: int, |
| 500 | project: str, | 511 | project: str, |
| 501 | train_dataloader: DataLoader, | 512 | train_dataloader: DataLoader, |
| 502 | val_dataloader: DataLoader, | 513 | val_dataloader: Optional[DataLoader], |
| 503 | optimizer: torch.optim.Optimizer, | 514 | optimizer: torch.optim.Optimizer, |
| 504 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | 515 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, |
| 505 | callbacks_fn: Callable[..., TrainingCallbacks], | 516 | callbacks_fn: Callable[..., TrainingCallbacks], |
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index 6e7ebe2..aeaa828 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py | |||
| @@ -26,7 +26,7 @@ def dreambooth_strategy( | |||
| 26 | vae: AutoencoderKL, | 26 | vae: AutoencoderKL, |
| 27 | sample_scheduler: DPMSolverMultistepScheduler, | 27 | sample_scheduler: DPMSolverMultistepScheduler, |
| 28 | train_dataloader: DataLoader, | 28 | train_dataloader: DataLoader, |
| 29 | val_dataloader: DataLoader, | 29 | val_dataloader: Optional[DataLoader], |
| 30 | output_dir: Path, | 30 | output_dir: Path, |
| 31 | seed: int, | 31 | seed: int, |
| 32 | train_text_encoder_epochs: int, | 32 | train_text_encoder_epochs: int, |
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 753dce0..568f9eb 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
| @@ -26,7 +26,7 @@ def textual_inversion_strategy( | |||
| 26 | vae: AutoencoderKL, | 26 | vae: AutoencoderKL, |
| 27 | sample_scheduler: DPMSolverMultistepScheduler, | 27 | sample_scheduler: DPMSolverMultistepScheduler, |
| 28 | train_dataloader: DataLoader, | 28 | train_dataloader: DataLoader, |
| 29 | val_dataloader: DataLoader, | 29 | val_dataloader: Optional[DataLoader], |
| 30 | output_dir: Path, | 30 | output_dir: Path, |
| 31 | seed: int, | 31 | seed: int, |
| 32 | placeholder_tokens: list[str], | 32 | placeholder_tokens: list[str], |
