diff options
author | Volpeon <git@volpeon.ink> | 2023-01-15 22:26:43 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-15 22:26:43 +0100 |
commit | 3f922880475c2c0a5679987d4a9a43606e838566 (patch) | |
tree | 757746927e34aa7fddff1e44c837b489233029d7 /training/strategy | |
parent | Restored functional trainer (diff) | |
download | textual-inversion-diff-3f922880475c2c0a5679987d4a9a43606e838566.tar.gz textual-inversion-diff-3f922880475c2c0a5679987d4a9a43606e838566.tar.bz2 textual-inversion-diff-3f922880475c2c0a5679987d4a9a43606e838566.zip |
Added Dreambooth strategy
Diffstat (limited to 'training/strategy')
-rw-r--r-- | training/strategy/dreambooth.py | 183 |
1 files changed, 183 insertions, 0 deletions
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py new file mode 100644 index 0000000..6e7ebe2 --- /dev/null +++ b/training/strategy/dreambooth.py | |||
@@ -0,0 +1,183 @@ | |||
1 | from contextlib import nullcontext | ||
2 | from typing import Optional | ||
3 | from functools import partial | ||
4 | from contextlib import contextmanager, nullcontext | ||
5 | from pathlib import Path | ||
6 | import itertools | ||
7 | |||
8 | import torch | ||
9 | from torch.utils.data import DataLoader | ||
10 | |||
11 | from accelerate import Accelerator | ||
12 | from transformers import CLIPTextModel | ||
13 | from diffusers import AutoencoderKL, UNet2DConditionModel, DPMSolverMultistepScheduler | ||
14 | |||
15 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | ||
16 | from models.clip.tokenizer import MultiCLIPTokenizer | ||
17 | from training.util import EMAModel | ||
18 | from training.functional import TrainingCallbacks, save_samples | ||
19 | |||
20 | |||
21 | def dreambooth_strategy( | ||
22 | accelerator: Accelerator, | ||
23 | unet: UNet2DConditionModel, | ||
24 | text_encoder: CLIPTextModel, | ||
25 | tokenizer: MultiCLIPTokenizer, | ||
26 | vae: AutoencoderKL, | ||
27 | sample_scheduler: DPMSolverMultistepScheduler, | ||
28 | train_dataloader: DataLoader, | ||
29 | val_dataloader: DataLoader, | ||
30 | output_dir: Path, | ||
31 | seed: int, | ||
32 | train_text_encoder_epochs: int, | ||
33 | max_grad_norm: float = 1.0, | ||
34 | use_ema: bool = False, | ||
35 | ema_inv_gamma: float = 1.0, | ||
36 | ema_power: int = 1, | ||
37 | ema_max_decay: float = 0.9999, | ||
38 | sample_batch_size: int = 1, | ||
39 | sample_num_batches: int = 1, | ||
40 | sample_num_steps: int = 20, | ||
41 | sample_guidance_scale: float = 7.5, | ||
42 | sample_image_size: Optional[int] = None, | ||
43 | ): | ||
44 | if accelerator.gradient_accumulation_steps > 1 and accelerator.num_processes > 1: | ||
45 | raise ValueError( | ||
46 | "Gradient accumulation is not supported when training the text encoder in distributed training. " | ||
47 | "Please set gradient_accumulation_steps to 1. This feature will be supported in the future." | ||
48 | ) | ||
49 | |||
50 | weight_dtype = torch.float32 | ||
51 | if accelerator.state.mixed_precision == "fp16": | ||
52 | weight_dtype = torch.float16 | ||
53 | elif accelerator.state.mixed_precision == "bf16": | ||
54 | weight_dtype = torch.bfloat16 | ||
55 | |||
56 | save_samples_ = partial( | ||
57 | save_samples, | ||
58 | accelerator=accelerator, | ||
59 | unet=unet, | ||
60 | text_encoder=text_encoder, | ||
61 | tokenizer=tokenizer, | ||
62 | vae=vae, | ||
63 | sample_scheduler=sample_scheduler, | ||
64 | train_dataloader=train_dataloader, | ||
65 | val_dataloader=val_dataloader, | ||
66 | dtype=weight_dtype, | ||
67 | output_dir=output_dir, | ||
68 | seed=seed, | ||
69 | batch_size=sample_batch_size, | ||
70 | num_batches=sample_num_batches, | ||
71 | num_steps=sample_num_steps, | ||
72 | guidance_scale=sample_guidance_scale, | ||
73 | image_size=sample_image_size, | ||
74 | ) | ||
75 | |||
76 | if use_ema: | ||
77 | ema_unet = EMAModel( | ||
78 | unet.parameters(), | ||
79 | inv_gamma=ema_inv_gamma, | ||
80 | power=ema_power, | ||
81 | max_value=ema_max_decay, | ||
82 | ) | ||
83 | else: | ||
84 | ema_unet = None | ||
85 | |||
86 | def ema_context(): | ||
87 | if use_ema: | ||
88 | return ema_unet.apply_temporary(unet.parameters()) | ||
89 | else: | ||
90 | return nullcontext() | ||
91 | |||
92 | def on_model(): | ||
93 | return unet | ||
94 | |||
95 | def on_prepare(): | ||
96 | unet.requires_grad_(True) | ||
97 | text_encoder.requires_grad_(True) | ||
98 | text_encoder.text_model.embeddings.persist() | ||
99 | text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(False) | ||
100 | |||
101 | if use_ema: | ||
102 | ema_unet.to(accelerator.device) | ||
103 | |||
104 | @contextmanager | ||
105 | def on_train(epoch: int): | ||
106 | tokenizer.train() | ||
107 | |||
108 | if epoch < train_text_encoder_epochs: | ||
109 | text_encoder.train() | ||
110 | elif epoch == train_text_encoder_epochs: | ||
111 | text_encoder.requires_grad_(False) | ||
112 | text_encoder.eval() | ||
113 | |||
114 | yield | ||
115 | |||
116 | @contextmanager | ||
117 | def on_eval(): | ||
118 | tokenizer.eval() | ||
119 | text_encoder.eval() | ||
120 | |||
121 | with ema_context(): | ||
122 | yield | ||
123 | |||
124 | def on_before_optimize(epoch: int): | ||
125 | if accelerator.sync_gradients: | ||
126 | params_to_clip = [unet.parameters()] | ||
127 | if epoch < train_text_encoder_epochs: | ||
128 | params_to_clip.append(text_encoder.parameters()) | ||
129 | accelerator.clip_grad_norm_(itertools.chain(*params_to_clip), max_grad_norm) | ||
130 | |||
131 | @torch.no_grad() | ||
132 | def on_after_optimize(lr: float): | ||
133 | if use_ema: | ||
134 | ema_unet.step(unet.parameters()) | ||
135 | |||
136 | def on_log(): | ||
137 | if use_ema: | ||
138 | return {"ema_decay": ema_unet.decay} | ||
139 | return {} | ||
140 | |||
141 | @torch.no_grad() | ||
142 | def on_checkpoint(step, postfix): | ||
143 | if postfix != "end": | ||
144 | return | ||
145 | |||
146 | print("Saving model...") | ||
147 | |||
148 | unet_ = accelerator.unwrap_model(unet) | ||
149 | text_encoder_ = accelerator.unwrap_model(text_encoder) | ||
150 | |||
151 | with ema_context(): | ||
152 | pipeline = VlpnStableDiffusion( | ||
153 | text_encoder=text_encoder_, | ||
154 | vae=vae, | ||
155 | unet=unet_, | ||
156 | tokenizer=tokenizer, | ||
157 | scheduler=sample_scheduler, | ||
158 | ) | ||
159 | pipeline.save_pretrained(output_dir.joinpath("model")) | ||
160 | |||
161 | del unet_ | ||
162 | del text_encoder_ | ||
163 | del pipeline | ||
164 | |||
165 | if torch.cuda.is_available(): | ||
166 | torch.cuda.empty_cache() | ||
167 | |||
168 | @torch.no_grad() | ||
169 | def on_sample(step): | ||
170 | with ema_context(): | ||
171 | save_samples_(step=step) | ||
172 | |||
173 | return TrainingCallbacks( | ||
174 | on_prepare=on_prepare, | ||
175 | on_model=on_model, | ||
176 | on_train=on_train, | ||
177 | on_eval=on_eval, | ||
178 | on_before_optimize=on_before_optimize, | ||
179 | on_after_optimize=on_after_optimize, | ||
180 | on_log=on_log, | ||
181 | on_checkpoint=on_checkpoint, | ||
182 | on_sample=on_sample, | ||
183 | ) | ||