diff options
author | Volpeon <git@volpeon.ink> | 2023-03-03 22:09:24 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-03-03 22:09:24 +0100 |
commit | 220806dbd21da3ba83c14096225c31824dfe81df (patch) | |
tree | a201272876bfe894f9d504d1582ac022add4b205 /training | |
parent | Implemented different noise offset (diff) | |
download | textual-inversion-diff-220806dbd21da3ba83c14096225c31824dfe81df.tar.gz textual-inversion-diff-220806dbd21da3ba83c14096225c31824dfe81df.tar.bz2 textual-inversion-diff-220806dbd21da3ba83c14096225c31824dfe81df.zip |
Removed offset noise from training, added init offset to pipeline
Diffstat (limited to 'training')
-rw-r--r-- | training/functional.py | 12 |
1 files changed, 2 insertions, 10 deletions
diff --git a/training/functional.py b/training/functional.py index 36269f0..1c38635 100644 --- a/training/functional.py +++ b/training/functional.py | |||
@@ -253,7 +253,6 @@ def loss_step( | |||
253 | text_encoder: CLIPTextModel, | 253 | text_encoder: CLIPTextModel, |
254 | with_prior_preservation: bool, | 254 | with_prior_preservation: bool, |
255 | prior_loss_weight: float, | 255 | prior_loss_weight: float, |
256 | noise_offset: float, | ||
257 | seed: int, | 256 | seed: int, |
258 | step: int, | 257 | step: int, |
259 | batch: dict[str, Any], | 258 | batch: dict[str, Any], |
@@ -268,17 +267,12 @@ def loss_step( | |||
268 | generator = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None | 267 | generator = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None |
269 | 268 | ||
270 | # Sample noise that we'll add to the latents | 269 | # Sample noise that we'll add to the latents |
271 | offsets = noise_offset * torch.randn( | 270 | noise = torch.randn( |
272 | latents.shape[0], 1, 1, 1, | 271 | latents.shape, |
273 | dtype=latents.dtype, | 272 | dtype=latents.dtype, |
274 | layout=latents.layout, | 273 | layout=latents.layout, |
275 | device=latents.device, | 274 | device=latents.device, |
276 | generator=generator | 275 | generator=generator |
277 | ).expand(latents.shape) | ||
278 | noise = torch.normal( | ||
279 | mean=offsets, | ||
280 | std=1, | ||
281 | generator=generator, | ||
282 | ) | 276 | ) |
283 | 277 | ||
284 | # Sample a random timestep for each image | 278 | # Sample a random timestep for each image |
@@ -565,7 +559,6 @@ def train( | |||
565 | global_step_offset: int = 0, | 559 | global_step_offset: int = 0, |
566 | with_prior_preservation: bool = False, | 560 | with_prior_preservation: bool = False, |
567 | prior_loss_weight: float = 1.0, | 561 | prior_loss_weight: float = 1.0, |
568 | noise_offset: float = 0.2, | ||
569 | **kwargs, | 562 | **kwargs, |
570 | ): | 563 | ): |
571 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare( | 564 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare( |
@@ -600,7 +593,6 @@ def train( | |||
600 | text_encoder, | 593 | text_encoder, |
601 | with_prior_preservation, | 594 | with_prior_preservation, |
602 | prior_loss_weight, | 595 | prior_loss_weight, |
603 | noise_offset, | ||
604 | seed, | 596 | seed, |
605 | ) | 597 | ) |
606 | 598 | ||