summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-05-16 16:48:51 +0200
committerVolpeon <git@volpeon.ink>2023-05-16 16:48:51 +0200
commit55a12f2c683b2ecfa4fc8b4015462ad2798abda5 (patch)
treefeeb3f9a041466e773bb5921cbf0adb208d60a49
parentAvoid model recompilation due to varying prompt lengths (diff)
downloadtextual-inversion-diff-55a12f2c683b2ecfa4fc8b4015462ad2798abda5.tar.gz
textual-inversion-diff-55a12f2c683b2ecfa4fc8b4015462ad2798abda5.tar.bz2
textual-inversion-diff-55a12f2c683b2ecfa4fc8b4015462ad2798abda5.zip
Fix LoRA training with DAdan
-rw-r--r--environment.yaml2
-rw-r--r--environment_nightly.yaml19
-rw-r--r--train_lora.py58
-rw-r--r--training/functional.py4
-rw-r--r--training/sampler.py2
5 files changed, 42 insertions, 43 deletions
diff --git a/environment.yaml b/environment.yaml
index cf2b732..1a55967 100644
--- a/environment.yaml
+++ b/environment.yaml
@@ -11,7 +11,7 @@ dependencies:
11 - gcc=11.3.0 11 - gcc=11.3.0
12 - gxx=11.3.0 12 - gxx=11.3.0
13 - matplotlib=3.6.2 13 - matplotlib=3.6.2
14 - numpy=1.23.4 14 - numpy=1.24.3
15 - pip=22.3.1 15 - pip=22.3.1
16 - python=3.10.8 16 - python=3.10.8
17 - pytorch=2.0.0=*cuda11.8* 17 - pytorch=2.0.0=*cuda11.8*
diff --git a/environment_nightly.yaml b/environment_nightly.yaml
index 4c5c798..d315bd8 100644
--- a/environment_nightly.yaml
+++ b/environment_nightly.yaml
@@ -4,28 +4,31 @@ channels:
4 - nvidia 4 - nvidia
5 - xformers/label/dev 5 - xformers/label/dev
6 - defaults 6 - defaults
7 - conda-forge 7 - conda-forge
8dependencies: 8dependencies:
9 - cuda-nvcc=12.1.105 9 - cuda-nvcc=11.8
10 - cuda-cudart-dev=11.8
11 - gcc=11.3.0
12 - gxx=11.3.0
10 - matplotlib=3.6.2 13 - matplotlib=3.6.2
11 - numpy=1.24.3 14 - numpy=1.24.3
12 - pip=22.3.1 15 - pip=22.3.1
13 - python=3.10.8 16 - python=3.10.8
14 - pytorch=2.1.0.dev20230429=*cuda12.1* 17 - pytorch=2.1.0.dev20230515=*cuda11.8*
15 - torchvision=0.16.0.dev20230429 18 - torchvision=0.16.0.dev20230516
16 # - xformers=0.0.19 19 # - xformers=0.0.19
17 - pip: 20 - pip:
18 - -e . 21 - -e .
19 - -e git+https://github.com/huggingface/accelerate#egg=accelerate 22 - -e git+https://github.com/huggingface/accelerate#egg=accelerate
20 - -e git+https://github.com/huggingface/diffusers#egg=diffusers 23 - -e git+https://github.com/huggingface/diffusers#egg=diffusers
21 - -e git+https://github.com/facebookresearch/dadaptation#egg=dadaptation 24 - -e git+https://github.com/facebookresearch/dadaptation#egg=dadaptation
25 - --pre --extra-index-url https://download.hidet.org/whl hidet
22 - bitsandbytes==0.38.1 26 - bitsandbytes==0.38.1
23 - hidet==0.2.3
24 - lion-pytorch==0.0.7 27 - lion-pytorch==0.0.7
25 - peft==0.2.0 28 - peft==0.3.0
26 - python-slugify>=6.1.2 29 - python-slugify>=6.1.2
27 - safetensors==0.3.1 30 - safetensors==0.3.1
28 - setuptools==65.6.3 31 - setuptools==65.6.3
29 - test-tube>=0.7.5 32 - test-tube>=0.7.5
30 - timm==0.8.17.dev0 33 - timm==0.9.2
31 - transformers==4.28.1 34 - transformers==4.29.1
diff --git a/train_lora.py b/train_lora.py
index 12d7e72..c74dd8f 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -48,8 +48,8 @@ warnings.filterwarnings('ignore')
48torch.backends.cuda.matmul.allow_tf32 = True 48torch.backends.cuda.matmul.allow_tf32 = True
49torch.backends.cudnn.benchmark = True 49torch.backends.cudnn.benchmark = True
50 50
51torch._dynamo.config.log_level = logging.WARNING 51# torch._dynamo.config.log_level = logging.WARNING
52# torch._dynamo.config.suppress_errors = True 52torch._dynamo.config.suppress_errors = True
53 53
54hidet.torch.dynamo_config.use_tensor_core(True) 54hidet.torch.dynamo_config.use_tensor_core(True)
55hidet.torch.dynamo_config.search_space(0) 55hidet.torch.dynamo_config.search_space(0)
@@ -1143,6 +1143,28 @@ def main():
1143 avg_loss_val = AverageMeter() 1143 avg_loss_val = AverageMeter()
1144 avg_acc_val = AverageMeter() 1144 avg_acc_val = AverageMeter()
1145 1145
1146 params_to_optimize = [
1147 {
1148 "params": (
1149 param
1150 for param in unet.parameters()
1151 if param.requires_grad
1152 ),
1153 "lr": learning_rate_unet,
1154 },
1155 {
1156 "params": (
1157 param
1158 for param in text_encoder.parameters()
1159 if param.requires_grad
1160 ),
1161 "lr": learning_rate_text,
1162 }
1163 ]
1164 group_labels = ["unet", "text"]
1165
1166 lora_optimizer = create_optimizer(params_to_optimize)
1167
1146 while True: 1168 while True:
1147 if len(auto_cycles) != 0: 1169 if len(auto_cycles) != 0:
1148 response = auto_cycles.pop(0) 1170 response = auto_cycles.pop(0)
@@ -1182,35 +1204,9 @@ def main():
1182 print("") 1204 print("")
1183 print(f"============ LoRA cycle {training_iter + 1}: {response} ============") 1205 print(f"============ LoRA cycle {training_iter + 1}: {response} ============")
1184 print("") 1206 print("")
1185 1207
1186 params_to_optimize = [] 1208 for group, lr in zip(lora_optimizer.param_groups, [learning_rate_unet, learning_rate_text]):
1187 group_labels = [] 1209 group['lr'] = lr
1188
1189 params_to_optimize.append({
1190 "params": (
1191 param
1192 for param in unet.parameters()
1193 if param.requires_grad
1194 ),
1195 "lr": learning_rate_unet,
1196 })
1197 group_labels.append("unet")
1198
1199 if training_iter < args.train_text_encoder_cycles:
1200 params_to_optimize.append({
1201 "params": (
1202 param
1203 for param in itertools.chain(
1204 text_encoder.text_model.encoder.parameters(),
1205 text_encoder.text_model.final_layer_norm.parameters(),
1206 )
1207 if param.requires_grad
1208 ),
1209 "lr": learning_rate_text,
1210 })
1211 group_labels.append("text")
1212
1213 lora_optimizer = create_optimizer(params_to_optimize)
1214 1210
1215 lora_lr_scheduler = create_lr_scheduler( 1211 lora_lr_scheduler = create_lr_scheduler(
1216 lr_scheduler, 1212 lr_scheduler,
diff --git a/training/functional.py b/training/functional.py
index 10560e5..fd3f9f4 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -710,8 +710,8 @@ def train(
710 vae = torch.compile(vae, backend='hidet') 710 vae = torch.compile(vae, backend='hidet')
711 711
712 if compile_unet: 712 if compile_unet:
713 # unet = torch.compile(unet, backend='hidet') 713 unet = torch.compile(unet, backend='hidet')
714 unet = torch.compile(unet, mode="reduce-overhead") 714 # unet = torch.compile(unet, mode="reduce-overhead")
715 715
716 callbacks = strategy.callbacks( 716 callbacks = strategy.callbacks(
717 accelerator=accelerator, 717 accelerator=accelerator,
diff --git a/training/sampler.py b/training/sampler.py
index 8afe255..bdb3e90 100644
--- a/training/sampler.py
+++ b/training/sampler.py
@@ -129,7 +129,7 @@ class LossSecondMomentResampler(LossAwareSampler):
129 self._loss_history = np.zeros( 129 self._loss_history = np.zeros(
130 [self.num_timesteps, history_per_term], dtype=np.float64 130 [self.num_timesteps, history_per_term], dtype=np.float64
131 ) 131 )
132 self._loss_counts = np.zeros([self.num_timesteps], dtype=np.int) 132 self._loss_counts = np.zeros([self.num_timesteps], dtype=int)
133 133
134 def weights(self): 134 def weights(self):
135 if not self._warmed_up(): 135 if not self._warmed_up():