summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-06-21 21:46:15 +0200
committerVolpeon <git@volpeon.ink>2023-06-21 21:46:15 +0200
commitb5d3df18c3a56699a3658ad58a02d4494836972f (patch)
tree8a43468111eee827564bb5d1561d2d4910915c61
parentUpdate (diff)
downloadtextual-inversion-diff-b5d3df18c3a56699a3658ad58a02d4494836972f.tar.gz
textual-inversion-diff-b5d3df18c3a56699a3658ad58a02d4494836972f.tar.bz2
textual-inversion-diff-b5d3df18c3a56699a3658ad58a02d4494836972f.zip
Update
-rw-r--r--train_dreambooth.py9
-rw-r--r--train_lora.py8
-rw-r--r--train_ti.py7
-rw-r--r--training/functional.py13
-rw-r--r--training/strategy/dreambooth.py10
5 files changed, 11 insertions, 36 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 659b84c..0543a35 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -246,12 +246,6 @@ def parse_args():
246 ), 246 ),
247 ) 247 )
248 parser.add_argument( 248 parser.add_argument(
249 "--offset_noise_strength",
250 type=float,
251 default=0,
252 help="Perlin offset noise strength.",
253 )
254 parser.add_argument(
255 "--input_pertubation", 249 "--input_pertubation",
256 type=float, 250 type=float,
257 default=0, 251 default=0,
@@ -496,7 +490,6 @@ def parse_args():
496 default=1.0, 490 default=1.0,
497 help="The weight of prior preservation loss.", 491 help="The weight of prior preservation loss.",
498 ) 492 )
499 parser.add_argument("--run_pti", action="store_true", help="Whether to run PTI.")
500 parser.add_argument("--emb_alpha", type=float, default=1.0, help="Embedding alpha") 493 parser.add_argument("--emb_alpha", type=float, default=1.0, help="Embedding alpha")
501 parser.add_argument( 494 parser.add_argument(
502 "--emb_dropout", 495 "--emb_dropout",
@@ -679,6 +672,7 @@ def main():
679 672
680 if args.gradient_checkpointing: 673 if args.gradient_checkpointing:
681 unet.enable_gradient_checkpointing() 674 unet.enable_gradient_checkpointing()
675 text_encoder.gradient_checkpointing_enable()
682 676
683 if len(args.alias_tokens) != 0: 677 if len(args.alias_tokens) != 0:
684 alias_placeholder_tokens = args.alias_tokens[::2] 678 alias_placeholder_tokens = args.alias_tokens[::2]
@@ -1074,7 +1068,6 @@ def main():
1074 sample_output_dir=dreambooth_sample_output_dir, 1068 sample_output_dir=dreambooth_sample_output_dir,
1075 checkpoint_output_dir=dreambooth_checkpoint_output_dir, 1069 checkpoint_output_dir=dreambooth_checkpoint_output_dir,
1076 sample_frequency=dreambooth_sample_frequency, 1070 sample_frequency=dreambooth_sample_frequency,
1077 offset_noise_strength=args.offset_noise_strength,
1078 input_pertubation=args.input_pertubation, 1071 input_pertubation=args.input_pertubation,
1079 no_val=args.valid_set_size == 0, 1072 no_val=args.valid_set_size == 0,
1080 avg_loss=avg_loss, 1073 avg_loss=avg_loss,
diff --git a/train_lora.py b/train_lora.py
index fccf48d..b7ee2d6 100644
--- a/train_lora.py
+++ b/train_lora.py
@@ -259,12 +259,6 @@ def parse_args():
259 ), 259 ),
260 ) 260 )
261 parser.add_argument( 261 parser.add_argument(
262 "--offset_noise_strength",
263 type=float,
264 default=0,
265 help="Perlin offset noise strength.",
266 )
267 parser.add_argument(
268 "--input_pertubation", 262 "--input_pertubation",
269 type=float, 263 type=float,
270 default=0, 264 default=0,
@@ -1138,7 +1132,6 @@ def main():
1138 sample_output_dir=pti_sample_output_dir, 1132 sample_output_dir=pti_sample_output_dir,
1139 checkpoint_output_dir=pti_checkpoint_output_dir, 1133 checkpoint_output_dir=pti_checkpoint_output_dir,
1140 sample_frequency=pti_sample_frequency, 1134 sample_frequency=pti_sample_frequency,
1141 offset_noise_strength=0,
1142 input_pertubation=args.input_pertubation, 1135 input_pertubation=args.input_pertubation,
1143 no_val=True, 1136 no_val=True,
1144 ) 1137 )
@@ -1291,7 +1284,6 @@ def main():
1291 sample_output_dir=lora_sample_output_dir, 1284 sample_output_dir=lora_sample_output_dir,
1292 checkpoint_output_dir=lora_checkpoint_output_dir, 1285 checkpoint_output_dir=lora_checkpoint_output_dir,
1293 sample_frequency=lora_sample_frequency, 1286 sample_frequency=lora_sample_frequency,
1294 offset_noise_strength=args.offset_noise_strength,
1295 input_pertubation=args.input_pertubation, 1287 input_pertubation=args.input_pertubation,
1296 no_val=args.valid_set_size == 0, 1288 no_val=args.valid_set_size == 0,
1297 avg_loss=avg_loss, 1289 avg_loss=avg_loss,
diff --git a/train_ti.py b/train_ti.py
index c6f0b3a..da0c03e 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -230,12 +230,6 @@ def parse_args():
230 help="Vector shuffling algorithm.", 230 help="Vector shuffling algorithm.",
231 ) 231 )
232 parser.add_argument( 232 parser.add_argument(
233 "--offset_noise_strength",
234 type=float,
235 default=0,
236 help="Offset noise strength.",
237 )
238 parser.add_argument(
239 "--input_pertubation", 233 "--input_pertubation",
240 type=float, 234 type=float,
241 default=0, 235 default=0,
@@ -876,7 +870,6 @@ def main():
876 checkpoint_frequency=args.checkpoint_frequency, 870 checkpoint_frequency=args.checkpoint_frequency,
877 milestone_checkpoints=not args.no_milestone_checkpoints, 871 milestone_checkpoints=not args.no_milestone_checkpoints,
878 global_step_offset=global_step_offset, 872 global_step_offset=global_step_offset,
879 offset_noise_strength=args.offset_noise_strength,
880 input_pertubation=args.input_pertubation, 873 input_pertubation=args.input_pertubation,
881 # -- 874 # --
882 use_emb_decay=args.use_emb_decay, 875 use_emb_decay=args.use_emb_decay,
diff --git a/training/functional.py b/training/functional.py
index f68faf9..3c7848f 100644
--- a/training/functional.py
+++ b/training/functional.py
@@ -348,7 +348,6 @@ def loss_step(
348 guidance_scale: float, 348 guidance_scale: float,
349 prior_loss_weight: float, 349 prior_loss_weight: float,
350 seed: int, 350 seed: int,
351 offset_noise_strength: float,
352 input_pertubation: float, 351 input_pertubation: float,
353 disc: Optional[ConvNeXtDiscriminator], 352 disc: Optional[ConvNeXtDiscriminator],
354 min_snr_gamma: int, 353 min_snr_gamma: int,
@@ -377,16 +376,6 @@ def loss_step(
377 ) 376 )
378 applied_noise = noise 377 applied_noise = noise
379 378
380 if offset_noise_strength != 0:
381 applied_noise = applied_noise + offset_noise_strength * perlin_noise(
382 latents.shape,
383 res=1,
384 octaves=4,
385 dtype=latents.dtype,
386 device=latents.device,
387 generator=generator,
388 )
389
390 if input_pertubation != 0: 379 if input_pertubation != 0:
391 applied_noise = applied_noise + input_pertubation * torch.randn( 380 applied_noise = applied_noise + input_pertubation * torch.randn(
392 latents.shape, 381 latents.shape,
@@ -751,7 +740,6 @@ def train(
751 global_step_offset: int = 0, 740 global_step_offset: int = 0,
752 guidance_scale: float = 0.0, 741 guidance_scale: float = 0.0,
753 prior_loss_weight: float = 1.0, 742 prior_loss_weight: float = 1.0,
754 offset_noise_strength: float = 0.01,
755 input_pertubation: float = 0.1, 743 input_pertubation: float = 0.1,
756 disc: Optional[ConvNeXtDiscriminator] = None, 744 disc: Optional[ConvNeXtDiscriminator] = None,
757 schedule_sampler: Optional[ScheduleSampler] = None, 745 schedule_sampler: Optional[ScheduleSampler] = None,
@@ -814,7 +802,6 @@ def train(
814 guidance_scale, 802 guidance_scale,
815 prior_loss_weight, 803 prior_loss_weight,
816 seed, 804 seed,
817 offset_noise_strength,
818 input_pertubation, 805 input_pertubation,
819 disc, 806 disc,
820 min_snr_gamma, 807 min_snr_gamma,
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py
index 88b441b..43fe838 100644
--- a/training/strategy/dreambooth.py
+++ b/training/strategy/dreambooth.py
@@ -1,4 +1,5 @@
1from typing import Optional 1from typing import Optional
2from types import MethodType
2from functools import partial 3from functools import partial
3from contextlib import contextmanager, nullcontext 4from contextlib import contextmanager, nullcontext
4from pathlib import Path 5from pathlib import Path
@@ -130,6 +131,9 @@ def dreambooth_strategy_callbacks(
130 unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False) 131 unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False)
131 text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) 132 text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False)
132 133
134 unet_.forward = MethodType(unet_.forward, unet_)
135 text_encoder_.forward = MethodType(text_encoder_.forward, text_encoder_)
136
133 with ema_context(): 137 with ema_context():
134 pipeline = VlpnStableDiffusion( 138 pipeline = VlpnStableDiffusion(
135 text_encoder=text_encoder_, 139 text_encoder=text_encoder_,
@@ -185,6 +189,7 @@ def dreambooth_prepare(
185 train_dataloader: DataLoader, 189 train_dataloader: DataLoader,
186 val_dataloader: Optional[DataLoader], 190 val_dataloader: Optional[DataLoader],
187 lr_scheduler: torch.optim.lr_scheduler._LRScheduler, 191 lr_scheduler: torch.optim.lr_scheduler._LRScheduler,
192 text_encoder_unfreeze_last_n_layers: int = 2,
188 **kwargs 193 **kwargs
189): 194):
190 ( 195 (
@@ -198,6 +203,11 @@ def dreambooth_prepare(
198 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler 203 text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler
199 ) 204 )
200 205
206 for layer in text_encoder.text_model.encoder.layers[
207 : (-1 * text_encoder_unfreeze_last_n_layers)
208 ]:
209 layer.requires_grad_(False)
210
201 text_encoder.text_model.embeddings.requires_grad_(False) 211 text_encoder.text_model.embeddings.requires_grad_(False)
202 212
203 return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler 213 return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler