From 59bf501198d7ff6c0c03c45e92adef14069d5ac6 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sun, 15 Jan 2023 12:33:52 +0100 Subject: Update --- training/functional.py | 100 +++++++++++------------------------------------- training/lr.py | 29 +++++++------- training/strategy/ti.py | 54 +++++++++++++------------- 3 files changed, 64 insertions(+), 119 deletions(-) (limited to 'training') diff --git a/training/functional.py b/training/functional.py index 4ca7470..c01595a 100644 --- a/training/functional.py +++ b/training/functional.py @@ -33,6 +33,7 @@ def const(result=None): @dataclass class TrainingCallbacks(): on_prepare: Callable[[float], None] = const() + on_model: Callable[[], torch.nn.Module] = const(None) on_log: Callable[[], dict[str, Any]] = const({}) on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) on_before_optimize: Callable[[int], None] = const() @@ -267,6 +268,7 @@ def loss_step( noise_scheduler: DDPMScheduler, unet: UNet2DConditionModel, text_encoder: CLIPTextModel, + with_prior_preservation: bool, prior_loss_weight: float, seed: int, step: int, @@ -322,7 +324,7 @@ def loss_step( else: raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") - if batch["with_prior"].all(): + if with_prior_preservation: # Chunk the noise and model_pred into two parts and compute the loss on each part separately. model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) target, target_prior = torch.chunk(target, 2, dim=0) @@ -347,7 +349,6 @@ def train_loop( accelerator: Accelerator, optimizer: torch.optim.Optimizer, lr_scheduler: torch.optim.lr_scheduler._LRScheduler, - model: torch.nn.Module, train_dataloader: DataLoader, val_dataloader: DataLoader, loss_step: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], @@ -387,28 +388,37 @@ def train_loop( ) global_progress_bar.set_description("Total progress") + model = callbacks.on_model() + on_log = callbacks.on_log + on_train = callbacks.on_train + on_before_optimize = callbacks.on_before_optimize + on_after_optimize = callbacks.on_after_optimize + on_eval = callbacks.on_eval + on_sample = callbacks.on_sample + on_checkpoint = callbacks.on_checkpoint + try: for epoch in range(num_epochs): if accelerator.is_main_process: if epoch % sample_frequency == 0: - callbacks.on_sample(global_step + global_step_offset) + on_sample(global_step + global_step_offset) if epoch % checkpoint_frequency == 0 and epoch != 0: - callbacks.on_checkpoint(global_step + global_step_offset, "training") + on_checkpoint(global_step + global_step_offset, "training") local_progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") local_progress_bar.reset() model.train() - with callbacks.on_train(epoch): + with on_train(epoch): for step, batch in enumerate(train_dataloader): with accelerator.accumulate(model): loss, acc, bsz = loss_step(step, batch) accelerator.backward(loss) - callbacks.on_before_optimize(epoch) + on_before_optimize(epoch) optimizer.step() lr_scheduler.step() @@ -419,7 +429,7 @@ def train_loop( # Checks if the accelerator has performed an optimization step behind the scenes if accelerator.sync_gradients: - callbacks.on_after_optimize(lr_scheduler.get_last_lr()[0]) + on_after_optimize(lr_scheduler.get_last_lr()[0]) local_progress_bar.update(1) global_progress_bar.update(1) @@ -433,7 +443,7 @@ def train_loop( "train/cur_acc": acc.item(), "lr": lr_scheduler.get_last_lr()[0], } - logs.update(callbacks.on_log()) + logs.update(on_log()) accelerator.log(logs, step=global_step) @@ -449,7 +459,7 @@ def train_loop( cur_loss_val = AverageMeter() cur_acc_val = AverageMeter() - with torch.inference_mode(), callbacks.on_eval(): + with torch.inference_mode(), on_eval(): for step, batch in enumerate(val_dataloader): loss, acc, bsz = loss_step(step, batch, True) @@ -485,80 +495,16 @@ def train_loop( if avg_acc_val.avg.item() > max_acc_val: accelerator.print( f"Global step {global_step}: Validation accuracy reached new maximum: {max_acc_val:.2e} -> {avg_acc_val.avg.item():.2e}") - callbacks.on_checkpoint(global_step + global_step_offset, "milestone") + on_checkpoint(global_step + global_step_offset, "milestone") max_acc_val = avg_acc_val.avg.item() # Create the pipeline using using the trained modules and save it. if accelerator.is_main_process: print("Finished!") - callbacks.on_checkpoint(global_step + global_step_offset, "end") - callbacks.on_sample(global_step + global_step_offset) - accelerator.end_training() + on_checkpoint(global_step + global_step_offset, "end") + on_sample(global_step + global_step_offset) except KeyboardInterrupt: if accelerator.is_main_process: print("Interrupted") - callbacks.on_checkpoint(global_step + global_step_offset, "end") - accelerator.end_training() - - -def train( - accelerator: Accelerator, - unet: UNet2DConditionModel, - text_encoder: CLIPTextModel, - vae: AutoencoderKL, - noise_scheduler: DDPMScheduler, - train_dataloader: DataLoader, - val_dataloader: DataLoader, - dtype: torch.dtype, - seed: int, - optimizer: torch.optim.Optimizer, - lr_scheduler: torch.optim.lr_scheduler._LRScheduler, - num_train_epochs: int = 100, - sample_frequency: int = 20, - checkpoint_frequency: int = 50, - global_step_offset: int = 0, - prior_loss_weight: float = 0, - callbacks: TrainingCallbacks = TrainingCallbacks(), -): - unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( - unet, text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler - ) - - vae.to(accelerator.device, dtype=dtype) - - for model in (unet, text_encoder, vae): - model.requires_grad_(False) - model.eval() - - callbacks.on_prepare() - - loss_step_ = partial( - loss_step, - vae, - noise_scheduler, - unet, - text_encoder, - prior_loss_weight, - seed, - ) - - if accelerator.is_main_process: - accelerator.init_trackers("textual_inversion") - - train_loop( - accelerator=accelerator, - optimizer=optimizer, - lr_scheduler=lr_scheduler, - model=text_encoder, - train_dataloader=train_dataloader, - val_dataloader=val_dataloader, - loss_step=loss_step_, - sample_frequency=sample_frequency, - checkpoint_frequency=checkpoint_frequency, - global_step_offset=global_step_offset, - num_epochs=num_train_epochs, - callbacks=callbacks, - ) - - accelerator.free_memory() + on_checkpoint(global_step + global_step_offset, "end") diff --git a/training/lr.py b/training/lr.py index 7584ba2..902c4eb 100644 --- a/training/lr.py +++ b/training/lr.py @@ -9,6 +9,7 @@ import torch from torch.optim.lr_scheduler import LambdaLR from tqdm.auto import tqdm +from training.functional import TrainingCallbacks from training.util import AverageMeter @@ -24,26 +25,19 @@ class LRFinder(): def __init__( self, accelerator, - model, optimizer, train_dataloader, val_dataloader, loss_fn: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], - on_train: Callable[[int], _GeneratorContextManager] = noop_ctx, - on_before_optimize: Callable[[int], None] = noop, - on_after_optimize: Callable[[float], None] = noop, - on_eval: Callable[[], _GeneratorContextManager] = noop_ctx + callbacks: TrainingCallbacks = TrainingCallbacks() ): self.accelerator = accelerator - self.model = model + self.model = callbacks.on_model() self.optimizer = optimizer self.train_dataloader = train_dataloader self.val_dataloader = val_dataloader self.loss_fn = loss_fn - self.on_train = on_train - self.on_before_optimize = on_before_optimize - self.on_after_optimize = on_after_optimize - self.on_eval = on_eval + self.callbacks = callbacks # self.model_state = copy.deepcopy(model.state_dict()) # self.optimizer_state = copy.deepcopy(optimizer.state_dict()) @@ -82,6 +76,13 @@ class LRFinder(): ) progress_bar.set_description("Epoch X / Y") + self.callbacks.on_prepare() + + on_train = self.callbacks.on_train + on_before_optimize = self.callbacks.on_before_optimize + on_after_optimize = self.callbacks.on_after_optimize + on_eval = self.callbacks.on_eval + for epoch in range(num_epochs): progress_bar.set_description(f"Epoch {epoch + 1} / {num_epochs}") @@ -90,7 +91,7 @@ class LRFinder(): self.model.train() - with self.on_train(epoch): + with on_train(epoch): for step, batch in enumerate(self.train_dataloader): if step >= num_train_batches: break @@ -100,21 +101,21 @@ class LRFinder(): self.accelerator.backward(loss) - self.on_before_optimize(epoch) + on_before_optimize(epoch) self.optimizer.step() lr_scheduler.step() self.optimizer.zero_grad(set_to_none=True) if self.accelerator.sync_gradients: - self.on_after_optimize(lr_scheduler.get_last_lr()[0]) + on_after_optimize(lr_scheduler.get_last_lr()[0]) progress_bar.update(1) self.model.eval() with torch.inference_mode(): - with self.on_eval(): + with on_eval(): for step, batch in enumerate(self.val_dataloader): if step >= num_val_batches: break diff --git a/training/strategy/ti.py b/training/strategy/ti.py index 6f8384f..753dce0 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py @@ -27,7 +27,6 @@ def textual_inversion_strategy( sample_scheduler: DPMSolverMultistepScheduler, train_dataloader: DataLoader, val_dataloader: DataLoader, - dtype: torch.dtype, output_dir: Path, seed: int, placeholder_tokens: list[str], @@ -48,6 +47,12 @@ def textual_inversion_strategy( sample_guidance_scale: float = 7.5, sample_image_size: Optional[int] = None, ): + weight_dtype = torch.float32 + if accelerator.state.mixed_precision == "fp16": + weight_dtype = torch.float16 + elif accelerator.state.mixed_precision == "bf16": + weight_dtype = torch.bfloat16 + save_samples_ = partial( save_samples, accelerator=accelerator, @@ -58,7 +63,7 @@ def textual_inversion_strategy( sample_scheduler=sample_scheduler, train_dataloader=train_dataloader, val_dataloader=val_dataloader, - dtype=dtype, + dtype=weight_dtype, output_dir=output_dir, seed=seed, batch_size=sample_batch_size, @@ -78,6 +83,17 @@ def textual_inversion_strategy( else: ema_embeddings = None + def ema_context(): + if use_ema: + return ema_embeddings.apply_temporary( + text_encoder.text_model.embeddings.temp_token_embedding.parameters() + ) + else: + return nullcontext() + + def on_model(): + return text_encoder + def on_prepare(): text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(True) @@ -89,24 +105,15 @@ def textual_inversion_strategy( @contextmanager def on_train(epoch: int): - try: - tokenizer.train() - yield - finally: - pass + tokenizer.train() + yield @contextmanager def on_eval(): - try: - tokenizer.eval() + tokenizer.eval() - ema_context = ema_embeddings.apply_temporary( - text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if use_ema else nullcontext() - - with ema_context: - yield - finally: - pass + with ema_context(): + yield @torch.no_grad() def on_after_optimize(lr: float): @@ -131,13 +138,7 @@ def textual_inversion_strategy( checkpoints_path = output_dir.joinpath("checkpoints") checkpoints_path.mkdir(parents=True, exist_ok=True) - text_encoder = accelerator.unwrap_model(text_encoder) - - ema_context = ema_embeddings.apply_temporary( - text_encoder.text_model.embeddings.temp_token_embedding.parameters() - ) if ema_embeddings is not None else nullcontext() - - with ema_context: + with ema_context(): for (token, ids) in zip(placeholder_tokens, placeholder_token_ids): text_encoder.text_model.embeddings.save_embed( ids, @@ -146,15 +147,12 @@ def textual_inversion_strategy( @torch.no_grad() def on_sample(step): - ema_context = ema_embeddings.apply_temporary( - text_encoder.text_model.embeddings.temp_token_embedding.parameters() - ) if ema_embeddings is not None else nullcontext() - - with ema_context: + with ema_context(): save_samples_(step=step) return TrainingCallbacks( on_prepare=on_prepare, + on_model=on_model, on_train=on_train, on_eval=on_eval, on_after_optimize=on_after_optimize, -- cgit v1.2.3-70-g09d2