summaryrefslogtreecommitdiffstats
path: root/training/strategy
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-03-21 13:46:36 +0100
committerVolpeon <git@volpeon.ink>2023-03-21 13:46:36 +0100
commitf5e0e98f6df9260a93fb650a0b97c85eb87b0fd3 (patch)
tree0d061f5fd8950d7ca7e0198731ee58980859dd18 /training/strategy
parentRestore min SNR (diff)
downloadtextual-inversion-diff-f5e0e98f6df9260a93fb650a0b97c85eb87b0fd3.tar.gz
textual-inversion-diff-f5e0e98f6df9260a93fb650a0b97c85eb87b0fd3.tar.bz2
textual-inversion-diff-f5e0e98f6df9260a93fb650a0b97c85eb87b0fd3.zip
Fixed SNR weighting, re-enabled xformers
Diffstat (limited to 'training/strategy')
-rw-r--r--training/strategy/lora.py70
1 files changed, 59 insertions, 11 deletions
diff --git a/training/strategy/lora.py b/training/strategy/lora.py
index cab5e4c..aa75bec 100644
--- a/training/strategy/lora.py
+++ b/training/strategy/lora.py
@@ -2,6 +2,7 @@ from typing import Optional
2from functools import partial 2from functools import partial
3from contextlib import contextmanager 3from contextlib import contextmanager
4from pathlib import Path 4from pathlib import Path
5import itertools
5 6
6import torch 7import torch
7from torch.utils.data import DataLoader 8from torch.utils.data import DataLoader
@@ -9,12 +10,18 @@ from torch.utils.data import DataLoader
9from accelerate import Accelerator 10from accelerate import Accelerator
10from transformers import CLIPTextModel 11from transformers import CLIPTextModel
11from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepScheduler 12from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepScheduler
12from diffusers.loaders import AttnProcsLayers 13from peft import LoraConfig, LoraModel, get_peft_model_state_dict
14from peft.tuners.lora import mark_only_lora_as_trainable
13 15
14from models.clip.tokenizer import MultiCLIPTokenizer 16from models.clip.tokenizer import MultiCLIPTokenizer
15from training.functional import TrainingStrategy, TrainingCallbacks, save_samples 17from training.functional import TrainingStrategy, TrainingCallbacks, save_samples
16 18
17 19
20# https://github.com/huggingface/peft/blob/main/examples/lora_dreambooth/train_dreambooth.py
21UNET_TARGET_MODULES = ["to_q", "to_v", "query", "value"]
22TEXT_ENCODER_TARGET_MODULES = ["q_proj", "v_proj"]
23
24
18def lora_strategy_callbacks( 25def lora_strategy_callbacks(
19 accelerator: Accelerator, 26 accelerator: Accelerator,
20 unet: UNet2DConditionModel, 27 unet: UNet2DConditionModel,
@@ -27,7 +34,6 @@ def lora_strategy_callbacks(
27 sample_output_dir: Path, 34 sample_output_dir: Path,
28 checkpoint_output_dir: Path, 35 checkpoint_output_dir: Path,
29 seed: int, 36 seed: int,
30 lora_layers: AttnProcsLayers,
31 max_grad_norm: float = 1.0, 37 max_grad_norm: float = 1.0,
32 sample_batch_size: int = 1, 38 sample_batch_size: int = 1,
33 sample_num_batches: int = 1, 39 sample_num_batches: int = 1,
@@ -57,7 +63,8 @@ def lora_strategy_callbacks(
57 ) 63 )
58 64
59 def on_prepare(): 65 def on_prepare():
60 lora_layers.requires_grad_(True) 66 mark_only_lora_as_trainable(unet.model, unet.peft_config.bias)
67 mark_only_lora_as_trainable(text_encoder.model, text_encoder.peft_config.bias)
61 68
62 def on_accum_model(): 69 def on_accum_model():
63 return unet 70 return unet
@@ -73,24 +80,44 @@ def lora_strategy_callbacks(
73 yield 80 yield
74 81
75 def on_before_optimize(lr: float, epoch: int): 82 def on_before_optimize(lr: float, epoch: int):
76 accelerator.clip_grad_norm_(lora_layers.parameters(), max_grad_norm) 83 accelerator.clip_grad_norm_(
84 itertools.chain(unet.parameters(), text_encoder.parameters()),
85 max_grad_norm
86 )
77 87
78 @torch.no_grad() 88 @torch.no_grad()
79 def on_checkpoint(step, postfix): 89 def on_checkpoint(step, postfix):
80 print(f"Saving checkpoint for step {step}...") 90 print(f"Saving checkpoint for step {step}...")
81 91
82 unet_ = accelerator.unwrap_model(unet, False) 92 unet_ = accelerator.unwrap_model(unet, False)
83 unet_.save_attn_procs( 93 text_encoder_ = accelerator.unwrap_model(text_encoder, False)
84 checkpoint_output_dir / f"{step}_{postfix}", 94
85 safe_serialization=True 95 lora_config = {}
96 state_dict = get_peft_model_state_dict(unet, state_dict=accelerator.get_state_dict(unet))
97 lora_config["peft_config"] = unet.get_peft_config_as_dict(inference=True)
98
99 text_encoder_state_dict = get_peft_model_state_dict(
100 text_encoder, state_dict=accelerator.get_state_dict(text_encoder)
86 ) 101 )
102 text_encoder_state_dict = {f"text_encoder_{k}": v for k, v in text_encoder_state_dict.items()}
103 state_dict.update(text_encoder_state_dict)
104 lora_config["text_encoder_peft_config"] = text_encoder.get_peft_config_as_dict(inference=True)
105
106 accelerator.print(state_dict)
107 accelerator.save(state_dict, checkpoint_output_dir / f"{step}_{postfix}.pt")
108
87 del unet_ 109 del unet_
110 del text_encoder_
88 111
89 @torch.no_grad() 112 @torch.no_grad()
90 def on_sample(step): 113 def on_sample(step):
91 unet_ = accelerator.unwrap_model(unet, False) 114 unet_ = accelerator.unwrap_model(unet, False)
115 text_encoder_ = accelerator.unwrap_model(text_encoder, False)
116
92 save_samples_(step=step, unet=unet_) 117 save_samples_(step=step, unet=unet_)
118
93 del unet_ 119 del unet_
120 del text_encoder_
94 121
95 if torch.cuda.is_available(): 122 if torch.cuda.is_available():
96 torch.cuda.empty_cache() 123 torch.cuda.empty_cache()
@@ -114,13 +141,34 @@ def lora_prepare(
114 train_dataloader: DataLoader, 141 train_dataloader: DataLoader,
115 val_dataloader: Optional[DataLoader], 142 val_dataloader: Optional[DataLoader],
116 lr_scheduler: torch.optim.lr_scheduler._LRScheduler, 143 lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
117 lora_layers: AttnProcsLayers, 144 lora_rank: int = 4,
145 lora_alpha: int = 32,
146 lora_dropout: float = 0,
147 lora_bias: str = "none",
118 **kwargs 148 **kwargs
119): 149):
120 lora_layers, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( 150 unet_config = LoraConfig(
121 lora_layers, optimizer, train_dataloader, val_dataloader, lr_scheduler) 151 r=lora_rank,
152 lora_alpha=lora_alpha,
153 target_modules=UNET_TARGET_MODULES,
154 lora_dropout=lora_dropout,
155 bias=lora_bias,
156 )
157 unet = LoraModel(unet_config, unet)
158
159 text_encoder_config = LoraConfig(
160 r=lora_rank,
161 lora_alpha=lora_alpha,
162 target_modules=TEXT_ENCODER_TARGET_MODULES,
163 lora_dropout=lora_dropout,
164 bias=lora_bias,
165 )
166 text_encoder = LoraModel(text_encoder_config, text_encoder)
167
168 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
169 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler)
122 170
123 return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {"lora_layers": lora_layers} 171 return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, {}
124 172
125 173
126lora_strategy = TrainingStrategy( 174lora_strategy = TrainingStrategy(