diff options
author | Volpeon <git@volpeon.ink> | 2022-12-26 14:24:21 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-12-26 14:24:21 +0100 |
commit | e0b686b475885f0c8480f7173eaa7359adf17e27 (patch) | |
tree | 6ad882f152e63801d31230466e4d6468e7ada697 | |
parent | Code simplifications, avoid autocast (diff) | |
download | textual-inversion-diff-e0b686b475885f0c8480f7173eaa7359adf17e27.tar.gz textual-inversion-diff-e0b686b475885f0c8480f7173eaa7359adf17e27.tar.bz2 textual-inversion-diff-e0b686b475885f0c8480f7173eaa7359adf17e27.zip |
Set default dimensions to 768; add config inheritance
-rw-r--r-- | common.py | 14 | ||||
-rw-r--r-- | data/csv.py | 4 | ||||
-rw-r--r-- | infer.py | 11 | ||||
-rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 8 | ||||
-rw-r--r-- | train_dreambooth.py | 7 | ||||
-rw-r--r-- | train_lora.py | 7 | ||||
-rw-r--r-- | train_ti.py | 11 |
7 files changed, 36 insertions, 26 deletions
@@ -1,9 +1,23 @@ | |||
1 | from pathlib import Path | 1 | from pathlib import Path |
2 | import json | ||
3 | |||
2 | import torch | 4 | import torch |
3 | 5 | ||
4 | from transformers import CLIPTextModel, CLIPTokenizer | 6 | from transformers import CLIPTextModel, CLIPTokenizer |
5 | 7 | ||
6 | 8 | ||
9 | def load_config(filename): | ||
10 | with open(filename, 'rt') as f: | ||
11 | config = json.load(f) | ||
12 | |||
13 | args = config["args"] | ||
14 | |||
15 | if "base" in config: | ||
16 | args = load_config(Path(filename).parent.joinpath(config["base"])) | args | ||
17 | |||
18 | return args | ||
19 | |||
20 | |||
7 | def load_text_embedding(embeddings, token_id, file): | 21 | def load_text_embedding(embeddings, token_id, file): |
8 | data = torch.load(file, map_location="cpu") | 22 | data = torch.load(file, map_location="cpu") |
9 | 23 | ||
diff --git a/data/csv.py b/data/csv.py index 0810c2c..0ad36dc 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -51,7 +51,7 @@ class CSVDataModule(): | |||
51 | prompt_processor: PromptProcessor, | 51 | prompt_processor: PromptProcessor, |
52 | class_subdir: str = "cls", | 52 | class_subdir: str = "cls", |
53 | num_class_images: int = 1, | 53 | num_class_images: int = 1, |
54 | size: int = 512, | 54 | size: int = 768, |
55 | repeats: int = 1, | 55 | repeats: int = 1, |
56 | dropout: float = 0, | 56 | dropout: float = 0, |
57 | interpolation: str = "bicubic", | 57 | interpolation: str = "bicubic", |
@@ -196,7 +196,7 @@ class CSVDataset(Dataset): | |||
196 | prompt_processor: PromptProcessor, | 196 | prompt_processor: PromptProcessor, |
197 | batch_size: int = 1, | 197 | batch_size: int = 1, |
198 | num_class_images: int = 0, | 198 | num_class_images: int = 0, |
199 | size: int = 512, | 199 | size: int = 768, |
200 | repeats: int = 1, | 200 | repeats: int = 1, |
201 | dropout: float = 0, | 201 | dropout: float = 0, |
202 | interpolation: str = "bicubic", | 202 | interpolation: str = "bicubic", |
@@ -24,7 +24,7 @@ from transformers import CLIPTextModel, CLIPTokenizer | |||
24 | from slugify import slugify | 24 | from slugify import slugify |
25 | 25 | ||
26 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 26 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
27 | from common import load_text_embeddings | 27 | from common import load_text_embeddings, load_config |
28 | 28 | ||
29 | 29 | ||
30 | torch.backends.cuda.matmul.allow_tf32 = True | 30 | torch.backends.cuda.matmul.allow_tf32 = True |
@@ -46,8 +46,8 @@ default_cmds = { | |||
46 | "negative_prompt": None, | 46 | "negative_prompt": None, |
47 | "image": None, | 47 | "image": None, |
48 | "image_noise": .7, | 48 | "image_noise": .7, |
49 | "width": 512, | 49 | "width": 768, |
50 | "height": 512, | 50 | "height": 768, |
51 | "batch_size": 1, | 51 | "batch_size": 1, |
52 | "batch_num": 1, | 52 | "batch_num": 1, |
53 | "steps": 30, | 53 | "steps": 30, |
@@ -163,9 +163,8 @@ def run_parser(parser, defaults, input=None): | |||
163 | conf_args = argparse.Namespace() | 163 | conf_args = argparse.Namespace() |
164 | 164 | ||
165 | if args.config is not None: | 165 | if args.config is not None: |
166 | with open(args.config, 'rt') as f: | 166 | args = load_config(args.config) |
167 | conf_args = parser.parse_known_args( | 167 | args = parser.parse_args(namespace=argparse.Namespace(**args)) |
168 | namespace=argparse.Namespace(**json.load(f)["args"]))[0] | ||
169 | 168 | ||
170 | res = defaults.copy() | 169 | res = defaults.copy() |
171 | for dict in [vars(conf_args), vars(args)]: | 170 | for dict in [vars(conf_args), vars(args)]: |
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index a43a8e4..53b5eea 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
@@ -318,8 +318,8 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
318 | negative_prompt: Optional[Union[str, List[str], List[List[str]]]] = None, | 318 | negative_prompt: Optional[Union[str, List[str], List[List[str]]]] = None, |
319 | num_images_per_prompt: Optional[int] = 1, | 319 | num_images_per_prompt: Optional[int] = 1, |
320 | strength: float = 0.8, | 320 | strength: float = 0.8, |
321 | height: Optional[int] = 512, | 321 | height: Optional[int] = 768, |
322 | width: Optional[int] = 512, | 322 | width: Optional[int] = 768, |
323 | num_inference_steps: Optional[int] = 50, | 323 | num_inference_steps: Optional[int] = 50, |
324 | guidance_scale: Optional[float] = 7.5, | 324 | guidance_scale: Optional[float] = 7.5, |
325 | eta: Optional[float] = 0.0, | 325 | eta: Optional[float] = 0.0, |
@@ -342,9 +342,9 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
342 | number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added | 342 | number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added |
343 | noise will be maximum and the denoising process will run for the full number of iterations specified in | 343 | noise will be maximum and the denoising process will run for the full number of iterations specified in |
344 | `num_inference_steps`. A value of 1, therefore, essentially ignores `init_image`. | 344 | `num_inference_steps`. A value of 1, therefore, essentially ignores `init_image`. |
345 | height (`int`, *optional*, defaults to 512): | 345 | height (`int`, *optional*, defaults to 768): |
346 | The height in pixels of the generated image. | 346 | The height in pixels of the generated image. |
347 | width (`int`, *optional*, defaults to 512): | 347 | width (`int`, *optional*, defaults to 768): |
348 | The width in pixels of the generated image. | 348 | The width in pixels of the generated image. |
349 | num_inference_steps (`int`, *optional*, defaults to 50): | 349 | num_inference_steps (`int`, *optional*, defaults to 50): |
350 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the | 350 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the |
diff --git a/train_dreambooth.py b/train_dreambooth.py index 2c765ec..08bc9e0 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -20,7 +20,7 @@ from tqdm.auto import tqdm | |||
20 | from transformers import CLIPTextModel, CLIPTokenizer | 20 | from transformers import CLIPTextModel, CLIPTokenizer |
21 | from slugify import slugify | 21 | from slugify import slugify |
22 | 22 | ||
23 | from common import load_text_embeddings | 23 | from common import load_text_embeddings, load_config |
24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
25 | from data.csv import CSVDataModule | 25 | from data.csv import CSVDataModule |
26 | from training.optimization import get_one_cycle_schedule | 26 | from training.optimization import get_one_cycle_schedule |
@@ -355,9 +355,8 @@ def parse_args(): | |||
355 | 355 | ||
356 | args = parser.parse_args() | 356 | args = parser.parse_args() |
357 | if args.config is not None: | 357 | if args.config is not None: |
358 | with open(args.config, 'rt') as f: | 358 | args = load_config(args.config) |
359 | args = parser.parse_args( | 359 | args = parser.parse_args(namespace=argparse.Namespace(**args)) |
360 | namespace=argparse.Namespace(**json.load(f)["args"])) | ||
361 | 360 | ||
362 | if args.train_data_file is None: | 361 | if args.train_data_file is None: |
363 | raise ValueError("You must specify --train_data_file") | 362 | raise ValueError("You must specify --train_data_file") |
diff --git a/train_lora.py b/train_lora.py index 34e1008..ffca304 100644 --- a/train_lora.py +++ b/train_lora.py | |||
@@ -20,7 +20,7 @@ from tqdm.auto import tqdm | |||
20 | from transformers import CLIPTextModel, CLIPTokenizer | 20 | from transformers import CLIPTextModel, CLIPTokenizer |
21 | from slugify import slugify | 21 | from slugify import slugify |
22 | 22 | ||
23 | from common import load_text_embeddings | 23 | from common import load_text_embeddings, load_config |
24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
25 | from data.csv import CSVDataModule | 25 | from data.csv import CSVDataModule |
26 | from training.lora import LoraAttnProcessor | 26 | from training.lora import LoraAttnProcessor |
@@ -317,9 +317,8 @@ def parse_args(): | |||
317 | 317 | ||
318 | args = parser.parse_args() | 318 | args = parser.parse_args() |
319 | if args.config is not None: | 319 | if args.config is not None: |
320 | with open(args.config, 'rt') as f: | 320 | args = load_config(args.config) |
321 | args = parser.parse_args( | 321 | args = parser.parse_args(namespace=argparse.Namespace(**args)) |
322 | namespace=argparse.Namespace(**json.load(f)["args"])) | ||
323 | 322 | ||
324 | if args.train_data_file is None: | 323 | if args.train_data_file is None: |
325 | raise ValueError("You must specify --train_data_file") | 324 | raise ValueError("You must specify --train_data_file") |
diff --git a/train_ti.py b/train_ti.py index a228795..6e30ac3 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -20,7 +20,7 @@ from tqdm.auto import tqdm | |||
20 | from transformers import CLIPTextModel, CLIPTokenizer | 20 | from transformers import CLIPTextModel, CLIPTokenizer |
21 | from slugify import slugify | 21 | from slugify import slugify |
22 | 22 | ||
23 | from common import load_text_embeddings, load_text_embedding | 23 | from common import load_text_embeddings, load_text_embedding, load_config |
24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
25 | from data.csv import CSVDataModule, CSVDataItem | 25 | from data.csv import CSVDataModule, CSVDataItem |
26 | from training.optimization import get_one_cycle_schedule | 26 | from training.optimization import get_one_cycle_schedule |
@@ -225,7 +225,7 @@ def parse_args(): | |||
225 | parser.add_argument( | 225 | parser.add_argument( |
226 | "--adam_weight_decay", | 226 | "--adam_weight_decay", |
227 | type=float, | 227 | type=float, |
228 | default=1e-2, | 228 | default=0, |
229 | help="Weight decay to use." | 229 | help="Weight decay to use." |
230 | ) | 230 | ) |
231 | parser.add_argument( | 231 | parser.add_argument( |
@@ -324,9 +324,8 @@ def parse_args(): | |||
324 | 324 | ||
325 | args = parser.parse_args() | 325 | args = parser.parse_args() |
326 | if args.config is not None: | 326 | if args.config is not None: |
327 | with open(args.config, 'rt') as f: | 327 | args = load_config(args.config) |
328 | args = parser.parse_args( | 328 | args = parser.parse_args(namespace=argparse.Namespace(**args)) |
329 | namespace=argparse.Namespace(**json.load(f)["args"])) | ||
330 | 329 | ||
331 | if args.train_data_file is None: | 330 | if args.train_data_file is None: |
332 | raise ValueError("You must specify --train_data_file") | 331 | raise ValueError("You must specify --train_data_file") |
@@ -407,7 +406,7 @@ class Checkpointer(CheckpointerBase): | |||
407 | 406 | ||
408 | for (placeholder_token, placeholder_token_id) in zip(self.placeholder_token, self.placeholder_token_id): | 407 | for (placeholder_token, placeholder_token_id) in zip(self.placeholder_token, self.placeholder_token_id): |
409 | # Save a checkpoint | 408 | # Save a checkpoint |
410 | learned_embeds = text_encoder.text_model.embeddings.trainable_embedding.weight[placeholder_token_id] | 409 | learned_embeds = text_encoder.text_model.embeddings.trainable_embedding.weight.data[placeholder_token_id] |
411 | learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()} | 410 | learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()} |
412 | 411 | ||
413 | filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix) | 412 | filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix) |