summaryrefslogtreecommitdiffstats
path: root/training/strategy/ti.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-15 12:33:52 +0100
committerVolpeon <git@volpeon.ink>2023-01-15 12:33:52 +0100
commit59bf501198d7ff6c0c03c45e92adef14069d5ac6 (patch)
treeaae4c7204b4f04bf2146408fb88892071840a05d /training/strategy/ti.py
parentRemoved unused code, put training callbacks in dataclass (diff)
downloadtextual-inversion-diff-59bf501198d7ff6c0c03c45e92adef14069d5ac6.tar.gz
textual-inversion-diff-59bf501198d7ff6c0c03c45e92adef14069d5ac6.tar.bz2
textual-inversion-diff-59bf501198d7ff6c0c03c45e92adef14069d5ac6.zip
Update
Diffstat (limited to 'training/strategy/ti.py')
-rw-r--r--training/strategy/ti.py54
1 files changed, 26 insertions, 28 deletions
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index 6f8384f..753dce0 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -27,7 +27,6 @@ def textual_inversion_strategy(
27 sample_scheduler: DPMSolverMultistepScheduler, 27 sample_scheduler: DPMSolverMultistepScheduler,
28 train_dataloader: DataLoader, 28 train_dataloader: DataLoader,
29 val_dataloader: DataLoader, 29 val_dataloader: DataLoader,
30 dtype: torch.dtype,
31 output_dir: Path, 30 output_dir: Path,
32 seed: int, 31 seed: int,
33 placeholder_tokens: list[str], 32 placeholder_tokens: list[str],
@@ -48,6 +47,12 @@ def textual_inversion_strategy(
48 sample_guidance_scale: float = 7.5, 47 sample_guidance_scale: float = 7.5,
49 sample_image_size: Optional[int] = None, 48 sample_image_size: Optional[int] = None,
50): 49):
50 weight_dtype = torch.float32
51 if accelerator.state.mixed_precision == "fp16":
52 weight_dtype = torch.float16
53 elif accelerator.state.mixed_precision == "bf16":
54 weight_dtype = torch.bfloat16
55
51 save_samples_ = partial( 56 save_samples_ = partial(
52 save_samples, 57 save_samples,
53 accelerator=accelerator, 58 accelerator=accelerator,
@@ -58,7 +63,7 @@ def textual_inversion_strategy(
58 sample_scheduler=sample_scheduler, 63 sample_scheduler=sample_scheduler,
59 train_dataloader=train_dataloader, 64 train_dataloader=train_dataloader,
60 val_dataloader=val_dataloader, 65 val_dataloader=val_dataloader,
61 dtype=dtype, 66 dtype=weight_dtype,
62 output_dir=output_dir, 67 output_dir=output_dir,
63 seed=seed, 68 seed=seed,
64 batch_size=sample_batch_size, 69 batch_size=sample_batch_size,
@@ -78,6 +83,17 @@ def textual_inversion_strategy(
78 else: 83 else:
79 ema_embeddings = None 84 ema_embeddings = None
80 85
86 def ema_context():
87 if use_ema:
88 return ema_embeddings.apply_temporary(
89 text_encoder.text_model.embeddings.temp_token_embedding.parameters()
90 )
91 else:
92 return nullcontext()
93
94 def on_model():
95 return text_encoder
96
81 def on_prepare(): 97 def on_prepare():
82 text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(True) 98 text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(True)
83 99
@@ -89,24 +105,15 @@ def textual_inversion_strategy(
89 105
90 @contextmanager 106 @contextmanager
91 def on_train(epoch: int): 107 def on_train(epoch: int):
92 try: 108 tokenizer.train()
93 tokenizer.train() 109 yield
94 yield
95 finally:
96 pass
97 110
98 @contextmanager 111 @contextmanager
99 def on_eval(): 112 def on_eval():
100 try: 113 tokenizer.eval()
101 tokenizer.eval()
102 114
103 ema_context = ema_embeddings.apply_temporary( 115 with ema_context():
104 text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if use_ema else nullcontext() 116 yield
105
106 with ema_context:
107 yield
108 finally:
109 pass
110 117
111 @torch.no_grad() 118 @torch.no_grad()
112 def on_after_optimize(lr: float): 119 def on_after_optimize(lr: float):
@@ -131,13 +138,7 @@ def textual_inversion_strategy(
131 checkpoints_path = output_dir.joinpath("checkpoints") 138 checkpoints_path = output_dir.joinpath("checkpoints")
132 checkpoints_path.mkdir(parents=True, exist_ok=True) 139 checkpoints_path.mkdir(parents=True, exist_ok=True)
133 140
134 text_encoder = accelerator.unwrap_model(text_encoder) 141 with ema_context():
135
136 ema_context = ema_embeddings.apply_temporary(
137 text_encoder.text_model.embeddings.temp_token_embedding.parameters()
138 ) if ema_embeddings is not None else nullcontext()
139
140 with ema_context:
141 for (token, ids) in zip(placeholder_tokens, placeholder_token_ids): 142 for (token, ids) in zip(placeholder_tokens, placeholder_token_ids):
142 text_encoder.text_model.embeddings.save_embed( 143 text_encoder.text_model.embeddings.save_embed(
143 ids, 144 ids,
@@ -146,15 +147,12 @@ def textual_inversion_strategy(
146 147
147 @torch.no_grad() 148 @torch.no_grad()
148 def on_sample(step): 149 def on_sample(step):
149 ema_context = ema_embeddings.apply_temporary( 150 with ema_context():
150 text_encoder.text_model.embeddings.temp_token_embedding.parameters()
151 ) if ema_embeddings is not None else nullcontext()
152
153 with ema_context:
154 save_samples_(step=step) 151 save_samples_(step=step)
155 152
156 return TrainingCallbacks( 153 return TrainingCallbacks(
157 on_prepare=on_prepare, 154 on_prepare=on_prepare,
155 on_model=on_model,
158 on_train=on_train, 156 on_train=on_train,
159 on_eval=on_eval, 157 on_eval=on_eval,
160 on_after_optimize=on_after_optimize, 158 on_after_optimize=on_after_optimize,