diff options
author | Volpeon <git@volpeon.ink> | 2023-03-01 12:34:42 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-03-01 12:34:42 +0100 |
commit | a1b8327085ddeab589be074d7e9df4291aba1210 (patch) | |
tree | 2f2016916d7a2f659268c3e375d55c59583c2b3b /training/functional.py | |
parent | Fixed TI normalization order (diff) | |
download | textual-inversion-diff-a1b8327085ddeab589be074d7e9df4291aba1210.tar.gz textual-inversion-diff-a1b8327085ddeab589be074d7e9df4291aba1210.tar.bz2 textual-inversion-diff-a1b8327085ddeab589be074d7e9df4291aba1210.zip |
Update
Diffstat (limited to 'training/functional.py')
-rw-r--r-- | training/functional.py | 50 |
1 files changed, 27 insertions, 23 deletions
diff --git a/training/functional.py b/training/functional.py index b830261..990c4cd 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -22,7 +22,6 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | |||
22 | from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings | 22 | from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings |
23 | from models.clip.util import get_extended_embeddings | 23 | from models.clip.util import get_extended_embeddings |
24 | from models.clip.tokenizer import MultiCLIPTokenizer | 24 | from models.clip.tokenizer import MultiCLIPTokenizer |
25 | from schedulers.scheduling_deis_multistep import DEISMultistepScheduler | ||
26 | from training.util import AverageMeter | 25 | from training.util import AverageMeter |
27 | 26 | ||
28 | 27 | ||
@@ -74,19 +73,12 @@ def make_grid(images, rows, cols): | |||
74 | return grid | 73 | return grid |
75 | 74 | ||
76 | 75 | ||
77 | def get_models(pretrained_model_name_or_path: str, noise_scheduler: str = "ddpm"): | 76 | def get_models(pretrained_model_name_or_path: str): |
78 | if noise_scheduler == "deis": | ||
79 | noise_scheduler_cls = DEISMultistepScheduler | ||
80 | elif noise_scheduler == "ddpm": | ||
81 | noise_scheduler_cls = DDPMScheduler | ||
82 | else: | ||
83 | raise ValueError(f"noise_scheduler must be one of [\"ddpm\", \"deis\"], got {noise_scheduler}") | ||
84 | |||
85 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') | 77 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') |
86 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') | 78 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') |
87 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') | 79 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') |
88 | unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet') | 80 | unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet') |
89 | noise_scheduler = noise_scheduler_cls.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') | 81 | noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') |
90 | sample_scheduler = UniPCMultistepScheduler.from_pretrained( | 82 | sample_scheduler = UniPCMultistepScheduler.from_pretrained( |
91 | pretrained_model_name_or_path, subfolder='scheduler') | 83 | pretrained_model_name_or_path, subfolder='scheduler') |
92 | 84 | ||
@@ -232,9 +224,6 @@ def generate_class_images( | |||
232 | 224 | ||
233 | del pipeline | 225 | del pipeline |
234 | 226 | ||
235 | if torch.cuda.is_available(): | ||
236 | torch.cuda.empty_cache() | ||
237 | |||
238 | 227 | ||
239 | def add_placeholder_tokens( | 228 | def add_placeholder_tokens( |
240 | tokenizer: MultiCLIPTokenizer, | 229 | tokenizer: MultiCLIPTokenizer, |
@@ -274,26 +263,41 @@ def loss_step( | |||
274 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 263 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() |
275 | latents = latents * vae.config.scaling_factor | 264 | latents = latents * vae.config.scaling_factor |
276 | 265 | ||
266 | bsz = latents.shape[0] | ||
267 | |||
277 | generator = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None | 268 | generator = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None |
278 | 269 | ||
279 | # Sample noise that we'll add to the latents | 270 | # Sample noise that we'll add to the latents |
280 | noise = torch.randn( | 271 | if low_freq_noise == 0: |
281 | latents.shape, | 272 | noise = torch.randn( |
282 | dtype=latents.dtype, | 273 | latents.shape, |
283 | layout=latents.layout, | 274 | dtype=latents.dtype, |
284 | device=latents.device, | 275 | layout=latents.layout, |
285 | generator=generator | 276 | device=latents.device, |
286 | ) | 277 | generator=generator |
287 | if low_freq_noise > 0: | 278 | ) |
288 | noise += low_freq_noise * torch.randn( | 279 | else: |
280 | noise = (1 - low_freq_noise) * torch.randn( | ||
281 | latents.shape, | ||
282 | dtype=latents.dtype, | ||
283 | layout=latents.layout, | ||
284 | device=latents.device, | ||
285 | generator=generator | ||
286 | ) + low_freq_noise * torch.randn( | ||
289 | latents.shape[0], latents.shape[1], 1, 1, | 287 | latents.shape[0], latents.shape[1], 1, 1, |
290 | dtype=latents.dtype, | 288 | dtype=latents.dtype, |
291 | layout=latents.layout, | 289 | layout=latents.layout, |
292 | device=latents.device, | 290 | device=latents.device, |
293 | generator=generator | 291 | generator=generator |
294 | ) | 292 | ) |
293 | # noise += low_freq_noise * torch.randn( | ||
294 | # bsz, 1, 1, 1, | ||
295 | # dtype=latents.dtype, | ||
296 | # layout=latents.layout, | ||
297 | # device=latents.device, | ||
298 | # generator=generator | ||
299 | # ) | ||
295 | 300 | ||
296 | bsz = latents.shape[0] | ||
297 | # Sample a random timestep for each image | 301 | # Sample a random timestep for each image |
298 | timesteps = torch.randint( | 302 | timesteps = torch.randint( |
299 | 0, | 303 | 0, |