diff options
author | Volpeon <git@volpeon.ink> | 2022-12-21 09:17:25 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2022-12-21 09:17:25 +0100 |
commit | 68540b27849564994d921968a36faa9b997e626d (patch) | |
tree | 8fbe834ab4c52f057cd114bbb0e786158f215acc /train_dreambooth.py | |
parent | Fix training (diff) | |
download | textual-inversion-diff-68540b27849564994d921968a36faa9b997e626d.tar.gz textual-inversion-diff-68540b27849564994d921968a36faa9b997e626d.tar.bz2 textual-inversion-diff-68540b27849564994d921968a36faa9b997e626d.zip |
Moved common training code into separate module
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r-- | train_dreambooth.py | 126 |
1 files changed, 16 insertions, 110 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 0f8fece..9749c62 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -16,7 +16,6 @@ from accelerate.utils import LoggerType, set_seed | |||
16 | from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel | 16 | from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel |
17 | from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup | 17 | from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup |
18 | from diffusers.training_utils import EMAModel | 18 | from diffusers.training_utils import EMAModel |
19 | from PIL import Image | ||
20 | from tqdm.auto import tqdm | 19 | from tqdm.auto import tqdm |
21 | from transformers import CLIPTextModel, CLIPTokenizer | 20 | from transformers import CLIPTextModel, CLIPTokenizer |
22 | from slugify import slugify | 21 | from slugify import slugify |
@@ -25,6 +24,7 @@ from common import load_text_embeddings | |||
25 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
26 | from data.csv import CSVDataModule | 25 | from data.csv import CSVDataModule |
27 | from training.optimization import get_one_cycle_schedule | 26 | from training.optimization import get_one_cycle_schedule |
27 | from training.util import AverageMeter, CheckpointerBase, freeze_params, save_args | ||
28 | from models.clip.prompt import PromptProcessor | 28 | from models.clip.prompt import PromptProcessor |
29 | 29 | ||
30 | logger = get_logger(__name__) | 30 | logger = get_logger(__name__) |
@@ -385,41 +385,7 @@ def parse_args(): | |||
385 | return args | 385 | return args |
386 | 386 | ||
387 | 387 | ||
388 | def save_args(basepath: Path, args, extra={}): | 388 | class Checkpointer(CheckpointerBase): |
389 | info = {"args": vars(args)} | ||
390 | info["args"].update(extra) | ||
391 | with open(basepath.joinpath("args.json"), "w") as f: | ||
392 | json.dump(info, f, indent=4) | ||
393 | |||
394 | |||
395 | def freeze_params(params): | ||
396 | for param in params: | ||
397 | param.requires_grad = False | ||
398 | |||
399 | |||
400 | def make_grid(images, rows, cols): | ||
401 | w, h = images[0].size | ||
402 | grid = Image.new('RGB', size=(cols*w, rows*h)) | ||
403 | for i, image in enumerate(images): | ||
404 | grid.paste(image, box=(i % cols*w, i//cols*h)) | ||
405 | return grid | ||
406 | |||
407 | |||
408 | class AverageMeter: | ||
409 | def __init__(self, name=None): | ||
410 | self.name = name | ||
411 | self.reset() | ||
412 | |||
413 | def reset(self): | ||
414 | self.sum = self.count = self.avg = 0 | ||
415 | |||
416 | def update(self, val, n=1): | ||
417 | self.sum += val * n | ||
418 | self.count += n | ||
419 | self.avg = self.sum / self.count | ||
420 | |||
421 | |||
422 | class Checkpointer: | ||
423 | def __init__( | 389 | def __init__( |
424 | self, | 390 | self, |
425 | datamodule, | 391 | datamodule, |
@@ -437,9 +403,20 @@ class Checkpointer: | |||
437 | sample_image_size, | 403 | sample_image_size, |
438 | sample_batches, | 404 | sample_batches, |
439 | sample_batch_size, | 405 | sample_batch_size, |
440 | seed | 406 | seed, |
441 | ): | 407 | ): |
442 | self.datamodule = datamodule | 408 | super().__init__( |
409 | datamodule=datamodule, | ||
410 | output_dir=output_dir, | ||
411 | instance_identifier=instance_identifier, | ||
412 | placeholder_token=placeholder_token, | ||
413 | placeholder_token_id=placeholder_token_id, | ||
414 | sample_image_size=sample_image_size, | ||
415 | seed=seed or torch.random.seed(), | ||
416 | sample_batches=sample_batches, | ||
417 | sample_batch_size=sample_batch_size | ||
418 | ) | ||
419 | |||
443 | self.accelerator = accelerator | 420 | self.accelerator = accelerator |
444 | self.vae = vae | 421 | self.vae = vae |
445 | self.unet = unet | 422 | self.unet = unet |
@@ -447,14 +424,6 @@ class Checkpointer: | |||
447 | self.tokenizer = tokenizer | 424 | self.tokenizer = tokenizer |
448 | self.text_encoder = text_encoder | 425 | self.text_encoder = text_encoder |
449 | self.scheduler = scheduler | 426 | self.scheduler = scheduler |
450 | self.output_dir = output_dir | ||
451 | self.instance_identifier = instance_identifier | ||
452 | self.placeholder_token = placeholder_token | ||
453 | self.placeholder_token_id = placeholder_token_id | ||
454 | self.sample_image_size = sample_image_size | ||
455 | self.seed = seed or torch.random.seed() | ||
456 | self.sample_batches = sample_batches | ||
457 | self.sample_batch_size = sample_batch_size | ||
458 | 427 | ||
459 | @torch.no_grad() | 428 | @torch.no_grad() |
460 | def save_model(self): | 429 | def save_model(self): |
@@ -481,8 +450,6 @@ class Checkpointer: | |||
481 | 450 | ||
482 | @torch.no_grad() | 451 | @torch.no_grad() |
483 | def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): | 452 | def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): |
484 | samples_path = Path(self.output_dir).joinpath("samples") | ||
485 | |||
486 | unet = self.ema_unet.averaged_model if self.ema_unet is not None else self.accelerator.unwrap_model(self.unet) | 453 | unet = self.ema_unet.averaged_model if self.ema_unet is not None else self.accelerator.unwrap_model(self.unet) |
487 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) | 454 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) |
488 | 455 | ||
@@ -495,72 +462,11 @@ class Checkpointer: | |||
495 | ).to(self.accelerator.device) | 462 | ).to(self.accelerator.device) |
496 | pipeline.set_progress_bar_config(dynamic_ncols=True) | 463 | pipeline.set_progress_bar_config(dynamic_ncols=True) |
497 | 464 | ||
498 | train_data = self.datamodule.train_dataloader() | 465 | super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta) |
499 | val_data = self.datamodule.val_dataloader() | ||
500 | |||
501 | generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) | ||
502 | stable_latents = torch.randn( | ||
503 | (self.sample_batch_size, pipeline.unet.in_channels, self.sample_image_size // 8, self.sample_image_size // 8), | ||
504 | device=pipeline.device, | ||
505 | generator=generator, | ||
506 | ) | ||
507 | |||
508 | with torch.autocast("cuda"), torch.inference_mode(): | ||
509 | for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: | ||
510 | all_samples = [] | ||
511 | file_path = samples_path.joinpath(pool, f"step_{step}.jpg") | ||
512 | file_path.parent.mkdir(parents=True, exist_ok=True) | ||
513 | |||
514 | data_enum = enumerate(data) | ||
515 | |||
516 | batches = [ | ||
517 | batch | ||
518 | for j, batch in data_enum | ||
519 | if j * data.batch_size < self.sample_batch_size * self.sample_batches | ||
520 | ] | ||
521 | prompts = [ | ||
522 | prompt.format(identifier=self.instance_identifier) | ||
523 | for batch in batches | ||
524 | for prompt in batch["prompts"] | ||
525 | ] | ||
526 | nprompts = [ | ||
527 | prompt | ||
528 | for batch in batches | ||
529 | for prompt in batch["nprompts"] | ||
530 | ] | ||
531 | |||
532 | for i in range(self.sample_batches): | ||
533 | prompt = prompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] | ||
534 | nprompt = nprompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] | ||
535 | |||
536 | samples = pipeline( | ||
537 | prompt=prompt, | ||
538 | negative_prompt=nprompt, | ||
539 | height=self.sample_image_size, | ||
540 | width=self.sample_image_size, | ||
541 | image=latents[:len(prompt)] if latents is not None else None, | ||
542 | generator=generator if latents is not None else None, | ||
543 | guidance_scale=guidance_scale, | ||
544 | eta=eta, | ||
545 | num_inference_steps=num_inference_steps, | ||
546 | output_type='pil' | ||
547 | ).images | ||
548 | |||
549 | all_samples += samples | ||
550 | |||
551 | del samples | ||
552 | |||
553 | image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size) | ||
554 | image_grid.save(file_path, quality=85) | ||
555 | |||
556 | del all_samples | ||
557 | del image_grid | ||
558 | 466 | ||
559 | del unet | 467 | del unet |
560 | del text_encoder | 468 | del text_encoder |
561 | del pipeline | 469 | del pipeline |
562 | del generator | ||
563 | del stable_latents | ||
564 | 470 | ||
565 | if torch.cuda.is_available(): | 471 | if torch.cuda.is_available(): |
566 | torch.cuda.empty_cache() | 472 | torch.cuda.empty_cache() |