summaryrefslogtreecommitdiffstats
path: root/train_lora.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-05-05 10:51:14 +0200
committerVolpeon <git@volpeon.ink>2023-05-05 10:51:14 +0200
commit8d2aa65402c829583e26cdf2c336b8d3057657d6 (patch)
treecc2d47f56d1433e7600abd494361b1ae0a068f80 /train_lora.py
parenttorch.compile won't work yet, keep code prepared (diff)
downloadtextual-inversion-diff-8d2aa65402c829583e26cdf2c336b8d3057657d6.tar.gz
textual-inversion-diff-8d2aa65402c829583e26cdf2c336b8d3057657d6.tar.bz2
textual-inversion-diff-8d2aa65402c829583e26cdf2c336b8d3057657d6.zip
Update
Diffstat (limited to 'train_lora.py')
-rw-r--r--train_lora.py8
1 files changed, 8 insertions, 0 deletions
diff --git a/train_lora.py b/train_lora.py
index 3c8fc97..cc7c1ec 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -251,6 +251,12 @@ def parse_args():
251 help="Perlin offset noise strength.", 251 help="Perlin offset noise strength.",
252 ) 252 )
253 parser.add_argument( 253 parser.add_argument(
254 "--input_pertubation",
255 type=float,
256 default=0,
257 help="The scale of input pretubation. Recommended 0.1."
258 )
259 parser.add_argument(
254 "--num_train_epochs", 260 "--num_train_epochs",
255 type=int, 261 type=int,
256 default=None 262 default=None
@@ -1040,6 +1046,7 @@ def main():
1040 checkpoint_output_dir=pti_checkpoint_output_dir, 1046 checkpoint_output_dir=pti_checkpoint_output_dir,
1041 sample_frequency=pti_sample_frequency, 1047 sample_frequency=pti_sample_frequency,
1042 offset_noise_strength=0, 1048 offset_noise_strength=0,
1049 input_pertubation=args.input_pertubation,
1043 no_val=True, 1050 no_val=True,
1044 ) 1051 )
1045 1052
@@ -1195,6 +1202,7 @@ def main():
1195 checkpoint_output_dir=lora_checkpoint_output_dir, 1202 checkpoint_output_dir=lora_checkpoint_output_dir,
1196 sample_frequency=lora_sample_frequency, 1203 sample_frequency=lora_sample_frequency,
1197 offset_noise_strength=args.offset_noise_strength, 1204 offset_noise_strength=args.offset_noise_strength,
1205 input_pertubation=args.input_pertubation,
1198 no_val=args.valid_set_size == 0, 1206 no_val=args.valid_set_size == 0,
1199 avg_loss=avg_loss, 1207 avg_loss=avg_loss,
1200 avg_acc=avg_acc, 1208 avg_acc=avg_acc,