diff options
| -rw-r--r-- | data/dreambooth/csv.py | 5 | ||||
| -rw-r--r-- | data/textual_inversion/csv.py | 4 | ||||
| -rw-r--r-- | dreambooth.py | 89 | ||||
| -rw-r--r-- | infer.py | 5 | ||||
| -rw-r--r-- | textual_inversion.py | 6 |
5 files changed, 58 insertions, 51 deletions
diff --git a/data/dreambooth/csv.py b/data/dreambooth/csv.py index 71aa1eb..c0b0067 100644 --- a/data/dreambooth/csv.py +++ b/data/dreambooth/csv.py | |||
| @@ -70,8 +70,9 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 70 | size=self.size, interpolation=self.interpolation, identifier=self.identifier, | 70 | size=self.size, interpolation=self.interpolation, identifier=self.identifier, |
| 71 | center_crop=self.center_crop, batch_size=self.batch_size) | 71 | center_crop=self.center_crop, batch_size=self.batch_size) |
| 72 | self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, | 72 | self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, |
| 73 | shuffle=True, collate_fn=self.collate_fn) | 73 | shuffle=True, pin_memory=True, collate_fn=self.collate_fn) |
| 74 | self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, collate_fn=self.collate_fn) | 74 | self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, |
| 75 | pin_memory=True, collate_fn=self.collate_fn) | ||
| 75 | 76 | ||
| 76 | def train_dataloader(self): | 77 | def train_dataloader(self): |
| 77 | return self.train_dataloader_ | 78 | return self.train_dataloader_ |
diff --git a/data/textual_inversion/csv.py b/data/textual_inversion/csv.py index 64f0c28..852b1cb 100644 --- a/data/textual_inversion/csv.py +++ b/data/textual_inversion/csv.py | |||
| @@ -60,8 +60,8 @@ class CSVDataModule(pl.LightningDataModule): | |||
| 60 | placeholder_token=self.placeholder_token, center_crop=self.center_crop) | 60 | placeholder_token=self.placeholder_token, center_crop=self.center_crop) |
| 61 | val_dataset = CSVDataset(self.data_val, self.tokenizer, size=self.size, interpolation=self.interpolation, | 61 | val_dataset = CSVDataset(self.data_val, self.tokenizer, size=self.size, interpolation=self.interpolation, |
| 62 | placeholder_token=self.placeholder_token, center_crop=self.center_crop) | 62 | placeholder_token=self.placeholder_token, center_crop=self.center_crop) |
| 63 | self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, shuffle=True) | 63 | self.train_dataloader_ = DataLoader(train_dataset, batch_size=self.batch_size, pin_memory=True, shuffle=True) |
| 64 | self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size) | 64 | self.val_dataloader_ = DataLoader(val_dataset, batch_size=self.batch_size, pin_memory=True) |
| 65 | 65 | ||
| 66 | def train_dataloader(self): | 66 | def train_dataloader(self): |
| 67 | return self.train_dataloader_ | 67 | return self.train_dataloader_ |
diff --git a/dreambooth.py b/dreambooth.py index 5fbf172..9d6b8d6 100644 --- a/dreambooth.py +++ b/dreambooth.py | |||
| @@ -13,7 +13,7 @@ import torch.utils.checkpoint | |||
| 13 | from accelerate import Accelerator | 13 | from accelerate import Accelerator |
| 14 | from accelerate.logging import get_logger | 14 | from accelerate.logging import get_logger |
| 15 | from accelerate.utils import LoggerType, set_seed | 15 | from accelerate.utils import LoggerType, set_seed |
| 16 | from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, LMSDiscreteScheduler, StableDiffusionPipeline, UNet2DConditionModel | 16 | from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, StableDiffusionPipeline, UNet2DConditionModel |
| 17 | from schedulers.scheduling_euler_a import EulerAScheduler | 17 | from schedulers.scheduling_euler_a import EulerAScheduler |
| 18 | from diffusers.optimization import get_scheduler | 18 | from diffusers.optimization import get_scheduler |
| 19 | from pipelines.stable_diffusion.no_check import NoCheck | 19 | from pipelines.stable_diffusion.no_check import NoCheck |
| @@ -30,6 +30,9 @@ from data.dreambooth.prompt import PromptDataset | |||
| 30 | logger = get_logger(__name__) | 30 | logger = get_logger(__name__) |
| 31 | 31 | ||
| 32 | 32 | ||
| 33 | torch.backends.cuda.matmul.allow_tf32 = True | ||
| 34 | |||
| 35 | |||
| 33 | def parse_args(): | 36 | def parse_args(): |
| 34 | parser = argparse.ArgumentParser( | 37 | parser = argparse.ArgumentParser( |
| 35 | description="Simple example of a training script." | 38 | description="Simple example of a training script." |
| @@ -346,7 +349,7 @@ class Checkpointer: | |||
| 346 | print("Saving model...") | 349 | print("Saving model...") |
| 347 | 350 | ||
| 348 | unwrapped = self.accelerator.unwrap_model(self.unet) | 351 | unwrapped = self.accelerator.unwrap_model(self.unet) |
| 349 | pipeline = StableDiffusionPipeline( | 352 | pipeline = VlpnStableDiffusion( |
| 350 | text_encoder=self.text_encoder, | 353 | text_encoder=self.text_encoder, |
| 351 | vae=self.vae, | 354 | vae=self.vae, |
| 352 | unet=self.accelerator.unwrap_model(self.unet), | 355 | unet=self.accelerator.unwrap_model(self.unet), |
| @@ -354,8 +357,6 @@ class Checkpointer: | |||
| 354 | scheduler=PNDMScheduler( | 357 | scheduler=PNDMScheduler( |
| 355 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True | 358 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True |
| 356 | ), | 359 | ), |
| 357 | safety_checker=NoCheck(), | ||
| 358 | feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), | ||
| 359 | ) | 360 | ) |
| 360 | pipeline.enable_attention_slicing() | 361 | pipeline.enable_attention_slicing() |
| 361 | pipeline.save_pretrained(f"{self.output_dir}/model") | 362 | pipeline.save_pretrained(f"{self.output_dir}/model") |
| @@ -381,7 +382,6 @@ class Checkpointer: | |||
| 381 | unet=unwrapped, | 382 | unet=unwrapped, |
| 382 | tokenizer=self.tokenizer, | 383 | tokenizer=self.tokenizer, |
| 383 | scheduler=scheduler, | 384 | scheduler=scheduler, |
| 384 | feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), | ||
| 385 | ).to(self.accelerator.device) | 385 | ).to(self.accelerator.device) |
| 386 | pipeline.enable_attention_slicing() | 386 | pipeline.enable_attention_slicing() |
| 387 | 387 | ||
| @@ -459,44 +459,6 @@ def main(): | |||
| 459 | if args.seed is not None: | 459 | if args.seed is not None: |
| 460 | set_seed(args.seed) | 460 | set_seed(args.seed) |
| 461 | 461 | ||
| 462 | if args.with_prior_preservation: | ||
| 463 | class_images_dir = Path(args.class_data_dir) | ||
| 464 | class_images_dir.mkdir(parents=True, exist_ok=True) | ||
| 465 | cur_class_images = len(list(class_images_dir.iterdir())) | ||
| 466 | |||
| 467 | if cur_class_images < args.num_class_images: | ||
| 468 | torch_dtype = torch.float32 | ||
| 469 | if accelerator.device.type == "cuda": | ||
| 470 | torch_dtype = {"no": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.mixed_precision] | ||
| 471 | |||
| 472 | pipeline = StableDiffusionPipeline.from_pretrained( | ||
| 473 | args.pretrained_model_name_or_path, torch_dtype=torch_dtype) | ||
| 474 | pipeline.enable_attention_slicing() | ||
| 475 | pipeline.set_progress_bar_config(disable=True) | ||
| 476 | pipeline.to(accelerator.device) | ||
| 477 | |||
| 478 | num_new_images = args.num_class_images - cur_class_images | ||
| 479 | logger.info(f"Number of class images to sample: {num_new_images}.") | ||
| 480 | |||
| 481 | sample_dataset = PromptDataset(args.class_prompt, num_new_images) | ||
| 482 | sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) | ||
| 483 | |||
| 484 | sample_dataloader = accelerator.prepare(sample_dataloader) | ||
| 485 | |||
| 486 | for example in tqdm( | ||
| 487 | sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process | ||
| 488 | ): | ||
| 489 | with accelerator.autocast(): | ||
| 490 | images = pipeline(example["prompt"]).images | ||
| 491 | |||
| 492 | for i, image in enumerate(images): | ||
| 493 | image.save(class_images_dir / f"{example['index'][i] + cur_class_images}.jpg") | ||
| 494 | |||
| 495 | del pipeline | ||
| 496 | |||
| 497 | if torch.cuda.is_available(): | ||
| 498 | torch.cuda.empty_cache() | ||
| 499 | |||
| 500 | # Load the tokenizer and add the placeholder token as a additional special token | 462 | # Load the tokenizer and add the placeholder token as a additional special token |
| 501 | if args.tokenizer_name: | 463 | if args.tokenizer_name: |
| 502 | tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) | 464 | tokenizer = CLIPTokenizer.from_pretrained(args.tokenizer_name) |
| @@ -526,6 +488,47 @@ def main(): | |||
| 526 | freeze_params(vae.parameters()) | 488 | freeze_params(vae.parameters()) |
| 527 | freeze_params(text_encoder.parameters()) | 489 | freeze_params(text_encoder.parameters()) |
| 528 | 490 | ||
| 491 | # Generate class images, if necessary | ||
| 492 | if args.with_prior_preservation: | ||
| 493 | class_images_dir = Path(args.class_data_dir) | ||
| 494 | class_images_dir.mkdir(parents=True, exist_ok=True) | ||
| 495 | cur_class_images = len(list(class_images_dir.iterdir())) | ||
| 496 | |||
| 497 | if cur_class_images < args.num_class_images: | ||
| 498 | scheduler = EulerAScheduler( | ||
| 499 | beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear" | ||
| 500 | ) | ||
| 501 | |||
| 502 | pipeline = VlpnStableDiffusion( | ||
| 503 | text_encoder=text_encoder, | ||
| 504 | vae=vae, | ||
| 505 | unet=unet, | ||
| 506 | tokenizer=tokenizer, | ||
| 507 | scheduler=scheduler, | ||
| 508 | ).to(accelerator.device) | ||
| 509 | pipeline.enable_attention_slicing() | ||
| 510 | pipeline.set_progress_bar_config(disable=True) | ||
| 511 | |||
| 512 | num_new_images = args.num_class_images - cur_class_images | ||
| 513 | logger.info(f"Number of class images to sample: {num_new_images}.") | ||
| 514 | |||
| 515 | sample_dataset = PromptDataset(args.class_prompt, num_new_images) | ||
| 516 | sample_dataloader = torch.utils.data.DataLoader(sample_dataset, batch_size=args.sample_batch_size) | ||
| 517 | |||
| 518 | sample_dataloader = accelerator.prepare(sample_dataloader) | ||
| 519 | |||
| 520 | for example in tqdm(sample_dataloader, desc="Generating class images", disable=not accelerator.is_local_main_process): | ||
| 521 | with accelerator.autocast(): | ||
| 522 | images = pipeline(example["prompt"]).images | ||
| 523 | |||
| 524 | for i, image in enumerate(images): | ||
| 525 | image.save(class_images_dir / f"{example['index'][i] + cur_class_images}.jpg") | ||
| 526 | |||
| 527 | del pipeline | ||
| 528 | |||
| 529 | if torch.cuda.is_available(): | ||
| 530 | torch.cuda.empty_cache() | ||
| 531 | |||
| 529 | if args.scale_lr: | 532 | if args.scale_lr: |
| 530 | args.learning_rate = ( | 533 | args.learning_rate = ( |
| 531 | args.learning_rate * args.gradient_accumulation_steps * | 534 | args.learning_rate * args.gradient_accumulation_steps * |
| @@ -16,6 +16,9 @@ from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | |||
| 16 | from schedulers.scheduling_euler_a import EulerAScheduler | 16 | from schedulers.scheduling_euler_a import EulerAScheduler |
| 17 | 17 | ||
| 18 | 18 | ||
| 19 | torch.backends.cuda.matmul.allow_tf32 = True | ||
| 20 | |||
| 21 | |||
| 19 | default_args = { | 22 | default_args = { |
| 20 | "model": None, | 23 | "model": None, |
| 21 | "scheduler": "euler_a", | 24 | "scheduler": "euler_a", |
| @@ -166,7 +169,6 @@ def create_pipeline(model, scheduler, dtype): | |||
| 166 | text_encoder = CLIPTextModel.from_pretrained(model + '/text_encoder', torch_dtype=dtype) | 169 | text_encoder = CLIPTextModel.from_pretrained(model + '/text_encoder', torch_dtype=dtype) |
| 167 | vae = AutoencoderKL.from_pretrained(model + '/vae', torch_dtype=dtype) | 170 | vae = AutoencoderKL.from_pretrained(model + '/vae', torch_dtype=dtype) |
| 168 | unet = UNet2DConditionModel.from_pretrained(model + '/unet', torch_dtype=dtype) | 171 | unet = UNet2DConditionModel.from_pretrained(model + '/unet', torch_dtype=dtype) |
| 169 | feature_extractor = CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32", torch_dtype=dtype) | ||
| 170 | 172 | ||
| 171 | if scheduler == "plms": | 173 | if scheduler == "plms": |
| 172 | scheduler = PNDMScheduler( | 174 | scheduler = PNDMScheduler( |
| @@ -191,7 +193,6 @@ def create_pipeline(model, scheduler, dtype): | |||
| 191 | unet=unet, | 193 | unet=unet, |
| 192 | tokenizer=tokenizer, | 194 | tokenizer=tokenizer, |
| 193 | scheduler=scheduler, | 195 | scheduler=scheduler, |
| 194 | feature_extractor=feature_extractor | ||
| 195 | ) | 196 | ) |
| 196 | # pipeline.enable_attention_slicing() | 197 | # pipeline.enable_attention_slicing() |
| 197 | pipeline.to("cuda") | 198 | pipeline.to("cuda") |
diff --git a/textual_inversion.py b/textual_inversion.py index 00d460f..5fc2338 100644 --- a/textual_inversion.py +++ b/textual_inversion.py | |||
| @@ -14,7 +14,7 @@ import torch.utils.checkpoint | |||
| 14 | from accelerate import Accelerator | 14 | from accelerate import Accelerator |
| 15 | from accelerate.logging import get_logger | 15 | from accelerate.logging import get_logger |
| 16 | from accelerate.utils import LoggerType, set_seed | 16 | from accelerate.utils import LoggerType, set_seed |
| 17 | from diffusers import AutoencoderKL, DDPMScheduler, LMSDiscreteScheduler, StableDiffusionPipeline, UNet2DConditionModel | 17 | from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel |
| 18 | from schedulers.scheduling_euler_a import EulerAScheduler | 18 | from schedulers.scheduling_euler_a import EulerAScheduler |
| 19 | from diffusers.optimization import get_scheduler | 19 | from diffusers.optimization import get_scheduler |
| 20 | from PIL import Image | 20 | from PIL import Image |
| @@ -30,6 +30,9 @@ from data.textual_inversion.csv import CSVDataModule | |||
| 30 | logger = get_logger(__name__) | 30 | logger = get_logger(__name__) |
| 31 | 31 | ||
| 32 | 32 | ||
| 33 | torch.backends.cuda.matmul.allow_tf32 = True | ||
| 34 | |||
| 35 | |||
| 33 | def parse_args(): | 36 | def parse_args(): |
| 34 | parser = argparse.ArgumentParser( | 37 | parser = argparse.ArgumentParser( |
| 35 | description="Simple example of a training script." | 38 | description="Simple example of a training script." |
| @@ -370,7 +373,6 @@ class Checkpointer: | |||
| 370 | unet=self.unet, | 373 | unet=self.unet, |
| 371 | tokenizer=self.tokenizer, | 374 | tokenizer=self.tokenizer, |
| 372 | scheduler=scheduler, | 375 | scheduler=scheduler, |
| 373 | feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"), | ||
| 374 | ).to(self.accelerator.device) | 376 | ).to(self.accelerator.device) |
| 375 | pipeline.enable_attention_slicing() | 377 | pipeline.enable_attention_slicing() |
| 376 | 378 | ||
