diff options
-rw-r--r-- | data/csv.py | 1 | ||||
-rw-r--r-- | train_ti.py | 1 | ||||
-rw-r--r-- | training/functional.py | 2 |
3 files changed, 3 insertions, 1 deletions
diff --git a/data/csv.py b/data/csv.py index 233f5d8..619452e 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -403,6 +403,7 @@ class VlpnDataset(IterableDataset): | |||
403 | if len(batch) >= batch_size: | 403 | if len(batch) >= batch_size: |
404 | yield batch | 404 | yield batch |
405 | batch = [] | 405 | batch = [] |
406 | continue | ||
406 | 407 | ||
407 | bucket_mask = mask.logical_and(self.bucket_assignments == bucket) | 408 | bucket_mask = mask.logical_and(self.bucket_assignments == bucket) |
408 | bucket_items = self.bucket_items[bucket_mask] | 409 | bucket_items = self.bucket_items[bucket_mask] |
diff --git a/train_ti.py b/train_ti.py index 171d085..f78c7d2 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -586,6 +586,7 @@ def main(): | |||
586 | seed=args.seed, | 586 | seed=args.seed, |
587 | with_prior_preservation=args.num_class_images != 0, | 587 | with_prior_preservation=args.num_class_images != 0, |
588 | prior_loss_weight=args.prior_loss_weight, | 588 | prior_loss_weight=args.prior_loss_weight, |
589 | low_freq_noise=0, | ||
589 | strategy=textual_inversion_strategy, | 590 | strategy=textual_inversion_strategy, |
590 | num_train_epochs=args.num_train_epochs, | 591 | num_train_epochs=args.num_train_epochs, |
591 | sample_frequency=args.sample_frequency, | 592 | sample_frequency=args.sample_frequency, |
diff --git a/training/functional.py b/training/functional.py index a9c7a8a..e1035ce 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -556,7 +556,7 @@ def train( | |||
556 | global_step_offset: int = 0, | 556 | global_step_offset: int = 0, |
557 | with_prior_preservation: bool = False, | 557 | with_prior_preservation: bool = False, |
558 | prior_loss_weight: float = 1.0, | 558 | prior_loss_weight: float = 1.0, |
559 | low_freq_noise: float = 0.05, | 559 | low_freq_noise: float = 0.1, |
560 | **kwargs, | 560 | **kwargs, |
561 | ): | 561 | ): |
562 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare( | 562 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare( |