diff options
author | Volpeon <git@volpeon.ink> | 2022-10-03 17:38:44 +0200 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-10-03 17:38:44 +0200 |
commit | f23fd5184b8ba4ec04506495f4a61726e50756f7 (patch) | |
tree | d4c5666b291316ed95437cc1c917b03ef3b679da | |
parent | Added negative prompt support for training scripts (diff) | |
download | textual-inversion-diff-f23fd5184b8ba4ec04506495f4a61726e50756f7.tar.gz textual-inversion-diff-f23fd5184b8ba4ec04506495f4a61726e50756f7.tar.bz2 textual-inversion-diff-f23fd5184b8ba4ec04506495f4a61726e50756f7.zip |
Small perf improvements
-rw-r--r-- | data/dreambooth/csv.py | 5 | ||||
-rw-r--r-- | data/textual_inversion/csv.py | 4 | ||||
-rw-r--r-- | dreambooth.py | 89 | ||||
-rw-r--r-- | infer.py | 5 | ||||
-rw-r--r-- | textual_inversion.py | 6 |
5 files changed, 58 insertions, 51 deletions
diff --git a/data/dreambooth/csv.py b/data/dreambooth/csv.py index 71aa1eb..c0b0067 100644 --- a/data/dreambooth/csv.py +++ b/data/dreambooth/csv.py | |||
@@ -70,8 +70,9 @@ class CSVDataModule(pl.LightningDataModule): | |||
70 | size=self.size, interpolation=self.interpolation, identifier=self.identifier, | 70 | size=self.size, interpolation=self.interpolation, identifier=self.identifier, |
71 | center_crop=self.center_crop, batch_size=self.batch_size) | 71 | center_crop=self.center_crop, batch_size=self.batch_size) |
72 | self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, | 72 | self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, |
73 | shuffle=True, collate_fn=self.collate_fn) | 73 | shuffle=True, pin_memory=True, collate_fn=self.collate_fn) |
74 | self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, collate_fn=self.collate_fn) | 74 | self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, |
75 | pin_memory=True, collate_fn=self.collate_fn) | ||
75 | 76 | ||
76 | def train_dataloader(self): | 77 | def train_dataloader(self): |
77 | return self.train_dataloader_ | 78 | return self.train_dataloader_ |
diff --git a/data/textual_inversion/csv.py b/data/textual_inversion/csv.py index 64f0c28..852b1cb 100644 --- a/data/textual_inversion/csv.py +++ b/data/textual_inversion/csv.py | |||
@@ -60,8 +60,8 @@ class CSVDataModule(pl.LightningDataModule): | |||
60 | placeholder_token=self.placeholder_token, center_crop=self.center_crop) | 60 | placeholder_token=self.placeholder_token, center_crop=self.center_crop) |
61 | val_dataset = CSVDataset(self.data_val, self.tokenizer, size=self.size, interpolation=self.interpolation, | 61 | val_dataset = CSVDataset(self.data_val, self.tokenizer, size=self.size, interpolation=self.interpolation, |
62 | placeholder_token=self.placeholder_token, center_crop=self.center_crop) | 62 | placeholder_token=self.placeholder_token, center_crop=self.center_crop) |
63 | self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True) | 63 | self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, pin_memory=True, shuffle=True) |
64 | self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size) | 64 | self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, pin_memory=True) |
65 | 65 | ||
66 | def train_dataloader(self): | 66 | def train_dataloader(self): |
67 | return self.train_dataloader_ | 67 | return self.train_dataloader_ |
diff --git a/dreambooth.py b/dreambooth.py index 5fbf172..9d6b8d6 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
@@ -13,7 +13,7 @@ import torch.utils.checkpoint | |||
13 | from accelerate import Accelerator | 13 | from accelerate import Accelerator |
14 | from accelerate.logging import get_logger | 14 | from accelerate.logging import get_logger |
15 | from accelerate.utils import LoggerType, set_seed | 15 | from accelerate.utils import LoggerType, set_seed |
16 | from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, LMSDiscreteScheduler, StableDiffusionPipeline, UNet2DConditionModel | 16 | from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel |
17 | from schedulers.scheduling_euler_a import EulerAScheduler | 17 | from schedulers.scheduling_euler_a import EulerAScheduler |
18 | from diffusers.optimization import get_scheduler | 18 | from diffusers.optimization import get_scheduler |
19 | from pipelines.stable_diffusion.no_check import NoCheck | 19 | from pipelines.stable_diffusion.no_check import NoCheck |
@@ -30,6 +30,9 @@ from data.dreambooth.prompt import PromptDataset | |||
30 | logger = get_logger(__name__) | 30 | logger = get_logger(__name__) |
31 | 31 | ||
32 | 32 | ||
33 | torch.backends.cuda.matmul.allow_tf32 = True | ||
34 | |||
35 | |||
33 | def parse_args(): | 36 | def parse_args(): |
34 | parser = argparse.ArgumentParser( | 37 | parser = argparse.ArgumentParser( |
35 | description="Simple example of a training script." | 38 | description="Simple example of a training script." |
@@ -346,7 +349,7 @@ class Checkpointer: | |||
346 | print("Saving model...") | 349 | print("Saving model...") |
347 | 350 | ||
348 | unwrapped = self.accelerator.unwrap_model(self.unet) | 351 | unwrapped = self.accelerator.unwrap_model(self.unet) |
349 | pipeline = StableDiffusionPipeline( | 352 | pipeline = VlpnStableDiffusion( |
350 | text_encoder=self.text_encoder, | 353 | text_encoder=self.text_encoder, |
351 | vae=self.vae, | 354 | vae=self.vae, |
352 | unet=self.accelerator.unwrap_model(self.unet), | 355 | unet=self.accelerator.unwrap_model(self.unet), |
@@ -354,8 +357,6 @@ class Checkpointer: | |||
354 | scheduler=PNDMScheduler( | 357 | scheduler=PNDMScheduler( |
355 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True | 358 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True |
356 | ), | 359 | ), |
357 | safety_checker=NoCheck(), | ||
358 | feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), | ||
359 | ) | 360 | ) |
360 | pipeline.enable_attention_slicing() | 361 | pipeline.enable_attention_slicing() |
361 | pipeline.save_pretrained(f"{self.output_dir}/model") | 362 | pipeline.save_pretrained(f"{self.output_dir}/model") |
@@ -381,7 +382,6 @@ class Checkpointer: | |||
381 | unet=unwrapped, | 382 | unet=unwrapped, |
382 | tokenizer=self.tokenizer, | 383 | tokenizer=self.tokenizer, |
383 | scheduler=scheduler, | 384 | scheduler=scheduler, |
384 | feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), | ||
385 | ).to(self.accelerator.device) | 385 | ).to(self.accelerator.device) |
386 | pipeline.enable_attention_slicing() | 386 | pipeline.enable_attention_slicing() |
387 | 387 | ||
@@ -459,44 +459,6 @@ def main(): | |||
459 | if args.seed is not None: | 459 | if args.seed is not None: |
460 | set_seed(args.seed) | 460 | set_seed(args.seed) |
461 | 461 | ||
462 | if args.with_prior_preservation: | ||
463 | class_images_dir = Path(args.class_data_dir) | ||
464 | class_images_dir.mkdir(parents=True, exist_ok=True) | ||
465 | cur_class_images = len(list(class_images_dir.iterdir())) | ||
466 | |||
467 | if cur_class_images < args.num_class_images: | ||
468 | torch_dtype = torch.float32 | ||
469 | if accelerator.device.type == "cuda": | ||
470 | torch_dtype = {"no": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.mixed_precision] | ||
471 | |||
472 | pipeline = StableDiffusionPipeline.from_pretrained( | ||
473 | args.pretrained_model_name_or_path, torch_dtype=torch_dtype) | ||
474 | pipeline.enable_attention_slicing() | ||
475 | pipeline.set_progress_bar_config(disable=True) | ||
476 | pipeline.to(accelerator.device) | ||
477 | |||
478 | num_new_images = args.num_class_images - cur_class_images | ||
479 | logger.info(f"Number of class images to sample: {num_new_images}.") | ||
480 | |||
481 | sample_dataset = PromptDataset(args.class_prompt, num_new_images) | ||
482 | sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) | ||
483 | |||
484 | sample_dataloader = accelerator.prepare(sample_dataloader) | ||
485 | |||
486 | for example in tqdm( | ||
487 | sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process | ||
488 | ): | ||
489 | with accelerator.autocast(): | ||
490 | images = pipeline(example["prompt"]).images | ||
491 | |||
492 | for i, image in enumerate(images): | ||
493 | image.save(class_images_dir / f"{example['index'][i] + cur_class_images}.jpg") | ||
494 | |||
495 | del pipeline | ||
496 | |||
497 | if torch.cuda.is_available(): | ||
498 | torch.cuda.empty_cache() | ||
499 | |||
500 | # Load the tokenizer and add the placeholder token as a additional special token | 462 | # Load the tokenizer and add the placeholder token as a additional special token |
501 | if args.tokenizer_name: | 463 | if args.tokenizer_name: |
502 | tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) | 464 | tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) |
@@ -526,6 +488,47 @@ def main(): | |||
526 | freeze_params(vae.parameters()) | 488 | freeze_params(vae.parameters()) |
527 | freeze_params(text_encoder.parameters()) | 489 | freeze_params(text_encoder.parameters()) |
528 | 490 | ||
491 | # Generate class images, if necessary | ||
492 | if args.with_prior_preservation: | ||
493 | class_images_dir = Path(args.class_data_dir) | ||
494 | class_images_dir.mkdir(parents=True, exist_ok=True) | ||
495 | cur_class_images = len(list(class_images_dir.iterdir())) | ||
496 | |||
497 | if cur_class_images < args.num_class_images: | ||
498 | scheduler = EulerAScheduler( | ||
499 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | ||
500 | ) | ||
501 | |||
502 | pipeline = VlpnStableDiffusion( | ||
503 | text_encoder=text_encoder, | ||
504 | vae=vae, | ||
505 | unet=unet, | ||
506 | tokenizer=tokenizer, | ||
507 | scheduler=scheduler, | ||
508 | ).to(accelerator.device) | ||
509 | pipeline.enable_attention_slicing() | ||
510 | pipeline.set_progress_bar_config(disable=True) | ||
511 | |||
512 | num_new_images = args.num_class_images - cur_class_images | ||
513 | logger.info(f"Number of class images to sample: {num_new_images}.") | ||
514 | |||
515 | sample_dataset = PromptDataset(args.class_prompt, num_new_images) | ||
516 | sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) | ||
517 | |||
518 | sample_dataloader = accelerator.prepare(sample_dataloader) | ||
519 | |||
520 | for example in tqdm(sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process): | ||
521 | with accelerator.autocast(): | ||
522 | images = pipeline(example["prompt"]).images | ||
523 | |||
524 | for i, image in enumerate(images): | ||
525 | image.save(class_images_dir / f"{example['index'][i] + cur_class_images}.jpg") | ||
526 | |||
527 | del pipeline | ||
528 | |||
529 | if torch.cuda.is_available(): | ||
530 | torch.cuda.empty_cache() | ||
531 | |||
529 | if args.scale_lr: | 532 | if args.scale_lr: |
530 | args.learning_rate = ( | 533 | args.learning_rate = ( |
531 | args.learning_rate * args.gradient_accumulation_steps * | 534 | args.learning_rate * args.gradient_accumulation_steps * |
@@ -16,6 +16,9 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | |||
16 | from schedulers.scheduling_euler_a import EulerAScheduler | 16 | from schedulers.scheduling_euler_a import EulerAScheduler |
17 | 17 | ||
18 | 18 | ||
19 | torch.backends.cuda.matmul.allow_tf32 = True | ||
20 | |||
21 | |||
19 | default_args = { | 22 | default_args = { |
20 | "model": None, | 23 | "model": None, |
21 | "scheduler": "euler_a", | 24 | "scheduler": "euler_a", |
@@ -166,7 +169,6 @@ def create_pipeline(model, scheduler, dtype): | |||
166 | text_encoder = CLIPTextModel.from_pretrained(model + '/text_encoder', torch_dtype=dtype) | 169 | text_encoder = CLIPTextModel.from_pretrained(model + '/text_encoder', torch_dtype=dtype) |
167 | vae = AutoencoderKL.from_pretrained(model + '/vae', torch_dtype=dtype) | 170 | vae = AutoencoderKL.from_pretrained(model + '/vae', torch_dtype=dtype) |
168 | unet = UNet2DConditionModel.from_pretrained(model + '/unet', torch_dtype=dtype) | 171 | unet = UNet2DConditionModel.from_pretrained(model + '/unet', torch_dtype=dtype) |
169 | feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32", torch_dtype=dtype) | ||
170 | 172 | ||
171 | if scheduler == "plms": | 173 | if scheduler == "plms": |
172 | scheduler = PNDMScheduler( | 174 | scheduler = PNDMScheduler( |
@@ -191,7 +193,6 @@ def create_pipeline(model, scheduler, dtype): | |||
191 | unet=unet, | 193 | unet=unet, |
192 | tokenizer=tokenizer, | 194 | tokenizer=tokenizer, |
193 | scheduler=scheduler, | 195 | scheduler=scheduler, |
194 | feature_extractor=feature_extractor | ||
195 | ) | 196 | ) |
196 | # pipeline.enable_attention_slicing() | 197 | # pipeline.enable_attention_slicing() |
197 | pipeline.to("cuda") | 198 | pipeline.to("cuda") |
diff --git a/textual_inversion.py b/textual_inversion.py index 00d460f..5fc2338 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
@@ -14,7 +14,7 @@ import torch.utils.checkpoint | |||
14 | from accelerate import Accelerator | 14 | from accelerate import Accelerator |
15 | from accelerate.logging import get_logger | 15 | from accelerate.logging import get_logger |
16 | from accelerate.utils import LoggerType, set_seed | 16 | from accelerate.utils import LoggerType, set_seed |
17 | from diffusers import AutoencoderKL, DDPMScheduler, LMSDiscreteScheduler, StableDiffusionPipeline, UNet2DConditionModel | 17 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel |
18 | from schedulers.scheduling_euler_a import EulerAScheduler | 18 | from schedulers.scheduling_euler_a import EulerAScheduler |
19 | from diffusers.optimization import get_scheduler | 19 | from diffusers.optimization import get_scheduler |
20 | from PIL import Image | 20 | from PIL import Image |
@@ -30,6 +30,9 @@ from data.textual_inversion.csv import CSVDataModule | |||
30 | logger = get_logger(__name__) | 30 | logger = get_logger(__name__) |
31 | 31 | ||
32 | 32 | ||
33 | torch.backends.cuda.matmul.allow_tf32 = True | ||
34 | |||
35 | |||
33 | def parse_args(): | 36 | def parse_args(): |
34 | parser = argparse.ArgumentParser( | 37 | parser = argparse.ArgumentParser( |
35 | description="Simple example of a training script." | 38 | description="Simple example of a training script." |
@@ -370,7 +373,6 @@ class Checkpointer: | |||
370 | unet=self.unet, | 373 | unet=self.unet, |
371 | tokenizer=self.tokenizer, | 374 | tokenizer=self.tokenizer, |
372 | scheduler=scheduler, | 375 | scheduler=scheduler, |
373 | feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), | ||
374 | ).to(self.accelerator.device) | 376 | ).to(self.accelerator.device) |
375 | pipeline.enable_attention_slicing() | 377 | pipeline.enable_attention_slicing() |
376 | 378 | ||