From b8df3dd5330845ff9f9f9af187a09ef0dbfc1c20 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 6 Jan 2023 17:34:23 +0100 Subject: Update --- environment.yaml | 4 ++-- models/clip/tokenizer.py | 2 +- train_ti.py | 41 +++++++++++++++++------------------------ training/util.py | 12 ++++++------ 4 files changed, 26 insertions(+), 33 deletions(-) diff --git a/environment.yaml b/environment.yaml index 6e689c7..eff69ed 100644 --- a/environment.yaml +++ b/environment.yaml @@ -16,11 +16,11 @@ dependencies: - -e git+https://github.com/huggingface/diffusers#egg=diffusers - -e git+https://github.com/cloneofsimo/lora#egg=lora-diffusion - accelerate==0.15.0 - - bitsandbytes==0.35.4 + - bitsandbytes==0.36.0.post2 - python-slugify>=6.1.2 - safetensors==0.2.7 - setuptools==65.6.3 - test-tube>=0.7.5 - transformers==4.25.1 - triton==2.0.0.dev20221202 - - xformers==0.0.16rc399 + - xformers==0.0.16rc401 diff --git a/models/clip/tokenizer.py b/models/clip/tokenizer.py index 034adf9..39c41ed 100644 --- a/models/clip/tokenizer.py +++ b/models/clip/tokenizer.py @@ -1,5 +1,5 @@ import copy -from typing import NamedTuple, Union, Literal +from typing import Union, Literal import numpy as np diff --git a/train_ti.py b/train_ti.py index f622299..9aab00c 100644 --- a/train_ti.py +++ b/train_ti.py @@ -280,7 +280,7 @@ def parse_args(): parser.add_argument( "--ema_power", type=float, - default=6/7 + default=7/8 ) parser.add_argument( "--ema_max_decay", @@ -464,30 +464,19 @@ class Checkpointer(CheckpointerBase): def __init__( self, weight_dtype, - datamodule, - accelerator, - vae, - unet, - tokenizer, - text_encoder, - ema_embeddings, + accelerator: Accelerator, + vae: AutoencoderKL, + unet: UNet2DConditionModel, + tokenizer: MultiCLIPTokenizer, + text_encoder: CLIPTextModel, + ema_embeddings: EMAModel, scheduler, placeholder_token, new_ids, - output_dir: Path, - sample_image_size, - sample_batches, - sample_batch_size, - seed + *args, + **kwargs ): - super().__init__( - datamodule=datamodule, - output_dir=output_dir, - sample_image_size=sample_image_size, - seed=seed or torch.random.seed(), - sample_batches=sample_batches, - sample_batch_size=sample_batch_size - ) + super().__init__(*args, **kwargs) self.weight_dtype = weight_dtype self.accelerator = accelerator @@ -829,7 +818,9 @@ def main(): # Move vae and unet to device vae.to(accelerator.device, dtype=weight_dtype) unet.to(accelerator.device, dtype=weight_dtype) - ema_embeddings.to(accelerator.device) + + if args.use_ema: + ema_embeddings.to(accelerator.device) # Keep vae and unet in eval mode as we don't train these vae.eval() @@ -854,13 +845,15 @@ def main(): tokenizer.train() yield finally: - tokenizer.eval() + pass @contextmanager def on_eval(): try: + tokenizer.eval() + ema_context = ema_embeddings.apply_temporary( - text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if args.use_ema is not None and eval else nullcontext() + text_encoder.text_model.embeddings.temp_token_embedding.parameters()) if args.use_ema else nullcontext() with ema_context: yield diff --git a/training/util.py b/training/util.py index bed7111..bc466e2 100644 --- a/training/util.py +++ b/training/util.py @@ -1,7 +1,7 @@ from pathlib import Path import json import copy -from typing import Iterable +from typing import Iterable, Optional from contextlib import contextmanager import torch @@ -42,15 +42,15 @@ class CheckpointerBase: self, datamodule, output_dir: Path, - sample_image_size, - sample_batches, - sample_batch_size, - seed + sample_image_size: int, + sample_batches: int, + sample_batch_size: int, + seed: Optional[int] = None ): self.datamodule = datamodule self.output_dir = output_dir self.sample_image_size = sample_image_size - self.seed = seed or torch.random.seed() + self.seed = seed if seed is not None else torch.random.seed() self.sample_batches = sample_batches self.sample_batch_size = sample_batch_size -- cgit v1.2.3-54-g00ecf