summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-16 15:52:43 +0100
committerVolpeon <git@volpeon.ink>2023-01-16 15:52:43 +0100
commit6c8cffe28baeafac77d047ff3f8ded9418033e2f (patch)
tree807c527deb1b15ef795f5cd8a7682151c69a037e
parentPad dataset if len(items) < batch_size (diff)
downloadtextual-inversion-diff-6c8cffe28baeafac77d047ff3f8ded9418033e2f.tar.gz
textual-inversion-diff-6c8cffe28baeafac77d047ff3f8ded9418033e2f.tar.bz2
textual-inversion-diff-6c8cffe28baeafac77d047ff3f8ded9418033e2f.zip
More training adjustments
-rw-r--r--data/csv.py39
-rw-r--r--train_dreambooth.py71
-rw-r--r--train_ti.py17
-rw-r--r--training/functional.py5
-rw-r--r--training/optimization.py10
-rw-r--r--training/strategy/ti.py2
6 files changed, 101 insertions, 43 deletions
diff --git a/data/csv.py b/data/csv.py
index dec66d7..85b98f8 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -174,7 +174,8 @@ class VlpnDataModule():
174 interpolation: str = "bicubic", 174 interpolation: str = "bicubic",
175 template_key: str = "template", 175 template_key: str = "template",
176 valid_set_size: Optional[int] = None, 176 valid_set_size: Optional[int] = None,
177 valid_set_repeat: int = 1, 177 train_set_pad: Optional[int] = None,
178 valid_set_pad: Optional[int] = None,
178 seed: Optional[int] = None, 179 seed: Optional[int] = None,
179 filter: Optional[Callable[[VlpnDataItem], bool]] = None, 180 filter: Optional[Callable[[VlpnDataItem], bool]] = None,
180 dtype: torch.dtype = torch.float32, 181 dtype: torch.dtype = torch.float32,
@@ -202,7 +203,8 @@ class VlpnDataModule():
202 self.template_key = template_key 203 self.template_key = template_key
203 self.interpolation = interpolation 204 self.interpolation = interpolation
204 self.valid_set_size = valid_set_size 205 self.valid_set_size = valid_set_size
205 self.valid_set_repeat = valid_set_repeat 206 self.train_set_pad = train_set_pad if train_set_pad is not None else batch_size
207 self.valid_set_pad = valid_set_pad if valid_set_pad is not None else batch_size
206 self.seed = seed 208 self.seed = seed
207 self.filter = filter 209 self.filter = filter
208 self.batch_size = batch_size 210 self.batch_size = batch_size
@@ -267,9 +269,6 @@ class VlpnDataModule():
267 items = self.prepare_items(template, expansions, items) 269 items = self.prepare_items(template, expansions, items)
268 items = self.filter_items(items) 270 items = self.filter_items(items)
269 271
270 if (len(items) < self.batch_size):
271 items = (items * self.batch_size)[:self.batch_size]
272
273 num_images = len(items) 272 num_images = len(items)
274 273
275 valid_set_size = min(self.valid_set_size, num_images) if self.valid_set_size is not None else num_images // 10 274 valid_set_size = min(self.valid_set_size, num_images) if self.valid_set_size is not None else num_images // 10
@@ -283,14 +282,17 @@ class VlpnDataModule():
283 collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.num_class_images != 0) 282 collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.num_class_images != 0)
284 283
285 if valid_set_size == 0: 284 if valid_set_size == 0:
286 data_train, data_val = items, [] 285 data_train, data_val = items, items[:1]
287 else: 286 else:
288 data_train, data_val = random_split(items, [train_set_size, valid_set_size], generator=generator) 287 data_train, data_val = random_split(items, [train_set_size, valid_set_size], generator=generator)
289 288
290 self.data_train = self.pad_items(data_train, self.num_class_images) 289 data_train = self.pad_items(data_train, self.num_class_images)
290
291 if len(data_train) < self.train_set_pad:
292 data_train *= math.ceil(self.train_set_pad / len(data_train))
291 293
292 train_dataset = VlpnDataset( 294 self.train_dataset = VlpnDataset(
293 self.data_train, self.tokenizer, 295 data_train, self.tokenizer,
294 num_buckets=self.num_buckets, progressive_buckets=self.progressive_buckets, 296 num_buckets=self.num_buckets, progressive_buckets=self.progressive_buckets,
295 bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, 297 bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels,
296 batch_size=self.batch_size, generator=generator, 298 batch_size=self.batch_size, generator=generator,
@@ -299,24 +301,26 @@ class VlpnDataModule():
299 ) 301 )
300 302
301 self.train_dataloader = DataLoader( 303 self.train_dataloader = DataLoader(
302 train_dataset, 304 self.train_dataset,
303 batch_size=None, pin_memory=True, collate_fn=collate_fn_ 305 batch_size=None, pin_memory=True, collate_fn=collate_fn_
304 ) 306 )
305 307
306 if valid_set_size != 0: 308 if len(data_val) != 0:
307 self.data_val = self.pad_items(data_val) 309 data_val = self.pad_items(data_val)
310
311 if len(data_val) < self.valid_set_pad:
312 data_val *= math.ceil(self.valid_set_pad / len(data_val))
308 313
309 val_dataset = VlpnDataset( 314 self.val_dataset = VlpnDataset(
310 self.data_val, self.tokenizer, 315 data_val, self.tokenizer,
311 num_buckets=self.num_buckets, progressive_buckets=True, 316 num_buckets=self.num_buckets, progressive_buckets=True,
312 bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels, 317 bucket_step_size=self.bucket_step_size, bucket_max_pixels=self.bucket_max_pixels,
313 repeat=self.valid_set_repeat,
314 batch_size=self.batch_size, generator=generator, 318 batch_size=self.batch_size, generator=generator,
315 size=self.size, interpolation=self.interpolation, 319 size=self.size, interpolation=self.interpolation,
316 ) 320 )
317 321
318 self.val_dataloader = DataLoader( 322 self.val_dataloader = DataLoader(
319 val_dataset, 323 self.val_dataset,
320 batch_size=None, pin_memory=True, collate_fn=collate_fn_ 324 batch_size=None, pin_memory=True, collate_fn=collate_fn_
321 ) 325 )
322 else: 326 else:
@@ -332,7 +336,6 @@ class VlpnDataset(IterableDataset):
332 bucket_step_size: int = 64, 336 bucket_step_size: int = 64,
333 bucket_max_pixels: Optional[int] = None, 337 bucket_max_pixels: Optional[int] = None,
334 progressive_buckets: bool = False, 338 progressive_buckets: bool = False,
335 repeat: int = 1,
336 batch_size: int = 1, 339 batch_size: int = 1,
337 num_class_images: int = 0, 340 num_class_images: int = 0,
338 size: int = 768, 341 size: int = 768,
@@ -341,7 +344,7 @@ class VlpnDataset(IterableDataset):
341 interpolation: str = "bicubic", 344 interpolation: str = "bicubic",
342 generator: Optional[torch.Generator] = None, 345 generator: Optional[torch.Generator] = None,
343 ): 346 ):
344 self.items = items * repeat 347 self.items = items
345 self.batch_size = batch_size 348 self.batch_size = batch_size
346 349
347 self.tokenizer = tokenizer 350 self.tokenizer = tokenizer
diff --git a/train_dreambooth.py b/train_dreambooth.py
index a9fbbbd..1dc41b1 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -55,6 +55,18 @@ def parse_args():
55 default="template", 55 default="template",
56 ) 56 )
57 parser.add_argument( 57 parser.add_argument(
58 "--train_set_pad",
59 type=int,
60 default=None,
61 help="The number to fill train dataset items up to."
62 )
63 parser.add_argument(
64 "--valid_set_pad",
65 type=int,
66 default=None,
67 help="The number to fill validation dataset items up to."
68 )
69 parser.add_argument(
58 "--project", 70 "--project",
59 type=str, 71 type=str,
60 default=None, 72 default=None,
@@ -188,11 +200,23 @@ def parse_args():
188 default=100 200 default=100
189 ) 201 )
190 parser.add_argument( 202 parser.add_argument(
203 "--ti_data_template",
204 type=str,
205 nargs='*',
206 default=[],
207 )
208 parser.add_argument(
191 "--ti_num_train_epochs", 209 "--ti_num_train_epochs",
192 type=int, 210 type=int,
193 default=10 211 default=10
194 ) 212 )
195 parser.add_argument( 213 parser.add_argument(
214 "--ti_batch_size",
215 type=int,
216 default=1,
217 help="Batch size (per device) for the training dataloader."
218 )
219 parser.add_argument(
196 "--max_train_steps", 220 "--max_train_steps",
197 type=int, 221 type=int,
198 default=None, 222 default=None,
@@ -458,6 +482,12 @@ def parse_args():
458 if len(args.placeholder_tokens) != len(args.num_vectors): 482 if len(args.placeholder_tokens) != len(args.num_vectors):
459 raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items") 483 raise ValueError("--placeholder_tokens and --num_vectors must have the same number of items")
460 484
485 if isinstance(args.ti_data_template, str):
486 args.ti_data_template = [args.ti_data_template]
487
488 if len(args.ti_data_template) == 0:
489 raise ValueError("You must specify --ti_data_template")
490
461 if isinstance(args.collection, str): 491 if isinstance(args.collection, str):
462 args.collection = [args.collection] 492 args.collection = [args.collection]
463 493
@@ -491,6 +521,8 @@ def main():
491 521
492 set_seed(args.seed) 522 set_seed(args.seed)
493 523
524 seed_generator = torch.Generator().manual_seed(args.seed)
525
494 save_args(output_dir, args) 526 save_args(output_dir, args)
495 527
496 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( 528 tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models(
@@ -512,6 +544,8 @@ def main():
512 if not embeddings_dir.exists() or not embeddings_dir.is_dir(): 544 if not embeddings_dir.exists() or not embeddings_dir.is_dir():
513 raise ValueError("--embeddings_dir must point to an existing directory") 545 raise ValueError("--embeddings_dir must point to an existing directory")
514 546
547 embeddings.persist()
548
515 added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir) 549 added_tokens, added_ids = load_embeddings_from_dir(tokenizer, embeddings, embeddings_dir)
516 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}") 550 print(f"Added {len(added_tokens)} tokens from embeddings dir: {list(zip(added_tokens, added_ids))}")
517 551
@@ -545,7 +579,6 @@ def main():
545 vae=vae, 579 vae=vae,
546 noise_scheduler=noise_scheduler, 580 noise_scheduler=noise_scheduler,
547 dtype=weight_dtype, 581 dtype=weight_dtype,
548 seed=args.seed,
549 with_prior_preservation=args.num_class_images != 0, 582 with_prior_preservation=args.num_class_images != 0,
550 prior_loss_weight=args.prior_loss_weight, 583 prior_loss_weight=args.prior_loss_weight,
551 ) 584 )
@@ -557,13 +590,17 @@ def main():
557 cur_dir = output_dir.joinpath("1-ti") 590 cur_dir = output_dir.joinpath("1-ti")
558 cur_dir.mkdir(parents=True, exist_ok=True) 591 cur_dir.mkdir(parents=True, exist_ok=True)
559 592
560 for placeholder_token, initializer_token, num_vectors in zip(args.placeholder_tokens, args.initializer_tokens, args.num_vectors): 593 for i, placeholder_token, initializer_token, num_vectors, data_template in zip(
561 print(f"Phase 1.1: {placeholder_token} ({num_vectors}) ({initializer_token})") 594 range(len(args.placeholder_tokens)),
562 595 args.placeholder_tokens,
596 args.initializer_tokens,
597 args.num_vectors,
598 args.ti_data_template
599 ):
563 cur_subdir = cur_dir.joinpath(placeholder_token) 600 cur_subdir = cur_dir.joinpath(placeholder_token)
564 cur_subdir.mkdir(parents=True, exist_ok=True) 601 cur_subdir.mkdir(parents=True, exist_ok=True)
565 602
566 placeholder_token_ids, _ = add_placeholder_tokens( 603 placeholder_token_ids, initializer_token_ids = add_placeholder_tokens(
567 tokenizer=tokenizer, 604 tokenizer=tokenizer,
568 embeddings=embeddings, 605 embeddings=embeddings,
569 placeholder_tokens=[placeholder_token], 606 placeholder_tokens=[placeholder_token],
@@ -571,17 +608,23 @@ def main():
571 num_vectors=[num_vectors] 608 num_vectors=[num_vectors]
572 ) 609 )
573 610
611 print(
612 f"Phase 1.{i + 1}: {placeholder_token}, {placeholder_token_ids[0]} ({initializer_token}, {initializer_token_ids[0]})")
613
614 args.seed = seed_generator.seed()
615
574 datamodule = VlpnDataModule( 616 datamodule = VlpnDataModule(
575 data_file=args.train_data_file, 617 data_file=args.train_data_file,
576 batch_size=args.train_batch_size, 618 batch_size=args.ti_batch_size,
577 tokenizer=tokenizer, 619 tokenizer=tokenizer,
578 class_subdir=args.class_image_dir, 620 class_subdir=args.class_image_dir,
579 num_class_images=args.num_class_images, 621 num_class_images=args.num_class_images,
580 size=args.resolution, 622 size=args.resolution,
581 shuffle=not args.no_tag_shuffle, 623 shuffle=not args.no_tag_shuffle,
582 template_key=args.train_data_template, 624 template_key=data_template,
583 valid_set_size=1, 625 valid_set_size=1,
584 valid_set_repeat=args.valid_set_repeat, 626 train_set_pad=args.train_set_pad,
627 valid_set_pad=args.valid_set_pad,
585 seed=args.seed, 628 seed=args.seed,
586 filter=partial(keyword_filter, [placeholder_token], args.collection, args.exclude_collections), 629 filter=partial(keyword_filter, [placeholder_token], args.collection, args.exclude_collections),
587 dtype=weight_dtype 630 dtype=weight_dtype
@@ -591,7 +634,9 @@ def main():
591 optimizer = optimizer_class( 634 optimizer = optimizer_class(
592 text_encoder.text_model.embeddings.temp_token_embedding.parameters(), 635 text_encoder.text_model.embeddings.temp_token_embedding.parameters(),
593 lr=args.ti_learning_rate, 636 lr=args.ti_learning_rate,
637 betas=(args.adam_beta1, args.adam_beta2),
594 weight_decay=0.0, 638 weight_decay=0.0,
639 eps=args.adam_epsilon,
595 ) 640 )
596 641
597 lr_scheduler = get_scheduler( 642 lr_scheduler = get_scheduler(
@@ -600,7 +645,6 @@ def main():
600 num_training_steps_per_epoch=len(datamodule.train_dataloader), 645 num_training_steps_per_epoch=len(datamodule.train_dataloader),
601 gradient_accumulation_steps=args.gradient_accumulation_steps, 646 gradient_accumulation_steps=args.gradient_accumulation_steps,
602 train_epochs=args.ti_num_train_epochs, 647 train_epochs=args.ti_num_train_epochs,
603 warmup_epochs=args.ti_num_train_epochs // 4,
604 ) 648 )
605 649
606 trainer( 650 trainer(
@@ -608,10 +652,11 @@ def main():
608 project="textual_inversion", 652 project="textual_inversion",
609 train_dataloader=datamodule.train_dataloader, 653 train_dataloader=datamodule.train_dataloader,
610 val_dataloader=datamodule.val_dataloader, 654 val_dataloader=datamodule.val_dataloader,
655 seed=args.seed,
611 optimizer=optimizer, 656 optimizer=optimizer,
612 lr_scheduler=lr_scheduler, 657 lr_scheduler=lr_scheduler,
613 num_train_epochs=args.ti_num_train_epochs, 658 num_train_epochs=args.ti_num_train_epochs,
614 sample_frequency=2, 659 sample_frequency=args.ti_num_train_epochs // 5,
615 checkpoint_frequency=9999999, 660 checkpoint_frequency=9999999,
616 # -- 661 # --
617 tokenizer=tokenizer, 662 tokenizer=tokenizer,
@@ -637,7 +682,7 @@ def main():
637 cur_dir = output_dir.joinpath("2-db") 682 cur_dir = output_dir.joinpath("2-db")
638 cur_dir.mkdir(parents=True, exist_ok=True) 683 cur_dir.mkdir(parents=True, exist_ok=True)
639 684
640 args.seed = (args.seed + 28635) >> 32 685 args.seed = seed_generator.seed()
641 686
642 datamodule = VlpnDataModule( 687 datamodule = VlpnDataModule(
643 data_file=args.train_data_file, 688 data_file=args.train_data_file,
@@ -654,7 +699,8 @@ def main():
654 shuffle=not args.no_tag_shuffle, 699 shuffle=not args.no_tag_shuffle,
655 template_key=args.train_data_template, 700 template_key=args.train_data_template,
656 valid_set_size=args.valid_set_size, 701 valid_set_size=args.valid_set_size,
657 valid_set_repeat=args.valid_set_repeat, 702 train_set_pad=args.train_set_pad,
703 valid_set_pad=args.valid_set_pad,
658 seed=args.seed, 704 seed=args.seed,
659 filter=partial(keyword_filter, None, args.collection, args.exclude_collections), 705 filter=partial(keyword_filter, None, args.collection, args.exclude_collections),
660 dtype=weight_dtype 706 dtype=weight_dtype
@@ -697,6 +743,7 @@ def main():
697 project="dreambooth", 743 project="dreambooth",
698 train_dataloader=datamodule.train_dataloader, 744 train_dataloader=datamodule.train_dataloader,
699 val_dataloader=datamodule.val_dataloader, 745 val_dataloader=datamodule.val_dataloader,
746 seed=args.seed,
700 optimizer=optimizer, 747 optimizer=optimizer,
701 lr_scheduler=lr_scheduler, 748 lr_scheduler=lr_scheduler,
702 num_train_epochs=args.num_train_epochs, 749 num_train_epochs=args.num_train_epochs,
diff --git a/train_ti.py b/train_ti.py
index a894ee7..7aecdef 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -360,10 +360,16 @@ def parse_args():
360 help="Number of images in the validation dataset." 360 help="Number of images in the validation dataset."
361 ) 361 )
362 parser.add_argument( 362 parser.add_argument(
363 "--valid_set_repeat", 363 "--train_set_pad",
364 type=int, 364 type=int,
365 default=1, 365 default=None,
366 help="Times the images in the validation dataset are repeated." 366 help="The number to fill train dataset items up to."
367 )
368 parser.add_argument(
369 "--valid_set_pad",
370 type=int,
371 default=None,
372 help="The number to fill validation dataset items up to."
367 ) 373 )
368 parser.add_argument( 374 parser.add_argument(
369 "--train_batch_size", 375 "--train_batch_size",
@@ -575,7 +581,8 @@ def main():
575 shuffle=not args.no_tag_shuffle, 581 shuffle=not args.no_tag_shuffle,
576 template_key=args.train_data_template, 582 template_key=args.train_data_template,
577 valid_set_size=args.valid_set_size, 583 valid_set_size=args.valid_set_size,
578 valid_set_repeat=args.valid_set_repeat, 584 train_set_pad=args.train_set_pad,
585 valid_set_pad=args.valid_set_pad,
579 seed=args.seed, 586 seed=args.seed,
580 filter=partial(keyword_filter, args.placeholder_tokens, args.collection, args.exclude_collections), 587 filter=partial(keyword_filter, args.placeholder_tokens, args.collection, args.exclude_collections),
581 dtype=weight_dtype 588 dtype=weight_dtype
@@ -590,7 +597,7 @@ def main():
590 unet, 597 unet,
591 tokenizer, 598 tokenizer,
592 sample_scheduler, 599 sample_scheduler,
593 datamodule.data_train, 600 datamodule.train_dataset,
594 args.sample_batch_size, 601 args.sample_batch_size,
595 args.sample_image_size, 602 args.sample_image_size,
596 args.sample_steps 603 args.sample_steps
diff --git a/training/functional.py b/training/functional.py
index c6b4dc3..b6b5d87 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -17,6 +17,7 @@ from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel, DPMSol
17from tqdm.auto import tqdm 17from tqdm.auto import tqdm
18from PIL import Image 18from PIL import Image
19 19
20from data.csv import VlpnDataset
20from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 21from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
21from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings 22from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embeddings
22from models.clip.util import get_extended_embeddings 23from models.clip.util import get_extended_embeddings
@@ -175,12 +176,12 @@ def generate_class_images(
175 unet: UNet2DConditionModel, 176 unet: UNet2DConditionModel,
176 tokenizer: MultiCLIPTokenizer, 177 tokenizer: MultiCLIPTokenizer,
177 sample_scheduler: DPMSolverMultistepScheduler, 178 sample_scheduler: DPMSolverMultistepScheduler,
178 data_train, 179 train_dataset: VlpnDataset,
179 sample_batch_size: int, 180 sample_batch_size: int,
180 sample_image_size: int, 181 sample_image_size: int,
181 sample_steps: int 182 sample_steps: int
182): 183):
183 missing_data = [item for item in data_train if not item.class_image_path.exists()] 184 missing_data = [item for item in train_dataset.items if not item.class_image_path.exists()]
184 185
185 if len(missing_data) == 0: 186 if len(missing_data) == 0:
186 return 187 return
diff --git a/training/optimization.py b/training/optimization.py
index 5db7794..6dee4bc 100644
--- a/training/optimization.py
+++ b/training/optimization.py
@@ -49,8 +49,8 @@ def get_one_cycle_schedule(
49 annealing: Literal["cos", "half_cos", "linear"] = "cos", 49 annealing: Literal["cos", "half_cos", "linear"] = "cos",
50 warmup_exp: int = 1, 50 warmup_exp: int = 1,
51 annealing_exp: int = 1, 51 annealing_exp: int = 1,
52 min_lr: int = 0.04, 52 min_lr: float = 0.04,
53 mid_point: int = 0.3, 53 mid_point: float = 0.3,
54 last_epoch: int = -1 54 last_epoch: int = -1
55): 55):
56 if warmup == "linear": 56 if warmup == "linear":
@@ -91,10 +91,10 @@ def get_scheduler(
91 id: str, 91 id: str,
92 optimizer: torch.optim.Optimizer, 92 optimizer: torch.optim.Optimizer,
93 num_training_steps_per_epoch: int, 93 num_training_steps_per_epoch: int,
94 gradient_accumulation_steps: int, 94 gradient_accumulation_steps: int = 1,
95 min_lr: float = 0.04, 95 min_lr: float = 0.04,
96 warmup_func: str = "cos", 96 warmup_func: Literal["cos", "linear"] = "cos",
97 annealing_func: str = "cos", 97 annealing_func: Literal["cos", "half_cos", "linear"] = "cos",
98 warmup_exp: int = 1, 98 warmup_exp: int = 1,
99 annealing_exp: int = 1, 99 annealing_exp: int = 1,
100 cycles: int = 1, 100 cycles: int = 1,
diff --git a/training/strategy/ti.py b/training/strategy/ti.py
index 568f9eb..9d39e15 100644
--- a/training/strategy/ti.py
+++ b/training/strategy/ti.py
@@ -36,7 +36,7 @@ def textual_inversion_strategy(
36 use_emb_decay: bool = False, 36 use_emb_decay: bool = False,
37 emb_decay_target: float = 0.4, 37 emb_decay_target: float = 0.4,
38 emb_decay_factor: float = 1, 38 emb_decay_factor: float = 1,
39 emb_decay_start: float = 1e-4, 39 emb_decay_start: float = 0,
40 use_ema: bool = False, 40 use_ema: bool = False,
41 ema_inv_gamma: float = 1.0, 41 ema_inv_gamma: float = 1.0,
42 ema_power: int = 1, 42 ema_power: int = 1,