diff options
author | Volpeon <git@volpeon.ink> | 2023-01-16 15:52:43 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-16 15:52:43 +0100 |
commit | 6c8cffe28baeafac77d047ff3f8ded9418033e2f (patch) | |
tree | 807c527deb1b15ef795f5cd8a7682151c69a037e /training | |
parent | Pad dataset if len(items) < batch_size (diff) | |
download | textual-inversion-diff-6c8cffe28baeafac77d047ff3f8ded9418033e2f.tar.gz textual-inversion-diff-6c8cffe28baeafac77d047ff3f8ded9418033e2f.tar.bz2 textual-inversion-diff-6c8cffe28baeafac77d047ff3f8ded9418033e2f.zip |
More training adjustments
Diffstat (limited to 'training')
-rw-r--r-- | training/functional.py | 5 | ||||
-rw-r--r-- | training/optimization.py | 10 | ||||
-rw-r--r-- | training/strategy/ti.py | 2 |
3 files changed, 9 insertions, 8 deletions
diff --git a/training/functional.py b/training/functional.py index c6b4dc3..b6b5d87 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -17,6 +17,7 @@ from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, DPMSol | |||
17 | from tqdm.auto import tqdm | 17 | from tqdm.auto import tqdm |
18 | from PIL import Image | 18 | from PIL import Image |
19 | 19 | ||
20 | from data.csv import VlpnDataset | ||
20 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 21 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
21 | from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings | 22 | from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings |
22 | from models.clip.util import get_extended_embeddings | 23 | from models.clip.util import get_extended_embeddings |
@@ -175,12 +176,12 @@ def generate_class_images( | |||
175 | unet: UNet2DConditionModel, | 176 | unet: UNet2DConditionModel, |
176 | tokenizer: MultiCLIPTokenizer, | 177 | tokenizer: MultiCLIPTokenizer, |
177 | sample_scheduler: DPMSolverMultistepScheduler, | 178 | sample_scheduler: DPMSolverMultistepScheduler, |
178 | data_train, | 179 | train_dataset: VlpnDataset, |
179 | sample_batch_size: int, | 180 | sample_batch_size: int, |
180 | sample_image_size: int, | 181 | sample_image_size: int, |
181 | sample_steps: int | 182 | sample_steps: int |
182 | ): | 183 | ): |
183 | missing_data = [item for item in data_train if not item.class_image_path.exists()] | 184 | missing_data = [item for item in train_dataset.items if not item.class_image_path.exists()] |
184 | 185 | ||
185 | if len(missing_data) == 0: | 186 | if len(missing_data) == 0: |
186 | return | 187 | return |
diff --git a/training/optimization.py b/training/optimization.py index 5db7794..6dee4bc 100644 --- a/training/optimization.py +++ b/training/optimization.py | |||
@@ -49,8 +49,8 @@ def get_one_cycle_schedule( | |||
49 | annealing: Literal["cos", "half_cos", "linear"] = "cos", | 49 | annealing: Literal["cos", "half_cos", "linear"] = "cos", |
50 | warmup_exp: int = 1, | 50 | warmup_exp: int = 1, |
51 | annealing_exp: int = 1, | 51 | annealing_exp: int = 1, |
52 | min_lr: int = 0.04, | 52 | min_lr: float = 0.04, |
53 | mid_point: int = 0.3, | 53 | mid_point: float = 0.3, |
54 | last_epoch: int = -1 | 54 | last_epoch: int = -1 |
55 | ): | 55 | ): |
56 | if warmup == "linear": | 56 | if warmup == "linear": |
@@ -91,10 +91,10 @@ def get_scheduler( | |||
91 | id: str, | 91 | id: str, |
92 | optimizer: torch.optim.Optimizer, | 92 | optimizer: torch.optim.Optimizer, |
93 | num_training_steps_per_epoch: int, | 93 | num_training_steps_per_epoch: int, |
94 | gradient_accumulation_steps: int, | 94 | gradient_accumulation_steps: int = 1, |
95 | min_lr: float = 0.04, | 95 | min_lr: float = 0.04, |
96 | warmup_func: str = "cos", | 96 | warmup_func: Literal["cos", "linear"] = "cos", |
97 | annealing_func: str = "cos", | 97 | annealing_func: Literal["cos", "half_cos", "linear"] = "cos", |
98 | warmup_exp: int = 1, | 98 | warmup_exp: int = 1, |
99 | annealing_exp: int = 1, | 99 | annealing_exp: int = 1, |
100 | cycles: int = 1, | 100 | cycles: int = 1, |
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 568f9eb..9d39e15 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
@@ -36,7 +36,7 @@ def textual_inversion_strategy( | |||
36 | use_emb_decay: bool = False, | 36 | use_emb_decay: bool = False, |
37 | emb_decay_target: float = 0.4, | 37 | emb_decay_target: float = 0.4, |
38 | emb_decay_factor: float = 1, | 38 | emb_decay_factor: float = 1, |
39 | emb_decay_start: float = 1e-4, | 39 | emb_decay_start: float = 0, |
40 | use_ema: bool = False, | 40 | use_ema: bool = False, |
41 | ema_inv_gamma: float = 1.0, | 41 | ema_inv_gamma: float = 1.0, |
42 | ema_power: int = 1, | 42 | ema_power: int = 1, |