diff options
Diffstat (limited to 'train_ti.py')
-rw-r--r-- | train_ti.py | 14 |
1 files changed, 14 insertions, 0 deletions
diff --git a/train_ti.py b/train_ti.py index d7878cd..082e9b7 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -13,10 +13,12 @@ from accelerate import Accelerator | |||
13 | from accelerate.logging import get_logger | 13 | from accelerate.logging import get_logger |
14 | from accelerate.utils import LoggerType, set_seed | 14 | from accelerate.utils import LoggerType, set_seed |
15 | from slugify import slugify | 15 | from slugify import slugify |
16 | from timm.models import create_model | ||
16 | import transformers | 17 | import transformers |
17 | 18 | ||
18 | from util.files import load_config, load_embeddings_from_dir | 19 | from util.files import load_config, load_embeddings_from_dir |
19 | from data.csv import VlpnDataModule, keyword_filter | 20 | from data.csv import VlpnDataModule, keyword_filter |
21 | from models.convnext.discriminator import ConvNeXtDiscriminator | ||
20 | from training.functional import train, add_placeholder_tokens, get_models | 22 | from training.functional import train, add_placeholder_tokens, get_models |
21 | from training.strategy.ti import textual_inversion_strategy | 23 | from training.strategy.ti import textual_inversion_strategy |
22 | from training.optimization import get_scheduler | 24 | from training.optimization import get_scheduler |
@@ -661,6 +663,17 @@ def main(): | |||
661 | unet.enable_gradient_checkpointing() | 663 | unet.enable_gradient_checkpointing() |
662 | text_encoder.gradient_checkpointing_enable() | 664 | text_encoder.gradient_checkpointing_enable() |
663 | 665 | ||
666 | convnext = create_model( | ||
667 | "convnext_tiny", | ||
668 | pretrained=False, | ||
669 | num_classes=3, | ||
670 | drop_path_rate=0.0, | ||
671 | ) | ||
672 | convnext.to(accelerator.device, dtype=weight_dtype) | ||
673 | convnext.requires_grad_(False) | ||
674 | convnext.eval() | ||
675 | disc = ConvNeXtDiscriminator(convnext, input_size=384) | ||
676 | |||
664 | if len(args.alias_tokens) != 0: | 677 | if len(args.alias_tokens) != 0: |
665 | alias_placeholder_tokens = args.alias_tokens[::2] | 678 | alias_placeholder_tokens = args.alias_tokens[::2] |
666 | alias_initializer_tokens = args.alias_tokens[1::2] | 679 | alias_initializer_tokens = args.alias_tokens[1::2] |
@@ -802,6 +815,7 @@ def main(): | |||
802 | milestone_checkpoints=not args.no_milestone_checkpoints, | 815 | milestone_checkpoints=not args.no_milestone_checkpoints, |
803 | global_step_offset=global_step_offset, | 816 | global_step_offset=global_step_offset, |
804 | offset_noise_strength=args.offset_noise_strength, | 817 | offset_noise_strength=args.offset_noise_strength, |
818 | disc=disc, | ||
805 | # -- | 819 | # -- |
806 | use_emb_decay=args.use_emb_decay, | 820 | use_emb_decay=args.use_emb_decay, |
807 | emb_decay_target=args.emb_decay_target, | 821 | emb_decay_target=args.emb_decay_target, |