diff options
| author | Volpeon <git@volpeon.ink> | 2023-06-24 16:26:22 +0200 | 
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2023-06-24 16:26:22 +0200 | 
| commit | 27b18776ba6d38d6bda5e5bafee3e7c4ca8c9712 (patch) | |
| tree | 6c1f2243475778bb5e9e1725bf3969a5442393d8 /training | |
| parent | Update (diff) | |
| download | textual-inversion-diff-27b18776ba6d38d6bda5e5bafee3e7c4ca8c9712.tar.gz textual-inversion-diff-27b18776ba6d38d6bda5e5bafee3e7c4ca8c9712.tar.bz2 textual-inversion-diff-27b18776ba6d38d6bda5e5bafee3e7c4ca8c9712.zip  | |
Fixes
Diffstat (limited to 'training')
| -rw-r--r-- | training/functional.py | 7 | ||||
| -rw-r--r-- | training/strategy/dreambooth.py | 3 | 
2 files changed, 5 insertions, 5 deletions
diff --git a/training/functional.py b/training/functional.py index 43b03ac..546aaff 100644 --- a/training/functional.py +++ b/training/functional.py  | |||
| @@ -2,6 +2,7 @@ from dataclasses import dataclass | |||
| 2 | import math | 2 | import math | 
| 3 | from contextlib import _GeneratorContextManager, nullcontext | 3 | from contextlib import _GeneratorContextManager, nullcontext | 
| 4 | from typing import Callable, Any, Tuple, Union, Optional, Protocol | 4 | from typing import Callable, Any, Tuple, Union, Optional, Protocol | 
| 5 | from types import MethodType | ||
| 5 | from functools import partial | 6 | from functools import partial | 
| 6 | from pathlib import Path | 7 | from pathlib import Path | 
| 7 | import itertools | 8 | import itertools | 
| @@ -108,6 +109,7 @@ def save_samples( | |||
| 108 | output_dir: Path, | 109 | output_dir: Path, | 
| 109 | seed: int, | 110 | seed: int, | 
| 110 | step: int, | 111 | step: int, | 
| 112 | validation_prompts: list[str] = [], | ||
| 111 | cycle: int = 1, | 113 | cycle: int = 1, | 
| 112 | batch_size: int = 1, | 114 | batch_size: int = 1, | 
| 113 | num_batches: int = 1, | 115 | num_batches: int = 1, | 
| @@ -136,7 +138,6 @@ def save_samples( | |||
| 136 | 138 | ||
| 137 | if val_dataloader is not None: | 139 | if val_dataloader is not None: | 
| 138 | datasets.append(("stable", val_dataloader, generator)) | 140 | datasets.append(("stable", val_dataloader, generator)) | 
| 139 | datasets.append(("val", val_dataloader, None)) | ||
| 140 | 141 | ||
| 141 | for pool, data, gen in datasets: | 142 | for pool, data, gen in datasets: | 
| 142 | all_samples = [] | 143 | all_samples = [] | 
| @@ -165,7 +166,6 @@ def save_samples( | |||
| 165 | guidance_scale=guidance_scale, | 166 | guidance_scale=guidance_scale, | 
| 166 | sag_scale=0, | 167 | sag_scale=0, | 
| 167 | num_inference_steps=num_steps, | 168 | num_inference_steps=num_steps, | 
| 168 | output_type=None, | ||
| 169 | ).images | 169 | ).images | 
| 170 | 170 | ||
| 171 | all_samples.append(torch.from_numpy(samples)) | 171 | all_samples.append(torch.from_numpy(samples)) | 
| @@ -803,4 +803,7 @@ def train( | |||
| 803 | accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) | 803 | accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) | 
| 804 | accelerator.unwrap_model(unet, keep_fp32_wrapper=False) | 804 | accelerator.unwrap_model(unet, keep_fp32_wrapper=False) | 
| 805 | 805 | ||
| 806 | text_encoder.forward = MethodType(text_encoder.forward, text_encoder) | ||
| 807 | unet.forward = MethodType(unet.forward, unet) | ||
| 808 | |||
| 806 | accelerator.free_memory() | 809 | accelerator.free_memory() | 
diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index 0f64747..bd853e2 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py  | |||
| @@ -155,9 +155,6 @@ def dreambooth_strategy_callbacks( | |||
| 155 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False) | 155 | unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False) | 
| 156 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) | 156 | text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) | 
| 157 | 157 | ||
| 158 | unet_.forward = MethodType(unet_.forward, unet_) | ||
| 159 | text_encoder_.forward = MethodType(text_encoder_.forward, text_encoder_) | ||
| 160 | |||
| 161 | with ema_context(): | 158 | with ema_context(): | 
| 162 | pipeline = VlpnStableDiffusion( | 159 | pipeline = VlpnStableDiffusion( | 
| 163 | text_encoder=text_encoder_, | 160 | text_encoder=text_encoder_, | 
