summaryrefslogtreecommitdiffstats
path: root/training/util.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/util.py')
-rw-r--r--training/util.py11
1 files changed, 3 insertions, 8 deletions
diff --git a/training/util.py b/training/util.py
index 557b196..237626f 100644
--- a/training/util.py
+++ b/training/util.py
@@ -1,18 +1,11 @@
1from pathlib import Path 1from pathlib import Path
2import json 2import json
3import copy 3import copy
4from typing import Iterable, Union 4from typing import Iterable, Any
5from contextlib import contextmanager 5from contextlib import contextmanager
6 6
7import torch 7import torch
8 8
9from transformers import CLIPTextModel
10from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, DPMSolverMultistepScheduler
11
12from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
13from models.clip.tokenizer import MultiCLIPTokenizer
14from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings
15
16 9
17def save_args(basepath: Path, args, extra={}): 10def save_args(basepath: Path, args, extra={}):
18 info = {"args": vars(args)} 11 info = {"args": vars(args)}
@@ -22,6 +15,8 @@ def save_args(basepath: Path, args, extra={}):
22 15
23 16
24class AverageMeter: 17class AverageMeter:
18 avg: Any
19
25 def __init__(self, name=None): 20 def __init__(self, name=None):
26 self.name = name 21 self.name = name
27 self.reset() 22 self.reset()