diff options
| -rw-r--r-- | data/csv.py | 45 | ||||
| -rw-r--r-- | infer.py | 56 | ||||
| -rw-r--r-- | train_dreambooth.py | 8 | ||||
| -rw-r--r-- | train_ti.py | 8 | ||||
| -rw-r--r-- | training/optimization.py | 4 |
5 files changed, 87 insertions, 34 deletions
diff --git a/data/csv.py b/data/csv.py index e901ab4..c505230 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -165,19 +165,27 @@ class CSVDataModule(): | |||
| 165 | self.data_val = self.pad_items(data_val) | 165 | self.data_val = self.pad_items(data_val) |
| 166 | 166 | ||
| 167 | def setup(self, stage=None): | 167 | def setup(self, stage=None): |
| 168 | train_dataset = CSVDataset(self.data_train, self.prompt_processor, batch_size=self.batch_size, | 168 | train_dataset = CSVDataset( |
| 169 | num_class_images=self.num_class_images, | 169 | self.data_train, self.prompt_processor, batch_size=self.batch_size, |
| 170 | size=self.size, interpolation=self.interpolation, | 170 | num_class_images=self.num_class_images, |
| 171 | center_crop=self.center_crop, repeats=self.repeats, dropout=self.dropout) | 171 | size=self.size, interpolation=self.interpolation, |
| 172 | val_dataset = CSVDataset(self.data_val, self.prompt_processor, batch_size=self.batch_size, | 172 | center_crop=self.center_crop, repeats=self.repeats, dropout=self.dropout |
| 173 | size=self.size, interpolation=self.interpolation, | 173 | ) |
| 174 | center_crop=self.center_crop) | 174 | val_dataset = CSVDataset( |
| 175 | self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, | 175 | self.data_val, self.prompt_processor, batch_size=self.batch_size, |
| 176 | shuffle=True, pin_memory=True, collate_fn=self.collate_fn, | 176 | size=self.size, interpolation=self.interpolation, |
| 177 | num_workers=self.num_workers) | 177 | center_crop=self.center_crop |
| 178 | self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, | 178 | ) |
| 179 | pin_memory=True, collate_fn=self.collate_fn, | 179 | self.train_dataloader_ = DataLoader( |
| 180 | num_workers=self.num_workers) | 180 | train_dataset, batch_size=self.batch_size, |
| 181 | shuffle=True, pin_memory=True, collate_fn=self.collate_fn, | ||
| 182 | num_workers=self.num_workers | ||
| 183 | ) | ||
| 184 | self.val_dataloader_ = DataLoader( | ||
| 185 | val_dataset, batch_size=self.batch_size, | ||
| 186 | pin_memory=True, collate_fn=self.collate_fn, | ||
| 187 | num_workers=self.num_workers | ||
| 188 | ) | ||
| 181 | 189 | ||
| 182 | def train_dataloader(self): | 190 | def train_dataloader(self): |
| 183 | return self.train_dataloader_ | 191 | return self.train_dataloader_ |
| @@ -210,11 +218,12 @@ class CSVDataset(Dataset): | |||
| 210 | self.num_instance_images = len(self.data) | 218 | self.num_instance_images = len(self.data) |
| 211 | self._length = self.num_instance_images * repeats | 219 | self._length = self.num_instance_images * repeats |
| 212 | 220 | ||
| 213 | self.interpolation = {"linear": transforms.InterpolationMode.NEAREST, | 221 | self.interpolation = { |
| 214 | "bilinear": transforms.InterpolationMode.BILINEAR, | 222 | "linear": transforms.InterpolationMode.NEAREST, |
| 215 | "bicubic": transforms.InterpolationMode.BICUBIC, | 223 | "bilinear": transforms.InterpolationMode.BILINEAR, |
| 216 | "lanczos": transforms.InterpolationMode.LANCZOS, | 224 | "bicubic": transforms.InterpolationMode.BICUBIC, |
| 217 | }[interpolation] | 225 | "lanczos": transforms.InterpolationMode.LANCZOS, |
| 226 | }[interpolation] | ||
| 218 | self.image_transforms = transforms.Compose( | 227 | self.image_transforms = transforms.Compose( |
| 219 | [ | 228 | [ |
| 220 | transforms.Resize(size, interpolation=self.interpolation), | 229 | transforms.Resize(size, interpolation=self.interpolation), |
| @@ -45,6 +45,7 @@ default_args = { | |||
| 45 | 45 | ||
| 46 | 46 | ||
| 47 | default_cmds = { | 47 | default_cmds = { |
| 48 | "project": "", | ||
| 48 | "scheduler": "dpmsm", | 49 | "scheduler": "dpmsm", |
| 49 | "prompt": None, | 50 | "prompt": None, |
| 50 | "negative_prompt": None, | 51 | "negative_prompt": None, |
| @@ -104,6 +105,12 @@ def create_cmd_parser(): | |||
| 104 | description="Simple example of a training script." | 105 | description="Simple example of a training script." |
| 105 | ) | 106 | ) |
| 106 | parser.add_argument( | 107 | parser.add_argument( |
| 108 | "--project", | ||
| 109 | type=str, | ||
| 110 | default=None, | ||
| 111 | help="The name of the current project.", | ||
| 112 | ) | ||
| 113 | parser.add_argument( | ||
| 107 | "--scheduler", | 114 | "--scheduler", |
| 108 | type=str, | 115 | type=str, |
| 109 | choices=["plms", "ddim", "klms", "dpmsm", "dpmss", "euler_a", "kdpm2", "kdpm2_a"], | 116 | choices=["plms", "ddim", "klms", "dpmsm", "dpmss", "euler_a", "kdpm2", "kdpm2_a"], |
| @@ -184,7 +191,16 @@ def save_args(basepath, args, extra={}): | |||
| 184 | json.dump(info, f, indent=4) | 191 | json.dump(info, f, indent=4) |
| 185 | 192 | ||
| 186 | 193 | ||
| 187 | def create_pipeline(model, embeddings_dir, dtype): | 194 | def load_embeddings(pipeline, embeddings_dir): |
| 195 | added_tokens = load_embeddings_from_dir( | ||
| 196 | pipeline.tokenizer, | ||
| 197 | pipeline.text_encoder.text_model.embeddings, | ||
| 198 | Path(embeddings_dir) | ||
| 199 | ) | ||
| 200 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {added_tokens}") | ||
| 201 | |||
| 202 | |||
| 203 | def create_pipeline(model, dtype): | ||
| 188 | print("Loading Stable Diffusion pipeline...") | 204 | print("Loading Stable Diffusion pipeline...") |
| 189 | 205 | ||
| 190 | tokenizer = MultiCLIPTokenizer.from_pretrained(model, subfolder='tokenizer', torch_dtype=dtype) | 206 | tokenizer = MultiCLIPTokenizer.from_pretrained(model, subfolder='tokenizer', torch_dtype=dtype) |
| @@ -193,10 +209,7 @@ def create_pipeline(model, embeddings_dir, dtype): | |||
| 193 | unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype) | 209 | unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype) |
| 194 | scheduler = DDIMScheduler.from_pretrained(model, subfolder='scheduler', torch_dtype=dtype) | 210 | scheduler = DDIMScheduler.from_pretrained(model, subfolder='scheduler', torch_dtype=dtype) |
| 195 | 211 | ||
| 196 | embeddings = patch_managed_embeddings(text_encoder) | 212 | patch_managed_embeddings(text_encoder) |
| 197 | added_tokens = load_embeddings_from_dir(tokenizer, embeddings, Path(embeddings_dir)) | ||
| 198 | |||
| 199 | print(f"Added {len(added_tokens)} tokens from embeddings dir: {added_tokens}") | ||
| 200 | 213 | ||
| 201 | pipeline = VlpnStableDiffusion( | 214 | pipeline = VlpnStableDiffusion( |
| 202 | text_encoder=text_encoder, | 215 | text_encoder=text_encoder, |
| @@ -220,7 +233,14 @@ def generate(output_dir, pipeline, args): | |||
| 220 | args.prompt = [args.prompt] | 233 | args.prompt = [args.prompt] |
| 221 | 234 | ||
| 222 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | 235 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") |
| 223 | output_dir = output_dir.joinpath(f"{now}_{slugify(args.prompt[0])[:100]}") | 236 | use_subdirs = len(args.prompt) != 1 |
| 237 | if use_subdirs: | ||
| 238 | if len(args.project) != 0: | ||
| 239 | output_dir = output_dir.joinpath(f"{now}_{slugify(args.project)}") | ||
| 240 | else: | ||
| 241 | output_dir = output_dir.joinpath(now) | ||
| 242 | else: | ||
| 243 | output_dir = output_dir.joinpath(f"{now}_{slugify(args.prompt[0])[:100]}") | ||
| 224 | output_dir.mkdir(parents=True, exist_ok=True) | 244 | output_dir.mkdir(parents=True, exist_ok=True) |
| 225 | 245 | ||
| 226 | args.seed = args.seed or torch.random.seed() | 246 | args.seed = args.seed or torch.random.seed() |
| @@ -257,7 +277,8 @@ def generate(output_dir, pipeline, args): | |||
| 257 | dynamic_ncols=True | 277 | dynamic_ncols=True |
| 258 | ) | 278 | ) |
| 259 | 279 | ||
| 260 | generator = torch.Generator(device="cuda").manual_seed(args.seed + i) | 280 | seed = args.seed + i |
| 281 | generator = torch.Generator(device="cuda").manual_seed(seed) | ||
| 261 | images = pipeline( | 282 | images = pipeline( |
| 262 | prompt=args.prompt, | 283 | prompt=args.prompt, |
| 263 | negative_prompt=args.negative_prompt, | 284 | negative_prompt=args.negative_prompt, |
| @@ -272,8 +293,13 @@ def generate(output_dir, pipeline, args): | |||
| 272 | ).images | 293 | ).images |
| 273 | 294 | ||
| 274 | for j, image in enumerate(images): | 295 | for j, image in enumerate(images): |
| 275 | image.save(output_dir.joinpath(f"{args.seed + i}_{j}.png")) | 296 | image_dir = output_dir |
| 276 | image.save(output_dir.joinpath(f"{args.seed + i}_{j}.jpg"), quality=85) | 297 | if use_subdirs: |
| 298 | idx = j % len(args.prompt) | ||
| 299 | image_dir = image_dir.joinpath(slugify(args.prompt[idx])[:100]) | ||
| 300 | image_dir.mkdir(parents=True, exist_ok=True) | ||
| 301 | image.save(image_dir.joinpath(f"{seed}_{j}.png")) | ||
| 302 | image.save(image_dir.joinpath(f"{seed}_{j}.jpg"), quality=85) | ||
| 277 | 303 | ||
| 278 | if torch.cuda.is_available(): | 304 | if torch.cuda.is_available(): |
| 279 | torch.cuda.empty_cache() | 305 | torch.cuda.empty_cache() |
| @@ -283,10 +309,11 @@ class CmdParse(cmd.Cmd): | |||
| 283 | prompt = 'dream> ' | 309 | prompt = 'dream> ' |
| 284 | commands = [] | 310 | commands = [] |
| 285 | 311 | ||
| 286 | def __init__(self, output_dir, pipeline, parser): | 312 | def __init__(self, output_dir, ti_embeddings_dir, pipeline, parser): |
| 287 | super().__init__() | 313 | super().__init__() |
| 288 | 314 | ||
| 289 | self.output_dir = output_dir | 315 | self.output_dir = output_dir |
| 316 | self.ti_embeddings_dir = ti_embeddings_dir | ||
| 290 | self.pipeline = pipeline | 317 | self.pipeline = pipeline |
| 291 | self.parser = parser | 318 | self.parser = parser |
| 292 | 319 | ||
| @@ -302,6 +329,10 @@ class CmdParse(cmd.Cmd): | |||
| 302 | if elements[0] == 'q': | 329 | if elements[0] == 'q': |
| 303 | return True | 330 | return True |
| 304 | 331 | ||
| 332 | if elements[0] == 'reload_embeddings': | ||
| 333 | load_embeddings(self.pipeline, self.ti_embeddings_dir) | ||
| 334 | return | ||
| 335 | |||
| 305 | try: | 336 | try: |
| 306 | args = run_parser(self.parser, default_cmds, elements) | 337 | args = run_parser(self.parser, default_cmds, elements) |
| 307 | 338 | ||
| @@ -337,9 +368,10 @@ def main(): | |||
| 337 | output_dir = Path(args.output_dir) | 368 | output_dir = Path(args.output_dir) |
| 338 | dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.precision] | 369 | dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.precision] |
| 339 | 370 | ||
| 340 | pipeline = create_pipeline(args.model, args.ti_embeddings_dir, dtype) | 371 | pipeline = create_pipeline(args.model, dtype) |
| 372 | load_embeddings(pipeline, args.ti_embeddings_dir) | ||
| 341 | cmd_parser = create_cmd_parser() | 373 | cmd_parser = create_cmd_parser() |
| 342 | cmd_prompt = CmdParse(output_dir, pipeline, cmd_parser) | 374 | cmd_prompt = CmdParse(output_dir, args.ti_embeddings_dir, pipeline, cmd_parser) |
| 343 | cmd_prompt.cmdloop() | 375 | cmd_prompt.cmdloop() |
| 344 | 376 | ||
| 345 | 377 | ||
diff --git a/train_dreambooth.py b/train_dreambooth.py index 5e6e35d..2e0696b 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -269,6 +269,12 @@ def parse_args(): | |||
| 269 | help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' | 269 | help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' |
| 270 | ) | 270 | ) |
| 271 | parser.add_argument( | 271 | parser.add_argument( |
| 272 | "--lr_min_lr", | ||
| 273 | type=float, | ||
| 274 | default=None, | ||
| 275 | help="Minimum learning rate in the lr scheduler." | ||
| 276 | ) | ||
| 277 | parser.add_argument( | ||
| 272 | "--use_ema", | 278 | "--use_ema", |
| 273 | action="store_true", | 279 | action="store_true", |
| 274 | default=True, | 280 | default=True, |
| @@ -799,6 +805,7 @@ def main(): | |||
| 799 | warmup_steps = args.lr_warmup_epochs * num_update_steps_per_epoch * args.gradient_accumulation_steps | 805 | warmup_steps = args.lr_warmup_epochs * num_update_steps_per_epoch * args.gradient_accumulation_steps |
| 800 | 806 | ||
| 801 | if args.lr_scheduler == "one_cycle": | 807 | if args.lr_scheduler == "one_cycle": |
| 808 | lr_min_lr = 0.04 if args.lr_min_lr is None else args.lr_min_lr / args.learning_rate | ||
| 802 | lr_scheduler = get_one_cycle_schedule( | 809 | lr_scheduler = get_one_cycle_schedule( |
| 803 | optimizer=optimizer, | 810 | optimizer=optimizer, |
| 804 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | 811 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, |
| @@ -806,6 +813,7 @@ def main(): | |||
| 806 | annealing=args.lr_annealing_func, | 813 | annealing=args.lr_annealing_func, |
| 807 | warmup_exp=args.lr_warmup_exp, | 814 | warmup_exp=args.lr_warmup_exp, |
| 808 | annealing_exp=args.lr_annealing_exp, | 815 | annealing_exp=args.lr_annealing_exp, |
| 816 | min_lr=lr_min_lr, | ||
| 809 | ) | 817 | ) |
| 810 | elif args.lr_scheduler == "cosine_with_restarts": | 818 | elif args.lr_scheduler == "cosine_with_restarts": |
| 811 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( | 819 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( |
diff --git a/train_ti.py b/train_ti.py index 6f116c3..1b60f64 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -260,6 +260,12 @@ def parse_args(): | |||
| 260 | help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' | 260 | help='If lr_annealing_func is "half_cos" or "cos", exponent to modify the function' |
| 261 | ) | 261 | ) |
| 262 | parser.add_argument( | 262 | parser.add_argument( |
| 263 | "--lr_min_lr", | ||
| 264 | type=float, | ||
| 265 | default=None, | ||
| 266 | help="Minimum learning rate in the lr scheduler." | ||
| 267 | ) | ||
| 268 | parser.add_argument( | ||
| 263 | "--use_8bit_adam", | 269 | "--use_8bit_adam", |
| 264 | action="store_true", | 270 | action="store_true", |
| 265 | help="Whether or not to use 8-bit Adam from bitsandbytes." | 271 | help="Whether or not to use 8-bit Adam from bitsandbytes." |
| @@ -744,6 +750,7 @@ def main(): | |||
| 744 | if args.find_lr: | 750 | if args.find_lr: |
| 745 | lr_scheduler = None | 751 | lr_scheduler = None |
| 746 | elif args.lr_scheduler == "one_cycle": | 752 | elif args.lr_scheduler == "one_cycle": |
| 753 | lr_min_lr = 0.04 if args.lr_min_lr is None else args.lr_min_lr / args.learning_rate | ||
| 747 | lr_scheduler = get_one_cycle_schedule( | 754 | lr_scheduler = get_one_cycle_schedule( |
| 748 | optimizer=optimizer, | 755 | optimizer=optimizer, |
| 749 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | 756 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, |
| @@ -751,6 +758,7 @@ def main(): | |||
| 751 | annealing=args.lr_annealing_func, | 758 | annealing=args.lr_annealing_func, |
| 752 | warmup_exp=args.lr_warmup_exp, | 759 | warmup_exp=args.lr_warmup_exp, |
| 753 | annealing_exp=args.lr_annealing_exp, | 760 | annealing_exp=args.lr_annealing_exp, |
| 761 | min_lr=lr_min_lr, | ||
| 754 | ) | 762 | ) |
| 755 | elif args.lr_scheduler == "cosine_with_restarts": | 763 | elif args.lr_scheduler == "cosine_with_restarts": |
| 756 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( | 764 | lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup( |
diff --git a/training/optimization.py b/training/optimization.py index 14c2bd5..dd84f9c 100644 --- a/training/optimization.py +++ b/training/optimization.py | |||
| @@ -5,10 +5,6 @@ from functools import partial | |||
| 5 | import torch | 5 | import torch |
| 6 | from torch.optim.lr_scheduler import LambdaLR | 6 | from torch.optim.lr_scheduler import LambdaLR |
| 7 | 7 | ||
| 8 | from diffusers.utils import logging | ||
| 9 | |||
| 10 | logger = logging.get_logger(__name__) | ||
| 11 | |||
| 12 | 8 | ||
| 13 | class OneCyclePhase(NamedTuple): | 9 | class OneCyclePhase(NamedTuple): |
| 14 | step_min: int | 10 | step_min: int |
