diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/functional.py | 2 | ||||
-rw-r--r-- | training/strategy/lora.py | 2 | ||||
-rw-r--r-- | training/strategy/ti.py | 2 |
3 files changed, 3 insertions, 3 deletions
diff --git a/training/functional.py b/training/functional.py index ccbb4ad..83e70e2 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -129,7 +129,7 @@ def save_samples( | |||
129 | 129 | ||
130 | for pool, data, gen in datasets: | 130 | for pool, data, gen in datasets: |
131 | all_samples = [] | 131 | all_samples = [] |
132 | file_path = output_dir.joinpath(pool, f"step_{step}.jpg") | 132 | file_path = output_dir / pool / f"step_{step}.jpg" |
133 | file_path.parent.mkdir(parents=True, exist_ok=True) | 133 | file_path.parent.mkdir(parents=True, exist_ok=True) |
134 | 134 | ||
135 | batches = list(itertools.islice(itertools.cycle(data), batch_size * num_batches)) | 135 | batches = list(itertools.islice(itertools.cycle(data), batch_size * num_batches)) |
diff --git a/training/strategy/lora.py b/training/strategy/lora.py index bc10e58..4dd1100 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py | |||
@@ -91,7 +91,7 @@ def lora_strategy_callbacks( | |||
91 | print(f"Saving checkpoint for step {step}...") | 91 | print(f"Saving checkpoint for step {step}...") |
92 | 92 | ||
93 | unet_ = accelerator.unwrap_model(unet) | 93 | unet_ = accelerator.unwrap_model(unet) |
94 | unet_.save_attn_procs(checkpoint_output_dir.joinpath(f"{step}_{postfix}")) | 94 | unet_.save_attn_procs(checkpoint_output_dir / f"{step}_{postfix}") |
95 | del unet_ | 95 | del unet_ |
96 | 96 | ||
97 | @torch.no_grad() | 97 | @torch.no_grad() |
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index da2b81c..0de3cb0 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
@@ -138,7 +138,7 @@ def textual_inversion_strategy_callbacks( | |||
138 | for (token, ids) in zip(placeholder_tokens, placeholder_token_ids): | 138 | for (token, ids) in zip(placeholder_tokens, placeholder_token_ids): |
139 | text_encoder.text_model.embeddings.save_embed( | 139 | text_encoder.text_model.embeddings.save_embed( |
140 | ids, | 140 | ids, |
141 | checkpoint_output_dir.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") | 141 | checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin" |
142 | ) | 142 | ) |
143 | 143 | ||
144 | @torch.no_grad() | 144 | @torch.no_grad() |