diff options
| -rw-r--r-- | environment.yaml | 2 | ||||
| -rw-r--r-- | environment_nightly.yaml | 19 | ||||
| -rw-r--r-- | train_lora.py | 58 | ||||
| -rw-r--r-- | training/functional.py | 4 | ||||
| -rw-r--r-- | training/sampler.py | 2 |
5 files changed, 42 insertions, 43 deletions
diff --git a/environment.yaml b/environment.yaml index cf2b732..1a55967 100644 --- a/environment.yaml +++ b/environment.yaml | |||
| @@ -11,7 +11,7 @@ dependencies: | |||
| 11 | - gcc=11.3.0 | 11 | - gcc=11.3.0 |
| 12 | - gxx=11.3.0 | 12 | - gxx=11.3.0 |
| 13 | - matplotlib=3.6.2 | 13 | - matplotlib=3.6.2 |
| 14 | - numpy=1.23.4 | 14 | - numpy=1.24.3 |
| 15 | - pip=22.3.1 | 15 | - pip=22.3.1 |
| 16 | - python=3.10.8 | 16 | - python=3.10.8 |
| 17 | - pytorch=2.0.0=*cuda11.8* | 17 | - pytorch=2.0.0=*cuda11.8* |
diff --git a/environment_nightly.yaml b/environment_nightly.yaml index 4c5c798..d315bd8 100644 --- a/environment_nightly.yaml +++ b/environment_nightly.yaml | |||
| @@ -4,28 +4,31 @@ channels: | |||
| 4 | - nvidia | 4 | - nvidia |
| 5 | - xformers/label/dev | 5 | - xformers/label/dev |
| 6 | - defaults | 6 | - defaults |
| 7 | - conda-forge | 7 | - conda-forge |
| 8 | dependencies: | 8 | dependencies: |
| 9 | - cuda-nvcc=12.1.105 | 9 | - cuda-nvcc=11.8 |
| 10 | - cuda-cudart-dev=11.8 | ||
| 11 | - gcc=11.3.0 | ||
| 12 | - gxx=11.3.0 | ||
| 10 | - matplotlib=3.6.2 | 13 | - matplotlib=3.6.2 |
| 11 | - numpy=1.24.3 | 14 | - numpy=1.24.3 |
| 12 | - pip=22.3.1 | 15 | - pip=22.3.1 |
| 13 | - python=3.10.8 | 16 | - python=3.10.8 |
| 14 | - pytorch=2.1.0.dev20230429=*cuda12.1* | 17 | - pytorch=2.1.0.dev20230515=*cuda11.8* |
| 15 | - torchvision=0.16.0.dev20230429 | 18 | - torchvision=0.16.0.dev20230516 |
| 16 | # - xformers=0.0.19 | 19 | # - xformers=0.0.19 |
| 17 | - pip: | 20 | - pip: |
| 18 | - -e . | 21 | - -e . |
| 19 | - -e git+https://github.com/huggingface/accelerate#egg=accelerate | 22 | - -e git+https://github.com/huggingface/accelerate#egg=accelerate |
| 20 | - -e git+https://github.com/huggingface/diffusers#egg=diffusers | 23 | - -e git+https://github.com/huggingface/diffusers#egg=diffusers |
| 21 | - -e git+https://github.com/facebookresearch/dadaptation#egg=dadaptation | 24 | - -e git+https://github.com/facebookresearch/dadaptation#egg=dadaptation |
| 25 | - --pre --extra-index-url https://download.hidet.org/whl hidet | ||
| 22 | - bitsandbytes==0.38.1 | 26 | - bitsandbytes==0.38.1 |
| 23 | - hidet==0.2.3 | ||
| 24 | - lion-pytorch==0.0.7 | 27 | - lion-pytorch==0.0.7 |
| 25 | - peft==0.2.0 | 28 | - peft==0.3.0 |
| 26 | - python-slugify>=6.1.2 | 29 | - python-slugify>=6.1.2 |
| 27 | - safetensors==0.3.1 | 30 | - safetensors==0.3.1 |
| 28 | - setuptools==65.6.3 | 31 | - setuptools==65.6.3 |
| 29 | - test-tube>=0.7.5 | 32 | - test-tube>=0.7.5 |
| 30 | - timm==0.8.17.dev0 | 33 | - timm==0.9.2 |
| 31 | - transformers==4.28.1 | 34 | - transformers==4.29.1 |
diff --git a/train_lora.py b/train_lora.py index 12d7e72..c74dd8f 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -48,8 +48,8 @@ warnings.filterwarnings('ignore') | |||
| 48 | torch.backends.cuda.matmul.allow_tf32 = True | 48 | torch.backends.cuda.matmul.allow_tf32 = True |
| 49 | torch.backends.cudnn.benchmark = True | 49 | torch.backends.cudnn.benchmark = True |
| 50 | 50 | ||
| 51 | torch._dynamo.config.log_level = logging.WARNING | 51 | # torch._dynamo.config.log_level = logging.WARNING |
| 52 | # torch._dynamo.config.suppress_errors = True | 52 | torch._dynamo.config.suppress_errors = True |
| 53 | 53 | ||
| 54 | hidet.torch.dynamo_config.use_tensor_core(True) | 54 | hidet.torch.dynamo_config.use_tensor_core(True) |
| 55 | hidet.torch.dynamo_config.search_space(0) | 55 | hidet.torch.dynamo_config.search_space(0) |
| @@ -1143,6 +1143,28 @@ def main(): | |||
| 1143 | avg_loss_val = AverageMeter() | 1143 | avg_loss_val = AverageMeter() |
| 1144 | avg_acc_val = AverageMeter() | 1144 | avg_acc_val = AverageMeter() |
| 1145 | 1145 | ||
| 1146 | params_to_optimize = [ | ||
| 1147 | { | ||
| 1148 | "params": ( | ||
| 1149 | param | ||
| 1150 | for param in unet.parameters() | ||
| 1151 | if param.requires_grad | ||
| 1152 | ), | ||
| 1153 | "lr": learning_rate_unet, | ||
| 1154 | }, | ||
| 1155 | { | ||
| 1156 | "params": ( | ||
| 1157 | param | ||
| 1158 | for param in text_encoder.parameters() | ||
| 1159 | if param.requires_grad | ||
| 1160 | ), | ||
| 1161 | "lr": learning_rate_text, | ||
| 1162 | } | ||
| 1163 | ] | ||
| 1164 | group_labels = ["unet", "text"] | ||
| 1165 | |||
| 1166 | lora_optimizer = create_optimizer(params_to_optimize) | ||
| 1167 | |||
| 1146 | while True: | 1168 | while True: |
| 1147 | if len(auto_cycles) != 0: | 1169 | if len(auto_cycles) != 0: |
| 1148 | response = auto_cycles.pop(0) | 1170 | response = auto_cycles.pop(0) |
| @@ -1182,35 +1204,9 @@ def main(): | |||
| 1182 | print("") | 1204 | print("") |
| 1183 | print(f"============ LoRA cycle {training_iter + 1}: {response} ============") | 1205 | print(f"============ LoRA cycle {training_iter + 1}: {response} ============") |
| 1184 | print("") | 1206 | print("") |
| 1185 | 1207 | ||
| 1186 | params_to_optimize = [] | 1208 | for group, lr in zip(lora_optimizer.param_groups, [learning_rate_unet, learning_rate_text]): |
| 1187 | group_labels = [] | 1209 | group['lr'] = lr |
| 1188 | |||
| 1189 | params_to_optimize.append({ | ||
| 1190 | "params": ( | ||
| 1191 | param | ||
| 1192 | for param in unet.parameters() | ||
| 1193 | if param.requires_grad | ||
| 1194 | ), | ||
| 1195 | "lr": learning_rate_unet, | ||
| 1196 | }) | ||
| 1197 | group_labels.append("unet") | ||
| 1198 | |||
| 1199 | if training_iter < args.train_text_encoder_cycles: | ||
| 1200 | params_to_optimize.append({ | ||
| 1201 | "params": ( | ||
| 1202 | param | ||
| 1203 | for param in itertools.chain( | ||
| 1204 | text_encoder.text_model.encoder.parameters(), | ||
| 1205 | text_encoder.text_model.final_layer_norm.parameters(), | ||
| 1206 | ) | ||
| 1207 | if param.requires_grad | ||
| 1208 | ), | ||
| 1209 | "lr": learning_rate_text, | ||
| 1210 | }) | ||
| 1211 | group_labels.append("text") | ||
| 1212 | |||
| 1213 | lora_optimizer = create_optimizer(params_to_optimize) | ||
| 1214 | 1210 | ||
| 1215 | lora_lr_scheduler = create_lr_scheduler( | 1211 | lora_lr_scheduler = create_lr_scheduler( |
| 1216 | lr_scheduler, | 1212 | lr_scheduler, |
diff --git a/training/functional.py b/training/functional.py index 10560e5..fd3f9f4 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -710,8 +710,8 @@ def train( | |||
| 710 | vae = torch.compile(vae, backend='hidet') | 710 | vae = torch.compile(vae, backend='hidet') |
| 711 | 711 | ||
| 712 | if compile_unet: | 712 | if compile_unet: |
| 713 | # unet = torch.compile(unet, backend='hidet') | 713 | unet = torch.compile(unet, backend='hidet') |
| 714 | unet = torch.compile(unet, mode="reduce-overhead") | 714 | # unet = torch.compile(unet, mode="reduce-overhead") |
| 715 | 715 | ||
| 716 | callbacks = strategy.callbacks( | 716 | callbacks = strategy.callbacks( |
| 717 | accelerator=accelerator, | 717 | accelerator=accelerator, |
diff --git a/training/sampler.py b/training/sampler.py index 8afe255..bdb3e90 100644 --- a/training/sampler.py +++ b/training/sampler.py | |||
| @@ -129,7 +129,7 @@ class LossSecondMomentResampler(LossAwareSampler): | |||
| 129 | self._loss_history = np.zeros( | 129 | self._loss_history = np.zeros( |
| 130 | [self.num_timesteps, history_per_term], dtype=np.float64 | 130 | [self.num_timesteps, history_per_term], dtype=np.float64 |
| 131 | ) | 131 | ) |
| 132 | self._loss_counts = np.zeros([self.num_timesteps], dtype=np.int) | 132 | self._loss_counts = np.zeros([self.num_timesteps], dtype=int) |
| 133 | 133 | ||
| 134 | def weights(self): | 134 | def weights(self): |
| 135 | if not self._warmed_up(): | 135 | if not self._warmed_up(): |
