summaryrefslogtreecommitdiffstats
path: root/training/functional.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-02-16 09:16:05 +0100
committerVolpeon <git@volpeon.ink>2023-02-16 09:16:05 +0100
commitd673760fc671d665aadae3b032f8e99f21ab986d (patch)
tree7c14a998742b19ddecac6ee25a669892b41c305e /training/functional.py
parentUpdate (diff)
downloadtextual-inversion-diff-d673760fc671d665aadae3b032f8e99f21ab986d.tar.gz
textual-inversion-diff-d673760fc671d665aadae3b032f8e99f21ab986d.tar.bz2
textual-inversion-diff-d673760fc671d665aadae3b032f8e99f21ab986d.zip
Integrated WIP UniPC scheduler
Diffstat (limited to 'training/functional.py')
-rw-r--r--training/functional.py30
1 files changed, 22 insertions, 8 deletions
diff --git a/training/functional.py b/training/functional.py
index e1035ce..b7ea90d 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -12,7 +12,7 @@ from torch.utils.data import DataLoader
12 12
13from accelerate import Accelerator 13from accelerate import Accelerator
14from transformers import CLIPTextModel 14from transformers import CLIPTextModel
15from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, DPMSolverMultistepScheduler 15from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
16 16
17from tqdm.auto import tqdm 17from tqdm.auto import tqdm
18from PIL import Image 18from PIL import Image
@@ -22,6 +22,7 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
22from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings 22from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings
23from models.clip.util import get_extended_embeddings 23from models.clip.util import get_extended_embeddings
24from models.clip.tokenizer import MultiCLIPTokenizer 24from models.clip.tokenizer import MultiCLIPTokenizer
25from schedulers.scheduling_unipc_multistep import UniPCMultistepScheduler
25from training.util import AverageMeter 26from training.util import AverageMeter
26 27
27 28
@@ -79,7 +80,7 @@ def get_models(pretrained_model_name_or_path: str):
79 vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') 80 vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae')
80 unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet') 81 unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet')
81 noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') 82 noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler')
82 sample_scheduler = DPMSolverMultistepScheduler.from_pretrained( 83 sample_scheduler = UniPCMultistepScheduler.from_pretrained(
83 pretrained_model_name_or_path, subfolder='scheduler') 84 pretrained_model_name_or_path, subfolder='scheduler')
84 85
85 embeddings = patch_managed_embeddings(text_encoder) 86 embeddings = patch_managed_embeddings(text_encoder)
@@ -93,7 +94,7 @@ def save_samples(
93 text_encoder: CLIPTextModel, 94 text_encoder: CLIPTextModel,
94 tokenizer: MultiCLIPTokenizer, 95 tokenizer: MultiCLIPTokenizer,
95 vae: AutoencoderKL, 96 vae: AutoencoderKL,
96 sample_scheduler: DPMSolverMultistepScheduler, 97 sample_scheduler: UniPCMultistepScheduler,
97 train_dataloader: DataLoader, 98 train_dataloader: DataLoader,
98 val_dataloader: Optional[DataLoader], 99 val_dataloader: Optional[DataLoader],
99 output_dir: Path, 100 output_dir: Path,
@@ -180,7 +181,7 @@ def generate_class_images(
180 vae: AutoencoderKL, 181 vae: AutoencoderKL,
181 unet: UNet2DConditionModel, 182 unet: UNet2DConditionModel,
182 tokenizer: MultiCLIPTokenizer, 183 tokenizer: MultiCLIPTokenizer,
183 sample_scheduler: DPMSolverMultistepScheduler, 184 sample_scheduler: UniPCMultistepScheduler,
184 train_dataset: VlpnDataset, 185 train_dataset: VlpnDataset,
185 sample_batch_size: int, 186 sample_batch_size: int,
186 sample_image_size: int, 187 sample_image_size: int,
@@ -284,6 +285,7 @@ def loss_step(
284 device=latents.device, 285 device=latents.device,
285 generator=generator 286 generator=generator
286 ) 287 )
288
287 bsz = latents.shape[0] 289 bsz = latents.shape[0]
288 # Sample a random timestep for each image 290 # Sample a random timestep for each image
289 timesteps = torch.randint( 291 timesteps = torch.randint(
@@ -351,6 +353,7 @@ def train_loop(
351 train_dataloader: DataLoader, 353 train_dataloader: DataLoader,
352 val_dataloader: Optional[DataLoader], 354 val_dataloader: Optional[DataLoader],
353 loss_step: LossCallable, 355 loss_step: LossCallable,
356 no_val: bool = False,
354 sample_frequency: int = 10, 357 sample_frequency: int = 10,
355 checkpoint_frequency: int = 50, 358 checkpoint_frequency: int = 50,
356 global_step_offset: int = 0, 359 global_step_offset: int = 0,
@@ -406,9 +409,15 @@ def train_loop(
406 for epoch in range(num_epochs): 409 for epoch in range(num_epochs):
407 if accelerator.is_main_process: 410 if accelerator.is_main_process:
408 if epoch % sample_frequency == 0: 411 if epoch % sample_frequency == 0:
412 local_progress_bar.clear()
413 global_progress_bar.clear()
414
409 on_sample(global_step + global_step_offset) 415 on_sample(global_step + global_step_offset)
410 416
411 if epoch % checkpoint_frequency == 0 and epoch != 0: 417 if epoch % checkpoint_frequency == 0 and epoch != 0:
418 local_progress_bar.clear()
419 global_progress_bar.clear()
420
412 on_checkpoint(global_step + global_step_offset, "training") 421 on_checkpoint(global_step + global_step_offset, "training")
413 422
414 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") 423 local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}")
@@ -463,7 +472,7 @@ def train_loop(
463 472
464 on_after_epoch(lr_scheduler.get_last_lr()[0]) 473 on_after_epoch(lr_scheduler.get_last_lr()[0])
465 474
466 if val_dataloader is not None: 475 if val_dataloader is not None and not no_val:
467 model.eval() 476 model.eval()
468 477
469 cur_loss_val = AverageMeter() 478 cur_loss_val = AverageMeter()
@@ -498,11 +507,11 @@ def train_loop(
498 507
499 accelerator.log(logs, step=global_step) 508 accelerator.log(logs, step=global_step)
500 509
501 local_progress_bar.clear()
502 global_progress_bar.clear()
503
504 if accelerator.is_main_process: 510 if accelerator.is_main_process:
505 if avg_acc_val.avg.item() > best_acc_val: 511 if avg_acc_val.avg.item() > best_acc_val:
512 local_progress_bar.clear()
513 global_progress_bar.clear()
514
506 accelerator.print( 515 accelerator.print(
507 f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") 516 f"Global step {global_step}: Validation accuracy reached new maximum: {best_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}")
508 on_checkpoint(global_step + global_step_offset, "milestone") 517 on_checkpoint(global_step + global_step_offset, "milestone")
@@ -513,6 +522,9 @@ def train_loop(
513 else: 522 else:
514 if accelerator.is_main_process: 523 if accelerator.is_main_process:
515 if avg_acc.avg.item() > best_acc: 524 if avg_acc.avg.item() > best_acc:
525 local_progress_bar.clear()
526 global_progress_bar.clear()
527
516 accelerator.print( 528 accelerator.print(
517 f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg.item():.2e}") 529 f"Global step {global_step}: Training accuracy reached new maximum: {best_acc:.2e} -> {avg_acc.avg.item():.2e}")
518 on_checkpoint(global_step + global_step_offset, "milestone") 530 on_checkpoint(global_step + global_step_offset, "milestone")
@@ -550,6 +562,7 @@ def train(
550 optimizer: torch.optim.Optimizer, 562 optimizer: torch.optim.Optimizer,
551 lr_scheduler: torch.optim.lr_scheduler._LRScheduler, 563 lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
552 strategy: TrainingStrategy, 564 strategy: TrainingStrategy,
565 no_val: bool = False,
553 num_train_epochs: int = 100, 566 num_train_epochs: int = 100,
554 sample_frequency: int = 20, 567 sample_frequency: int = 20,
555 checkpoint_frequency: int = 50, 568 checkpoint_frequency: int = 50,
@@ -604,6 +617,7 @@ def train(
604 lr_scheduler=lr_scheduler, 617 lr_scheduler=lr_scheduler,
605 train_dataloader=train_dataloader, 618 train_dataloader=train_dataloader,
606 val_dataloader=val_dataloader, 619 val_dataloader=val_dataloader,
620 no_val=no_val,
607 loss_step=loss_step_, 621 loss_step=loss_step_,
608 sample_frequency=sample_frequency, 622 sample_frequency=sample_frequency,
609 checkpoint_frequency=checkpoint_frequency, 623 checkpoint_frequency=checkpoint_frequency,