summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--train_ti.py46
-rw-r--r--training/strategy/dreambooth.py183
2 files changed, 206 insertions, 23 deletions
diff --git a/train_ti.py b/train_ti.py
index 77dec12..2497519 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -557,15 +557,6 @@ def main():
557 else: 557 else:
558 optimizer_class = torch.optim.AdamW 558 optimizer_class = torch.optim.AdamW
559 559
560 optimizer = optimizer_class(
561 text_encoder.text_model.embeddings.temp_token_embedding.parameters(),
562 lr=args.learning_rate,
563 betas=(args.adam_beta1, args.adam_beta2),
564 weight_decay=args.adam_weight_decay,
565 eps=args.adam_epsilon,
566 amsgrad=args.adam_amsgrad,
567 )
568
569 weight_dtype = torch.float32 560 weight_dtype = torch.float32
570 if args.mixed_precision == "fp16": 561 if args.mixed_precision == "fp16":
571 weight_dtype = torch.float16 562 weight_dtype = torch.float16
@@ -624,6 +615,29 @@ def main():
624 args.sample_steps 615 args.sample_steps
625 ) 616 )
626 617
618 trainer = partial(
619 train,
620 accelerator=accelerator,
621 unet=unet,
622 text_encoder=text_encoder,
623 vae=vae,
624 noise_scheduler=noise_scheduler,
625 train_dataloader=train_dataloader,
626 val_dataloader=val_dataloader,
627 dtype=weight_dtype,
628 seed=args.seed,
629 callbacks_fn=textual_inversion_strategy
630 )
631
632 optimizer = optimizer_class(
633 text_encoder.text_model.embeddings.temp_token_embedding.parameters(),
634 lr=args.learning_rate,
635 betas=(args.adam_beta1, args.adam_beta2),
636 weight_decay=args.adam_weight_decay,
637 eps=args.adam_epsilon,
638 amsgrad=args.adam_amsgrad,
639 )
640
627 if args.find_lr: 641 if args.find_lr:
628 lr_scheduler = None 642 lr_scheduler = None
629 else: 643 else:
@@ -642,20 +656,6 @@ def main():
642 warmup_epochs=args.lr_warmup_epochs, 656 warmup_epochs=args.lr_warmup_epochs,
643 ) 657 )
644 658
645 trainer = partial(
646 train,
647 accelerator=accelerator,
648 unet=unet,
649 text_encoder=text_encoder,
650 vae=vae,
651 noise_scheduler=noise_scheduler,
652 train_dataloader=train_dataloader,
653 val_dataloader=val_dataloader,
654 dtype=weight_dtype,
655 seed=args.seed,
656 callbacks_fn=textual_inversion_strategy
657 )
658
659 trainer( 659 trainer(
660 optimizer=optimizer, 660 optimizer=optimizer,
661 lr_scheduler=lr_scheduler, 661 lr_scheduler=lr_scheduler,
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py
new file mode 100644
index 0000000..6e7ebe2
--- /dev/null
+++ b/training/strategy/dreambooth.py
@@ -0,0 +1,183 @@
1from contextlib import nullcontext
2from typing import Optional
3from functools import partial
4from contextlib import contextmanager, nullcontext
5from pathlib import Path
6import itertools
7
8import torch
9from torch.utils.data import DataLoader
10
11from accelerate import Accelerator
12from transformers import CLIPTextModel
13from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepScheduler
14
15from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
16from models.clip.tokenizer import MultiCLIPTokenizer
17from training.util import EMAModel
18from training.functional import TrainingCallbacks, save_samples
19
20
21def dreambooth_strategy(
22 accelerator: Accelerator,
23 unet: UNet2DConditionModel,
24 text_encoder: CLIPTextModel,
25 tokenizer: MultiCLIPTokenizer,
26 vae: AutoencoderKL,
27 sample_scheduler: DPMSolverMultistepScheduler,
28 train_dataloader: DataLoader,
29 val_dataloader: DataLoader,
30 output_dir: Path,
31 seed: int,
32 train_text_encoder_epochs: int,
33 max_grad_norm: float = 1.0,
34 use_ema: bool = False,
35 ema_inv_gamma: float = 1.0,
36 ema_power: int = 1,
37 ema_max_decay: float = 0.9999,
38 sample_batch_size: int = 1,
39 sample_num_batches: int = 1,
40 sample_num_steps: int = 20,
41 sample_guidance_scale: float = 7.5,
42 sample_image_size: Optional[int] = None,
43):
44 if accelerator.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
45 raise ValueError(
46 "Gradient accumulation is not supported when training the text encoder in distributed training. "
47 "Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
48 )
49
50 weight_dtype = torch.float32
51 if accelerator.state.mixed_precision == "fp16":
52 weight_dtype = torch.float16
53 elif accelerator.state.mixed_precision == "bf16":
54 weight_dtype = torch.bfloat16
55
56 save_samples_ = partial(
57 save_samples,
58 accelerator=accelerator,
59 unet=unet,
60 text_encoder=text_encoder,
61 tokenizer=tokenizer,
62 vae=vae,
63 sample_scheduler=sample_scheduler,
64 train_dataloader=train_dataloader,
65 val_dataloader=val_dataloader,
66 dtype=weight_dtype,
67 output_dir=output_dir,
68 seed=seed,
69 batch_size=sample_batch_size,
70 num_batches=sample_num_batches,
71 num_steps=sample_num_steps,
72 guidance_scale=sample_guidance_scale,
73 image_size=sample_image_size,
74 )
75
76 if use_ema:
77 ema_unet = EMAModel(
78 unet.parameters(),
79 inv_gamma=ema_inv_gamma,
80 power=ema_power,
81 max_value=ema_max_decay,
82 )
83 else:
84 ema_unet = None
85
86 def ema_context():
87 if use_ema:
88 return ema_unet.apply_temporary(unet.parameters())
89 else:
90 return nullcontext()
91
92 def on_model():
93 return unet
94
95 def on_prepare():
96 unet.requires_grad_(True)
97 text_encoder.requires_grad_(True)
98 text_encoder.text_model.embeddings.persist()
99 text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(False)
100
101 if use_ema:
102 ema_unet.to(accelerator.device)
103
104 @contextmanager
105 def on_train(epoch: int):
106 tokenizer.train()
107
108 if epoch < train_text_encoder_epochs:
109 text_encoder.train()
110 elif epoch == train_text_encoder_epochs:
111 text_encoder.requires_grad_(False)
112 text_encoder.eval()
113
114 yield
115
116 @contextmanager
117 def on_eval():
118 tokenizer.eval()
119 text_encoder.eval()
120
121 with ema_context():
122 yield
123
124 def on_before_optimize(epoch: int):
125 if accelerator.sync_gradients:
126 params_to_clip = [unet.parameters()]
127 if epoch < train_text_encoder_epochs:
128 params_to_clip.append(text_encoder.parameters())
129 accelerator.clip_grad_norm_(itertools.chain(*params_to_clip), max_grad_norm)
130
131 @torch.no_grad()
132 def on_after_optimize(lr: float):
133 if use_ema:
134 ema_unet.step(unet.parameters())
135
136 def on_log():
137 if use_ema:
138 return {"ema_decay": ema_unet.decay}
139 return {}
140
141 @torch.no_grad()
142 def on_checkpoint(step, postfix):
143 if postfix != "end":
144 return
145
146 print("Saving model...")
147
148 unet_ = accelerator.unwrap_model(unet)
149 text_encoder_ = accelerator.unwrap_model(text_encoder)
150
151 with ema_context():
152 pipeline = VlpnStableDiffusion(
153 text_encoder=text_encoder_,
154 vae=vae,
155 unet=unet_,
156 tokenizer=tokenizer,
157 scheduler=sample_scheduler,
158 )
159 pipeline.save_pretrained(output_dir.joinpath("model"))
160
161 del unet_
162 del text_encoder_
163 del pipeline
164
165 if torch.cuda.is_available():
166 torch.cuda.empty_cache()
167
168 @torch.no_grad()
169 def on_sample(step):
170 with ema_context():
171 save_samples_(step=step)
172
173 return TrainingCallbacks(
174 on_prepare=on_prepare,
175 on_model=on_model,
176 on_train=on_train,
177 on_eval=on_eval,
178 on_before_optimize=on_before_optimize,
179 on_after_optimize=on_after_optimize,
180 on_log=on_log,
181 on_checkpoint=on_checkpoint,
182 on_sample=on_sample,
183 )