summaryrefslogtreecommitdiffstats
path: root/training/functional.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/functional.py')
-rw-r--r--training/functional.py31
1 files changed, 11 insertions, 20 deletions
diff --git a/training/functional.py b/training/functional.py
index c373ac9..8f47734 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -34,7 +34,7 @@ def const(result=None):
34@dataclass 34@dataclass
35class TrainingCallbacks(): 35class TrainingCallbacks():
36 on_prepare: Callable[[], None] = const() 36 on_prepare: Callable[[], None] = const()
37 on_model: Callable[[], torch.nn.Module] = const(None) 37 on_accum_model: Callable[[], torch.nn.Module] = const(None)
38 on_log: Callable[[], dict[str, Any]] = const({}) 38 on_log: Callable[[], dict[str, Any]] = const({})
39 on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext()) 39 on_train: Callable[[int], _GeneratorContextManager] = const(nullcontext())
40 on_before_optimize: Callable[[float, int], None] = const() 40 on_before_optimize: Callable[[float, int], None] = const()
@@ -51,7 +51,11 @@ class TrainingStrategyPrepareCallable(Protocol):
51 accelerator: Accelerator, 51 accelerator: Accelerator,
52 text_encoder: CLIPTextModel, 52 text_encoder: CLIPTextModel,
53 unet: UNet2DConditionModel, 53 unet: UNet2DConditionModel,
54 *args 54 optimizer: torch.optim.Optimizer,
55 train_dataloader: DataLoader,
56 val_dataloader: Optional[DataLoader],
57 lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
58 **kwargs
55 ) -> Tuple: ... 59 ) -> Tuple: ...
56 60
57 61
@@ -92,7 +96,6 @@ def save_samples(
92 sample_scheduler: DPMSolverMultistepScheduler, 96 sample_scheduler: DPMSolverMultistepScheduler,
93 train_dataloader: DataLoader, 97 train_dataloader: DataLoader,
94 val_dataloader: Optional[DataLoader], 98 val_dataloader: Optional[DataLoader],
95 dtype: torch.dtype,
96 output_dir: Path, 99 output_dir: Path,
97 seed: int, 100 seed: int,
98 step: int, 101 step: int,
@@ -107,15 +110,6 @@ def save_samples(
107 grid_cols = min(batch_size, 4) 110 grid_cols = min(batch_size, 4)
108 grid_rows = (num_batches * batch_size) // grid_cols 111 grid_rows = (num_batches * batch_size) // grid_cols
109 112
110 unet = accelerator.unwrap_model(unet)
111 text_encoder = accelerator.unwrap_model(text_encoder)
112
113 orig_unet_dtype = unet.dtype
114 orig_text_encoder_dtype = text_encoder.dtype
115
116 unet.to(dtype=dtype)
117 text_encoder.to(dtype=dtype)
118
119 pipeline = VlpnStableDiffusion( 113 pipeline = VlpnStableDiffusion(
120 text_encoder=text_encoder, 114 text_encoder=text_encoder,
121 vae=vae, 115 vae=vae,
@@ -172,11 +166,6 @@ def save_samples(
172 image_grid = make_grid(all_samples, grid_rows, grid_cols) 166 image_grid = make_grid(all_samples, grid_rows, grid_cols)
173 image_grid.save(file_path, quality=85) 167 image_grid.save(file_path, quality=85)
174 168
175 unet.to(dtype=orig_unet_dtype)
176 text_encoder.to(dtype=orig_text_encoder_dtype)
177
178 del unet
179 del text_encoder
180 del generator 169 del generator
181 del pipeline 170 del pipeline
182 171
@@ -393,7 +382,7 @@ def train_loop(
393 ) 382 )
394 global_progress_bar.set_description("Total progress") 383 global_progress_bar.set_description("Total progress")
395 384
396 model = callbacks.on_model() 385 model = callbacks.on_accum_model()
397 on_log = callbacks.on_log 386 on_log = callbacks.on_log
398 on_train = callbacks.on_train 387 on_train = callbacks.on_train
399 on_before_optimize = callbacks.on_before_optimize 388 on_before_optimize = callbacks.on_before_optimize
@@ -559,8 +548,10 @@ def train(
559 prior_loss_weight: float = 1.0, 548 prior_loss_weight: float = 1.0,
560 **kwargs, 549 **kwargs,
561): 550):
562 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = strategy.prepare( 551 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare(
563 accelerator, text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler) 552 accelerator, text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, **kwargs)
553
554 kwargs.update(extra)
564 555
565 vae.to(accelerator.device, dtype=dtype) 556 vae.to(accelerator.device, dtype=dtype)
566 557