From f0a171923cc8240177302f3dccb6177a2c646ab3 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Thu, 11 May 2023 18:37:43 +0200 Subject: Update --- environment.yaml | 8 +++++--- train_lora.py | 2 +- train_ti.py | 24 ++++++++++++++++++++++++ training/functional.py | 24 +++++++++++++++--------- 4 files changed, 45 insertions(+), 13 deletions(-) diff --git a/environment.yaml b/environment.yaml index dfbafaf..85033ce 100644 --- a/environment.yaml +++ b/environment.yaml @@ -6,6 +6,8 @@ channels: - defaults - conda-forge dependencies: + - cuda-nvcc=11.8 + - cuda-cudart-dev=11.8 - gcc=11.3.0 - gxx=11.3.0 - matplotlib=3.6.2 @@ -20,13 +22,13 @@ dependencies: - -e git+https://github.com/huggingface/accelerate#egg=accelerate - -e git+https://github.com/huggingface/diffusers#egg=diffusers - -e git+https://github.com/facebookresearch/dadaptation#egg=dadaptation + - --pre --extra-index-url https://download.hidet.org/whl hidet - bitsandbytes==0.38.1 - - hidet==0.2.3 - lion-pytorch==0.0.7 - - peft==0.2.0 + - peft==0.3.0 - python-slugify>=6.1.2 - safetensors==0.3.1 - setuptools==65.6.3 - test-tube>=0.7.5 - timm==0.8.17.dev0 - - transformers==4.28.1 + - transformers==4.29.0 diff --git a/train_lora.py b/train_lora.py index 70fbae4..737af58 100644 --- a/train_lora.py +++ b/train_lora.py @@ -46,7 +46,7 @@ torch.backends.cudnn.benchmark = True torch._dynamo.config.log_level = logging.WARNING hidet.torch.dynamo_config.use_tensor_core(True) -# hidet.torch.dynamo_config.use_attention(True) +hidet.torch.dynamo_config.use_attention(True) hidet.torch.dynamo_config.search_space(0) diff --git a/train_ti.py b/train_ti.py index 26f7941..6fd974e 100644 --- a/train_ti.py +++ b/train_ti.py @@ -5,13 +5,16 @@ from functools import partial from pathlib import Path from typing import Union import math +import warnings import torch import torch.utils.checkpoint +import hidet from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import LoggerType, set_seed +from diffusers.models.attention_processor import AttnProcessor from timm.models import create_model import transformers @@ -28,10 +31,18 @@ from training.util import AverageMeter, save_args logger = get_logger(__name__) +warnings.filterwarnings('ignore') + torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.benchmark = True +# torch._dynamo.config.log_level = logging.WARNING + +hidet.torch.dynamo_config.use_tensor_core(True) +hidet.torch.dynamo_config.use_attention(True) +hidet.torch.dynamo_config.search_space(0) + def parse_args(): parser = argparse.ArgumentParser( @@ -706,6 +717,19 @@ def main(): if args.use_xformers: vae.set_use_memory_efficient_attention_xformers(True) unet.enable_xformers_memory_efficient_attention() + elif args.compile_unet: + unet.mid_block.attentions[0].transformer_blocks[0].attn1._use_2_0_attn = False + + proc = AttnProcessor() + + def fn_recursive_set_proc(module: torch.nn.Module): + if hasattr(module, "processor"): + module.processor = proc + + for child in module.children(): + fn_recursive_set_proc(child) + + fn_recursive_set_proc(unet) if args.gradient_checkpointing: unet.enable_gradient_checkpointing() diff --git a/training/functional.py b/training/functional.py index eae5681..49c21c7 100644 --- a/training/functional.py +++ b/training/functional.py @@ -27,6 +27,7 @@ from models.convnext.discriminator import ConvNeXtDiscriminator from training.util import AverageMeter from training.sampler import ScheduleSampler, LossAwareSampler, UniformSampler from util.slerp import slerp +from util.noise import perlin_noise def const(result=None): @@ -350,28 +351,33 @@ def loss_step( device=latents.device, generator=generator ) + applied_noise = noise if offset_noise_strength != 0: - offset_noise = torch.randn( - (latents.shape[0], latents.shape[1], 1, 1), + applied_noise = applied_noise + offset_noise_strength * perlin_noise( + latents.shape, + res=1, + octaves=4, dtype=latents.dtype, device=latents.device, generator=generator - ).expand(noise.shape) - noise = noise + offset_noise_strength * offset_noise + ) if input_pertubation != 0: - new_noise = noise + input_pertubation * torch.randn_like(noise) + applied_noise = applied_noise + input_pertubation * torch.randn( + latents.shape, + dtype=latents.dtype, + layout=latents.layout, + device=latents.device, + generator=generator + ) # Sample a random timestep for each image timesteps, weights = schedule_sampler.sample(bsz, latents.device) # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) - if input_pertubation != 0: - noisy_latents = noise_scheduler.add_noise(latents, new_noise, timesteps) - else: - noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) + noisy_latents = noise_scheduler.add_noise(latents, applied_noise, timesteps) noisy_latents = noisy_latents.to(dtype=unet.dtype) # Get the text embedding for conditioning -- cgit v1.2.3-70-g09d2