summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--common.py14
-rw-r--r--data/csv.py4
-rw-r--r--infer.py11
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py8
-rw-r--r--train_dreambooth.py7
-rw-r--r--train_lora.py7
-rw-r--r--train_ti.py11
7 files changed, 36 insertions, 26 deletions
diff --git a/common.py b/common.py
index 7ffa77f..f369475 100644
--- a/common.py
+++ b/common.py
@@ -1,9 +1,23 @@
1from pathlib import Path 1from pathlib import Path
2import json
3
2import torch 4import torch
3 5
4from transformers import CLIPTextModel, CLIPTokenizer 6from transformers import CLIPTextModel, CLIPTokenizer
5 7
6 8
9def load_config(filename):
10 with open(filename, 'rt') as f:
11 config = json.load(f)
12
13 args = config["args"]
14
15 if "base" in config:
16 args = load_config(Path(filename).parent.joinpath(config["base"])) | args
17
18 return args
19
20
7def load_text_embedding(embeddings, token_id, file): 21def load_text_embedding(embeddings, token_id, file):
8 data = torch.load(file, map_location="cpu") 22 data = torch.load(file, map_location="cpu")
9 23
diff --git a/data/csv.py b/data/csv.py
index 0810c2c..0ad36dc 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -51,7 +51,7 @@ class CSVDataModule():
51 prompt_processor: PromptProcessor, 51 prompt_processor: PromptProcessor,
52 class_subdir: str = "cls", 52 class_subdir: str = "cls",
53 num_class_images: int = 1, 53 num_class_images: int = 1,
54 size: int = 512, 54 size: int = 768,
55 repeats: int = 1, 55 repeats: int = 1,
56 dropout: float = 0, 56 dropout: float = 0,
57 interpolation: str = "bicubic", 57 interpolation: str = "bicubic",
@@ -196,7 +196,7 @@ class CSVDataset(Dataset):
196 prompt_processor: PromptProcessor, 196 prompt_processor: PromptProcessor,
197 batch_size: int = 1, 197 batch_size: int = 1,
198 num_class_images: int = 0, 198 num_class_images: int = 0,
199 size: int = 512, 199 size: int = 768,
200 repeats: int = 1, 200 repeats: int = 1,
201 dropout: float = 0, 201 dropout: float = 0,
202 interpolation: str = "bicubic", 202 interpolation: str = "bicubic",
diff --git a/infer.py b/infer.py
index f566114..ae0b4da 100644
--- a/infer.py
+++ b/infer.py
@@ -24,7 +24,7 @@ from transformers import CLIPTextModel, CLIPTokenizer
24from slugify import slugify 24from slugify import slugify
25 25
26from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 26from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
27from common import load_text_embeddings 27from common import load_text_embeddings, load_config
28 28
29 29
30torch.backends.cuda.matmul.allow_tf32 = True 30torch.backends.cuda.matmul.allow_tf32 = True
@@ -46,8 +46,8 @@ default_cmds = {
46 "negative_prompt": None, 46 "negative_prompt": None,
47 "image": None, 47 "image": None,
48 "image_noise": .7, 48 "image_noise": .7,
49 "width": 512, 49 "width": 768,
50 "height": 512, 50 "height": 768,
51 "batch_size": 1, 51 "batch_size": 1,
52 "batch_num": 1, 52 "batch_num": 1,
53 "steps": 30, 53 "steps": 30,
@@ -163,9 +163,8 @@ def run_parser(parser, defaults, input=None):
163 conf_args = argparse.Namespace() 163 conf_args = argparse.Namespace()
164 164
165 if args.config is not None: 165 if args.config is not None:
166 with open(args.config, 'rt') as f: 166 args = load_config(args.config)
167 conf_args = parser.parse_known_args( 167 args = parser.parse_args(namespace=argparse.Namespace(**args))
168 namespace=argparse.Namespace(**json.load(f)["args"]))[0]
169 168
170 res = defaults.copy() 169 res = defaults.copy()
171 for dict in [vars(conf_args), vars(args)]: 170 for dict in [vars(conf_args), vars(args)]:
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index a43a8e4..53b5eea 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -318,8 +318,8 @@ class VlpnStableDiffusion(DiffusionPipeline):
318 negative_prompt: Optional[Union[str, List[str], List[List[str]]]] = None, 318 negative_prompt: Optional[Union[str, List[str], List[List[str]]]] = None,
319 num_images_per_prompt: Optional[int] = 1, 319 num_images_per_prompt: Optional[int] = 1,
320 strength: float = 0.8, 320 strength: float = 0.8,
321 height: Optional[int] = 512, 321 height: Optional[int] = 768,
322 width: Optional[int] = 512, 322 width: Optional[int] = 768,
323 num_inference_steps: Optional[int] = 50, 323 num_inference_steps: Optional[int] = 50,
324 guidance_scale: Optional[float] = 7.5, 324 guidance_scale: Optional[float] = 7.5,
325 eta: Optional[float] = 0.0, 325 eta: Optional[float] = 0.0,
@@ -342,9 +342,9 @@ class VlpnStableDiffusion(DiffusionPipeline):
342 number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added 342 number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added
343 noise will be maximum and the denoising process will run for the full number of iterations specified in 343 noise will be maximum and the denoising process will run for the full number of iterations specified in
344 `num_inference_steps`. A value of 1, therefore, essentially ignores `init_image`. 344 `num_inference_steps`. A value of 1, therefore, essentially ignores `init_image`.
345 height (`int`, *optional*, defaults to 512): 345 height (`int`, *optional*, defaults to 768):
346 The height in pixels of the generated image. 346 The height in pixels of the generated image.
347 width (`int`, *optional*, defaults to 512): 347 width (`int`, *optional*, defaults to 768):
348 The width in pixels of the generated image. 348 The width in pixels of the generated image.
349 num_inference_steps (`int`, *optional*, defaults to 50): 349 num_inference_steps (`int`, *optional*, defaults to 50):
350 The number of denoising steps. More denoising steps usually lead to a higher quality image at the 350 The number of denoising steps. More denoising steps usually lead to a higher quality image at the
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 2c765ec..08bc9e0 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -20,7 +20,7 @@ from tqdm.auto import tqdm
20from transformers import CLIPTextModel, CLIPTokenizer 20from transformers import CLIPTextModel, CLIPTokenizer
21from slugify import slugify 21from slugify import slugify
22 22
23from common import load_text_embeddings 23from common import load_text_embeddings, load_config
24from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 24from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
25from data.csv import CSVDataModule 25from data.csv import CSVDataModule
26from training.optimization import get_one_cycle_schedule 26from training.optimization import get_one_cycle_schedule
@@ -355,9 +355,8 @@ def parse_args():
355 355
356 args = parser.parse_args() 356 args = parser.parse_args()
357 if args.config is not None: 357 if args.config is not None:
358 with open(args.config, 'rt') as f: 358 args = load_config(args.config)
359 args = parser.parse_args( 359 args = parser.parse_args(namespace=argparse.Namespace(**args))
360 namespace=argparse.Namespace(**json.load(f)["args"]))
361 360
362 if args.train_data_file is None: 361 if args.train_data_file is None:
363 raise ValueError("You must specify --train_data_file") 362 raise ValueError("You must specify --train_data_file")
diff --git a/train_lora.py b/train_lora.py
index 34e1008..ffca304 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -20,7 +20,7 @@ from tqdm.auto import tqdm
20from transformers import CLIPTextModel, CLIPTokenizer 20from transformers import CLIPTextModel, CLIPTokenizer
21from slugify import slugify 21from slugify import slugify
22 22
23from common import load_text_embeddings 23from common import load_text_embeddings, load_config
24from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 24from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
25from data.csv import CSVDataModule 25from data.csv import CSVDataModule
26from training.lora import LoraAttnProcessor 26from training.lora import LoraAttnProcessor
@@ -317,9 +317,8 @@ def parse_args():
317 317
318 args = parser.parse_args() 318 args = parser.parse_args()
319 if args.config is not None: 319 if args.config is not None:
320 with open(args.config, 'rt') as f: 320 args = load_config(args.config)
321 args = parser.parse_args( 321 args = parser.parse_args(namespace=argparse.Namespace(**args))
322 namespace=argparse.Namespace(**json.load(f)["args"]))
323 322
324 if args.train_data_file is None: 323 if args.train_data_file is None:
325 raise ValueError("You must specify --train_data_file") 324 raise ValueError("You must specify --train_data_file")
diff --git a/train_ti.py b/train_ti.py
index a228795..6e30ac3 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -20,7 +20,7 @@ from tqdm.auto import tqdm
20from transformers import CLIPTextModel, CLIPTokenizer 20from transformers import CLIPTextModel, CLIPTokenizer
21from slugify import slugify 21from slugify import slugify
22 22
23from common import load_text_embeddings, load_text_embedding 23from common import load_text_embeddings, load_text_embedding, load_config
24from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 24from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
25from data.csv import CSVDataModule, CSVDataItem 25from data.csv import CSVDataModule, CSVDataItem
26from training.optimization import get_one_cycle_schedule 26from training.optimization import get_one_cycle_schedule
@@ -225,7 +225,7 @@ def parse_args():
225 parser.add_argument( 225 parser.add_argument(
226 "--adam_weight_decay", 226 "--adam_weight_decay",
227 type=float, 227 type=float,
228 default=1e-2, 228 default=0,
229 help="Weight decay to use." 229 help="Weight decay to use."
230 ) 230 )
231 parser.add_argument( 231 parser.add_argument(
@@ -324,9 +324,8 @@ def parse_args():
324 324
325 args = parser.parse_args() 325 args = parser.parse_args()
326 if args.config is not None: 326 if args.config is not None:
327 with open(args.config, 'rt') as f: 327 args = load_config(args.config)
328 args = parser.parse_args( 328 args = parser.parse_args(namespace=argparse.Namespace(**args))
329 namespace=argparse.Namespace(**json.load(f)["args"]))
330 329
331 if args.train_data_file is None: 330 if args.train_data_file is None:
332 raise ValueError("You must specify --train_data_file") 331 raise ValueError("You must specify --train_data_file")
@@ -407,7 +406,7 @@ class Checkpointer(CheckpointerBase):
407 406
408 for (placeholder_token, placeholder_token_id) in zip(self.placeholder_token, self.placeholder_token_id): 407 for (placeholder_token, placeholder_token_id) in zip(self.placeholder_token, self.placeholder_token_id):
409 # Save a checkpoint 408 # Save a checkpoint
410 learned_embeds = text_encoder.text_model.embeddings.trainable_embedding.weight[placeholder_token_id] 409 learned_embeds = text_encoder.text_model.embeddings.trainable_embedding.weight.data[placeholder_token_id]
411 learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()} 410 learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()}
412 411
413 filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix) 412 filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix)