summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--data/csv.py144
-rw-r--r--infer.py23
-rw-r--r--train_dreambooth.py12
-rw-r--r--train_ti.py94
-rw-r--r--training/util.py4
5 files changed, 151 insertions, 126 deletions
diff --git a/data/csv.py b/data/csv.py
index 4986153..59d6d8d 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -11,11 +11,26 @@ from models.clip.prompt import PromptProcessor
11from data.keywords import prompt_to_keywords, keywords_to_prompt 11from data.keywords import prompt_to_keywords, keywords_to_prompt
12 12
13 13
14image_cache: dict[str, Image.Image] = {}
15
16
17def get_image(path):
18 if path in image_cache:
19 return image_cache[path]
20
21 image = Image.open(path)
22 if not image.mode == "RGB":
23 image = image.convert("RGB")
24 image_cache[path] = image
25
26 return image
27
28
14def prepare_prompt(prompt: Union[str, Dict[str, str]]): 29def prepare_prompt(prompt: Union[str, Dict[str, str]]):
15 return {"content": prompt} if isinstance(prompt, str) else prompt 30 return {"content": prompt} if isinstance(prompt, str) else prompt
16 31
17 32
18class CSVDataItem(NamedTuple): 33class VlpnDataItem(NamedTuple):
19 instance_image_path: Path 34 instance_image_path: Path
20 class_image_path: Path 35 class_image_path: Path
21 prompt: list[str] 36 prompt: list[str]
@@ -24,7 +39,15 @@ class CSVDataItem(NamedTuple):
24 collection: list[str] 39 collection: list[str]
25 40
26 41
27class CSVDataModule(): 42class VlpnDataBucket():
43 def __init__(self, width: int, height: int):
44 self.width = width
45 self.height = height
46 self.ratio = width / height
47 self.items: list[VlpnDataItem] = []
48
49
50class VlpnDataModule():
28 def __init__( 51 def __init__(
29 self, 52 self,
30 batch_size: int, 53 batch_size: int,
@@ -36,11 +59,10 @@ class CSVDataModule():
36 repeats: int = 1, 59 repeats: int = 1,
37 dropout: float = 0, 60 dropout: float = 0,
38 interpolation: str = "bicubic", 61 interpolation: str = "bicubic",
39 center_crop: bool = False,
40 template_key: str = "template", 62 template_key: str = "template",
41 valid_set_size: Optional[int] = None, 63 valid_set_size: Optional[int] = None,
42 seed: Optional[int] = None, 64 seed: Optional[int] = None,
43 filter: Optional[Callable[[CSVDataItem], bool]] = None, 65 filter: Optional[Callable[[VlpnDataItem], bool]] = None,
44 collate_fn=None, 66 collate_fn=None,
45 num_workers: int = 0 67 num_workers: int = 0
46 ): 68 ):
@@ -60,7 +82,6 @@ class CSVDataModule():
60 self.size = size 82 self.size = size
61 self.repeats = repeats 83 self.repeats = repeats
62 self.dropout = dropout 84 self.dropout = dropout
63 self.center_crop = center_crop
64 self.template_key = template_key 85 self.template_key = template_key
65 self.interpolation = interpolation 86 self.interpolation = interpolation
66 self.valid_set_size = valid_set_size 87 self.valid_set_size = valid_set_size
@@ -70,14 +91,14 @@ class CSVDataModule():
70 self.num_workers = num_workers 91 self.num_workers = num_workers
71 self.batch_size = batch_size 92 self.batch_size = batch_size
72 93
73 def prepare_items(self, template, expansions, data) -> list[CSVDataItem]: 94 def prepare_items(self, template, expansions, data) -> list[VlpnDataItem]:
74 image = template["image"] if "image" in template else "{}" 95 image = template["image"] if "image" in template else "{}"
75 prompt = template["prompt"] if "prompt" in template else "{content}" 96 prompt = template["prompt"] if "prompt" in template else "{content}"
76 cprompt = template["cprompt"] if "cprompt" in template else "{content}" 97 cprompt = template["cprompt"] if "cprompt" in template else "{content}"
77 nprompt = template["nprompt"] if "nprompt" in template else "{content}" 98 nprompt = template["nprompt"] if "nprompt" in template else "{content}"
78 99
79 return [ 100 return [
80 CSVDataItem( 101 VlpnDataItem(
81 self.data_root.joinpath(image.format(item["image"])), 102 self.data_root.joinpath(image.format(item["image"])),
82 None, 103 None,
83 prompt_to_keywords( 104 prompt_to_keywords(
@@ -97,17 +118,17 @@ class CSVDataModule():
97 for item in data 118 for item in data
98 ] 119 ]
99 120
100 def filter_items(self, items: list[CSVDataItem]) -> list[CSVDataItem]: 121 def filter_items(self, items: list[VlpnDataItem]) -> list[VlpnDataItem]:
101 if self.filter is None: 122 if self.filter is None:
102 return items 123 return items
103 124
104 return [item for item in items if self.filter(item)] 125 return [item for item in items if self.filter(item)]
105 126
106 def pad_items(self, items: list[CSVDataItem], num_class_images: int = 1) -> list[CSVDataItem]: 127 def pad_items(self, items: list[VlpnDataItem], num_class_images: int = 1) -> list[VlpnDataItem]:
107 image_multiplier = max(num_class_images, 1) 128 image_multiplier = max(num_class_images, 1)
108 129
109 return [ 130 return [
110 CSVDataItem( 131 VlpnDataItem(
111 item.instance_image_path, 132 item.instance_image_path,
112 self.class_root.joinpath(f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}"), 133 self.class_root.joinpath(f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}"),
113 item.prompt, 134 item.prompt,
@@ -119,7 +140,30 @@ class CSVDataModule():
119 for i in range(image_multiplier) 140 for i in range(image_multiplier)
120 ] 141 ]
121 142
122 def prepare_data(self): 143 def generate_buckets(self, items: list[VlpnDataItem]):
144 buckets = [VlpnDataBucket(self.size, self.size)]
145
146 for i in range(1, 5):
147 s = self.size + i * 64
148 buckets.append(VlpnDataBucket(s, self.size))
149 buckets.append(VlpnDataBucket(self.size, s))
150
151 for item in items:
152 image = get_image(item.instance_image_path)
153 ratio = image.width / image.height
154
155 if ratio >= 1:
156 candidates = [bucket for bucket in buckets if bucket.ratio >= 1 and ratio >= bucket.ratio]
157 else:
158 candidates = [bucket for bucket in buckets if bucket.ratio <= 1 and ratio <= bucket.ratio]
159
160 for bucket in candidates:
161 bucket.items.append(item)
162
163 buckets = [bucket for bucket in buckets if len(bucket.items) != 0]
164 return buckets
165
166 def setup(self):
123 with open(self.data_file, 'rt') as f: 167 with open(self.data_file, 'rt') as f:
124 metadata = json.load(f) 168 metadata = json.load(f)
125 template = metadata[self.template_key] if self.template_key in metadata else {} 169 template = metadata[self.template_key] if self.template_key in metadata else {}
@@ -144,48 +188,48 @@ class CSVDataModule():
144 self.data_train = self.pad_items(data_train, self.num_class_images) 188 self.data_train = self.pad_items(data_train, self.num_class_images)
145 self.data_val = self.pad_items(data_val) 189 self.data_val = self.pad_items(data_val)
146 190
147 def setup(self, stage=None): 191 buckets = self.generate_buckets(data_train)
148 train_dataset = CSVDataset( 192
149 self.data_train, self.prompt_processor, batch_size=self.batch_size, 193 train_datasets = [
150 num_class_images=self.num_class_images, 194 VlpnDataset(
151 size=self.size, interpolation=self.interpolation, 195 bucket.items, self.prompt_processor, batch_size=self.batch_size,
152 center_crop=self.center_crop, repeats=self.repeats, dropout=self.dropout 196 width=bucket.width, height=bucket.height, interpolation=self.interpolation,
153 ) 197 num_class_images=self.num_class_images, repeats=self.repeats, dropout=self.dropout,
154 val_dataset = CSVDataset( 198 )
155 self.data_val, self.prompt_processor, batch_size=self.batch_size, 199 for bucket in buckets
156 size=self.size, interpolation=self.interpolation, 200 ]
157 center_crop=self.center_crop 201
158 ) 202 val_dataset = VlpnDataset(
159 self.train_dataloader_ = DataLoader( 203 data_val, self.prompt_processor, batch_size=self.batch_size,
160 train_dataset, batch_size=self.batch_size, 204 width=self.size, height=self.size, interpolation=self.interpolation,
161 shuffle=True, pin_memory=True, collate_fn=self.collate_fn,
162 num_workers=self.num_workers
163 )
164 self.val_dataloader_ = DataLoader(
165 val_dataset, batch_size=self.batch_size,
166 pin_memory=True, collate_fn=self.collate_fn,
167 num_workers=self.num_workers
168 ) 205 )
169 206
170 def train_dataloader(self): 207 self.train_dataloaders = [
171 return self.train_dataloader_ 208 DataLoader(
209 dataset, batch_size=self.batch_size, shuffle=True,
210 pin_memory=True, collate_fn=self.collate_fn, num_workers=self.num_workers
211 )
212 for dataset in train_datasets
213 ]
172 214
173 def val_dataloader(self): 215 self.val_dataloader = DataLoader(
174 return self.val_dataloader_ 216 val_dataset, batch_size=self.batch_size,
217 pin_memory=True, collate_fn=self.collate_fn, num_workers=self.num_workers
218 )
175 219
176 220
177class CSVDataset(Dataset): 221class VlpnDataset(Dataset):
178 def __init__( 222 def __init__(
179 self, 223 self,
180 data: List[CSVDataItem], 224 data: List[VlpnDataItem],
181 prompt_processor: PromptProcessor, 225 prompt_processor: PromptProcessor,
182 batch_size: int = 1, 226 batch_size: int = 1,
183 num_class_images: int = 0, 227 num_class_images: int = 0,
184 size: int = 768, 228 width: int = 768,
229 height: int = 768,
185 repeats: int = 1, 230 repeats: int = 1,
186 dropout: float = 0, 231 dropout: float = 0,
187 interpolation: str = "bicubic", 232 interpolation: str = "bicubic",
188 center_crop: bool = False,
189 ): 233 ):
190 234
191 self.data = data 235 self.data = data
@@ -193,7 +237,6 @@ class CSVDataset(Dataset):
193 self.batch_size = batch_size 237 self.batch_size = batch_size
194 self.num_class_images = num_class_images 238 self.num_class_images = num_class_images
195 self.dropout = dropout 239 self.dropout = dropout
196 self.image_cache = {}
197 240
198 self.num_instance_images = len(self.data) 241 self.num_instance_images = len(self.data)
199 self._length = self.num_instance_images * repeats 242 self._length = self.num_instance_images * repeats
@@ -206,8 +249,8 @@ class CSVDataset(Dataset):
206 }[interpolation] 249 }[interpolation]
207 self.image_transforms = transforms.Compose( 250 self.image_transforms = transforms.Compose(
208 [ 251 [
209 transforms.Resize(size, interpolation=self.interpolation), 252 transforms.Resize(min(width, height), interpolation=self.interpolation),
210 transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), 253 transforms.RandomCrop((height, width)),
211 transforms.RandomHorizontalFlip(), 254 transforms.RandomHorizontalFlip(),
212 transforms.ToTensor(), 255 transforms.ToTensor(),
213 transforms.Normalize([0.5], [0.5]), 256 transforms.Normalize([0.5], [0.5]),
@@ -217,17 +260,6 @@ class CSVDataset(Dataset):
217 def __len__(self): 260 def __len__(self):
218 return math.ceil(self._length / self.batch_size) * self.batch_size 261 return math.ceil(self._length / self.batch_size) * self.batch_size
219 262
220 def get_image(self, path):
221 if path in self.image_cache:
222 return self.image_cache[path]
223
224 image = Image.open(path)
225 if not image.mode == "RGB":
226 image = image.convert("RGB")
227 self.image_cache[path] = image
228
229 return image
230
231 def get_example(self, i): 263 def get_example(self, i):
232 item = self.data[i % self.num_instance_images] 264 item = self.data[i % self.num_instance_images]
233 265
@@ -235,9 +267,9 @@ class CSVDataset(Dataset):
235 example["prompts"] = item.prompt 267 example["prompts"] = item.prompt
236 example["cprompts"] = item.cprompt 268 example["cprompts"] = item.cprompt
237 example["nprompts"] = item.nprompt 269 example["nprompts"] = item.nprompt
238 example["instance_images"] = self.get_image(item.instance_image_path) 270 example["instance_images"] = get_image(item.instance_image_path)
239 if self.num_class_images != 0: 271 if self.num_class_images != 0:
240 example["class_images"] = self.get_image(item.class_image_path) 272 example["class_images"] = get_image(item.class_image_path)
241 273
242 return example 274 return example
243 275
diff --git a/infer.py b/infer.py
index d3d5f1b..2b07b21 100644
--- a/infer.py
+++ b/infer.py
@@ -238,16 +238,15 @@ def create_pipeline(model, dtype):
238 return pipeline 238 return pipeline
239 239
240 240
241def shuffle_prompts(prompts: list[str]) -> list[str]:
242 return [keywords_to_prompt(prompt_to_keywords(prompt), shuffle=True) for prompt in prompts]
243
244
241@torch.inference_mode() 245@torch.inference_mode()
242def generate(output_dir: Path, pipeline, args): 246def generate(output_dir: Path, pipeline, args):
243 if isinstance(args.prompt, str): 247 if isinstance(args.prompt, str):
244 args.prompt = [args.prompt] 248 args.prompt = [args.prompt]
245 249
246 if args.shuffle:
247 args.prompt *= args.batch_size
248 args.batch_size = 1
249 args.prompt = [keywords_to_prompt(prompt_to_keywords(prompt), shuffle=True) for prompt in args.prompt]
250
251 args.prompt = [args.template.format(prompt) for prompt in args.prompt] 250 args.prompt = [args.template.format(prompt) for prompt in args.prompt]
252 251
253 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") 252 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
@@ -263,9 +262,6 @@ def generate(output_dir: Path, pipeline, args):
263 dir = output_dir.joinpath(slugify(prompt)[:100]) 262 dir = output_dir.joinpath(slugify(prompt)[:100])
264 dir.mkdir(parents=True, exist_ok=True) 263 dir.mkdir(parents=True, exist_ok=True)
265 image_dir.append(dir) 264 image_dir.append(dir)
266
267 with open(dir.joinpath('prompt.txt'), 'w') as f:
268 f.write(prompt)
269 else: 265 else:
270 output_dir = output_dir.joinpath(f"{now}_{slugify(args.prompt[0])[:100]}") 266 output_dir = output_dir.joinpath(f"{now}_{slugify(args.prompt[0])[:100]}")
271 output_dir.mkdir(parents=True, exist_ok=True) 267 output_dir.mkdir(parents=True, exist_ok=True)
@@ -306,9 +302,10 @@ def generate(output_dir: Path, pipeline, args):
306 ) 302 )
307 303
308 seed = args.seed + i 304 seed = args.seed + i
305 prompt = shuffle_prompts(args.prompt) if args.shuffle else args.prompt
309 generator = torch.Generator(device="cuda").manual_seed(seed) 306 generator = torch.Generator(device="cuda").manual_seed(seed)
310 images = pipeline( 307 images = pipeline(
311 prompt=args.prompt, 308 prompt=prompt,
312 negative_prompt=args.negative_prompt, 309 negative_prompt=args.negative_prompt,
313 height=args.height, 310 height=args.height,
314 width=args.width, 311 width=args.width,
@@ -321,9 +318,13 @@ def generate(output_dir: Path, pipeline, args):
321 ).images 318 ).images
322 319
323 for j, image in enumerate(images): 320 for j, image in enumerate(images):
321 basename = f"{seed}_{j // len(args.prompt)}"
324 dir = image_dir[j % len(args.prompt)] 322 dir = image_dir[j % len(args.prompt)]
325 image.save(dir.joinpath(f"{seed}_{j // len(args.prompt)}.png")) 323
326 image.save(dir.joinpath(f"{seed}_{j // len(args.prompt)}.jpg"), quality=85) 324 image.save(dir.joinpath(f"{basename}.png"))
325 image.save(dir.joinpath(f"{basename}.jpg"), quality=85)
326 with open(dir.joinpath(f"{basename}.txt"), 'w') as f:
327 f.write(prompt[j % len(args.prompt)])
327 328
328 if torch.cuda.is_available(): 329 if torch.cuda.is_available():
329 torch.cuda.empty_cache() 330 torch.cuda.empty_cache()
diff --git a/train_dreambooth.py b/train_dreambooth.py
index e8256be..d265bcc 100644
--- a/train_dreambooth.py
+++ b/train_dreambooth.py
@@ -22,7 +22,7 @@ from slugify import slugify
22 22
23from util import load_config, load_embeddings_from_dir 23from util import load_config, load_embeddings_from_dir
24from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 24from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
25from data.csv import CSVDataModule, CSVDataItem 25from data.csv import VlpnDataModule, VlpnDataItem
26from training.common import run_model 26from training.common import run_model
27from training.optimization import get_one_cycle_schedule 27from training.optimization import get_one_cycle_schedule
28from training.lr import LRFinder 28from training.lr import LRFinder
@@ -172,11 +172,6 @@ def parse_args():
172 ), 172 ),
173 ) 173 )
174 parser.add_argument( 174 parser.add_argument(
175 "--center_crop",
176 action="store_true",
177 help="Whether to center crop images before resizing to resolution"
178 )
179 parser.add_argument(
180 "--dataloader_num_workers", 175 "--dataloader_num_workers",
181 type=int, 176 type=int,
182 default=0, 177 default=0,
@@ -698,7 +693,7 @@ def main():
698 elif args.mixed_precision == "bf16": 693 elif args.mixed_precision == "bf16":
699 weight_dtype = torch.bfloat16 694 weight_dtype = torch.bfloat16
700 695
701 def keyword_filter(item: CSVDataItem): 696 def keyword_filter(item: VlpnDataItem):
702 cond3 = args.collection is None or args.collection in item.collection 697 cond3 = args.collection is None or args.collection in item.collection
703 cond4 = args.exclude_collections is None or not any( 698 cond4 = args.exclude_collections is None or not any(
704 collection in item.collection 699 collection in item.collection
@@ -733,7 +728,7 @@ def main():
733 } 728 }
734 return batch 729 return batch
735 730
736 datamodule = CSVDataModule( 731 datamodule = VlpnDataModule(
737 data_file=args.train_data_file, 732 data_file=args.train_data_file,
738 batch_size=args.train_batch_size, 733 batch_size=args.train_batch_size,
739 prompt_processor=prompt_processor, 734 prompt_processor=prompt_processor,
@@ -742,7 +737,6 @@ def main():
742 size=args.resolution, 737 size=args.resolution,
743 repeats=args.repeats, 738 repeats=args.repeats,
744 dropout=args.tag_dropout, 739 dropout=args.tag_dropout,
745 center_crop=args.center_crop,
746 template_key=args.train_data_template, 740 template_key=args.train_data_template,
747 valid_set_size=args.valid_set_size, 741 valid_set_size=args.valid_set_size,
748 num_workers=args.dataloader_num_workers, 742 num_workers=args.dataloader_num_workers,
diff --git a/train_ti.py b/train_ti.py
index 0ffc9e6..89c6672 100644
--- a/train_ti.py
+++ b/train_ti.py
@@ -21,7 +21,7 @@ from slugify import slugify
21 21
22from util import load_config, load_embeddings_from_dir 22from util import load_config, load_embeddings_from_dir
23from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion 23from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion
24from data.csv import CSVDataModule, CSVDataItem 24from data.csv import VlpnDataModule, VlpnDataItem
25from training.common import run_model 25from training.common import run_model
26from training.optimization import get_one_cycle_schedule 26from training.optimization import get_one_cycle_schedule
27from training.lr import LRFinder 27from training.lr import LRFinder
@@ -146,11 +146,6 @@ def parse_args():
146 ), 146 ),
147 ) 147 )
148 parser.add_argument( 148 parser.add_argument(
149 "--center_crop",
150 action="store_true",
151 help="Whether to center crop images before resizing to resolution"
152 )
153 parser.add_argument(
154 "--tag_dropout", 149 "--tag_dropout",
155 type=float, 150 type=float,
156 default=0.1, 151 default=0.1,
@@ -668,7 +663,7 @@ def main():
668 elif args.mixed_precision == "bf16": 663 elif args.mixed_precision == "bf16":
669 weight_dtype = torch.bfloat16 664 weight_dtype = torch.bfloat16
670 665
671 def keyword_filter(item: CSVDataItem): 666 def keyword_filter(item: VlpnDataItem):
672 cond1 = any( 667 cond1 = any(
673 keyword in part 668 keyword in part
674 for keyword in args.placeholder_token 669 for keyword in args.placeholder_token
@@ -708,7 +703,7 @@ def main():
708 } 703 }
709 return batch 704 return batch
710 705
711 datamodule = CSVDataModule( 706 datamodule = VlpnDataModule(
712 data_file=args.train_data_file, 707 data_file=args.train_data_file,
713 batch_size=args.train_batch_size, 708 batch_size=args.train_batch_size,
714 prompt_processor=prompt_processor, 709 prompt_processor=prompt_processor,
@@ -717,7 +712,6 @@ def main():
717 size=args.resolution, 712 size=args.resolution,
718 repeats=args.repeats, 713 repeats=args.repeats,
719 dropout=args.tag_dropout, 714 dropout=args.tag_dropout,
720 center_crop=args.center_crop,
721 template_key=args.train_data_template, 715 template_key=args.train_data_template,
722 valid_set_size=args.valid_set_size, 716 valid_set_size=args.valid_set_size,
723 num_workers=args.dataloader_num_workers, 717 num_workers=args.dataloader_num_workers,
@@ -725,8 +719,6 @@ def main():
725 filter=keyword_filter, 719 filter=keyword_filter,
726 collate_fn=collate_fn 720 collate_fn=collate_fn
727 ) 721 )
728
729 datamodule.prepare_data()
730 datamodule.setup() 722 datamodule.setup()
731 723
732 if args.num_class_images != 0: 724 if args.num_class_images != 0:
@@ -769,12 +761,14 @@ def main():
769 if torch.cuda.is_available(): 761 if torch.cuda.is_available():
770 torch.cuda.empty_cache() 762 torch.cuda.empty_cache()
771 763
772 train_dataloader = datamodule.train_dataloader() 764 train_dataloaders = datamodule.train_dataloaders
773 val_dataloader = datamodule.val_dataloader() 765 default_train_dataloader = train_dataloaders[0]
766 val_dataloader = datamodule.val_dataloader
774 767
775 # Scheduler and math around the number of training steps. 768 # Scheduler and math around the number of training steps.
776 overrode_max_train_steps = False 769 overrode_max_train_steps = False
777 num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 770 num_update_steps_per_dataloader = sum(len(dataloader) for dataloader in train_dataloaders)
771 num_update_steps_per_epoch = math.ceil(num_update_steps_per_dataloader / args.gradient_accumulation_steps)
778 if args.max_train_steps is None: 772 if args.max_train_steps is None:
779 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 773 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
780 overrode_max_train_steps = True 774 overrode_max_train_steps = True
@@ -811,9 +805,10 @@ def main():
811 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, 805 num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
812 ) 806 )
813 807
814 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( 808 text_encoder, optimizer, val_dataloader, lr_scheduler = accelerator.prepare(
815 text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler 809 text_encoder, optimizer, val_dataloader, lr_scheduler
816 ) 810 )
811 train_dataloaders = accelerator.prepare(*train_dataloaders)
817 812
818 # Move vae and unet to device 813 # Move vae and unet to device
819 vae.to(accelerator.device, dtype=weight_dtype) 814 vae.to(accelerator.device, dtype=weight_dtype)
@@ -831,7 +826,8 @@ def main():
831 unet.eval() 826 unet.eval()
832 827
833 # We need to recalculate our total training steps as the size of the training dataloader may have changed. 828 # We need to recalculate our total training steps as the size of the training dataloader may have changed.
834 num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) 829 num_update_steps_per_dataloader = sum(len(dataloader) for dataloader in train_dataloaders)
830 num_update_steps_per_epoch = math.ceil(num_update_steps_per_dataloader / args.gradient_accumulation_steps)
835 if overrode_max_train_steps: 831 if overrode_max_train_steps:
836 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch 832 args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
837 833
@@ -889,7 +885,7 @@ def main():
889 accelerator, 885 accelerator,
890 text_encoder, 886 text_encoder,
891 optimizer, 887 optimizer,
892 train_dataloader, 888 default_train_dataloader,
893 val_dataloader, 889 val_dataloader,
894 loop, 890 loop,
895 on_train=on_train, 891 on_train=on_train,
@@ -968,46 +964,48 @@ def main():
968 text_encoder.train() 964 text_encoder.train()
969 965
970 with on_train(): 966 with on_train():
971 for step, batch in enumerate(train_dataloader): 967 for train_dataloader in train_dataloaders:
972 with accelerator.accumulate(text_encoder): 968 for step, batch in enumerate(train_dataloader):
973 loss, acc, bsz = loop(step, batch) 969 with accelerator.accumulate(text_encoder):
970 loss, acc, bsz = loop(step, batch)
974 971
975 accelerator.backward(loss) 972 accelerator.backward(loss)
976 973
977 optimizer.step() 974 optimizer.step()
978 if not accelerator.optimizer_step_was_skipped: 975 if not accelerator.optimizer_step_was_skipped:
979 lr_scheduler.step() 976 lr_scheduler.step()
980 optimizer.zero_grad(set_to_none=True) 977 optimizer.zero_grad(set_to_none=True)
981 978
982 avg_loss.update(loss.detach_(), bsz) 979 avg_loss.update(loss.detach_(), bsz)
983 avg_acc.update(acc.detach_(), bsz) 980 avg_acc.update(acc.detach_(), bsz)
984 981
985 # Checks if the accelerator has performed an optimization step behind the scenes 982 # Checks if the accelerator has performed an optimization step behind the scenes
986 if accelerator.sync_gradients: 983 if accelerator.sync_gradients:
987 if args.use_ema: 984 if args.use_ema:
988 ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) 985 ema_embeddings.step(
986 text_encoder.text_model.embeddings.temp_token_embedding.parameters())
989 987
990 local_progress_bar.update(1) 988 local_progress_bar.update(1)
991 global_progress_bar.update(1) 989 global_progress_bar.update(1)
992 990
993 global_step += 1 991 global_step += 1
994 992
995 logs = { 993 logs = {
996 "train/loss": avg_loss.avg.item(), 994 "train/loss": avg_loss.avg.item(),
997 "train/acc": avg_acc.avg.item(), 995 "train/acc": avg_acc.avg.item(),
998 "train/cur_loss": loss.item(), 996 "train/cur_loss": loss.item(),
999 "train/cur_acc": acc.item(), 997 "train/cur_acc": acc.item(),
1000 "lr": lr_scheduler.get_last_lr()[0], 998 "lr": lr_scheduler.get_last_lr()[0],
1001 } 999 }
1002 if args.use_ema: 1000 if args.use_ema:
1003 logs["ema_decay"] = ema_embeddings.decay 1001 logs["ema_decay"] = ema_embeddings.decay
1004 1002
1005 accelerator.log(logs, step=global_step) 1003 accelerator.log(logs, step=global_step)
1006 1004
1007 local_progress_bar.set_postfix(**logs) 1005 local_progress_bar.set_postfix(**logs)
1008 1006
1009 if global_step >= args.max_train_steps: 1007 if global_step >= args.max_train_steps:
1010 break 1008 break
1011 1009
1012 accelerator.wait_for_everyone() 1010 accelerator.wait_for_everyone()
1013 1011
diff --git a/training/util.py b/training/util.py
index bc466e2..6f42228 100644
--- a/training/util.py
+++ b/training/util.py
@@ -58,8 +58,8 @@ class CheckpointerBase:
58 def save_samples(self, pipeline, step, num_inference_steps, guidance_scale=7.5, eta=0.0): 58 def save_samples(self, pipeline, step, num_inference_steps, guidance_scale=7.5, eta=0.0):
59 samples_path = Path(self.output_dir).joinpath("samples") 59 samples_path = Path(self.output_dir).joinpath("samples")
60 60
61 train_data = self.datamodule.train_dataloader() 61 train_data = self.datamodule.train_dataloaders[0]
62 val_data = self.datamodule.val_dataloader() 62 val_data = self.datamodule.val_dataloader
63 63
64 generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) 64 generator = torch.Generator(device=pipeline.device).manual_seed(self.seed)
65 65