diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/functional.py | 30 |
1 files changed, 22 insertions, 8 deletions
diff --git a/training/functional.py b/training/functional.py index e1035ce..b7ea90d 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -12,7 +12,7 @@ from torch.utils.data import DataLoader | |||
12 | 12 | ||
13 | from accelerate import Accelerator | 13 | from accelerate import Accelerator |
14 | from transformers import CLIPTextModel | 14 | from transformers import CLIPTextModel |
15 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, DPMSolverMultistepScheduler | 15 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel |
16 | 16 | ||
17 | from tqdm.auto import tqdm | 17 | from tqdm.auto import tqdm |
18 | from PIL import Image | 18 | from PIL import Image |
@@ -22,6 +22,7 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | |||
22 | from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings | 22 | from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings |
23 | from models.clip.util import get_extended_embeddings | 23 | from models.clip.util import get_extended_embeddings |
24 | from models.clip.tokenizer import MultiCLIPTokenizer | 24 | from models.clip.tokenizer import MultiCLIPTokenizer |
25 | from schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler | ||
25 | from training.util import AverageMeter | 26 | from training.util import AverageMeter |
26 | 27 | ||
27 | 28 | ||
@@ -79,7 +80,7 @@ def get_models(pretrained_model_name_or_path: str): | |||
79 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') | 80 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') |
80 | unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet') | 81 | unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet') |
81 | noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') | 82 | noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') |
82 | sample_scheduler = DPMSolverMultistepScheduler.from_pretrained( | 83 | sample_scheduler = UniPCMultistepScheduler.from_pretrained( |
83 | pretrained_model_name_or_path, subfolder='scheduler') | 84 | pretrained_model_name_or_path, subfolder='scheduler') |
84 | 85 | ||
85 | embeddings = patch_managed_embeddings(text_encoder) | 86 | embeddings = patch_managed_embeddings(text_encoder) |
@@ -93,7 +94,7 @@ def save_samples( | |||
93 | text_encoder: CLIPTextModel, | 94 | text_encoder: CLIPTextModel, |
94 | tokenizer: MultiCLIPTokenizer, | 95 | tokenizer: MultiCLIPTokenizer, |
95 | vae: AutoencoderKL, | 96 | vae: AutoencoderKL, |
96 | sample_scheduler: DPMSolverMultistepScheduler, | 97 | sample_scheduler: UniPCMultistepScheduler, |
97 | train_dataloader: DataLoader, | 98 | train_dataloader: DataLoader, |
98 | val_dataloader: Optional[DataLoader], | 99 | val_dataloader: Optional[DataLoader], |
99 | output_dir: Path, | 100 | output_dir: Path, |
@@ -180,7 +181,7 @@ def generate_class_images( | |||
180 | vae: AutoencoderKL, | 181 | vae: AutoencoderKL, |
181 | unet: UNet2DConditionModel, | 182 | unet: UNet2DConditionModel, |
182 | tokenizer: MultiCLIPTokenizer, | 183 | tokenizer: MultiCLIPTokenizer, |
183 | sample_scheduler: DPMSolverMultistepScheduler, | 184 | sample_scheduler: UniPCMultistepScheduler, |
184 | train_dataset: VlpnDataset, | 185 | train_dataset: VlpnDataset, |
185 | sample_batch_size: int, | 186 | sample_batch_size: int, |
186 | sample_image_size: int, | 187 | sample_image_size: int, |
@@ -284,6 +285,7 @@ def loss_step( | |||
284 | device=latents.device, | 285 | device=latents.device, |
285 | generator=generator | 286 | generator=generator |
286 | ) | 287 | ) |
288 | |||
287 | bsz = latents.shape[0] | 289 | bsz = latents.shape[0] |
288 | # Sample a random timestep for each image | 290 | # Sample a random timestep for each image |
289 | timesteps = torch.randint( | 291 | timesteps = torch.randint( |
@@ -351,6 +353,7 @@ def train_loop( | |||
351 | train_dataloader: DataLoader, | 353 | train_dataloader: DataLoader, |
352 | val_dataloader: Optional[DataLoader], | 354 | val_dataloader: Optional[DataLoader], |
353 | loss_step: LossCallable, | 355 | loss_step: LossCallable, |
356 | no_val: bool = False, | ||
354 | sample_frequency: int = 10, | 357 | sample_frequency: int = 10, |
355 | checkpoint_frequency: int = 50, | 358 | checkpoint_frequency: int = 50, |
356 | global_step_offset: int = 0, | 359 | global_step_offset: int = 0, |
@@ -406,9 +409,15 @@ def train_loop( | |||
406 | for epoch in range(num_epochs): | 409 | for epoch in range(num_epochs): |
407 | if accelerator.is_main_process: | 410 | if accelerator.is_main_process: |
408 | if epoch % sample_frequency == 0: | 411 | if epoch % sample_frequency == 0: |
412 | local_progress_bar.clear() | ||
413 | global_progress_bar.clear() | ||
414 | |||
409 | on_sample(global_step + global_step_offset) | 415 | on_sample(global_step + global_step_offset) |
410 | 416 | ||
411 | if epoch % checkpoint_frequency == 0 and epoch != 0: | 417 | if epoch % checkpoint_frequency == 0 and epoch != 0: |
418 | local_progress_bar.clear() | ||
419 | global_progress_bar.clear() | ||
420 | |||
412 | on_checkpoint(global_step + global_step_offset, "training") | 421 | on_checkpoint(global_step + global_step_offset, "training") |
413 | 422 | ||
414 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") | 423 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") |
@@ -463,7 +472,7 @@ def train_loop( | |||
463 | 472 | ||
464 | on_after_epoch(lr_scheduler.get_last_lr()[0]) | 473 | on_after_epoch(lr_scheduler.get_last_lr()[0]) |
465 | 474 | ||
466 | if val_dataloader is not None: | 475 | if val_dataloader is not None and not no_val: |
467 | model.eval() | 476 | model.eval() |
468 | 477 | ||
469 | cur_loss_val = AverageMeter() | 478 | cur_loss_val = AverageMeter() |
@@ -498,11 +507,11 @@ def train_loop( | |||
498 | 507 | ||
499 | accelerator.log(logs, step=global_step) | 508 | accelerator.log(logs, step=global_step) |
500 | 509 | ||
501 | local_progress_bar.clear() | ||
502 | global_progress_bar.clear() | ||
503 | |||
504 | if accelerator.is_main_process: | 510 | if accelerator.is_main_process: |
505 | if avg_acc_val.avg.item() > best_acc_val: | 511 | if avg_acc_val.avg.item() > best_acc_val: |
512 | local_progress_bar.clear() | ||
513 | global_progress_bar.clear() | ||
514 | |||
506 | accelerator.print( | 515 | accelerator.print( |
507 | f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") | 516 | f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") |
508 | on_checkpoint(global_step + global_step_offset, "milestone") | 517 | on_checkpoint(global_step + global_step_offset, "milestone") |
@@ -513,6 +522,9 @@ def train_loop( | |||
513 | else: | 522 | else: |
514 | if accelerator.is_main_process: | 523 | if accelerator.is_main_process: |
515 | if avg_acc.avg.item() > best_acc: | 524 | if avg_acc.avg.item() > best_acc: |
525 | local_progress_bar.clear() | ||
526 | global_progress_bar.clear() | ||
527 | |||
516 | accelerator.print( | 528 | accelerator.print( |
517 | f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg.item():.2e}") | 529 | f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg.item():.2e}") |
518 | on_checkpoint(global_step + global_step_offset, "milestone") | 530 | on_checkpoint(global_step + global_step_offset, "milestone") |
@@ -550,6 +562,7 @@ def train( | |||
550 | optimizer: torch.optim.Optimizer, | 562 | optimizer: torch.optim.Optimizer, |
551 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | 563 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, |
552 | strategy: TrainingStrategy, | 564 | strategy: TrainingStrategy, |
565 | no_val: bool = False, | ||
553 | num_train_epochs: int = 100, | 566 | num_train_epochs: int = 100, |
554 | sample_frequency: int = 20, | 567 | sample_frequency: int = 20, |
555 | checkpoint_frequency: int = 50, | 568 | checkpoint_frequency: int = 50, |
@@ -604,6 +617,7 @@ def train( | |||
604 | lr_scheduler=lr_scheduler, | 617 | lr_scheduler=lr_scheduler, |
605 | train_dataloader=train_dataloader, | 618 | train_dataloader=train_dataloader, |
606 | val_dataloader=val_dataloader, | 619 | val_dataloader=val_dataloader, |
620 | no_val=no_val, | ||
607 | loss_step=loss_step_, | 621 | loss_step=loss_step_, |
608 | sample_frequency=sample_frequency, | 622 | sample_frequency=sample_frequency, |
609 | checkpoint_frequency=checkpoint_frequency, | 623 | checkpoint_frequency=checkpoint_frequency, |