diff options
| -rw-r--r-- | environment.yaml | 8 | ||||
| -rw-r--r-- | train_lora.py | 2 | ||||
| -rw-r--r-- | train_ti.py | 24 | ||||
| -rw-r--r-- | training/functional.py | 24 |
4 files changed, 45 insertions, 13 deletions
diff --git a/environment.yaml b/environment.yaml index dfbafaf..85033ce 100644 --- a/environment.yaml +++ b/environment.yaml | |||
| @@ -6,6 +6,8 @@ channels: | |||
| 6 | - defaults | 6 | - defaults |
| 7 | - conda-forge | 7 | - conda-forge |
| 8 | dependencies: | 8 | dependencies: |
| 9 | - cuda-nvcc=11.8 | ||
| 10 | - cuda-cudart-dev=11.8 | ||
| 9 | - gcc=11.3.0 | 11 | - gcc=11.3.0 |
| 10 | - gxx=11.3.0 | 12 | - gxx=11.3.0 |
| 11 | - matplotlib=3.6.2 | 13 | - matplotlib=3.6.2 |
| @@ -20,13 +22,13 @@ dependencies: | |||
| 20 | - -e git+https://github.com/huggingface/accelerate#egg=accelerate | 22 | - -e git+https://github.com/huggingface/accelerate#egg=accelerate |
| 21 | - -e git+https://github.com/huggingface/diffusers#egg=diffusers | 23 | - -e git+https://github.com/huggingface/diffusers#egg=diffusers |
| 22 | - -e git+https://github.com/facebookresearch/dadaptation#egg=dadaptation | 24 | - -e git+https://github.com/facebookresearch/dadaptation#egg=dadaptation |
| 25 | - --pre --extra-index-url https://download.hidet.org/whl hidet | ||
| 23 | - bitsandbytes==0.38.1 | 26 | - bitsandbytes==0.38.1 |
| 24 | - hidet==0.2.3 | ||
| 25 | - lion-pytorch==0.0.7 | 27 | - lion-pytorch==0.0.7 |
| 26 | - peft==0.2.0 | 28 | - peft==0.3.0 |
| 27 | - python-slugify>=6.1.2 | 29 | - python-slugify>=6.1.2 |
| 28 | - safetensors==0.3.1 | 30 | - safetensors==0.3.1 |
| 29 | - setuptools==65.6.3 | 31 | - setuptools==65.6.3 |
| 30 | - test-tube>=0.7.5 | 32 | - test-tube>=0.7.5 |
| 31 | - timm==0.8.17.dev0 | 33 | - timm==0.8.17.dev0 |
| 32 | - transformers==4.28.1 | 34 | - transformers==4.29.0 |
diff --git a/train_lora.py b/train_lora.py index 70fbae4..737af58 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -46,7 +46,7 @@ torch.backends.cudnn.benchmark = True | |||
| 46 | torch._dynamo.config.log_level = logging.WARNING | 46 | torch._dynamo.config.log_level = logging.WARNING |
| 47 | 47 | ||
| 48 | hidet.torch.dynamo_config.use_tensor_core(True) | 48 | hidet.torch.dynamo_config.use_tensor_core(True) |
| 49 | # hidet.torch.dynamo_config.use_attention(True) | 49 | hidet.torch.dynamo_config.use_attention(True) |
| 50 | hidet.torch.dynamo_config.search_space(0) | 50 | hidet.torch.dynamo_config.search_space(0) |
| 51 | 51 | ||
| 52 | 52 | ||
diff --git a/train_ti.py b/train_ti.py index 26f7941..6fd974e 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -5,13 +5,16 @@ from functools import partial | |||
| 5 | from pathlib import Path | 5 | from pathlib import Path |
| 6 | from typing import Union | 6 | from typing import Union |
| 7 | import math | 7 | import math |
| 8 | import warnings | ||
| 8 | 9 | ||
| 9 | import torch | 10 | import torch |
| 10 | import torch.utils.checkpoint | 11 | import torch.utils.checkpoint |
| 12 | import hidet | ||
| 11 | 13 | ||
| 12 | from accelerate import Accelerator | 14 | from accelerate import Accelerator |
| 13 | from accelerate.logging import get_logger | 15 | from accelerate.logging import get_logger |
| 14 | from accelerate.utils import LoggerType, set_seed | 16 | from accelerate.utils import LoggerType, set_seed |
| 17 | from diffusers.models.attention_processor import AttnProcessor | ||
| 15 | from timm.models import create_model | 18 | from timm.models import create_model |
| 16 | import transformers | 19 | import transformers |
| 17 | 20 | ||
| @@ -28,10 +31,18 @@ from training.util import AverageMeter, save_args | |||
| 28 | 31 | ||
| 29 | logger = get_logger(__name__) | 32 | logger = get_logger(__name__) |
| 30 | 33 | ||
| 34 | warnings.filterwarnings('ignore') | ||
| 35 | |||
| 31 | 36 | ||
| 32 | torch.backends.cuda.matmul.allow_tf32 = True | 37 | torch.backends.cuda.matmul.allow_tf32 = True |
| 33 | torch.backends.cudnn.benchmark = True | 38 | torch.backends.cudnn.benchmark = True |
| 34 | 39 | ||
| 40 | # torch._dynamo.config.log_level = logging.WARNING | ||
| 41 | |||
| 42 | hidet.torch.dynamo_config.use_tensor_core(True) | ||
| 43 | hidet.torch.dynamo_config.use_attention(True) | ||
| 44 | hidet.torch.dynamo_config.search_space(0) | ||
| 45 | |||
| 35 | 46 | ||
| 36 | def parse_args(): | 47 | def parse_args(): |
| 37 | parser = argparse.ArgumentParser( | 48 | parser = argparse.ArgumentParser( |
| @@ -706,6 +717,19 @@ def main(): | |||
| 706 | if args.use_xformers: | 717 | if args.use_xformers: |
| 707 | vae.set_use_memory_efficient_attention_xformers(True) | 718 | vae.set_use_memory_efficient_attention_xformers(True) |
| 708 | unet.enable_xformers_memory_efficient_attention() | 719 | unet.enable_xformers_memory_efficient_attention() |
| 720 | elif args.compile_unet: | ||
| 721 | unet.mid_block.attentions[0].transformer_blocks[0].attn1._use_2_0_attn = False | ||
| 722 | |||
| 723 | proc = AttnProcessor() | ||
| 724 | |||
| 725 | def fn_recursive_set_proc(module: torch.nn.Module): | ||
| 726 | if hasattr(module, "processor"): | ||
| 727 | module.processor = proc | ||
| 728 | |||
| 729 | for child in module.children(): | ||
| 730 | fn_recursive_set_proc(child) | ||
| 731 | |||
| 732 | fn_recursive_set_proc(unet) | ||
| 709 | 733 | ||
| 710 | if args.gradient_checkpointing: | 734 | if args.gradient_checkpointing: |
| 711 | unet.enable_gradient_checkpointing() | 735 | unet.enable_gradient_checkpointing() |
diff --git a/training/functional.py b/training/functional.py index eae5681..49c21c7 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -27,6 +27,7 @@ from models.convnext.discriminator import ConvNeXtDiscriminator | |||
| 27 | from training.util import AverageMeter | 27 | from training.util import AverageMeter |
| 28 | from training.sampler import ScheduleSampler, LossAwareSampler, UniformSampler | 28 | from training.sampler import ScheduleSampler, LossAwareSampler, UniformSampler |
| 29 | from util.slerp import slerp | 29 | from util.slerp import slerp |
| 30 | from util.noise import perlin_noise | ||
| 30 | 31 | ||
| 31 | 32 | ||
| 32 | def const(result=None): | 33 | def const(result=None): |
| @@ -350,28 +351,33 @@ def loss_step( | |||
| 350 | device=latents.device, | 351 | device=latents.device, |
| 351 | generator=generator | 352 | generator=generator |
| 352 | ) | 353 | ) |
| 354 | applied_noise = noise | ||
| 353 | 355 | ||
| 354 | if offset_noise_strength != 0: | 356 | if offset_noise_strength != 0: |
| 355 | offset_noise = torch.randn( | 357 | applied_noise = applied_noise + offset_noise_strength * perlin_noise( |
| 356 | (latents.shape[0], latents.shape[1], 1, 1), | 358 | latents.shape, |
| 359 | res=1, | ||
| 360 | octaves=4, | ||
| 357 | dtype=latents.dtype, | 361 | dtype=latents.dtype, |
| 358 | device=latents.device, | 362 | device=latents.device, |
| 359 | generator=generator | 363 | generator=generator |
| 360 | ).expand(noise.shape) | 364 | ) |
| 361 | noise = noise + offset_noise_strength * offset_noise | ||
| 362 | 365 | ||
| 363 | if input_pertubation != 0: | 366 | if input_pertubation != 0: |
| 364 | new_noise = noise + input_pertubation * torch.randn_like(noise) | 367 | applied_noise = applied_noise + input_pertubation * torch.randn( |
| 368 | latents.shape, | ||
| 369 | dtype=latents.dtype, | ||
| 370 | layout=latents.layout, | ||
| 371 | device=latents.device, | ||
| 372 | generator=generator | ||
| 373 | ) | ||
| 365 | 374 | ||
| 366 | # Sample a random timestep for each image | 375 | # Sample a random timestep for each image |
| 367 | timesteps, weights = schedule_sampler.sample(bsz, latents.device) | 376 | timesteps, weights = schedule_sampler.sample(bsz, latents.device) |
| 368 | 377 | ||
| 369 | # Add noise to the latents according to the noise magnitude at each timestep | 378 | # Add noise to the latents according to the noise magnitude at each timestep |
| 370 | # (this is the forward diffusion process) | 379 | # (this is the forward diffusion process) |
| 371 | if input_pertubation != 0: | 380 | noisy_latents = noise_scheduler.add_noise(latents, applied_noise, timesteps) |
| 372 | noisy_latents = noise_scheduler.add_noise(latents, new_noise, timesteps) | ||
| 373 | else: | ||
| 374 | noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) | ||
| 375 | noisy_latents = noisy_latents.to(dtype=unet.dtype) | 381 | noisy_latents = noisy_latents.to(dtype=unet.dtype) |
| 376 | 382 | ||
| 377 | # Get the text embedding for conditioning | 383 | # Get the text embedding for conditioning |
