summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-08 09:43:22 +0100
committerVolpeon <git@volpeon.ink>2023-01-08 09:43:22 +0100
commit5571c4ebcb39813e2bd8585de30c64bb02f9d7fa (patch)
treea073f625eaa49c3cd908aacb3debae23e5badbf7
parentCleanup (diff)
downloadtextual-inversion-diff-5571c4ebcb39813e2bd8585de30c64bb02f9d7fa.tar.gz
textual-inversion-diff-5571c4ebcb39813e2bd8585de30c64bb02f9d7fa.tar.bz2
textual-inversion-diff-5571c4ebcb39813e2bd8585de30c64bb02f9d7fa.zip
Improved aspect ratio bucketing
-rw-r--r--data/csv.py273
-rw-r--r--train_dreambooth.py100
-rw-r--r--train_ti.py85
-rw-r--r--training/util.py2
4 files changed, 237 insertions, 223 deletions
diff --git a/data/csv.py b/data/csv.py
index 654aec1..9be36ba 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -2,20 +2,28 @@ import math
2import torch 2import torch
3import json 3import json
4from pathlib import Path 4from pathlib import Path
5from typing import NamedTuple, Optional, Union, Callable
6
5from PIL import Image 7from PIL import Image
6from torch.utils.data import Dataset, DataLoader, random_split
7from torchvision import transforms
8from typing import Dict, NamedTuple, List, Optional, Union, Callable
9 8
10import numpy as np 9from torch.utils.data import IterableDataset, DataLoader, random_split
10from torchvision import transforms
11 11
12from models.clip.prompt import PromptProcessor
13from data.keywords import prompt_to_keywords, keywords_to_prompt 12from data.keywords import prompt_to_keywords, keywords_to_prompt
13from models.clip.prompt import PromptProcessor
14 14
15 15
16image_cache: dict[str, Image.Image] = {} 16image_cache: dict[str, Image.Image] = {}
17 17
18 18
19interpolations = {
20 "linear": transforms.InterpolationMode.NEAREST,
21 "bilinear": transforms.InterpolationMode.BILINEAR,
22 "bicubic": transforms.InterpolationMode.BICUBIC,
23 "lanczos": transforms.InterpolationMode.LANCZOS,
24}
25
26
19def get_image(path): 27def get_image(path):
20 if path in image_cache: 28 if path in image_cache:
21 return image_cache[path] 29 return image_cache[path]
@@ -28,10 +36,46 @@ def get_image(path):
28 return image 36 return image
29 37
30 38
31def prepare_prompt(prompt: Union[str, Dict[str, str]]): 39def prepare_prompt(prompt: Union[str, dict[str, str]]):
32 return {"content": prompt} if isinstance(prompt, str) else prompt 40 return {"content": prompt} if isinstance(prompt, str) else prompt
33 41
34 42
43def generate_buckets(items: list[str], size: int, num_buckets: int, progressive_buckets: bool):
44 item_order: list[int] = []
45 item_buckets: list[int] = []
46 buckets = [1.0]
47
48 for i in range(1, num_buckets + 1):
49 s = size + i * 64
50 buckets.append(s / size)
51 buckets.append(size / s)
52
53 buckets = torch.tensor(buckets)
54 bucket_indices = torch.arange(len(buckets))
55
56 for i, item in enumerate(items):
57 image = get_image(item)
58 ratio = image.width / image.height
59
60 if ratio >= 1:
61 mask = torch.bitwise_and(buckets >= 1, buckets <= ratio)
62 else:
63 mask = torch.bitwise_and(buckets <= 1, buckets >= ratio)
64
65 if not progressive_buckets:
66 mask = (buckets + (~mask) * math.inf - ratio).abs().argmin()
67
68 indices = bucket_indices[mask]
69
70 if len(indices.shape) == 0:
71 indices = indices.unsqueeze(0)
72
73 item_order += [i] * len(indices)
74 item_buckets += indices
75
76 return buckets.tolist(), item_order, item_buckets
77
78
35class VlpnDataItem(NamedTuple): 79class VlpnDataItem(NamedTuple):
36 instance_image_path: Path 80 instance_image_path: Path
37 class_image_path: Path 81 class_image_path: Path
@@ -41,14 +85,6 @@ class VlpnDataItem(NamedTuple):
41 collection: list[str] 85 collection: list[str]
42 86
43 87
44class VlpnDataBucket():
45 def __init__(self, width: int, height: int):
46 self.width = width
47 self.height = height
48 self.ratio = width / height
49 self.items: list[VlpnDataItem] = []
50
51
52class VlpnDataModule(): 88class VlpnDataModule():
53 def __init__( 89 def __init__(
54 self, 90 self,
@@ -60,7 +96,6 @@ class VlpnDataModule():
60 size: int = 768, 96 size: int = 768,
61 num_aspect_ratio_buckets: int = 0, 97 num_aspect_ratio_buckets: int = 0,
62 progressive_aspect_ratio_buckets: bool = False, 98 progressive_aspect_ratio_buckets: bool = False,
63 repeats: int = 1,
64 dropout: float = 0, 99 dropout: float = 0,
65 interpolation: str = "bicubic", 100 interpolation: str = "bicubic",
66 template_key: str = "template", 101 template_key: str = "template",
@@ -86,7 +121,6 @@ class VlpnDataModule():
86 self.size = size 121 self.size = size
87 self.num_aspect_ratio_buckets = num_aspect_ratio_buckets 122 self.num_aspect_ratio_buckets = num_aspect_ratio_buckets
88 self.progressive_aspect_ratio_buckets = progressive_aspect_ratio_buckets 123 self.progressive_aspect_ratio_buckets = progressive_aspect_ratio_buckets
89 self.repeats = repeats
90 self.dropout = dropout 124 self.dropout = dropout
91 self.template_key = template_key 125 self.template_key = template_key
92 self.interpolation = interpolation 126 self.interpolation = interpolation
@@ -146,36 +180,6 @@ class VlpnDataModule():
146 for i in range(image_multiplier) 180 for i in range(image_multiplier)
147 ] 181 ]
148 182
149 def generate_buckets(self, items: list[VlpnDataItem]):
150 buckets = [VlpnDataBucket(self.size, self.size)]
151
152 for i in range(1, self.num_aspect_ratio_buckets + 1):
153 s = self.size + i * 64
154 buckets.append(VlpnDataBucket(s, self.size))
155 buckets.append(VlpnDataBucket(self.size, s))
156
157 buckets = np.array(buckets)
158 bucket_ratios = np.array([bucket.ratio for bucket in buckets])
159
160 for item in items:
161 image = get_image(item.instance_image_path)
162 ratio = image.width / image.height
163
164 if ratio >= 1:
165 mask = np.bitwise_and(bucket_ratios >= 1, bucket_ratios <= ratio)
166 else:
167 mask = np.bitwise_and(bucket_ratios <= 1, bucket_ratios >= ratio)
168
169 if not self.progressive_aspect_ratio_buckets:
170 ratios = bucket_ratios.copy()
171 ratios[~mask] = math.inf
172 mask = [np.argmin(np.abs(ratios - ratio))]
173
174 for bucket in buckets[mask]:
175 bucket.items.append(item)
176
177 return [bucket for bucket in buckets if len(bucket.items) != 0]
178
179 def setup(self): 183 def setup(self):
180 with open(self.data_file, 'rt') as f: 184 with open(self.data_file, 'rt') as f:
181 metadata = json.load(f) 185 metadata = json.load(f)
@@ -201,105 +205,136 @@ class VlpnDataModule():
201 self.data_train = self.pad_items(data_train, self.num_class_images) 205 self.data_train = self.pad_items(data_train, self.num_class_images)
202 self.data_val = self.pad_items(data_val) 206 self.data_val = self.pad_items(data_val)
203 207
204 buckets = self.generate_buckets(data_train) 208 train_dataset = VlpnDataset(
205 209 self.data_train, self.prompt_processor,
206 train_datasets = [ 210 num_buckets=self.num_aspect_ratio_buckets, progressive_buckets=self.progressive_aspect_ratio_buckets,
207 VlpnDataset( 211 batch_size=self.batch_size,
208 bucket.items, self.prompt_processor, 212 size=self.size, interpolation=self.interpolation,
209 width=bucket.width, height=bucket.height, interpolation=self.interpolation, 213 num_class_images=self.num_class_images, dropout=self.dropout, shuffle=True,
210 num_class_images=self.num_class_images, repeats=self.repeats, dropout=self.dropout, 214 )
211 )
212 for bucket in buckets
213 ]
214 215
215 val_dataset = VlpnDataset( 216 val_dataset = VlpnDataset(
216 data_val, self.prompt_processor, 217 self.data_val, self.prompt_processor,
217 width=self.size, height=self.size, interpolation=self.interpolation, 218 batch_size=self.batch_size,
219 size=self.size, interpolation=self.interpolation,
218 ) 220 )
219 221
220 self.train_dataloaders = [ 222 self.train_dataloader = DataLoader(
221 DataLoader( 223 train_dataset,
222 dataset, batch_size=self.batch_size, shuffle=True, 224 batch_size=None, pin_memory=True, collate_fn=self.collate_fn, num_workers=self.num_workers
223 pin_memory=True, collate_fn=self.collate_fn, num_workers=self.num_workers 225 )
224 )
225 for dataset in train_datasets
226 ]
227 226
228 self.val_dataloader = DataLoader( 227 self.val_dataloader = DataLoader(
229 val_dataset, batch_size=self.batch_size, 228 val_dataset,
230 pin_memory=True, collate_fn=self.collate_fn, num_workers=self.num_workers 229 batch_size=None, pin_memory=True, collate_fn=self.collate_fn, num_workers=self.num_workers
231 ) 230 )
232 231
233 232
234class VlpnDataset(Dataset): 233class VlpnDataset(IterableDataset):
235 def __init__( 234 def __init__(
236 self, 235 self,
237 data: List[VlpnDataItem], 236 items: list[VlpnDataItem],
238 prompt_processor: PromptProcessor, 237 prompt_processor: PromptProcessor,
238 num_buckets: int = 1,
239 progressive_buckets: bool = False,
240 batch_size: int = 1,
239 num_class_images: int = 0, 241 num_class_images: int = 0,
240 width: int = 768, 242 size: int = 768,
241 height: int = 768,
242 repeats: int = 1,
243 dropout: float = 0, 243 dropout: float = 0,
244 shuffle: bool = False,
244 interpolation: str = "bicubic", 245 interpolation: str = "bicubic",
246 generator: Optional[torch.Generator] = None,
245 ): 247 ):
248 self.items = items
249 self.batch_size = batch_size
246 250
247 self.data = data
248 self.prompt_processor = prompt_processor 251 self.prompt_processor = prompt_processor
249 self.num_class_images = num_class_images 252 self.num_class_images = num_class_images
253 self.size = size
250 self.dropout = dropout 254 self.dropout = dropout
251 255 self.shuffle = shuffle
252 self.num_instance_images = len(self.data) 256 self.interpolation = interpolations[interpolation]
253 self._length = self.num_instance_images * repeats 257 self.generator = generator
254 258
255 self.interpolation = { 259 buckets, item_order, item_buckets = generate_buckets(
256 "linear": transforms.InterpolationMode.NEAREST, 260 [item.instance_image_path for item in items],
257 "bilinear": transforms.InterpolationMode.BILINEAR, 261 size,
258 "bicubic": transforms.InterpolationMode.BICUBIC, 262 num_buckets,
259 "lanczos": transforms.InterpolationMode.LANCZOS, 263 progressive_buckets
260 }[interpolation]
261 self.image_transforms = transforms.Compose(
262 [
263 transforms.Resize(min(width, height), interpolation=self.interpolation),
264 transforms.RandomCrop((height, width)),
265 transforms.RandomHorizontalFlip(),
266 transforms.ToTensor(),
267 transforms.Normalize([0.5], [0.5]),
268 ]
269 ) 264 )
270 265
271 def __len__(self): 266 self.buckets = torch.tensor(buckets)
272 return self._length 267 self.item_order = torch.tensor(item_order)
268 self.item_buckets = torch.tensor(item_buckets)
273 269
274 def get_example(self, i): 270 def __len__(self):
275 item = self.data[i % self.num_instance_images] 271 return len(self.item_buckets)
276 272
277 example = {} 273 def __iter__(self):
278 example["prompts"] = item.prompt 274 worker_info = torch.utils.data.get_worker_info()
279 example["cprompts"] = item.cprompt 275
280 example["nprompts"] = item.nprompt 276 if self.shuffle:
281 example["instance_images"] = get_image(item.instance_image_path) 277 perm = torch.randperm(len(self.item_buckets), generator=self.generator)
282 if self.num_class_images != 0: 278 self.item_order = self.item_order[perm]
283 example["class_images"] = get_image(item.class_image_path) 279 self.item_buckets = self.item_buckets[perm]
284 280
285 return example 281 item_mask = torch.ones_like(self.item_buckets, dtype=bool)
282 bucket = -1
283 image_transforms = None
284 batch = []
285 batch_size = self.batch_size
286
287 if worker_info is not None:
288 batch_size = math.ceil(batch_size / worker_info.num_workers)
289 worker_batch = math.ceil(len(self) / worker_info.num_workers)
290 start = worker_info.id * worker_batch
291 end = start + worker_batch
292 item_mask[:start] = False
293 item_mask[end:] = False
294
295 while item_mask.any():
296 item_indices = self.item_order[(self.item_buckets == bucket) & item_mask]
297
298 if len(batch) >= batch_size or (len(item_indices) == 0 and len(batch) != 0):
299 yield batch
300 batch = []
301
302 if len(item_indices) == 0:
303 bucket = self.item_buckets[item_mask][0]
304 ratio = self.buckets[bucket]
305 width = self.size * ratio if ratio > 1 else self.size
306 height = self.size / ratio if ratio < 1 else self.size
307
308 image_transforms = transforms.Compose(
309 [
310 transforms.Resize(min(width, height), interpolation=self.interpolation),
311 transforms.RandomCrop((height, width)),
312 transforms.RandomHorizontalFlip(),
313 transforms.ToTensor(),
314 transforms.Normalize([0.5], [0.5]),
315 ]
316 )
317 else:
318 item_index = item_indices[0]
319 item = self.items[item_index]
320 item_mask[item_index] = False
286 321
287 def __getitem__(self, i): 322 example = {}
288 unprocessed_example = self.get_example(i)
289 323
290 example = {} 324 example["prompts"] = keywords_to_prompt(item.prompt)
325 example["cprompts"] = item.cprompt
326 example["nprompts"] = item.nprompt
291 327
292 example["prompts"] = keywords_to_prompt(unprocessed_example["prompts"]) 328 example["instance_images"] = image_transforms(get_image(item.instance_image_path))
293 example["cprompts"] = unprocessed_example["cprompts"] 329 example["instance_prompt_ids"] = self.prompt_processor.get_input_ids(
294 example["nprompts"] = unprocessed_example["nprompts"] 330 keywords_to_prompt(item.prompt, self.dropout, True)
331 )
295 332
296 example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) 333 if self.num_class_images != 0:
297 example["instance_prompt_ids"] = self.prompt_processor.get_input_ids( 334 example["class_images"] = image_transforms(get_image(item.class_image_path))
298 keywords_to_prompt(unprocessed_example["prompts"], self.dropout, True) 335 example["class_prompt_ids"] = self.prompt_processor.get_input_ids(example["cprompts"])
299 )
300 336
301 if self.num_class_images != 0: 337 batch.append(example)
302 example["class_images"] = self.image_transforms(unprocessed_example["class_images"])
303 example["class_prompt_ids"] = self.prompt_processor.get_input_ids(example["cprompts"])
304 338
305 return example 339 if len(batch) != 0:
340 yield batch
diff --git a/train_dreambooth.py b/train_dreambooth.py
index 589af59..42a7d0f 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -134,12 +134,6 @@ def parse_args():
134 help="The directory where class images will be saved.", 134 help="The directory where class images will be saved.",
135 ) 135 )
136 parser.add_argument( 136 parser.add_argument(
137 "--repeats",
138 type=int,
139 default=1,
140 help="How many times to repeat the training data."
141 )
142 parser.add_argument(
143 "--output_dir", 137 "--output_dir",
144 type=str, 138 type=str,
145 default="output/dreambooth", 139 default="output/dreambooth",
@@ -738,7 +732,6 @@ def main():
738 class_subdir=args.class_image_dir, 732 class_subdir=args.class_image_dir,
739 num_class_images=args.num_class_images, 733 num_class_images=args.num_class_images,
740 size=args.resolution, 734 size=args.resolution,
741 repeats=args.repeats,
742 dropout=args.tag_dropout, 735 dropout=args.tag_dropout,
743 template_key=args.train_data_template, 736 template_key=args.train_data_template,
744 valid_set_size=args.valid_set_size, 737 valid_set_size=args.valid_set_size,
@@ -751,7 +744,7 @@ def main():
751 datamodule.prepare_data() 744 datamodule.prepare_data()
752 datamodule.setup() 745 datamodule.setup()
753 746
754 train_dataloaders = datamodule.train_dataloaders 747 train_dataloader = datamodule.train_dataloader
755 val_dataloader = datamodule.val_dataloader 748 val_dataloader = datamodule.val_dataloader
756 749
757 if args.num_class_images != 0: 750 if args.num_class_images != 0:
@@ -770,8 +763,7 @@ def main():
770 763
771 # Scheduler and math around the number of training steps. 764 # Scheduler and math around the number of training steps.
772 overrode_max_train_steps = False 765 overrode_max_train_steps = False
773 num_update_steps_per_dataloader = sum(len(dataloader) for dataloader in train_dataloaders) 766 num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
774 num_update_steps_per_epoch = math.ceil(num_update_steps_per_dataloader / args.gradient_accumulation_steps)
775 if args.max_train_steps is None: 767 if args.max_train_steps is None:
776 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 768 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
777 overrode_max_train_steps = True 769 overrode_max_train_steps = True
@@ -820,8 +812,7 @@ def main():
820 ema_unet.to(accelerator.device) 812 ema_unet.to(accelerator.device)
821 813
822 # We need to recalculate our total training steps as the size of the training dataloader may have changed. 814 # We need to recalculate our total training steps as the size of the training dataloader may have changed.
823 num_update_steps_per_dataloader = sum(len(dataloader) for dataloader in train_dataloaders) 815 num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
824 num_update_steps_per_epoch = math.ceil(num_update_steps_per_dataloader / args.gradient_accumulation_steps)
825 if overrode_max_train_steps: 816 if overrode_max_train_steps:
826 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 817 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
827 818
@@ -877,7 +868,7 @@ def main():
877 accelerator, 868 accelerator,
878 text_encoder, 869 text_encoder,
879 optimizer, 870 optimizer,
880 train_dataloaders[0], 871 train_dataloader,
881 val_dataloader, 872 val_dataloader,
882 loop, 873 loop,
883 on_train=tokenizer.train, 874 on_train=tokenizer.train,
@@ -960,54 +951,53 @@ def main():
960 text_encoder.requires_grad_(False) 951 text_encoder.requires_grad_(False)
961 952
962 with on_train(): 953 with on_train():
963 for train_dataloader in train_dataloaders: 954 for step, batch in enumerate(train_dataloader):
964 for step, batch in enumerate(train_dataloader): 955 with accelerator.accumulate(unet):
965 with accelerator.accumulate(unet): 956 loss, acc, bsz = loop(step, batch)
966 loss, acc, bsz = loop(step, batch)
967
968 accelerator.backward(loss)
969
970 if accelerator.sync_gradients:
971 params_to_clip = (
972 itertools.chain(unet.parameters(), text_encoder.parameters())
973 if args.train_text_encoder and epoch < args.train_text_encoder_epochs
974 else unet.parameters()
975 )
976 accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
977
978 optimizer.step()
979 if not accelerator.optimizer_step_was_skipped:
980 lr_scheduler.step()
981 if args.use_ema:
982 ema_unet.step(unet.parameters())
983 optimizer.zero_grad(set_to_none=True)
984
985 avg_loss.update(loss.detach_(), bsz)
986 avg_acc.update(acc.detach_(), bsz)
987
988 # Checks if the accelerator has performed an optimization step behind the scenes
989 if accelerator.sync_gradients:
990 local_progress_bar.update(1)
991 global_progress_bar.update(1)
992 957
993 global_step += 1 958 accelerator.backward(loss)
994 959
995 logs = { 960 if accelerator.sync_gradients:
996 "train/loss": avg_loss.avg.item(), 961 params_to_clip = (
997 "train/acc": avg_acc.avg.item(), 962 itertools.chain(unet.parameters(), text_encoder.parameters())
998 "train/cur_loss": loss.item(), 963 if args.train_text_encoder and epoch < args.train_text_encoder_epochs
999 "train/cur_acc": acc.item(), 964 else unet.parameters()
1000 "lr": lr_scheduler.get_last_lr()[0] 965 )
1001 } 966 accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm)
967
968 optimizer.step()
969 if not accelerator.optimizer_step_was_skipped:
970 lr_scheduler.step()
1002 if args.use_ema: 971 if args.use_ema:
1003 logs["ema_decay"] = 1 - ema_unet.decay 972 ema_unet.step(unet.parameters())
973 optimizer.zero_grad(set_to_none=True)
1004 974
1005 accelerator.log(logs, step=global_step) 975 avg_loss.update(loss.detach_(), bsz)
976 avg_acc.update(acc.detach_(), bsz)
1006 977
1007 local_progress_bar.set_postfix(**logs) 978 # Checks if the accelerator has performed an optimization step behind the scenes
979 if accelerator.sync_gradients:
980 local_progress_bar.update(1)
981 global_progress_bar.update(1)
982
983 global_step += 1
984
985 logs = {
986 "train/loss": avg_loss.avg.item(),
987 "train/acc": avg_acc.avg.item(),
988 "train/cur_loss": loss.item(),
989 "train/cur_acc": acc.item(),
990 "lr": lr_scheduler.get_last_lr()[0]
991 }
992 if args.use_ema:
993 logs["ema_decay"] = 1 - ema_unet.decay
994
995 accelerator.log(logs, step=global_step)
996
997 local_progress_bar.set_postfix(**logs)
1008 998
1009 if global_step >= args.max_train_steps: 999 if global_step >= args.max_train_steps:
1010 break 1000 break
1011 1001
1012 accelerator.wait_for_everyone() 1002 accelerator.wait_for_everyone()
1013 1003
diff --git a/train_ti.py b/train_ti.py
index b4b602b..727b591 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -107,12 +107,6 @@ def parse_args():
107 help="Exclude all items with a listed collection.", 107 help="Exclude all items with a listed collection.",
108 ) 108 )
109 parser.add_argument( 109 parser.add_argument(
110 "--repeats",
111 type=int,
112 default=1,
113 help="How many times to repeat the training data."
114 )
115 parser.add_argument(
116 "--output_dir", 110 "--output_dir",
117 type=str, 111 type=str,
118 default="output/text-inversion", 112 default="output/text-inversion",
@@ -722,7 +716,6 @@ def main():
722 size=args.resolution, 716 size=args.resolution,
723 num_aspect_ratio_buckets=args.num_aspect_ratio_buckets, 717 num_aspect_ratio_buckets=args.num_aspect_ratio_buckets,
724 progressive_aspect_ratio_buckets=args.progressive_aspect_ratio_buckets, 718 progressive_aspect_ratio_buckets=args.progressive_aspect_ratio_buckets,
725 repeats=args.repeats,
726 dropout=args.tag_dropout, 719 dropout=args.tag_dropout,
727 template_key=args.train_data_template, 720 template_key=args.train_data_template,
728 valid_set_size=args.valid_set_size, 721 valid_set_size=args.valid_set_size,
@@ -733,7 +726,7 @@ def main():
733 ) 726 )
734 datamodule.setup() 727 datamodule.setup()
735 728
736 train_dataloaders = datamodule.train_dataloaders 729 train_dataloader = datamodule.train_dataloader
737 val_dataloader = datamodule.val_dataloader 730 val_dataloader = datamodule.val_dataloader
738 731
739 if args.num_class_images != 0: 732 if args.num_class_images != 0:
@@ -752,8 +745,7 @@ def main():
752 745
753 # Scheduler and math around the number of training steps. 746 # Scheduler and math around the number of training steps.
754 overrode_max_train_steps = False 747 overrode_max_train_steps = False
755 num_update_steps_per_dataloader = sum(len(dataloader) for dataloader in train_dataloaders) 748 num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
756 num_update_steps_per_epoch = math.ceil(num_update_steps_per_dataloader / args.gradient_accumulation_steps)
757 if args.max_train_steps is None: 749 if args.max_train_steps is None:
758 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 750 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
759 overrode_max_train_steps = True 751 overrode_max_train_steps = True
@@ -790,10 +782,9 @@ def main():
790 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 782 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
791 ) 783 )
792 784
793 text_encoder, optimizer, val_dataloader, lr_scheduler = accelerator.prepare( 785 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare(
794 text_encoder, optimizer, val_dataloader, lr_scheduler 786 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler
795 ) 787 )
796 train_dataloaders = accelerator.prepare(*train_dataloaders)
797 788
798 # Move vae and unet to device 789 # Move vae and unet to device
799 vae.to(accelerator.device, dtype=weight_dtype) 790 vae.to(accelerator.device, dtype=weight_dtype)
@@ -811,8 +802,7 @@ def main():
811 unet.eval() 802 unet.eval()
812 803
813 # We need to recalculate our total training steps as the size of the training dataloader may have changed. 804 # We need to recalculate our total training steps as the size of the training dataloader may have changed.
814 num_update_steps_per_dataloader = sum(len(dataloader) for dataloader in train_dataloaders) 805 num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
815 num_update_steps_per_epoch = math.ceil(num_update_steps_per_dataloader / args.gradient_accumulation_steps)
816 if overrode_max_train_steps: 806 if overrode_max_train_steps:
817 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 807 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
818 808
@@ -870,7 +860,7 @@ def main():
870 accelerator, 860 accelerator,
871 text_encoder, 861 text_encoder,
872 optimizer, 862 optimizer,
873 train_dataloaders[0], 863 train_dataloader,
874 val_dataloader, 864 val_dataloader,
875 loop, 865 loop,
876 on_train=on_train, 866 on_train=on_train,
@@ -949,48 +939,47 @@ def main():
949 text_encoder.train() 939 text_encoder.train()
950 940
951 with on_train(): 941 with on_train():
952 for train_dataloader in train_dataloaders: 942 for step, batch in enumerate(train_dataloader):
953 for step, batch in enumerate(train_dataloader): 943 with accelerator.accumulate(text_encoder):
954 with accelerator.accumulate(text_encoder): 944 loss, acc, bsz = loop(step, batch)
955 loss, acc, bsz = loop(step, batch)
956 945
957 accelerator.backward(loss) 946 accelerator.backward(loss)
958 947
959 optimizer.step() 948 optimizer.step()
960 if not accelerator.optimizer_step_was_skipped: 949 if not accelerator.optimizer_step_was_skipped:
961 lr_scheduler.step() 950 lr_scheduler.step()
962 optimizer.zero_grad(set_to_none=True) 951 optimizer.zero_grad(set_to_none=True)
963 952
964 avg_loss.update(loss.detach_(), bsz) 953 avg_loss.update(loss.detach_(), bsz)
965 avg_acc.update(acc.detach_(), bsz) 954 avg_acc.update(acc.detach_(), bsz)
966 955
967 # Checks if the accelerator has performed an optimization step behind the scenes 956 # Checks if the accelerator has performed an optimization step behind the scenes
968 if accelerator.sync_gradients: 957 if accelerator.sync_gradients:
969 if args.use_ema: 958 if args.use_ema:
970 ema_embeddings.step( 959 ema_embeddings.step(
971 text_encoder.text_model.embeddings.temp_token_embedding.parameters()) 960 text_encoder.text_model.embeddings.temp_token_embedding.parameters())
972 961
973 local_progress_bar.update(1) 962 local_progress_bar.update(1)
974 global_progress_bar.update(1) 963 global_progress_bar.update(1)
975 964
976 global_step += 1 965 global_step += 1
977 966
978 logs = { 967 logs = {
979 "train/loss": avg_loss.avg.item(), 968 "train/loss": avg_loss.avg.item(),
980 "train/acc": avg_acc.avg.item(), 969 "train/acc": avg_acc.avg.item(),
981 "train/cur_loss": loss.item(), 970 "train/cur_loss": loss.item(),
982 "train/cur_acc": acc.item(), 971 "train/cur_acc": acc.item(),
983 "lr": lr_scheduler.get_last_lr()[0], 972 "lr": lr_scheduler.get_last_lr()[0],
984 } 973 }
985 if args.use_ema: 974 if args.use_ema:
986 logs["ema_decay"] = ema_embeddings.decay 975 logs["ema_decay"] = ema_embeddings.decay
987 976
988 accelerator.log(logs, step=global_step) 977 accelerator.log(logs, step=global_step)
989 978
990 local_progress_bar.set_postfix(**logs) 979 local_progress_bar.set_postfix(**logs)
991 980
992 if global_step >= args.max_train_steps: 981 if global_step >= args.max_train_steps:
993 break 982 break
994 983
995 accelerator.wait_for_everyone() 984 accelerator.wait_for_everyone()
996 985
diff --git a/training/util.py b/training/util.py
index 2b7f71d..ae6bfc4 100644
--- a/training/util.py
+++ b/training/util.py
@@ -59,7 +59,7 @@ class CheckpointerBase:
59 def save_samples(self, pipeline, step, num_inference_steps, guidance_scale=7.5, eta=0.0): 59 def save_samples(self, pipeline, step, num_inference_steps, guidance_scale=7.5, eta=0.0):
60 samples_path = Path(self.output_dir).joinpath("samples") 60 samples_path = Path(self.output_dir).joinpath("samples")
61 61
62 train_data = self.datamodule.train_dataloaders[0] 62 train_data = self.datamodule.train_dataloader
63 val_data = self.datamodule.val_dataloader 63 val_data = self.datamodule.val_dataloader
64 64
65 generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) 65 generator = torch.Generator(device=pipeline.device).manual_seed(self.seed)