diff options
| -rw-r--r-- | environment.yaml | 4 | ||||
| -rw-r--r-- | models/clip/tokenizer.py | 2 | ||||
| -rw-r--r-- | train_ti.py | 41 | ||||
| -rw-r--r-- | training/util.py | 12 |
4 files changed, 26 insertions, 33 deletions
diff --git a/environment.yaml b/environment.yaml index 6e689c7..eff69ed 100644 --- a/environment.yaml +++ b/environment.yaml | |||
| @@ -16,11 +16,11 @@ dependencies: | |||
| 16 | - -e git+https://github.com/huggingface/diffusers#egg=diffusers | 16 | - -e git+https://github.com/huggingface/diffusers#egg=diffusers |
| 17 | - -e git+https://github.com/cloneofsimo/lora#egg=lora-diffusion | 17 | - -e git+https://github.com/cloneofsimo/lora#egg=lora-diffusion |
| 18 | - accelerate==0.15.0 | 18 | - accelerate==0.15.0 |
| 19 | - bitsandbytes==0.35.4 | 19 | - bitsandbytes==0.36.0.post2 |
| 20 | - python-slugify>=6.1.2 | 20 | - python-slugify>=6.1.2 |
| 21 | - safetensors==0.2.7 | 21 | - safetensors==0.2.7 |
| 22 | - setuptools==65.6.3 | 22 | - setuptools==65.6.3 |
| 23 | - test-tube>=0.7.5 | 23 | - test-tube>=0.7.5 |
| 24 | - transformers==4.25.1 | 24 | - transformers==4.25.1 |
| 25 | - triton==2.0.0.dev20221202 | 25 | - triton==2.0.0.dev20221202 |
| 26 | - xformers==0.0.16rc399 | 26 | - xformers==0.0.16rc401 |
diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py index 034adf9..39c41ed 100644 --- a/models/clip/tokenizer.py +++ b/models/clip/tokenizer.py | |||
| @@ -1,5 +1,5 @@ | |||
| 1 | import copy | 1 | import copy |
| 2 | from typing import NamedTuple, Union, Literal | 2 | from typing import Union, Literal |
| 3 | 3 | ||
| 4 | import numpy as np | 4 | import numpy as np |
| 5 | 5 | ||
diff --git a/train_ti.py b/train_ti.py index f622299..9aab00c 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -280,7 +280,7 @@ def parse_args(): | |||
| 280 | parser.add_argument( | 280 | parser.add_argument( |
| 281 | "--ema_power", | 281 | "--ema_power", |
| 282 | type=float, | 282 | type=float, |
| 283 | default=6/7 | 283 | default=7/8 |
| 284 | ) | 284 | ) |
| 285 | parser.add_argument( | 285 | parser.add_argument( |
| 286 | "--ema_max_decay", | 286 | "--ema_max_decay", |
| @@ -464,30 +464,19 @@ class Checkpointer(CheckpointerBase): | |||
| 464 | def __init__( | 464 | def __init__( |
| 465 | self, | 465 | self, |
| 466 | weight_dtype, | 466 | weight_dtype, |
| 467 | datamodule, | 467 | accelerator: Accelerator, |
| 468 | accelerator, | 468 | vae: AutoencoderKL, |
| 469 | vae, | 469 | unet: UNet2DConditionModel, |
| 470 | unet, | 470 | tokenizer: MultiCLIPTokenizer, |
| 471 | tokenizer, | 471 | text_encoder: CLIPTextModel, |
| 472 | text_encoder, | 472 | ema_embeddings: EMAModel, |
| 473 | ema_embeddings, | ||
| 474 | scheduler, | 473 | scheduler, |
| 475 | placeholder_token, | 474 | placeholder_token, |
| 476 | new_ids, | 475 | new_ids, |
| 477 | output_dir: Path, | 476 | *args, |
| 478 | sample_image_size, | 477 | **kwargs |
| 479 | sample_batches, | ||
| 480 | sample_batch_size, | ||
| 481 | seed | ||
| 482 | ): | 478 | ): |
| 483 | super().__init__( | 479 | super().__init__(*args, **kwargs) |
| 484 | datamodule=datamodule, | ||
| 485 | output_dir=output_dir, | ||
| 486 | sample_image_size=sample_image_size, | ||
| 487 | seed=seed or torch.random.seed(), | ||
| 488 | sample_batches=sample_batches, | ||
| 489 | sample_batch_size=sample_batch_size | ||
| 490 | ) | ||
| 491 | 480 | ||
| 492 | self.weight_dtype = weight_dtype | 481 | self.weight_dtype = weight_dtype |
| 493 | self.accelerator = accelerator | 482 | self.accelerator = accelerator |
| @@ -829,7 +818,9 @@ def main(): | |||
| 829 | # Move vae and unet to device | 818 | # Move vae and unet to device |
| 830 | vae.to(accelerator.device, dtype=weight_dtype) | 819 | vae.to(accelerator.device, dtype=weight_dtype) |
| 831 | unet.to(accelerator.device, dtype=weight_dtype) | 820 | unet.to(accelerator.device, dtype=weight_dtype) |
| 832 | ema_embeddings.to(accelerator.device) | 821 | |
| 822 | if args.use_ema: | ||
| 823 | ema_embeddings.to(accelerator.device) | ||
| 833 | 824 | ||
| 834 | # Keep vae and unet in eval mode as we don't train these | 825 | # Keep vae and unet in eval mode as we don't train these |
| 835 | vae.eval() | 826 | vae.eval() |
| @@ -854,13 +845,15 @@ def main(): | |||
| 854 | tokenizer.train() | 845 | tokenizer.train() |
| 855 | yield | 846 | yield |
| 856 | finally: | 847 | finally: |
| 857 | tokenizer.eval() | 848 | pass |
| 858 | 849 | ||
| 859 | @contextmanager | 850 | @contextmanager |
| 860 | def on_eval(): | 851 | def on_eval(): |
| 861 | try: | 852 | try: |
| 853 | tokenizer.eval() | ||
| 854 | |||
| 862 | ema_context = ema_embeddings.apply_temporary( | 855 | ema_context = ema_embeddings.apply_temporary( |
| 863 | text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if args.use_ema is not None and eval else nullcontext() | 856 | text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if args.use_ema else nullcontext() |
| 864 | 857 | ||
| 865 | with ema_context: | 858 | with ema_context: |
| 866 | yield | 859 | yield |
diff --git a/training/util.py b/training/util.py index bed7111..bc466e2 100644 --- a/training/util.py +++ b/training/util.py | |||
| @@ -1,7 +1,7 @@ | |||
| 1 | from pathlib import Path | 1 | from pathlib import Path |
| 2 | import json | 2 | import json |
| 3 | import copy | 3 | import copy |
| 4 | from typing import Iterable | 4 | from typing import Iterable, Optional |
| 5 | from contextlib import contextmanager | 5 | from contextlib import contextmanager |
| 6 | 6 | ||
| 7 | import torch | 7 | import torch |
| @@ -42,15 +42,15 @@ class CheckpointerBase: | |||
| 42 | self, | 42 | self, |
| 43 | datamodule, | 43 | datamodule, |
| 44 | output_dir: Path, | 44 | output_dir: Path, |
| 45 | sample_image_size, | 45 | sample_image_size: int, |
| 46 | sample_batches, | 46 | sample_batches: int, |
| 47 | sample_batch_size, | 47 | sample_batch_size: int, |
| 48 | seed | 48 | seed: Optional[int] = None |
| 49 | ): | 49 | ): |
| 50 | self.datamodule = datamodule | 50 | self.datamodule = datamodule |
| 51 | self.output_dir = output_dir | 51 | self.output_dir = output_dir |
| 52 | self.sample_image_size = sample_image_size | 52 | self.sample_image_size = sample_image_size |
| 53 | self.seed = seed or torch.random.seed() | 53 | self.seed = seed if seed is not None else torch.random.seed() |
| 54 | self.sample_batches = sample_batches | 54 | self.sample_batches = sample_batches |
| 55 | self.sample_batch_size = sample_batch_size | 55 | self.sample_batch_size = sample_batch_size |
| 56 | 56 | ||
