From 55a12f2c683b2ecfa4fc8b4015462ad2798abda5 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 16 May 2023 16:48:51 +0200 Subject: Fix LoRA training with DAdan --- environment.yaml | 2 +- environment_nightly.yaml | 19 +++++++++------- train_lora.py | 58 ++++++++++++++++++++++-------------------------- training/functional.py | 4 ++-- training/sampler.py | 2 +- 5 files changed, 42 insertions(+), 43 deletions(-) diff --git a/environment.yaml b/environment.yaml index cf2b732..1a55967 100644 --- a/environment.yaml +++ b/environment.yaml @@ -11,7 +11,7 @@ dependencies: - gcc=11.3.0 - gxx=11.3.0 - matplotlib=3.6.2 - - numpy=1.23.4 + - numpy=1.24.3 - pip=22.3.1 - python=3.10.8 - pytorch=2.0.0=*cuda11.8* diff --git a/environment_nightly.yaml b/environment_nightly.yaml index 4c5c798..d315bd8 100644 --- a/environment_nightly.yaml +++ b/environment_nightly.yaml @@ -4,28 +4,31 @@ channels: - nvidia - xformers/label/dev - defaults - - conda-forge + - conda-forge dependencies: - - cuda-nvcc=12.1.105 + - cuda-nvcc=11.8 + - cuda-cudart-dev=11.8 + - gcc=11.3.0 + - gxx=11.3.0 - matplotlib=3.6.2 - numpy=1.24.3 - pip=22.3.1 - python=3.10.8 - - pytorch=2.1.0.dev20230429=*cuda12.1* - - torchvision=0.16.0.dev20230429 + - pytorch=2.1.0.dev20230515=*cuda11.8* + - torchvision=0.16.0.dev20230516 # - xformers=0.0.19 - pip: - -e . - -e git+https://github.com/huggingface/accelerate#egg=accelerate - -e git+https://github.com/huggingface/diffusers#egg=diffusers - -e git+https://github.com/facebookresearch/dadaptation#egg=dadaptation + - --pre --extra-index-url https://download.hidet.org/whl hidet - bitsandbytes==0.38.1 - - hidet==0.2.3 - lion-pytorch==0.0.7 - - peft==0.2.0 + - peft==0.3.0 - python-slugify>=6.1.2 - safetensors==0.3.1 - setuptools==65.6.3 - test-tube>=0.7.5 - - timm==0.8.17.dev0 - - transformers==4.28.1 + - timm==0.9.2 + - transformers==4.29.1 diff --git a/train_lora.py b/train_lora.py index 12d7e72..c74dd8f 100644 --- a/train_lora.py +++ b/train_lora.py @@ -48,8 +48,8 @@ warnings.filterwarnings('ignore') torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.benchmark = True -torch._dynamo.config.log_level = logging.WARNING -# torch._dynamo.config.suppress_errors = True +# torch._dynamo.config.log_level = logging.WARNING +torch._dynamo.config.suppress_errors = True hidet.torch.dynamo_config.use_tensor_core(True) hidet.torch.dynamo_config.search_space(0) @@ -1143,6 +1143,28 @@ def main(): avg_loss_val = AverageMeter() avg_acc_val = AverageMeter() + params_to_optimize = [ + { + "params": ( + param + for param in unet.parameters() + if param.requires_grad + ), + "lr": learning_rate_unet, + }, + { + "params": ( + param + for param in text_encoder.parameters() + if param.requires_grad + ), + "lr": learning_rate_text, + } + ] + group_labels = ["unet", "text"] + + lora_optimizer = create_optimizer(params_to_optimize) + while True: if len(auto_cycles) != 0: response = auto_cycles.pop(0) @@ -1182,35 +1204,9 @@ def main(): print("") print(f"============ LoRA cycle {training_iter + 1}: {response} ============") print("") - - params_to_optimize = [] - group_labels = [] - - params_to_optimize.append({ - "params": ( - param - for param in unet.parameters() - if param.requires_grad - ), - "lr": learning_rate_unet, - }) - group_labels.append("unet") - - if training_iter < args.train_text_encoder_cycles: - params_to_optimize.append({ - "params": ( - param - for param in itertools.chain( - text_encoder.text_model.encoder.parameters(), - text_encoder.text_model.final_layer_norm.parameters(), - ) - if param.requires_grad - ), - "lr": learning_rate_text, - }) - group_labels.append("text") - - lora_optimizer = create_optimizer(params_to_optimize) + + for group, lr in zip(lora_optimizer.param_groups, [learning_rate_unet, learning_rate_text]): + group['lr'] = lr lora_lr_scheduler = create_lr_scheduler( lr_scheduler, diff --git a/training/functional.py b/training/functional.py index 10560e5..fd3f9f4 100644 --- a/training/functional.py +++ b/training/functional.py @@ -710,8 +710,8 @@ def train( vae = torch.compile(vae, backend='hidet') if compile_unet: - # unet = torch.compile(unet, backend='hidet') - unet = torch.compile(unet, mode="reduce-overhead") + unet = torch.compile(unet, backend='hidet') + # unet = torch.compile(unet, mode="reduce-overhead") callbacks = strategy.callbacks( accelerator=accelerator, diff --git a/training/sampler.py b/training/sampler.py index 8afe255..bdb3e90 100644 --- a/training/sampler.py +++ b/training/sampler.py @@ -129,7 +129,7 @@ class LossSecondMomentResampler(LossAwareSampler): self._loss_history = np.zeros( [self.num_timesteps, history_per_term], dtype=np.float64 ) - self._loss_counts = np.zeros([self.num_timesteps], dtype=np.int) + self._loss_counts = np.zeros([self.num_timesteps], dtype=int) def weights(self): if not self._warmed_up(): -- cgit v1.2.3-70-g09d2