diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/functional.py | 221 | ||||
-rw-r--r-- | training/lr.py | 4 | ||||
-rw-r--r-- | training/optimization.py | 38 | ||||
-rw-r--r-- | training/sampler.py | 2 | ||||
-rw-r--r-- | training/strategy/dreambooth.py | 29 | ||||
-rw-r--r-- | training/strategy/lora.py | 41 | ||||
-rw-r--r-- | training/strategy/ti.py | 27 |
7 files changed, 245 insertions, 117 deletions
diff --git a/training/functional.py b/training/functional.py index fd3f9f4..f68faf9 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -14,7 +14,13 @@ import numpy as np | |||
14 | 14 | ||
15 | from accelerate import Accelerator | 15 | from accelerate import Accelerator |
16 | from transformers import CLIPTextModel | 16 | from transformers import CLIPTextModel |
17 | from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler, UniPCMultistepScheduler, SchedulerMixin | 17 | from diffusers import ( |
18 | AutoencoderKL, | ||
19 | UNet2DConditionModel, | ||
20 | DDPMScheduler, | ||
21 | UniPCMultistepScheduler, | ||
22 | SchedulerMixin, | ||
23 | ) | ||
18 | 24 | ||
19 | from tqdm.auto import tqdm | 25 | from tqdm.auto import tqdm |
20 | 26 | ||
@@ -33,11 +39,12 @@ from util.noise import perlin_noise | |||
33 | def const(result=None): | 39 | def const(result=None): |
34 | def fn(*args, **kwargs): | 40 | def fn(*args, **kwargs): |
35 | return result | 41 | return result |
42 | |||
36 | return fn | 43 | return fn |
37 | 44 | ||
38 | 45 | ||
39 | @dataclass | 46 | @dataclass |
40 | class TrainingCallbacks(): | 47 | class TrainingCallbacks: |
41 | on_log: Callable[[], dict[str, Any]] = const({}) | 48 | on_log: Callable[[], dict[str, Any]] = const({}) |
42 | on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) | 49 | on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) |
43 | on_before_optimize: Callable[[int], Any] = const() | 50 | on_before_optimize: Callable[[int], Any] = const() |
@@ -58,23 +65,36 @@ class TrainingStrategyPrepareCallable(Protocol): | |||
58 | train_dataloader: DataLoader, | 65 | train_dataloader: DataLoader, |
59 | val_dataloader: Optional[DataLoader], | 66 | val_dataloader: Optional[DataLoader], |
60 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | 67 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, |
61 | **kwargs | 68 | **kwargs, |
62 | ) -> Tuple: ... | 69 | ) -> Tuple: |
70 | ... | ||
63 | 71 | ||
64 | 72 | ||
65 | @dataclass | 73 | @dataclass |
66 | class TrainingStrategy(): | 74 | class TrainingStrategy: |
67 | callbacks: Callable[..., TrainingCallbacks] | 75 | callbacks: Callable[..., TrainingCallbacks] |
68 | prepare: TrainingStrategyPrepareCallable | 76 | prepare: TrainingStrategyPrepareCallable |
69 | 77 | ||
70 | 78 | ||
71 | def get_models(pretrained_model_name_or_path: str, torch_dtype=torch.float32): | 79 | def get_models(pretrained_model_name_or_path: str, torch_dtype=torch.float32): |
72 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') | 80 | tokenizer = MultiCLIPTokenizer.from_pretrained( |
73 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder', torch_dtype=torch_dtype) | 81 | pretrained_model_name_or_path, subfolder="tokenizer" |
74 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae', torch_dtype=torch_dtype) | 82 | ) |
75 | unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet', torch_dtype=torch_dtype) | 83 | text_encoder = CLIPTextModel.from_pretrained( |
76 | noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') | 84 | pretrained_model_name_or_path, subfolder="text_encoder", torch_dtype=torch_dtype |
77 | sample_scheduler = UniPCMultistepScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') | 85 | ) |
86 | vae = AutoencoderKL.from_pretrained( | ||
87 | pretrained_model_name_or_path, subfolder="vae", torch_dtype=torch_dtype | ||
88 | ) | ||
89 | unet = UNet2DConditionModel.from_pretrained( | ||
90 | pretrained_model_name_or_path, subfolder="unet", torch_dtype=torch_dtype | ||
91 | ) | ||
92 | noise_scheduler = DDPMScheduler.from_pretrained( | ||
93 | pretrained_model_name_or_path, subfolder="scheduler" | ||
94 | ) | ||
95 | sample_scheduler = UniPCMultistepScheduler.from_pretrained( | ||
96 | pretrained_model_name_or_path, subfolder="scheduler" | ||
97 | ) | ||
78 | 98 | ||
79 | return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler | 99 | return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler |
80 | 100 | ||
@@ -113,7 +133,9 @@ def save_samples( | |||
113 | 133 | ||
114 | generator = torch.Generator(device=accelerator.device).manual_seed(seed) | 134 | generator = torch.Generator(device=accelerator.device).manual_seed(seed) |
115 | 135 | ||
116 | datasets: list[tuple[str, DataLoader, Optional[torch.Generator]]] = [("train", train_dataloader, None)] | 136 | datasets: list[tuple[str, DataLoader, Optional[torch.Generator]]] = [ |
137 | ("train", train_dataloader, None) | ||
138 | ] | ||
117 | 139 | ||
118 | if val_dataloader is not None: | 140 | if val_dataloader is not None: |
119 | datasets.append(("stable", val_dataloader, generator)) | 141 | datasets.append(("stable", val_dataloader, generator)) |
@@ -124,17 +146,11 @@ def save_samples( | |||
124 | file_path = output_dir / pool / f"step_{cycle}_{step}.jpg" | 146 | file_path = output_dir / pool / f"step_{cycle}_{step}.jpg" |
125 | file_path.parent.mkdir(parents=True, exist_ok=True) | 147 | file_path.parent.mkdir(parents=True, exist_ok=True) |
126 | 148 | ||
127 | batches = list(itertools.islice(itertools.cycle(data), batch_size * num_batches)) | 149 | batches = list( |
128 | prompt_ids = [ | 150 | itertools.islice(itertools.cycle(data), batch_size * num_batches) |
129 | prompt | 151 | ) |
130 | for batch in batches | 152 | prompt_ids = [prompt for batch in batches for prompt in batch["prompt_ids"]] |
131 | for prompt in batch["prompt_ids"] | 153 | nprompt_ids = [prompt for batch in batches for prompt in batch["nprompt_ids"]] |
132 | ] | ||
133 | nprompt_ids = [ | ||
134 | prompt | ||
135 | for batch in batches | ||
136 | for prompt in batch["nprompt_ids"] | ||
137 | ] | ||
138 | 154 | ||
139 | with torch.inference_mode(): | 155 | with torch.inference_mode(): |
140 | for i in range(num_batches): | 156 | for i in range(num_batches): |
@@ -165,7 +181,9 @@ def save_samples( | |||
165 | pass | 181 | pass |
166 | 182 | ||
167 | image_grid = make_grid(all_samples.permute(0, 3, 1, 2), grid_cols) | 183 | image_grid = make_grid(all_samples.permute(0, 3, 1, 2), grid_cols) |
168 | image_grid = pipeline.numpy_to_pil(image_grid.unsqueeze(0).permute(0, 2, 3, 1).numpy())[0] | 184 | image_grid = pipeline.numpy_to_pil( |
185 | image_grid.unsqueeze(0).permute(0, 2, 3, 1).numpy() | ||
186 | )[0] | ||
169 | image_grid.save(file_path, quality=85) | 187 | image_grid.save(file_path, quality=85) |
170 | 188 | ||
171 | del generator, pipeline | 189 | del generator, pipeline |
@@ -184,15 +202,17 @@ def generate_class_images( | |||
184 | train_dataset: VlpnDataset, | 202 | train_dataset: VlpnDataset, |
185 | sample_batch_size: int, | 203 | sample_batch_size: int, |
186 | sample_image_size: int, | 204 | sample_image_size: int, |
187 | sample_steps: int | 205 | sample_steps: int, |
188 | ): | 206 | ): |
189 | missing_data = [item for item in train_dataset.items if not item.class_image_path.exists()] | 207 | missing_data = [ |
208 | item for item in train_dataset.items if not item.class_image_path.exists() | ||
209 | ] | ||
190 | 210 | ||
191 | if len(missing_data) == 0: | 211 | if len(missing_data) == 0: |
192 | return | 212 | return |
193 | 213 | ||
194 | batched_data = [ | 214 | batched_data = [ |
195 | missing_data[i:i+sample_batch_size] | 215 | missing_data[i : i + sample_batch_size] |
196 | for i in range(0, len(missing_data), sample_batch_size) | 216 | for i in range(0, len(missing_data), sample_batch_size) |
197 | ] | 217 | ] |
198 | 218 | ||
@@ -216,7 +236,7 @@ def generate_class_images( | |||
216 | negative_prompt=nprompt, | 236 | negative_prompt=nprompt, |
217 | height=sample_image_size, | 237 | height=sample_image_size, |
218 | width=sample_image_size, | 238 | width=sample_image_size, |
219 | num_inference_steps=sample_steps | 239 | num_inference_steps=sample_steps, |
220 | ).images | 240 | ).images |
221 | 241 | ||
222 | for i, image in enumerate(images): | 242 | for i, image in enumerate(images): |
@@ -245,8 +265,12 @@ def add_placeholder_tokens( | |||
245 | 265 | ||
246 | embeddings.resize(len(tokenizer)) | 266 | embeddings.resize(len(tokenizer)) |
247 | 267 | ||
248 | for (placeholder_token_id, initializer_token_id) in zip(placeholder_token_ids, initializer_token_ids): | 268 | for placeholder_token_id, initializer_token_id in zip( |
249 | embeddings.add_embed(placeholder_token_id, initializer_token_id, initializer_noise) | 269 | placeholder_token_ids, initializer_token_ids |
270 | ): | ||
271 | embeddings.add_embed( | ||
272 | placeholder_token_id, initializer_token_id, initializer_noise | ||
273 | ) | ||
250 | 274 | ||
251 | return placeholder_token_ids, initializer_token_ids | 275 | return placeholder_token_ids, initializer_token_ids |
252 | 276 | ||
@@ -261,12 +285,16 @@ def compute_snr(timesteps, noise_scheduler): | |||
261 | 285 | ||
262 | # Expand the tensors. | 286 | # Expand the tensors. |
263 | # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 | 287 | # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 |
264 | sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float() | 288 | sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[ |
289 | timesteps | ||
290 | ].float() | ||
265 | while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): | 291 | while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): |
266 | sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] | 292 | sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] |
267 | alpha = sqrt_alphas_cumprod.expand(timesteps.shape) | 293 | alpha = sqrt_alphas_cumprod.expand(timesteps.shape) |
268 | 294 | ||
269 | sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float() | 295 | sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to( |
296 | device=timesteps.device | ||
297 | )[timesteps].float() | ||
270 | while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): | 298 | while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): |
271 | sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] | 299 | sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] |
272 | sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) | 300 | sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) |
@@ -277,21 +305,22 @@ def compute_snr(timesteps, noise_scheduler): | |||
277 | 305 | ||
278 | 306 | ||
279 | def get_original( | 307 | def get_original( |
280 | noise_scheduler, | 308 | noise_scheduler, model_output, sample: torch.FloatTensor, timesteps: torch.IntTensor |
281 | model_output, | ||
282 | sample: torch.FloatTensor, | ||
283 | timesteps: torch.IntTensor | ||
284 | ): | 309 | ): |
285 | alphas_cumprod = noise_scheduler.alphas_cumprod | 310 | alphas_cumprod = noise_scheduler.alphas_cumprod |
286 | sqrt_alphas_cumprod = alphas_cumprod**0.5 | 311 | sqrt_alphas_cumprod = alphas_cumprod**0.5 |
287 | sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 | 312 | sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 |
288 | 313 | ||
289 | sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float() | 314 | sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[ |
315 | timesteps | ||
316 | ].float() | ||
290 | while len(sqrt_alphas_cumprod.shape) < len(sample.shape): | 317 | while len(sqrt_alphas_cumprod.shape) < len(sample.shape): |
291 | sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] | 318 | sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] |
292 | alpha = sqrt_alphas_cumprod.expand(sample.shape) | 319 | alpha = sqrt_alphas_cumprod.expand(sample.shape) |
293 | 320 | ||
294 | sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float() | 321 | sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to( |
322 | device=timesteps.device | ||
323 | )[timesteps].float() | ||
295 | while len(sqrt_one_minus_alphas_cumprod.shape) < len(sample.shape): | 324 | while len(sqrt_one_minus_alphas_cumprod.shape) < len(sample.shape): |
296 | sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] | 325 | sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] |
297 | sigma = sqrt_one_minus_alphas_cumprod.expand(sample.shape) | 326 | sigma = sqrt_one_minus_alphas_cumprod.expand(sample.shape) |
@@ -329,7 +358,9 @@ def loss_step( | |||
329 | eval: bool = False, | 358 | eval: bool = False, |
330 | ): | 359 | ): |
331 | images = batch["pixel_values"] | 360 | images = batch["pixel_values"] |
332 | generator = torch.Generator(device=images.device).manual_seed(seed + step) if eval else None | 361 | generator = ( |
362 | torch.Generator(device=images.device).manual_seed(seed + step) if eval else None | ||
363 | ) | ||
333 | bsz = images.shape[0] | 364 | bsz = images.shape[0] |
334 | 365 | ||
335 | # Convert images to latent space | 366 | # Convert images to latent space |
@@ -342,7 +373,7 @@ def loss_step( | |||
342 | dtype=latents.dtype, | 373 | dtype=latents.dtype, |
343 | layout=latents.layout, | 374 | layout=latents.layout, |
344 | device=latents.device, | 375 | device=latents.device, |
345 | generator=generator | 376 | generator=generator, |
346 | ) | 377 | ) |
347 | applied_noise = noise | 378 | applied_noise = noise |
348 | 379 | ||
@@ -353,7 +384,7 @@ def loss_step( | |||
353 | octaves=4, | 384 | octaves=4, |
354 | dtype=latents.dtype, | 385 | dtype=latents.dtype, |
355 | device=latents.device, | 386 | device=latents.device, |
356 | generator=generator | 387 | generator=generator, |
357 | ) | 388 | ) |
358 | 389 | ||
359 | if input_pertubation != 0: | 390 | if input_pertubation != 0: |
@@ -362,7 +393,7 @@ def loss_step( | |||
362 | dtype=latents.dtype, | 393 | dtype=latents.dtype, |
363 | layout=latents.layout, | 394 | layout=latents.layout, |
364 | device=latents.device, | 395 | device=latents.device, |
365 | generator=generator | 396 | generator=generator, |
366 | ) | 397 | ) |
367 | 398 | ||
368 | # Sample a random timestep for each image | 399 | # Sample a random timestep for each image |
@@ -375,25 +406,27 @@ def loss_step( | |||
375 | 406 | ||
376 | # Get the text embedding for conditioning | 407 | # Get the text embedding for conditioning |
377 | encoder_hidden_states = get_extended_embeddings( | 408 | encoder_hidden_states = get_extended_embeddings( |
378 | text_encoder, | 409 | text_encoder, batch["input_ids"], batch["attention_mask"] |
379 | batch["input_ids"], | ||
380 | batch["attention_mask"] | ||
381 | ) | 410 | ) |
382 | encoder_hidden_states = encoder_hidden_states.to(dtype=unet.dtype) | 411 | encoder_hidden_states = encoder_hidden_states.to(dtype=unet.dtype) |
383 | 412 | ||
384 | # Predict the noise residual | 413 | # Predict the noise residual |
385 | model_pred = unet(noisy_latents, timesteps, encoder_hidden_states, return_dict=False)[0] | 414 | model_pred = unet( |
415 | noisy_latents, timesteps, encoder_hidden_states, return_dict=False | ||
416 | )[0] | ||
386 | 417 | ||
387 | if guidance_scale != 0: | 418 | if guidance_scale != 0: |
388 | uncond_encoder_hidden_states = get_extended_embeddings( | 419 | uncond_encoder_hidden_states = get_extended_embeddings( |
389 | text_encoder, | 420 | text_encoder, batch["negative_input_ids"], batch["negative_attention_mask"] |
390 | batch["negative_input_ids"], | ||
391 | batch["negative_attention_mask"] | ||
392 | ) | 421 | ) |
393 | uncond_encoder_hidden_states = uncond_encoder_hidden_states.to(dtype=unet.dtype) | 422 | uncond_encoder_hidden_states = uncond_encoder_hidden_states.to(dtype=unet.dtype) |
394 | 423 | ||
395 | model_pred_uncond = unet(noisy_latents, timesteps, uncond_encoder_hidden_states, return_dict=False)[0] | 424 | model_pred_uncond = unet( |
396 | model_pred = model_pred_uncond + guidance_scale * (model_pred - model_pred_uncond) | 425 | noisy_latents, timesteps, uncond_encoder_hidden_states, return_dict=False |
426 | )[0] | ||
427 | model_pred = model_pred_uncond + guidance_scale * ( | ||
428 | model_pred - model_pred_uncond | ||
429 | ) | ||
397 | 430 | ||
398 | # Get the target for loss depending on the prediction type | 431 | # Get the target for loss depending on the prediction type |
399 | if noise_scheduler.config.prediction_type == "epsilon": | 432 | if noise_scheduler.config.prediction_type == "epsilon": |
@@ -401,7 +434,9 @@ def loss_step( | |||
401 | elif noise_scheduler.config.prediction_type == "v_prediction": | 434 | elif noise_scheduler.config.prediction_type == "v_prediction": |
402 | target = noise_scheduler.get_velocity(latents, noise, timesteps) | 435 | target = noise_scheduler.get_velocity(latents, noise, timesteps) |
403 | else: | 436 | else: |
404 | raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") | 437 | raise ValueError( |
438 | f"Unknown prediction type {noise_scheduler.config.prediction_type}" | ||
439 | ) | ||
405 | 440 | ||
406 | acc = (model_pred == target).float().mean() | 441 | acc = (model_pred == target).float().mean() |
407 | 442 | ||
@@ -414,7 +449,9 @@ def loss_step( | |||
414 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") | 449 | loss = F.mse_loss(model_pred.float(), target.float(), reduction="none") |
415 | 450 | ||
416 | # Compute prior loss | 451 | # Compute prior loss |
417 | prior_loss = F.mse_loss(model_pred_prior.float(), target_prior.float(), reduction="none") | 452 | prior_loss = F.mse_loss( |
453 | model_pred_prior.float(), target_prior.float(), reduction="none" | ||
454 | ) | ||
418 | 455 | ||
419 | # Add the prior loss to the instance loss. | 456 | # Add the prior loss to the instance loss. |
420 | loss = loss + prior_loss_weight * prior_loss | 457 | loss = loss + prior_loss_weight * prior_loss |
@@ -433,7 +470,10 @@ def loss_step( | |||
433 | if min_snr_gamma != 0: | 470 | if min_snr_gamma != 0: |
434 | snr = compute_snr(timesteps, noise_scheduler) | 471 | snr = compute_snr(timesteps, noise_scheduler) |
435 | mse_loss_weights = ( | 472 | mse_loss_weights = ( |
436 | torch.stack([snr, min_snr_gamma * torch.ones_like(timesteps)], dim=1).min(dim=1)[0] / snr | 473 | torch.stack([snr, min_snr_gamma * torch.ones_like(timesteps)], dim=1).min( |
474 | dim=1 | ||
475 | )[0] | ||
476 | / snr | ||
437 | ) | 477 | ) |
438 | loss = loss * mse_loss_weights | 478 | loss = loss * mse_loss_weights |
439 | 479 | ||
@@ -447,8 +487,14 @@ def loss_step( | |||
447 | 487 | ||
448 | 488 | ||
449 | class LossCallable(Protocol): | 489 | class LossCallable(Protocol): |
450 | def __call__(self, step: int, batch: dict[Any, Any], cache: dict[str, Any], | 490 | def __call__( |
451 | eval: bool = False) -> Tuple[Any, Any, int]: ... | 491 | self, |
492 | step: int, | ||
493 | batch: dict[Any, Any], | ||
494 | cache: dict[str, Any], | ||
495 | eval: bool = False, | ||
496 | ) -> Tuple[Any, Any, int]: | ||
497 | ... | ||
452 | 498 | ||
453 | 499 | ||
454 | def train_loop( | 500 | def train_loop( |
@@ -472,9 +518,14 @@ def train_loop( | |||
472 | avg_acc_val: AverageMeter = AverageMeter(), | 518 | avg_acc_val: AverageMeter = AverageMeter(), |
473 | callbacks: TrainingCallbacks = TrainingCallbacks(), | 519 | callbacks: TrainingCallbacks = TrainingCallbacks(), |
474 | ): | 520 | ): |
475 | num_training_steps_per_epoch = math.ceil(len(train_dataloader) / gradient_accumulation_steps) | 521 | num_training_steps_per_epoch = math.ceil( |
476 | num_val_steps_per_epoch = math.ceil( | 522 | len(train_dataloader) / gradient_accumulation_steps |
477 | len(val_dataloader) / gradient_accumulation_steps) if val_dataloader is not None else 0 | 523 | ) |
524 | num_val_steps_per_epoch = ( | ||
525 | math.ceil(len(val_dataloader) / gradient_accumulation_steps) | ||
526 | if val_dataloader is not None | ||
527 | else 0 | ||
528 | ) | ||
478 | 529 | ||
479 | num_training_steps = num_training_steps_per_epoch * num_epochs | 530 | num_training_steps = num_training_steps_per_epoch * num_epochs |
480 | num_val_steps = num_val_steps_per_epoch * num_epochs | 531 | num_val_steps = num_val_steps_per_epoch * num_epochs |
@@ -488,14 +539,14 @@ def train_loop( | |||
488 | local_progress_bar = tqdm( | 539 | local_progress_bar = tqdm( |
489 | range(num_training_steps_per_epoch + num_val_steps_per_epoch), | 540 | range(num_training_steps_per_epoch + num_val_steps_per_epoch), |
490 | disable=not accelerator.is_local_main_process, | 541 | disable=not accelerator.is_local_main_process, |
491 | dynamic_ncols=True | 542 | dynamic_ncols=True, |
492 | ) | 543 | ) |
493 | local_progress_bar.set_description(f"Epoch 1 / {num_epochs}") | 544 | local_progress_bar.set_description(f"Epoch 1 / {num_epochs}") |
494 | 545 | ||
495 | global_progress_bar = tqdm( | 546 | global_progress_bar = tqdm( |
496 | range(num_training_steps + num_val_steps), | 547 | range(num_training_steps + num_val_steps), |
497 | disable=not accelerator.is_local_main_process, | 548 | disable=not accelerator.is_local_main_process, |
498 | dynamic_ncols=True | 549 | dynamic_ncols=True, |
499 | ) | 550 | ) |
500 | global_progress_bar.set_description("Total progress") | 551 | global_progress_bar.set_description("Total progress") |
501 | 552 | ||
@@ -513,7 +564,9 @@ def train_loop( | |||
513 | try: | 564 | try: |
514 | import dadaptation | 565 | import dadaptation |
515 | 566 | ||
516 | isDadaptation = isinstance(optimizer.optimizer, (dadaptation.DAdaptAdam, dadaptation.DAdaptAdan)) | 567 | isDadaptation = isinstance( |
568 | optimizer.optimizer, (dadaptation.DAdaptAdam, dadaptation.DAdaptAdan) | ||
569 | ) | ||
517 | except ImportError: | 570 | except ImportError: |
518 | pass | 571 | pass |
519 | 572 | ||
@@ -565,7 +618,10 @@ def train_loop( | |||
565 | label = group_labels[i] if i < len(group_labels) else f"{i}" | 618 | label = group_labels[i] if i < len(group_labels) else f"{i}" |
566 | logs[f"lr/{label}"] = lr | 619 | logs[f"lr/{label}"] = lr |
567 | if isDadaptation: | 620 | if isDadaptation: |
568 | lr = optimizer.param_groups[i]["d"] * optimizer.param_groups[i]["lr"] | 621 | lr = ( |
622 | optimizer.param_groups[i]["d"] | ||
623 | * optimizer.param_groups[i]["lr"] | ||
624 | ) | ||
569 | logs[f"d*lr/{label}"] = lr | 625 | logs[f"d*lr/{label}"] = lr |
570 | lrs[label] = lr | 626 | lrs[label] = lr |
571 | 627 | ||
@@ -573,8 +629,10 @@ def train_loop( | |||
573 | 629 | ||
574 | local_progress_bar.set_postfix(**logs) | 630 | local_progress_bar.set_postfix(**logs) |
575 | 631 | ||
576 | if ((step + 1) % gradient_accumulation_steps == 0) or ((step + 1) == len(train_dataloader)): | 632 | if ((step + 1) % gradient_accumulation_steps == 0) or ( |
577 | before_optimize_result = on_before_optimize(epoch) | 633 | (step + 1) == len(train_dataloader) |
634 | ): | ||
635 | before_optimize_result = on_before_optimize(cycle) | ||
578 | 636 | ||
579 | optimizer.step() | 637 | optimizer.step() |
580 | lr_scheduler.step() | 638 | lr_scheduler.step() |
@@ -614,7 +672,9 @@ def train_loop( | |||
614 | } | 672 | } |
615 | local_progress_bar.set_postfix(**logs) | 673 | local_progress_bar.set_postfix(**logs) |
616 | 674 | ||
617 | if ((step + 1) % gradient_accumulation_steps == 0) or ((step + 1) == len(val_dataloader)): | 675 | if ((step + 1) % gradient_accumulation_steps == 0) or ( |
676 | (step + 1) == len(val_dataloader) | ||
677 | ): | ||
618 | local_progress_bar.update(1) | 678 | local_progress_bar.update(1) |
619 | global_progress_bar.update(1) | 679 | global_progress_bar.update(1) |
620 | 680 | ||
@@ -634,7 +694,8 @@ def train_loop( | |||
634 | global_progress_bar.clear() | 694 | global_progress_bar.clear() |
635 | 695 | ||
636 | accelerator.print( | 696 | accelerator.print( |
637 | f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg:.2e}") | 697 | f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg:.2e}" |
698 | ) | ||
638 | on_checkpoint(global_step, "milestone") | 699 | on_checkpoint(global_step, "milestone") |
639 | best_acc_val = avg_acc_val.max | 700 | best_acc_val = avg_acc_val.max |
640 | else: | 701 | else: |
@@ -644,7 +705,8 @@ def train_loop( | |||
644 | global_progress_bar.clear() | 705 | global_progress_bar.clear() |
645 | 706 | ||
646 | accelerator.print( | 707 | accelerator.print( |
647 | f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg:.2e}") | 708 | f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg:.2e}" |
709 | ) | ||
648 | on_checkpoint(global_step, "milestone") | 710 | on_checkpoint(global_step, "milestone") |
649 | best_acc = avg_acc.max | 711 | best_acc = avg_acc.max |
650 | 712 | ||
@@ -700,17 +762,32 @@ def train( | |||
700 | avg_acc_val: AverageMeter = AverageMeter(), | 762 | avg_acc_val: AverageMeter = AverageMeter(), |
701 | **kwargs, | 763 | **kwargs, |
702 | ): | 764 | ): |
703 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = strategy.prepare( | 765 | ( |
704 | accelerator, text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, **kwargs) | 766 | text_encoder, |
767 | unet, | ||
768 | optimizer, | ||
769 | train_dataloader, | ||
770 | val_dataloader, | ||
771 | lr_scheduler, | ||
772 | ) = strategy.prepare( | ||
773 | accelerator, | ||
774 | text_encoder, | ||
775 | unet, | ||
776 | optimizer, | ||
777 | train_dataloader, | ||
778 | val_dataloader, | ||
779 | lr_scheduler, | ||
780 | **kwargs, | ||
781 | ) | ||
705 | 782 | ||
706 | vae.to(accelerator.device, dtype=dtype) | 783 | vae.to(accelerator.device, dtype=dtype) |
707 | vae.requires_grad_(False) | 784 | vae.requires_grad_(False) |
708 | vae.eval() | 785 | vae.eval() |
709 | 786 | ||
710 | vae = torch.compile(vae, backend='hidet') | 787 | vae = torch.compile(vae, backend="hidet") |
711 | 788 | ||
712 | if compile_unet: | 789 | if compile_unet: |
713 | unet = torch.compile(unet, backend='hidet') | 790 | unet = torch.compile(unet, backend="hidet") |
714 | # unet = torch.compile(unet, mode="reduce-overhead") | 791 | # unet = torch.compile(unet, mode="reduce-overhead") |
715 | 792 | ||
716 | callbacks = strategy.callbacks( | 793 | callbacks = strategy.callbacks( |
diff --git a/training/lr.py b/training/lr.py index f5b362f..a75078f 100644 --- a/training/lr.py +++ b/training/lr.py | |||
@@ -23,12 +23,12 @@ def plot_metrics( | |||
23 | fig, ax_loss = plt.subplots() | 23 | fig, ax_loss = plt.subplots() |
24 | ax_acc = ax_loss.twinx() | 24 | ax_acc = ax_loss.twinx() |
25 | 25 | ||
26 | ax_loss.plot(lrs, losses, color='red') | 26 | ax_loss.plot(lrs, losses, color="red") |
27 | ax_loss.set_xscale("log") | 27 | ax_loss.set_xscale("log") |
28 | ax_loss.set_xlabel(f"Learning rate") | 28 | ax_loss.set_xlabel(f"Learning rate") |
29 | ax_loss.set_ylabel("Loss") | 29 | ax_loss.set_ylabel("Loss") |
30 | 30 | ||
31 | ax_acc.plot(lrs, accs, color='blue') | 31 | ax_acc.plot(lrs, accs, color="blue") |
32 | ax_acc.set_xscale("log") | 32 | ax_acc.set_xscale("log") |
33 | ax_acc.set_ylabel("Accuracy") | 33 | ax_acc.set_ylabel("Accuracy") |
34 | 34 | ||
diff --git a/training/optimization.py b/training/optimization.py index d22a900..55531bf 100644 --- a/training/optimization.py +++ b/training/optimization.py | |||
@@ -5,7 +5,10 @@ from functools import partial | |||
5 | import torch | 5 | import torch |
6 | from torch.optim.lr_scheduler import LambdaLR | 6 | from torch.optim.lr_scheduler import LambdaLR |
7 | 7 | ||
8 | from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup | 8 | from diffusers.optimization import ( |
9 | get_scheduler as get_scheduler_, | ||
10 | get_cosine_with_hard_restarts_schedule_with_warmup, | ||
11 | ) | ||
9 | from transformers.optimization import get_adafactor_schedule | 12 | from transformers.optimization import get_adafactor_schedule |
10 | 13 | ||
11 | 14 | ||
@@ -52,7 +55,7 @@ def get_one_cycle_schedule( | |||
52 | annealing_exp: int = 1, | 55 | annealing_exp: int = 1, |
53 | min_lr: float = 0.04, | 56 | min_lr: float = 0.04, |
54 | mid_point: float = 0.3, | 57 | mid_point: float = 0.3, |
55 | last_epoch: int = -1 | 58 | last_epoch: int = -1, |
56 | ): | 59 | ): |
57 | if warmup == "linear": | 60 | if warmup == "linear": |
58 | warmup_func = warmup_linear | 61 | warmup_func = warmup_linear |
@@ -83,12 +86,16 @@ def get_one_cycle_schedule( | |||
83 | 86 | ||
84 | def lr_lambda(current_step: int): | 87 | def lr_lambda(current_step: int): |
85 | phase = [p for p in phases if current_step >= p.step_min][-1] | 88 | phase = [p for p in phases if current_step >= p.step_min][-1] |
86 | return phase.min + phase.func((current_step - phase.step_min) / (phase.step_max - phase.step_min)) * (phase.max - phase.min) | 89 | return phase.min + phase.func( |
90 | (current_step - phase.step_min) / (phase.step_max - phase.step_min) | ||
91 | ) * (phase.max - phase.min) | ||
87 | 92 | ||
88 | return LambdaLR(optimizer, lr_lambda, last_epoch) | 93 | return LambdaLR(optimizer, lr_lambda, last_epoch) |
89 | 94 | ||
90 | 95 | ||
91 | def get_exponential_growing_schedule(optimizer, end_lr: float, num_training_steps: int, last_epoch: int = -1): | 96 | def get_exponential_growing_schedule( |
97 | optimizer, end_lr: float, num_training_steps: int, last_epoch: int = -1 | ||
98 | ): | ||
92 | def lr_lambda(base_lr: float, current_step: int): | 99 | def lr_lambda(base_lr: float, current_step: int): |
93 | return (end_lr / base_lr) ** (current_step / num_training_steps) | 100 | return (end_lr / base_lr) ** (current_step / num_training_steps) |
94 | 101 | ||
@@ -132,7 +139,14 @@ def get_scheduler( | |||
132 | ) | 139 | ) |
133 | elif id == "exponential_growth": | 140 | elif id == "exponential_growth": |
134 | if cycles is None: | 141 | if cycles is None: |
135 | cycles = math.ceil(math.sqrt(((num_training_steps - num_warmup_steps) / num_training_steps_per_epoch))) | 142 | cycles = math.ceil( |
143 | math.sqrt( | ||
144 | ( | ||
145 | (num_training_steps - num_warmup_steps) | ||
146 | / num_training_steps_per_epoch | ||
147 | ) | ||
148 | ) | ||
149 | ) | ||
136 | 150 | ||
137 | lr_scheduler = get_exponential_growing_schedule( | 151 | lr_scheduler = get_exponential_growing_schedule( |
138 | optimizer=optimizer, | 152 | optimizer=optimizer, |
@@ -141,7 +155,14 @@ def get_scheduler( | |||
141 | ) | 155 | ) |
142 | elif id == "cosine_with_restarts": | 156 | elif id == "cosine_with_restarts": |
143 | if cycles is None: | 157 | if cycles is None: |
144 | cycles = math.ceil(math.sqrt(((num_training_steps - num_warmup_steps) / num_training_steps_per_epoch))) | 158 | cycles = math.ceil( |
159 | math.sqrt( | ||
160 | ( | ||
161 | (num_training_steps - num_warmup_steps) | ||
162 | / num_training_steps_per_epoch | ||
163 | ) | ||
164 | ) | ||
165 | ) | ||
145 | 166 | ||
146 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( | 167 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( |
147 | optimizer=optimizer, | 168 | optimizer=optimizer, |
@@ -150,10 +171,7 @@ def get_scheduler( | |||
150 | num_cycles=cycles, | 171 | num_cycles=cycles, |
151 | ) | 172 | ) |
152 | elif id == "adafactor": | 173 | elif id == "adafactor": |
153 | lr_scheduler = get_adafactor_schedule( | 174 | lr_scheduler = get_adafactor_schedule(optimizer, initial_lr=min_lr) |
154 | optimizer, | ||
155 | initial_lr=min_lr | ||
156 | ) | ||
157 | else: | 175 | else: |
158 | lr_scheduler = get_scheduler_( | 176 | lr_scheduler = get_scheduler_( |
159 | id, | 177 | id, |
diff --git a/training/sampler.py b/training/sampler.py index bdb3e90..0487d66 100644 --- a/training/sampler.py +++ b/training/sampler.py | |||
@@ -134,7 +134,7 @@ class LossSecondMomentResampler(LossAwareSampler): | |||
134 | def weights(self): | 134 | def weights(self): |
135 | if not self._warmed_up(): | 135 | if not self._warmed_up(): |
136 | return np.ones([self.num_timesteps], dtype=np.float64) | 136 | return np.ones([self.num_timesteps], dtype=np.float64) |
137 | weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) | 137 | weights = np.sqrt(np.mean(self._loss_history**2, axis=-1)) |
138 | weights /= np.sum(weights) | 138 | weights /= np.sum(weights) |
139 | weights *= 1 - self.uniform_prob | 139 | weights *= 1 - self.uniform_prob |
140 | weights += self.uniform_prob / len(weights) | 140 | weights += self.uniform_prob / len(weights) |
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index e6fcc89..88b441b 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py | |||
@@ -29,7 +29,7 @@ def dreambooth_strategy_callbacks( | |||
29 | sample_output_dir: Path, | 29 | sample_output_dir: Path, |
30 | checkpoint_output_dir: Path, | 30 | checkpoint_output_dir: Path, |
31 | seed: int, | 31 | seed: int, |
32 | train_text_encoder_epochs: int, | 32 | train_text_encoder_cycles: int, |
33 | max_grad_norm: float = 1.0, | 33 | max_grad_norm: float = 1.0, |
34 | use_ema: bool = False, | 34 | use_ema: bool = False, |
35 | ema_inv_gamma: float = 1.0, | 35 | ema_inv_gamma: float = 1.0, |
@@ -85,15 +85,13 @@ def dreambooth_strategy_callbacks( | |||
85 | return nullcontext() | 85 | return nullcontext() |
86 | 86 | ||
87 | @contextmanager | 87 | @contextmanager |
88 | def on_train(epoch: int): | 88 | def on_train(cycle: int): |
89 | unet.train() | 89 | unet.train() |
90 | tokenizer.train() | 90 | tokenizer.train() |
91 | 91 | ||
92 | if epoch < train_text_encoder_epochs: | 92 | if cycle < train_text_encoder_cycles: |
93 | text_encoder.train() | 93 | text_encoder.train() |
94 | elif epoch == train_text_encoder_epochs: | 94 | tokenizer.train() |
95 | text_encoder.requires_grad_(False) | ||
96 | text_encoder.eval() | ||
97 | 95 | ||
98 | yield | 96 | yield |
99 | 97 | ||
@@ -106,9 +104,9 @@ def dreambooth_strategy_callbacks( | |||
106 | with ema_context(): | 104 | with ema_context(): |
107 | yield | 105 | yield |
108 | 106 | ||
109 | def on_before_optimize(epoch: int): | 107 | def on_before_optimize(cycle: int): |
110 | params_to_clip = [unet.parameters()] | 108 | params_to_clip = [unet.parameters()] |
111 | if epoch < train_text_encoder_epochs: | 109 | if cycle < train_text_encoder_cycles: |
112 | params_to_clip.append(text_encoder.parameters()) | 110 | params_to_clip.append(text_encoder.parameters()) |
113 | accelerator.clip_grad_norm_(itertools.chain(*params_to_clip), max_grad_norm) | 111 | accelerator.clip_grad_norm_(itertools.chain(*params_to_clip), max_grad_norm) |
114 | 112 | ||
@@ -189,8 +187,16 @@ def dreambooth_prepare( | |||
189 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | 187 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, |
190 | **kwargs | 188 | **kwargs |
191 | ): | 189 | ): |
192 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( | 190 | ( |
193 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) | 191 | text_encoder, |
192 | unet, | ||
193 | optimizer, | ||
194 | train_dataloader, | ||
195 | val_dataloader, | ||
196 | lr_scheduler, | ||
197 | ) = accelerator.prepare( | ||
198 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler | ||
199 | ) | ||
194 | 200 | ||
195 | text_encoder.text_model.embeddings.requires_grad_(False) | 201 | text_encoder.text_model.embeddings.requires_grad_(False) |
196 | 202 | ||
@@ -198,6 +204,5 @@ def dreambooth_prepare( | |||
198 | 204 | ||
199 | 205 | ||
200 | dreambooth_strategy = TrainingStrategy( | 206 | dreambooth_strategy = TrainingStrategy( |
201 | callbacks=dreambooth_strategy_callbacks, | 207 | callbacks=dreambooth_strategy_callbacks, prepare=dreambooth_prepare |
202 | prepare=dreambooth_prepare | ||
203 | ) | 208 | ) |
diff --git a/training/strategy/lora.py b/training/strategy/lora.py index f942b76..14e3384 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py | |||
@@ -81,7 +81,7 @@ def lora_strategy_callbacks( | |||
81 | tokenizer.eval() | 81 | tokenizer.eval() |
82 | yield | 82 | yield |
83 | 83 | ||
84 | def on_before_optimize(epoch: int): | 84 | def on_before_optimize(cycle: int): |
85 | if not pti_mode: | 85 | if not pti_mode: |
86 | accelerator.clip_grad_norm_( | 86 | accelerator.clip_grad_norm_( |
87 | itertools.chain( | 87 | itertools.chain( |
@@ -89,7 +89,7 @@ def lora_strategy_callbacks( | |||
89 | text_encoder.text_model.encoder.parameters(), | 89 | text_encoder.text_model.encoder.parameters(), |
90 | text_encoder.text_model.final_layer_norm.parameters(), | 90 | text_encoder.text_model.final_layer_norm.parameters(), |
91 | ), | 91 | ), |
92 | max_grad_norm | 92 | max_grad_norm, |
93 | ) | 93 | ) |
94 | 94 | ||
95 | if len(placeholder_tokens) != 0 and use_emb_decay: | 95 | if len(placeholder_tokens) != 0 and use_emb_decay: |
@@ -108,7 +108,9 @@ def lora_strategy_callbacks( | |||
108 | 108 | ||
109 | if lambda_ != 0: | 109 | if lambda_ != 0: |
110 | norm = w[:, :].norm(dim=-1, keepdim=True) | 110 | norm = w[:, :].norm(dim=-1, keepdim=True) |
111 | w[:].add_((w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) | 111 | w[:].add_( |
112 | (w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm) | ||
113 | ) | ||
112 | 114 | ||
113 | @torch.no_grad() | 115 | @torch.no_grad() |
114 | def on_checkpoint(step, postfix): | 116 | def on_checkpoint(step, postfix): |
@@ -128,25 +130,32 @@ def lora_strategy_callbacks( | |||
128 | 130 | ||
129 | if not pti_mode: | 131 | if not pti_mode: |
130 | lora_config = {} | 132 | lora_config = {} |
131 | state_dict = get_peft_model_state_dict(unet_, state_dict=accelerator.get_state_dict(unet_)) | 133 | state_dict = get_peft_model_state_dict( |
134 | unet_, state_dict=accelerator.get_state_dict(unet_) | ||
135 | ) | ||
132 | lora_config["peft_config"] = unet_.get_peft_config_as_dict(inference=True) | 136 | lora_config["peft_config"] = unet_.get_peft_config_as_dict(inference=True) |
133 | 137 | ||
134 | text_encoder_state_dict = get_peft_model_state_dict( | 138 | text_encoder_state_dict = get_peft_model_state_dict( |
135 | text_encoder_, state_dict=accelerator.get_state_dict(text_encoder_) | 139 | text_encoder_, state_dict=accelerator.get_state_dict(text_encoder_) |
136 | ) | 140 | ) |
137 | text_encoder_state_dict = {f"text_encoder_{k}": v for k, v in text_encoder_state_dict.items()} | 141 | text_encoder_state_dict = { |
142 | f"text_encoder_{k}": v for k, v in text_encoder_state_dict.items() | ||
143 | } | ||
138 | state_dict.update(text_encoder_state_dict) | 144 | state_dict.update(text_encoder_state_dict) |
139 | lora_config["text_encoder_peft_config"] = text_encoder_.get_peft_config_as_dict(inference=True) | 145 | lora_config[ |
146 | "text_encoder_peft_config" | ||
147 | ] = text_encoder_.get_peft_config_as_dict(inference=True) | ||
140 | 148 | ||
141 | if len(placeholder_tokens) != 0: | 149 | if len(placeholder_tokens) != 0: |
142 | ti_state_dict = { | 150 | ti_state_dict = { |
143 | f"ti_${token}": text_encoder.text_model.embeddings.get_embed(ids) | 151 | f"ti_${token}": text_encoder.text_model.embeddings.get_embed(ids) |
144 | for (token, ids) | 152 | for (token, ids) in zip(placeholder_tokens, placeholder_token_ids) |
145 | in zip(placeholder_tokens, placeholder_token_ids) | ||
146 | } | 153 | } |
147 | state_dict.update(ti_state_dict) | 154 | state_dict.update(ti_state_dict) |
148 | 155 | ||
149 | save_file(state_dict, checkpoint_output_dir / f"{step}_{postfix}.safetensors") | 156 | save_file( |
157 | state_dict, checkpoint_output_dir / f"{step}_{postfix}.safetensors" | ||
158 | ) | ||
150 | with open(checkpoint_output_dir / "lora_config.json", "w") as f: | 159 | with open(checkpoint_output_dir / "lora_config.json", "w") as f: |
151 | json.dump(lora_config, f) | 160 | json.dump(lora_config, f) |
152 | 161 | ||
@@ -185,10 +194,18 @@ def lora_prepare( | |||
185 | train_dataloader: DataLoader, | 194 | train_dataloader: DataLoader, |
186 | val_dataloader: Optional[DataLoader], | 195 | val_dataloader: Optional[DataLoader], |
187 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | 196 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, |
188 | **kwargs | 197 | **kwargs, |
189 | ): | 198 | ): |
190 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( | 199 | ( |
191 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) | 200 | text_encoder, |
201 | unet, | ||
202 | optimizer, | ||
203 | train_dataloader, | ||
204 | val_dataloader, | ||
205 | lr_scheduler, | ||
206 | ) = accelerator.prepare( | ||
207 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler | ||
208 | ) | ||
192 | 209 | ||
193 | # text_encoder.text_model.embeddings.token_embedding.requires_grad_(True) | 210 | # text_encoder.text_model.embeddings.token_embedding.requires_grad_(True) |
194 | 211 | ||
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 6bc1d7d..7373982 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
@@ -104,7 +104,7 @@ def textual_inversion_strategy_callbacks( | |||
104 | yield | 104 | yield |
105 | 105 | ||
106 | @torch.no_grad() | 106 | @torch.no_grad() |
107 | def on_before_optimize(epoch: int): | 107 | def on_before_optimize(cycle: int): |
108 | if use_emb_decay: | 108 | if use_emb_decay: |
109 | params = [ | 109 | params = [ |
110 | p | 110 | p |
@@ -116,7 +116,9 @@ def textual_inversion_strategy_callbacks( | |||
116 | @torch.no_grad() | 116 | @torch.no_grad() |
117 | def on_after_optimize(w, lrs: dict[str, float]): | 117 | def on_after_optimize(w, lrs: dict[str, float]): |
118 | if ema_embeddings is not None: | 118 | if ema_embeddings is not None: |
119 | ema_embeddings.step(text_encoder.text_model.embeddings.token_embedding.parameters()) | 119 | ema_embeddings.step( |
120 | text_encoder.text_model.embeddings.token_embedding.parameters() | ||
121 | ) | ||
120 | 122 | ||
121 | if use_emb_decay and w is not None: | 123 | if use_emb_decay and w is not None: |
122 | lr = lrs["emb"] if "emb" in lrs else lrs["0"] | 124 | lr = lrs["emb"] if "emb" in lrs else lrs["0"] |
@@ -124,7 +126,9 @@ def textual_inversion_strategy_callbacks( | |||
124 | 126 | ||
125 | if lambda_ != 0: | 127 | if lambda_ != 0: |
126 | norm = w[:, :].norm(dim=-1, keepdim=True) | 128 | norm = w[:, :].norm(dim=-1, keepdim=True) |
127 | w[:].add_((w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm)) | 129 | w[:].add_( |
130 | (w[:] / norm.clamp_min(1e-12)) * lambda_ * (emb_decay_target - norm) | ||
131 | ) | ||
128 | 132 | ||
129 | def on_log(): | 133 | def on_log(): |
130 | if ema_embeddings is not None: | 134 | if ema_embeddings is not None: |
@@ -136,10 +140,10 @@ def textual_inversion_strategy_callbacks( | |||
136 | print(f"Saving checkpoint for step {step}...") | 140 | print(f"Saving checkpoint for step {step}...") |
137 | 141 | ||
138 | with ema_context(): | 142 | with ema_context(): |
139 | for (token, ids) in zip(placeholder_tokens, placeholder_token_ids): | 143 | for token, ids in zip(placeholder_tokens, placeholder_token_ids): |
140 | text_encoder.text_model.embeddings.save_embed( | 144 | text_encoder.text_model.embeddings.save_embed( |
141 | ids, | 145 | ids, |
142 | checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin" | 146 | checkpoint_output_dir / f"{slugify(token)}_{step}_{postfix}.bin", |
143 | ) | 147 | ) |
144 | 148 | ||
145 | @torch.no_grad() | 149 | @torch.no_grad() |
@@ -183,7 +187,7 @@ def textual_inversion_prepare( | |||
183 | val_dataloader: Optional[DataLoader], | 187 | val_dataloader: Optional[DataLoader], |
184 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | 188 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, |
185 | gradient_checkpointing: bool = False, | 189 | gradient_checkpointing: bool = False, |
186 | **kwargs | 190 | **kwargs, |
187 | ): | 191 | ): |
188 | weight_dtype = torch.float32 | 192 | weight_dtype = torch.float32 |
189 | if accelerator.state.mixed_precision == "fp16": | 193 | if accelerator.state.mixed_precision == "fp16": |
@@ -191,8 +195,15 @@ def textual_inversion_prepare( | |||
191 | elif accelerator.state.mixed_precision == "bf16": | 195 | elif accelerator.state.mixed_precision == "bf16": |
192 | weight_dtype = torch.bfloat16 | 196 | weight_dtype = torch.bfloat16 |
193 | 197 | ||
194 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( | 198 | ( |
195 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler) | 199 | text_encoder, |
200 | optimizer, | ||
201 | train_dataloader, | ||
202 | val_dataloader, | ||
203 | lr_scheduler, | ||
204 | ) = accelerator.prepare( | ||
205 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler | ||
206 | ) | ||
196 | 207 | ||
197 | unet.to(accelerator.device, dtype=weight_dtype) | 208 | unet.to(accelerator.device, dtype=weight_dtype) |
198 | unet.requires_grad_(False) | 209 | unet.requires_grad_(False) |