From e32b4d4c04a31b22051740e5f26e16960464f787 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Fri, 3 Mar 2023 18:53:15 +0100 Subject: Implemented different noise offset --- environment.yaml | 2 +- train_dreambooth.py | 4 ++-- train_lora.py | 2 +- train_ti.py | 4 ++-- training/functional.py | 31 ++++++++++--------------------- training/util.py | 1 - 6 files changed, 16 insertions(+), 28 deletions(-) diff --git a/environment.yaml b/environment.yaml index 4899709..018a9ab 100644 --- a/environment.yaml +++ b/environment.yaml @@ -24,4 +24,4 @@ dependencies: - setuptools==65.6.3 - test-tube>=0.7.5 - transformers==4.26.1 - - triton==2.0.0a2 + - triton==2.0.0 diff --git a/train_dreambooth.py b/train_dreambooth.py index 6d699f3..8571dff 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py @@ -438,7 +438,7 @@ def main(): accelerator = Accelerator( log_with=LoggerType.TENSORBOARD, - logging_dir=f"{output_dir}", + project_dir=f"{output_dir}", mixed_precision=args.mixed_precision ) @@ -526,7 +526,7 @@ def main(): with_prior_preservation=args.num_class_images != 0, prior_loss_weight=args.prior_loss_weight, no_val=args.valid_set_size == 0, - # low_freq_noise=0, + # noise_offset=0, ) checkpoint_output_dir = output_dir / "model" diff --git a/train_lora.py b/train_lora.py index 0a3d4c9..e213e3d 100644 --- a/train_lora.py +++ b/train_lora.py @@ -398,7 +398,7 @@ def main(): accelerator = Accelerator( log_with=LoggerType.TENSORBOARD, - logging_dir=f"{output_dir}", + project_dir=f"{output_dir}", mixed_precision=args.mixed_precision ) diff --git a/train_ti.py b/train_ti.py index 394711f..bc9348d 100644 --- a/train_ti.py +++ b/train_ti.py @@ -517,7 +517,7 @@ def main(): accelerator = Accelerator( log_with=LoggerType.TENSORBOARD, - logging_dir=f"{output_dir}", + project_dir=f"{output_dir}", mixed_precision=args.mixed_precision ) @@ -607,7 +607,7 @@ def main(): with_prior_preservation=args.num_class_images != 0, prior_loss_weight=args.prior_loss_weight, no_val=args.valid_set_size == 0, - # low_freq_noise=0, + noise_offset=0, strategy=textual_inversion_strategy, num_train_epochs=args.num_train_epochs, gradient_accumulation_steps=args.gradient_accumulation_steps, diff --git a/training/functional.py b/training/functional.py index 2d582bf..36269f0 100644 --- a/training/functional.py +++ b/training/functional.py @@ -253,7 +253,7 @@ def loss_step( text_encoder: CLIPTextModel, with_prior_preservation: bool, prior_loss_weight: float, - low_freq_noise: float, + noise_offset: float, seed: int, step: int, batch: dict[str, Any], @@ -268,30 +268,19 @@ def loss_step( generator = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None # Sample noise that we'll add to the latents - noise = torch.randn( - latents.shape, + offsets = noise_offset * torch.randn( + latents.shape[0], 1, 1, 1, dtype=latents.dtype, layout=latents.layout, device=latents.device, generator=generator + ).expand(latents.shape) + noise = torch.normal( + mean=offsets, + std=1, + generator=generator, ) - if low_freq_noise != 0: - low_freq_factor = low_freq_noise * torch.randn( - latents.shape[0], 1, 1, 1, - dtype=latents.dtype, - layout=latents.layout, - device=latents.device, - generator=generator - ) - noise = noise * (1 - low_freq_factor) + low_freq_factor * torch.randn( - latents.shape[0], latents.shape[1], 1, 1, - dtype=latents.dtype, - layout=latents.layout, - device=latents.device, - generator=generator - ) - # Sample a random timestep for each image timesteps = torch.randint( 0, @@ -576,7 +565,7 @@ def train( global_step_offset: int = 0, with_prior_preservation: bool = False, prior_loss_weight: float = 1.0, - low_freq_noise: float = 0.1, + noise_offset: float = 0.2, **kwargs, ): text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare( @@ -611,7 +600,7 @@ def train( text_encoder, with_prior_preservation, prior_loss_weight, - low_freq_noise, + noise_offset, seed, ) diff --git a/training/util.py b/training/util.py index c8524de..8bd8a83 100644 --- a/training/util.py +++ b/training/util.py @@ -1,6 +1,5 @@ from pathlib import Path import json -import copy from typing import Iterable, Any from contextlib import contextmanager -- cgit v1.2.3-54-g00ecf