diff options
Diffstat (limited to 'training')
| -rw-r--r-- | training/functional.py | 30 |
1 files changed, 22 insertions, 8 deletions
diff --git a/training/functional.py b/training/functional.py index e1035ce..b7ea90d 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -12,7 +12,7 @@ from torch.utils.data import DataLoader | |||
| 12 | 12 | ||
| 13 | from accelerate import Accelerator | 13 | from accelerate import Accelerator |
| 14 | from transformers import CLIPTextModel | 14 | from transformers import CLIPTextModel |
| 15 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, DPMSolverMultistepScheduler | 15 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel |
| 16 | 16 | ||
| 17 | from tqdm.auto import tqdm | 17 | from tqdm.auto import tqdm |
| 18 | from PIL import Image | 18 | from PIL import Image |
| @@ -22,6 +22,7 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | |||
| 22 | from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings | 22 | from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings |
| 23 | from models.clip.util import get_extended_embeddings | 23 | from models.clip.util import get_extended_embeddings |
| 24 | from models.clip.tokenizer import MultiCLIPTokenizer | 24 | from models.clip.tokenizer import MultiCLIPTokenizer |
| 25 | from schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler | ||
| 25 | from training.util import AverageMeter | 26 | from training.util import AverageMeter |
| 26 | 27 | ||
| 27 | 28 | ||
| @@ -79,7 +80,7 @@ def get_models(pretrained_model_name_or_path: str): | |||
| 79 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') | 80 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') |
| 80 | unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet') | 81 | unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet') |
| 81 | noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') | 82 | noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') |
| 82 | sample_scheduler = DPMSolverMultistepScheduler.from_pretrained( | 83 | sample_scheduler = UniPCMultistepScheduler.from_pretrained( |
| 83 | pretrained_model_name_or_path, subfolder='scheduler') | 84 | pretrained_model_name_or_path, subfolder='scheduler') |
| 84 | 85 | ||
| 85 | embeddings = patch_managed_embeddings(text_encoder) | 86 | embeddings = patch_managed_embeddings(text_encoder) |
| @@ -93,7 +94,7 @@ def save_samples( | |||
| 93 | text_encoder: CLIPTextModel, | 94 | text_encoder: CLIPTextModel, |
| 94 | tokenizer: MultiCLIPTokenizer, | 95 | tokenizer: MultiCLIPTokenizer, |
| 95 | vae: AutoencoderKL, | 96 | vae: AutoencoderKL, |
| 96 | sample_scheduler: DPMSolverMultistepScheduler, | 97 | sample_scheduler: UniPCMultistepScheduler, |
| 97 | train_dataloader: DataLoader, | 98 | train_dataloader: DataLoader, |
| 98 | val_dataloader: Optional[DataLoader], | 99 | val_dataloader: Optional[DataLoader], |
| 99 | output_dir: Path, | 100 | output_dir: Path, |
| @@ -180,7 +181,7 @@ def generate_class_images( | |||
| 180 | vae: AutoencoderKL, | 181 | vae: AutoencoderKL, |
| 181 | unet: UNet2DConditionModel, | 182 | unet: UNet2DConditionModel, |
| 182 | tokenizer: MultiCLIPTokenizer, | 183 | tokenizer: MultiCLIPTokenizer, |
| 183 | sample_scheduler: DPMSolverMultistepScheduler, | 184 | sample_scheduler: UniPCMultistepScheduler, |
| 184 | train_dataset: VlpnDataset, | 185 | train_dataset: VlpnDataset, |
| 185 | sample_batch_size: int, | 186 | sample_batch_size: int, |
| 186 | sample_image_size: int, | 187 | sample_image_size: int, |
| @@ -284,6 +285,7 @@ def loss_step( | |||
| 284 | device=latents.device, | 285 | device=latents.device, |
| 285 | generator=generator | 286 | generator=generator |
| 286 | ) | 287 | ) |
| 288 | |||
| 287 | bsz = latents.shape[0] | 289 | bsz = latents.shape[0] |
| 288 | # Sample a random timestep for each image | 290 | # Sample a random timestep for each image |
| 289 | timesteps = torch.randint( | 291 | timesteps = torch.randint( |
| @@ -351,6 +353,7 @@ def train_loop( | |||
| 351 | train_dataloader: DataLoader, | 353 | train_dataloader: DataLoader, |
| 352 | val_dataloader: Optional[DataLoader], | 354 | val_dataloader: Optional[DataLoader], |
| 353 | loss_step: LossCallable, | 355 | loss_step: LossCallable, |
| 356 | no_val: bool = False, | ||
| 354 | sample_frequency: int = 10, | 357 | sample_frequency: int = 10, |
| 355 | checkpoint_frequency: int = 50, | 358 | checkpoint_frequency: int = 50, |
| 356 | global_step_offset: int = 0, | 359 | global_step_offset: int = 0, |
| @@ -406,9 +409,15 @@ def train_loop( | |||
| 406 | for epoch in range(num_epochs): | 409 | for epoch in range(num_epochs): |
| 407 | if accelerator.is_main_process: | 410 | if accelerator.is_main_process: |
| 408 | if epoch % sample_frequency == 0: | 411 | if epoch % sample_frequency == 0: |
| 412 | local_progress_bar.clear() | ||
| 413 | global_progress_bar.clear() | ||
| 414 | |||
| 409 | on_sample(global_step + global_step_offset) | 415 | on_sample(global_step + global_step_offset) |
| 410 | 416 | ||
| 411 | if epoch % checkpoint_frequency == 0 and epoch != 0: | 417 | if epoch % checkpoint_frequency == 0 and epoch != 0: |
| 418 | local_progress_bar.clear() | ||
| 419 | global_progress_bar.clear() | ||
| 420 | |||
| 412 | on_checkpoint(global_step + global_step_offset, "training") | 421 | on_checkpoint(global_step + global_step_offset, "training") |
| 413 | 422 | ||
| 414 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") | 423 | local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") |
| @@ -463,7 +472,7 @@ def train_loop( | |||
| 463 | 472 | ||
| 464 | on_after_epoch(lr_scheduler.get_last_lr()[0]) | 473 | on_after_epoch(lr_scheduler.get_last_lr()[0]) |
| 465 | 474 | ||
| 466 | if val_dataloader is not None: | 475 | if val_dataloader is not None and not no_val: |
| 467 | model.eval() | 476 | model.eval() |
| 468 | 477 | ||
| 469 | cur_loss_val = AverageMeter() | 478 | cur_loss_val = AverageMeter() |
| @@ -498,11 +507,11 @@ def train_loop( | |||
| 498 | 507 | ||
| 499 | accelerator.log(logs, step=global_step) | 508 | accelerator.log(logs, step=global_step) |
| 500 | 509 | ||
| 501 | local_progress_bar.clear() | ||
| 502 | global_progress_bar.clear() | ||
| 503 | |||
| 504 | if accelerator.is_main_process: | 510 | if accelerator.is_main_process: |
| 505 | if avg_acc_val.avg.item() > best_acc_val: | 511 | if avg_acc_val.avg.item() > best_acc_val: |
| 512 | local_progress_bar.clear() | ||
| 513 | global_progress_bar.clear() | ||
| 514 | |||
| 506 | accelerator.print( | 515 | accelerator.print( |
| 507 | f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") | 516 | f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") |
| 508 | on_checkpoint(global_step + global_step_offset, "milestone") | 517 | on_checkpoint(global_step + global_step_offset, "milestone") |
| @@ -513,6 +522,9 @@ def train_loop( | |||
| 513 | else: | 522 | else: |
| 514 | if accelerator.is_main_process: | 523 | if accelerator.is_main_process: |
| 515 | if avg_acc.avg.item() > best_acc: | 524 | if avg_acc.avg.item() > best_acc: |
| 525 | local_progress_bar.clear() | ||
| 526 | global_progress_bar.clear() | ||
| 527 | |||
| 516 | accelerator.print( | 528 | accelerator.print( |
| 517 | f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg.item():.2e}") | 529 | f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg.item():.2e}") |
| 518 | on_checkpoint(global_step + global_step_offset, "milestone") | 530 | on_checkpoint(global_step + global_step_offset, "milestone") |
| @@ -550,6 +562,7 @@ def train( | |||
| 550 | optimizer: torch.optim.Optimizer, | 562 | optimizer: torch.optim.Optimizer, |
| 551 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | 563 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, |
| 552 | strategy: TrainingStrategy, | 564 | strategy: TrainingStrategy, |
| 565 | no_val: bool = False, | ||
| 553 | num_train_epochs: int = 100, | 566 | num_train_epochs: int = 100, |
| 554 | sample_frequency: int = 20, | 567 | sample_frequency: int = 20, |
| 555 | checkpoint_frequency: int = 50, | 568 | checkpoint_frequency: int = 50, |
| @@ -604,6 +617,7 @@ def train( | |||
| 604 | lr_scheduler=lr_scheduler, | 617 | lr_scheduler=lr_scheduler, |
| 605 | train_dataloader=train_dataloader, | 618 | train_dataloader=train_dataloader, |
| 606 | val_dataloader=val_dataloader, | 619 | val_dataloader=val_dataloader, |
| 620 | no_val=no_val, | ||
| 607 | loss_step=loss_step_, | 621 | loss_step=loss_step_, |
| 608 | sample_frequency=sample_frequency, | 622 | sample_frequency=sample_frequency, |
| 609 | checkpoint_frequency=checkpoint_frequency, | 623 | checkpoint_frequency=checkpoint_frequency, |
