summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--environment.yaml4
-rw-r--r--models/clip/tokenizer.py2
-rw-r--r--train_ti.py41
-rw-r--r--training/util.py12
4 files changed, 26 insertions, 33 deletions
diff --git a/environment.yaml b/environment.yaml
index 6e689c7..eff69ed 100644
--- a/environment.yaml
+++ b/environment.yaml
@@ -16,11 +16,11 @@ dependencies:
16 - -e git+https://github.com/huggingface/diffusers#egg=diffusers 16 - -e git+https://github.com/huggingface/diffusers#egg=diffusers
17 - -e git+https://github.com/cloneofsimo/lora#egg=lora-diffusion 17 - -e git+https://github.com/cloneofsimo/lora#egg=lora-diffusion
18 - accelerate==0.15.0 18 - accelerate==0.15.0
19 - bitsandbytes==0.35.4 19 - bitsandbytes==0.36.0.post2
20 - python-slugify>=6.1.2 20 - python-slugify>=6.1.2
21 - safetensors==0.2.7 21 - safetensors==0.2.7
22 - setuptools==65.6.3 22 - setuptools==65.6.3
23 - test-tube>=0.7.5 23 - test-tube>=0.7.5
24 - transformers==4.25.1 24 - transformers==4.25.1
25 - triton==2.0.0.dev20221202 25 - triton==2.0.0.dev20221202
26 - xformers==0.0.16rc399 26 - xformers==0.0.16rc401
diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py
index 034adf9..39c41ed 100644
--- a/models/clip/tokenizer.py
+++ b/models/clip/tokenizer.py
@@ -1,5 +1,5 @@
1import copy 1import copy
2from typing import NamedTuple, Union, Literal 2from typing import Union, Literal
3 3
4import numpy as np 4import numpy as np
5 5
diff --git a/train_ti.py b/train_ti.py
index f622299..9aab00c 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -280,7 +280,7 @@ def parse_args():
280 parser.add_argument( 280 parser.add_argument(
281 "--ema_power", 281 "--ema_power",
282 type=float, 282 type=float,
283 default=6/7 283 default=7/8
284 ) 284 )
285 parser.add_argument( 285 parser.add_argument(
286 "--ema_max_decay", 286 "--ema_max_decay",
@@ -464,30 +464,19 @@ class Checkpointer(CheckpointerBase):
464 def __init__( 464 def __init__(
465 self, 465 self,
466 weight_dtype, 466 weight_dtype,
467 datamodule, 467 accelerator: Accelerator,
468 accelerator, 468 vae: AutoencoderKL,
469 vae, 469 unet: UNet2DConditionModel,
470 unet, 470 tokenizer: MultiCLIPTokenizer,
471 tokenizer, 471 text_encoder: CLIPTextModel,
472 text_encoder, 472 ema_embeddings: EMAModel,
473 ema_embeddings,
474 scheduler, 473 scheduler,
475 placeholder_token, 474 placeholder_token,
476 new_ids, 475 new_ids,
477 output_dir: Path, 476 *args,
478 sample_image_size, 477 **kwargs
479 sample_batches,
480 sample_batch_size,
481 seed
482 ): 478 ):
483 super().__init__( 479 super().__init__(*args, **kwargs)
484 datamodule=datamodule,
485 output_dir=output_dir,
486 sample_image_size=sample_image_size,
487 seed=seed or torch.random.seed(),
488 sample_batches=sample_batches,
489 sample_batch_size=sample_batch_size
490 )
491 480
492 self.weight_dtype = weight_dtype 481 self.weight_dtype = weight_dtype
493 self.accelerator = accelerator 482 self.accelerator = accelerator
@@ -829,7 +818,9 @@ def main():
829 # Move vae and unet to device 818 # Move vae and unet to device
830 vae.to(accelerator.device, dtype=weight_dtype) 819 vae.to(accelerator.device, dtype=weight_dtype)
831 unet.to(accelerator.device, dtype=weight_dtype) 820 unet.to(accelerator.device, dtype=weight_dtype)
832 ema_embeddings.to(accelerator.device) 821
822 if args.use_ema:
823 ema_embeddings.to(accelerator.device)
833 824
834 # Keep vae and unet in eval mode as we don't train these 825 # Keep vae and unet in eval mode as we don't train these
835 vae.eval() 826 vae.eval()
@@ -854,13 +845,15 @@ def main():
854 tokenizer.train() 845 tokenizer.train()
855 yield 846 yield
856 finally: 847 finally:
857 tokenizer.eval() 848 pass
858 849
859 @contextmanager 850 @contextmanager
860 def on_eval(): 851 def on_eval():
861 try: 852 try:
853 tokenizer.eval()
854
862 ema_context = ema_embeddings.apply_temporary( 855 ema_context = ema_embeddings.apply_temporary(
863 text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if args.use_ema is not None and eval else nullcontext() 856 text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if args.use_ema else nullcontext()
864 857
865 with ema_context: 858 with ema_context:
866 yield 859 yield
diff --git a/training/util.py b/training/util.py
index bed7111..bc466e2 100644
--- a/training/util.py
+++ b/training/util.py
@@ -1,7 +1,7 @@
1from pathlib import Path 1from pathlib import Path
2import json 2import json
3import copy 3import copy
4from typing import Iterable 4from typing import Iterable, Optional
5from contextlib import contextmanager 5from contextlib import contextmanager
6 6
7import torch 7import torch
@@ -42,15 +42,15 @@ class CheckpointerBase:
42 self, 42 self,
43 datamodule, 43 datamodule,
44 output_dir: Path, 44 output_dir: Path,
45 sample_image_size, 45 sample_image_size: int,
46 sample_batches, 46 sample_batches: int,
47 sample_batch_size, 47 sample_batch_size: int,
48 seed 48 seed: Optional[int] = None
49 ): 49 ):
50 self.datamodule = datamodule 50 self.datamodule = datamodule
51 self.output_dir = output_dir 51 self.output_dir = output_dir
52 self.sample_image_size = sample_image_size 52 self.sample_image_size = sample_image_size
53 self.seed = seed or torch.random.seed() 53 self.seed = seed if seed is not None else torch.random.seed()
54 self.sample_batches = sample_batches 54 self.sample_batches = sample_batches
55 self.sample_batch_size = sample_batch_size 55 self.sample_batch_size = sample_batch_size
56 56