diff options
| -rw-r--r-- | environment.yaml | 2 | ||||
| -rw-r--r-- | train_dreambooth.py | 4 | ||||
| -rw-r--r-- | train_lora.py | 2 | ||||
| -rw-r--r-- | train_ti.py | 4 | ||||
| -rw-r--r-- | training/functional.py | 31 | ||||
| -rw-r--r-- | training/util.py | 1 |
6 files changed, 16 insertions, 28 deletions
diff --git a/environment.yaml b/environment.yaml index 4899709..018a9ab 100644 --- a/environment.yaml +++ b/environment.yaml | |||
| @@ -24,4 +24,4 @@ dependencies: | |||
| 24 | - setuptools==65.6.3 | 24 | - setuptools==65.6.3 |
| 25 | - test-tube>=0.7.5 | 25 | - test-tube>=0.7.5 |
| 26 | - transformers==4.26.1 | 26 | - transformers==4.26.1 |
| 27 | - triton==2.0.0a2 | 27 | - triton==2.0.0 |
diff --git a/train_dreambooth.py b/train_dreambooth.py index 6d699f3..8571dff 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
| @@ -438,7 +438,7 @@ def main(): | |||
| 438 | 438 | ||
| 439 | accelerator = Accelerator( | 439 | accelerator = Accelerator( |
| 440 | log_with=LoggerType.TENSORBOARD, | 440 | log_with=LoggerType.TENSORBOARD, |
| 441 | logging_dir=f"{output_dir}", | 441 | project_dir=f"{output_dir}", |
| 442 | mixed_precision=args.mixed_precision | 442 | mixed_precision=args.mixed_precision |
| 443 | ) | 443 | ) |
| 444 | 444 | ||
| @@ -526,7 +526,7 @@ def main(): | |||
| 526 | with_prior_preservation=args.num_class_images != 0, | 526 | with_prior_preservation=args.num_class_images != 0, |
| 527 | prior_loss_weight=args.prior_loss_weight, | 527 | prior_loss_weight=args.prior_loss_weight, |
| 528 | no_val=args.valid_set_size == 0, | 528 | no_val=args.valid_set_size == 0, |
| 529 | # low_freq_noise=0, | 529 | # noise_offset=0, |
| 530 | ) | 530 | ) |
| 531 | 531 | ||
| 532 | checkpoint_output_dir = output_dir / "model" | 532 | checkpoint_output_dir = output_dir / "model" |
diff --git a/train_lora.py b/train_lora.py index 0a3d4c9..e213e3d 100644 --- a/train_lora.py +++ b/train_lora.py | |||
| @@ -398,7 +398,7 @@ def main(): | |||
| 398 | 398 | ||
| 399 | accelerator = Accelerator( | 399 | accelerator = Accelerator( |
| 400 | log_with=LoggerType.TENSORBOARD, | 400 | log_with=LoggerType.TENSORBOARD, |
| 401 | logging_dir=f"{output_dir}", | 401 | project_dir=f"{output_dir}", |
| 402 | mixed_precision=args.mixed_precision | 402 | mixed_precision=args.mixed_precision |
| 403 | ) | 403 | ) |
| 404 | 404 | ||
diff --git a/train_ti.py b/train_ti.py index 394711f..bc9348d 100644 --- a/train_ti.py +++ b/train_ti.py | |||
| @@ -517,7 +517,7 @@ def main(): | |||
| 517 | 517 | ||
| 518 | accelerator = Accelerator( | 518 | accelerator = Accelerator( |
| 519 | log_with=LoggerType.TENSORBOARD, | 519 | log_with=LoggerType.TENSORBOARD, |
| 520 | logging_dir=f"{output_dir}", | 520 | project_dir=f"{output_dir}", |
| 521 | mixed_precision=args.mixed_precision | 521 | mixed_precision=args.mixed_precision |
| 522 | ) | 522 | ) |
| 523 | 523 | ||
| @@ -607,7 +607,7 @@ def main(): | |||
| 607 | with_prior_preservation=args.num_class_images != 0, | 607 | with_prior_preservation=args.num_class_images != 0, |
| 608 | prior_loss_weight=args.prior_loss_weight, | 608 | prior_loss_weight=args.prior_loss_weight, |
| 609 | no_val=args.valid_set_size == 0, | 609 | no_val=args.valid_set_size == 0, |
| 610 | # low_freq_noise=0, | 610 | noise_offset=0, |
| 611 | strategy=textual_inversion_strategy, | 611 | strategy=textual_inversion_strategy, |
| 612 | num_train_epochs=args.num_train_epochs, | 612 | num_train_epochs=args.num_train_epochs, |
| 613 | gradient_accumulation_steps=args.gradient_accumulation_steps, | 613 | gradient_accumulation_steps=args.gradient_accumulation_steps, |
diff --git a/training/functional.py b/training/functional.py index 2d582bf..36269f0 100644 --- a/training/functional.py +++ b/training/functional.py | |||
| @@ -253,7 +253,7 @@ def loss_step( | |||
| 253 | text_encoder: CLIPTextModel, | 253 | text_encoder: CLIPTextModel, |
| 254 | with_prior_preservation: bool, | 254 | with_prior_preservation: bool, |
| 255 | prior_loss_weight: float, | 255 | prior_loss_weight: float, |
| 256 | low_freq_noise: float, | 256 | noise_offset: float, |
| 257 | seed: int, | 257 | seed: int, |
| 258 | step: int, | 258 | step: int, |
| 259 | batch: dict[str, Any], | 259 | batch: dict[str, Any], |
| @@ -268,30 +268,19 @@ def loss_step( | |||
| 268 | generator = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None | 268 | generator = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None |
| 269 | 269 | ||
| 270 | # Sample noise that we'll add to the latents | 270 | # Sample noise that we'll add to the latents |
| 271 | noise = torch.randn( | 271 | offsets = noise_offset * torch.randn( |
| 272 | latents.shape, | 272 | latents.shape[0], 1, 1, 1, |
| 273 | dtype=latents.dtype, | 273 | dtype=latents.dtype, |
| 274 | layout=latents.layout, | 274 | layout=latents.layout, |
| 275 | device=latents.device, | 275 | device=latents.device, |
| 276 | generator=generator | 276 | generator=generator |
| 277 | ).expand(latents.shape) | ||
| 278 | noise = torch.normal( | ||
| 279 | mean=offsets, | ||
| 280 | std=1, | ||
| 281 | generator=generator, | ||
| 277 | ) | 282 | ) |
| 278 | 283 | ||
| 279 | if low_freq_noise != 0: | ||
| 280 | low_freq_factor = low_freq_noise * torch.randn( | ||
| 281 | latents.shape[0], 1, 1, 1, | ||
| 282 | dtype=latents.dtype, | ||
| 283 | layout=latents.layout, | ||
| 284 | device=latents.device, | ||
| 285 | generator=generator | ||
| 286 | ) | ||
| 287 | noise = noise * (1 - low_freq_factor) + low_freq_factor * torch.randn( | ||
| 288 | latents.shape[0], latents.shape[1], 1, 1, | ||
| 289 | dtype=latents.dtype, | ||
| 290 | layout=latents.layout, | ||
| 291 | device=latents.device, | ||
| 292 | generator=generator | ||
| 293 | ) | ||
| 294 | |||
| 295 | # Sample a random timestep for each image | 284 | # Sample a random timestep for each image |
| 296 | timesteps = torch.randint( | 285 | timesteps = torch.randint( |
| 297 | 0, | 286 | 0, |
| @@ -576,7 +565,7 @@ def train( | |||
| 576 | global_step_offset: int = 0, | 565 | global_step_offset: int = 0, |
| 577 | with_prior_preservation: bool = False, | 566 | with_prior_preservation: bool = False, |
| 578 | prior_loss_weight: float = 1.0, | 567 | prior_loss_weight: float = 1.0, |
| 579 | low_freq_noise: float = 0.1, | 568 | noise_offset: float = 0.2, |
| 580 | **kwargs, | 569 | **kwargs, |
| 581 | ): | 570 | ): |
| 582 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare( | 571 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare( |
| @@ -611,7 +600,7 @@ def train( | |||
| 611 | text_encoder, | 600 | text_encoder, |
| 612 | with_prior_preservation, | 601 | with_prior_preservation, |
| 613 | prior_loss_weight, | 602 | prior_loss_weight, |
| 614 | low_freq_noise, | 603 | noise_offset, |
| 615 | seed, | 604 | seed, |
| 616 | ) | 605 | ) |
| 617 | 606 | ||
diff --git a/training/util.py b/training/util.py index c8524de..8bd8a83 100644 --- a/training/util.py +++ b/training/util.py | |||
| @@ -1,6 +1,5 @@ | |||
| 1 | from pathlib import Path | 1 | from pathlib import Path |
| 2 | import json | 2 | import json |
| 3 | import copy | ||
| 4 | from typing import Iterable, Any | 3 | from typing import Iterable, Any |
| 5 | from contextlib import contextmanager | 4 | from contextlib import contextmanager |
| 6 | 5 | ||
