summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-10-03 17:38:44 +0200
committerVolpeon <git@volpeon.ink>2022-10-03 17:38:44 +0200
commitf23fd5184b8ba4ec04506495f4a61726e50756f7 (patch)
treed4c5666b291316ed95437cc1c917b03ef3b679da
parentAdded negative prompt support for training scripts (diff)
downloadtextual-inversion-diff-f23fd5184b8ba4ec04506495f4a61726e50756f7.tar.gz
textual-inversion-diff-f23fd5184b8ba4ec04506495f4a61726e50756f7.tar.bz2
textual-inversion-diff-f23fd5184b8ba4ec04506495f4a61726e50756f7.zip
Small perf improvements
-rw-r--r--data/dreambooth/csv.py5
-rw-r--r--data/textual_inversion/csv.py4
-rw-r--r--dreambooth.py89
-rw-r--r--infer.py5
-rw-r--r--textual_inversion.py6
5 files changed, 58 insertions, 51 deletions
diff --git a/data/dreambooth/csv.py b/data/dreambooth/csv.py
index 71aa1eb..c0b0067 100644
--- a/data/dreambooth/csv.py
+++ b/data/dreambooth/csv.py
@@ -70,8 +70,9 @@ class CSVDataModule(pl.LightningDataModule):
70 size=self.size, interpolation=self.interpolation, identifier=self.identifier, 70 size=self.size, interpolation=self.interpolation, identifier=self.identifier,
71 center_crop=self.center_crop, batch_size=self.batch_size) 71 center_crop=self.center_crop, batch_size=self.batch_size)
72 self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, 72 self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size,
73 shuffle=True, collate_fn=self.collate_fn) 73 shuffle=True, pin_memory=True, collate_fn=self.collate_fn)
74 self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, collate_fn=self.collate_fn) 74 self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size,
75 pin_memory=True, collate_fn=self.collate_fn)
75 76
76 def train_dataloader(self): 77 def train_dataloader(self):
77 return self.train_dataloader_ 78 return self.train_dataloader_
diff --git a/data/textual_inversion/csv.py b/data/textual_inversion/csv.py
index 64f0c28..852b1cb 100644
--- a/data/textual_inversion/csv.py
+++ b/data/textual_inversion/csv.py
@@ -60,8 +60,8 @@ class CSVDataModule(pl.LightningDataModule):
60 placeholder_token=self.placeholder_token, center_crop=self.center_crop) 60 placeholder_token=self.placeholder_token, center_crop=self.center_crop)
61 val_dataset = CSVDataset(self.data_val, self.tokenizer, size=self.size, interpolation=self.interpolation, 61 val_dataset = CSVDataset(self.data_val, self.tokenizer, size=self.size, interpolation=self.interpolation,
62 placeholder_token=self.placeholder_token, center_crop=self.center_crop) 62 placeholder_token=self.placeholder_token, center_crop=self.center_crop)
63 self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True) 63 self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, pin_memory=True, shuffle=True)
64 self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size) 64 self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, pin_memory=True)
65 65
66 def train_dataloader(self): 66 def train_dataloader(self):
67 return self.train_dataloader_ 67 return self.train_dataloader_
diff --git a/dreambooth.py b/dreambooth.py
index 5fbf172..9d6b8d6 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -13,7 +13,7 @@ import torch.utils.checkpoint
13from accelerate import Accelerator 13from accelerate import Accelerator
14from accelerate.logging import get_logger 14from accelerate.logging import get_logger
15from accelerate.utils import LoggerType, set_seed 15from accelerate.utils import LoggerType, set_seed
16from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, LMSDiscreteScheduler, StableDiffusionPipeline, UNet2DConditionModel 16from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel
17from schedulers.scheduling_euler_a import EulerAScheduler 17from schedulers.scheduling_euler_a import EulerAScheduler
18from diffusers.optimization import get_scheduler 18from diffusers.optimization import get_scheduler
19from pipelines.stable_diffusion.no_check import NoCheck 19from pipelines.stable_diffusion.no_check import NoCheck
@@ -30,6 +30,9 @@ from data.dreambooth.prompt import PromptDataset
30logger = get_logger(__name__) 30logger = get_logger(__name__)
31 31
32 32
33torch.backends.cuda.matmul.allow_tf32 = True
34
35
33def parse_args(): 36def parse_args():
34 parser = argparse.ArgumentParser( 37 parser = argparse.ArgumentParser(
35 description="Simple example of a training script." 38 description="Simple example of a training script."
@@ -346,7 +349,7 @@ class Checkpointer:
346 print("Saving model...") 349 print("Saving model...")
347 350
348 unwrapped = self.accelerator.unwrap_model(self.unet) 351 unwrapped = self.accelerator.unwrap_model(self.unet)
349 pipeline = StableDiffusionPipeline( 352 pipeline = VlpnStableDiffusion(
350 text_encoder=self.text_encoder, 353 text_encoder=self.text_encoder,
351 vae=self.vae, 354 vae=self.vae,
352 unet=self.accelerator.unwrap_model(self.unet), 355 unet=self.accelerator.unwrap_model(self.unet),
@@ -354,8 +357,6 @@ class Checkpointer:
354 scheduler=PNDMScheduler( 357 scheduler=PNDMScheduler(
355 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True 358 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
356 ), 359 ),
357 safety_checker=NoCheck(),
358 feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
359 ) 360 )
360 pipeline.enable_attention_slicing() 361 pipeline.enable_attention_slicing()
361 pipeline.save_pretrained(f"{self.output_dir}/model") 362 pipeline.save_pretrained(f"{self.output_dir}/model")
@@ -381,7 +382,6 @@ class Checkpointer:
381 unet=unwrapped, 382 unet=unwrapped,
382 tokenizer=self.tokenizer, 383 tokenizer=self.tokenizer,
383 scheduler=scheduler, 384 scheduler=scheduler,
384 feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
385 ).to(self.accelerator.device) 385 ).to(self.accelerator.device)
386 pipeline.enable_attention_slicing() 386 pipeline.enable_attention_slicing()
387 387
@@ -459,44 +459,6 @@ def main():
459 if args.seed is not None: 459 if args.seed is not None:
460 set_seed(args.seed) 460 set_seed(args.seed)
461 461
462 if args.with_prior_preservation:
463 class_images_dir = Path(args.class_data_dir)
464 class_images_dir.mkdir(parents=True, exist_ok=True)
465 cur_class_images = len(list(class_images_dir.iterdir()))
466
467 if cur_class_images < args.num_class_images:
468 torch_dtype = torch.float32
469 if accelerator.device.type == "cuda":
470 torch_dtype = {"no": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.mixed_precision]
471
472 pipeline = StableDiffusionPipeline.from_pretrained(
473 args.pretrained_model_name_or_path, torch_dtype=torch_dtype)
474 pipeline.enable_attention_slicing()
475 pipeline.set_progress_bar_config(disable=True)
476 pipeline.to(accelerator.device)
477
478 num_new_images = args.num_class_images - cur_class_images
479 logger.info(f"Number of class images to sample: {num_new_images}.")
480
481 sample_dataset = PromptDataset(args.class_prompt, num_new_images)
482 sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
483
484 sample_dataloader = accelerator.prepare(sample_dataloader)
485
486 for example in tqdm(
487 sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process
488 ):
489 with accelerator.autocast():
490 images = pipeline(example["prompt"]).images
491
492 for i, image in enumerate(images):
493 image.save(class_images_dir / f"{example['index'][i] + cur_class_images}.jpg")
494
495 del pipeline
496
497 if torch.cuda.is_available():
498 torch.cuda.empty_cache()
499
500 # Load the tokenizer and add the placeholder token as a additional special token 462 # Load the tokenizer and add the placeholder token as a additional special token
501 if args.tokenizer_name: 463 if args.tokenizer_name:
502 tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) 464 tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name)
@@ -526,6 +488,47 @@ def main():
526 freeze_params(vae.parameters()) 488 freeze_params(vae.parameters())
527 freeze_params(text_encoder.parameters()) 489 freeze_params(text_encoder.parameters())
528 490
491 # Generate class images, if necessary
492 if args.with_prior_preservation:
493 class_images_dir = Path(args.class_data_dir)
494 class_images_dir.mkdir(parents=True, exist_ok=True)
495 cur_class_images = len(list(class_images_dir.iterdir()))
496
497 if cur_class_images < args.num_class_images:
498 scheduler = EulerAScheduler(
499 beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
500 )
501
502 pipeline = VlpnStableDiffusion(
503 text_encoder=text_encoder,
504 vae=vae,
505 unet=unet,
506 tokenizer=tokenizer,
507 scheduler=scheduler,
508 ).to(accelerator.device)
509 pipeline.enable_attention_slicing()
510 pipeline.set_progress_bar_config(disable=True)
511
512 num_new_images = args.num_class_images - cur_class_images
513 logger.info(f"Number of class images to sample: {num_new_images}.")
514
515 sample_dataset = PromptDataset(args.class_prompt, num_new_images)
516 sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size)
517
518 sample_dataloader = accelerator.prepare(sample_dataloader)
519
520 for example in tqdm(sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process):
521 with accelerator.autocast():
522 images = pipeline(example["prompt"]).images
523
524 for i, image in enumerate(images):
525 image.save(class_images_dir / f"{example['index'][i] + cur_class_images}.jpg")
526
527 del pipeline
528
529 if torch.cuda.is_available():
530 torch.cuda.empty_cache()
531
529 if args.scale_lr: 532 if args.scale_lr:
530 args.learning_rate = ( 533 args.learning_rate = (
531 args.learning_rate * args.gradient_accumulation_steps * 534 args.learning_rate * args.gradient_accumulation_steps *
diff --git a/infer.py b/infer.py
index b15b17f..3dc0f32 100644
--- a/infer.py
+++ b/infer.py
@@ -16,6 +16,9 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
16from schedulers.scheduling_euler_a import EulerAScheduler 16from schedulers.scheduling_euler_a import EulerAScheduler
17 17
18 18
19torch.backends.cuda.matmul.allow_tf32 = True
20
21
19default_args = { 22default_args = {
20 "model": None, 23 "model": None,
21 "scheduler": "euler_a", 24 "scheduler": "euler_a",
@@ -166,7 +169,6 @@ def create_pipeline(model, scheduler, dtype):
166 text_encoder = CLIPTextModel.from_pretrained(model + '/text_encoder', torch_dtype=dtype) 169 text_encoder = CLIPTextModel.from_pretrained(model + '/text_encoder', torch_dtype=dtype)
167 vae = AutoencoderKL.from_pretrained(model + '/vae', torch_dtype=dtype) 170 vae = AutoencoderKL.from_pretrained(model + '/vae', torch_dtype=dtype)
168 unet = UNet2DConditionModel.from_pretrained(model + '/unet', torch_dtype=dtype) 171 unet = UNet2DConditionModel.from_pretrained(model + '/unet', torch_dtype=dtype)
169 feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32", torch_dtype=dtype)
170 172
171 if scheduler == "plms": 173 if scheduler == "plms":
172 scheduler = PNDMScheduler( 174 scheduler = PNDMScheduler(
@@ -191,7 +193,6 @@ def create_pipeline(model, scheduler, dtype):
191 unet=unet, 193 unet=unet,
192 tokenizer=tokenizer, 194 tokenizer=tokenizer,
193 scheduler=scheduler, 195 scheduler=scheduler,
194 feature_extractor=feature_extractor
195 ) 196 )
196 # pipeline.enable_attention_slicing() 197 # pipeline.enable_attention_slicing()
197 pipeline.to("cuda") 198 pipeline.to("cuda")
diff --git a/textual_inversion.py b/textual_inversion.py
index 00d460f..5fc2338 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -14,7 +14,7 @@ import torch.utils.checkpoint
14from accelerate import Accelerator 14from accelerate import Accelerator
15from accelerate.logging import get_logger 15from accelerate.logging import get_logger
16from accelerate.utils import LoggerType, set_seed 16from accelerate.utils import LoggerType, set_seed
17from diffusers import AutoencoderKL, DDPMScheduler, LMSDiscreteScheduler, StableDiffusionPipeline, UNet2DConditionModel 17from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel
18from schedulers.scheduling_euler_a import EulerAScheduler 18from schedulers.scheduling_euler_a import EulerAScheduler
19from diffusers.optimization import get_scheduler 19from diffusers.optimization import get_scheduler
20from PIL import Image 20from PIL import Image
@@ -30,6 +30,9 @@ from data.textual_inversion.csv import CSVDataModule
30logger = get_logger(__name__) 30logger = get_logger(__name__)
31 31
32 32
33torch.backends.cuda.matmul.allow_tf32 = True
34
35
33def parse_args(): 36def parse_args():
34 parser = argparse.ArgumentParser( 37 parser = argparse.ArgumentParser(
35 description="Simple example of a training script." 38 description="Simple example of a training script."
@@ -370,7 +373,6 @@ class Checkpointer:
370 unet=self.unet, 373 unet=self.unet,
371 tokenizer=self.tokenizer, 374 tokenizer=self.tokenizer,
372 scheduler=scheduler, 375 scheduler=scheduler,
373 feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
374 ).to(self.accelerator.device) 376 ).to(self.accelerator.device)
375 pipeline.enable_attention_slicing() 377 pipeline.enable_attention_slicing()
376 378