summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-02-15 15:11:37 +0100
committerVolpeon <git@volpeon.ink>2023-02-15 15:11:37 +0100
commit0b8568ebfb484b9155d19c567e890ff52776544d (patch)
tree09ff3256ca8d2c9359a236e7bdb274f539c5253e
parentImproved batch padding (diff)
downloadtextual-inversion-diff-0b8568ebfb484b9155d19c567e890ff52776544d.tar.gz
textual-inversion-diff-0b8568ebfb484b9155d19c567e890ff52776544d.tar.bz2
textual-inversion-diff-0b8568ebfb484b9155d19c567e890ff52776544d.zip
Update
-rw-r--r--data/csv.py1
-rw-r--r--train_ti.py1
-rw-r--r--training/functional.py2
3 files changed, 3 insertions, 1 deletions
diff --git a/data/csv.py b/data/csv.py
index 233f5d8..619452e 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -403,6 +403,7 @@ class VlpnDataset(IterableDataset):
403 if len(batch) >= batch_size: 403 if len(batch) >= batch_size:
404 yield batch 404 yield batch
405 batch = [] 405 batch = []
406 continue
406 407
407 bucket_mask = mask.logical_and(self.bucket_assignments == bucket) 408 bucket_mask = mask.logical_and(self.bucket_assignments == bucket)
408 bucket_items = self.bucket_items[bucket_mask] 409 bucket_items = self.bucket_items[bucket_mask]
diff --git a/train_ti.py b/train_ti.py
index 171d085..f78c7d2 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -586,6 +586,7 @@ def main():
586 seed=args.seed, 586 seed=args.seed,
587 with_prior_preservation=args.num_class_images != 0, 587 with_prior_preservation=args.num_class_images != 0,
588 prior_loss_weight=args.prior_loss_weight, 588 prior_loss_weight=args.prior_loss_weight,
589 low_freq_noise=0,
589 strategy=textual_inversion_strategy, 590 strategy=textual_inversion_strategy,
590 num_train_epochs=args.num_train_epochs, 591 num_train_epochs=args.num_train_epochs,
591 sample_frequency=args.sample_frequency, 592 sample_frequency=args.sample_frequency,
diff --git a/training/functional.py b/training/functional.py
index a9c7a8a..e1035ce 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -556,7 +556,7 @@ def train(
556 global_step_offset: int = 0, 556 global_step_offset: int = 0,
557 with_prior_preservation: bool = False, 557 with_prior_preservation: bool = False,
558 prior_loss_weight: float = 1.0, 558 prior_loss_weight: float = 1.0,
559 low_freq_noise: float = 0.05, 559 low_freq_noise: float = 0.1,
560 **kwargs, 560 **kwargs,
561): 561):
562 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare( 562 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare(