diff options
| -rw-r--r-- | data/csv.py | 2 | ||||
| -rw-r--r-- | train_ti.py | 6 |
2 files changed, 2 insertions, 6 deletions
diff --git a/data/csv.py b/data/csv.py index 6857b6f..85b98f8 100644 --- a/data/csv.py +++ b/data/csv.py | |||
| @@ -282,7 +282,7 @@ class VlpnDataModule(): | |||
| 282 | collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.num_class_images != 0) | 282 | collate_fn_ = partial(collate_fn, self.dtype, self.tokenizer, self.num_class_images != 0) |
| 283 | 283 | ||
| 284 | if valid_set_size == 0: | 284 | if valid_set_size == 0: |
| 285 | data_train, data_val = items, [] | 285 | data_train, data_val = items, items[:1] |
| 286 | else: | 286 | else: |
| 287 | data_train, data_val = random_split(items, [train_set_size, valid_set_size], generator=generator) | 287 | data_train, data_val = random_split(items, [train_set_size, valid_set_size], generator=generator) |
| 288 | 288 | ||
diff --git a/train_ti.py b/train_ti.py index adba8d4..e696577 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -409,7 +409,7 @@ def parse_args(): | |||
| 409 | ) | 409 | ) |
| 410 | parser.add_argument( | 410 | parser.add_argument( |
| 411 | "--emb_decay_start", | 411 | "--emb_decay_start", |
| 412 | default=1e-4, | 412 | default=0, |
| 413 | type=float, | 413 | type=float, |
| 414 | help="Embedding decay start offset." | 414 | help="Embedding decay start offset." |
| 415 | ) | 415 | ) |
| @@ -514,8 +514,6 @@ def main(): | |||
| 514 | 514 | ||
| 515 | set_seed(args.seed) | 515 | set_seed(args.seed) |
| 516 | 516 | ||
| 517 | seed_generator = torch.Generator().manual_seed(args.seed) | ||
| 518 | |||
| 519 | save_args(output_dir, args) | 517 | save_args(output_dir, args) |
| 520 | 518 | ||
| 521 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( | 519 | tokenizer, text_encoder, vae, unet, noise_scheduler, sample_scheduler, embeddings = get_models( |
| @@ -595,8 +593,6 @@ def main(): | |||
| 595 | print( | 593 | print( |
| 596 | f"{i + 1}: {placeholder_token}, {placeholder_token_ids[0]} ({initializer_token}, {initializer_token_ids[0]})") | 594 | f"{i + 1}: {placeholder_token}, {placeholder_token_ids[0]} ({initializer_token}, {initializer_token_ids[0]})") |
| 597 | 595 | ||
| 598 | args.seed = seed_generator.seed() | ||
| 599 | |||
| 600 | datamodule = VlpnDataModule( | 596 | datamodule = VlpnDataModule( |
| 601 | data_file=args.train_data_file, | 597 | data_file=args.train_data_file, |
| 602 | batch_size=args.train_batch_size, | 598 | batch_size=args.train_batch_size, |
