From 6c8cffe28baeafac77d047ff3f8ded9418033e2f Mon Sep 17 00:00:00 2001 From: Volpeon Date: Mon, 16 Jan 2023 15:52:43 +0100 Subject: More training adjustments --- training/functional.py | 5 +++-- training/optimization.py | 10 +++++----- training/strategy/ti.py | 2 +- 3 files changed, 9 insertions(+), 8 deletions(-) (limited to 'training') diff --git a/training/functional.py b/training/functional.py index c6b4dc3..b6b5d87 100644 --- a/training/functional.py +++ b/training/functional.py @@ -17,6 +17,7 @@ from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, DPMSol from tqdm.auto import tqdm from PIL import Image +from data.csv import VlpnDataset from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings from models.clip.util import get_extended_embeddings @@ -175,12 +176,12 @@ def generate_class_images( unet: UNet2DConditionModel, tokenizer: MultiCLIPTokenizer, sample_scheduler: DPMSolverMultistepScheduler, - data_train, + train_dataset: VlpnDataset, sample_batch_size: int, sample_image_size: int, sample_steps: int ): - missing_data = [item for item in data_train if not item.class_image_path.exists()] + missing_data = [item for item in train_dataset.items if not item.class_image_path.exists()] if len(missing_data) == 0: return diff --git a/training/optimization.py b/training/optimization.py index 5db7794..6dee4bc 100644 --- a/training/optimization.py +++ b/training/optimization.py @@ -49,8 +49,8 @@ def get_one_cycle_schedule( annealing: Literal["cos", "half_cos", "linear"] = "cos", warmup_exp: int = 1, annealing_exp: int = 1, - min_lr: int = 0.04, - mid_point: int = 0.3, + min_lr: float = 0.04, + mid_point: float = 0.3, last_epoch: int = -1 ): if warmup == "linear": @@ -91,10 +91,10 @@ def get_scheduler( id: str, optimizer: torch.optim.Optimizer, num_training_steps_per_epoch: int, - gradient_accumulation_steps: int, + gradient_accumulation_steps: int = 1, min_lr: float = 0.04, - warmup_func: str = "cos", - annealing_func: str = "cos", + warmup_func: Literal["cos", "linear"] = "cos", + annealing_func: Literal["cos", "half_cos", "linear"] = "cos", warmup_exp: int = 1, annealing_exp: int = 1, cycles: int = 1, diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 568f9eb..9d39e15 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py @@ -36,7 +36,7 @@ def textual_inversion_strategy( use_emb_decay: bool = False, emb_decay_target: float = 0.4, emb_decay_factor: float = 1, - emb_decay_start: float = 1e-4, + emb_decay_start: float = 0, use_ema: bool = False, ema_inv_gamma: float = 1.0, ema_power: int = 1, -- cgit v1.2.3-54-g00ecf