diff options
| -rw-r--r-- | common.py | 14 | ||||
| -rw-r--r-- | data/csv.py | 4 | ||||
| -rw-r--r-- | infer.py | 11 | ||||
| -rw-r--r-- | pipelines/stable_diffusion/vlpn_stable_diffusion.py | 8 | ||||
| -rw-r--r-- | train_dreambooth.py | 7 | ||||
| -rw-r--r-- | train_lora.py | 7 | ||||
| -rw-r--r-- | train_ti.py | 11 |
7 files changed, 36 insertions, 26 deletions
| @@ -1,9 +1,23 @@ | |||
| 1 | from pathlib import Path | 1 | from pathlib import Path |
| 2 | import json | ||
| 3 | |||
| 2 | import torch | 4 | import torch |
| 3 | 5 | ||
| 4 | from transformers import CLIPTextModel, CLIPTokenizer | 6 | from transformers import CLIPTextModel, CLIPTokenizer |
| 5 | 7 | ||
| 6 | 8 | ||
| 9 | def load_config(filename): | ||
| 10 | with open(filename, 'rt') as f: | ||
| 11 | config = json.load(f) | ||
| 12 | |||
| 13 | args = config["args"] | ||
| 14 | |||
| 15 | if "base" in config: | ||
| 16 | args = load_config(Path(filename).parent.joinpath(config["base"])) | args | ||
| 17 | |||
| 18 | return args | ||
| 19 | |||
| 20 | |||
| 7 | def load_text_embedding(embeddings, token_id, file): | 21 | def load_text_embedding(embeddings, token_id, file): |
| 8 | data = torch.load(file, map_location="cpu") | 22 | data = torch.load(file, map_location="cpu") |
| 9 | 23 | ||
diff --git a/data/csv.py b/data/csv.py index 0810c2c..0ad36dc 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -51,7 +51,7 @@ class CSVDataModule(): | |||
| 51 | prompt_processor: PromptProcessor, | 51 | prompt_processor: PromptProcessor, |
| 52 | class_subdir: str = "cls", | 52 | class_subdir: str = "cls", |
| 53 | num_class_images: int = 1, | 53 | num_class_images: int = 1, |
| 54 | size: int = 512, | 54 | size: int = 768, |
| 55 | repeats: int = 1, | 55 | repeats: int = 1, |
| 56 | dropout: float = 0, | 56 | dropout: float = 0, |
| 57 | interpolation: str = "bicubic", | 57 | interpolation: str = "bicubic", |
| @@ -196,7 +196,7 @@ class CSVDataset(Dataset): | |||
| 196 | prompt_processor: PromptProcessor, | 196 | prompt_processor: PromptProcessor, |
| 197 | batch_size: int = 1, | 197 | batch_size: int = 1, |
| 198 | num_class_images: int = 0, | 198 | num_class_images: int = 0, |
| 199 | size: int = 512, | 199 | size: int = 768, |
| 200 | repeats: int = 1, | 200 | repeats: int = 1, |
| 201 | dropout: float = 0, | 201 | dropout: float = 0, |
| 202 | interpolation: str = "bicubic", | 202 | interpolation: str = "bicubic", |
| @@ -24,7 +24,7 @@ from transformers import CLIPTextModel, CLIPTokenizer | |||
| 24 | from slugify import slugify | 24 | from slugify import slugify |
| 25 | 25 | ||
| 26 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 26 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
| 27 | from common import load_text_embeddings | 27 | from common import load_text_embeddings, load_config |
| 28 | 28 | ||
| 29 | 29 | ||
| 30 | torch.backends.cuda.matmul.allow_tf32 = True | 30 | torch.backends.cuda.matmul.allow_tf32 = True |
| @@ -46,8 +46,8 @@ default_cmds = { | |||
| 46 | "negative_prompt": None, | 46 | "negative_prompt": None, |
| 47 | "image": None, | 47 | "image": None, |
| 48 | "image_noise": .7, | 48 | "image_noise": .7, |
| 49 | "width": 512, | 49 | "width": 768, |
| 50 | "height": 512, | 50 | "height": 768, |
| 51 | "batch_size": 1, | 51 | "batch_size": 1, |
| 52 | "batch_num": 1, | 52 | "batch_num": 1, |
| 53 | "steps": 30, | 53 | "steps": 30, |
| @@ -163,9 +163,8 @@ def run_parser(parser, defaults, input=None): | |||
| 163 | conf_args = argparse.Namespace() | 163 | conf_args = argparse.Namespace() |
| 164 | 164 | ||
| 165 | if args.config is not None: | 165 | if args.config is not None: |
| 166 | with open(args.config, 'rt') as f: | 166 | args = load_config(args.config) |
| 167 | conf_args = parser.parse_known_args( | 167 | args = parser.parse_args(namespace=argparse.Namespace(**args)) |
| 168 | namespace=argparse.Namespace(**json.load(f)["args"]))[0] | ||
| 169 | 168 | ||
| 170 | res = defaults.copy() | 169 | res = defaults.copy() |
| 171 | for dict in [vars(conf_args), vars(args)]: | 170 | for dict in [vars(conf_args), vars(args)]: |
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py index a43a8e4..53b5eea 100644 --- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py +++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py | |||
| @@ -318,8 +318,8 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 318 | negative_prompt: Optional[Union[str, List[str], List[List[str]]]] = None, | 318 | negative_prompt: Optional[Union[str, List[str], List[List[str]]]] = None, |
| 319 | num_images_per_prompt: Optional[int] = 1, | 319 | num_images_per_prompt: Optional[int] = 1, |
| 320 | strength: float = 0.8, | 320 | strength: float = 0.8, |
| 321 | height: Optional[int] = 512, | 321 | height: Optional[int] = 768, |
| 322 | width: Optional[int] = 512, | 322 | width: Optional[int] = 768, |
| 323 | num_inference_steps: Optional[int] = 50, | 323 | num_inference_steps: Optional[int] = 50, |
| 324 | guidance_scale: Optional[float] = 7.5, | 324 | guidance_scale: Optional[float] = 7.5, |
| 325 | eta: Optional[float] = 0.0, | 325 | eta: Optional[float] = 0.0, |
| @@ -342,9 +342,9 @@ class VlpnStableDiffusion(DiffusionPipeline): | |||
| 342 | number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added | 342 | number of denoising steps depends on the amount of noise initially added. When `strength` is 1, added |
| 343 | noise will be maximum and the denoising process will run for the full number of iterations specified in | 343 | noise will be maximum and the denoising process will run for the full number of iterations specified in |
| 344 | `num_inference_steps`. A value of 1, therefore, essentially ignores `init_image`. | 344 | `num_inference_steps`. A value of 1, therefore, essentially ignores `init_image`. |
| 345 | height (`int`, *optional*, defaults to 512): | 345 | height (`int`, *optional*, defaults to 768): |
| 346 | The height in pixels of the generated image. | 346 | The height in pixels of the generated image. |
| 347 | width (`int`, *optional*, defaults to 512): | 347 | width (`int`, *optional*, defaults to 768): |
| 348 | The width in pixels of the generated image. | 348 | The width in pixels of the generated image. |
| 349 | num_inference_steps (`int`, *optional*, defaults to 50): | 349 | num_inference_steps (`int`, *optional*, defaults to 50): |
| 350 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the | 350 | The number of denoising steps. More denoising steps usually lead to a higher quality image at the |
diff --git a/train_dreambooth.py b/train_dreambooth.py index 2c765ec..08bc9e0 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -20,7 +20,7 @@ from tqdm.auto import tqdm | |||
| 20 | from transformers import CLIPTextModel, CLIPTokenizer | 20 | from transformers import CLIPTextModel, CLIPTokenizer |
| 21 | from slugify import slugify | 21 | from slugify import slugify |
| 22 | 22 | ||
| 23 | from common import load_text_embeddings | 23 | from common import load_text_embeddings, load_config |
| 24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
| 25 | from data.csv import CSVDataModule | 25 | from data.csv import CSVDataModule |
| 26 | from training.optimization import get_one_cycle_schedule | 26 | from training.optimization import get_one_cycle_schedule |
| @@ -355,9 +355,8 @@ def parse_args(): | |||
| 355 | 355 | ||
| 356 | args = parser.parse_args() | 356 | args = parser.parse_args() |
| 357 | if args.config is not None: | 357 | if args.config is not None: |
| 358 | with open(args.config, 'rt') as f: | 358 | args = load_config(args.config) |
| 359 | args = parser.parse_args( | 359 | args = parser.parse_args(namespace=argparse.Namespace(**args)) |
| 360 | namespace=argparse.Namespace(**json.load(f)["args"])) | ||
| 361 | 360 | ||
| 362 | if args.train_data_file is None: | 361 | if args.train_data_file is None: |
| 363 | raise ValueError("You must specify --train_data_file") | 362 | raise ValueError("You must specify --train_data_file") |
diff --git a/train_lora.py b/train_lora.py index 34e1008..ffca304 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -20,7 +20,7 @@ from tqdm.auto import tqdm | |||
| 20 | from transformers import CLIPTextModel, CLIPTokenizer | 20 | from transformers import CLIPTextModel, CLIPTokenizer |
| 21 | from slugify import slugify | 21 | from slugify import slugify |
| 22 | 22 | ||
| 23 | from common import load_text_embeddings | 23 | from common import load_text_embeddings, load_config |
| 24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
| 25 | from data.csv import CSVDataModule | 25 | from data.csv import CSVDataModule |
| 26 | from training.lora import LoraAttnProcessor | 26 | from training.lora import LoraAttnProcessor |
| @@ -317,9 +317,8 @@ def parse_args(): | |||
| 317 | 317 | ||
| 318 | args = parser.parse_args() | 318 | args = parser.parse_args() |
| 319 | if args.config is not None: | 319 | if args.config is not None: |
| 320 | with open(args.config, 'rt') as f: | 320 | args = load_config(args.config) |
| 321 | args = parser.parse_args( | 321 | args = parser.parse_args(namespace=argparse.Namespace(**args)) |
| 322 | namespace=argparse.Namespace(**json.load(f)["args"])) | ||
| 323 | 322 | ||
| 324 | if args.train_data_file is None: | 323 | if args.train_data_file is None: |
| 325 | raise ValueError("You must specify --train_data_file") | 324 | raise ValueError("You must specify --train_data_file") |
diff --git a/train_ti.py b/train_ti.py index a228795..6e30ac3 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -20,7 +20,7 @@ from tqdm.auto import tqdm | |||
| 20 | from transformers import CLIPTextModel, CLIPTokenizer | 20 | from transformers import CLIPTextModel, CLIPTokenizer |
| 21 | from slugify import slugify | 21 | from slugify import slugify |
| 22 | 22 | ||
| 23 | from common import load_text_embeddings, load_text_embedding | 23 | from common import load_text_embeddings, load_text_embedding, load_config |
| 24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
| 25 | from data.csv import CSVDataModule, CSVDataItem | 25 | from data.csv import CSVDataModule, CSVDataItem |
| 26 | from training.optimization import get_one_cycle_schedule | 26 | from training.optimization import get_one_cycle_schedule |
| @@ -225,7 +225,7 @@ def parse_args(): | |||
| 225 | parser.add_argument( | 225 | parser.add_argument( |
| 226 | "--adam_weight_decay", | 226 | "--adam_weight_decay", |
| 227 | type=float, | 227 | type=float, |
| 228 | default=1e-2, | 228 | default=0, |
| 229 | help="Weight decay to use." | 229 | help="Weight decay to use." |
| 230 | ) | 230 | ) |
| 231 | parser.add_argument( | 231 | parser.add_argument( |
| @@ -324,9 +324,8 @@ def parse_args(): | |||
| 324 | 324 | ||
| 325 | args = parser.parse_args() | 325 | args = parser.parse_args() |
| 326 | if args.config is not None: | 326 | if args.config is not None: |
| 327 | with open(args.config, 'rt') as f: | 327 | args = load_config(args.config) |
| 328 | args = parser.parse_args( | 328 | args = parser.parse_args(namespace=argparse.Namespace(**args)) |
| 329 | namespace=argparse.Namespace(**json.load(f)["args"])) | ||
| 330 | 329 | ||
| 331 | if args.train_data_file is None: | 330 | if args.train_data_file is None: |
| 332 | raise ValueError("You must specify --train_data_file") | 331 | raise ValueError("You must specify --train_data_file") |
| @@ -407,7 +406,7 @@ class Checkpointer(CheckpointerBase): | |||
| 407 | 406 | ||
| 408 | for (placeholder_token, placeholder_token_id) in zip(self.placeholder_token, self.placeholder_token_id): | 407 | for (placeholder_token, placeholder_token_id) in zip(self.placeholder_token, self.placeholder_token_id): |
| 409 | # Save a checkpoint | 408 | # Save a checkpoint |
| 410 | learned_embeds = text_encoder.text_model.embeddings.trainable_embedding.weight[placeholder_token_id] | 409 | learned_embeds = text_encoder.text_model.embeddings.trainable_embedding.weight.data[placeholder_token_id] |
| 411 | learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()} | 410 | learned_embeds_dict = {placeholder_token: learned_embeds.detach().cpu()} |
| 412 | 411 | ||
| 413 | filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix) | 412 | filename = f"%s_%d_%s.bin" % (slugify(placeholder_token), step, postfix) |
