diff options
Diffstat (limited to 'train_ti.py')
| -rw-r--r-- | train_ti.py | 14 |
1 files changed, 14 insertions, 0 deletions
diff --git a/train_ti.py b/train_ti.py index d7878cd..082e9b7 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -13,10 +13,12 @@ from accelerate import Accelerator | |||
| 13 | from accelerate.logging import get_logger | 13 | from accelerate.logging import get_logger |
| 14 | from accelerate.utils import LoggerType, set_seed | 14 | from accelerate.utils import LoggerType, set_seed |
| 15 | from slugify import slugify | 15 | from slugify import slugify |
| 16 | from timm.models import create_model | ||
| 16 | import transformers | 17 | import transformers |
| 17 | 18 | ||
| 18 | from util.files import load_config, load_embeddings_from_dir | 19 | from util.files import load_config, load_embeddings_from_dir |
| 19 | from data.csv import VlpnDataModule, keyword_filter | 20 | from data.csv import VlpnDataModule, keyword_filter |
| 21 | from models.convnext.discriminator import ConvNeXtDiscriminator | ||
| 20 | from training.functional import train, add_placeholder_tokens, get_models | 22 | from training.functional import train, add_placeholder_tokens, get_models |
| 21 | from training.strategy.ti import textual_inversion_strategy | 23 | from training.strategy.ti import textual_inversion_strategy |
| 22 | from training.optimization import get_scheduler | 24 | from training.optimization import get_scheduler |
| @@ -661,6 +663,17 @@ def main(): | |||
| 661 | unet.enable_gradient_checkpointing() | 663 | unet.enable_gradient_checkpointing() |
| 662 | text_encoder.gradient_checkpointing_enable() | 664 | text_encoder.gradient_checkpointing_enable() |
| 663 | 665 | ||
| 666 | convnext = create_model( | ||
| 667 | "convnext_tiny", | ||
| 668 | pretrained=False, | ||
| 669 | num_classes=3, | ||
| 670 | drop_path_rate=0.0, | ||
| 671 | ) | ||
| 672 | convnext.to(accelerator.device, dtype=weight_dtype) | ||
| 673 | convnext.requires_grad_(False) | ||
| 674 | convnext.eval() | ||
| 675 | disc = ConvNeXtDiscriminator(convnext, input_size=384) | ||
| 676 | |||
| 664 | if len(args.alias_tokens) != 0: | 677 | if len(args.alias_tokens) != 0: |
| 665 | alias_placeholder_tokens = args.alias_tokens[::2] | 678 | alias_placeholder_tokens = args.alias_tokens[::2] |
| 666 | alias_initializer_tokens = args.alias_tokens[1::2] | 679 | alias_initializer_tokens = args.alias_tokens[1::2] |
| @@ -802,6 +815,7 @@ def main(): | |||
| 802 | milestone_checkpoints=not args.no_milestone_checkpoints, | 815 | milestone_checkpoints=not args.no_milestone_checkpoints, |
| 803 | global_step_offset=global_step_offset, | 816 | global_step_offset=global_step_offset, |
| 804 | offset_noise_strength=args.offset_noise_strength, | 817 | offset_noise_strength=args.offset_noise_strength, |
| 818 | disc=disc, | ||
| 805 | # -- | 819 | # -- |
| 806 | use_emb_decay=args.use_emb_decay, | 820 | use_emb_decay=args.use_emb_decay, |
| 807 | emb_decay_target=args.emb_decay_target, | 821 | emb_decay_target=args.emb_decay_target, |
