summaryrefslogtreecommitdiffstats
path: root/train_dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-23 23:02:01 +0100
committerVolpeon <git@volpeon.ink>2022-12-23 23:02:01 +0100
commit3a83ec17318dc60ed46b4a3279d3dcbe7e8b02de (patch)
tree7b12a26c195e7298bb6cbc993ad0dd0f322fede4 /train_dreambooth.py
parentnum_class_images is now class images per train image (diff)
downloadtextual-inversion-diff-3a83ec17318dc60ed46b4a3279d3dcbe7e8b02de.tar.gz
textual-inversion-diff-3a83ec17318dc60ed46b4a3279d3dcbe7e8b02de.tar.bz2
textual-inversion-diff-3a83ec17318dc60ed46b4a3279d3dcbe7e8b02de.zip
Better dataset prompt handling
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r--train_dreambooth.py2
1 files changed, 1 insertions, 1 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 2f913e7..1a79b2b 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -687,7 +687,7 @@ def main():
687 ).to(accelerator.device) 687 ).to(accelerator.device)
688 pipeline.set_progress_bar_config(dynamic_ncols=True) 688 pipeline.set_progress_bar_config(dynamic_ncols=True)
689 689
690 with torch.autocast("cuda"), torch.inference_mode(): 690 with torch.inference_mode():
691 for batch in batched_data: 691 for batch in batched_data:
692 image_name = [item.class_image_path for item in batch] 692 image_name = [item.class_image_path for item in batch]
693 prompt = [item.cprompt for item in batch] 693 prompt = [item.cprompt for item in batch]