diff options
Diffstat (limited to 'training/strategy/ti.py')
-rw-r--r-- | training/strategy/ti.py | 13 |
1 files changed, 7 insertions, 6 deletions
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 9d39e15..00f3529 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
@@ -27,7 +27,8 @@ def textual_inversion_strategy( | |||
27 | sample_scheduler: DPMSolverMultistepScheduler, | 27 | sample_scheduler: DPMSolverMultistepScheduler, |
28 | train_dataloader: DataLoader, | 28 | train_dataloader: DataLoader, |
29 | val_dataloader: Optional[DataLoader], | 29 | val_dataloader: Optional[DataLoader], |
30 | output_dir: Path, | 30 | sample_output_dir: Path, |
31 | checkpoint_output_dir: Path, | ||
31 | seed: int, | 32 | seed: int, |
32 | placeholder_tokens: list[str], | 33 | placeholder_tokens: list[str], |
33 | placeholder_token_ids: list[list[int]], | 34 | placeholder_token_ids: list[list[int]], |
@@ -47,6 +48,9 @@ def textual_inversion_strategy( | |||
47 | sample_guidance_scale: float = 7.5, | 48 | sample_guidance_scale: float = 7.5, |
48 | sample_image_size: Optional[int] = None, | 49 | sample_image_size: Optional[int] = None, |
49 | ): | 50 | ): |
51 | sample_output_dir.mkdir(parents=True, exist_ok=True) | ||
52 | checkpoint_output_dir.mkdir(parents=True, exist_ok=True) | ||
53 | |||
50 | weight_dtype = torch.float32 | 54 | weight_dtype = torch.float32 |
51 | if accelerator.state.mixed_precision == "fp16": | 55 | if accelerator.state.mixed_precision == "fp16": |
52 | weight_dtype = torch.float16 | 56 | weight_dtype = torch.float16 |
@@ -64,7 +68,7 @@ def textual_inversion_strategy( | |||
64 | train_dataloader=train_dataloader, | 68 | train_dataloader=train_dataloader, |
65 | val_dataloader=val_dataloader, | 69 | val_dataloader=val_dataloader, |
66 | dtype=weight_dtype, | 70 | dtype=weight_dtype, |
67 | output_dir=output_dir, | 71 | output_dir=sample_output_dir, |
68 | seed=seed, | 72 | seed=seed, |
69 | batch_size=sample_batch_size, | 73 | batch_size=sample_batch_size, |
70 | num_batches=sample_num_batches, | 74 | num_batches=sample_num_batches, |
@@ -135,14 +139,11 @@ def textual_inversion_strategy( | |||
135 | def on_checkpoint(step, postfix): | 139 | def on_checkpoint(step, postfix): |
136 | print(f"Saving checkpoint for step {step}...") | 140 | print(f"Saving checkpoint for step {step}...") |
137 | 141 | ||
138 | checkpoints_path = output_dir.joinpath("checkpoints") | ||
139 | checkpoints_path.mkdir(parents=True, exist_ok=True) | ||
140 | |||
141 | with ema_context(): | 142 | with ema_context(): |
142 | for (token, ids) in zip(placeholder_tokens, placeholder_token_ids): | 143 | for (token, ids) in zip(placeholder_tokens, placeholder_token_ids): |
143 | text_encoder.text_model.embeddings.save_embed( | 144 | text_encoder.text_model.embeddings.save_embed( |
144 | ids, | 145 | ids, |
145 | checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") | 146 | checkpoint_output_dir.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") |
146 | ) | 147 | ) |
147 | 148 | ||
148 | @torch.no_grad() | 149 | @torch.no_grad() |