diff options
Diffstat (limited to 'training')
| -rw-r--r-- | training/functional.py | 5 | ||||
| -rw-r--r-- | training/optimization.py | 10 | ||||
| -rw-r--r-- | training/strategy/ti.py | 2 |
3 files changed, 9 insertions, 8 deletions
diff --git a/training/functional.py b/training/functional.py index c6b4dc3..b6b5d87 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -17,6 +17,7 @@ from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, DPMSol | |||
| 17 | from tqdm.auto import tqdm | 17 | from tqdm.auto import tqdm |
| 18 | from PIL import Image | 18 | from PIL import Image |
| 19 | 19 | ||
| 20 | from data.csv import VlpnDataset | ||
| 20 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 21 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
| 21 | from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings | 22 | from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings |
| 22 | from models.clip.util import get_extended_embeddings | 23 | from models.clip.util import get_extended_embeddings |
| @@ -175,12 +176,12 @@ def generate_class_images( | |||
| 175 | unet: UNet2DConditionModel, | 176 | unet: UNet2DConditionModel, |
| 176 | tokenizer: MultiCLIPTokenizer, | 177 | tokenizer: MultiCLIPTokenizer, |
| 177 | sample_scheduler: DPMSolverMultistepScheduler, | 178 | sample_scheduler: DPMSolverMultistepScheduler, |
| 178 | data_train, | 179 | train_dataset: VlpnDataset, |
| 179 | sample_batch_size: int, | 180 | sample_batch_size: int, |
| 180 | sample_image_size: int, | 181 | sample_image_size: int, |
| 181 | sample_steps: int | 182 | sample_steps: int |
| 182 | ): | 183 | ): |
| 183 | missing_data = [item for item in data_train if not item.class_image_path.exists()] | 184 | missing_data = [item for item in train_dataset.items if not item.class_image_path.exists()] |
| 184 | 185 | ||
| 185 | if len(missing_data) == 0: | 186 | if len(missing_data) == 0: |
| 186 | return | 187 | return |
diff --git a/training/optimization.py b/training/optimization.py index 5db7794..6dee4bc 100644 --- a/training/optimization.py +++ b/training/optimization.py | |||
| @@ -49,8 +49,8 @@ def get_one_cycle_schedule( | |||
| 49 | annealing: Literal["cos", "half_cos", "linear"] = "cos", | 49 | annealing: Literal["cos", "half_cos", "linear"] = "cos", |
| 50 | warmup_exp: int = 1, | 50 | warmup_exp: int = 1, |
| 51 | annealing_exp: int = 1, | 51 | annealing_exp: int = 1, |
| 52 | min_lr: int = 0.04, | 52 | min_lr: float = 0.04, |
| 53 | mid_point: int = 0.3, | 53 | mid_point: float = 0.3, |
| 54 | last_epoch: int = -1 | 54 | last_epoch: int = -1 |
| 55 | ): | 55 | ): |
| 56 | if warmup == "linear": | 56 | if warmup == "linear": |
| @@ -91,10 +91,10 @@ def get_scheduler( | |||
| 91 | id: str, | 91 | id: str, |
| 92 | optimizer: torch.optim.Optimizer, | 92 | optimizer: torch.optim.Optimizer, |
| 93 | num_training_steps_per_epoch: int, | 93 | num_training_steps_per_epoch: int, |
| 94 | gradient_accumulation_steps: int, | 94 | gradient_accumulation_steps: int = 1, |
| 95 | min_lr: float = 0.04, | 95 | min_lr: float = 0.04, |
| 96 | warmup_func: str = "cos", | 96 | warmup_func: Literal["cos", "linear"] = "cos", |
| 97 | annealing_func: str = "cos", | 97 | annealing_func: Literal["cos", "half_cos", "linear"] = "cos", |
| 98 | warmup_exp: int = 1, | 98 | warmup_exp: int = 1, |
| 99 | annealing_exp: int = 1, | 99 | annealing_exp: int = 1, |
| 100 | cycles: int = 1, | 100 | cycles: int = 1, |
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 568f9eb..9d39e15 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
| @@ -36,7 +36,7 @@ def textual_inversion_strategy( | |||
| 36 | use_emb_decay: bool = False, | 36 | use_emb_decay: bool = False, |
| 37 | emb_decay_target: float = 0.4, | 37 | emb_decay_target: float = 0.4, |
| 38 | emb_decay_factor: float = 1, | 38 | emb_decay_factor: float = 1, |
| 39 | emb_decay_start: float = 1e-4, | 39 | emb_decay_start: float = 0, |
| 40 | use_ema: bool = False, | 40 | use_ema: bool = False, |
| 41 | ema_inv_gamma: float = 1.0, | 41 | ema_inv_gamma: float = 1.0, |
| 42 | ema_power: int = 1, | 42 | ema_power: int = 1, |
