summaryrefslogtreecommitdiffstats
path: root/infer.py
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-10-21 09:50:46 +0200
committerVolpeon <git@volpeon.ink>2022-10-21 09:50:46 +0200
commitca914af018632b6231fb3ee4fcd5cdbdc467c784 (patch)
tree01af701c5ac740518cdbc4001592a3f9a29cc57a /infer.py
parentDreambooth: Added option to insert a new input token; removed Dreambooth Plus (diff)
downloadtextual-inversion-diff-ca914af018632b6231fb3ee4fcd5cdbdc467c784.tar.gz
textual-inversion-diff-ca914af018632b6231fb3ee4fcd5cdbdc467c784.tar.bz2
textual-inversion-diff-ca914af018632b6231fb3ee4fcd5cdbdc467c784.zip
Add optional TI functionality to Dreambooth
Diffstat (limited to 'infer.py')
-rw-r--r--infer.py6
1 files changed, 3 insertions, 3 deletions
diff --git a/infer.py b/infer.py
index 8e17c4e..01010eb 100644
--- a/infer.py
+++ b/infer.py
@@ -258,7 +258,7 @@ def generate(output_dir, pipeline, args):
258 output_dir = output_dir.joinpath(f"{now}_{slugify(args.prompt[0])[:100]}") 258 output_dir = output_dir.joinpath(f"{now}_{slugify(args.prompt[0])[:100]}")
259 output_dir.mkdir(parents=True, exist_ok=True) 259 output_dir.mkdir(parents=True, exist_ok=True)
260 260
261 seed = args.seed or torch.random.seed() 261 args.seed = args.seed or torch.random.seed()
262 262
263 save_args(output_dir, args) 263 save_args(output_dir, args)
264 264
@@ -276,7 +276,7 @@ def generate(output_dir, pipeline, args):
276 dynamic_ncols=True 276 dynamic_ncols=True
277 ) 277 )
278 278
279 generator = torch.Generator(device="cuda").manual_seed(seed + i) 279 generator = torch.Generator(device="cuda").manual_seed(args.seed + i)
280 images = pipeline( 280 images = pipeline(
281 prompt=args.prompt * (args.batch_size // len(args.prompt)), 281 prompt=args.prompt * (args.batch_size // len(args.prompt)),
282 height=args.height, 282 height=args.height,
@@ -290,7 +290,7 @@ def generate(output_dir, pipeline, args):
290 ).images 290 ).images
291 291
292 for j, image in enumerate(images): 292 for j, image in enumerate(images):
293 image.save(output_dir.joinpath(f"{seed + i}_{j}.jpg")) 293 image.save(output_dir.joinpath(f"{args.seed + i}_{j}.jpg"))
294 294
295 if torch.cuda.is_available(): 295 if torch.cuda.is_available():
296 torch.cuda.empty_cache() 296 torch.cuda.empty_cache()