diff options
| -rw-r--r-- | train_lora.py | 4 | ||||
| -rw-r--r-- | training/functional.py | 4 | ||||
| -rw-r--r-- | training/strategy/dreambooth.py | 2 | ||||
| -rw-r--r-- | training/strategy/lora.py | 4 | ||||
| -rw-r--r-- | training/strategy/ti.py | 2 |
5 files changed, 5 insertions, 11 deletions
diff --git a/train_lora.py b/train_lora.py index 2cb85cc..b273ae1 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -13,7 +13,7 @@ from accelerate.logging import get_logger | |||
| 13 | from accelerate.utils import LoggerType, set_seed | 13 | from accelerate.utils import LoggerType, set_seed |
| 14 | from slugify import slugify | 14 | from slugify import slugify |
| 15 | from diffusers.loaders import AttnProcsLayers | 15 | from diffusers.loaders import AttnProcsLayers |
| 16 | from diffusers.models.cross_attention import LoRACrossAttnProcessor | 16 | from diffusers.models.cross_attention import LoRAXFormersCrossAttnProcessor |
| 17 | 17 | ||
| 18 | from util import load_config, load_embeddings_from_dir | 18 | from util import load_config, load_embeddings_from_dir |
| 19 | from data.csv import VlpnDataModule, keyword_filter | 19 | from data.csv import VlpnDataModule, keyword_filter |
| @@ -430,7 +430,7 @@ def main(): | |||
| 430 | block_id = int(name[len("down_blocks.")]) | 430 | block_id = int(name[len("down_blocks.")]) |
| 431 | hidden_size = unet.config.block_out_channels[block_id] | 431 | hidden_size = unet.config.block_out_channels[block_id] |
| 432 | 432 | ||
| 433 | lora_attn_procs[name] = LoRACrossAttnProcessor( | 433 | lora_attn_procs[name] = LoRAXFormersCrossAttnProcessor( |
| 434 | hidden_size=hidden_size, cross_attention_dim=cross_attention_dim | 434 | hidden_size=hidden_size, cross_attention_dim=cross_attention_dim |
| 435 | ) | 435 | ) |
| 436 | 436 | ||
diff --git a/training/functional.py b/training/functional.py index 8f47734..ccbb4ad 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -261,8 +261,8 @@ def loss_step( | |||
| 261 | eval: bool = False | 261 | eval: bool = False |
| 262 | ): | 262 | ): |
| 263 | # Convert images to latent space | 263 | # Convert images to latent space |
| 264 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() | 264 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() |
| 265 | latents = latents * 0.18215 | 265 | latents = latents * vae.config.scaling_factor |
| 266 | 266 | ||
| 267 | generator = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None | 267 | generator = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None |
| 268 | 268 | ||
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index b4c77f3..8aaed3a 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py | |||
| @@ -1,4 +1,3 @@ | |||
| 1 | from contextlib import nullcontext | ||
| 2 | from typing import Optional | 1 | from typing import Optional |
| 3 | from functools import partial | 2 | from functools import partial |
| 4 | from contextlib import contextmanager, nullcontext | 3 | from contextlib import contextmanager, nullcontext |
| @@ -6,7 +5,6 @@ from pathlib import Path | |||
| 6 | import itertools | 5 | import itertools |
| 7 | 6 | ||
| 8 | import torch | 7 | import torch |
| 9 | import torch.nn as nn | ||
| 10 | from torch.utils.data import DataLoader | 8 | from torch.utils.data import DataLoader |
| 11 | 9 | ||
| 12 | from accelerate import Accelerator | 10 | from accelerate import Accelerator |
diff --git a/training/strategy/lora.py b/training/strategy/lora.py index 88d1824..92abaa6 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py | |||
| @@ -1,11 +1,9 @@ | |||
| 1 | from contextlib import nullcontext | ||
| 2 | from typing import Optional | 1 | from typing import Optional |
| 3 | from functools import partial | 2 | from functools import partial |
| 4 | from contextlib import contextmanager, nullcontext | 3 | from contextlib import contextmanager |
| 5 | from pathlib import Path | 4 | from pathlib import Path |
| 6 | 5 | ||
| 7 | import torch | 6 | import torch |
| 8 | import torch.nn as nn | ||
| 9 | from torch.utils.data import DataLoader | 7 | from torch.utils.data import DataLoader |
| 10 | 8 | ||
| 11 | from accelerate import Accelerator | 9 | from accelerate import Accelerator |
diff --git a/training/strategy/ti.py b/training/strategy/ti.py index d306f18..da2b81c 100644 --- a/training/strategy/ti.py +++ b/training/strategy/ti.py | |||
| @@ -1,11 +1,9 @@ | |||
| 1 | from contextlib import nullcontext | ||
| 2 | from typing import Optional | 1 | from typing import Optional |
| 3 | from functools import partial | 2 | from functools import partial |
| 4 | from contextlib import contextmanager, nullcontext | 3 | from contextlib import contextmanager, nullcontext |
| 5 | from pathlib import Path | 4 | from pathlib import Path |
| 6 | 5 | ||
| 7 | import torch | 6 | import torch |
| 8 | import torch.nn as nn | ||
| 9 | from torch.utils.data import DataLoader | 7 | from torch.utils.data import DataLoader |
| 10 | 8 | ||
| 11 | from accelerate import Accelerator | 9 | from accelerate import Accelerator |
