diff options
| author | Volpeon <git@volpeon.ink> | 2023-03-01 12:34:42 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-03-01 12:34:42 +0100 |
| commit | a1b8327085ddeab589be074d7e9df4291aba1210 (patch) | |
| tree | 2f2016916d7a2f659268c3e375d55c59583c2b3b /training/functional.py | |
| parent | Fixed TI normalization order (diff) | |
| download | textual-inversion-diff-a1b8327085ddeab589be074d7e9df4291aba1210.tar.gz textual-inversion-diff-a1b8327085ddeab589be074d7e9df4291aba1210.tar.bz2 textual-inversion-diff-a1b8327085ddeab589be074d7e9df4291aba1210.zip | |
Update
Diffstat (limited to 'training/functional.py')
| -rw-r--r-- | training/functional.py | 50 |
1 files changed, 27 insertions, 23 deletions
diff --git a/training/functional.py b/training/functional.py index b830261..990c4cd 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -22,7 +22,6 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | |||
| 22 | from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings | 22 | from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings |
| 23 | from models.clip.util import get_extended_embeddings | 23 | from models.clip.util import get_extended_embeddings |
| 24 | from models.clip.tokenizer import MultiCLIPTokenizer | 24 | from models.clip.tokenizer import MultiCLIPTokenizer |
| 25 | from schedulers.scheduling_deis_multistep import DEISMultistepScheduler | ||
| 26 | from training.util import AverageMeter | 25 | from training.util import AverageMeter |
| 27 | 26 | ||
| 28 | 27 | ||
| @@ -74,19 +73,12 @@ def make_grid(images, rows, cols): | |||
| 74 | return grid | 73 | return grid |
| 75 | 74 | ||
| 76 | 75 | ||
| 77 | def get_models(pretrained_model_name_or_path: str, noise_scheduler: str = "ddpm"): | 76 | def get_models(pretrained_model_name_or_path: str): |
| 78 | if noise_scheduler == "deis": | ||
| 79 | noise_scheduler_cls = DEISMultistepScheduler | ||
| 80 | elif noise_scheduler == "ddpm": | ||
| 81 | noise_scheduler_cls = DDPMScheduler | ||
| 82 | else: | ||
| 83 | raise ValueError(f"noise_scheduler must be one of [\"ddpm\", \"deis\"], got {noise_scheduler}") | ||
| 84 | |||
| 85 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') | 77 | tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') |
| 86 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') | 78 | text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') |
| 87 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') | 79 | vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') |
| 88 | unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet') | 80 | unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet') |
| 89 | noise_scheduler = noise_scheduler_cls.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') | 81 | noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') |
| 90 | sample_scheduler = UniPCMultistepScheduler.from_pretrained( | 82 | sample_scheduler = UniPCMultistepScheduler.from_pretrained( |
| 91 | pretrained_model_name_or_path, subfolder='scheduler') | 83 | pretrained_model_name_or_path, subfolder='scheduler') |
| 92 | 84 | ||
| @@ -232,9 +224,6 @@ def generate_class_images( | |||
| 232 | 224 | ||
| 233 | del pipeline | 225 | del pipeline |
| 234 | 226 | ||
| 235 | if torch.cuda.is_available(): | ||
| 236 | torch.cuda.empty_cache() | ||
| 237 | |||
| 238 | 227 | ||
| 239 | def add_placeholder_tokens( | 228 | def add_placeholder_tokens( |
| 240 | tokenizer: MultiCLIPTokenizer, | 229 | tokenizer: MultiCLIPTokenizer, |
| @@ -274,26 +263,41 @@ def loss_step( | |||
| 274 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 263 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() |
| 275 | latents = latents * vae.config.scaling_factor | 264 | latents = latents * vae.config.scaling_factor |
| 276 | 265 | ||
| 266 | bsz = latents.shape[0] | ||
| 267 | |||
| 277 | generator = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None | 268 | generator = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None |
| 278 | 269 | ||
| 279 | # Sample noise that we'll add to the latents | 270 | # Sample noise that we'll add to the latents |
| 280 | noise = torch.randn( | 271 | if low_freq_noise == 0: |
| 281 | latents.shape, | 272 | noise = torch.randn( |
| 282 | dtype=latents.dtype, | 273 | latents.shape, |
| 283 | layout=latents.layout, | 274 | dtype=latents.dtype, |
| 284 | device=latents.device, | 275 | layout=latents.layout, |
| 285 | generator=generator | 276 | device=latents.device, |
| 286 | ) | 277 | generator=generator |
| 287 | if low_freq_noise > 0: | 278 | ) |
| 288 | noise += low_freq_noise * torch.randn( | 279 | else: |
| 280 | noise = (1 - low_freq_noise) * torch.randn( | ||
| 281 | latents.shape, | ||
| 282 | dtype=latents.dtype, | ||
| 283 | layout=latents.layout, | ||
| 284 | device=latents.device, | ||
| 285 | generator=generator | ||
| 286 | ) + low_freq_noise * torch.randn( | ||
| 289 | latents.shape[0], latents.shape[1], 1, 1, | 287 | latents.shape[0], latents.shape[1], 1, 1, |
| 290 | dtype=latents.dtype, | 288 | dtype=latents.dtype, |
| 291 | layout=latents.layout, | 289 | layout=latents.layout, |
| 292 | device=latents.device, | 290 | device=latents.device, |
| 293 | generator=generator | 291 | generator=generator |
| 294 | ) | 292 | ) |
| 293 | # noise += low_freq_noise * torch.randn( | ||
| 294 | # bsz, 1, 1, 1, | ||
| 295 | # dtype=latents.dtype, | ||
| 296 | # layout=latents.layout, | ||
| 297 | # device=latents.device, | ||
| 298 | # generator=generator | ||
| 299 | # ) | ||
| 295 | 300 | ||
| 296 | bsz = latents.shape[0] | ||
| 297 | # Sample a random timestep for each image | 301 | # Sample a random timestep for each image |
| 298 | timesteps = torch.randint( | 302 | timesteps = torch.randint( |
| 299 | 0, | 303 | 0, |
