diff options
| author | Volpeon <git@volpeon.ink> | 2023-03-04 15:08:51 +0100 | 
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-03-04 15:08:51 +0100 | 
| commit | 220c842d22f282544e4d12d277a40f39f85d3c35 (patch) | |
| tree | 6649e9603038d0e04a3f865712add5a6952ef81e /training | |
| parent | Update (diff) | |
| download | textual-inversion-diff-220c842d22f282544e4d12d277a40f39f85d3c35.tar.gz textual-inversion-diff-220c842d22f282544e4d12d277a40f39f85d3c35.tar.bz2 textual-inversion-diff-220c842d22f282544e4d12d277a40f39f85d3c35.zip  | |
Added Perlin noise to training
Diffstat (limited to 'training')
| -rw-r--r-- | training/functional.py | 17 | 
1 files changed, 17 insertions, 0 deletions
diff --git a/training/functional.py b/training/functional.py index 1c38635..db46766 100644 --- a/training/functional.py +++ b/training/functional.py  | |||
| @@ -23,6 +23,7 @@ from models.clip.embeddings import ManagedCLIPTextEmbeddings, patch_managed_embe | |||
| 23 | from models.clip.util import get_extended_embeddings | 23 | from models.clip.util import get_extended_embeddings | 
| 24 | from models.clip.tokenizer import MultiCLIPTokenizer | 24 | from models.clip.tokenizer import MultiCLIPTokenizer | 
| 25 | from training.util import AverageMeter | 25 | from training.util import AverageMeter | 
| 26 | from util.noise import perlin_noise | ||
| 26 | 27 | ||
| 27 | 28 | ||
| 28 | def const(result=None): | 29 | def const(result=None): | 
| @@ -253,6 +254,7 @@ def loss_step( | |||
| 253 | text_encoder: CLIPTextModel, | 254 | text_encoder: CLIPTextModel, | 
| 254 | with_prior_preservation: bool, | 255 | with_prior_preservation: bool, | 
| 255 | prior_loss_weight: float, | 256 | prior_loss_weight: float, | 
| 257 | perlin_strength: float, | ||
| 256 | seed: int, | 258 | seed: int, | 
| 257 | step: int, | 259 | step: int, | 
| 258 | batch: dict[str, Any], | 260 | batch: dict[str, Any], | 
| @@ -275,6 +277,19 @@ def loss_step( | |||
| 275 | generator=generator | 277 | generator=generator | 
| 276 | ) | 278 | ) | 
| 277 | 279 | ||
| 280 | if perlin_strength != 0: | ||
| 281 | noise += perlin_strength * perlin_noise( | ||
| 282 | latents.shape[0], | ||
| 283 | latents.shape[1], | ||
| 284 | latents.shape[2], | ||
| 285 | latents.shape[3], | ||
| 286 | res=1, | ||
| 287 | octaves=4, | ||
| 288 | dtype=latents.dtype, | ||
| 289 | device=latents.device, | ||
| 290 | generator=generator | ||
| 291 | ) | ||
| 292 | |||
| 278 | # Sample a random timestep for each image | 293 | # Sample a random timestep for each image | 
| 279 | timesteps = torch.randint( | 294 | timesteps = torch.randint( | 
| 280 | 0, | 295 | 0, | 
| @@ -559,6 +574,7 @@ def train( | |||
| 559 | global_step_offset: int = 0, | 574 | global_step_offset: int = 0, | 
| 560 | with_prior_preservation: bool = False, | 575 | with_prior_preservation: bool = False, | 
| 561 | prior_loss_weight: float = 1.0, | 576 | prior_loss_weight: float = 1.0, | 
| 577 | perlin_strength: float = 0.1, | ||
| 562 | **kwargs, | 578 | **kwargs, | 
| 563 | ): | 579 | ): | 
| 564 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare( | 580 | text_encoder, unet, optimizer, train_dataloader, val_dataloader, lr_scheduler, extra = strategy.prepare( | 
| @@ -593,6 +609,7 @@ def train( | |||
| 593 | text_encoder, | 609 | text_encoder, | 
| 594 | with_prior_preservation, | 610 | with_prior_preservation, | 
| 595 | prior_loss_weight, | 611 | prior_loss_weight, | 
| 612 | perlin_strength, | ||
| 596 | seed, | 613 | seed, | 
| 597 | ) | 614 | ) | 
| 598 | 615 | ||
