From 27b18776ba6d38d6bda5e5bafee3e7c4ca8c9712 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Sat, 24 Jun 2023 16:26:22 +0200 Subject: Fixes --- training/functional.py | 7 +++++-- training/strategy/dreambooth.py | 3 --- 2 files changed, 5 insertions(+), 5 deletions(-) (limited to 'training') diff --git a/training/functional.py b/training/functional.py index 43b03ac..546aaff 100644 --- a/training/functional.py +++ b/training/functional.py @@ -2,6 +2,7 @@ from dataclasses import dataclass import math from contextlib import _GeneratorContextManager, nullcontext from typing import Callable, Any, Tuple, Union, Optional, Protocol +from types import MethodType from functools import partial from pathlib import Path import itertools @@ -108,6 +109,7 @@ def save_samples( output_dir: Path, seed: int, step: int, + validation_prompts: list[str] = [], cycle: int = 1, batch_size: int = 1, num_batches: int = 1, @@ -136,7 +138,6 @@ def save_samples( if val_dataloader is not None: datasets.append(("stable", val_dataloader, generator)) - datasets.append(("val", val_dataloader, None)) for pool, data, gen in datasets: all_samples = [] @@ -165,7 +166,6 @@ def save_samples( guidance_scale=guidance_scale, sag_scale=0, num_inference_steps=num_steps, - output_type=None, ).images all_samples.append(torch.from_numpy(samples)) @@ -803,4 +803,7 @@ def train( accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) accelerator.unwrap_model(unet, keep_fp32_wrapper=False) + text_encoder.forward = MethodType(text_encoder.forward, text_encoder) + unet.forward = MethodType(unet.forward, unet) + accelerator.free_memory() diff --git a/training/strategy/dreambooth.py b/training/strategy/dreambooth.py index 0f64747..bd853e2 100644 --- a/training/strategy/dreambooth.py +++ b/training/strategy/dreambooth.py @@ -155,9 +155,6 @@ def dreambooth_strategy_callbacks( unet_ = accelerator.unwrap_model(unet, keep_fp32_wrapper=False) text_encoder_ = accelerator.unwrap_model(text_encoder, keep_fp32_wrapper=False) - unet_.forward = MethodType(unet_.forward, unet_) - text_encoder_.forward = MethodType(text_encoder_.forward, text_encoder_) - with ema_context(): pipeline = VlpnStableDiffusion( text_encoder=text_encoder_, -- cgit v1.2.3-70-g09d2