diff options
-rw-r--r-- | data/csv.py | 47 | ||||
-rw-r--r-- | train_dreambooth.py | 6 | ||||
-rw-r--r-- | train_ti.py | 9 | ||||
-rw-r--r-- | training/functional.py | 101 | ||||
-rw-r--r-- | training/strategy/dreambooth.py | 2 | ||||
-rw-r--r-- | training/strategy/ti.py | 2 |
6 files changed, 91 insertions, 76 deletions
diff --git a/data/csv.py b/data/csv.py index 002fdd2..968af8d 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -269,18 +269,22 @@ class VlpnDataModule(): | |||
269 | 269 | ||
270 | num_images = len(items) | 270 | num_images = len(items) |
271 | 271 | ||
272 | valid_set_size = self.valid_set_size if self.valid_set_size is not None else num_images // 10 | 272 | valid_set_size = min(self.valid_set_size, num_images) if self.valid_set_size is not None else num_images // 10 |
273 | valid_set_size = max(valid_set_size, 1) | 273 | train_set_size = max(num_images - valid_set_size, 1) |
274 | train_set_size = num_images - valid_set_size | 274 | valid_set_size = num_images - train_set_size |
275 | 275 | ||
276 | generator = torch.Generator(device="cpu") | 276 | generator = torch.Generator(device="cpu") |
277 | if self.seed is not None: | 277 | if self.seed is not None: |
278 | generator = generator.manual_seed(self.seed) | 278 | generator = generator.manual_seed(self.seed) |
279 | 279 | ||
280 | data_train, data_val = random_split(items, [train_set_size, valid_set_size], generator=generator) | 280 | collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.num_class_images != 0) |
281 | |||
282 | if valid_set_size == 0: | ||
283 | data_train, data_val = items, [] | ||
284 | else: | ||
285 | data_train, data_val = random_split(items, [train_set_size, valid_set_size], generator=generator) | ||
281 | 286 | ||
282 | self.data_train = self.pad_items(data_train, self.num_class_images) | 287 | self.data_train = self.pad_items(data_train, self.num_class_images) |
283 | self.data_val = self.pad_items(data_val) | ||
284 | 288 | ||
285 | train_dataset = VlpnDataset( | 289 | train_dataset = VlpnDataset( |
286 | self.data_train, self.tokenizer, | 290 | self.data_train, self.tokenizer, |
@@ -291,26 +295,29 @@ class VlpnDataModule(): | |||
291 | num_class_images=self.num_class_images, dropout=self.dropout, shuffle=self.shuffle, | 295 | num_class_images=self.num_class_images, dropout=self.dropout, shuffle=self.shuffle, |
292 | ) | 296 | ) |
293 | 297 | ||
294 | val_dataset = VlpnDataset( | ||
295 | self.data_val, self.tokenizer, | ||
296 | num_buckets=self.num_buckets, progressive_buckets=True, | ||
297 | bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, | ||
298 | repeat=self.valid_set_repeat, | ||
299 | batch_size=self.batch_size, generator=generator, | ||
300 | size=self.size, interpolation=self.interpolation, | ||
301 | ) | ||
302 | |||
303 | collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.num_class_images != 0) | ||
304 | |||
305 | self.train_dataloader = DataLoader( | 298 | self.train_dataloader = DataLoader( |
306 | train_dataset, | 299 | train_dataset, |
307 | batch_size=None, pin_memory=True, collate_fn=collate_fn_ | 300 | batch_size=None, pin_memory=True, collate_fn=collate_fn_ |
308 | ) | 301 | ) |
309 | 302 | ||
310 | self.val_dataloader = DataLoader( | 303 | if valid_set_size != 0: |
311 | val_dataset, | 304 | self.data_val = self.pad_items(data_val) |
312 | batch_size=None, pin_memory=True, collate_fn=collate_fn_ | 305 | |
313 | ) | 306 | val_dataset = VlpnDataset( |
307 | self.data_val, self.tokenizer, | ||
308 | num_buckets=self.num_buckets, progressive_buckets=True, | ||
309 | bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, | ||
310 | repeat=self.valid_set_repeat, | ||
311 | batch_size=self.batch_size, generator=generator, | ||
312 | size=self.size, interpolation=self.interpolation, | ||
313 | ) | ||
314 | |||
315 | self.val_dataloader = DataLoader( | ||
316 | val_dataset, | ||
317 | batch_size=None, pin_memory=True, collate_fn=collate_fn_ | ||
318 | ) | ||
319 | else: | ||
320 | self.val_dataloader = None | ||
314 | 321 | ||
315 | 322 | ||
316 | class VlpnDataset(IterableDataset): | 323 | class VlpnDataset(IterableDataset): |
diff --git a/train_dreambooth.py b/train_dreambooth.py index 05777d0..4e41f77 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -564,7 +564,7 @@ def main(): | |||
564 | embeddings=embeddings, | 564 | embeddings=embeddings, |
565 | placeholder_tokens=[placeholder_token], | 565 | placeholder_tokens=[placeholder_token], |
566 | initializer_tokens=[initializer_token], | 566 | initializer_tokens=[initializer_token], |
567 | num_vectors=num_vectors | 567 | num_vectors=[num_vectors] |
568 | ) | 568 | ) |
569 | 569 | ||
570 | datamodule = VlpnDataModule( | 570 | datamodule = VlpnDataModule( |
@@ -579,7 +579,7 @@ def main(): | |||
579 | valid_set_size=args.valid_set_size, | 579 | valid_set_size=args.valid_set_size, |
580 | valid_set_repeat=args.valid_set_repeat, | 580 | valid_set_repeat=args.valid_set_repeat, |
581 | seed=args.seed, | 581 | seed=args.seed, |
582 | filter=partial(keyword_filter, placeholder_token, args.collection, args.exclude_collections), | 582 | filter=partial(keyword_filter, [placeholder_token], args.collection, args.exclude_collections), |
583 | dtype=weight_dtype | 583 | dtype=weight_dtype |
584 | ) | 584 | ) |
585 | datamodule.setup() | 585 | datamodule.setup() |
@@ -654,7 +654,7 @@ def main(): | |||
654 | valid_set_size=args.valid_set_size, | 654 | valid_set_size=args.valid_set_size, |
655 | valid_set_repeat=args.valid_set_repeat, | 655 | valid_set_repeat=args.valid_set_repeat, |
656 | seed=args.seed, | 656 | seed=args.seed, |
657 | filter=partial(keyword_filter, args.placeholder_tokens, args.collection, args.exclude_collections), | 657 | filter=partial(keyword_filter, None, args.collection, args.exclude_collections), |
658 | dtype=weight_dtype | 658 | dtype=weight_dtype |
659 | ) | 659 | ) |
660 | datamodule.setup() | 660 | datamodule.setup() |
diff --git a/train_ti.py b/train_ti.py index 48a2333..a894ee7 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -582,9 +582,6 @@ def main(): | |||
582 | ) | 582 | ) |
583 | datamodule.setup() | 583 | datamodule.setup() |
584 | 584 | ||
585 | train_dataloader = datamodule.train_dataloader | ||
586 | val_dataloader = datamodule.val_dataloader | ||
587 | |||
588 | if args.num_class_images != 0: | 585 | if args.num_class_images != 0: |
589 | generate_class_images( | 586 | generate_class_images( |
590 | accelerator, | 587 | accelerator, |
@@ -623,7 +620,7 @@ def main(): | |||
623 | lr_scheduler = get_scheduler( | 620 | lr_scheduler = get_scheduler( |
624 | args.lr_scheduler, | 621 | args.lr_scheduler, |
625 | optimizer=optimizer, | 622 | optimizer=optimizer, |
626 | num_training_steps_per_epoch=len(train_dataloader), | 623 | num_training_steps_per_epoch=len(datamodule.train_dataloader), |
627 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 624 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
628 | min_lr=args.lr_min_lr, | 625 | min_lr=args.lr_min_lr, |
629 | warmup_func=args.lr_warmup_func, | 626 | warmup_func=args.lr_warmup_func, |
@@ -637,8 +634,8 @@ def main(): | |||
637 | 634 | ||
638 | trainer( | 635 | trainer( |
639 | project="textual_inversion", | 636 | project="textual_inversion", |
640 | train_dataloader=train_dataloader, | 637 | train_dataloader=datamodule.train_dataloader, |
641 | val_dataloader=val_dataloader, | 638 | val_dataloader=datamodule.val_dataloader, |
642 | optimizer=optimizer, | 639 | optimizer=optimizer, |
643 | lr_scheduler=lr_scheduler, | 640 | lr_scheduler=lr_scheduler, |
644 | num_train_epochs=args.num_train_epochs, | 641 | num_train_epochs=args.num_train_epochs, |
diff --git a/training/functional.py b/training/functional.py index 1b6162b..c6b4dc3 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -73,7 +73,7 @@ def save_samples( | |||
73 | vae: AutoencoderKL, | 73 | vae: AutoencoderKL, |
74 | sample_scheduler: DPMSolverMultistepScheduler, | 74 | sample_scheduler: DPMSolverMultistepScheduler, |
75 | train_dataloader: DataLoader, | 75 | train_dataloader: DataLoader, |
76 | val_dataloader: DataLoader, | 76 | val_dataloader: Optional[DataLoader], |
77 | dtype: torch.dtype, | 77 | dtype: torch.dtype, |
78 | output_dir: Path, | 78 | output_dir: Path, |
79 | seed: int, | 79 | seed: int, |
@@ -111,11 +111,13 @@ def save_samples( | |||
111 | 111 | ||
112 | generator = torch.Generator(device=accelerator.device).manual_seed(seed) | 112 | generator = torch.Generator(device=accelerator.device).manual_seed(seed) |
113 | 113 | ||
114 | for pool, data, gen in [ | 114 | datasets: list[tuple[str, DataLoader, Optional[torch.Generator]]] = [("train", train_dataloader, None)] |
115 | ("stable", val_dataloader, generator), | 115 | |
116 | ("val", val_dataloader, None), | 116 | if val_dataloader is not None: |
117 | ("train", train_dataloader, None) | 117 | datasets.append(("stable", val_dataloader, generator)) |
118 | ]: | 118 | datasets.append(("val", val_dataloader, None)) |
119 | |||
120 | for pool, data, gen in datasets: | ||
119 | all_samples = [] | 121 | all_samples = [] |
120 | file_path = samples_path.joinpath(pool, f"step_{step}.jpg") | 122 | file_path = samples_path.joinpath(pool, f"step_{step}.jpg") |
121 | file_path.parent.mkdir(parents=True, exist_ok=True) | 123 | file_path.parent.mkdir(parents=True, exist_ok=True) |
@@ -328,7 +330,7 @@ def train_loop( | |||
328 | optimizer: torch.optim.Optimizer, | 330 | optimizer: torch.optim.Optimizer, |
329 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | 331 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, |
330 | train_dataloader: DataLoader, | 332 | train_dataloader: DataLoader, |
331 | val_dataloader: DataLoader, | 333 | val_dataloader: Optional[DataLoader], |
332 | loss_step: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], | 334 | loss_step: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], |
333 | sample_frequency: int = 10, | 335 | sample_frequency: int = 10, |
334 | checkpoint_frequency: int = 50, | 336 | checkpoint_frequency: int = 50, |
@@ -337,7 +339,7 @@ def train_loop( | |||
337 | callbacks: TrainingCallbacks = TrainingCallbacks(), | 339 | callbacks: TrainingCallbacks = TrainingCallbacks(), |
338 | ): | 340 | ): |
339 | num_training_steps_per_epoch = math.ceil(len(train_dataloader) / accelerator.gradient_accumulation_steps) | 341 | num_training_steps_per_epoch = math.ceil(len(train_dataloader) / accelerator.gradient_accumulation_steps) |
340 | num_val_steps_per_epoch = len(val_dataloader) | 342 | num_val_steps_per_epoch = len(val_dataloader) if val_dataloader is not None else 0 |
341 | 343 | ||
342 | num_training_steps = num_training_steps_per_epoch * num_epochs | 344 | num_training_steps = num_training_steps_per_epoch * num_epochs |
343 | num_val_steps = num_val_steps_per_epoch * num_epochs | 345 | num_val_steps = num_val_steps_per_epoch * num_epochs |
@@ -350,6 +352,7 @@ def train_loop( | |||
350 | avg_loss_val = AverageMeter() | 352 | avg_loss_val = AverageMeter() |
351 | avg_acc_val = AverageMeter() | 353 | avg_acc_val = AverageMeter() |
352 | 354 | ||
355 | max_acc = 0.0 | ||
353 | max_acc_val = 0.0 | 356 | max_acc_val = 0.0 |
354 | 357 | ||
355 | local_progress_bar = tqdm( | 358 | local_progress_bar = tqdm( |
@@ -432,49 +435,57 @@ def train_loop( | |||
432 | 435 | ||
433 | accelerator.wait_for_everyone() | 436 | accelerator.wait_for_everyone() |
434 | 437 | ||
435 | model.eval() | 438 | if val_dataloader is not None: |
436 | 439 | model.eval() | |
437 | cur_loss_val = AverageMeter() | ||
438 | cur_acc_val = AverageMeter() | ||
439 | |||
440 | with torch.inference_mode(), on_eval(): | ||
441 | for step, batch in enumerate(val_dataloader): | ||
442 | loss, acc, bsz = loss_step(step, batch, True) | ||
443 | |||
444 | loss = loss.detach_() | ||
445 | acc = acc.detach_() | ||
446 | |||
447 | cur_loss_val.update(loss, bsz) | ||
448 | cur_acc_val.update(acc, bsz) | ||
449 | 440 | ||
450 | avg_loss_val.update(loss, bsz) | 441 | cur_loss_val = AverageMeter() |
451 | avg_acc_val.update(acc, bsz) | 442 | cur_acc_val = AverageMeter() |
452 | 443 | ||
453 | local_progress_bar.update(1) | 444 | with torch.inference_mode(), on_eval(): |
454 | global_progress_bar.update(1) | 445 | for step, batch in enumerate(val_dataloader): |
446 | loss, acc, bsz = loss_step(step, batch, True) | ||
455 | 447 | ||
456 | logs = { | 448 | loss = loss.detach_() |
457 | "val/loss": avg_loss_val.avg.item(), | 449 | acc = acc.detach_() |
458 | "val/acc": avg_acc_val.avg.item(), | ||
459 | "val/cur_loss": loss.item(), | ||
460 | "val/cur_acc": acc.item(), | ||
461 | } | ||
462 | local_progress_bar.set_postfix(**logs) | ||
463 | 450 | ||
464 | logs["val/cur_loss"] = cur_loss_val.avg.item() | 451 | cur_loss_val.update(loss, bsz) |
465 | logs["val/cur_acc"] = cur_acc_val.avg.item() | 452 | cur_acc_val.update(acc, bsz) |
466 | 453 | ||
467 | accelerator.log(logs, step=global_step) | 454 | avg_loss_val.update(loss, bsz) |
455 | avg_acc_val.update(acc, bsz) | ||
468 | 456 | ||
469 | local_progress_bar.clear() | 457 | local_progress_bar.update(1) |
470 | global_progress_bar.clear() | 458 | global_progress_bar.update(1) |
471 | 459 | ||
472 | if accelerator.is_main_process: | 460 | logs = { |
473 | if avg_acc_val.avg.item() > max_acc_val: | 461 | "val/loss": avg_loss_val.avg.item(), |
474 | accelerator.print( | 462 | "val/acc": avg_acc_val.avg.item(), |
475 | f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") | 463 | "val/cur_loss": loss.item(), |
476 | on_checkpoint(global_step + global_step_offset, "milestone") | 464 | "val/cur_acc": acc.item(), |
477 | max_acc_val = avg_acc_val.avg.item() | 465 | } |
466 | local_progress_bar.set_postfix(**logs) | ||
467 | |||
468 | logs["val/cur_loss"] = cur_loss_val.avg.item() | ||
469 | logs["val/cur_acc"] = cur_acc_val.avg.item() | ||
470 | |||
471 | accelerator.log(logs, step=global_step) | ||
472 | |||
473 | local_progress_bar.clear() | ||
474 | global_progress_bar.clear() | ||
475 | |||
476 | if accelerator.is_main_process: | ||
477 | if avg_acc_val.avg.item() > max_acc_val: | ||
478 | accelerator.print( | ||
479 | f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") | ||
480 | on_checkpoint(global_step + global_step_offset, "milestone") | ||
481 | max_acc_val = avg_acc_val.avg.item() | ||
482 | else: | ||
483 | if accelerator.is_main_process: | ||
484 | if avg_acc.avg.item() > max_acc: | ||
485 | accelerator.print( | ||
486 | f"Global step {global_step}: Training accuracy reached new maximum: {max_acc:.2e} -> {avg_acc.avg.item():.2e}") | ||
487 | on_checkpoint(global_step + global_step_offset, "milestone") | ||
488 | max_acc = avg_acc.avg.item() | ||
478 | 489 | ||
479 | # Create the pipeline using using the trained modules and save it. | 490 | # Create the pipeline using using the trained modules and save it. |
480 | if accelerator.is_main_process: | 491 | if accelerator.is_main_process: |
@@ -499,7 +510,7 @@ def train( | |||
499 | seed: int, | 510 | seed: int, |
500 | project: str, | 511 | project: str, |
501 | train_dataloader: DataLoader, | 512 | train_dataloader: DataLoader, |
502 | val_dataloader: DataLoader, | 513 | val_dataloader: Optional[DataLoader], |
503 | optimizer: torch.optim.Optimizer, | 514 | optimizer: torch.optim.Optimizer, |
504 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | 515 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, |
505 | callbacks_fn: Callable[..., TrainingCallbacks], | 516 | callbacks_fn: Callable[..., TrainingCallbacks], |
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index 6e7ebe2..aeaa828 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py | |||
@@ -26,7 +26,7 @@ def dreambooth_strategy( | |||
26 | vae: AutoencoderKL, | 26 | vae: AutoencoderKL, |
27 | sample_scheduler: DPMSolverMultistepScheduler, | 27 | sample_scheduler: DPMSolverMultistepScheduler, |
28 | train_dataloader: DataLoader, | 28 | train_dataloader: DataLoader, |
29 | val_dataloader: DataLoader, | 29 | val_dataloader: Optional[DataLoader], |
30 | output_dir: Path, | 30 | output_dir: Path, |
31 | seed: int, | 31 | seed: int, |
32 | train_text_encoder_epochs: int, | 32 | train_text_encoder_epochs: int, |
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 753dce0..568f9eb 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
@@ -26,7 +26,7 @@ def textual_inversion_strategy( | |||
26 | vae: AutoencoderKL, | 26 | vae: AutoencoderKL, |
27 | sample_scheduler: DPMSolverMultistepScheduler, | 27 | sample_scheduler: DPMSolverMultistepScheduler, |
28 | train_dataloader: DataLoader, | 28 | train_dataloader: DataLoader, |
29 | val_dataloader: DataLoader, | 29 | val_dataloader: Optional[DataLoader], |
30 | output_dir: Path, | 30 | output_dir: Path, |
31 | seed: int, | 31 | seed: int, |
32 | placeholder_tokens: list[str], | 32 | placeholder_tokens: list[str], |