summaryrefslogtreecommitdiffstats
path: root/training/strategy/lora.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-03-07 07:11:51 +0100
committerVolpeon <git@volpeon.ink>2023-03-07 07:11:51 +0100
commitfe3113451fdde72ddccfc71639f0a2a1e146209a (patch)
treeba4114faf1bd00a642f97b5e7729ad74213c3b80 /training/strategy/lora.py
parentUpdate (diff)
downloadtextual-inversion-diff-fe3113451fdde72ddccfc71639f0a2a1e146209a.tar.gz
textual-inversion-diff-fe3113451fdde72ddccfc71639f0a2a1e146209a.tar.bz2
textual-inversion-diff-fe3113451fdde72ddccfc71639f0a2a1e146209a.zip
Update
Diffstat (limited to 'training/strategy/lora.py')
-rw-r--r--training/strategy/lora.py25
1 files changed, 11 insertions, 14 deletions
diff --git a/training/strategy/lora.py b/training/strategy/lora.py
index ccec215..cab5e4c 100644
--- a/training/strategy/lora.py
+++ b/training/strategy/lora.py
@@ -11,10 +11,7 @@ from transformers import CLIPTextModel
11from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepScheduler 11from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepScheduler
12from diffusers.loaders import AttnProcsLayers 12from diffusers.loaders import AttnProcsLayers
13 13
14from slugify import slugify
15
16from models.clip.tokenizer import MultiCLIPTokenizer 14from models.clip.tokenizer import MultiCLIPTokenizer
17from training.util import EMAModel
18from training.functional import TrainingStrategy, TrainingCallbacks, save_samples 15from training.functional import TrainingStrategy, TrainingCallbacks, save_samples
19 16
20 17
@@ -41,16 +38,9 @@ def lora_strategy_callbacks(
41 sample_output_dir.mkdir(parents=True, exist_ok=True) 38 sample_output_dir.mkdir(parents=True, exist_ok=True)
42 checkpoint_output_dir.mkdir(parents=True, exist_ok=True) 39 checkpoint_output_dir.mkdir(parents=True, exist_ok=True)
43 40
44 weight_dtype = torch.float32
45 if accelerator.state.mixed_precision == "fp16":
46 weight_dtype = torch.float16
47 elif accelerator.state.mixed_precision == "bf16":
48 weight_dtype = torch.bfloat16
49
50 save_samples_ = partial( 41 save_samples_ = partial(
51 save_samples, 42 save_samples,
52 accelerator=accelerator, 43 accelerator=accelerator,
53 unet=unet,
54 text_encoder=text_encoder, 44 text_encoder=text_encoder,
55 tokenizer=tokenizer, 45 tokenizer=tokenizer,
56 vae=vae, 46 vae=vae,
@@ -83,20 +73,27 @@ def lora_strategy_callbacks(
83 yield 73 yield
84 74
85 def on_before_optimize(lr: float, epoch: int): 75 def on_before_optimize(lr: float, epoch: int):
86 if accelerator.sync_gradients: 76 accelerator.clip_grad_norm_(lora_layers.parameters(), max_grad_norm)
87 accelerator.clip_grad_norm_(lora_layers.parameters(), max_grad_norm)
88 77
89 @torch.no_grad() 78 @torch.no_grad()
90 def on_checkpoint(step, postfix): 79 def on_checkpoint(step, postfix):
91 print(f"Saving checkpoint for step {step}...") 80 print(f"Saving checkpoint for step {step}...")
92 81
93 unet_ = accelerator.unwrap_model(unet, False) 82 unet_ = accelerator.unwrap_model(unet, False)
94 unet_.save_attn_procs(checkpoint_output_dir / f"{step}_{postfix}") 83 unet_.save_attn_procs(
84 checkpoint_output_dir / f"{step}_{postfix}",
85 safe_serialization=True
86 )
95 del unet_ 87 del unet_
96 88
97 @torch.no_grad() 89 @torch.no_grad()
98 def on_sample(step): 90 def on_sample(step):
99 save_samples_(step=step) 91 unet_ = accelerator.unwrap_model(unet, False)
92 save_samples_(step=step, unet=unet_)
93 del unet_
94
95 if torch.cuda.is_available():
96 torch.cuda.empty_cache()
100 97
101 return TrainingCallbacks( 98 return TrainingCallbacks(
102 on_prepare=on_prepare, 99 on_prepare=on_prepare,