diff options
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 2 |
1 files changed, 1 insertions, 1 deletions
diff --git a/train_ti.py b/train_ti.py index e272b5d..cc208f0 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -621,7 +621,7 @@ def main(): | |||
| 621 | ).to(accelerator.device) | 621 | ).to(accelerator.device) |
| 622 | pipeline.set_progress_bar_config(dynamic_ncols=True) | 622 | pipeline.set_progress_bar_config(dynamic_ncols=True) |
| 623 | 623 | ||
| 624 | with torch.autocast("cuda"), torch.inference_mode(): | 624 | with torch.inference_mode(): |
| 625 | for batch in batched_data: | 625 | for batch in batched_data: |
| 626 | image_name = [item.class_image_path for item in batch] | 626 | image_name = [item.class_image_path for item in batch] |
| 627 | prompt = [item.cprompt for item in batch] | 627 | prompt = [item.cprompt for item in batch] |
