summaryrefslogtreecommitdiffstats
path: root/training/strategy/ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-16 21:47:06 +0100
committerVolpeon <git@volpeon.ink>2023-01-16 21:47:06 +0100
commit2dfd1790078753f19ca8c585ac77079f3114f3a9 (patch)
treed1d1d643f247767c13535105dbe4afafcc5ab8c0 /training/strategy/ti.py
parentIf valid set size is 0, re-use one image from train set (diff)
downloadtextual-inversion-diff-2dfd1790078753f19ca8c585ac77079f3114f3a9.tar.gz
textual-inversion-diff-2dfd1790078753f19ca8c585ac77079f3114f3a9.tar.bz2
textual-inversion-diff-2dfd1790078753f19ca8c585ac77079f3114f3a9.zip
Training update
Diffstat (limited to 'training/strategy/ti.py')
-rw-r--r--training/strategy/ti.py13
1 files changed, 7 insertions, 6 deletions
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index 9d39e15..00f3529 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -27,7 +27,8 @@ def textual_inversion_strategy(
27 sample_scheduler: DPMSolverMultistepScheduler, 27 sample_scheduler: DPMSolverMultistepScheduler,
28 train_dataloader: DataLoader, 28 train_dataloader: DataLoader,
29 val_dataloader: Optional[DataLoader], 29 val_dataloader: Optional[DataLoader],
30 output_dir: Path, 30 sample_output_dir: Path,
31 checkpoint_output_dir: Path,
31 seed: int, 32 seed: int,
32 placeholder_tokens: list[str], 33 placeholder_tokens: list[str],
33 placeholder_token_ids: list[list[int]], 34 placeholder_token_ids: list[list[int]],
@@ -47,6 +48,9 @@ def textual_inversion_strategy(
47 sample_guidance_scale: float = 7.5, 48 sample_guidance_scale: float = 7.5,
48 sample_image_size: Optional[int] = None, 49 sample_image_size: Optional[int] = None,
49): 50):
51 sample_output_dir.mkdir(parents=True, exist_ok=True)
52 checkpoint_output_dir.mkdir(parents=True, exist_ok=True)
53
50 weight_dtype = torch.float32 54 weight_dtype = torch.float32
51 if accelerator.state.mixed_precision == "fp16": 55 if accelerator.state.mixed_precision == "fp16":
52 weight_dtype = torch.float16 56 weight_dtype = torch.float16
@@ -64,7 +68,7 @@ def textual_inversion_strategy(
64 train_dataloader=train_dataloader, 68 train_dataloader=train_dataloader,
65 val_dataloader=val_dataloader, 69 val_dataloader=val_dataloader,
66 dtype=weight_dtype, 70 dtype=weight_dtype,
67 output_dir=output_dir, 71 output_dir=sample_output_dir,
68 seed=seed, 72 seed=seed,
69 batch_size=sample_batch_size, 73 batch_size=sample_batch_size,
70 num_batches=sample_num_batches, 74 num_batches=sample_num_batches,
@@ -135,14 +139,11 @@ def textual_inversion_strategy(
135 def on_checkpoint(step, postfix): 139 def on_checkpoint(step, postfix):
136 print(f"Saving checkpoint for step {step}...") 140 print(f"Saving checkpoint for step {step}...")
137 141
138 checkpoints_path = output_dir.joinpath("checkpoints")
139 checkpoints_path.mkdir(parents=True, exist_ok=True)
140
141 with ema_context(): 142 with ema_context():
142 for (token, ids) in zip(placeholder_tokens, placeholder_token_ids): 143 for (token, ids) in zip(placeholder_tokens, placeholder_token_ids):
143 text_encoder.text_model.embeddings.save_embed( 144 text_encoder.text_model.embeddings.save_embed(
144 ids, 145 ids,
145 checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") 146 checkpoint_output_dir.joinpath(f"{slugify(token)}_{step}_{postfix}.bin")
146 ) 147 )
147 148
148 @torch.no_grad() 149 @torch.no_grad()