diff options
Diffstat (limited to 'training/strategy')
-rw-r--r-- | training/strategy/lora.py | 25 |
1 files changed, 11 insertions, 14 deletions
diff --git a/training/strategy/lora.py b/training/strategy/lora.py index ccec215..cab5e4c 100644 --- a/training/strategy/lora.py +++ b/training/strategy/lora.py | |||
@@ -11,10 +11,7 @@ from transformers import CLIPTextModel | |||
11 | from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepScheduler | 11 | from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepScheduler |
12 | from diffusers.loaders import AttnProcsLayers | 12 | from diffusers.loaders import AttnProcsLayers |
13 | 13 | ||
14 | from slugify import slugify | ||
15 | |||
16 | from models.clip.tokenizer import MultiCLIPTokenizer | 14 | from models.clip.tokenizer import MultiCLIPTokenizer |
17 | from training.util import EMAModel | ||
18 | from training.functional import TrainingStrategy, TrainingCallbacks, save_samples | 15 | from training.functional import TrainingStrategy, TrainingCallbacks, save_samples |
19 | 16 | ||
20 | 17 | ||
@@ -41,16 +38,9 @@ def lora_strategy_callbacks( | |||
41 | sample_output_dir.mkdir(parents=True, exist_ok=True) | 38 | sample_output_dir.mkdir(parents=True, exist_ok=True) |
42 | checkpoint_output_dir.mkdir(parents=True, exist_ok=True) | 39 | checkpoint_output_dir.mkdir(parents=True, exist_ok=True) |
43 | 40 | ||
44 | weight_dtype = torch.float32 | ||
45 | if accelerator.state.mixed_precision == "fp16": | ||
46 | weight_dtype = torch.float16 | ||
47 | elif accelerator.state.mixed_precision == "bf16": | ||
48 | weight_dtype = torch.bfloat16 | ||
49 | |||
50 | save_samples_ = partial( | 41 | save_samples_ = partial( |
51 | save_samples, | 42 | save_samples, |
52 | accelerator=accelerator, | 43 | accelerator=accelerator, |
53 | unet=unet, | ||
54 | text_encoder=text_encoder, | 44 | text_encoder=text_encoder, |
55 | tokenizer=tokenizer, | 45 | tokenizer=tokenizer, |
56 | vae=vae, | 46 | vae=vae, |
@@ -83,20 +73,27 @@ def lora_strategy_callbacks( | |||
83 | yield | 73 | yield |
84 | 74 | ||
85 | def on_before_optimize(lr: float, epoch: int): | 75 | def on_before_optimize(lr: float, epoch: int): |
86 | if accelerator.sync_gradients: | 76 | accelerator.clip_grad_norm_(lora_layers.parameters(), max_grad_norm) |
87 | accelerator.clip_grad_norm_(lora_layers.parameters(), max_grad_norm) | ||
88 | 77 | ||
89 | @torch.no_grad() | 78 | @torch.no_grad() |
90 | def on_checkpoint(step, postfix): | 79 | def on_checkpoint(step, postfix): |
91 | print(f"Saving checkpoint for step {step}...") | 80 | print(f"Saving checkpoint for step {step}...") |
92 | 81 | ||
93 | unet_ = accelerator.unwrap_model(unet, False) | 82 | unet_ = accelerator.unwrap_model(unet, False) |
94 | unet_.save_attn_procs(checkpoint_output_dir / f"{step}_{postfix}") | 83 | unet_.save_attn_procs( |
84 | checkpoint_output_dir / f"{step}_{postfix}", | ||
85 | safe_serialization=True | ||
86 | ) | ||
95 | del unet_ | 87 | del unet_ |
96 | 88 | ||
97 | @torch.no_grad() | 89 | @torch.no_grad() |
98 | def on_sample(step): | 90 | def on_sample(step): |
99 | save_samples_(step=step) | 91 | unet_ = accelerator.unwrap_model(unet, False) |
92 | save_samples_(step=step, unet=unet_) | ||
93 | del unet_ | ||
94 | |||
95 | if torch.cuda.is_available(): | ||
96 | torch.cuda.empty_cache() | ||
100 | 97 | ||
101 | return TrainingCallbacks( | 98 | return TrainingCallbacks( |
102 | on_prepare=on_prepare, | 99 | on_prepare=on_prepare, |