summaryrefslogtreecommitdiffstats
path: root/train_dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-12-21 09:17:25 +0100
committerVolpeon <git@volpeon.ink>2022-12-21 09:17:25 +0100
commit68540b27849564994d921968a36faa9b997e626d (patch)
tree8fbe834ab4c52f057cd114bbb0e786158f215acc /train_dreambooth.py
parentFix training (diff)
downloadtextual-inversion-diff-68540b27849564994d921968a36faa9b997e626d.tar.gz
textual-inversion-diff-68540b27849564994d921968a36faa9b997e626d.tar.bz2
textual-inversion-diff-68540b27849564994d921968a36faa9b997e626d.zip
Moved common training code into separate module
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r--train_dreambooth.py126
1 files changed, 16 insertions, 110 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 0f8fece..9749c62 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -16,7 +16,6 @@ from accelerate.utils import LoggerType, set_seed
16from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel 16from diffusers import AutoencoderKL, DDPMScheduler, DPMSolverMultistepScheduler, UNet2DConditionModel
17from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup 17from diffusers.optimization import get_scheduler, get_cosine_with_hard_restarts_schedule_with_warmup
18from diffusers.training_utils import EMAModel 18from diffusers.training_utils import EMAModel
19from PIL import Image
20from tqdm.auto import tqdm 19from tqdm.auto import tqdm
21from transformers import CLIPTextModel, CLIPTokenizer 20from transformers import CLIPTextModel, CLIPTokenizer
22from slugify import slugify 21from slugify import slugify
@@ -25,6 +24,7 @@ from common import load_text_embeddings
25from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 24from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
26from data.csv import CSVDataModule 25from data.csv import CSVDataModule
27from training.optimization import get_one_cycle_schedule 26from training.optimization import get_one_cycle_schedule
27from training.util import AverageMeter, CheckpointerBase, freeze_params, save_args
28from models.clip.prompt import PromptProcessor 28from models.clip.prompt import PromptProcessor
29 29
30logger = get_logger(__name__) 30logger = get_logger(__name__)
@@ -385,41 +385,7 @@ def parse_args():
385 return args 385 return args
386 386
387 387
388def save_args(basepath: Path, args, extra={}): 388class Checkpointer(CheckpointerBase):
389 info = {"args": vars(args)}
390 info["args"].update(extra)
391 with open(basepath.joinpath("args.json"), "w") as f:
392 json.dump(info, f, indent=4)
393
394
395def freeze_params(params):
396 for param in params:
397 param.requires_grad = False
398
399
400def make_grid(images, rows, cols):
401 w, h = images[0].size
402 grid = Image.new('RGB', size=(cols*w, rows*h))
403 for i, image in enumerate(images):
404 grid.paste(image, box=(i % cols*w, i//cols*h))
405 return grid
406
407
408class AverageMeter:
409 def __init__(self, name=None):
410 self.name = name
411 self.reset()
412
413 def reset(self):
414 self.sum = self.count = self.avg = 0
415
416 def update(self, val, n=1):
417 self.sum += val * n
418 self.count += n
419 self.avg = self.sum / self.count
420
421
422class Checkpointer:
423 def __init__( 389 def __init__(
424 self, 390 self,
425 datamodule, 391 datamodule,
@@ -437,9 +403,20 @@ class Checkpointer:
437 sample_image_size, 403 sample_image_size,
438 sample_batches, 404 sample_batches,
439 sample_batch_size, 405 sample_batch_size,
440 seed 406 seed,
441 ): 407 ):
442 self.datamodule = datamodule 408 super().__init__(
409 datamodule=datamodule,
410 output_dir=output_dir,
411 instance_identifier=instance_identifier,
412 placeholder_token=placeholder_token,
413 placeholder_token_id=placeholder_token_id,
414 sample_image_size=sample_image_size,
415 seed=seed or torch.random.seed(),
416 sample_batches=sample_batches,
417 sample_batch_size=sample_batch_size
418 )
419
443 self.accelerator = accelerator 420 self.accelerator = accelerator
444 self.vae = vae 421 self.vae = vae
445 self.unet = unet 422 self.unet = unet
@@ -447,14 +424,6 @@ class Checkpointer:
447 self.tokenizer = tokenizer 424 self.tokenizer = tokenizer
448 self.text_encoder = text_encoder 425 self.text_encoder = text_encoder
449 self.scheduler = scheduler 426 self.scheduler = scheduler
450 self.output_dir = output_dir
451 self.instance_identifier = instance_identifier
452 self.placeholder_token = placeholder_token
453 self.placeholder_token_id = placeholder_token_id
454 self.sample_image_size = sample_image_size
455 self.seed = seed or torch.random.seed()
456 self.sample_batches = sample_batches
457 self.sample_batch_size = sample_batch_size
458 427
459 @torch.no_grad() 428 @torch.no_grad()
460 def save_model(self): 429 def save_model(self):
@@ -481,8 +450,6 @@ class Checkpointer:
481 450
482 @torch.no_grad() 451 @torch.no_grad()
483 def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0): 452 def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0):
484 samples_path = Path(self.output_dir).joinpath("samples")
485
486 unet = self.ema_unet.averaged_model if self.ema_unet is not None else self.accelerator.unwrap_model(self.unet) 453 unet = self.ema_unet.averaged_model if self.ema_unet is not None else self.accelerator.unwrap_model(self.unet)
487 text_encoder = self.accelerator.unwrap_model(self.text_encoder) 454 text_encoder = self.accelerator.unwrap_model(self.text_encoder)
488 455
@@ -495,72 +462,11 @@ class Checkpointer:
495 ).to(self.accelerator.device) 462 ).to(self.accelerator.device)
496 pipeline.set_progress_bar_config(dynamic_ncols=True) 463 pipeline.set_progress_bar_config(dynamic_ncols=True)
497 464
498 train_data = self.datamodule.train_dataloader() 465 super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta)
499 val_data = self.datamodule.val_dataloader()
500
501 generator = torch.Generator(device=pipeline.device).manual_seed(self.seed)
502 stable_latents = torch.randn(
503 (self.sample_batch_size, pipeline.unet.in_channels, self.sample_image_size // 8, self.sample_image_size // 8),
504 device=pipeline.device,
505 generator=generator,
506 )
507
508 with torch.autocast("cuda"), torch.inference_mode():
509 for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]:
510 all_samples = []
511 file_path = samples_path.joinpath(pool, f"step_{step}.jpg")
512 file_path.parent.mkdir(parents=True, exist_ok=True)
513
514 data_enum = enumerate(data)
515
516 batches = [
517 batch
518 for j, batch in data_enum
519 if j * data.batch_size < self.sample_batch_size * self.sample_batches
520 ]
521 prompts = [
522 prompt.format(identifier=self.instance_identifier)
523 for batch in batches
524 for prompt in batch["prompts"]
525 ]
526 nprompts = [
527 prompt
528 for batch in batches
529 for prompt in batch["nprompts"]
530 ]
531
532 for i in range(self.sample_batches):
533 prompt = prompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size]
534 nprompt = nprompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size]
535
536 samples = pipeline(
537 prompt=prompt,
538 negative_prompt=nprompt,
539 height=self.sample_image_size,
540 width=self.sample_image_size,
541 image=latents[:len(prompt)] if latents is not None else None,
542 generator=generator if latents is not None else None,
543 guidance_scale=guidance_scale,
544 eta=eta,
545 num_inference_steps=num_inference_steps,
546 output_type='pil'
547 ).images
548
549 all_samples += samples
550
551 del samples
552
553 image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size)
554 image_grid.save(file_path, quality=85)
555
556 del all_samples
557 del image_grid
558 466
559 del unet 467 del unet
560 del text_encoder 468 del text_encoder
561 del pipeline 469 del pipeline
562 del generator
563 del stable_latents
564 470
565 if torch.cuda.is_available(): 471 if torch.cuda.is_available():
566 torch.cuda.empty_cache() 472 torch.cuda.empty_cache()