diff options
| author | Volpeon <git@volpeon.ink> | 2022-12-21 09:17:25 +0100 |
|---|---|---|
| committer | Volpeon <git@volpeon.ink> | 2022-12-21 09:17:25 +0100 |
| commit | 68540b27849564994d921968a36faa9b997e626d (patch) | |
| tree | 8fbe834ab4c52f057cd114bbb0e786158f215acc /training | |
| parent | Fix training (diff) | |
| download | textual-inversion-diff-68540b27849564994d921968a36faa9b997e626d.tar.gz textual-inversion-diff-68540b27849564994d921968a36faa9b997e626d.tar.bz2 textual-inversion-diff-68540b27849564994d921968a36faa9b997e626d.zip | |
Moved common training code into separate module
Diffstat (limited to 'training')
| -rw-r--r-- | training/optimization.py | 2 | ||||
| -rw-r--r-- | training/util.py | 131 |
2 files changed, 132 insertions, 1 deletions
diff --git a/training/optimization.py b/training/optimization.py index 0e603fa..c501ed9 100644 --- a/training/optimization.py +++ b/training/optimization.py | |||
| @@ -6,7 +6,7 @@ from diffusers.utils import logging | |||
| 6 | logger = logging.get_logger(__name__) | 6 | logger = logging.get_logger(__name__) |
| 7 | 7 | ||
| 8 | 8 | ||
| 9 | def get_one_cycle_schedule(optimizer, num_training_steps, annealing="cos", min_lr=0.05, mid_point=0.4, last_epoch=-1): | 9 | def get_one_cycle_schedule(optimizer, num_training_steps, annealing="cos", min_lr=0.001, mid_point=0.4, last_epoch=-1): |
| 10 | """ | 10 | """ |
| 11 | Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after | 11 | Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after |
| 12 | a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. | 12 | a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer. |
diff --git a/training/util.py b/training/util.py new file mode 100644 index 0000000..e8d22ae --- /dev/null +++ b/training/util.py | |||
| @@ -0,0 +1,131 @@ | |||
| 1 | from pathlib import Path | ||
| 2 | import json | ||
| 3 | |||
| 4 | import torch | ||
| 5 | from PIL import Image | ||
| 6 | |||
| 7 | |||
| 8 | def freeze_params(params): | ||
| 9 | for param in params: | ||
| 10 | param.requires_grad = False | ||
| 11 | |||
| 12 | |||
| 13 | def save_args(basepath: Path, args, extra={}): | ||
| 14 | info = {"args": vars(args)} | ||
| 15 | info["args"].update(extra) | ||
| 16 | with open(basepath.joinpath("args.json"), "w") as f: | ||
| 17 | json.dump(info, f, indent=4) | ||
| 18 | |||
| 19 | |||
| 20 | def make_grid(images, rows, cols): | ||
| 21 | w, h = images[0].size | ||
| 22 | grid = Image.new('RGB', size=(cols*w, rows*h)) | ||
| 23 | for i, image in enumerate(images): | ||
| 24 | grid.paste(image, box=(i % cols*w, i//cols*h)) | ||
| 25 | return grid | ||
| 26 | |||
| 27 | |||
| 28 | class AverageMeter: | ||
| 29 | def __init__(self, name=None): | ||
| 30 | self.name = name | ||
| 31 | self.reset() | ||
| 32 | |||
| 33 | def reset(self): | ||
| 34 | self.sum = self.count = self.avg = 0 | ||
| 35 | |||
| 36 | def update(self, val, n=1): | ||
| 37 | self.sum += val * n | ||
| 38 | self.count += n | ||
| 39 | self.avg = self.sum / self.count | ||
| 40 | |||
| 41 | |||
| 42 | class CheckpointerBase: | ||
| 43 | def __init__( | ||
| 44 | self, | ||
| 45 | datamodule, | ||
| 46 | output_dir: Path, | ||
| 47 | instance_identifier, | ||
| 48 | placeholder_token, | ||
| 49 | placeholder_token_id, | ||
| 50 | sample_image_size, | ||
| 51 | sample_batches, | ||
| 52 | sample_batch_size, | ||
| 53 | seed | ||
| 54 | ): | ||
| 55 | self.datamodule = datamodule | ||
| 56 | self.output_dir = output_dir | ||
| 57 | self.instance_identifier = instance_identifier | ||
| 58 | self.placeholder_token = placeholder_token | ||
| 59 | self.placeholder_token_id = placeholder_token_id | ||
| 60 | self.sample_image_size = sample_image_size | ||
| 61 | self.seed = seed or torch.random.seed() | ||
| 62 | self.sample_batches = sample_batches | ||
| 63 | self.sample_batch_size = sample_batch_size | ||
| 64 | |||
| 65 | @torch.no_grad() | ||
| 66 | def save_samples(self, pipeline, step, num_inference_steps, guidance_scale=7.5, eta=0.0): | ||
| 67 | samples_path = Path(self.output_dir).joinpath("samples") | ||
| 68 | |||
| 69 | train_data = self.datamodule.train_dataloader() | ||
| 70 | val_data = self.datamodule.val_dataloader() | ||
| 71 | |||
| 72 | generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) | ||
| 73 | stable_latents = torch.randn( | ||
| 74 | (self.sample_batch_size, pipeline.unet.in_channels, self.sample_image_size // 8, self.sample_image_size // 8), | ||
| 75 | device=pipeline.device, | ||
| 76 | generator=generator, | ||
| 77 | ) | ||
| 78 | |||
| 79 | with torch.autocast("cuda"), torch.inference_mode(): | ||
| 80 | for pool, data, latents in [("stable", val_data, stable_latents), ("val", val_data, None), ("train", train_data, None)]: | ||
| 81 | all_samples = [] | ||
| 82 | file_path = samples_path.joinpath(pool, f"step_{step}.jpg") | ||
| 83 | file_path.parent.mkdir(parents=True, exist_ok=True) | ||
| 84 | |||
| 85 | data_enum = enumerate(data) | ||
| 86 | |||
| 87 | batches = [ | ||
| 88 | batch | ||
| 89 | for j, batch in data_enum | ||
| 90 | if j * data.batch_size < self.sample_batch_size * self.sample_batches | ||
| 91 | ] | ||
| 92 | prompts = [ | ||
| 93 | prompt.format(identifier=self.instance_identifier) | ||
| 94 | for batch in batches | ||
| 95 | for prompt in batch["prompts"] | ||
| 96 | ] | ||
| 97 | nprompts = [ | ||
| 98 | prompt | ||
| 99 | for batch in batches | ||
| 100 | for prompt in batch["nprompts"] | ||
| 101 | ] | ||
| 102 | |||
| 103 | for i in range(self.sample_batches): | ||
| 104 | prompt = prompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] | ||
| 105 | nprompt = nprompts[i * self.sample_batch_size:(i + 1) * self.sample_batch_size] | ||
| 106 | |||
| 107 | samples = pipeline( | ||
| 108 | prompt=prompt, | ||
| 109 | negative_prompt=nprompt, | ||
| 110 | height=self.sample_image_size, | ||
| 111 | width=self.sample_image_size, | ||
| 112 | image=latents[:len(prompt)] if latents is not None else None, | ||
| 113 | generator=generator if latents is not None else None, | ||
| 114 | guidance_scale=guidance_scale, | ||
| 115 | eta=eta, | ||
| 116 | num_inference_steps=num_inference_steps, | ||
| 117 | output_type='pil' | ||
| 118 | ).images | ||
| 119 | |||
| 120 | all_samples += samples | ||
| 121 | |||
| 122 | del samples | ||
| 123 | |||
| 124 | image_grid = make_grid(all_samples, self.sample_batches, self.sample_batch_size) | ||
| 125 | image_grid.save(file_path, quality=85) | ||
| 126 | |||
| 127 | del all_samples | ||
| 128 | del image_grid | ||
| 129 | |||
| 130 | del generator | ||
| 131 | del stable_latents | ||
