diff options
| author | Volpeon <git@volpeon.ink> | 2023-01-11 21:54:10 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-01-11 21:54:10 +0100 |
| commit | 92113e3a7c9cfda2bc2f6cc0fa5b1234505f145d (patch) | |
| tree | 7c032c88126b15e2a9eb13ccc4a8293e8d660f29 | |
| parent | Better defaults (diff) | |
| download | textual-inversion-diff-92113e3a7c9cfda2bc2f6cc0fa5b1234505f145d.tar.gz textual-inversion-diff-92113e3a7c9cfda2bc2f6cc0fa5b1234505f145d.tar.bz2 textual-inversion-diff-92113e3a7c9cfda2bc2f6cc0fa5b1234505f145d.zip | |
TI: Use grad clipping from LoRA #104
| -rw-r--r-- | environment.yaml | 2 | ||||
| -rw-r--r-- | train_dreambooth.py | 4 | ||||
| -rw-r--r-- | train_ti.py | 19 | ||||
| -rw-r--r-- | training/common.py | 2 |
4 files changed, 15 insertions, 12 deletions
diff --git a/environment.yaml b/environment.yaml index eff69ed..9af40eb 100644 --- a/environment.yaml +++ b/environment.yaml | |||
| @@ -23,4 +23,4 @@ dependencies: | |||
| 23 | - test-tube>=0.7.5 | 23 | - test-tube>=0.7.5 |
| 24 | - transformers==4.25.1 | 24 | - transformers==4.25.1 |
| 25 | - triton==2.0.0.dev20221202 | 25 | - triton==2.0.0.dev20221202 |
| 26 | - xformers==0.0.16rc401 | 26 | - xformers==0.0.16rc403 |
diff --git a/train_dreambooth.py b/train_dreambooth.py index 0182693..73d9935 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -24,7 +24,7 @@ from slugify import slugify | |||
| 24 | from util import load_config, load_embeddings_from_dir | 24 | from util import load_config, load_embeddings_from_dir |
| 25 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 25 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
| 26 | from data.csv import VlpnDataModule, VlpnDataItem | 26 | from data.csv import VlpnDataModule, VlpnDataItem |
| 27 | from training.common import run_model, generate_class_images | 27 | from training.common import loss_step, generate_class_images |
| 28 | from training.optimization import get_one_cycle_schedule | 28 | from training.optimization import get_one_cycle_schedule |
| 29 | from training.lr import LRFinder | 29 | from training.lr import LRFinder |
| 30 | from training.util import AverageMeter, CheckpointerBase, save_args | 30 | from training.util import AverageMeter, CheckpointerBase, save_args |
| @@ -883,7 +883,7 @@ def main(): | |||
| 883 | pass | 883 | pass |
| 884 | 884 | ||
| 885 | loop = partial( | 885 | loop = partial( |
| 886 | run_model, | 886 | loss_step, |
| 887 | vae, | 887 | vae, |
| 888 | noise_scheduler, | 888 | noise_scheduler, |
| 889 | unet, | 889 | unet, |
diff --git a/train_ti.py b/train_ti.py index 4e2c3c5..1054a5d 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -7,6 +7,7 @@ from pathlib import Path | |||
| 7 | from contextlib import contextmanager, nullcontext | 7 | from contextlib import contextmanager, nullcontext |
| 8 | 8 | ||
| 9 | import torch | 9 | import torch |
| 10 | import torch.nn.functional as F | ||
| 10 | import torch.utils.checkpoint | 11 | import torch.utils.checkpoint |
| 11 | 12 | ||
| 12 | from accelerate import Accelerator | 13 | from accelerate import Accelerator |
| @@ -22,7 +23,7 @@ from slugify import slugify | |||
| 22 | from util import load_config, load_embeddings_from_dir | 23 | from util import load_config, load_embeddings_from_dir |
| 23 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
| 24 | from data.csv import VlpnDataModule, VlpnDataItem | 25 | from data.csv import VlpnDataModule, VlpnDataItem |
| 25 | from training.common import run_model, generate_class_images | 26 | from training.common import loss_step, generate_class_images |
| 26 | from training.optimization import get_one_cycle_schedule | 27 | from training.optimization import get_one_cycle_schedule |
| 27 | from training.lr import LRFinder | 28 | from training.lr import LRFinder |
| 28 | from training.util import AverageMeter, CheckpointerBase, EMAModel, save_args | 29 | from training.util import AverageMeter, CheckpointerBase, EMAModel, save_args |
| @@ -165,7 +166,7 @@ def parse_args(): | |||
| 165 | parser.add_argument( | 166 | parser.add_argument( |
| 166 | "--tag_dropout", | 167 | "--tag_dropout", |
| 167 | type=float, | 168 | type=float, |
| 168 | default=0, | 169 | default=0.1, |
| 169 | help="Tag dropout probability.", | 170 | help="Tag dropout probability.", |
| 170 | ) | 171 | ) |
| 171 | parser.add_argument( | 172 | parser.add_argument( |
| @@ -866,14 +867,16 @@ def main(): | |||
| 866 | finally: | 867 | finally: |
| 867 | pass | 868 | pass |
| 868 | 869 | ||
| 870 | @torch.no_grad() | ||
| 869 | def on_clip(): | 871 | def on_clip(): |
| 870 | accelerator.clip_grad_norm_( | 872 | embeddings = text_encoder.text_model.embeddings.temp_token_embedding |
| 871 | text_encoder.text_model.embeddings.temp_token_embedding.parameters(), | 873 | |
| 872 | args.max_grad_norm | 874 | pre_norm = embeddings.weight.norm(dim=-1, keepdim=True) |
| 873 | ) | 875 | lambda_ = min(1.0, 100 * lr_scheduler.get_last_lr()[0]) |
| 876 | embeddings.weight[:] = F.normalize(embeddings.weight, dim=-1) * (pre_norm + lambda_ * (0.4 - pre_norm)) | ||
| 874 | 877 | ||
| 875 | loop = partial( | 878 | loop = partial( |
| 876 | run_model, | 879 | loss_step, |
| 877 | vae, | 880 | vae, |
| 878 | noise_scheduler, | 881 | noise_scheduler, |
| 879 | unet, | 882 | unet, |
| @@ -971,7 +974,7 @@ def main(): | |||
| 971 | 974 | ||
| 972 | try: | 975 | try: |
| 973 | for epoch in range(num_epochs): | 976 | for epoch in range(num_epochs): |
| 974 | if accelerator.is_main_process: | 977 | if accelerator.is_main_process and False: |
| 975 | if epoch % args.sample_frequency == 0: | 978 | if epoch % args.sample_frequency == 0: |
| 976 | checkpointer.save_samples(global_step + global_step_offset, args.sample_steps) | 979 | checkpointer.save_samples(global_step + global_step_offset, args.sample_steps) |
| 977 | 980 | ||
diff --git a/training/common.py b/training/common.py index 67c2ab6..0b2ae44 100644 --- a/training/common.py +++ b/training/common.py | |||
| @@ -58,7 +58,7 @@ def generate_class_images( | |||
| 58 | torch.cuda.empty_cache() | 58 | torch.cuda.empty_cache() |
| 59 | 59 | ||
| 60 | 60 | ||
| 61 | def run_model( | 61 | def loss_step( |
| 62 | vae: AutoencoderKL, | 62 | vae: AutoencoderKL, |
| 63 | noise_scheduler: DDPMScheduler, | 63 | noise_scheduler: DDPMScheduler, |
| 64 | unet: UNet2DConditionModel, | 64 | unet: UNet2DConditionModel, |
