From 7ccd4614a56cfd6ecacba85605f338593f1059f0 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 7 Feb 2023 20:44:43 +0100 Subject: Add Lora --- training/functional.py | 31 +++++++++++-------------------- 1 file changed, 11 insertions(+), 20 deletions(-) (limited to 'training/functional.py') diff --git a/training/functional.py b/training/functional.py index c373ac9..8f47734 100644 --- a/training/functional.py +++ b/training/functional.py @@ -34,7 +34,7 @@ def const(result=None): @dataclass class TrainingCallbacks(): on_prepare: Callable[[], None] = const() - on_model: Callable[[], torch.nn.Module] = const(None) + on_accum_model: Callable[[], torch.nn.Module] = const(None) on_log: Callable[[], dict[str, Any]] = const({}) on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) on_before_optimize: Callable[[float, int], None] = const() @@ -51,7 +51,11 @@ class TrainingStrategyPrepareCallable(Protocol): accelerator: Accelerator, text_encoder: CLIPTextModel, unet: UNet2DConditionModel, - *args + optimizer: torch.optim.Optimizer, + train_dataloader: DataLoader, + val_dataloader: Optional[DataLoader], + lr_scheduler: torch.optim.lr_scheduler._LRScheduler, + **kwargs ) -> Tuple: ... @@ -92,7 +96,6 @@ def save_samples( sample_scheduler: DPMSolverMultistepScheduler, train_dataloader: DataLoader, val_dataloader: Optional[DataLoader], - dtype: torch.dtype, output_dir: Path, seed: int, step: int, @@ -107,15 +110,6 @@ def save_samples( grid_cols = min(batch_size, 4) grid_rows = (num_batches * batch_size) // grid_cols - unet = accelerator.unwrap_model(unet) - text_encoder = accelerator.unwrap_model(text_encoder) - - orig_unet_dtype = unet.dtype - orig_text_encoder_dtype = text_encoder.dtype - - unet.to(dtype=dtype) - text_encoder.to(dtype=dtype) - pipeline = VlpnStableDiffusion( text_encoder=text_encoder, vae=vae, @@ -172,11 +166,6 @@ def save_samples( image_grid = make_grid(all_samples, grid_rows, grid_cols) image_grid.save(file_path, quality=85) - unet.to(dtype=orig_unet_dtype) - text_encoder.to(dtype=orig_text_encoder_dtype) - - del unet - del text_encoder del generator del pipeline @@ -393,7 +382,7 @@ def train_loop( ) global_progress_bar.set_description("Total progress") - model = callbacks.on_model() + model = callbacks.on_accum_model() on_log = callbacks.on_log on_train = callbacks.on_train on_before_optimize = callbacks.on_before_optimize @@ -559,8 +548,10 @@ def train( prior_loss_weight: float = 1.0, **kwargs, ): - text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = strategy.prepare( - accelerator, text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) + text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare( + accelerator, text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, **kwargs) + + kwargs.update(extra) vae.to(accelerator.device, dtype=dtype) -- cgit v1.2.3-54-g00ecf