diff options
Diffstat (limited to 'training/util.py')
| -rw-r--r-- | training/util.py | 11 |
1 files changed, 3 insertions, 8 deletions
diff --git a/training/util.py b/training/util.py index 557b196..237626f 100644 --- a/training/util.py +++ b/training/util.py | |||
| @@ -1,18 +1,11 @@ | |||
| 1 | from pathlib import Path | 1 | from pathlib import Path |
| 2 | import json | 2 | import json |
| 3 | import copy | 3 | import copy |
| 4 | from typing import Iterable, Union | 4 | from typing import Iterable, Any |
| 5 | from contextlib import contextmanager | 5 | from contextlib import contextmanager |
| 6 | 6 | ||
| 7 | import torch | 7 | import torch |
| 8 | 8 | ||
| 9 | from transformers import CLIPTextModel | ||
| 10 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, DPMSolverMultistepScheduler | ||
| 11 | |||
| 12 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | ||
| 13 | from models.clip.tokenizer import MultiCLIPTokenizer | ||
| 14 | from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings | ||
| 15 | |||
| 16 | 9 | ||
| 17 | def save_args(basepath: Path, args, extra={}): | 10 | def save_args(basepath: Path, args, extra={}): |
| 18 | info = {"args": vars(args)} | 11 | info = {"args": vars(args)} |
| @@ -22,6 +15,8 @@ def save_args(basepath: Path, args, extra={}): | |||
| 22 | 15 | ||
| 23 | 16 | ||
| 24 | class AverageMeter: | 17 | class AverageMeter: |
| 18 | avg: Any | ||
| 19 | |||
| 25 | def __init__(self, name=None): | 20 | def __init__(self, name=None): |
| 26 | self.name = name | 21 | self.name = name |
| 27 | self.reset() | 22 | self.reset() |
