From 64c79cc3e7fad49131f90fbb0648b6d5587563e5 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 10 Dec 2022 08:43:34 +0100 Subject: Various updated; shuffle prompt content during training --- dreambooth.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) (limited to 'dreambooth.py') diff --git a/dreambooth.py b/dreambooth.py index ec9531e..0044c1e 100644 --- a/dreambooth.py +++ b/dreambooth.py @@ -1,13 +1,11 @@ import argparse import itertools import math -import os import datetime import logging import json from pathlib import Path -import numpy as np import torch import torch.nn.functional as F import torch.utils.checkpoint @@ -299,7 +297,7 @@ def parse_args(): parser.add_argument( "--sample_steps", type=int, - default=20, + default=15, help="Number of steps for sample generation. Higher values will result in more detailed samples, but longer runtimes.", ) parser.add_argument( @@ -613,7 +611,7 @@ def main(): ) # Freeze text_encoder and vae - freeze_params(vae.parameters()) + vae.requires_grad_(False) if len(args.placeholder_token) != 0: print(f"Adding text embeddings: {args.placeholder_token}") @@ -629,6 +627,10 @@ def main(): placeholder_token_id = tokenizer.convert_tokens_to_ids(args.placeholder_token) + print(f"Token ID mappings:") + for (token_id, token) in zip(placeholder_token_id, args.placeholder_token): + print(f"- {token_id} {token}") + # Resize the token embeddings as we are adding new special tokens to the tokenizer text_encoder.resize_token_embeddings(len(tokenizer)) -- cgit v1.2.3-54-g00ecf