summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-11 21:54:10 +0100
committerVolpeon <git@volpeon.ink>2023-01-11 21:54:10 +0100
commit92113e3a7c9cfda2bc2f6cc0fa5b1234505f145d (patch)
tree7c032c88126b15e2a9eb13ccc4a8293e8d660f29
parentBetter defaults (diff)
downloadtextual-inversion-diff-92113e3a7c9cfda2bc2f6cc0fa5b1234505f145d.tar.gz
textual-inversion-diff-92113e3a7c9cfda2bc2f6cc0fa5b1234505f145d.tar.bz2
textual-inversion-diff-92113e3a7c9cfda2bc2f6cc0fa5b1234505f145d.zip
TI: Use grad clipping from LoRA #104
-rw-r--r--environment.yaml2
-rw-r--r--train_dreambooth.py4
-rw-r--r--train_ti.py19
-rw-r--r--training/common.py2
4 files changed, 15 insertions, 12 deletions
diff --git a/environment.yaml b/environment.yaml
index eff69ed..9af40eb 100644
--- a/environment.yaml
+++ b/environment.yaml
@@ -23,4 +23,4 @@ dependencies:
23 - test-tube>=0.7.5 23 - test-tube>=0.7.5
24 - transformers==4.25.1 24 - transformers==4.25.1
25 - triton==2.0.0.dev20221202 25 - triton==2.0.0.dev20221202
26 - xformers==0.0.16rc401 26 - xformers==0.0.16rc403
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 0182693..73d9935 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -24,7 +24,7 @@ from slugify import slugify
24from util import load_config, load_embeddings_from_dir 24from util import load_config, load_embeddings_from_dir
25from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 25from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
26from data.csv import VlpnDataModule, VlpnDataItem 26from data.csv import VlpnDataModule, VlpnDataItem
27from training.common import run_model, generate_class_images 27from training.common import loss_step, generate_class_images
28from training.optimization import get_one_cycle_schedule 28from training.optimization import get_one_cycle_schedule
29from training.lr import LRFinder 29from training.lr import LRFinder
30from training.util import AverageMeter, CheckpointerBase, save_args 30from training.util import AverageMeter, CheckpointerBase, save_args
@@ -883,7 +883,7 @@ def main():
883 pass 883 pass
884 884
885 loop = partial( 885 loop = partial(
886 run_model, 886 loss_step,
887 vae, 887 vae,
888 noise_scheduler, 888 noise_scheduler,
889 unet, 889 unet,
diff --git a/train_ti.py b/train_ti.py
index 4e2c3c5..1054a5d 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -7,6 +7,7 @@ from pathlib import Path
7from contextlib import contextmanager, nullcontext 7from contextlib import contextmanager, nullcontext
8 8
9import torch 9import torch
10import torch.nn.functional as F
10import torch.utils.checkpoint 11import torch.utils.checkpoint
11 12
12from accelerate import Accelerator 13from accelerate import Accelerator
@@ -22,7 +23,7 @@ from slugify import slugify
22from util import load_config, load_embeddings_from_dir 23from util import load_config, load_embeddings_from_dir
23from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 24from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
24from data.csv import VlpnDataModule, VlpnDataItem 25from data.csv import VlpnDataModule, VlpnDataItem
25from training.common import run_model, generate_class_images 26from training.common import loss_step, generate_class_images
26from training.optimization import get_one_cycle_schedule 27from training.optimization import get_one_cycle_schedule
27from training.lr import LRFinder 28from training.lr import LRFinder
28from training.util import AverageMeter, CheckpointerBase, EMAModel, save_args 29from training.util import AverageMeter, CheckpointerBase, EMAModel, save_args
@@ -165,7 +166,7 @@ def parse_args():
165 parser.add_argument( 166 parser.add_argument(
166 "--tag_dropout", 167 "--tag_dropout",
167 type=float, 168 type=float,
168 default=0, 169 default=0.1,
169 help="Tag dropout probability.", 170 help="Tag dropout probability.",
170 ) 171 )
171 parser.add_argument( 172 parser.add_argument(
@@ -866,14 +867,16 @@ def main():
866 finally: 867 finally:
867 pass 868 pass
868 869
870 @torch.no_grad()
869 def on_clip(): 871 def on_clip():
870 accelerator.clip_grad_norm_( 872 embeddings = text_encoder.text_model.embeddings.temp_token_embedding
871 text_encoder.text_model.embeddings.temp_token_embedding.parameters(), 873
872 args.max_grad_norm 874 pre_norm = embeddings.weight.norm(dim=-1, keepdim=True)
873 ) 875 lambda_ = min(1.0, 100 * lr_scheduler.get_last_lr()[0])
876 embeddings.weight[:] = F.normalize(embeddings.weight, dim=-1) * (pre_norm + lambda_ * (0.4 - pre_norm))
874 877
875 loop = partial( 878 loop = partial(
876 run_model, 879 loss_step,
877 vae, 880 vae,
878 noise_scheduler, 881 noise_scheduler,
879 unet, 882 unet,
@@ -971,7 +974,7 @@ def main():
971 974
972 try: 975 try:
973 for epoch in range(num_epochs): 976 for epoch in range(num_epochs):
974 if accelerator.is_main_process: 977 if accelerator.is_main_process and False:
975 if epoch % args.sample_frequency == 0: 978 if epoch % args.sample_frequency == 0:
976 checkpointer.save_samples(global_step + global_step_offset, args.sample_steps) 979 checkpointer.save_samples(global_step + global_step_offset, args.sample_steps)
977 980
diff --git a/training/common.py b/training/common.py
index 67c2ab6..0b2ae44 100644
--- a/training/common.py
+++ b/training/common.py
@@ -58,7 +58,7 @@ def generate_class_images(
58 torch.cuda.empty_cache() 58 torch.cuda.empty_cache()
59 59
60 60
61def run_model( 61def loss_step(
62 vae: AutoencoderKL, 62 vae: AutoencoderKL,
63 noise_scheduler: DDPMScheduler, 63 noise_scheduler: DDPMScheduler,
64 unet: UNet2DConditionModel, 64 unet: UNet2DConditionModel,