summaryrefslogtreecommitdiffstats
path: root/train_dreambooth.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-03-28 16:24:22 +0200
committerVolpeon <git@volpeon.ink>2023-03-28 16:24:22 +0200
commit179a45253a5b3712f32bd127f693a6bb810a9c17 (patch)
treeac9f1152d858089742e4f9ce79e0870e0f2b9a2d /train_dreambooth.py
parentFix TI (diff)
downloadtextual-inversion-diff-179a45253a5b3712f32bd127f693a6bb810a9c17.tar.gz
textual-inversion-diff-179a45253a5b3712f32bd127f693a6bb810a9c17.tar.bz2
textual-inversion-diff-179a45253a5b3712f32bd127f693a6bb810a9c17.zip
Support num_train_steps arg again
Diffstat (limited to 'train_dreambooth.py')
-rw-r--r--train_dreambooth.py17
1 files changed, 11 insertions, 6 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 9345797..acb8287 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -4,6 +4,7 @@ import logging
4import itertools 4import itertools
5from pathlib import Path 5from pathlib import Path
6from functools import partial 6from functools import partial
7import math
7 8
8import torch 9import torch
9import torch.utils.checkpoint 10import torch.utils.checkpoint
@@ -189,13 +190,12 @@ def parse_args():
189 parser.add_argument( 190 parser.add_argument(
190 "--num_train_epochs", 191 "--num_train_epochs",
191 type=int, 192 type=int,
192 default=100 193 default=None
193 ) 194 )
194 parser.add_argument( 195 parser.add_argument(
195 "--max_train_steps", 196 "--num_train_steps",
196 type=int, 197 type=int,
197 default=None, 198 default=2000
198 help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
199 ) 199 )
200 parser.add_argument( 200 parser.add_argument(
201 "--gradient_accumulation_steps", 201 "--gradient_accumulation_steps",
@@ -595,6 +595,11 @@ def main():
595 ) 595 )
596 datamodule.setup() 596 datamodule.setup()
597 597
598 num_train_epochs = args.num_train_epochs
599
600 if num_train_epochs is None:
601 num_train_epochs = math.ceil(len(datamodule.train_dataset) / args.num_train_steps)
602
598 params_to_optimize = (unet.parameters(), ) 603 params_to_optimize = (unet.parameters(), )
599 if args.train_text_encoder_epochs != 0: 604 if args.train_text_encoder_epochs != 0:
600 params_to_optimize += ( 605 params_to_optimize += (
@@ -619,7 +624,7 @@ def main():
619 annealing_exp=args.lr_annealing_exp, 624 annealing_exp=args.lr_annealing_exp,
620 cycles=args.lr_cycles, 625 cycles=args.lr_cycles,
621 end_lr=1e2, 626 end_lr=1e2,
622 train_epochs=args.num_train_epochs, 627 train_epochs=num_train_epochs,
623 warmup_epochs=args.lr_warmup_epochs, 628 warmup_epochs=args.lr_warmup_epochs,
624 ) 629 )
625 630
@@ -631,7 +636,7 @@ def main():
631 seed=args.seed, 636 seed=args.seed,
632 optimizer=optimizer, 637 optimizer=optimizer,
633 lr_scheduler=lr_scheduler, 638 lr_scheduler=lr_scheduler,
634 num_train_epochs=args.num_train_epochs, 639 num_train_epochs=num_train_epochs,
635 gradient_accumulation_steps=args.gradient_accumulation_steps, 640 gradient_accumulation_steps=args.gradient_accumulation_steps,
636 sample_frequency=args.sample_frequency, 641 sample_frequency=args.sample_frequency,
637 offset_noise_strength=args.offset_noise_strength, 642 offset_noise_strength=args.offset_noise_strength,