diff options
Diffstat (limited to 'training')
-rw-r--r-- | training/functional.py | 7 | ||||
-rw-r--r-- | training/strategy/dreambooth.py | 3 |
2 files changed, 5 insertions, 5 deletions
diff --git a/training/functional.py b/training/functional.py index 43b03ac..546aaff 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -2,6 +2,7 @@ from dataclasses import dataclass | |||
2 | import math | 2 | import math |
3 | from contextlib import _GeneratorContextManager, nullcontext | 3 | from contextlib import _GeneratorContextManager, nullcontext |
4 | from typing import Callable, Any, Tuple, Union, Optional, Protocol | 4 | from typing import Callable, Any, Tuple, Union, Optional, Protocol |
5 | from types import MethodType | ||
5 | from functools import partial | 6 | from functools import partial |
6 | from pathlib import Path | 7 | from pathlib import Path |
7 | import itertools | 8 | import itertools |
@@ -108,6 +109,7 @@ def save_samples( | |||
108 | output_dir: Path, | 109 | output_dir: Path, |
109 | seed: int, | 110 | seed: int, |
110 | step: int, | 111 | step: int, |
112 | validation_prompts: list[str] = [], | ||
111 | cycle: int = 1, | 113 | cycle: int = 1, |
112 | batch_size: int = 1, | 114 | batch_size: int = 1, |
113 | num_batches: int = 1, | 115 | num_batches: int = 1, |
@@ -136,7 +138,6 @@ def save_samples( | |||
136 | 138 | ||
137 | if val_dataloader is not None: | 139 | if val_dataloader is not None: |
138 | datasets.append(("stable", val_dataloader, generator)) | 140 | datasets.append(("stable", val_dataloader, generator)) |
139 | datasets.append(("val", val_dataloader, None)) | ||
140 | 141 | ||
141 | for pool, data, gen in datasets: | 142 | for pool, data, gen in datasets: |
142 | all_samples = [] | 143 | all_samples = [] |
@@ -165,7 +166,6 @@ def save_samples( | |||
165 | guidance_scale=guidance_scale, | 166 | guidance_scale=guidance_scale, |
166 | sag_scale=0, | 167 | sag_scale=0, |
167 | num_inference_steps=num_steps, | 168 | num_inference_steps=num_steps, |
168 | output_type=None, | ||
169 | ).images | 169 | ).images |
170 | 170 | ||
171 | all_samples.append(torch.from_numpy(samples)) | 171 | all_samples.append(torch.from_numpy(samples)) |
@@ -803,4 +803,7 @@ def train( | |||
803 | accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) | 803 | accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) |
804 | accelerator.unwrap_model(unet, keep_fp32_wrapper=False) | 804 | accelerator.unwrap_model(unet, keep_fp32_wrapper=False) |
805 | 805 | ||
806 | text_encoder.forward = MethodType(text_encoder.forward, text_encoder) | ||
807 | unet.forward = MethodType(unet.forward, unet) | ||
808 | |||
806 | accelerator.free_memory() | 809 | accelerator.free_memory() |
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index 0f64747..bd853e2 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py | |||
@@ -155,9 +155,6 @@ def dreambooth_strategy_callbacks( | |||
155 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False) | 155 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False) |
156 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) | 156 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) |
157 | 157 | ||
158 | unet_.forward = MethodType(unet_.forward, unet_) | ||
159 | text_encoder_.forward = MethodType(text_encoder_.forward, text_encoder_) | ||
160 | |||
161 | with ema_context(): | 158 | with ema_context(): |
162 | pipeline = VlpnStableDiffusion( | 159 | pipeline = VlpnStableDiffusion( |
163 | text_encoder=text_encoder_, | 160 | text_encoder=text_encoder_, |