diff options
| author | Volpeon <git@volpeon.ink> | 2022-12-21 09:17:25 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-12-21 09:17:25 +0100 |
| commit | 68540b27849564994d921968a36faa9b997e626d (patch) | |
| tree | 8fbe834ab4c52f057cd114bbb0e786158f215acc /train_dreambooth.py | |
| parent | Fix training (diff) | |
| download | textual-inversion-diff-68540b27849564994d921968a36faa9b997e626d.tar.gz textual-inversion-diff-68540b27849564994d921968a36faa9b997e626d.tar.bz2 textual-inversion-diff-68540b27849564994d921968a36faa9b997e626d.zip | |
Moved common training code into separate module
Diffstat (limited to 'train_dreambooth.py')
| -rw-r--r-- | train_dreambooth.py | 126 |
1 files changed, 16 insertions, 110 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 0f8fece..9749c62 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -16,7 +16,6 @@ from accelerate.utils import LoggerType, set_seed | |||
| 16 | from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel | 16 | from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel |
| 17 | from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup | 17 | from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup |
| 18 | from diffusers.training_utils import EMAModel | 18 | from diffusers.training_utils import EMAModel |
| 19 | from PIL import Image | ||
| 20 | from tqdm.auto import tqdm | 19 | from tqdm.auto import tqdm |
| 21 | from transformers import CLIPTextModel, CLIPTokenizer | 20 | from transformers import CLIPTextModel, CLIPTokenizer |
| 22 | from slugify import slugify | 21 | from slugify import slugify |
| @@ -25,6 +24,7 @@ from common import load_text_embeddings | |||
| 25 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
| 26 | from data.csv import CSVDataModule | 25 | from data.csv import CSVDataModule |
| 27 | from training.optimization import get_one_cycle_schedule | 26 | from training.optimization import get_one_cycle_schedule |
| 27 | from training.util import AverageMeter, CheckpointerBase, freeze_params, save_args | ||
| 28 | from models.clip.prompt import PromptProcessor | 28 | from models.clip.prompt import PromptProcessor |
| 29 | 29 | ||
| 30 | logger = get_logger(__name__) | 30 | logger = get_logger(__name__) |
| @@ -385,41 +385,7 @@ def parse_args(): | |||
| 385 | return args | 385 | return args |
| 386 | 386 | ||
| 387 | 387 | ||
| 388 | def save_args(basepath: Path, args, extra={}): | 388 | class Checkpointer(CheckpointerBase): |
| 389 | info = {"args": vars(args)} | ||
| 390 | info["args"].update(extra) | ||
| 391 | with open(basepath.joinpath("args.json"), "w") as f: | ||
| 392 | json.dump(info, f, indent=4) | ||
| 393 | |||
| 394 | |||
| 395 | def freeze_params(params): | ||
| 396 | for param in params: | ||
| 397 | param.requires_grad = False | ||
| 398 | |||
| 399 | |||
| 400 | def make_grid(images, rows, cols): | ||
| 401 | w, h = images[0].size | ||
| 402 | grid = Image.new('RGB', size=(cols*w, rows*h)) | ||
| 403 | for i, image in enumerate(images): | ||
| 404 | grid.paste(image, box=(i % cols*w, i//cols*h)) | ||
| 405 | return grid | ||
| 406 | |||
| 407 | |||
| 408 | class AverageMeter: | ||
| 409 | def __init__(self, name=None): | ||
| 410 | self.name = name | ||
| 411 | self.reset() | ||
| 412 | |||
| 413 | def reset(self): | ||
| 414 | self.sum = self.count = self.avg = 0 | ||
| 415 | |||
| 416 | def update(self, val, n=1): | ||
| 417 | self.sum += val * n | ||
| 418 | self.count += n | ||
| 419 | self.avg = self.sum / self.count | ||
| 420 | |||
| 421 | |||
| 422 | class Checkpointer: | ||
| 423 | def __init__( | 389 | def __init__( |
| 424 | self, | 390 | self, |
| 425 | datamodule, | 391 | datamodule, |
| @@ -437,9 +403,20 @@ class Checkpointer: | |||
| 437 | sample_image_size, | 403 | sample_image_size, |
| 438 | sample_batches, | 404 | sample_batches, |
| 439 | sample_batch_size, | 405 | sample_batch_size, |
| 440 | seed | 406 | seed, |
| 441 | ): | 407 | ): |
| 442 | self.datamodule = datamodule | 408 | super().__init__( |
| 409 | datamodule=datamodule, | ||
| 410 | output_dir=output_dir, | ||
| 411 | instance_identifier=instance_identifier, | ||
| 412 | placeholder_token=placeholder_token, | ||
| 413 | placeholder_token_id=placeholder_token_id, | ||
| 414 | sample_image_size=sample_image_size, | ||
| 415 | seed=seed or torch.random.seed(), | ||
| 416 | sample_batches=sample_batches, | ||
| 417 | sample_batch_size=sample_batch_size | ||
| 418 | ) | ||
| 419 | |||
| 443 | self.accelerator = accelerator | 420 | self.accelerator = accelerator |
| 444 | self.vae = vae | 421 | self.vae = vae |
| 445 | self.unet = unet | 422 | self.unet = unet |
| @@ -447,14 +424,6 @@ class Checkpointer: | |||
| 447 | self.tokenizer = tokenizer | 424 | self.tokenizer = tokenizer |
| 448 | self.text_encoder = text_encoder | 425 | self.text_encoder = text_encoder |
| 449 | self.scheduler = scheduler | 426 | self.scheduler = scheduler |
| 450 | self.output_dir = output_dir | ||
| 451 | self.instance_identifier = instance_identifier | ||
| 452 | self.placeholder_token = placeholder_token | ||
| 453 | self.placeholder_token_id = placeholder_token_id | ||
| 454 | self.sample_image_size = sample_image_size | ||
| 455 | self.seed = seed or torch.random.seed() | ||
| 456 | self.sample_batches = sample_batches | ||
| 457 | self.sample_batch_size = sample_batch_size | ||
| 458 | 427 | ||
| 459 | @torch.no_grad() | 428 | @torch.no_grad() |
| 460 | def save_model(self): | 429 | def save_model(self): |
| @@ -481,8 +450,6 @@ class Checkpointer: | |||
| 481 | 450 | ||
| 482 | @torch.no_grad() | 451 | @torch.no_grad() |
| 483 | def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): | 452 | def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): |
| 484 | samples_path = Path(self.output_dir).joinpath("samples") | ||
| 485 | |||
| 486 | unet = self.ema_unet.averaged_model if self.ema_unet is not None else self.accelerator.unwrap_model(self.unet) | 453 | unet = self.ema_unet.averaged_model if self.ema_unet is not None else self.accelerator.unwrap_model(self.unet) |
| 487 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) | 454 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) |
| 488 | 455 | ||
| @@ -495,72 +462,11 @@ class Checkpointer: | |||
| 495 | ).to(self.accelerator.device) | 462 | ).to(self.accelerator.device) |
| 496 | pipeline.set_progress_bar_config(dynamic_ncols=True) | 463 | pipeline.set_progress_bar_config(dynamic_ncols=True) |
| 497 | 464 | ||
| 498 | train_data = self.datamodule.train_dataloader() | 465 | super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta) |
| 499 | val_data = self.datamodule.val_dataloader() | ||
| 500 | |||
| 501 | generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) | ||
| 502 | stable_latents = torch.randn( | ||
| 503 | (self.sample_batch_size, pipeline.unet.in_channels, self.sample_image_size // 8, self.sample_image_size // 8), | ||
| 504 | device=pipeline.device, | ||
| 505 | generator=generator, | ||
| 506 | ) | ||
| 507 | |||
| 508 | with torch.autocast("cuda"), torch.inference_mode(): | ||
| 509 | for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: | ||
| 510 | all_samples = [] | ||
| 511 | file_path = samples_path.joinpath(pool, f"step_{step}.jpg") | ||
| 512 | file_path.parent.mkdir(parents=True, exist_ok=True) | ||
| 513 | |||
| 514 | data_enum = enumerate(data) | ||
| 515 | |||
| 516 | batches = [ | ||
| 517 | batch | ||
| 518 | for j, batch in data_enum | ||
| 519 | if j * data.batch_size < self.sample_batch_size * self.sample_batches | ||
| 520 | ] | ||
| 521 | prompts = [ | ||
| 522 | prompt.format(identifier=self.instance_identifier) | ||
| 523 | for batch in batches | ||
| 524 | for prompt in batch["prompts"] | ||
| 525 | ] | ||
| 526 | nprompts = [ | ||
| 527 | prompt | ||
| 528 | for batch in batches | ||
| 529 | for prompt in batch["nprompts"] | ||
| 530 | ] | ||
| 531 | |||
| 532 | for i in range(self.sample_batches): | ||
| 533 | prompt = prompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] | ||
| 534 | nprompt = nprompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] | ||
| 535 | |||
| 536 | samples = pipeline( | ||
| 537 | prompt=prompt, | ||
| 538 | negative_prompt=nprompt, | ||
| 539 | height=self.sample_image_size, | ||
| 540 | width=self.sample_image_size, | ||
| 541 | image=latents[:len(prompt)] if latents is not None else None, | ||
| 542 | generator=generator if latents is not None else None, | ||
| 543 | guidance_scale=guidance_scale, | ||
| 544 | eta=eta, | ||
| 545 | num_inference_steps=num_inference_steps, | ||
| 546 | output_type='pil' | ||
| 547 | ).images | ||
| 548 | |||
| 549 | all_samples += samples | ||
| 550 | |||
| 551 | del samples | ||
| 552 | |||
| 553 | image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size) | ||
| 554 | image_grid.save(file_path, quality=85) | ||
| 555 | |||
| 556 | del all_samples | ||
| 557 | del image_grid | ||
| 558 | 466 | ||
| 559 | del unet | 467 | del unet |
| 560 | del text_encoder | 468 | del text_encoder |
| 561 | del pipeline | 469 | del pipeline |
| 562 | del generator | ||
| 563 | del stable_latents | ||
| 564 | 470 | ||
| 565 | if torch.cuda.is_available(): | 471 | if torch.cuda.is_available(): |
| 566 | torch.cuda.empty_cache() | 472 | torch.cuda.empty_cache() |
