diff options
Diffstat (limited to 'training/util.py')
-rw-r--r-- | training/util.py | 11 |
1 files changed, 3 insertions, 8 deletions
diff --git a/training/util.py b/training/util.py index 557b196..237626f 100644 --- a/training/util.py +++ b/training/util.py | |||
@@ -1,18 +1,11 @@ | |||
1 | from pathlib import Path | 1 | from pathlib import Path |
2 | import json | 2 | import json |
3 | import copy | 3 | import copy |
4 | from typing import Iterable, Union | 4 | from typing import Iterable, Any |
5 | from contextlib import contextmanager | 5 | from contextlib import contextmanager |
6 | 6 | ||
7 | import torch | 7 | import torch |
8 | 8 | ||
9 | from transformers import CLIPTextModel | ||
10 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, DPMSolverMultistepScheduler | ||
11 | |||
12 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | ||
13 | from models.clip.tokenizer import MultiCLIPTokenizer | ||
14 | from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings | ||
15 | |||
16 | 9 | ||
17 | def save_args(basepath: Path, args, extra={}): | 10 | def save_args(basepath: Path, args, extra={}): |
18 | info = {"args": vars(args)} | 11 | info = {"args": vars(args)} |
@@ -22,6 +15,8 @@ def save_args(basepath: Path, args, extra={}): | |||
22 | 15 | ||
23 | 16 | ||
24 | class AverageMeter: | 17 | class AverageMeter: |
18 | avg: Any | ||
19 | |||
25 | def __init__(self, name=None): | 20 | def __init__(self, name=None): |
26 | self.name = name | 21 | self.name = name |
27 | self.reset() | 22 | self.reset() |