diff options
-rw-r--r-- | data/csv.py | 273 | ||||
-rw-r--r-- | train_dreambooth.py | 100 | ||||
-rw-r--r-- | train_ti.py | 85 | ||||
-rw-r--r-- | training/util.py | 2 |
4 files changed, 237 insertions, 223 deletions
diff --git a/data/csv.py b/data/csv.py index 654aec1..9be36ba 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -2,20 +2,28 @@ import math | |||
2 | import torch | 2 | import torch |
3 | import json | 3 | import json |
4 | from pathlib import Path | 4 | from pathlib import Path |
5 | from typing import NamedTuple, Optional, Union, Callable | ||
6 | |||
5 | from PIL import Image | 7 | from PIL import Image |
6 | from torch.utils.data import Dataset, DataLoader, random_split | ||
7 | from torchvision import transforms | ||
8 | from typing import Dict, NamedTuple, List, Optional, Union, Callable | ||
9 | 8 | ||
10 | import numpy as np | 9 | from torch.utils.data import IterableDataset, DataLoader, random_split |
10 | from torchvision import transforms | ||
11 | 11 | ||
12 | from models.clip.prompt import PromptProcessor | ||
13 | from data.keywords import prompt_to_keywords, keywords_to_prompt | 12 | from data.keywords import prompt_to_keywords, keywords_to_prompt |
13 | from models.clip.prompt import PromptProcessor | ||
14 | 14 | ||
15 | 15 | ||
16 | image_cache: dict[str, Image.Image] = {} | 16 | image_cache: dict[str, Image.Image] = {} |
17 | 17 | ||
18 | 18 | ||
19 | interpolations = { | ||
20 | "linear": transforms.InterpolationMode.NEAREST, | ||
21 | "bilinear": transforms.InterpolationMode.BILINEAR, | ||
22 | "bicubic": transforms.InterpolationMode.BICUBIC, | ||
23 | "lanczos": transforms.InterpolationMode.LANCZOS, | ||
24 | } | ||
25 | |||
26 | |||
19 | def get_image(path): | 27 | def get_image(path): |
20 | if path in image_cache: | 28 | if path in image_cache: |
21 | return image_cache[path] | 29 | return image_cache[path] |
@@ -28,10 +36,46 @@ def get_image(path): | |||
28 | return image | 36 | return image |
29 | 37 | ||
30 | 38 | ||
31 | def prepare_prompt(prompt: Union[str, Dict[str, str]]): | 39 | def prepare_prompt(prompt: Union[str, dict[str, str]]): |
32 | return {"content": prompt} if isinstance(prompt, str) else prompt | 40 | return {"content": prompt} if isinstance(prompt, str) else prompt |
33 | 41 | ||
34 | 42 | ||
43 | def generate_buckets(items: list[str], size: int, num_buckets: int, progressive_buckets: bool): | ||
44 | item_order: list[int] = [] | ||
45 | item_buckets: list[int] = [] | ||
46 | buckets = [1.0] | ||
47 | |||
48 | for i in range(1, num_buckets + 1): | ||
49 | s = size + i * 64 | ||
50 | buckets.append(s / size) | ||
51 | buckets.append(size / s) | ||
52 | |||
53 | buckets = torch.tensor(buckets) | ||
54 | bucket_indices = torch.arange(len(buckets)) | ||
55 | |||
56 | for i, item in enumerate(items): | ||
57 | image = get_image(item) | ||
58 | ratio = image.width / image.height | ||
59 | |||
60 | if ratio >= 1: | ||
61 | mask = torch.bitwise_and(buckets >= 1, buckets <= ratio) | ||
62 | else: | ||
63 | mask = torch.bitwise_and(buckets <= 1, buckets >= ratio) | ||
64 | |||
65 | if not progressive_buckets: | ||
66 | mask = (buckets + (~mask) * math.inf - ratio).abs().argmin() | ||
67 | |||
68 | indices = bucket_indices[mask] | ||
69 | |||
70 | if len(indices.shape) == 0: | ||
71 | indices = indices.unsqueeze(0) | ||
72 | |||
73 | item_order += [i] * len(indices) | ||
74 | item_buckets += indices | ||
75 | |||
76 | return buckets.tolist(), item_order, item_buckets | ||
77 | |||
78 | |||
35 | class VlpnDataItem(NamedTuple): | 79 | class VlpnDataItem(NamedTuple): |
36 | instance_image_path: Path | 80 | instance_image_path: Path |
37 | class_image_path: Path | 81 | class_image_path: Path |
@@ -41,14 +85,6 @@ class VlpnDataItem(NamedTuple): | |||
41 | collection: list[str] | 85 | collection: list[str] |
42 | 86 | ||
43 | 87 | ||
44 | class VlpnDataBucket(): | ||
45 | def __init__(self, width: int, height: int): | ||
46 | self.width = width | ||
47 | self.height = height | ||
48 | self.ratio = width / height | ||
49 | self.items: list[VlpnDataItem] = [] | ||
50 | |||
51 | |||
52 | class VlpnDataModule(): | 88 | class VlpnDataModule(): |
53 | def __init__( | 89 | def __init__( |
54 | self, | 90 | self, |
@@ -60,7 +96,6 @@ class VlpnDataModule(): | |||
60 | size: int = 768, | 96 | size: int = 768, |
61 | num_aspect_ratio_buckets: int = 0, | 97 | num_aspect_ratio_buckets: int = 0, |
62 | progressive_aspect_ratio_buckets: bool = False, | 98 | progressive_aspect_ratio_buckets: bool = False, |
63 | repeats: int = 1, | ||
64 | dropout: float = 0, | 99 | dropout: float = 0, |
65 | interpolation: str = "bicubic", | 100 | interpolation: str = "bicubic", |
66 | template_key: str = "template", | 101 | template_key: str = "template", |
@@ -86,7 +121,6 @@ class VlpnDataModule(): | |||
86 | self.size = size | 121 | self.size = size |
87 | self.num_aspect_ratio_buckets = num_aspect_ratio_buckets | 122 | self.num_aspect_ratio_buckets = num_aspect_ratio_buckets |
88 | self.progressive_aspect_ratio_buckets = progressive_aspect_ratio_buckets | 123 | self.progressive_aspect_ratio_buckets = progressive_aspect_ratio_buckets |
89 | self.repeats = repeats | ||
90 | self.dropout = dropout | 124 | self.dropout = dropout |
91 | self.template_key = template_key | 125 | self.template_key = template_key |
92 | self.interpolation = interpolation | 126 | self.interpolation = interpolation |
@@ -146,36 +180,6 @@ class VlpnDataModule(): | |||
146 | for i in range(image_multiplier) | 180 | for i in range(image_multiplier) |
147 | ] | 181 | ] |
148 | 182 | ||
149 | def generate_buckets(self, items: list[VlpnDataItem]): | ||
150 | buckets = [VlpnDataBucket(self.size, self.size)] | ||
151 | |||
152 | for i in range(1, self.num_aspect_ratio_buckets + 1): | ||
153 | s = self.size + i * 64 | ||
154 | buckets.append(VlpnDataBucket(s, self.size)) | ||
155 | buckets.append(VlpnDataBucket(self.size, s)) | ||
156 | |||
157 | buckets = np.array(buckets) | ||
158 | bucket_ratios = np.array([bucket.ratio for bucket in buckets]) | ||
159 | |||
160 | for item in items: | ||
161 | image = get_image(item.instance_image_path) | ||
162 | ratio = image.width / image.height | ||
163 | |||
164 | if ratio >= 1: | ||
165 | mask = np.bitwise_and(bucket_ratios >= 1, bucket_ratios <= ratio) | ||
166 | else: | ||
167 | mask = np.bitwise_and(bucket_ratios <= 1, bucket_ratios >= ratio) | ||
168 | |||
169 | if not self.progressive_aspect_ratio_buckets: | ||
170 | ratios = bucket_ratios.copy() | ||
171 | ratios[~mask] = math.inf | ||
172 | mask = [np.argmin(np.abs(ratios - ratio))] | ||
173 | |||
174 | for bucket in buckets[mask]: | ||
175 | bucket.items.append(item) | ||
176 | |||
177 | return [bucket for bucket in buckets if len(bucket.items) != 0] | ||
178 | |||
179 | def setup(self): | 183 | def setup(self): |
180 | with open(self.data_file, 'rt') as f: | 184 | with open(self.data_file, 'rt') as f: |
181 | metadata = json.load(f) | 185 | metadata = json.load(f) |
@@ -201,105 +205,136 @@ class VlpnDataModule(): | |||
201 | self.data_train = self.pad_items(data_train, self.num_class_images) | 205 | self.data_train = self.pad_items(data_train, self.num_class_images) |
202 | self.data_val = self.pad_items(data_val) | 206 | self.data_val = self.pad_items(data_val) |
203 | 207 | ||
204 | buckets = self.generate_buckets(data_train) | 208 | train_dataset = VlpnDataset( |
205 | 209 | self.data_train, self.prompt_processor, | |
206 | train_datasets = [ | 210 | num_buckets=self.num_aspect_ratio_buckets, progressive_buckets=self.progressive_aspect_ratio_buckets, |
207 | VlpnDataset( | 211 | batch_size=self.batch_size, |
208 | bucket.items, self.prompt_processor, | 212 | size=self.size, interpolation=self.interpolation, |
209 | width=bucket.width, height=bucket.height, interpolation=self.interpolation, | 213 | num_class_images=self.num_class_images, dropout=self.dropout, shuffle=True, |
210 | num_class_images=self.num_class_images, repeats=self.repeats, dropout=self.dropout, | 214 | ) |
211 | ) | ||
212 | for bucket in buckets | ||
213 | ] | ||
214 | 215 | ||
215 | val_dataset = VlpnDataset( | 216 | val_dataset = VlpnDataset( |
216 | data_val, self.prompt_processor, | 217 | self.data_val, self.prompt_processor, |
217 | width=self.size, height=self.size, interpolation=self.interpolation, | 218 | batch_size=self.batch_size, |
219 | size=self.size, interpolation=self.interpolation, | ||
218 | ) | 220 | ) |
219 | 221 | ||
220 | self.train_dataloaders = [ | 222 | self.train_dataloader = DataLoader( |
221 | DataLoader( | 223 | train_dataset, |
222 | dataset, batch_size=self.batch_size, shuffle=True, | 224 | batch_size=None, pin_memory=True, collate_fn=self.collate_fn, num_workers=self.num_workers |
223 | pin_memory=True, collate_fn=self.collate_fn, num_workers=self.num_workers | 225 | ) |
224 | ) | ||
225 | for dataset in train_datasets | ||
226 | ] | ||
227 | 226 | ||
228 | self.val_dataloader = DataLoader( | 227 | self.val_dataloader = DataLoader( |
229 | val_dataset, batch_size=self.batch_size, | 228 | val_dataset, |
230 | pin_memory=True, collate_fn=self.collate_fn, num_workers=self.num_workers | 229 | batch_size=None, pin_memory=True, collate_fn=self.collate_fn, num_workers=self.num_workers |
231 | ) | 230 | ) |
232 | 231 | ||
233 | 232 | ||
234 | class VlpnDataset(Dataset): | 233 | class VlpnDataset(IterableDataset): |
235 | def __init__( | 234 | def __init__( |
236 | self, | 235 | self, |
237 | data: List[VlpnDataItem], | 236 | items: list[VlpnDataItem], |
238 | prompt_processor: PromptProcessor, | 237 | prompt_processor: PromptProcessor, |
238 | num_buckets: int = 1, | ||
239 | progressive_buckets: bool = False, | ||
240 | batch_size: int = 1, | ||
239 | num_class_images: int = 0, | 241 | num_class_images: int = 0, |
240 | width: int = 768, | 242 | size: int = 768, |
241 | height: int = 768, | ||
242 | repeats: int = 1, | ||
243 | dropout: float = 0, | 243 | dropout: float = 0, |
244 | shuffle: bool = False, | ||
244 | interpolation: str = "bicubic", | 245 | interpolation: str = "bicubic", |
246 | generator: Optional[torch.Generator] = None, | ||
245 | ): | 247 | ): |
248 | self.items = items | ||
249 | self.batch_size = batch_size | ||
246 | 250 | ||
247 | self.data = data | ||
248 | self.prompt_processor = prompt_processor | 251 | self.prompt_processor = prompt_processor |
249 | self.num_class_images = num_class_images | 252 | self.num_class_images = num_class_images |
253 | self.size = size | ||
250 | self.dropout = dropout | 254 | self.dropout = dropout |
251 | 255 | self.shuffle = shuffle | |
252 | self.num_instance_images = len(self.data) | 256 | self.interpolation = interpolations[interpolation] |
253 | self._length = self.num_instance_images * repeats | 257 | self.generator = generator |
254 | 258 | ||
255 | self.interpolation = { | 259 | buckets, item_order, item_buckets = generate_buckets( |
256 | "linear": transforms.InterpolationMode.NEAREST, | 260 | [item.instance_image_path for item in items], |
257 | "bilinear": transforms.InterpolationMode.BILINEAR, | 261 | size, |
258 | "bicubic": transforms.InterpolationMode.BICUBIC, | 262 | num_buckets, |
259 | "lanczos": transforms.InterpolationMode.LANCZOS, | 263 | progressive_buckets |
260 | }[interpolation] | ||
261 | self.image_transforms = transforms.Compose( | ||
262 | [ | ||
263 | transforms.Resize(min(width, height), interpolation=self.interpolation), | ||
264 | transforms.RandomCrop((height, width)), | ||
265 | transforms.RandomHorizontalFlip(), | ||
266 | transforms.ToTensor(), | ||
267 | transforms.Normalize([0.5], [0.5]), | ||
268 | ] | ||
269 | ) | 264 | ) |
270 | 265 | ||
271 | def __len__(self): | 266 | self.buckets = torch.tensor(buckets) |
272 | return self._length | 267 | self.item_order = torch.tensor(item_order) |
268 | self.item_buckets = torch.tensor(item_buckets) | ||
273 | 269 | ||
274 | def get_example(self, i): | 270 | def __len__(self): |
275 | item = self.data[i % self.num_instance_images] | 271 | return len(self.item_buckets) |
276 | 272 | ||
277 | example = {} | 273 | def __iter__(self): |
278 | example["prompts"] = item.prompt | 274 | worker_info = torch.utils.data.get_worker_info() |
279 | example["cprompts"] = item.cprompt | 275 | |
280 | example["nprompts"] = item.nprompt | 276 | if self.shuffle: |
281 | example["instance_images"] = get_image(item.instance_image_path) | 277 | perm = torch.randperm(len(self.item_buckets), generator=self.generator) |
282 | if self.num_class_images != 0: | 278 | self.item_order = self.item_order[perm] |
283 | example["class_images"] = get_image(item.class_image_path) | 279 | self.item_buckets = self.item_buckets[perm] |
284 | 280 | ||
285 | return example | 281 | item_mask = torch.ones_like(self.item_buckets, dtype=bool) |
282 | bucket = -1 | ||
283 | image_transforms = None | ||
284 | batch = [] | ||
285 | batch_size = self.batch_size | ||
286 | |||
287 | if worker_info is not None: | ||
288 | batch_size = math.ceil(batch_size / worker_info.num_workers) | ||
289 | worker_batch = math.ceil(len(self) / worker_info.num_workers) | ||
290 | start = worker_info.id * worker_batch | ||
291 | end = start + worker_batch | ||
292 | item_mask[:start] = False | ||
293 | item_mask[end:] = False | ||
294 | |||
295 | while item_mask.any(): | ||
296 | item_indices = self.item_order[(self.item_buckets == bucket) & item_mask] | ||
297 | |||
298 | if len(batch) >= batch_size or (len(item_indices) == 0 and len(batch) != 0): | ||
299 | yield batch | ||
300 | batch = [] | ||
301 | |||
302 | if len(item_indices) == 0: | ||
303 | bucket = self.item_buckets[item_mask][0] | ||
304 | ratio = self.buckets[bucket] | ||
305 | width = self.size * ratio if ratio > 1 else self.size | ||
306 | height = self.size / ratio if ratio < 1 else self.size | ||
307 | |||
308 | image_transforms = transforms.Compose( | ||
309 | [ | ||
310 | transforms.Resize(min(width, height), interpolation=self.interpolation), | ||
311 | transforms.RandomCrop((height, width)), | ||
312 | transforms.RandomHorizontalFlip(), | ||
313 | transforms.ToTensor(), | ||
314 | transforms.Normalize([0.5], [0.5]), | ||
315 | ] | ||
316 | ) | ||
317 | else: | ||
318 | item_index = item_indices[0] | ||
319 | item = self.items[item_index] | ||
320 | item_mask[item_index] = False | ||
286 | 321 | ||
287 | def __getitem__(self, i): | 322 | example = {} |
288 | unprocessed_example = self.get_example(i) | ||
289 | 323 | ||
290 | example = {} | 324 | example["prompts"] = keywords_to_prompt(item.prompt) |
325 | example["cprompts"] = item.cprompt | ||
326 | example["nprompts"] = item.nprompt | ||
291 | 327 | ||
292 | example["prompts"] = keywords_to_prompt(unprocessed_example["prompts"]) | 328 | example["instance_images"] = image_transforms(get_image(item.instance_image_path)) |
293 | example["cprompts"] = unprocessed_example["cprompts"] | 329 | example["instance_prompt_ids"] = self.prompt_processor.get_input_ids( |
294 | example["nprompts"] = unprocessed_example["nprompts"] | 330 | keywords_to_prompt(item.prompt, self.dropout, True) |
331 | ) | ||
295 | 332 | ||
296 | example["instance_images"] = self.image_transforms(unprocessed_example["instance_images"]) | 333 | if self.num_class_images != 0: |
297 | example["instance_prompt_ids"] = self.prompt_processor.get_input_ids( | 334 | example["class_images"] = image_transforms(get_image(item.class_image_path)) |
298 | keywords_to_prompt(unprocessed_example["prompts"], self.dropout, True) | 335 | example["class_prompt_ids"] = self.prompt_processor.get_input_ids(example["cprompts"]) |
299 | ) | ||
300 | 336 | ||
301 | if self.num_class_images != 0: | 337 | batch.append(example) |
302 | example["class_images"] = self.image_transforms(unprocessed_example["class_images"]) | ||
303 | example["class_prompt_ids"] = self.prompt_processor.get_input_ids(example["cprompts"]) | ||
304 | 338 | ||
305 | return example | 339 | if len(batch) != 0: |
340 | yield batch | ||
diff --git a/train_dreambooth.py b/train_dreambooth.py index 589af59..42a7d0f 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -134,12 +134,6 @@ def parse_args(): | |||
134 | help="The directory where class images will be saved.", | 134 | help="The directory where class images will be saved.", |
135 | ) | 135 | ) |
136 | parser.add_argument( | 136 | parser.add_argument( |
137 | "--repeats", | ||
138 | type=int, | ||
139 | default=1, | ||
140 | help="How many times to repeat the training data." | ||
141 | ) | ||
142 | parser.add_argument( | ||
143 | "--output_dir", | 137 | "--output_dir", |
144 | type=str, | 138 | type=str, |
145 | default="output/dreambooth", | 139 | default="output/dreambooth", |
@@ -738,7 +732,6 @@ def main(): | |||
738 | class_subdir=args.class_image_dir, | 732 | class_subdir=args.class_image_dir, |
739 | num_class_images=args.num_class_images, | 733 | num_class_images=args.num_class_images, |
740 | size=args.resolution, | 734 | size=args.resolution, |
741 | repeats=args.repeats, | ||
742 | dropout=args.tag_dropout, | 735 | dropout=args.tag_dropout, |
743 | template_key=args.train_data_template, | 736 | template_key=args.train_data_template, |
744 | valid_set_size=args.valid_set_size, | 737 | valid_set_size=args.valid_set_size, |
@@ -751,7 +744,7 @@ def main(): | |||
751 | datamodule.prepare_data() | 744 | datamodule.prepare_data() |
752 | datamodule.setup() | 745 | datamodule.setup() |
753 | 746 | ||
754 | train_dataloaders = datamodule.train_dataloaders | 747 | train_dataloader = datamodule.train_dataloader |
755 | val_dataloader = datamodule.val_dataloader | 748 | val_dataloader = datamodule.val_dataloader |
756 | 749 | ||
757 | if args.num_class_images != 0: | 750 | if args.num_class_images != 0: |
@@ -770,8 +763,7 @@ def main(): | |||
770 | 763 | ||
771 | # Scheduler and math around the number of training steps. | 764 | # Scheduler and math around the number of training steps. |
772 | overrode_max_train_steps = False | 765 | overrode_max_train_steps = False |
773 | num_update_steps_per_dataloader = sum(len(dataloader) for dataloader in train_dataloaders) | 766 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) |
774 | num_update_steps_per_epoch = math.ceil(num_update_steps_per_dataloader / args.gradient_accumulation_steps) | ||
775 | if args.max_train_steps is None: | 767 | if args.max_train_steps is None: |
776 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch | 768 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch |
777 | overrode_max_train_steps = True | 769 | overrode_max_train_steps = True |
@@ -820,8 +812,7 @@ def main(): | |||
820 | ema_unet.to(accelerator.device) | 812 | ema_unet.to(accelerator.device) |
821 | 813 | ||
822 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. | 814 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. |
823 | num_update_steps_per_dataloader = sum(len(dataloader) for dataloader in train_dataloaders) | 815 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) |
824 | num_update_steps_per_epoch = math.ceil(num_update_steps_per_dataloader / args.gradient_accumulation_steps) | ||
825 | if overrode_max_train_steps: | 816 | if overrode_max_train_steps: |
826 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch | 817 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch |
827 | 818 | ||
@@ -877,7 +868,7 @@ def main(): | |||
877 | accelerator, | 868 | accelerator, |
878 | text_encoder, | 869 | text_encoder, |
879 | optimizer, | 870 | optimizer, |
880 | train_dataloaders[0], | 871 | train_dataloader, |
881 | val_dataloader, | 872 | val_dataloader, |
882 | loop, | 873 | loop, |
883 | on_train=tokenizer.train, | 874 | on_train=tokenizer.train, |
@@ -960,54 +951,53 @@ def main(): | |||
960 | text_encoder.requires_grad_(False) | 951 | text_encoder.requires_grad_(False) |
961 | 952 | ||
962 | with on_train(): | 953 | with on_train(): |
963 | for train_dataloader in train_dataloaders: | 954 | for step, batch in enumerate(train_dataloader): |
964 | for step, batch in enumerate(train_dataloader): | 955 | with accelerator.accumulate(unet): |
965 | with accelerator.accumulate(unet): | 956 | loss, acc, bsz = loop(step, batch) |
966 | loss, acc, bsz = loop(step, batch) | ||
967 | |||
968 | accelerator.backward(loss) | ||
969 | |||
970 | if accelerator.sync_gradients: | ||
971 | params_to_clip = ( | ||
972 | itertools.chain(unet.parameters(), text_encoder.parameters()) | ||
973 | if args.train_text_encoder and epoch < args.train_text_encoder_epochs | ||
974 | else unet.parameters() | ||
975 | ) | ||
976 | accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) | ||
977 | |||
978 | optimizer.step() | ||
979 | if not accelerator.optimizer_step_was_skipped: | ||
980 | lr_scheduler.step() | ||
981 | if args.use_ema: | ||
982 | ema_unet.step(unet.parameters()) | ||
983 | optimizer.zero_grad(set_to_none=True) | ||
984 | |||
985 | avg_loss.update(loss.detach_(), bsz) | ||
986 | avg_acc.update(acc.detach_(), bsz) | ||
987 | |||
988 | # Checks if the accelerator has performed an optimization step behind the scenes | ||
989 | if accelerator.sync_gradients: | ||
990 | local_progress_bar.update(1) | ||
991 | global_progress_bar.update(1) | ||
992 | 957 | ||
993 | global_step += 1 | 958 | accelerator.backward(loss) |
994 | 959 | ||
995 | logs = { | 960 | if accelerator.sync_gradients: |
996 | "train/loss": avg_loss.avg.item(), | 961 | params_to_clip = ( |
997 | "train/acc": avg_acc.avg.item(), | 962 | itertools.chain(unet.parameters(), text_encoder.parameters()) |
998 | "train/cur_loss": loss.item(), | 963 | if args.train_text_encoder and epoch < args.train_text_encoder_epochs |
999 | "train/cur_acc": acc.item(), | 964 | else unet.parameters() |
1000 | "lr": lr_scheduler.get_last_lr()[0] | 965 | ) |
1001 | } | 966 | accelerator.clip_grad_norm_(params_to_clip, args.max_grad_norm) |
967 | |||
968 | optimizer.step() | ||
969 | if not accelerator.optimizer_step_was_skipped: | ||
970 | lr_scheduler.step() | ||
1002 | if args.use_ema: | 971 | if args.use_ema: |
1003 | logs["ema_decay"] = 1 - ema_unet.decay | 972 | ema_unet.step(unet.parameters()) |
973 | optimizer.zero_grad(set_to_none=True) | ||
1004 | 974 | ||
1005 | accelerator.log(logs, step=global_step) | 975 | avg_loss.update(loss.detach_(), bsz) |
976 | avg_acc.update(acc.detach_(), bsz) | ||
1006 | 977 | ||
1007 | local_progress_bar.set_postfix(**logs) | 978 | # Checks if the accelerator has performed an optimization step behind the scenes |
979 | if accelerator.sync_gradients: | ||
980 | local_progress_bar.update(1) | ||
981 | global_progress_bar.update(1) | ||
982 | |||
983 | global_step += 1 | ||
984 | |||
985 | logs = { | ||
986 | "train/loss": avg_loss.avg.item(), | ||
987 | "train/acc": avg_acc.avg.item(), | ||
988 | "train/cur_loss": loss.item(), | ||
989 | "train/cur_acc": acc.item(), | ||
990 | "lr": lr_scheduler.get_last_lr()[0] | ||
991 | } | ||
992 | if args.use_ema: | ||
993 | logs["ema_decay"] = 1 - ema_unet.decay | ||
994 | |||
995 | accelerator.log(logs, step=global_step) | ||
996 | |||
997 | local_progress_bar.set_postfix(**logs) | ||
1008 | 998 | ||
1009 | if global_step >= args.max_train_steps: | 999 | if global_step >= args.max_train_steps: |
1010 | break | 1000 | break |
1011 | 1001 | ||
1012 | accelerator.wait_for_everyone() | 1002 | accelerator.wait_for_everyone() |
1013 | 1003 | ||
diff --git a/train_ti.py b/train_ti.py index b4b602b..727b591 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -107,12 +107,6 @@ def parse_args(): | |||
107 | help="Exclude all items with a listed collection.", | 107 | help="Exclude all items with a listed collection.", |
108 | ) | 108 | ) |
109 | parser.add_argument( | 109 | parser.add_argument( |
110 | "--repeats", | ||
111 | type=int, | ||
112 | default=1, | ||
113 | help="How many times to repeat the training data." | ||
114 | ) | ||
115 | parser.add_argument( | ||
116 | "--output_dir", | 110 | "--output_dir", |
117 | type=str, | 111 | type=str, |
118 | default="output/text-inversion", | 112 | default="output/text-inversion", |
@@ -722,7 +716,6 @@ def main(): | |||
722 | size=args.resolution, | 716 | size=args.resolution, |
723 | num_aspect_ratio_buckets=args.num_aspect_ratio_buckets, | 717 | num_aspect_ratio_buckets=args.num_aspect_ratio_buckets, |
724 | progressive_aspect_ratio_buckets=args.progressive_aspect_ratio_buckets, | 718 | progressive_aspect_ratio_buckets=args.progressive_aspect_ratio_buckets, |
725 | repeats=args.repeats, | ||
726 | dropout=args.tag_dropout, | 719 | dropout=args.tag_dropout, |
727 | template_key=args.train_data_template, | 720 | template_key=args.train_data_template, |
728 | valid_set_size=args.valid_set_size, | 721 | valid_set_size=args.valid_set_size, |
@@ -733,7 +726,7 @@ def main(): | |||
733 | ) | 726 | ) |
734 | datamodule.setup() | 727 | datamodule.setup() |
735 | 728 | ||
736 | train_dataloaders = datamodule.train_dataloaders | 729 | train_dataloader = datamodule.train_dataloader |
737 | val_dataloader = datamodule.val_dataloader | 730 | val_dataloader = datamodule.val_dataloader |
738 | 731 | ||
739 | if args.num_class_images != 0: | 732 | if args.num_class_images != 0: |
@@ -752,8 +745,7 @@ def main(): | |||
752 | 745 | ||
753 | # Scheduler and math around the number of training steps. | 746 | # Scheduler and math around the number of training steps. |
754 | overrode_max_train_steps = False | 747 | overrode_max_train_steps = False |
755 | num_update_steps_per_dataloader = sum(len(dataloader) for dataloader in train_dataloaders) | 748 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) |
756 | num_update_steps_per_epoch = math.ceil(num_update_steps_per_dataloader / args.gradient_accumulation_steps) | ||
757 | if args.max_train_steps is None: | 749 | if args.max_train_steps is None: |
758 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch | 750 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch |
759 | overrode_max_train_steps = True | 751 | overrode_max_train_steps = True |
@@ -790,10 +782,9 @@ def main(): | |||
790 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | 782 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, |
791 | ) | 783 | ) |
792 | 784 | ||
793 | text_encoder, optimizer, val_dataloader, lr_scheduler = accelerator.prepare( | 785 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( |
794 | text_encoder, optimizer, val_dataloader, lr_scheduler | 786 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler |
795 | ) | 787 | ) |
796 | train_dataloaders = accelerator.prepare(*train_dataloaders) | ||
797 | 788 | ||
798 | # Move vae and unet to device | 789 | # Move vae and unet to device |
799 | vae.to(accelerator.device, dtype=weight_dtype) | 790 | vae.to(accelerator.device, dtype=weight_dtype) |
@@ -811,8 +802,7 @@ def main(): | |||
811 | unet.eval() | 802 | unet.eval() |
812 | 803 | ||
813 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. | 804 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. |
814 | num_update_steps_per_dataloader = sum(len(dataloader) for dataloader in train_dataloaders) | 805 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) |
815 | num_update_steps_per_epoch = math.ceil(num_update_steps_per_dataloader / args.gradient_accumulation_steps) | ||
816 | if overrode_max_train_steps: | 806 | if overrode_max_train_steps: |
817 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch | 807 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch |
818 | 808 | ||
@@ -870,7 +860,7 @@ def main(): | |||
870 | accelerator, | 860 | accelerator, |
871 | text_encoder, | 861 | text_encoder, |
872 | optimizer, | 862 | optimizer, |
873 | train_dataloaders[0], | 863 | train_dataloader, |
874 | val_dataloader, | 864 | val_dataloader, |
875 | loop, | 865 | loop, |
876 | on_train=on_train, | 866 | on_train=on_train, |
@@ -949,48 +939,47 @@ def main(): | |||
949 | text_encoder.train() | 939 | text_encoder.train() |
950 | 940 | ||
951 | with on_train(): | 941 | with on_train(): |
952 | for train_dataloader in train_dataloaders: | 942 | for step, batch in enumerate(train_dataloader): |
953 | for step, batch in enumerate(train_dataloader): | 943 | with accelerator.accumulate(text_encoder): |
954 | with accelerator.accumulate(text_encoder): | 944 | loss, acc, bsz = loop(step, batch) |
955 | loss, acc, bsz = loop(step, batch) | ||
956 | 945 | ||
957 | accelerator.backward(loss) | 946 | accelerator.backward(loss) |
958 | 947 | ||
959 | optimizer.step() | 948 | optimizer.step() |
960 | if not accelerator.optimizer_step_was_skipped: | 949 | if not accelerator.optimizer_step_was_skipped: |
961 | lr_scheduler.step() | 950 | lr_scheduler.step() |
962 | optimizer.zero_grad(set_to_none=True) | 951 | optimizer.zero_grad(set_to_none=True) |
963 | 952 | ||
964 | avg_loss.update(loss.detach_(), bsz) | 953 | avg_loss.update(loss.detach_(), bsz) |
965 | avg_acc.update(acc.detach_(), bsz) | 954 | avg_acc.update(acc.detach_(), bsz) |
966 | 955 | ||
967 | # Checks if the accelerator has performed an optimization step behind the scenes | 956 | # Checks if the accelerator has performed an optimization step behind the scenes |
968 | if accelerator.sync_gradients: | 957 | if accelerator.sync_gradients: |
969 | if args.use_ema: | 958 | if args.use_ema: |
970 | ema_embeddings.step( | 959 | ema_embeddings.step( |
971 | text_encoder.text_model.embeddings.temp_token_embedding.parameters()) | 960 | text_encoder.text_model.embeddings.temp_token_embedding.parameters()) |
972 | 961 | ||
973 | local_progress_bar.update(1) | 962 | local_progress_bar.update(1) |
974 | global_progress_bar.update(1) | 963 | global_progress_bar.update(1) |
975 | 964 | ||
976 | global_step += 1 | 965 | global_step += 1 |
977 | 966 | ||
978 | logs = { | 967 | logs = { |
979 | "train/loss": avg_loss.avg.item(), | 968 | "train/loss": avg_loss.avg.item(), |
980 | "train/acc": avg_acc.avg.item(), | 969 | "train/acc": avg_acc.avg.item(), |
981 | "train/cur_loss": loss.item(), | 970 | "train/cur_loss": loss.item(), |
982 | "train/cur_acc": acc.item(), | 971 | "train/cur_acc": acc.item(), |
983 | "lr": lr_scheduler.get_last_lr()[0], | 972 | "lr": lr_scheduler.get_last_lr()[0], |
984 | } | 973 | } |
985 | if args.use_ema: | 974 | if args.use_ema: |
986 | logs["ema_decay"] = ema_embeddings.decay | 975 | logs["ema_decay"] = ema_embeddings.decay |
987 | 976 | ||
988 | accelerator.log(logs, step=global_step) | 977 | accelerator.log(logs, step=global_step) |
989 | 978 | ||
990 | local_progress_bar.set_postfix(**logs) | 979 | local_progress_bar.set_postfix(**logs) |
991 | 980 | ||
992 | if global_step >= args.max_train_steps: | 981 | if global_step >= args.max_train_steps: |
993 | break | 982 | break |
994 | 983 | ||
995 | accelerator.wait_for_everyone() | 984 | accelerator.wait_for_everyone() |
996 | 985 | ||
diff --git a/training/util.py b/training/util.py index 2b7f71d..ae6bfc4 100644 --- a/training/util.py +++ b/training/util.py | |||
@@ -59,7 +59,7 @@ class CheckpointerBase: | |||
59 | def save_samples(self, pipeline, step, num_inference_steps, guidance_scale=7.5, eta=0.0): | 59 | def save_samples(self, pipeline, step, num_inference_steps, guidance_scale=7.5, eta=0.0): |
60 | samples_path = Path(self.output_dir).joinpath("samples") | 60 | samples_path = Path(self.output_dir).joinpath("samples") |
61 | 61 | ||
62 | train_data = self.datamodule.train_dataloaders[0] | 62 | train_data = self.datamodule.train_dataloader |
63 | val_data = self.datamodule.val_dataloader | 63 | val_data = self.datamodule.val_dataloader |
64 | 64 | ||
65 | generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) | 65 | generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) |