summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-16 15:52:43 +0100
committerVolpeon <git@volpeon.ink>2023-01-16 15:52:43 +0100
commit6c8cffe28baeafac77d047ff3f8ded9418033e2f (patch)
tree807c527deb1b15ef795f5cd8a7682151c69a037e /training
parentPad dataset if len(items) < batch_size (diff)
downloadtextual-inversion-diff-6c8cffe28baeafac77d047ff3f8ded9418033e2f.tar.gz
textual-inversion-diff-6c8cffe28baeafac77d047ff3f8ded9418033e2f.tar.bz2
textual-inversion-diff-6c8cffe28baeafac77d047ff3f8ded9418033e2f.zip
More training adjustments
Diffstat (limited to 'training')
-rw-r--r--training/functional.py5
-rw-r--r--training/optimization.py10
-rw-r--r--training/strategy/ti.py2
3 files changed, 9 insertions, 8 deletions
diff --git a/training/functional.py b/training/functional.py
index c6b4dc3..b6b5d87 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -17,6 +17,7 @@ from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, DPMSol
17from tqdm.auto import tqdm 17from tqdm.auto import tqdm
18from PIL import Image 18from PIL import Image
19 19
20from data.csv import VlpnDataset
20from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 21from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
21from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings 22from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings
22from models.clip.util import get_extended_embeddings 23from models.clip.util import get_extended_embeddings
@@ -175,12 +176,12 @@ def generate_class_images(
175 unet: UNet2DConditionModel, 176 unet: UNet2DConditionModel,
176 tokenizer: MultiCLIPTokenizer, 177 tokenizer: MultiCLIPTokenizer,
177 sample_scheduler: DPMSolverMultistepScheduler, 178 sample_scheduler: DPMSolverMultistepScheduler,
178 data_train, 179 train_dataset: VlpnDataset,
179 sample_batch_size: int, 180 sample_batch_size: int,
180 sample_image_size: int, 181 sample_image_size: int,
181 sample_steps: int 182 sample_steps: int
182): 183):
183 missing_data = [item for item in data_train if not item.class_image_path.exists()] 184 missing_data = [item for item in train_dataset.items if not item.class_image_path.exists()]
184 185
185 if len(missing_data) == 0: 186 if len(missing_data) == 0:
186 return 187 return
diff --git a/training/optimization.py b/training/optimization.py
index 5db7794..6dee4bc 100644
--- a/training/optimization.py
+++ b/training/optimization.py
@@ -49,8 +49,8 @@ def get_one_cycle_schedule(
49 annealing: Literal["cos", "half_cos", "linear"] = "cos", 49 annealing: Literal["cos", "half_cos", "linear"] = "cos",
50 warmup_exp: int = 1, 50 warmup_exp: int = 1,
51 annealing_exp: int = 1, 51 annealing_exp: int = 1,
52 min_lr: int = 0.04, 52 min_lr: float = 0.04,
53 mid_point: int = 0.3, 53 mid_point: float = 0.3,
54 last_epoch: int = -1 54 last_epoch: int = -1
55): 55):
56 if warmup == "linear": 56 if warmup == "linear":
@@ -91,10 +91,10 @@ def get_scheduler(
91 id: str, 91 id: str,
92 optimizer: torch.optim.Optimizer, 92 optimizer: torch.optim.Optimizer,
93 num_training_steps_per_epoch: int, 93 num_training_steps_per_epoch: int,
94 gradient_accumulation_steps: int, 94 gradient_accumulation_steps: int = 1,
95 min_lr: float = 0.04, 95 min_lr: float = 0.04,
96 warmup_func: str = "cos", 96 warmup_func: Literal["cos", "linear"] = "cos",
97 annealing_func: str = "cos", 97 annealing_func: Literal["cos", "half_cos", "linear"] = "cos",
98 warmup_exp: int = 1, 98 warmup_exp: int = 1,
99 annealing_exp: int = 1, 99 annealing_exp: int = 1,
100 cycles: int = 1, 100 cycles: int = 1,
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index 568f9eb..9d39e15 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -36,7 +36,7 @@ def textual_inversion_strategy(
36 use_emb_decay: bool = False, 36 use_emb_decay: bool = False,
37 emb_decay_target: float = 0.4, 37 emb_decay_target: float = 0.4,
38 emb_decay_factor: float = 1, 38 emb_decay_factor: float = 1,
39 emb_decay_start: float = 1e-4, 39 emb_decay_start: float = 0,
40 use_ema: bool = False, 40 use_ema: bool = False,
41 ema_inv_gamma: float = 1.0, 41 ema_inv_gamma: float = 1.0,
42 ema_power: int = 1, 42 ema_power: int = 1,