diff options
| -rw-r--r-- | train_dreambooth.py | 9 | ||||
| -rw-r--r-- | train_lora.py | 8 | ||||
| -rw-r--r-- | train_ti.py | 7 | ||||
| -rw-r--r-- | training/functional.py | 13 | ||||
| -rw-r--r-- | training/strategy/dreambooth.py | 10 |
5 files changed, 11 insertions, 36 deletions
diff --git a/train_dreambooth.py b/train_dreambooth.py index 659b84c..0543a35 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -246,12 +246,6 @@ def parse_args(): | |||
| 246 | ), | 246 | ), |
| 247 | ) | 247 | ) |
| 248 | parser.add_argument( | 248 | parser.add_argument( |
| 249 | "--offset_noise_strength", | ||
| 250 | type=float, | ||
| 251 | default=0, | ||
| 252 | help="Perlin offset noise strength.", | ||
| 253 | ) | ||
| 254 | parser.add_argument( | ||
| 255 | "--input_pertubation", | 249 | "--input_pertubation", |
| 256 | type=float, | 250 | type=float, |
| 257 | default=0, | 251 | default=0, |
| @@ -496,7 +490,6 @@ def parse_args(): | |||
| 496 | default=1.0, | 490 | default=1.0, |
| 497 | help="The weight of prior preservation loss.", | 491 | help="The weight of prior preservation loss.", |
| 498 | ) | 492 | ) |
| 499 | parser.add_argument("--run_pti", action="store_true", help="Whether to run PTI.") | ||
| 500 | parser.add_argument("--emb_alpha", type=float, default=1.0, help="Embedding alpha") | 493 | parser.add_argument("--emb_alpha", type=float, default=1.0, help="Embedding alpha") |
| 501 | parser.add_argument( | 494 | parser.add_argument( |
| 502 | "--emb_dropout", | 495 | "--emb_dropout", |
| @@ -679,6 +672,7 @@ def main(): | |||
| 679 | 672 | ||
| 680 | if args.gradient_checkpointing: | 673 | if args.gradient_checkpointing: |
| 681 | unet.enable_gradient_checkpointing() | 674 | unet.enable_gradient_checkpointing() |
| 675 | text_encoder.gradient_checkpointing_enable() | ||
| 682 | 676 | ||
| 683 | if len(args.alias_tokens) != 0: | 677 | if len(args.alias_tokens) != 0: |
| 684 | alias_placeholder_tokens = args.alias_tokens[::2] | 678 | alias_placeholder_tokens = args.alias_tokens[::2] |
| @@ -1074,7 +1068,6 @@ def main(): | |||
| 1074 | sample_output_dir=dreambooth_sample_output_dir, | 1068 | sample_output_dir=dreambooth_sample_output_dir, |
| 1075 | checkpoint_output_dir=dreambooth_checkpoint_output_dir, | 1069 | checkpoint_output_dir=dreambooth_checkpoint_output_dir, |
| 1076 | sample_frequency=dreambooth_sample_frequency, | 1070 | sample_frequency=dreambooth_sample_frequency, |
| 1077 | offset_noise_strength=args.offset_noise_strength, | ||
| 1078 | input_pertubation=args.input_pertubation, | 1071 | input_pertubation=args.input_pertubation, |
| 1079 | no_val=args.valid_set_size == 0, | 1072 | no_val=args.valid_set_size == 0, |
| 1080 | avg_loss=avg_loss, | 1073 | avg_loss=avg_loss, |
diff --git a/train_lora.py b/train_lora.py index fccf48d..b7ee2d6 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -259,12 +259,6 @@ def parse_args(): | |||
| 259 | ), | 259 | ), |
| 260 | ) | 260 | ) |
| 261 | parser.add_argument( | 261 | parser.add_argument( |
| 262 | "--offset_noise_strength", | ||
| 263 | type=float, | ||
| 264 | default=0, | ||
| 265 | help="Perlin offset noise strength.", | ||
| 266 | ) | ||
| 267 | parser.add_argument( | ||
| 268 | "--input_pertubation", | 262 | "--input_pertubation", |
| 269 | type=float, | 263 | type=float, |
| 270 | default=0, | 264 | default=0, |
| @@ -1138,7 +1132,6 @@ def main(): | |||
| 1138 | sample_output_dir=pti_sample_output_dir, | 1132 | sample_output_dir=pti_sample_output_dir, |
| 1139 | checkpoint_output_dir=pti_checkpoint_output_dir, | 1133 | checkpoint_output_dir=pti_checkpoint_output_dir, |
| 1140 | sample_frequency=pti_sample_frequency, | 1134 | sample_frequency=pti_sample_frequency, |
| 1141 | offset_noise_strength=0, | ||
| 1142 | input_pertubation=args.input_pertubation, | 1135 | input_pertubation=args.input_pertubation, |
| 1143 | no_val=True, | 1136 | no_val=True, |
| 1144 | ) | 1137 | ) |
| @@ -1291,7 +1284,6 @@ def main(): | |||
| 1291 | sample_output_dir=lora_sample_output_dir, | 1284 | sample_output_dir=lora_sample_output_dir, |
| 1292 | checkpoint_output_dir=lora_checkpoint_output_dir, | 1285 | checkpoint_output_dir=lora_checkpoint_output_dir, |
| 1293 | sample_frequency=lora_sample_frequency, | 1286 | sample_frequency=lora_sample_frequency, |
| 1294 | offset_noise_strength=args.offset_noise_strength, | ||
| 1295 | input_pertubation=args.input_pertubation, | 1287 | input_pertubation=args.input_pertubation, |
| 1296 | no_val=args.valid_set_size == 0, | 1288 | no_val=args.valid_set_size == 0, |
| 1297 | avg_loss=avg_loss, | 1289 | avg_loss=avg_loss, |
diff --git a/train_ti.py b/train_ti.py index c6f0b3a..da0c03e 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -230,12 +230,6 @@ def parse_args(): | |||
| 230 | help="Vector shuffling algorithm.", | 230 | help="Vector shuffling algorithm.", |
| 231 | ) | 231 | ) |
| 232 | parser.add_argument( | 232 | parser.add_argument( |
| 233 | "--offset_noise_strength", | ||
| 234 | type=float, | ||
| 235 | default=0, | ||
| 236 | help="Offset noise strength.", | ||
| 237 | ) | ||
| 238 | parser.add_argument( | ||
| 239 | "--input_pertubation", | 233 | "--input_pertubation", |
| 240 | type=float, | 234 | type=float, |
| 241 | default=0, | 235 | default=0, |
| @@ -876,7 +870,6 @@ def main(): | |||
| 876 | checkpoint_frequency=args.checkpoint_frequency, | 870 | checkpoint_frequency=args.checkpoint_frequency, |
| 877 | milestone_checkpoints=not args.no_milestone_checkpoints, | 871 | milestone_checkpoints=not args.no_milestone_checkpoints, |
| 878 | global_step_offset=global_step_offset, | 872 | global_step_offset=global_step_offset, |
| 879 | offset_noise_strength=args.offset_noise_strength, | ||
| 880 | input_pertubation=args.input_pertubation, | 873 | input_pertubation=args.input_pertubation, |
| 881 | # -- | 874 | # -- |
| 882 | use_emb_decay=args.use_emb_decay, | 875 | use_emb_decay=args.use_emb_decay, |
diff --git a/training/functional.py b/training/functional.py index f68faf9..3c7848f 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -348,7 +348,6 @@ def loss_step( | |||
| 348 | guidance_scale: float, | 348 | guidance_scale: float, |
| 349 | prior_loss_weight: float, | 349 | prior_loss_weight: float, |
| 350 | seed: int, | 350 | seed: int, |
| 351 | offset_noise_strength: float, | ||
| 352 | input_pertubation: float, | 351 | input_pertubation: float, |
| 353 | disc: Optional[ConvNeXtDiscriminator], | 352 | disc: Optional[ConvNeXtDiscriminator], |
| 354 | min_snr_gamma: int, | 353 | min_snr_gamma: int, |
| @@ -377,16 +376,6 @@ def loss_step( | |||
| 377 | ) | 376 | ) |
| 378 | applied_noise = noise | 377 | applied_noise = noise |
| 379 | 378 | ||
| 380 | if offset_noise_strength != 0: | ||
| 381 | applied_noise = applied_noise + offset_noise_strength * perlin_noise( | ||
| 382 | latents.shape, | ||
| 383 | res=1, | ||
| 384 | octaves=4, | ||
| 385 | dtype=latents.dtype, | ||
| 386 | device=latents.device, | ||
| 387 | generator=generator, | ||
| 388 | ) | ||
| 389 | |||
| 390 | if input_pertubation != 0: | 379 | if input_pertubation != 0: |
| 391 | applied_noise = applied_noise + input_pertubation * torch.randn( | 380 | applied_noise = applied_noise + input_pertubation * torch.randn( |
| 392 | latents.shape, | 381 | latents.shape, |
| @@ -751,7 +740,6 @@ def train( | |||
| 751 | global_step_offset: int = 0, | 740 | global_step_offset: int = 0, |
| 752 | guidance_scale: float = 0.0, | 741 | guidance_scale: float = 0.0, |
| 753 | prior_loss_weight: float = 1.0, | 742 | prior_loss_weight: float = 1.0, |
| 754 | offset_noise_strength: float = 0.01, | ||
| 755 | input_pertubation: float = 0.1, | 743 | input_pertubation: float = 0.1, |
| 756 | disc: Optional[ConvNeXtDiscriminator] = None, | 744 | disc: Optional[ConvNeXtDiscriminator] = None, |
| 757 | schedule_sampler: Optional[ScheduleSampler] = None, | 745 | schedule_sampler: Optional[ScheduleSampler] = None, |
| @@ -814,7 +802,6 @@ def train( | |||
| 814 | guidance_scale, | 802 | guidance_scale, |
| 815 | prior_loss_weight, | 803 | prior_loss_weight, |
| 816 | seed, | 804 | seed, |
| 817 | offset_noise_strength, | ||
| 818 | input_pertubation, | 805 | input_pertubation, |
| 819 | disc, | 806 | disc, |
| 820 | min_snr_gamma, | 807 | min_snr_gamma, |
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index 88b441b..43fe838 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py | |||
| @@ -1,4 +1,5 @@ | |||
| 1 | from typing import Optional | 1 | from typing import Optional |
| 2 | from types import MethodType | ||
| 2 | from functools import partial | 3 | from functools import partial |
| 3 | from contextlib import contextmanager, nullcontext | 4 | from contextlib import contextmanager, nullcontext |
| 4 | from pathlib import Path | 5 | from pathlib import Path |
| @@ -130,6 +131,9 @@ def dreambooth_strategy_callbacks( | |||
| 130 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False) | 131 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False) |
| 131 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) | 132 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) |
| 132 | 133 | ||
| 134 | unet_.forward = MethodType(unet_.forward, unet_) | ||
| 135 | text_encoder_.forward = MethodType(text_encoder_.forward, text_encoder_) | ||
| 136 | |||
| 133 | with ema_context(): | 137 | with ema_context(): |
| 134 | pipeline = VlpnStableDiffusion( | 138 | pipeline = VlpnStableDiffusion( |
| 135 | text_encoder=text_encoder_, | 139 | text_encoder=text_encoder_, |
| @@ -185,6 +189,7 @@ def dreambooth_prepare( | |||
| 185 | train_dataloader: DataLoader, | 189 | train_dataloader: DataLoader, |
| 186 | val_dataloader: Optional[DataLoader], | 190 | val_dataloader: Optional[DataLoader], |
| 187 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, | 191 | lr_scheduler: torch.optim.lr_scheduler._LRScheduler, |
| 192 | text_encoder_unfreeze_last_n_layers: int = 2, | ||
| 188 | **kwargs | 193 | **kwargs |
| 189 | ): | 194 | ): |
| 190 | ( | 195 | ( |
| @@ -198,6 +203,11 @@ def dreambooth_prepare( | |||
| 198 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler | 203 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler |
| 199 | ) | 204 | ) |
| 200 | 205 | ||
| 206 | for layer in text_encoder.text_model.encoder.layers[ | ||
| 207 | : (-1 * text_encoder_unfreeze_last_n_layers) | ||
| 208 | ]: | ||
| 209 | layer.requires_grad_(False) | ||
| 210 | |||
| 201 | text_encoder.text_model.embeddings.requires_grad_(False) | 211 | text_encoder.text_model.embeddings.requires_grad_(False) |
| 202 | 212 | ||
| 203 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler | 213 | return text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler |
