diff options
-rw-r--r-- | data/csv.py | 2 | ||||
-rw-r--r-- | train_ti.py | 6 |
2 files changed, 2 insertions, 6 deletions
diff --git a/data/csv.py b/data/csv.py index 6857b6f..85b98f8 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -282,7 +282,7 @@ class VlpnDataModule(): | |||
282 | collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.num_class_images != 0) | 282 | collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.num_class_images != 0) |
283 | 283 | ||
284 | if valid_set_size == 0: | 284 | if valid_set_size == 0: |
285 | data_train, data_val = items, [] | 285 | data_train, data_val = items, items[:1] |
286 | else: | 286 | else: |
287 | data_train, data_val = random_split(items, [train_set_size, valid_set_size], generator=generator) | 287 | data_train, data_val = random_split(items, [train_set_size, valid_set_size], generator=generator) |
288 | 288 | ||
diff --git a/train_ti.py b/train_ti.py index adba8d4..e696577 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -409,7 +409,7 @@ def parse_args(): | |||
409 | ) | 409 | ) |
410 | parser.add_argument( | 410 | parser.add_argument( |
411 | "--emb_decay_start", | 411 | "--emb_decay_start", |
412 | default=1e-4, | 412 | default=0, |
413 | type=float, | 413 | type=float, |
414 | help="Embedding decay start offset." | 414 | help="Embedding decay start offset." |
415 | ) | 415 | ) |
@@ -514,8 +514,6 @@ def main(): | |||
514 | 514 | ||
515 | set_seed(args.seed) | 515 | set_seed(args.seed) |
516 | 516 | ||
517 | seed_generator = torch.Generator().manual_seed(args.seed) | ||
518 | |||
519 | save_args(output_dir, args) | 517 | save_args(output_dir, args) |
520 | 518 | ||
521 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( | 519 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( |
@@ -595,8 +593,6 @@ def main(): | |||
595 | print( | 593 | print( |
596 | f"{i + 1}: {placeholder_token}, {placeholder_token_ids[0]} ({initializer_token}, {initializer_token_ids[0]})") | 594 | f"{i + 1}: {placeholder_token}, {placeholder_token_ids[0]} ({initializer_token}, {initializer_token_ids[0]})") |
597 | 595 | ||
598 | args.seed = seed_generator.seed() | ||
599 | |||
600 | datamodule = VlpnDataModule( | 596 | datamodule = VlpnDataModule( |
601 | data_file=args.train_data_file, | 597 | data_file=args.train_data_file, |
602 | batch_size=args.train_batch_size, | 598 | batch_size=args.train_batch_size, |