summaryrefslogtreecommitdiffstats
path: root/train_ti.py
diff options
context:
space:
mode:
Diffstat (limited to 'train_ti.py')
-rw-r--r--train_ti.py14
1 files changed, 14 insertions, 0 deletions
diff --git a/train_ti.py b/train_ti.py
index d7878cd..082e9b7 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -13,10 +13,12 @@ from accelerate import Accelerator
13from accelerate.logging import get_logger 13from accelerate.logging import get_logger
14from accelerate.utils import LoggerType, set_seed 14from accelerate.utils import LoggerType, set_seed
15from slugify import slugify 15from slugify import slugify
16from timm.models import create_model
16import transformers 17import transformers
17 18
18from util.files import load_config, load_embeddings_from_dir 19from util.files import load_config, load_embeddings_from_dir
19from data.csv import VlpnDataModule, keyword_filter 20from data.csv import VlpnDataModule, keyword_filter
21from models.convnext.discriminator import ConvNeXtDiscriminator
20from training.functional import train, add_placeholder_tokens, get_models 22from training.functional import train, add_placeholder_tokens, get_models
21from training.strategy.ti import textual_inversion_strategy 23from training.strategy.ti import textual_inversion_strategy
22from training.optimization import get_scheduler 24from training.optimization import get_scheduler
@@ -661,6 +663,17 @@ def main():
661 unet.enable_gradient_checkpointing() 663 unet.enable_gradient_checkpointing()
662 text_encoder.gradient_checkpointing_enable() 664 text_encoder.gradient_checkpointing_enable()
663 665
666 convnext = create_model(
667 "convnext_tiny",
668 pretrained=False,
669 num_classes=3,
670 drop_path_rate=0.0,
671 )
672 convnext.to(accelerator.device, dtype=weight_dtype)
673 convnext.requires_grad_(False)
674 convnext.eval()
675 disc = ConvNeXtDiscriminator(convnext, input_size=384)
676
664 if len(args.alias_tokens) != 0: 677 if len(args.alias_tokens) != 0:
665 alias_placeholder_tokens = args.alias_tokens[::2] 678 alias_placeholder_tokens = args.alias_tokens[::2]
666 alias_initializer_tokens = args.alias_tokens[1::2] 679 alias_initializer_tokens = args.alias_tokens[1::2]
@@ -802,6 +815,7 @@ def main():
802 milestone_checkpoints=not args.no_milestone_checkpoints, 815 milestone_checkpoints=not args.no_milestone_checkpoints,
803 global_step_offset=global_step_offset, 816 global_step_offset=global_step_offset,
804 offset_noise_strength=args.offset_noise_strength, 817 offset_noise_strength=args.offset_noise_strength,
818 disc=disc,
805 # -- 819 # --
806 use_emb_decay=args.use_emb_decay, 820 use_emb_decay=args.use_emb_decay,
807 emb_decay_target=args.emb_decay_target, 821 emb_decay_target=args.emb_decay_target,