diff options
-rw-r--r-- | environment.yaml | 4 | ||||
-rw-r--r-- | models/clip/tokenizer.py | 2 | ||||
-rw-r--r-- | train_ti.py | 41 | ||||
-rw-r--r-- | training/util.py | 12 |
4 files changed, 26 insertions, 33 deletions
diff --git a/environment.yaml b/environment.yaml index 6e689c7..eff69ed 100644 --- a/environment.yaml +++ b/environment.yaml | |||
@@ -16,11 +16,11 @@ dependencies: | |||
16 | - -e git+https://github.com/huggingface/diffusers#egg=diffusers | 16 | - -e git+https://github.com/huggingface/diffusers#egg=diffusers |
17 | - -e git+https://github.com/cloneofsimo/lora#egg=lora-diffusion | 17 | - -e git+https://github.com/cloneofsimo/lora#egg=lora-diffusion |
18 | - accelerate==0.15.0 | 18 | - accelerate==0.15.0 |
19 | - bitsandbytes==0.35.4 | 19 | - bitsandbytes==0.36.0.post2 |
20 | - python-slugify>=6.1.2 | 20 | - python-slugify>=6.1.2 |
21 | - safetensors==0.2.7 | 21 | - safetensors==0.2.7 |
22 | - setuptools==65.6.3 | 22 | - setuptools==65.6.3 |
23 | - test-tube>=0.7.5 | 23 | - test-tube>=0.7.5 |
24 | - transformers==4.25.1 | 24 | - transformers==4.25.1 |
25 | - triton==2.0.0.dev20221202 | 25 | - triton==2.0.0.dev20221202 |
26 | - xformers==0.0.16rc399 | 26 | - xformers==0.0.16rc401 |
diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py index 034adf9..39c41ed 100644 --- a/models/clip/tokenizer.py +++ b/models/clip/tokenizer.py | |||
@@ -1,5 +1,5 @@ | |||
1 | import copy | 1 | import copy |
2 | from typing import NamedTuple, Union, Literal | 2 | from typing import Union, Literal |
3 | 3 | ||
4 | import numpy as np | 4 | import numpy as np |
5 | 5 | ||
diff --git a/train_ti.py b/train_ti.py index f622299..9aab00c 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -280,7 +280,7 @@ def parse_args(): | |||
280 | parser.add_argument( | 280 | parser.add_argument( |
281 | "--ema_power", | 281 | "--ema_power", |
282 | type=float, | 282 | type=float, |
283 | default=6/7 | 283 | default=7/8 |
284 | ) | 284 | ) |
285 | parser.add_argument( | 285 | parser.add_argument( |
286 | "--ema_max_decay", | 286 | "--ema_max_decay", |
@@ -464,30 +464,19 @@ class Checkpointer(CheckpointerBase): | |||
464 | def __init__( | 464 | def __init__( |
465 | self, | 465 | self, |
466 | weight_dtype, | 466 | weight_dtype, |
467 | datamodule, | 467 | accelerator: Accelerator, |
468 | accelerator, | 468 | vae: AutoencoderKL, |
469 | vae, | 469 | unet: UNet2DConditionModel, |
470 | unet, | 470 | tokenizer: MultiCLIPTokenizer, |
471 | tokenizer, | 471 | text_encoder: CLIPTextModel, |
472 | text_encoder, | 472 | ema_embeddings: EMAModel, |
473 | ema_embeddings, | ||
474 | scheduler, | 473 | scheduler, |
475 | placeholder_token, | 474 | placeholder_token, |
476 | new_ids, | 475 | new_ids, |
477 | output_dir: Path, | 476 | *args, |
478 | sample_image_size, | 477 | **kwargs |
479 | sample_batches, | ||
480 | sample_batch_size, | ||
481 | seed | ||
482 | ): | 478 | ): |
483 | super().__init__( | 479 | super().__init__(*args, **kwargs) |
484 | datamodule=datamodule, | ||
485 | output_dir=output_dir, | ||
486 | sample_image_size=sample_image_size, | ||
487 | seed=seed or torch.random.seed(), | ||
488 | sample_batches=sample_batches, | ||
489 | sample_batch_size=sample_batch_size | ||
490 | ) | ||
491 | 480 | ||
492 | self.weight_dtype = weight_dtype | 481 | self.weight_dtype = weight_dtype |
493 | self.accelerator = accelerator | 482 | self.accelerator = accelerator |
@@ -829,7 +818,9 @@ def main(): | |||
829 | # Move vae and unet to device | 818 | # Move vae and unet to device |
830 | vae.to(accelerator.device, dtype=weight_dtype) | 819 | vae.to(accelerator.device, dtype=weight_dtype) |
831 | unet.to(accelerator.device, dtype=weight_dtype) | 820 | unet.to(accelerator.device, dtype=weight_dtype) |
832 | ema_embeddings.to(accelerator.device) | 821 | |
822 | if args.use_ema: | ||
823 | ema_embeddings.to(accelerator.device) | ||
833 | 824 | ||
834 | # Keep vae and unet in eval mode as we don't train these | 825 | # Keep vae and unet in eval mode as we don't train these |
835 | vae.eval() | 826 | vae.eval() |
@@ -854,13 +845,15 @@ def main(): | |||
854 | tokenizer.train() | 845 | tokenizer.train() |
855 | yield | 846 | yield |
856 | finally: | 847 | finally: |
857 | tokenizer.eval() | 848 | pass |
858 | 849 | ||
859 | @contextmanager | 850 | @contextmanager |
860 | def on_eval(): | 851 | def on_eval(): |
861 | try: | 852 | try: |
853 | tokenizer.eval() | ||
854 | |||
862 | ema_context = ema_embeddings.apply_temporary( | 855 | ema_context = ema_embeddings.apply_temporary( |
863 | text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if args.use_ema is not None and eval else nullcontext() | 856 | text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if args.use_ema else nullcontext() |
864 | 857 | ||
865 | with ema_context: | 858 | with ema_context: |
866 | yield | 859 | yield |
diff --git a/training/util.py b/training/util.py index bed7111..bc466e2 100644 --- a/training/util.py +++ b/training/util.py | |||
@@ -1,7 +1,7 @@ | |||
1 | from pathlib import Path | 1 | from pathlib import Path |
2 | import json | 2 | import json |
3 | import copy | 3 | import copy |
4 | from typing import Iterable | 4 | from typing import Iterable, Optional |
5 | from contextlib import contextmanager | 5 | from contextlib import contextmanager |
6 | 6 | ||
7 | import torch | 7 | import torch |
@@ -42,15 +42,15 @@ class CheckpointerBase: | |||
42 | self, | 42 | self, |
43 | datamodule, | 43 | datamodule, |
44 | output_dir: Path, | 44 | output_dir: Path, |
45 | sample_image_size, | 45 | sample_image_size: int, |
46 | sample_batches, | 46 | sample_batches: int, |
47 | sample_batch_size, | 47 | sample_batch_size: int, |
48 | seed | 48 | seed: Optional[int] = None |
49 | ): | 49 | ): |
50 | self.datamodule = datamodule | 50 | self.datamodule = datamodule |
51 | self.output_dir = output_dir | 51 | self.output_dir = output_dir |
52 | self.sample_image_size = sample_image_size | 52 | self.sample_image_size = sample_image_size |
53 | self.seed = seed or torch.random.seed() | 53 | self.seed = seed if seed is not None else torch.random.seed() |
54 | self.sample_batches = sample_batches | 54 | self.sample_batches = sample_batches |
55 | self.sample_batch_size = sample_batch_size | 55 | self.sample_batch_size = sample_batch_size |
56 | 56 | ||