summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-05-11 18:37:43 +0200
committerVolpeon <git@volpeon.ink>2023-05-11 18:37:43 +0200
commitf0a171923cc8240177302f3dccb6177a2c646ab3 (patch)
tree15f11293c048c8c9b5d625dc27bb8662e15e685e
parentUpdate (diff)
downloadtextual-inversion-diff-f0a171923cc8240177302f3dccb6177a2c646ab3.tar.gz
textual-inversion-diff-f0a171923cc8240177302f3dccb6177a2c646ab3.tar.bz2
textual-inversion-diff-f0a171923cc8240177302f3dccb6177a2c646ab3.zip
Update
-rw-r--r--environment.yaml8
-rw-r--r--train_lora.py2
-rw-r--r--train_ti.py24
-rw-r--r--training/functional.py24
4 files changed, 45 insertions, 13 deletions
diff --git a/environment.yaml b/environment.yaml
index dfbafaf..85033ce 100644
--- a/environment.yaml
+++ b/environment.yaml
@@ -6,6 +6,8 @@ channels:
6 - defaults 6 - defaults
7 - conda-forge 7 - conda-forge
8dependencies: 8dependencies:
9 - cuda-nvcc=11.8
10 - cuda-cudart-dev=11.8
9 - gcc=11.3.0 11 - gcc=11.3.0
10 - gxx=11.3.0 12 - gxx=11.3.0
11 - matplotlib=3.6.2 13 - matplotlib=3.6.2
@@ -20,13 +22,13 @@ dependencies:
20 - -e git+https://github.com/huggingface/accelerate#egg=accelerate 22 - -e git+https://github.com/huggingface/accelerate#egg=accelerate
21 - -e git+https://github.com/huggingface/diffusers#egg=diffusers 23 - -e git+https://github.com/huggingface/diffusers#egg=diffusers
22 - -e git+https://github.com/facebookresearch/dadaptation#egg=dadaptation 24 - -e git+https://github.com/facebookresearch/dadaptation#egg=dadaptation
25 - --pre --extra-index-url https://download.hidet.org/whl hidet
23 - bitsandbytes==0.38.1 26 - bitsandbytes==0.38.1
24 - hidet==0.2.3
25 - lion-pytorch==0.0.7 27 - lion-pytorch==0.0.7
26 - peft==0.2.0 28 - peft==0.3.0
27 - python-slugify>=6.1.2 29 - python-slugify>=6.1.2
28 - safetensors==0.3.1 30 - safetensors==0.3.1
29 - setuptools==65.6.3 31 - setuptools==65.6.3
30 - test-tube>=0.7.5 32 - test-tube>=0.7.5
31 - timm==0.8.17.dev0 33 - timm==0.8.17.dev0
32 - transformers==4.28.1 34 - transformers==4.29.0
diff --git a/train_lora.py b/train_lora.py
index 70fbae4..737af58 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -46,7 +46,7 @@ torch.backends.cudnn.benchmark = True
46torch._dynamo.config.log_level = logging.WARNING 46torch._dynamo.config.log_level = logging.WARNING
47 47
48hidet.torch.dynamo_config.use_tensor_core(True) 48hidet.torch.dynamo_config.use_tensor_core(True)
49# hidet.torch.dynamo_config.use_attention(True) 49hidet.torch.dynamo_config.use_attention(True)
50hidet.torch.dynamo_config.search_space(0) 50hidet.torch.dynamo_config.search_space(0)
51 51
52 52
diff --git a/train_ti.py b/train_ti.py
index 26f7941..6fd974e 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -5,13 +5,16 @@ from functools import partial
5from pathlib import Path 5from pathlib import Path
6from typing import Union 6from typing import Union
7import math 7import math
8import warnings
8 9
9import torch 10import torch
10import torch.utils.checkpoint 11import torch.utils.checkpoint
12import hidet
11 13
12from accelerate import Accelerator 14from accelerate import Accelerator
13from accelerate.logging import get_logger 15from accelerate.logging import get_logger
14from accelerate.utils import LoggerType, set_seed 16from accelerate.utils import LoggerType, set_seed
17from diffusers.models.attention_processor import AttnProcessor
15from timm.models import create_model 18from timm.models import create_model
16import transformers 19import transformers
17 20
@@ -28,10 +31,18 @@ from training.util import AverageMeter, save_args
28 31
29logger = get_logger(__name__) 32logger = get_logger(__name__)
30 33
34warnings.filterwarnings('ignore')
35
31 36
32torch.backends.cuda.matmul.allow_tf32 = True 37torch.backends.cuda.matmul.allow_tf32 = True
33torch.backends.cudnn.benchmark = True 38torch.backends.cudnn.benchmark = True
34 39
40# torch._dynamo.config.log_level = logging.WARNING
41
42hidet.torch.dynamo_config.use_tensor_core(True)
43hidet.torch.dynamo_config.use_attention(True)
44hidet.torch.dynamo_config.search_space(0)
45
35 46
36def parse_args(): 47def parse_args():
37 parser = argparse.ArgumentParser( 48 parser = argparse.ArgumentParser(
@@ -706,6 +717,19 @@ def main():
706 if args.use_xformers: 717 if args.use_xformers:
707 vae.set_use_memory_efficient_attention_xformers(True) 718 vae.set_use_memory_efficient_attention_xformers(True)
708 unet.enable_xformers_memory_efficient_attention() 719 unet.enable_xformers_memory_efficient_attention()
720 elif args.compile_unet:
721 unet.mid_block.attentions[0].transformer_blocks[0].attn1._use_2_0_attn = False
722
723 proc = AttnProcessor()
724
725 def fn_recursive_set_proc(module: torch.nn.Module):
726 if hasattr(module, "processor"):
727 module.processor = proc
728
729 for child in module.children():
730 fn_recursive_set_proc(child)
731
732 fn_recursive_set_proc(unet)
709 733
710 if args.gradient_checkpointing: 734 if args.gradient_checkpointing:
711 unet.enable_gradient_checkpointing() 735 unet.enable_gradient_checkpointing()
diff --git a/training/functional.py b/training/functional.py
index eae5681..49c21c7 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -27,6 +27,7 @@ from models.convnext.discriminator import ConvNeXtDiscriminator
27from training.util import AverageMeter 27from training.util import AverageMeter
28from training.sampler import ScheduleSampler, LossAwareSampler, UniformSampler 28from training.sampler import ScheduleSampler, LossAwareSampler, UniformSampler
29from util.slerp import slerp 29from util.slerp import slerp
30from util.noise import perlin_noise
30 31
31 32
32def const(result=None): 33def const(result=None):
@@ -350,28 +351,33 @@ def loss_step(
350 device=latents.device, 351 device=latents.device,
351 generator=generator 352 generator=generator
352 ) 353 )
354 applied_noise = noise
353 355
354 if offset_noise_strength != 0: 356 if offset_noise_strength != 0:
355 offset_noise = torch.randn( 357 applied_noise = applied_noise + offset_noise_strength * perlin_noise(
356 (latents.shape[0], latents.shape[1], 1, 1), 358 latents.shape,
359 res=1,
360 octaves=4,
357 dtype=latents.dtype, 361 dtype=latents.dtype,
358 device=latents.device, 362 device=latents.device,
359 generator=generator 363 generator=generator
360 ).expand(noise.shape) 364 )
361 noise = noise + offset_noise_strength * offset_noise
362 365
363 if input_pertubation != 0: 366 if input_pertubation != 0:
364 new_noise = noise + input_pertubation * torch.randn_like(noise) 367 applied_noise = applied_noise + input_pertubation * torch.randn(
368 latents.shape,
369 dtype=latents.dtype,
370 layout=latents.layout,
371 device=latents.device,
372 generator=generator
373 )
365 374
366 # Sample a random timestep for each image 375 # Sample a random timestep for each image
367 timesteps, weights = schedule_sampler.sample(bsz, latents.device) 376 timesteps, weights = schedule_sampler.sample(bsz, latents.device)
368 377
369 # Add noise to the latents according to the noise magnitude at each timestep 378 # Add noise to the latents according to the noise magnitude at each timestep
370 # (this is the forward diffusion process) 379 # (this is the forward diffusion process)
371 if input_pertubation != 0: 380 noisy_latents = noise_scheduler.add_noise(latents, applied_noise, timesteps)
372 noisy_latents = noise_scheduler.add_noise(latents, new_noise, timesteps)
373 else:
374 noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
375 noisy_latents = noisy_latents.to(dtype=unet.dtype) 381 noisy_latents = noisy_latents.to(dtype=unet.dtype)
376 382
377 # Get the text embedding for conditioning 383 # Get the text embedding for conditioning