summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
Diffstat (limited to 'training')
-rw-r--r--training/common.py264
-rw-r--r--training/modules/dreambooth.py0
-rw-r--r--training/modules/lora.py0
-rw-r--r--training/modules/ti.py284
-rw-r--r--training/optimization.py53
5 files changed, 70 insertions, 531 deletions
diff --git a/training/common.py b/training/common.py
index 73ce814..b6964a3 100644
--- a/training/common.py
+++ b/training/common.py
@@ -1,52 +1,24 @@
1import math 1import math
2from pathlib import Path
3from contextlib import _GeneratorContextManager, nullcontext 2from contextlib import _GeneratorContextManager, nullcontext
4from typing import Callable, Any, Tuple, Union, Literal, Optional, NamedTuple 3from typing import Callable, Any, Tuple, Union
5import datetime
6import logging
7 4
8import torch 5import torch
9import torch.nn.functional as F 6import torch.nn.functional as F
10from torch.utils.data import DataLoader 7from torch.utils.data import DataLoader
11 8
12from accelerate import Accelerator 9from accelerate import Accelerator
13from accelerate.utils import LoggerType, set_seed
14from transformers import CLIPTextModel 10from transformers import CLIPTextModel
15from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, DPMSolverMultistepScheduler 11from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, DPMSolverMultistepScheduler
16from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup
17 12
18from tqdm.auto import tqdm 13from tqdm.auto import tqdm
19from slugify import slugify
20 14
21from data.csv import VlpnDataModule, VlpnDataItem
22from util import load_embeddings_from_dir
23from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 15from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
24from models.clip.embeddings import patch_managed_embeddings 16from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings
25from models.clip.util import get_extended_embeddings 17from models.clip.util import get_extended_embeddings
26from models.clip.tokenizer import MultiCLIPTokenizer 18from models.clip.tokenizer import MultiCLIPTokenizer
27from training.optimization import get_one_cycle_schedule
28from training.util import AverageMeter, CheckpointerBase 19from training.util import AverageMeter, CheckpointerBase
29 20
30 21
31class TrainingSetup(NamedTuple):
32 accelerator: Accelerator
33 tokenizer: MultiCLIPTokenizer
34 text_encoder: CLIPTextModel
35 vae: AutoencoderKL
36 unet: UNet2DConditionModel
37 noise_scheduler: DDPMScheduler
38 checkpoint_scheduler: DPMSolverMultistepScheduler
39 optimizer_class: Callable
40 learning_rate: float
41 weight_dtype: torch.dtype
42 output_dir: Path
43 seed: int
44 train_dataloader: DataLoader
45 val_dataloader: DataLoader
46 placeholder_token: list[str]
47 placeholder_token_ids: list[list[int]]
48
49
50def noop(*args, **kwards): 22def noop(*args, **kwards):
51 pass 23 pass
52 24
@@ -59,57 +31,6 @@ def noop_on_log():
59 return {} 31 return {}
60 32
61 33
62def get_scheduler(
63 id: str,
64 optimizer: torch.optim.Optimizer,
65 num_training_steps_per_epoch: int,
66 gradient_accumulation_steps: int,
67 min_lr: float = 0.04,
68 warmup_func: str = "cos",
69 annealing_func: str = "cos",
70 warmup_exp: int = 1,
71 annealing_exp: int = 1,
72 cycles: int = 1,
73 train_epochs: int = 100,
74 warmup_epochs: int = 10,
75):
76 num_training_steps_per_epoch = math.ceil(
77 num_training_steps_per_epoch / gradient_accumulation_steps
78 ) * gradient_accumulation_steps
79 num_training_steps = train_epochs * num_training_steps_per_epoch
80 num_warmup_steps = warmup_epochs * num_training_steps_per_epoch
81
82 if id == "one_cycle":
83 lr_scheduler = get_one_cycle_schedule(
84 optimizer=optimizer,
85 num_training_steps=num_training_steps,
86 warmup=warmup_func,
87 annealing=annealing_func,
88 warmup_exp=warmup_exp,
89 annealing_exp=annealing_exp,
90 min_lr=min_lr,
91 )
92 elif id == "cosine_with_restarts":
93 if cycles is None:
94 cycles = math.ceil(math.sqrt(((num_training_steps - num_warmup_steps) / num_training_steps_per_epoch)))
95
96 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
97 optimizer=optimizer,
98 num_warmup_steps=num_warmup_steps,
99 num_training_steps=num_training_steps,
100 num_cycles=cycles,
101 )
102 else:
103 lr_scheduler = get_scheduler_(
104 id,
105 optimizer=optimizer,
106 num_warmup_steps=num_warmup_steps,
107 num_training_steps=num_training_steps,
108 )
109
110 return lr_scheduler
111
112
113def generate_class_images( 34def generate_class_images(
114 accelerator, 35 accelerator,
115 text_encoder, 36 text_encoder,
@@ -162,194 +83,43 @@ def generate_class_images(
162 torch.cuda.empty_cache() 83 torch.cuda.empty_cache()
163 84
164 85
165def train_setup( 86def get_models(pretrained_model_name_or_path: str):
166 output_dir: str,
167 project: str,
168 pretrained_model_name_or_path: str,
169 learning_rate: float,
170 data_file: str,
171 gradient_accumulation_steps: int = 1,
172 mixed_precision: Literal["no", "fp16", "bf16"] = "no",
173 seed: Optional[int] = None,
174 vector_shuffle: Union[bool, Literal["all", "trailing", "leading", "between", "off"]] = "auto",
175 vector_dropout: float = 0.1,
176 gradient_checkpointing: bool = True,
177 embeddings_dir: Optional[str] = None,
178 placeholder_token: list[str] = [],
179 initializer_token: list[str] = [],
180 num_vectors: int = 1,
181 scale_lr: bool = False,
182 use_8bit_adam: bool = False,
183 train_batch_size: int = 1,
184 class_image_dir: Optional[str] = None,
185 num_class_images: int = 0,
186 resolution: int = 768,
187 num_buckets: int = 0,
188 progressive_buckets: bool = False,
189 bucket_step_size: int = 64,
190 bucket_max_pixels: Optional[int] = None,
191 tag_dropout: float = 0.1,
192 tag_shuffle: bool = True,
193 data_template: str = "template",
194 valid_set_size: Optional[int] = None,
195 valid_set_repeat: int = 1,
196 data_filter: Optional[Callable[[VlpnDataItem], bool]] = None,
197 sample_batch_size: int = 1,
198 sample_image_size: int = 768,
199 sample_steps: int = 20,
200) -> TrainingSetup:
201 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
202 output_dir = Path(output_dir).joinpath(slugify(project), now)
203 output_dir.mkdir(parents=True, exist_ok=True)
204
205 accelerator = Accelerator(
206 log_with=LoggerType.TENSORBOARD,
207 logging_dir=f"{output_dir}",
208 gradient_accumulation_steps=gradient_accumulation_steps,
209 mixed_precision=mixed_precision
210 )
211
212 logging.basicConfig(filename=output_dir.joinpath("log.txt"), level=logging.DEBUG)
213
214 seed = seed or (torch.random.seed() >> 32)
215 set_seed(seed)
216
217 # Load the tokenizer and add the placeholder token as a additional special token
218 tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer') 87 tokenizer = MultiCLIPTokenizer.from_pretrained(pretrained_model_name_or_path, subfolder='tokenizer')
219 tokenizer.set_use_vector_shuffle(vector_shuffle)
220 tokenizer.set_dropout(vector_dropout)
221
222 # Load models and create wrapper for stable diffusion
223 text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder') 88 text_encoder = CLIPTextModel.from_pretrained(pretrained_model_name_or_path, subfolder='text_encoder')
224 vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae') 89 vae = AutoencoderKL.from_pretrained(pretrained_model_name_or_path, subfolder='vae')
225 unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet') 90 unet = UNet2DConditionModel.from_pretrained(pretrained_model_name_or_path, subfolder='unet')
226 noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler') 91 noise_scheduler = DDPMScheduler.from_pretrained(pretrained_model_name_or_path, subfolder='scheduler')
227 checkpoint_scheduler = DPMSolverMultistepScheduler.from_pretrained( 92 sample_scheduler = DPMSolverMultistepScheduler.from_pretrained(
228 pretrained_model_name_or_path, subfolder='scheduler') 93 pretrained_model_name_or_path, subfolder='scheduler')
229 94
230 vae.enable_slicing() 95 vae.enable_slicing()
231 vae.set_use_memory_efficient_attention_xformers(True) 96 vae.set_use_memory_efficient_attention_xformers(True)
232 unet.set_use_memory_efficient_attention_xformers(True) 97 unet.set_use_memory_efficient_attention_xformers(True)
233 98
234 if gradient_checkpointing:
235 unet.enable_gradient_checkpointing()
236 text_encoder.gradient_checkpointing_enable()
237
238 embeddings = patch_managed_embeddings(text_encoder) 99 embeddings = patch_managed_embeddings(text_encoder)
239 100
240 if embeddings_dir is not None: 101 return tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings
241 embeddings_dir = Path(embeddings_dir)
242 if not embeddings_dir.exists() or not embeddings_dir.is_dir():
243 raise ValueError("--embeddings_dir must point to an existing directory")
244 102
245 added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir)
246 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}")
247 103
248 # Convert the initializer_token, placeholder_token to ids 104def add_placeholder_tokens(
105 tokenizer: MultiCLIPTokenizer,
106 embeddings: ManagedCLIPTextEmbeddings,
107 placeholder_tokens: list[str],
108 initializer_tokens: list[str],
109 num_vectors: Union[list[int], int]
110):
249 initializer_token_ids = [ 111 initializer_token_ids = [
250 tokenizer.encode(token, add_special_tokens=False) 112 tokenizer.encode(token, add_special_tokens=False)
251 for token in initializer_token 113 for token in initializer_tokens
252 ] 114 ]
115 placeholder_token_ids = tokenizer.add_multi_tokens(placeholder_tokens, num_vectors)
253 116
254 placeholder_token_ids = tokenizer.add_multi_tokens(placeholder_token, num_vectors)
255 embeddings.resize(len(tokenizer)) 117 embeddings.resize(len(tokenizer))
256 118
257 for (new_id, init_ids) in zip(placeholder_token_ids, initializer_token_ids): 119 for (placeholder_token_id, initializer_token_id) in zip(placeholder_token_ids, initializer_token_ids):
258 embeddings.add_embed(new_id, init_ids) 120 embeddings.add_embed(placeholder_token_id, initializer_token_id)
259
260 init_ratios = [
261 f"{len(init_ids)} / {len(new_id)}"
262 for new_id, init_ids in zip(placeholder_token_ids, initializer_token_ids)
263 ]
264
265 print(f"Added {len(placeholder_token_ids)} new tokens: {list(zip(placeholder_token, placeholder_token_ids, init_ratios))}")
266 121
267 vae.requires_grad_(False) 122 return placeholder_token_ids
268 unet.requires_grad_(False)
269 text_encoder.requires_grad_(False)
270
271 if scale_lr:
272 learning_rate = (
273 learning_rate * gradient_accumulation_steps *
274 train_batch_size * accelerator.num_processes
275 )
276
277 # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
278 if use_8bit_adam:
279 try:
280 import bitsandbytes as bnb
281 except ImportError:
282 raise ImportError("To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.")
283
284 optimizer_class = bnb.optim.AdamW8bit
285 else:
286 optimizer_class = torch.optim.AdamW
287
288 weight_dtype = torch.float32
289 if mixed_precision == "fp16":
290 weight_dtype = torch.float16
291 elif mixed_precision == "bf16":
292 weight_dtype = torch.bfloat16
293
294 datamodule = VlpnDataModule(
295 data_file=data_file,
296 batch_size=train_batch_size,
297 tokenizer=tokenizer,
298 class_subdir=class_image_dir,
299 num_class_images=num_class_images,
300 size=resolution,
301 num_buckets=num_buckets,
302 progressive_buckets=progressive_buckets,
303 bucket_step_size=bucket_step_size,
304 bucket_max_pixels=bucket_max_pixels,
305 dropout=tag_dropout,
306 shuffle=tag_shuffle,
307 template_key=data_template,
308 valid_set_size=valid_set_size,
309 valid_set_repeat=valid_set_repeat,
310 seed=seed,
311 filter=data_filter,
312 dtype=weight_dtype
313 )
314 datamodule.setup()
315
316 train_dataloader = datamodule.train_dataloader
317 val_dataloader = datamodule.val_dataloader
318
319 train_dataloader, val_dataloader = accelerator.prepare(train_dataloader, val_dataloader)
320
321 if num_class_images != 0:
322 generate_class_images(
323 accelerator,
324 text_encoder,
325 vae,
326 unet,
327 tokenizer,
328 checkpoint_scheduler,
329 datamodule.data_train,
330 sample_batch_size,
331 sample_image_size,
332 sample_steps
333 )
334
335 return TrainingSetup(
336 accelerator=accelerator,
337 tokenizer=tokenizer,
338 text_encoder=text_encoder,
339 vae=vae,
340 unet=unet,
341 noise_scheduler=noise_scheduler,
342 checkpoint_scheduler=checkpoint_scheduler,
343 optimizer_class=optimizer_class,
344 learning_rate=learning_rate,
345 output_dir=output_dir,
346 weight_dtype=weight_dtype,
347 seed=seed,
348 train_dataloader=train_dataloader,
349 val_dataloader=val_dataloader,
350 placeholder_token=placeholder_token,
351 placeholder_token_ids=placeholder_token_ids
352 )
353 123
354 124
355def loss_step( 125def loss_step(
diff --git a/training/modules/dreambooth.py b/training/modules/dreambooth.py
deleted file mode 100644
index e69de29..0000000
--- a/training/modules/dreambooth.py
+++ /dev/null
diff --git a/training/modules/lora.py b/training/modules/lora.py
deleted file mode 100644
index e69de29..0000000
--- a/training/modules/lora.py
+++ /dev/null
diff --git a/training/modules/ti.py b/training/modules/ti.py
deleted file mode 100644
index 2db6f88..0000000
--- a/training/modules/ti.py
+++ /dev/null
@@ -1,284 +0,0 @@
1from typing import Literal
2from functools import partial
3from contextlib import contextmanager, nullcontext
4
5import torch
6
7from slugify import slugify
8
9from accelerate import Accelerator
10from transformers import CLIPTextModel
11from diffusers import AutoencoderKL, UNet2DConditionModel
12
13from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
14from models.clip.tokenizer import MultiCLIPTokenizer
15
16from training.common import TrainingSetup, get_scheduler, train_loop, loss_step
17from training.util import EMAModel, CheckpointerBase
18
19
20class Checkpointer(CheckpointerBase):
21 def __init__(
22 self,
23 accelerator: Accelerator,
24 vae: AutoencoderKL,
25 unet: UNet2DConditionModel,
26 tokenizer: MultiCLIPTokenizer,
27 text_encoder: CLIPTextModel,
28 ema_embeddings: EMAModel,
29 weight_dtype: torch.dtype,
30 scheduler,
31 placeholder_token,
32 placeholder_token_ids,
33 *args,
34 **kwargs
35 ):
36 super().__init__(*args, **kwargs)
37
38 self.weight_dtype = weight_dtype
39 self.accelerator = accelerator
40 self.vae = vae
41 self.unet = unet
42 self.tokenizer = tokenizer
43 self.text_encoder = text_encoder
44 self.ema_embeddings = ema_embeddings
45 self.scheduler = scheduler
46 self.placeholder_token = placeholder_token
47 self.placeholder_token_ids = placeholder_token_ids
48
49 @torch.no_grad()
50 def checkpoint(self, step, postfix):
51 print("Saving checkpoint for step %d..." % step)
52
53 checkpoints_path = self.output_dir.joinpath("checkpoints")
54 checkpoints_path.mkdir(parents=True, exist_ok=True)
55
56 text_encoder = self.accelerator.unwrap_model(self.text_encoder)
57
58 ema_context = nullcontext()
59 if self.ema_embeddings is not None:
60 ema_context = self.ema_embeddings.apply_temporary(
61 text_encoder.text_model.embeddings.temp_token_embedding.parameters())
62
63 with ema_context:
64 for (token, ids) in zip(self.placeholder_token, self.placeholder_token_ids):
65 text_encoder.text_model.embeddings.save_embed(
66 ids,
67 checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin")
68 )
69
70 del text_encoder
71
72 @torch.no_grad()
73 def save_samples(self, step, num_inference_steps, guidance_scale=7.5, eta=0.0):
74 text_encoder = self.accelerator.unwrap_model(self.text_encoder)
75
76 ema_context = nullcontext()
77 if self.ema_embeddings is not None:
78 ema_context = self.ema_embeddings.apply_temporary(
79 text_encoder.text_model.embeddings.temp_token_embedding.parameters())
80
81 with ema_context:
82 orig_dtype = text_encoder.dtype
83 text_encoder.to(dtype=self.weight_dtype)
84
85 pipeline = VlpnStableDiffusion(
86 text_encoder=text_encoder,
87 vae=self.vae,
88 unet=self.unet,
89 tokenizer=self.tokenizer,
90 scheduler=self.scheduler,
91 ).to(self.accelerator.device)
92 pipeline.set_progress_bar_config(dynamic_ncols=True)
93
94 super().save_samples(pipeline, step, num_inference_steps, guidance_scale, eta)
95
96 text_encoder.to(dtype=orig_dtype)
97
98 del text_encoder
99 del pipeline
100
101 if torch.cuda.is_available():
102 torch.cuda.empty_cache()
103
104
105def train_ti(
106 setup: TrainingSetup,
107 num_train_epochs: int = 100,
108 num_class_images: int = 0,
109 prior_loss_weight: float = 1.0,
110 use_ema: bool = False,
111 ema_inv_gamma: float = 1.0,
112 ema_power: float = 4/5,
113 ema_max_decay: float = .9999,
114 adam_beta1: float = 0.9,
115 adam_beta2: float = 0.999,
116 adam_weight_decay: float = 0,
117 adam_epsilon: float = 1e-08,
118 adam_amsgrad: bool = False,
119 lr_scheduler: Literal[
120 "linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup", "one_cycle"
121 ] = "one_cycle",
122 lr_min_lr: float = 0.04,
123 lr_warmup_func: Literal["linear", "cos"] = "cos",
124 lr_annealing_func: Literal["linear", "half_cos", "cos"] = "cos",
125 lr_warmup_exp: int = 1,
126 lr_annealing_exp: int = 1,
127 lr_cycles: int = 1,
128 lr_warmup_epochs: int = 10,
129 emb_decay_target: float = 0.4,
130 emb_decay_factor: float = 1,
131 emb_decay_start: float = 1e-4,
132 sample_image_size: int = 768,
133 sample_batch_size: int = 1,
134 sample_batches: int = 1,
135 sample_frequency: int = 10,
136 sample_steps: int = 20,
137 checkpoint_frequency: int = 50,
138 global_step_offset: int = 0,
139):
140 if use_ema:
141 ema_embeddings = EMAModel(
142 setup.text_encoder.text_model.embeddings.temp_token_embedding.parameters(),
143 inv_gamma=ema_inv_gamma,
144 power=ema_power,
145 max_value=ema_max_decay,
146 )
147 else:
148 ema_embeddings = None
149
150 setup.text_encoder.requires_grad_(True)
151 setup.text_encoder.text_model.encoder.requires_grad_(False)
152 setup.text_encoder.text_model.final_layer_norm.requires_grad_(False)
153 setup.text_encoder.text_model.embeddings.position_embedding.requires_grad_(False)
154 setup.text_encoder.text_model.embeddings.token_embedding.requires_grad_(False)
155
156 # Initialize the optimizer
157 optimizer = setup.optimizer_class(
158 setup.text_encoder.text_model.embeddings.temp_token_embedding.parameters(),
159 lr=setup.learning_rate,
160 betas=(adam_beta1, adam_beta2),
161 weight_decay=adam_weight_decay,
162 eps=adam_epsilon,
163 amsgrad=adam_amsgrad,
164 )
165
166 lr_scheduler = get_scheduler(
167 lr_scheduler,
168 optimizer=optimizer,
169 min_lr=lr_min_lr,
170 warmup_func=lr_warmup_func,
171 annealing_func=lr_annealing_func,
172 warmup_exp=lr_warmup_exp,
173 annealing_exp=lr_annealing_exp,
174 cycles=lr_cycles,
175 train_epochs=num_train_epochs,
176 warmup_epochs=lr_warmup_epochs,
177 num_training_steps_per_epoch=len(setup.train_dataloader),
178 gradient_accumulation_steps=setup.accelerator.gradient_accumulation_steps
179 )
180
181 text_encoder, optimizer, lr_scheduler = setup.accelerator.prepare(
182 setup.text_encoder, optimizer, lr_scheduler
183 )
184
185 # Move vae and unet to device
186 setup.vae.to(setup.accelerator.device, dtype=setup.weight_dtype)
187 setup.unet.to(setup.accelerator.device, dtype=setup.weight_dtype)
188
189 if use_ema:
190 ema_embeddings.to(setup.accelerator.device)
191
192 setup.unet.train()
193
194 @contextmanager
195 def on_train(epoch: int):
196 try:
197 setup.tokenizer.train()
198 yield
199 finally:
200 pass
201
202 @contextmanager
203 def on_eval():
204 try:
205 setup.tokenizer.eval()
206
207 ema_context = nullcontext()
208 if use_ema:
209 ema_context = ema_embeddings.apply_temporary(
210 text_encoder.text_model.embeddings.temp_token_embedding.parameters())
211
212 with ema_context:
213 yield
214 finally:
215 pass
216
217 @torch.no_grad()
218 def on_after_optimize(lr: float):
219 text_encoder.text_model.embeddings.normalize(
220 emb_decay_target,
221 min(1.0, max(0.0, emb_decay_factor * ((lr - emb_decay_start) / (setup.learning_rate - emb_decay_start))))
222 )
223
224 if use_ema:
225 ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters())
226
227 def on_log():
228 if use_ema:
229 return {"ema_decay": ema_embeddings.decay}
230 return {}
231
232 loss_step_ = partial(
233 loss_step,
234 setup.vae,
235 setup.noise_scheduler,
236 setup.unet,
237 text_encoder,
238 num_class_images != 0,
239 prior_loss_weight,
240 setup.seed,
241 )
242
243 checkpointer = Checkpointer(
244 accelerator=setup.accelerator,
245 vae=setup.vae,
246 unet=setup.unet,
247 tokenizer=setup.tokenizer,
248 text_encoder=text_encoder,
249 ema_embeddings=ema_embeddings,
250 weight_dtype=setup.weight_dtype,
251 scheduler=setup.checkpoint_scheduler,
252 placeholder_token=setup.placeholder_token,
253 placeholder_token_ids=setup.placeholder_token_ids,
254 train_dataloader=setup.train_dataloader,
255 val_dataloader=setup.val_dataloader,
256 output_dir=setup.output_dir,
257 seed=setup.seed,
258 sample_image_size=sample_image_size,
259 sample_batch_size=sample_batch_size,
260 sample_batches=sample_batches
261 )
262
263 if setup.accelerator.is_main_process:
264 setup.accelerator.init_trackers("textual_inversion")
265
266 train_loop(
267 accelerator=setup.accelerator,
268 optimizer=optimizer,
269 lr_scheduler=lr_scheduler,
270 model=text_encoder,
271 checkpointer=checkpointer,
272 train_dataloader=setup.train_dataloader,
273 val_dataloader=setup.val_dataloader,
274 loss_step=loss_step_,
275 sample_frequency=sample_frequency,
276 sample_steps=sample_steps,
277 checkpoint_frequency=checkpoint_frequency,
278 global_step_offset=global_step_offset,
279 num_epochs=num_train_epochs,
280 on_log=on_log,
281 on_train=on_train,
282 on_after_optimize=on_after_optimize,
283 on_eval=on_eval
284 )
diff --git a/training/optimization.py b/training/optimization.py
index dd84f9c..5db7794 100644
--- a/training/optimization.py
+++ b/training/optimization.py
@@ -5,6 +5,8 @@ from functools import partial
5import torch 5import torch
6from torch.optim.lr_scheduler import LambdaLR 6from torch.optim.lr_scheduler import LambdaLR
7 7
8from diffusers.optimization import get_scheduler as get_scheduler_, get_cosine_with_hard_restarts_schedule_with_warmup
9
8 10
9class OneCyclePhase(NamedTuple): 11class OneCyclePhase(NamedTuple):
10 step_min: int 12 step_min: int
@@ -83,3 +85,54 @@ def get_one_cycle_schedule(
83 return phase.min + phase.func((current_step - phase.step_min) / (phase.step_max - phase.step_min)) * (phase.max - phase.min) 85 return phase.min + phase.func((current_step - phase.step_min) / (phase.step_max - phase.step_min)) * (phase.max - phase.min)
84 86
85 return LambdaLR(optimizer, lr_lambda, last_epoch) 87 return LambdaLR(optimizer, lr_lambda, last_epoch)
88
89
90def get_scheduler(
91 id: str,
92 optimizer: torch.optim.Optimizer,
93 num_training_steps_per_epoch: int,
94 gradient_accumulation_steps: int,
95 min_lr: float = 0.04,
96 warmup_func: str = "cos",
97 annealing_func: str = "cos",
98 warmup_exp: int = 1,
99 annealing_exp: int = 1,
100 cycles: int = 1,
101 train_epochs: int = 100,
102 warmup_epochs: int = 10,
103):
104 num_training_steps_per_epoch = math.ceil(
105 num_training_steps_per_epoch / gradient_accumulation_steps
106 ) * gradient_accumulation_steps
107 num_training_steps = train_epochs * num_training_steps_per_epoch
108 num_warmup_steps = warmup_epochs * num_training_steps_per_epoch
109
110 if id == "one_cycle":
111 lr_scheduler = get_one_cycle_schedule(
112 optimizer=optimizer,
113 num_training_steps=num_training_steps,
114 warmup=warmup_func,
115 annealing=annealing_func,
116 warmup_exp=warmup_exp,
117 annealing_exp=annealing_exp,
118 min_lr=min_lr,
119 )
120 elif id == "cosine_with_restarts":
121 if cycles is None:
122 cycles = math.ceil(math.sqrt(((num_training_steps - num_warmup_steps) / num_training_steps_per_epoch)))
123
124 lr_scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
125 optimizer=optimizer,
126 num_warmup_steps=num_warmup_steps,
127 num_training_steps=num_training_steps,
128 num_cycles=cycles,
129 )
130 else:
131 lr_scheduler = get_scheduler_(
132 id,
133 optimizer=optimizer,
134 num_warmup_steps=num_warmup_steps,
135 num_training_steps=num_training_steps,
136 )
137
138 return lr_scheduler