diff options
| author | Volpeon <git@volpeon.ink> | 2023-06-21 13:28:49 +0200 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-06-21 13:28:49 +0200 |
| commit | 8364ce697ddf6117fdd4f7222832d546d63880de (patch) | |
| tree | 152c99815bbd8b2659d0dabe63c98f63151c97c2 /training | |
| parent | Fix LoRA training with DAdan (diff) | |
| download | textual-inversion-diff-8364ce697ddf6117fdd4f7222832d546d63880de.tar.gz textual-inversion-diff-8364ce697ddf6117fdd4f7222832d546d63880de.tar.bz2 textual-inversion-diff-8364ce697ddf6117fdd4f7222832d546d63880de.zip | |
Update
Diffstat (limited to 'training')
| -rw-r--r-- | training/functional.py | 221 | ||||
| -rw-r--r-- | training/lr.py | 4 | ||||
| -rw-r--r-- | training/optimization.py | 38 | ||||
| -rw-r--r-- | training/sampler.py | 2 | ||||
| -rw-r--r-- | training/strategy/dreambooth.py | 29 | ||||
| -rw-r--r-- | training/strategy/lora.py | 41 | ||||
| -rw-r--r-- | training/strategy/ti.py | 27 |
7 files changed, 245 insertions, 117 deletions
diff --git a/training/functional.py b/training/functional.py index fd3f9f4..f68faf9 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -14,7 +14,13 @@ import numpy as np | |||
| 14 | 14 | ||
| 15 | from accelerate import Accelerator | 15 | from accelerate import Accelerator |
| 16 | from transformers import CLIPTextModel | 16 | from transformers import CLIPTextModel |
| 17 | from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler, UniPCMultistepScheduler, SchedulerMixin | 17 | from diffusers import ( |
| 18 | AutoencoderKL, | ||
| 19 | UNet2DConditionModel, | ||
| 20 | DDPMScheduler, | ||
| 21 | UniPCMultistepScheduler, | ||
| 22 | SchedulerMixin, | ||
| 23 | ) | ||
| 18 | 24 | ||
| 19 | from tqdm.auto import tqdm | 25 | from tqdm.auto import tqdm |
| 20 | 26 | ||
| @@ -33,11 +39,12 @@ from util.noise import perlin_noise | |||
| 33 | def const(result=None): | 39 | def const(result=None): |
| 34 | def fn(*args, **kwargs): | 40 | def fn(*args, **kwargs): |
| 35 | return result | 41 | return result |
| 42 | |||
| 36 | return fn | 43 | return fn |
| 37 | 44 | ||
| 38 | 45 | ||
| 39 | @dataclass | 46 | @dataclass |
| 40 | class TrainingCallbacks(): | 47 | class TrainingCallbacks: |
| 41 | on_log: Callable[[], dict[str, Any]] = const({}) | 48 | on_log: Callable[[], dict[str, Any]] = const({}) |
| 42 | on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) | 49 | on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) |
| 43 | on_before_optimize: Callable[[int], Any] = const() | 50 | on_before_optimize: Callable[[int], Any] = const() |
| @@ -58,23 +65,36 @@ class TrainingStrategyPrepareCallable(Protocol): | |||
| 58 | train_dataloader: DataLoader, | 65 | train_dataloader: DataLoader, |
| 59 | val_dataloader: Optional[DataLoader], | 66 | val_dataloader: Optional[DataLoader], |
| 60 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | 67 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, |
| 61 | **kwargs | 68 | **kwargs, |
| 62 | ) -> Tuple: ... | 69 | ) -> Tuple: |
| 70 | ... | ||
| 63 | 71 | ||
| 64 | 72 | ||
| 65 | @dataclass | 73 | @dataclass |
| 66 | class TrainingStrategy(): | 74 | class TrainingStrategy: |
| 67 | callbacks: Callable[..., TrainingCallbacks] | 75 | callbacks: Callable[..., TrainingCallbacks] |
| 68 | prepare: TrainingStrategyPrepareCallable | 76 | prepare: TrainingStrategyPrepareCallable |
| 69 | 77 | ||
| 70 | 78 | ||
| 71 | def get_models(pretrained_model_name_or_path: str, torch_dtype=torch.float32): | 79 | def get_models(pretrained_model_name_or_path: str, torch_dtype=torch.float32): |
| 72 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') | 80 | tokenizer = MultiCLIPTokenizer.from_pretrained( |
| 73 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder', torch_dtype=torch_dtype) | 81 | pretrained_model_name_or_path, subfolder="tokenizer" |
| 74 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae', torch_dtype=torch_dtype) | 82 | ) |
| 75 | unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet', torch_dtype=torch_dtype) | 83 | text_encoder = CLIPTextModel.from_pretrained( |
| 76 | noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') | 84 | pretrained_model_name_or_path, subfolder="text_encoder", torch_dtype=torch_dtype |
| 77 | sample_scheduler = UniPCMultistepScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') | 85 | ) |
| 86 | vae = AutoencoderKL.from_pretrained( | ||
| 87 | pretrained_model_name_or_path, subfolder="vae", torch_dtype=torch_dtype | ||
| 88 | ) | ||
| 89 | unet = UNet2DConditionModel.from_pretrained( | ||
| 90 | pretrained_model_name_or_path, subfolder="unet", torch_dtype=torch_dtype | ||
| 91 | ) | ||
| 92 | noise_scheduler = DDPMScheduler.from_pretrained( | ||
| 93 | pretrained_model_name_or_path, subfolder="scheduler" | ||
| 94 | ) | ||
| 95 | sample_scheduler = UniPCMultistepScheduler.from_pretrained( | ||
| 96 | pretrained_model_name_or_path, subfolder="scheduler" | ||
| 97 | ) | ||
| 78 | 98 | ||
| 79 | return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler | 99 | return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler |
| 80 | 100 | ||
| @@ -113,7 +133,9 @@ def save_samples( | |||
| 113 | 133 | ||
| 114 | generator = torch.Generator(device=accelerator.device).manual_seed(seed) | 134 | generator = torch.Generator(device=accelerator.device).manual_seed(seed) |
| 115 | 135 | ||
| 116 | datasets: list[tuple[str, DataLoader, Optional[torch.Generator]]] = [("train", train_dataloader, None)] | 136 | datasets: list[tuple[str, DataLoader, Optional[torch.Generator]]] = [ |
| 137 | ("train", train_dataloader, None) | ||
| 138 | ] | ||
| 117 | 139 | ||
| 118 | if val_dataloader is not None: | 140 | if val_dataloader is not None: |
| 119 | datasets.append(("stable", val_dataloader, generator)) | 141 | datasets.append(("stable", val_dataloader, generator)) |
| @@ -124,17 +146,11 @@ def save_samples( | |||
| 124 | file_path = output_dir / pool / f"step_{cycle}_{step}.jpg" | 146 | file_path = output_dir / pool / f"step_{cycle}_{step}.jpg" |
| 125 | file_path.parent.mkdir(parents=True, exist_ok=True) | 147 | file_path.parent.mkdir(parents=True, exist_ok=True) |
| 126 | 148 | ||
| 127 | batches = list(itertools.islice(itertools.cycle(data), batch_size * num_batches)) | 149 | batches = list( |
| 128 | prompt_ids = [ | 150 | itertools.islice(itertools.cycle(data), batch_size * num_batches) |
| 129 | prompt | 151 | ) |
| 130 | for batch in batches | 152 | prompt_ids = [prompt for batch in batches for prompt in batch["prompt_ids"]] |
| 131 | for prompt in batch["prompt_ids"] | 153 | nprompt_ids = [prompt for batch in batches for prompt in batch["nprompt_ids"]] |
| 132 | ] | ||
| 133 | nprompt_ids = [ | ||
| 134 | prompt | ||
| 135 | for batch in batches | ||
| 136 | for prompt in batch["nprompt_ids"] | ||
| 137 | ] | ||
| 138 | 154 | ||
| 139 | with torch.inference_mode(): | 155 | with torch.inference_mode(): |
| 140 | for i in range(num_batches): | 156 | for i in range(num_batches): |
| @@ -165,7 +181,9 @@ def save_samples( | |||
| 165 | pass | 181 | pass |
| 166 | 182 | ||
| 167 | image_grid = make_grid(all_samples.permute(0, 3, 1, 2), grid_cols) | 183 | image_grid = make_grid(all_samples.permute(0, 3, 1, 2), grid_cols) |
| 168 | image_grid = pipeline.numpy_to_pil(image_grid.unsqueeze(0).permute(0, 2, 3, 1).numpy())[0] | 184 | image_grid = pipeline.numpy_to_pil( |
| 185 | image_grid.unsqueeze(0).permute(0, 2, 3, 1).numpy() | ||
| 186 | )[0] | ||
| 169 | image_grid.save(file_path, quality=85) | 187 | image_grid.save(file_path, quality=85) |
| 170 | 188 | ||
| 171 | del generator, pipeline | 189 | del generator, pipeline |
| @@ -184,15 +202,17 @@ def generate_class_images( | |||
| 184 | train_dataset: VlpnDataset, | 202 | train_dataset: VlpnDataset, |
| 185 | sample_batch_size: int, | 203 | sample_batch_size: int, |
| 186 | sample_image_size: int, | 204 | sample_image_size: int, |
| 187 | sample_steps: int | 205 | sample_steps: int, |
| 188 | ): | 206 | ): |
| 189 | missing_data = [item for item in train_dataset.items if not item.class_image_path.exists()] | 207 | missing_data = [ |
| 208 | item for item in train_dataset.items if not item.class_image_path.exists() | ||
| 209 | ] | ||
| 190 | 210 | ||
| 191 | if len(missing_data) == 0: | 211 | if len(missing_data) == 0: |
| 192 | return | 212 | return |
| 193 | 213 | ||
| 194 | batched_data = [ | 214 | batched_data = [ |
| 195 | missing_data[i:i+sample_batch_size] | 215 | missing_data[i : i + sample_batch_size] |
| 196 | for i in range(0, len(missing_data), sample_batch_size) | 216 | for i in range(0, len(missing_data), sample_batch_size) |
| 197 | ] | 217 | ] |
| 198 | 218 | ||
| @@ -216,7 +236,7 @@ def generate_class_images( | |||
| 216 | negative_prompt=nprompt, | 236 | negative_prompt=nprompt, |
| 217 | height=sample_image_size, | 237 | height=sample_image_size, |
| 218 | width=sample_image_size, | 238 | width=sample_image_size, |
| 219 | num_inference_steps=sample_steps | 239 | num_inference_steps=sample_steps, |
| 220 | ).images | 240 | ).images |
| 221 | 241 | ||
| 222 | for i, image in enumerate(images): | 242 | for i, image in enumerate(images): |
| @@ -245,8 +265,12 @@ def add_placeholder_tokens( | |||
| 245 | 265 | ||
| 246 | embeddings.resize(len(tokenizer)) | 266 | embeddings.resize(len(tokenizer)) |
| 247 | 267 | ||
| 248 | for (placeholder_token_id, initializer_token_id) in zip(placeholder_token_ids, initializer_token_ids): | 268 | for placeholder_token_id, initializer_token_id in zip( |
| 249 | embeddings.add_embed(placeholder_token_id, initializer_token_id, initializer_noise) | 269 | placeholder_token_ids, initializer_token_ids |
| 270 | ): | ||
| 271 | embeddings.add_embed( | ||
| 272 | placeholder_token_id, initializer_token_id, initializer_noise | ||
| 273 | ) | ||
| 250 | 274 | ||
| 251 | return placeholder_token_ids, initializer_token_ids | 275 | return placeholder_token_ids, initializer_token_ids |
| 252 | 276 | ||
| @@ -261,12 +285,16 @@ def compute_snr(timesteps, noise_scheduler): | |||
| 261 | 285 | ||
| 262 | # Expand the tensors. | 286 | # Expand the tensors. |
| 263 | # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 | 287 | # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 |
| 264 | sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float() | 288 | sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[ |
| 289 | timesteps | ||
| 290 | ].float() | ||
| 265 | while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): | 291 | while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): |
| 266 | sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] | 292 | sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] |
| 267 | alpha = sqrt_alphas_cumprod.expand(timesteps.shape) | 293 | alpha = sqrt_alphas_cumprod.expand(timesteps.shape) |
| 268 | 294 | ||
| 269 | sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float() | 295 | sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to( |
| 296 | device=timesteps.device | ||
| 297 | )[timesteps].float() | ||
| 270 | while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): | 298 | while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): |
| 271 | sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] | 299 | sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] |
| 272 | sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) | 300 | sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) |
| @@ -277,21 +305,22 @@ def compute_snr(timesteps, noise_scheduler): | |||
| 277 | 305 | ||
| 278 | 306 | ||
| 279 | def get_original( | 307 | def get_original( |
| 280 | noise_scheduler, | 308 | noise_scheduler, model_output, sample: torch.FloatTensor, timesteps: torch.IntTensor |
| 281 | model_output, | ||
| 282 | sample: torch.FloatTensor, | ||
| 283 | timesteps: torch.IntTensor | ||
| 284 | ): | 309 | ): |
| 285 | alphas_cumprod = noise_scheduler.alphas_cumprod | 310 | alphas_cumprod = noise_scheduler.alphas_cumprod |
| 286 | sqrt_alphas_cumprod = alphas_cumprod**0.5 | 311 | sqrt_alphas_cumprod = alphas_cumprod**0.5 |
| 287 | sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 | 312 | sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 |
| 288 | 313 | ||
| 289 | sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float() | 314 | sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[ |
| 315 | timesteps | ||
| 316 | ].float() | ||
| 290 | while len(sqrt_alphas_cumprod.shape) < len(sample.shape): | 317 | while len(sqrt_alphas_cumprod.shape) < len(sample.shape): |
| 291 | sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] | 318 | sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] |
| 292 | alpha = sqrt_alphas_cumprod.expand(sample.shape) | 319 | alpha = sqrt_alphas_cumprod.expand(sample.shape) |
| 293 | 320 | ||
| 294 | sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float() | 321 | sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to( |
| 322 | device=timesteps.device | ||
| 323 | )[timesteps].float() | ||
| 295 | while len(sqrt_one_minus_alphas_cumprod.shape) < len(sample.shape): | 324 | while len(sqrt_one_minus_alphas_cumprod.shape) < len(sample.shape): |
| 296 | sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] | 325 | sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] |
| 297 | sigma = sqrt_one_minus_alphas_cumprod.expand(sample.shape) | 326 | sigma = sqrt_one_minus_alphas_cumprod.expand(sample.shape) |
| @@ -329,7 +358,9 @@ def loss_step( | |||
| 329 | eval: bool = False, | 358 | eval: bool = False, |
| 330 | ): | 359 | ): |
| 331 | images = batch["pixel_values"] | 360 | images = batch["pixel_values"] |
| 332 | generator = torch.Generator(device=images.device).manual_seed(seed + step) if eval else None | 361 | generator = ( |
| 362 | torch.Generator(device=images.device).manual_seed(seed + step) if eval else None | ||
| 363 | ) | ||
| 333 | bsz = images.shape[0] | 364 | bsz = images.shape[0] |
| 334 | 365 | ||
| 335 | # Convert images to latent space | 366 | # Convert images to latent space |
| @@ -342,7 +373,7 @@ def loss_step( | |||
| 342 | dtype=latents.dtype, | 373 | dtype=latents.dtype, |
| 343 | layout=latents.layout, | 374 | layout=latents.layout, |
| 344 | device=latents.device, | 375 | device=latents.device, |
| 345 | generator=generator | 376 | generator=generator, |
| 346 | ) | 377 | ) |
| 347 | applied_noise = noise | 378 | applied_noise = noise |
| 348 | 379 | ||
| @@ -353,7 +384,7 @@ def loss_step( | |||
| 353 | octaves=4, | 384 | octaves=4, |
| 354 | dtype=latents.dtype, | 385 | dtype=latents.dtype, |
| 355 | device=latents.device, | 386 | device=latents.device, |
| 356 | generator=generator | 387 | generator=generator, |
| 357 | ) | 388 | ) |
| 358 | 389 | ||
| 359 | if input_pertubation != 0: | 390 | if input_pertubation != 0: |
| @@ -362,7 +393,7 @@ def loss_step( | |||
| 362 | dtype=latents.dtype, | 393 | dtype=latents.dtype, |
| 363 | layout=latents.layout, | 394 | layout=latents.layout, |
| 364 | device=latents.device, | 395 | device=latents.device, |
| 365 | generator=generator | 396 | generator=generator, |
| 366 | ) | 397 | ) |
| 367 | 398 | ||
| 368 | # Sample a random timestep for each image | 399 | # Sample a random timestep for each image |
| @@ -375,25 +406,27 @@ def loss_step( | |||
| 375 | 406 | ||
| 376 | # Get the text embedding for conditioning | 407 | # Get the text embedding for conditioning |
| 377 | encoder_hidden_states = get_extended_embeddings( | 408 | encoder_hidden_states = get_extended_embeddings( |
| 378 | text_encoder, | 409 | text_encoder, batch["input_ids"], batch["attention_mask"] |
| 379 | batch["input_ids"], | ||
| 380 | batch["attention_mask"] | ||
| 381 | ) | 410 | ) |
| 382 | encoder_hidden_states = encoder_hidden_states.to(dtype=unet.dtype) | 411 | encoder_hidden_states = encoder_hidden_states.to(dtype=unet.dtype) |
| 383 | 412 | ||
| 384 | # Predict the noise residual | 413 | # Predict the noise residual |
| 385 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0] | 414 | model_pred = unet( |
| 415 | noisy_latents, timesteps, encoder_hidden_states, return_dict=False | ||
| 416 | )[0] | ||
| 386 | 417 | ||
| 387 | if guidance_scale != 0: | 418 | if guidance_scale != 0: |
| 388 | uncond_encoder_hidden_states = get_extended_embeddings( | 419 | uncond_encoder_hidden_states = get_extended_embeddings( |
| 389 | text_encoder, | 420 | text_encoder, batch["negative_input_ids"], batch["negative_attention_mask"] |
| 390 | batch["negative_input_ids"], | ||
| 391 | batch["negative_attention_mask"] | ||
| 392 | ) | 421 | ) |
| 393 | uncond_encoder_hidden_states = uncond_encoder_hidden_states.to(dtype=unet.dtype) | 422 | uncond_encoder_hidden_states = uncond_encoder_hidden_states.to(dtype=unet.dtype) |
| 394 | 423 | ||
| 395 | model_pred_uncond = unet(noisy_latents, timesteps, uncond_encoder_hidden_states, return_dict=False)[0] | 424 | model_pred_uncond = unet( |
| 396 | model_pred = model_pred_uncond + guidance_scale * (model_pred - model_pred_uncond) | 425 | noisy_latents, timesteps, uncond_encoder_hidden_states, return_dict=False |
| 426 | )[0] | ||
| 427 | model_pred = model_pred_uncond + guidance_scale * ( | ||
| 428 | model_pred - model_pred_uncond | ||
| 429 | ) | ||
| 397 | 430 | ||
| 398 | # Get the target for loss depending on the prediction type | 431 | # Get the target for loss depending on the prediction type |
| 399 | if noise_scheduler.config.prediction_type == "epsilon": | 432 | if noise_scheduler.config.prediction_type == "epsilon": |
| @@ -401,7 +434,9 @@ def loss_step( | |||
| 401 | elif noise_scheduler.config.prediction_type == "v_prediction": | 434 | elif noise_scheduler.config.prediction_type == "v_prediction": |
| 402 | target = noise_scheduler.get_velocity(latents, noise, timesteps) | 435 | target = noise_scheduler.get_velocity(latents, noise, timesteps) |
| 403 | else: | 436 | else: |
| 404 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | 437 | raise ValueError( |
| 438 | f"Unknown prediction type {noise_scheduler.config.prediction_type}" | ||
| 439 | ) | ||
| 405 | 440 | ||
| 406 | acc = (model_pred == target).float().mean() | 441 | acc = (model_pred == target).float().mean() |
| 407 | 442 | ||
| @@ -414,7 +449,9 @@ def loss_step( | |||
| 414 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") | 449 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") |
| 415 | 450 | ||
| 416 | # Compute prior loss | 451 | # Compute prior loss |
| 417 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="none") | 452 | prior_loss = F.mse_loss( |
| 453 | model_pred_prior.float(), target_prior.float(), reduction="none" | ||
| 454 | ) | ||
| 418 | 455 | ||
| 419 | # Add the prior loss to the instance loss. | 456 | # Add the prior loss to the instance loss. |
| 420 | loss = loss + prior_loss_weight * prior_loss | 457 | loss = loss + prior_loss_weight * prior_loss |
| @@ -433,7 +470,10 @@ def loss_step( | |||
| 433 | if min_snr_gamma != 0: | 470 | if min_snr_gamma != 0: |
| 434 | snr = compute_snr(timesteps, noise_scheduler) | 471 | snr = compute_snr(timesteps, noise_scheduler) |
| 435 | mse_loss_weights = ( | 472 | mse_loss_weights = ( |
| 436 | torch.stack([snr, min_snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr | 473 | torch.stack([snr, min_snr_gamma * torch.ones_like(timesteps)], dim=1).min( |
| 474 | dim=1 | ||
| 475 | )[0] | ||
| 476 | / snr | ||
| 437 | ) | 477 | ) |
| 438 | loss = loss * mse_loss_weights | 478 | loss = loss * mse_loss_weights |
| 439 | 479 | ||
| @@ -447,8 +487,14 @@ def loss_step( | |||
| 447 | 487 | ||
| 448 | 488 | ||
| 449 | class LossCallable(Protocol): | 489 | class LossCallable(Protocol): |
| 450 | def __call__(self, step: int, batch: dict[Any, Any], cache: dict[str, Any], | 490 | def __call__( |
| 451 | eval: bool = False) -> Tuple[Any, Any, int]: ... | 491 | self, |
| 492 | step: int, | ||
| 493 | batch: dict[Any, Any], | ||
| 494 | cache: dict[str, Any], | ||
| 495 | eval: bool = False, | ||
| 496 | ) -> Tuple[Any, Any, int]: | ||
| 497 | ... | ||
| 452 | 498 | ||
| 453 | 499 | ||
| 454 | def train_loop( | 500 | def train_loop( |
| @@ -472,9 +518,14 @@ def train_loop( | |||
| 472 | avg_acc_val: AverageMeter = AverageMeter(), | 518 | avg_acc_val: AverageMeter = AverageMeter(), |
| 473 | callbacks: TrainingCallbacks = TrainingCallbacks(), | 519 | callbacks: TrainingCallbacks = TrainingCallbacks(), |
| 474 | ): | 520 | ): |
| 475 | num_training_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) | 521 | num_training_steps_per_epoch = math.ceil( |
| 476 | num_val_steps_per_epoch = math.ceil( | 522 | len(train_dataloader) / gradient_accumulation_steps |
| 477 | len(val_dataloader) / gradient_accumulation_steps) if val_dataloader is not None else 0 | 523 | ) |
| 524 | num_val_steps_per_epoch = ( | ||
| 525 | math.ceil(len(val_dataloader) / gradient_accumulation_steps) | ||
| 526 | if val_dataloader is not None | ||
| 527 | else 0 | ||
| 528 | ) | ||
| 478 | 529 | ||
| 479 | num_training_steps = num_training_steps_per_epoch * num_epochs | 530 | num_training_steps = num_training_steps_per_epoch * num_epochs |
| 480 | num_val_steps = num_val_steps_per_epoch * num_epochs | 531 | num_val_steps = num_val_steps_per_epoch * num_epochs |
| @@ -488,14 +539,14 @@ def train_loop( | |||
| 488 | local_progress_bar = tqdm( | 539 | local_progress_bar = tqdm( |
| 489 | range(num_training_steps_per_epoch + num_val_steps_per_epoch), | 540 | range(num_training_steps_per_epoch + num_val_steps_per_epoch), |
| 490 | disable=not accelerator.is_local_main_process, | 541 | disable=not accelerator.is_local_main_process, |
| 491 | dynamic_ncols=True | 542 | dynamic_ncols=True, |
| 492 | ) | 543 | ) |
| 493 | local_progress_bar.set_description(f"Epoch 1 / {num_epochs}") | 544 | local_progress_bar.set_description(f"Epoch 1 / {num_epochs}") |
| 494 | 545 | ||
| 495 | global_progress_bar = tqdm( | 546 | global_progress_bar = tqdm( |
| 496 | range(num_training_steps + num_val_steps), | 547 | range(num_training_steps + num_val_steps), |
| 497 | disable=not accelerator.is_local_main_process, | 548 | disable=not accelerator.is_local_main_process, |
| 498 | dynamic_ncols=True | 549 | dynamic_ncols=True, |
| 499 | ) | 550 | ) |
| 500 | global_progress_bar.set_description("Total progress") | 551 | global_progress_bar.set_description("Total progress") |
| 501 | 552 | ||
| @@ -513,7 +564,9 @@ def train_loop( | |||
| 513 | try: | 564 | try: |
| 514 | import dadaptation | 565 | import dadaptation |
| 515 | 566 | ||
| 516 | isDadaptation = isinstance(optimizer.optimizer, (dadaptation.DAdaptAdam, dadaptation.DAdaptAdan)) | 567 | isDadaptation = isinstance( |
| 568 | optimizer.optimizer, (dadaptation.DAdaptAdam, dadaptation.DAdaptAdan) | ||
| 569 | ) | ||
| 517 | except ImportError: | 570 | except ImportError: |
| 518 | pass | 571 | pass |
| 519 | 572 | ||
| @@ -565,7 +618,10 @@ def train_loop( | |||
| 565 | label = group_labels[i] if i < len(group_labels) else f"{i}" | 618 | label = group_labels[i] if i < len(group_labels) else f"{i}" |
| 566 | logs[f"lr/{label}"] = lr | 619 | logs[f"lr/{label}"] = lr |
| 567 | if isDadaptation: | 620 | if isDadaptation: |
| 568 | lr = optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"] | 621 | lr = ( |
| 622 | optimizer.param_groups[i]["d"] | ||
| 623 | * optimizer.param_groups[i]["lr"] | ||
| 624 | ) | ||
| 569 | logs[f"d*lr/{label}"] = lr | 625 | logs[f"d*lr/{label}"] = lr |
| 570 | lrs[label] = lr | 626 | lrs[label] = lr |
| 571 | 627 | ||
| @@ -573,8 +629,10 @@ def train_loop( | |||
| 573 | 629 | ||
| 574 | local_progress_bar.set_postfix(**logs) | 630 | local_progress_bar.set_postfix(**logs) |
| 575 | 631 | ||
| 576 | if ((step + 1) % gradient_accumulation_steps == 0) or ((step + 1) == len(train_dataloader)): | 632 | if ((step + 1) % gradient_accumulation_steps == 0) or ( |
| 577 | before_optimize_result = on_before_optimize(epoch) | 633 | (step + 1) == len(train_dataloader) |
| 634 | ): | ||
| 635 | before_optimize_result = on_before_optimize(cycle) | ||
| 578 | 636 | ||
| 579 | optimizer.step() | 637 | optimizer.step() |
| 580 | lr_scheduler.step() | 638 | lr_scheduler.step() |
| @@ -614,7 +672,9 @@ def train_loop( | |||
| 614 | } | 672 | } |
| 615 | local_progress_bar.set_postfix(**logs) | 673 | local_progress_bar.set_postfix(**logs) |
| 616 | 674 | ||
| 617 | if ((step + 1) % gradient_accumulation_steps == 0) or ((step + 1) == len(val_dataloader)): | 675 | if ((step + 1) % gradient_accumulation_steps == 0) or ( |
| 676 | (step + 1) == len(val_dataloader) | ||
| 677 | ): | ||
| 618 | local_progress_bar.update(1) | 678 | local_progress_bar.update(1) |
| 619 | global_progress_bar.update(1) | 679 | global_progress_bar.update(1) |
| 620 | 680 | ||
| @@ -634,7 +694,8 @@ def train_loop( | |||
| 634 | global_progress_bar.clear() | 694 | global_progress_bar.clear() |
| 635 | 695 | ||
| 636 | accelerator.print( | 696 | accelerator.print( |
| 637 | f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg:.2e}") | 697 | f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg:.2e}" |
| 698 | ) | ||
| 638 | on_checkpoint(global_step, "milestone") | 699 | on_checkpoint(global_step, "milestone") |
| 639 | best_acc_val = avg_acc_val.max | 700 | best_acc_val = avg_acc_val.max |
| 640 | else: | 701 | else: |
| @@ -644,7 +705,8 @@ def train_loop( | |||
| 644 | global_progress_bar.clear() | 705 | global_progress_bar.clear() |
| 645 | 706 | ||
| 646 | accelerator.print( | 707 | accelerator.print( |
| 647 | f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg:.2e}") | 708 | f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg:.2e}" |
| 709 | ) | ||
| 648 | on_checkpoint(global_step, "milestone") | 710 | on_checkpoint(global_step, "milestone") |
| 649 | best_acc = avg_acc.max | 711 | best_acc = avg_acc.max |
| 650 | 712 | ||
| @@ -700,17 +762,32 @@ def train( | |||
| 700 | avg_acc_val: AverageMeter = AverageMeter(), | 762 | avg_acc_val: AverageMeter = AverageMeter(), |
| 701 | **kwargs, | 763 | **kwargs, |
| 702 | ): | 764 | ): |
| 703 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = strategy.prepare( | 765 | ( |
| 704 | accelerator, text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, **kwargs) | 766 | text_encoder, |
| 767 | unet, | ||
| 768 | optimizer, | ||
| 769 | train_dataloader, | ||
| 770 | val_dataloader, | ||
| 771 | lr_scheduler, | ||
| 772 | ) = strategy.prepare( | ||
| 773 | accelerator, | ||
| 774 | text_encoder, | ||
| 775 | unet, | ||
| 776 | optimizer, | ||
| 777 | train_dataloader, | ||
| 778 | val_dataloader, | ||
| 779 | lr_scheduler, | ||
| 780 | **kwargs, | ||
| 781 | ) | ||
| 705 | 782 | ||
| 706 | vae.to(accelerator.device, dtype=dtype) | 783 | vae.to(accelerator.device, dtype=dtype) |
| 707 | vae.requires_grad_(False) | 784 | vae.requires_grad_(False) |
| 708 | vae.eval() | 785 | vae.eval() |
| 709 | 786 | ||
| 710 | vae = torch.compile(vae, backend='hidet') | 787 | vae = torch.compile(vae, backend="hidet") |
| 711 | 788 | ||
| 712 | if compile_unet: | 789 | if compile_unet: |
| 713 | unet = torch.compile(unet, backend='hidet') | 790 | unet = torch.compile(unet, backend="hidet") |
| 714 | # unet = torch.compile(unet, mode="reduce-overhead") | 791 | # unet = torch.compile(unet, mode="reduce-overhead") |
| 715 | 792 | ||
| 716 | callbacks = strategy.callbacks( | 793 | callbacks = strategy.callbacks( |
diff --git a/training/lr.py b/training/lr.py index f5b362f..a75078f 100644 --- a/training/lr.py +++ b/training/lr.py | |||
| @@ -23,12 +23,12 @@ def plot_metrics( | |||
| 23 | fig, ax_loss = plt.subplots() | 23 | fig, ax_loss = plt.subplots() |
| 24 | ax_acc = ax_loss.twinx() | 24 | ax_acc = ax_loss.twinx() |
| 25 | 25 | ||
| 26 | ax_loss.plot(lrs, losses, color='red') | 26 | ax_loss.plot(lrs, losses, color="red") |
| 27 | ax_loss.set_xscale("log") | 27 | ax_loss.set_xscale("log") |
| 28 | ax_loss.set_xlabel(f"Learning rate") | 28 | ax_loss.set_xlabel(f"Learning rate") |
| 29 | ax_loss.set_ylabel("Loss") | 29 | ax_loss.set_ylabel("Loss") |
| 30 | 30 | ||
| 31 | ax_acc.plot(lrs, accs, color='blue') | 31 | ax_acc.plot(lrs, accs, color="blue") |
| 32 | ax_acc.set_xscale("log") | 32 | ax_acc.set_xscale("log") |
| 33 | ax_acc.set_ylabel("Accuracy") | 33 | ax_acc.set_ylabel("Accuracy") |
| 34 | 34 | ||
diff --git a/training/optimization.py b/training/optimization.py index d22a900..55531bf 100644 --- a/training/optimization.py +++ b/training/optimization.py | |||
| @@ -5,7 +5,10 @@ from functools import partial | |||
| 5 | import torch | 5 | import torch |
| 6 | from torch.optim.lr_scheduler import LambdaLR | 6 | from torch.optim.lr_scheduler import LambdaLR |
| 7 | 7 | ||
| 8 | from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup | 8 | from diffusers.optimization import ( |
| 9 | get_scheduler as get_scheduler_, | ||
| 10 | get_cosine_with_hard_restarts_schedule_with_warmup, | ||
| 11 | ) | ||
| 9 | from transformers.optimization import get_adafactor_schedule | 12 | from transformers.optimization import get_adafactor_schedule |
| 10 | 13 | ||
| 11 | 14 | ||
| @@ -52,7 +55,7 @@ def get_one_cycle_schedule( | |||
| 52 | annealing_exp: int = 1, | 55 | annealing_exp: int = 1, |
| 53 | min_lr: float = 0.04, | 56 | min_lr: float = 0.04, |
| 54 | mid_point: float = 0.3, | 57 | mid_point: float = 0.3, |
| 55 | last_epoch: int = -1 | 58 | last_epoch: int = -1, |
| 56 | ): | 59 | ): |
| 57 | if warmup == "linear": | 60 | if warmup == "linear": |
| 58 | warmup_func = warmup_linear | 61 | warmup_func = warmup_linear |
| @@ -83,12 +86,16 @@ def get_one_cycle_schedule( | |||
| 83 | 86 | ||
| 84 | def lr_lambda(current_step: int): | 87 | def lr_lambda(current_step: int): |
| 85 | phase = [p for p in phases if current_step >= p.step_min][-1] | 88 | phase = [p for p in phases if current_step >= p.step_min][-1] |
| 86 | return phase.min + phase.func((current_step - phase.step_min) / (phase.step_max - phase.step_min)) * (phase.max - phase.min) | 89 | return phase.min + phase.func( |
| 90 | (current_step - phase.step_min) / (phase.step_max - phase.step_min) | ||
| 91 | ) * (phase.max - phase.min) | ||
| 87 | 92 | ||
| 88 | return LambdaLR(optimizer, lr_lambda, last_epoch) | 93 | return LambdaLR(optimizer, lr_lambda, last_epoch) |
| 89 | 94 | ||
| 90 | 95 | ||
| 91 | def get_exponential_growing_schedule(optimizer, end_lr: float, num_training_steps: int, last_epoch: int = -1): | 96 | def get_exponential_growing_schedule( |
| 97 | optimizer, end_lr: float, num_training_steps: int, last_epoch: int = -1 | ||
| 98 | ): | ||
| 92 | def lr_lambda(base_lr: float, current_step: int): | 99 | def lr_lambda(base_lr: float, current_step: int): |
| 93 | return (end_lr / base_lr) ** (current_step / num_training_steps) | 100 | return (end_lr / base_lr) ** (current_step / num_training_steps) |
| 94 | 101 | ||
| @@ -132,7 +139,14 @@ def get_scheduler( | |||
| 132 | ) | 139 | ) |
| 133 | elif id == "exponential_growth": | 140 | elif id == "exponential_growth": |
| 134 | if cycles is None: | 141 | if cycles is None: |
| 135 | cycles = math.ceil(math.sqrt(((num_training_steps - num_warmup_steps) / num_training_steps_per_epoch))) | 142 | cycles = math.ceil( |
| 143 | math.sqrt( | ||
| 144 | ( | ||
| 145 | (num_training_steps - num_warmup_steps) | ||
| 146 | / num_training_steps_per_epoch | ||
| 147 | ) | ||
| 148 | ) | ||
| 149 | ) | ||
| 136 | 150 | ||
| 137 | lr_scheduler = get_exponential_growing_schedule( | 151 | lr_scheduler = get_exponential_growing_schedule( |
| 138 | optimizer=optimizer, | 152 | optimizer=optimizer, |
| @@ -141,7 +155,14 @@ def get_scheduler( | |||
| 141 | ) | 155 | ) |
| 142 | elif id == "cosine_with_restarts": | 156 | elif id == "cosine_with_restarts": |
| 143 | if cycles is None: | 157 | if cycles is None: |
| 144 | cycles = math.ceil(math.sqrt(((num_training_steps - num_warmup_steps) / num_training_steps_per_epoch))) | 158 | cycles = math.ceil( |
| 159 | math.sqrt( | ||
| 160 | ( | ||
| 161 | (num_training_steps - num_warmup_steps) | ||
| 162 | / num_training_steps_per_epoch | ||
| 163 | ) | ||
| 164 | ) | ||
| 165 | ) | ||
| 145 | 166 | ||
| 146 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( | 167 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( |
| 147 | optimizer=optimizer, | 168 | optimizer=optimizer, |
| @@ -150,10 +171,7 @@ def get_scheduler( | |||
| 150 | num_cycles=cycles, | 171 | num_cycles=cycles, |
| 151 | ) | 172 | ) |
| 152 | elif id == "adafactor": | 173 | elif id == "adafactor": |
| 153 | lr_scheduler = get_adafactor_schedule( | 174 | lr_scheduler = get_adafactor_schedule(optimizer, initial_lr=min_lr) |
| 154 | optimizer, | ||
| 155 | initial_lr=min_lr | ||
| 156 | ) | ||
| 157 | else: | 175 | else: |
| 158 | lr_scheduler = get_scheduler_( | 176 | lr_scheduler = get_scheduler_( |
| 159 | id, | 177 | id, |
diff --git a/training/sampler.py b/training/sampler.py index bdb3e90..0487d66 100644 --- a/training/sampler.py +++ b/training/sampler.py | |||
| @@ -134,7 +134,7 @@ class LossSecondMomentResampler(LossAwareSampler): | |||
| 134 | def weights(self): | 134 | def weights(self): |
| 135 | if not self._warmed_up(): | 135 | if not self._warmed_up(): |
| 136 | return np.ones([self.num_timesteps], dtype=np.float64) | 136 | return np.ones([self.num_timesteps], dtype=np.float64) |
| 137 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) | 137 | weights = np.sqrt(np.mean(self._loss_history**2, axis=-1)) |
| 138 | weights /= np.sum(weights) | 138 | weights /= np.sum(weights) |
| 139 | weights *= 1 - self.uniform_prob | 139 | weights *= 1 - self.uniform_prob |
| 140 | weights += self.uniform_prob / len(weights) | 140 | weights += self.uniform_prob / len(weights) |
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index e6fcc89..88b441b 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py | |||
| @@ -29,7 +29,7 @@ def dreambooth_strategy_callbacks( | |||
| 29 | sample_output_dir: Path, | 29 | sample_output_dir: Path, |
| 30 | checkpoint_output_dir: Path, | 30 | checkpoint_output_dir: Path, |
| 31 | seed: int, | 31 | seed: int, |
| 32 | train_text_encoder_epochs: int, | 32 | train_text_encoder_cycles: int, |
| 33 | max_grad_norm: float = 1.0, | 33 | max_grad_norm: float = 1.0, |
| 34 | use_ema: bool = False, | 34 | use_ema: bool = False, |
| 35 | ema_inv_gamma: float = 1.0, | 35 | ema_inv_gamma: float = 1.0, |
| @@ -85,15 +85,13 @@ def dreambooth_strategy_callbacks( | |||
| 85 | return nullcontext() | 85 | return nullcontext() |
| 86 | 86 | ||
| 87 | @contextmanager | 87 | @contextmanager |
| 88 | def on_train(epoch: int): | 88 | def on_train(cycle: int): |
| 89 | unet.train() | 89 | unet.train() |
| 90 | tokenizer.train() | 90 | tokenizer.train() |
| 91 | 91 | ||
| 92 | if epoch < train_text_encoder_epochs: | 92 | if cycle < train_text_encoder_cycles: |
| 93 | text_encoder.train() | 93 | text_encoder.train() |
| 94 | elif epoch == train_text_encoder_epochs: | 94 | tokenizer.train() |
| 95 | text_encoder.requires_grad_(False) | ||
| 96 | text_encoder.eval() | ||
| 97 | 95 | ||
| 98 | yield | 96 | yield |
| 99 | 97 | ||
| @@ -106,9 +104,9 @@ def dreambooth_strategy_callbacks( | |||
| 106 | with ema_context(): | 104 | with ema_context(): |
| 107 | yield | 105 | yield |
| 108 | 106 | ||
| 109 | def on_before_optimize(epoch: int): | 107 | def on_before_optimize(cycle: int): |
| 110 | params_to_clip = [unet.parameters()] | 108 | params_to_clip = [unet.parameters()] |
| 111 | if epoch < train_text_encoder_epochs: | 109 | if cycle < train_text_encoder_cycles: |
| 112 | params_to_clip.append(text_encoder.parameters()) | 110 | params_to_clip.append(text_encoder.parameters()) |
| 113 | accelerator.clip_grad_norm_(itertools.chain(*params_to_clip), max_grad_norm) | 111 | accelerator.clip_grad_norm_(itertools.chain(*params_to_clip), max_grad_norm) |
| 114 | 112 | ||
| @@ -189,8 +187,16 @@ def dreambooth_prepare( | |||
| 189 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | 187 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, |
| 190 | **kwargs | 188 | **kwargs |
| 191 | ): | 189 | ): |
| 192 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( | 190 | ( |
| 193 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) | 191 | text_encoder, |
| 192 | unet, | ||
| 193 | optimizer, | ||
| 194 | train_dataloader, | ||
| 195 | val_dataloader, | ||
| 196 | lr_scheduler, | ||
| 197 | ) = accelerator.prepare( | ||
| 198 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler | ||
| 199 | ) | ||
| 194 | 200 | ||
| 195 | text_encoder.text_model.embeddings.requires_grad_(False) | 201 | text_encoder.text_model.embeddings.requires_grad_(False) |
| 196 | 202 | ||
| @@ -198,6 +204,5 @@ def dreambooth_prepare( | |||
| 198 | 204 | ||
| 199 | 205 | ||
| 200 | dreambooth_strategy = TrainingStrategy( | 206 | dreambooth_strategy = TrainingStrategy( |
| 201 | callbacks=dreambooth_strategy_callbacks, | 207 | callbacks=dreambooth_strategy_callbacks, prepare=dreambooth_prepare |
| 202 | prepare=dreambooth_prepare | ||
| 203 | ) | 208 | ) |
diff --git a/training/strategy/lora.py b/training/strategy/lora.py index f942b76..14e3384 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py | |||
| @@ -81,7 +81,7 @@ def lora_strategy_callbacks( | |||
| 81 | tokenizer.eval() | 81 | tokenizer.eval() |
| 82 | yield | 82 | yield |
| 83 | 83 | ||
| 84 | def on_before_optimize(epoch: int): | 84 | def on_before_optimize(cycle: int): |
| 85 | if not pti_mode: | 85 | if not pti_mode: |
| 86 | accelerator.clip_grad_norm_( | 86 | accelerator.clip_grad_norm_( |
| 87 | itertools.chain( | 87 | itertools.chain( |
| @@ -89,7 +89,7 @@ def lora_strategy_callbacks( | |||
| 89 | text_encoder.text_model.encoder.parameters(), | 89 | text_encoder.text_model.encoder.parameters(), |
| 90 | text_encoder.text_model.final_layer_norm.parameters(), | 90 | text_encoder.text_model.final_layer_norm.parameters(), |
| 91 | ), | 91 | ), |
| 92 | max_grad_norm | 92 | max_grad_norm, |
| 93 | ) | 93 | ) |
| 94 | 94 | ||
| 95 | if len(placeholder_tokens) != 0 and use_emb_decay: | 95 | if len(placeholder_tokens) != 0 and use_emb_decay: |
| @@ -108,7 +108,9 @@ def lora_strategy_callbacks( | |||
| 108 | 108 | ||
| 109 | if lambda_ != 0: | 109 | if lambda_ != 0: |
| 110 | norm = w[:, :].norm(dim=-1, keepdim=True) | 110 | norm = w[:, :].norm(dim=-1, keepdim=True) |
| 111 | w[:].add_((w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) | 111 | w[:].add_( |
| 112 | (w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm) | ||
| 113 | ) | ||
| 112 | 114 | ||
| 113 | @torch.no_grad() | 115 | @torch.no_grad() |
| 114 | def on_checkpoint(step, postfix): | 116 | def on_checkpoint(step, postfix): |
| @@ -128,25 +130,32 @@ def lora_strategy_callbacks( | |||
| 128 | 130 | ||
| 129 | if not pti_mode: | 131 | if not pti_mode: |
| 130 | lora_config = {} | 132 | lora_config = {} |
| 131 | state_dict = get_peft_model_state_dict(unet_, state_dict=accelerator.get_state_dict(unet_)) | 133 | state_dict = get_peft_model_state_dict( |
| 134 | unet_, state_dict=accelerator.get_state_dict(unet_) | ||
| 135 | ) | ||
| 132 | lora_config["peft_config"] = unet_.get_peft_config_as_dict(inference=True) | 136 | lora_config["peft_config"] = unet_.get_peft_config_as_dict(inference=True) |
| 133 | 137 | ||
| 134 | text_encoder_state_dict = get_peft_model_state_dict( | 138 | text_encoder_state_dict = get_peft_model_state_dict( |
| 135 | text_encoder_, state_dict=accelerator.get_state_dict(text_encoder_) | 139 | text_encoder_, state_dict=accelerator.get_state_dict(text_encoder_) |
| 136 | ) | 140 | ) |
| 137 | text_encoder_state_dict = {f"text_encoder_{k}": v for k, v in text_encoder_state_dict.items()} | 141 | text_encoder_state_dict = { |
| 142 | f"text_encoder_{k}": v for k, v in text_encoder_state_dict.items() | ||
| 143 | } | ||
| 138 | state_dict.update(text_encoder_state_dict) | 144 | state_dict.update(text_encoder_state_dict) |
| 139 | lora_config["text_encoder_peft_config"] = text_encoder_.get_peft_config_as_dict(inference=True) | 145 | lora_config[ |
| 146 | "text_encoder_peft_config" | ||
| 147 | ] = text_encoder_.get_peft_config_as_dict(inference=True) | ||
| 140 | 148 | ||
| 141 | if len(placeholder_tokens) != 0: | 149 | if len(placeholder_tokens) != 0: |
| 142 | ti_state_dict = { | 150 | ti_state_dict = { |
| 143 | f"ti_${token}": text_encoder.text_model.embeddings.get_embed(ids) | 151 | f"ti_${token}": text_encoder.text_model.embeddings.get_embed(ids) |
| 144 | for (token, ids) | 152 | for (token, ids) in zip(placeholder_tokens, placeholder_token_ids) |
| 145 | in zip(placeholder_tokens, placeholder_token_ids) | ||
| 146 | } | 153 | } |
| 147 | state_dict.update(ti_state_dict) | 154 | state_dict.update(ti_state_dict) |
| 148 | 155 | ||
| 149 | save_file(state_dict, checkpoint_output_dir / f"{step}_{postfix}.safetensors") | 156 | save_file( |
| 157 | state_dict, checkpoint_output_dir / f"{step}_{postfix}.safetensors" | ||
| 158 | ) | ||
| 150 | with open(checkpoint_output_dir / "lora_config.json", "w") as f: | 159 | with open(checkpoint_output_dir / "lora_config.json", "w") as f: |
| 151 | json.dump(lora_config, f) | 160 | json.dump(lora_config, f) |
| 152 | 161 | ||
| @@ -185,10 +194,18 @@ def lora_prepare( | |||
| 185 | train_dataloader: DataLoader, | 194 | train_dataloader: DataLoader, |
| 186 | val_dataloader: Optional[DataLoader], | 195 | val_dataloader: Optional[DataLoader], |
| 187 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | 196 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, |
| 188 | **kwargs | 197 | **kwargs, |
| 189 | ): | 198 | ): |
| 190 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( | 199 | ( |
| 191 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) | 200 | text_encoder, |
| 201 | unet, | ||
| 202 | optimizer, | ||
| 203 | train_dataloader, | ||
| 204 | val_dataloader, | ||
| 205 | lr_scheduler, | ||
| 206 | ) = accelerator.prepare( | ||
| 207 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler | ||
| 208 | ) | ||
| 192 | 209 | ||
| 193 | # text_encoder.text_model.embeddings.token_embedding.requires_grad_(True) | 210 | # text_encoder.text_model.embeddings.token_embedding.requires_grad_(True) |
| 194 | 211 | ||
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 6bc1d7d..7373982 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
| @@ -104,7 +104,7 @@ def textual_inversion_strategy_callbacks( | |||
| 104 | yield | 104 | yield |
| 105 | 105 | ||
| 106 | @torch.no_grad() | 106 | @torch.no_grad() |
| 107 | def on_before_optimize(epoch: int): | 107 | def on_before_optimize(cycle: int): |
| 108 | if use_emb_decay: | 108 | if use_emb_decay: |
| 109 | params = [ | 109 | params = [ |
| 110 | p | 110 | p |
| @@ -116,7 +116,9 @@ def textual_inversion_strategy_callbacks( | |||
| 116 | @torch.no_grad() | 116 | @torch.no_grad() |
| 117 | def on_after_optimize(w, lrs: dict[str, float]): | 117 | def on_after_optimize(w, lrs: dict[str, float]): |
| 118 | if ema_embeddings is not None: | 118 | if ema_embeddings is not None: |
| 119 | ema_embeddings.step(text_encoder.text_model.embeddings.token_embedding.parameters()) | 119 | ema_embeddings.step( |
| 120 | text_encoder.text_model.embeddings.token_embedding.parameters() | ||
| 121 | ) | ||
| 120 | 122 | ||
| 121 | if use_emb_decay and w is not None: | 123 | if use_emb_decay and w is not None: |
| 122 | lr = lrs["emb"] if "emb" in lrs else lrs["0"] | 124 | lr = lrs["emb"] if "emb" in lrs else lrs["0"] |
| @@ -124,7 +126,9 @@ def textual_inversion_strategy_callbacks( | |||
| 124 | 126 | ||
| 125 | if lambda_ != 0: | 127 | if lambda_ != 0: |
| 126 | norm = w[:, :].norm(dim=-1, keepdim=True) | 128 | norm = w[:, :].norm(dim=-1, keepdim=True) |
| 127 | w[:].add_((w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) | 129 | w[:].add_( |
| 130 | (w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm) | ||
| 131 | ) | ||
| 128 | 132 | ||
| 129 | def on_log(): | 133 | def on_log(): |
| 130 | if ema_embeddings is not None: | 134 | if ema_embeddings is not None: |
| @@ -136,10 +140,10 @@ def textual_inversion_strategy_callbacks( | |||
| 136 | print(f"Saving checkpoint for step {step}...") | 140 | print(f"Saving checkpoint for step {step}...") |
| 137 | 141 | ||
| 138 | with ema_context(): | 142 | with ema_context(): |
| 139 | for (token, ids) in zip(placeholder_tokens, placeholder_token_ids): | 143 | for token, ids in zip(placeholder_tokens, placeholder_token_ids): |
| 140 | text_encoder.text_model.embeddings.save_embed( | 144 | text_encoder.text_model.embeddings.save_embed( |
| 141 | ids, | 145 | ids, |
| 142 | checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin" | 146 | checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin", |
| 143 | ) | 147 | ) |
| 144 | 148 | ||
| 145 | @torch.no_grad() | 149 | @torch.no_grad() |
| @@ -183,7 +187,7 @@ def textual_inversion_prepare( | |||
| 183 | val_dataloader: Optional[DataLoader], | 187 | val_dataloader: Optional[DataLoader], |
| 184 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | 188 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, |
| 185 | gradient_checkpointing: bool = False, | 189 | gradient_checkpointing: bool = False, |
| 186 | **kwargs | 190 | **kwargs, |
| 187 | ): | 191 | ): |
| 188 | weight_dtype = torch.float32 | 192 | weight_dtype = torch.float32 |
| 189 | if accelerator.state.mixed_precision == "fp16": | 193 | if accelerator.state.mixed_precision == "fp16": |
| @@ -191,8 +195,15 @@ def textual_inversion_prepare( | |||
| 191 | elif accelerator.state.mixed_precision == "bf16": | 195 | elif accelerator.state.mixed_precision == "bf16": |
| 192 | weight_dtype = torch.bfloat16 | 196 | weight_dtype = torch.bfloat16 |
| 193 | 197 | ||
| 194 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( | 198 | ( |
| 195 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler) | 199 | text_encoder, |
| 200 | optimizer, | ||
| 201 | train_dataloader, | ||
| 202 | val_dataloader, | ||
| 203 | lr_scheduler, | ||
| 204 | ) = accelerator.prepare( | ||
| 205 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler | ||
| 206 | ) | ||
| 196 | 207 | ||
| 197 | unet.to(accelerator.device, dtype=weight_dtype) | 208 | unet.to(accelerator.device, dtype=weight_dtype) |
| 198 | unet.requires_grad_(False) | 209 | unet.requires_grad_(False) |
