diff options
author | Volpeon <git@volpeon.ink> | 2023-01-07 13:57:46 +0100 |
---|---|---|
committer | Volpeon <git@volpeon.ink> | 2023-01-07 13:57:46 +0100 |
commit | 3ee13893f9a4973ac75f45fe9318c35760dd4b1f (patch) | |
tree | e652a54e6c241eef52ddb30f2d7048da8f306f7b | |
parent | Update (diff) | |
download | textual-inversion-diff-3ee13893f9a4973ac75f45fe9318c35760dd4b1f.tar.gz textual-inversion-diff-3ee13893f9a4973ac75f45fe9318c35760dd4b1f.tar.bz2 textual-inversion-diff-3ee13893f9a4973ac75f45fe9318c35760dd4b1f.zip |
Added progressive aspect ratio bucketing
-rw-r--r-- | data/csv.py | 144 | ||||
-rw-r--r-- | infer.py | 23 | ||||
-rw-r--r-- | train_dreambooth.py | 12 | ||||
-rw-r--r-- | train_ti.py | 94 | ||||
-rw-r--r-- | training/util.py | 4 |
5 files changed, 151 insertions, 126 deletions
diff --git a/data/csv.py b/data/csv.py index 4986153..59d6d8d 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -11,11 +11,26 @@ from models.clip.prompt import PromptProcessor | |||
11 | from data.keywords import prompt_to_keywords, keywords_to_prompt | 11 | from data.keywords import prompt_to_keywords, keywords_to_prompt |
12 | 12 | ||
13 | 13 | ||
14 | image_cache: dict[str, Image.Image] = {} | ||
15 | |||
16 | |||
17 | def get_image(path): | ||
18 | if path in image_cache: | ||
19 | return image_cache[path] | ||
20 | |||
21 | image = Image.open(path) | ||
22 | if not image.mode == "RGB": | ||
23 | image = image.convert("RGB") | ||
24 | image_cache[path] = image | ||
25 | |||
26 | return image | ||
27 | |||
28 | |||
14 | def prepare_prompt(prompt: Union[str, Dict[str, str]]): | 29 | def prepare_prompt(prompt: Union[str, Dict[str, str]]): |
15 | return {"content": prompt} if isinstance(prompt, str) else prompt | 30 | return {"content": prompt} if isinstance(prompt, str) else prompt |
16 | 31 | ||
17 | 32 | ||
18 | class CSVDataItem(NamedTuple): | 33 | class VlpnDataItem(NamedTuple): |
19 | instance_image_path: Path | 34 | instance_image_path: Path |
20 | class_image_path: Path | 35 | class_image_path: Path |
21 | prompt: list[str] | 36 | prompt: list[str] |
@@ -24,7 +39,15 @@ class CSVDataItem(NamedTuple): | |||
24 | collection: list[str] | 39 | collection: list[str] |
25 | 40 | ||
26 | 41 | ||
27 | class CSVDataModule(): | 42 | class VlpnDataBucket(): |
43 | def __init__(self, width: int, height: int): | ||
44 | self.width = width | ||
45 | self.height = height | ||
46 | self.ratio = width / height | ||
47 | self.items: list[VlpnDataItem] = [] | ||
48 | |||
49 | |||
50 | class VlpnDataModule(): | ||
28 | def __init__( | 51 | def __init__( |
29 | self, | 52 | self, |
30 | batch_size: int, | 53 | batch_size: int, |
@@ -36,11 +59,10 @@ class CSVDataModule(): | |||
36 | repeats: int = 1, | 59 | repeats: int = 1, |
37 | dropout: float = 0, | 60 | dropout: float = 0, |
38 | interpolation: str = "bicubic", | 61 | interpolation: str = "bicubic", |
39 | center_crop: bool = False, | ||
40 | template_key: str = "template", | 62 | template_key: str = "template", |
41 | valid_set_size: Optional[int] = None, | 63 | valid_set_size: Optional[int] = None, |
42 | seed: Optional[int] = None, | 64 | seed: Optional[int] = None, |
43 | filter: Optional[Callable[[CSVDataItem], bool]] = None, | 65 | filter: Optional[Callable[[VlpnDataItem], bool]] = None, |
44 | collate_fn=None, | 66 | collate_fn=None, |
45 | num_workers: int = 0 | 67 | num_workers: int = 0 |
46 | ): | 68 | ): |
@@ -60,7 +82,6 @@ class CSVDataModule(): | |||
60 | self.size = size | 82 | self.size = size |
61 | self.repeats = repeats | 83 | self.repeats = repeats |
62 | self.dropout = dropout | 84 | self.dropout = dropout |
63 | self.center_crop = center_crop | ||
64 | self.template_key = template_key | 85 | self.template_key = template_key |
65 | self.interpolation = interpolation | 86 | self.interpolation = interpolation |
66 | self.valid_set_size = valid_set_size | 87 | self.valid_set_size = valid_set_size |
@@ -70,14 +91,14 @@ class CSVDataModule(): | |||
70 | self.num_workers = num_workers | 91 | self.num_workers = num_workers |
71 | self.batch_size = batch_size | 92 | self.batch_size = batch_size |
72 | 93 | ||
73 | def prepare_items(self, template, expansions, data) -> list[CSVDataItem]: | 94 | def prepare_items(self, template, expansions, data) -> list[VlpnDataItem]: |
74 | image = template["image"] if "image" in template else "{}" | 95 | image = template["image"] if "image" in template else "{}" |
75 | prompt = template["prompt"] if "prompt" in template else "{content}" | 96 | prompt = template["prompt"] if "prompt" in template else "{content}" |
76 | cprompt = template["cprompt"] if "cprompt" in template else "{content}" | 97 | cprompt = template["cprompt"] if "cprompt" in template else "{content}" |
77 | nprompt = template["nprompt"] if "nprompt" in template else "{content}" | 98 | nprompt = template["nprompt"] if "nprompt" in template else "{content}" |
78 | 99 | ||
79 | return [ | 100 | return [ |
80 | CSVDataItem( | 101 | VlpnDataItem( |
81 | self.data_root.joinpath(image.format(item["image"])), | 102 | self.data_root.joinpath(image.format(item["image"])), |
82 | None, | 103 | None, |
83 | prompt_to_keywords( | 104 | prompt_to_keywords( |
@@ -97,17 +118,17 @@ class CSVDataModule(): | |||
97 | for item in data | 118 | for item in data |
98 | ] | 119 | ] |
99 | 120 | ||
100 | def filter_items(self, items: list[CSVDataItem]) -> list[CSVDataItem]: | 121 | def filter_items(self, items: list[VlpnDataItem]) -> list[VlpnDataItem]: |
101 | if self.filter is None: | 122 | if self.filter is None: |
102 | return items | 123 | return items |
103 | 124 | ||
104 | return [item for item in items if self.filter(item)] | 125 | return [item for item in items if self.filter(item)] |
105 | 126 | ||
106 | def pad_items(self, items: list[CSVDataItem], num_class_images: int = 1) -> list[CSVDataItem]: | 127 | def pad_items(self, items: list[VlpnDataItem], num_class_images: int = 1) -> list[VlpnDataItem]: |
107 | image_multiplier = max(num_class_images, 1) | 128 | image_multiplier = max(num_class_images, 1) |
108 | 129 | ||
109 | return [ | 130 | return [ |
110 | CSVDataItem( | 131 | VlpnDataItem( |
111 | item.instance_image_path, | 132 | item.instance_image_path, |
112 | self.class_root.joinpath(f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}"), | 133 | self.class_root.joinpath(f"{item.instance_image_path.stem}_{i}{item.instance_image_path.suffix}"), |
113 | item.prompt, | 134 | item.prompt, |
@@ -119,7 +140,30 @@ class CSVDataModule(): | |||
119 | for i in range(image_multiplier) | 140 | for i in range(image_multiplier) |
120 | ] | 141 | ] |
121 | 142 | ||
122 | def prepare_data(self): | 143 | def generate_buckets(self, items: list[VlpnDataItem]): |
144 | buckets = [VlpnDataBucket(self.size, self.size)] | ||
145 | |||
146 | for i in range(1, 5): | ||
147 | s = self.size + i * 64 | ||
148 | buckets.append(VlpnDataBucket(s, self.size)) | ||
149 | buckets.append(VlpnDataBucket(self.size, s)) | ||
150 | |||
151 | for item in items: | ||
152 | image = get_image(item.instance_image_path) | ||
153 | ratio = image.width / image.height | ||
154 | |||
155 | if ratio >= 1: | ||
156 | candidates = [bucket for bucket in buckets if bucket.ratio >= 1 and ratio >= bucket.ratio] | ||
157 | else: | ||
158 | candidates = [bucket for bucket in buckets if bucket.ratio <= 1 and ratio <= bucket.ratio] | ||
159 | |||
160 | for bucket in candidates: | ||
161 | bucket.items.append(item) | ||
162 | |||
163 | buckets = [bucket for bucket in buckets if len(bucket.items) != 0] | ||
164 | return buckets | ||
165 | |||
166 | def setup(self): | ||
123 | with open(self.data_file, 'rt') as f: | 167 | with open(self.data_file, 'rt') as f: |
124 | metadata = json.load(f) | 168 | metadata = json.load(f) |
125 | template = metadata[self.template_key] if self.template_key in metadata else {} | 169 | template = metadata[self.template_key] if self.template_key in metadata else {} |
@@ -144,48 +188,48 @@ class CSVDataModule(): | |||
144 | self.data_train = self.pad_items(data_train, self.num_class_images) | 188 | self.data_train = self.pad_items(data_train, self.num_class_images) |
145 | self.data_val = self.pad_items(data_val) | 189 | self.data_val = self.pad_items(data_val) |
146 | 190 | ||
147 | def setup(self, stage=None): | 191 | buckets = self.generate_buckets(data_train) |
148 | train_dataset = CSVDataset( | 192 | |
149 | self.data_train, self.prompt_processor, batch_size=self.batch_size, | 193 | train_datasets = [ |
150 | num_class_images=self.num_class_images, | 194 | VlpnDataset( |
151 | size=self.size, interpolation=self.interpolation, | 195 | bucket.items, self.prompt_processor, batch_size=self.batch_size, |
152 | center_crop=self.center_crop, repeats=self.repeats, dropout=self.dropout | 196 | width=bucket.width, height=bucket.height, interpolation=self.interpolation, |
153 | ) | 197 | num_class_images=self.num_class_images, repeats=self.repeats, dropout=self.dropout, |
154 | val_dataset = CSVDataset( | 198 | ) |
155 | self.data_val, self.prompt_processor, batch_size=self.batch_size, | 199 | for bucket in buckets |
156 | size=self.size, interpolation=self.interpolation, | 200 | ] |
157 | center_crop=self.center_crop | 201 | |
158 | ) | 202 | val_dataset = VlpnDataset( |
159 | self.train_dataloader_ = DataLoader( | 203 | data_val, self.prompt_processor, batch_size=self.batch_size, |
160 | train_dataset, batch_size=self.batch_size, | 204 | width=self.size, height=self.size, interpolation=self.interpolation, |
161 | shuffle=True, pin_memory=True, collate_fn=self.collate_fn, | ||
162 | num_workers=self.num_workers | ||
163 | ) | ||
164 | self.val_dataloader_ = DataLoader( | ||
165 | val_dataset, batch_size=self.batch_size, | ||
166 | pin_memory=True, collate_fn=self.collate_fn, | ||
167 | num_workers=self.num_workers | ||
168 | ) | 205 | ) |
169 | 206 | ||
170 | def train_dataloader(self): | 207 | self.train_dataloaders = [ |
171 | return self.train_dataloader_ | 208 | DataLoader( |
209 | dataset, batch_size=self.batch_size, shuffle=True, | ||
210 | pin_memory=True, collate_fn=self.collate_fn, num_workers=self.num_workers | ||
211 | ) | ||
212 | for dataset in train_datasets | ||
213 | ] | ||
172 | 214 | ||
173 | def val_dataloader(self): | 215 | self.val_dataloader = DataLoader( |
174 | return self.val_dataloader_ | 216 | val_dataset, batch_size=self.batch_size, |
217 | pin_memory=True, collate_fn=self.collate_fn, num_workers=self.num_workers | ||
218 | ) | ||
175 | 219 | ||
176 | 220 | ||
177 | class CSVDataset(Dataset): | 221 | class VlpnDataset(Dataset): |
178 | def __init__( | 222 | def __init__( |
179 | self, | 223 | self, |
180 | data: List[CSVDataItem], | 224 | data: List[VlpnDataItem], |
181 | prompt_processor: PromptProcessor, | 225 | prompt_processor: PromptProcessor, |
182 | batch_size: int = 1, | 226 | batch_size: int = 1, |
183 | num_class_images: int = 0, | 227 | num_class_images: int = 0, |
184 | size: int = 768, | 228 | width: int = 768, |
229 | height: int = 768, | ||
185 | repeats: int = 1, | 230 | repeats: int = 1, |
186 | dropout: float = 0, | 231 | dropout: float = 0, |
187 | interpolation: str = "bicubic", | 232 | interpolation: str = "bicubic", |
188 | center_crop: bool = False, | ||
189 | ): | 233 | ): |
190 | 234 | ||
191 | self.data = data | 235 | self.data = data |
@@ -193,7 +237,6 @@ class CSVDataset(Dataset): | |||
193 | self.batch_size = batch_size | 237 | self.batch_size = batch_size |
194 | self.num_class_images = num_class_images | 238 | self.num_class_images = num_class_images |
195 | self.dropout = dropout | 239 | self.dropout = dropout |
196 | self.image_cache = {} | ||
197 | 240 | ||
198 | self.num_instance_images = len(self.data) | 241 | self.num_instance_images = len(self.data) |
199 | self._length = self.num_instance_images * repeats | 242 | self._length = self.num_instance_images * repeats |
@@ -206,8 +249,8 @@ class CSVDataset(Dataset): | |||
206 | }[interpolation] | 249 | }[interpolation] |
207 | self.image_transforms = transforms.Compose( | 250 | self.image_transforms = transforms.Compose( |
208 | [ | 251 | [ |
209 | transforms.Resize(size, interpolation=self.interpolation), | 252 | transforms.Resize(min(width, height), interpolation=self.interpolation), |
210 | transforms.CenterCrop(size) if center_crop else transforms.RandomCrop(size), | 253 | transforms.RandomCrop((height, width)), |
211 | transforms.RandomHorizontalFlip(), | 254 | transforms.RandomHorizontalFlip(), |
212 | transforms.ToTensor(), | 255 | transforms.ToTensor(), |
213 | transforms.Normalize([0.5], [0.5]), | 256 | transforms.Normalize([0.5], [0.5]), |
@@ -217,17 +260,6 @@ class CSVDataset(Dataset): | |||
217 | def __len__(self): | 260 | def __len__(self): |
218 | return math.ceil(self._length / self.batch_size) * self.batch_size | 261 | return math.ceil(self._length / self.batch_size) * self.batch_size |
219 | 262 | ||
220 | def get_image(self, path): | ||
221 | if path in self.image_cache: | ||
222 | return self.image_cache[path] | ||
223 | |||
224 | image = Image.open(path) | ||
225 | if not image.mode == "RGB": | ||
226 | image = image.convert("RGB") | ||
227 | self.image_cache[path] = image | ||
228 | |||
229 | return image | ||
230 | |||
231 | def get_example(self, i): | 263 | def get_example(self, i): |
232 | item = self.data[i % self.num_instance_images] | 264 | item = self.data[i % self.num_instance_images] |
233 | 265 | ||
@@ -235,9 +267,9 @@ class CSVDataset(Dataset): | |||
235 | example["prompts"] = item.prompt | 267 | example["prompts"] = item.prompt |
236 | example["cprompts"] = item.cprompt | 268 | example["cprompts"] = item.cprompt |
237 | example["nprompts"] = item.nprompt | 269 | example["nprompts"] = item.nprompt |
238 | example["instance_images"] = self.get_image(item.instance_image_path) | 270 | example["instance_images"] = get_image(item.instance_image_path) |
239 | if self.num_class_images != 0: | 271 | if self.num_class_images != 0: |
240 | example["class_images"] = self.get_image(item.class_image_path) | 272 | example["class_images"] = get_image(item.class_image_path) |
241 | 273 | ||
242 | return example | 274 | return example |
243 | 275 | ||
@@ -238,16 +238,15 @@ def create_pipeline(model, dtype): | |||
238 | return pipeline | 238 | return pipeline |
239 | 239 | ||
240 | 240 | ||
241 | def shuffle_prompts(prompts: list[str]) -> list[str]: | ||
242 | return [keywords_to_prompt(prompt_to_keywords(prompt), shuffle=True) for prompt in prompts] | ||
243 | |||
244 | |||
241 | @torch.inference_mode() | 245 | @torch.inference_mode() |
242 | def generate(output_dir: Path, pipeline, args): | 246 | def generate(output_dir: Path, pipeline, args): |
243 | if isinstance(args.prompt, str): | 247 | if isinstance(args.prompt, str): |
244 | args.prompt = [args.prompt] | 248 | args.prompt = [args.prompt] |
245 | 249 | ||
246 | if args.shuffle: | ||
247 | args.prompt *= args.batch_size | ||
248 | args.batch_size = 1 | ||
249 | args.prompt = [keywords_to_prompt(prompt_to_keywords(prompt), shuffle=True) for prompt in args.prompt] | ||
250 | |||
251 | args.prompt = [args.template.format(prompt) for prompt in args.prompt] | 250 | args.prompt = [args.template.format(prompt) for prompt in args.prompt] |
252 | 251 | ||
253 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") | 252 | now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S") |
@@ -263,9 +262,6 @@ def generate(output_dir: Path, pipeline, args): | |||
263 | dir = output_dir.joinpath(slugify(prompt)[:100]) | 262 | dir = output_dir.joinpath(slugify(prompt)[:100]) |
264 | dir.mkdir(parents=True, exist_ok=True) | 263 | dir.mkdir(parents=True, exist_ok=True) |
265 | image_dir.append(dir) | 264 | image_dir.append(dir) |
266 | |||
267 | with open(dir.joinpath('prompt.txt'), 'w') as f: | ||
268 | f.write(prompt) | ||
269 | else: | 265 | else: |
270 | output_dir = output_dir.joinpath(f"{now}_{slugify(args.prompt[0])[:100]}") | 266 | output_dir = output_dir.joinpath(f"{now}_{slugify(args.prompt[0])[:100]}") |
271 | output_dir.mkdir(parents=True, exist_ok=True) | 267 | output_dir.mkdir(parents=True, exist_ok=True) |
@@ -306,9 +302,10 @@ def generate(output_dir: Path, pipeline, args): | |||
306 | ) | 302 | ) |
307 | 303 | ||
308 | seed = args.seed + i | 304 | seed = args.seed + i |
305 | prompt = shuffle_prompts(args.prompt) if args.shuffle else args.prompt | ||
309 | generator = torch.Generator(device="cuda").manual_seed(seed) | 306 | generator = torch.Generator(device="cuda").manual_seed(seed) |
310 | images = pipeline( | 307 | images = pipeline( |
311 | prompt=args.prompt, | 308 | prompt=prompt, |
312 | negative_prompt=args.negative_prompt, | 309 | negative_prompt=args.negative_prompt, |
313 | height=args.height, | 310 | height=args.height, |
314 | width=args.width, | 311 | width=args.width, |
@@ -321,9 +318,13 @@ def generate(output_dir: Path, pipeline, args): | |||
321 | ).images | 318 | ).images |
322 | 319 | ||
323 | for j, image in enumerate(images): | 320 | for j, image in enumerate(images): |
321 | basename = f"{seed}_{j // len(args.prompt)}" | ||
324 | dir = image_dir[j % len(args.prompt)] | 322 | dir = image_dir[j % len(args.prompt)] |
325 | image.save(dir.joinpath(f"{seed}_{j // len(args.prompt)}.png")) | 323 | |
326 | image.save(dir.joinpath(f"{seed}_{j // len(args.prompt)}.jpg"), quality=85) | 324 | image.save(dir.joinpath(f"{basename}.png")) |
325 | image.save(dir.joinpath(f"{basename}.jpg"), quality=85) | ||
326 | with open(dir.joinpath(f"{basename}.txt"), 'w') as f: | ||
327 | f.write(prompt[j % len(args.prompt)]) | ||
327 | 328 | ||
328 | if torch.cuda.is_available(): | 329 | if torch.cuda.is_available(): |
329 | torch.cuda.empty_cache() | 330 | torch.cuda.empty_cache() |
diff --git a/train_dreambooth.py b/train_dreambooth.py index e8256be..d265bcc 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -22,7 +22,7 @@ from slugify import slugify | |||
22 | 22 | ||
23 | from util import load_config, load_embeddings_from_dir | 23 | from util import load_config, load_embeddings_from_dir |
24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 24 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
25 | from data.csv import CSVDataModule, CSVDataItem | 25 | from data.csv import VlpnDataModule, VlpnDataItem |
26 | from training.common import run_model | 26 | from training.common import run_model |
27 | from training.optimization import get_one_cycle_schedule | 27 | from training.optimization import get_one_cycle_schedule |
28 | from training.lr import LRFinder | 28 | from training.lr import LRFinder |
@@ -172,11 +172,6 @@ def parse_args(): | |||
172 | ), | 172 | ), |
173 | ) | 173 | ) |
174 | parser.add_argument( | 174 | parser.add_argument( |
175 | "--center_crop", | ||
176 | action="store_true", | ||
177 | help="Whether to center crop images before resizing to resolution" | ||
178 | ) | ||
179 | parser.add_argument( | ||
180 | "--dataloader_num_workers", | 175 | "--dataloader_num_workers", |
181 | type=int, | 176 | type=int, |
182 | default=0, | 177 | default=0, |
@@ -698,7 +693,7 @@ def main(): | |||
698 | elif args.mixed_precision == "bf16": | 693 | elif args.mixed_precision == "bf16": |
699 | weight_dtype = torch.bfloat16 | 694 | weight_dtype = torch.bfloat16 |
700 | 695 | ||
701 | def keyword_filter(item: CSVDataItem): | 696 | def keyword_filter(item: VlpnDataItem): |
702 | cond3 = args.collection is None or args.collection in item.collection | 697 | cond3 = args.collection is None or args.collection in item.collection |
703 | cond4 = args.exclude_collections is None or not any( | 698 | cond4 = args.exclude_collections is None or not any( |
704 | collection in item.collection | 699 | collection in item.collection |
@@ -733,7 +728,7 @@ def main(): | |||
733 | } | 728 | } |
734 | return batch | 729 | return batch |
735 | 730 | ||
736 | datamodule = CSVDataModule( | 731 | datamodule = VlpnDataModule( |
737 | data_file=args.train_data_file, | 732 | data_file=args.train_data_file, |
738 | batch_size=args.train_batch_size, | 733 | batch_size=args.train_batch_size, |
739 | prompt_processor=prompt_processor, | 734 | prompt_processor=prompt_processor, |
@@ -742,7 +737,6 @@ def main(): | |||
742 | size=args.resolution, | 737 | size=args.resolution, |
743 | repeats=args.repeats, | 738 | repeats=args.repeats, |
744 | dropout=args.tag_dropout, | 739 | dropout=args.tag_dropout, |
745 | center_crop=args.center_crop, | ||
746 | template_key=args.train_data_template, | 740 | template_key=args.train_data_template, |
747 | valid_set_size=args.valid_set_size, | 741 | valid_set_size=args.valid_set_size, |
748 | num_workers=args.dataloader_num_workers, | 742 | num_workers=args.dataloader_num_workers, |
diff --git a/train_ti.py b/train_ti.py index 0ffc9e6..89c6672 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -21,7 +21,7 @@ from slugify import slugify | |||
21 | 21 | ||
22 | from util import load_config, load_embeddings_from_dir | 22 | from util import load_config, load_embeddings_from_dir |
23 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion | 23 | from pipelines.stable_diffusion.vlpn_stable_diffusion import VlpnStableDiffusion |
24 | from data.csv import CSVDataModule, CSVDataItem | 24 | from data.csv import VlpnDataModule, VlpnDataItem |
25 | from training.common import run_model | 25 | from training.common import run_model |
26 | from training.optimization import get_one_cycle_schedule | 26 | from training.optimization import get_one_cycle_schedule |
27 | from training.lr import LRFinder | 27 | from training.lr import LRFinder |
@@ -146,11 +146,6 @@ def parse_args(): | |||
146 | ), | 146 | ), |
147 | ) | 147 | ) |
148 | parser.add_argument( | 148 | parser.add_argument( |
149 | "--center_crop", | ||
150 | action="store_true", | ||
151 | help="Whether to center crop images before resizing to resolution" | ||
152 | ) | ||
153 | parser.add_argument( | ||
154 | "--tag_dropout", | 149 | "--tag_dropout", |
155 | type=float, | 150 | type=float, |
156 | default=0.1, | 151 | default=0.1, |
@@ -668,7 +663,7 @@ def main(): | |||
668 | elif args.mixed_precision == "bf16": | 663 | elif args.mixed_precision == "bf16": |
669 | weight_dtype = torch.bfloat16 | 664 | weight_dtype = torch.bfloat16 |
670 | 665 | ||
671 | def keyword_filter(item: CSVDataItem): | 666 | def keyword_filter(item: VlpnDataItem): |
672 | cond1 = any( | 667 | cond1 = any( |
673 | keyword in part | 668 | keyword in part |
674 | for keyword in args.placeholder_token | 669 | for keyword in args.placeholder_token |
@@ -708,7 +703,7 @@ def main(): | |||
708 | } | 703 | } |
709 | return batch | 704 | return batch |
710 | 705 | ||
711 | datamodule = CSVDataModule( | 706 | datamodule = VlpnDataModule( |
712 | data_file=args.train_data_file, | 707 | data_file=args.train_data_file, |
713 | batch_size=args.train_batch_size, | 708 | batch_size=args.train_batch_size, |
714 | prompt_processor=prompt_processor, | 709 | prompt_processor=prompt_processor, |
@@ -717,7 +712,6 @@ def main(): | |||
717 | size=args.resolution, | 712 | size=args.resolution, |
718 | repeats=args.repeats, | 713 | repeats=args.repeats, |
719 | dropout=args.tag_dropout, | 714 | dropout=args.tag_dropout, |
720 | center_crop=args.center_crop, | ||
721 | template_key=args.train_data_template, | 715 | template_key=args.train_data_template, |
722 | valid_set_size=args.valid_set_size, | 716 | valid_set_size=args.valid_set_size, |
723 | num_workers=args.dataloader_num_workers, | 717 | num_workers=args.dataloader_num_workers, |
@@ -725,8 +719,6 @@ def main(): | |||
725 | filter=keyword_filter, | 719 | filter=keyword_filter, |
726 | collate_fn=collate_fn | 720 | collate_fn=collate_fn |
727 | ) | 721 | ) |
728 | |||
729 | datamodule.prepare_data() | ||
730 | datamodule.setup() | 722 | datamodule.setup() |
731 | 723 | ||
732 | if args.num_class_images != 0: | 724 | if args.num_class_images != 0: |
@@ -769,12 +761,14 @@ def main(): | |||
769 | if torch.cuda.is_available(): | 761 | if torch.cuda.is_available(): |
770 | torch.cuda.empty_cache() | 762 | torch.cuda.empty_cache() |
771 | 763 | ||
772 | train_dataloader = datamodule.train_dataloader() | 764 | train_dataloaders = datamodule.train_dataloaders |
773 | val_dataloader = datamodule.val_dataloader() | 765 | default_train_dataloader = train_dataloaders[0] |
766 | val_dataloader = datamodule.val_dataloader | ||
774 | 767 | ||
775 | # Scheduler and math around the number of training steps. | 768 | # Scheduler and math around the number of training steps. |
776 | overrode_max_train_steps = False | 769 | overrode_max_train_steps = False |
777 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) | 770 | num_update_steps_per_dataloader = sum(len(dataloader) for dataloader in train_dataloaders) |
771 | num_update_steps_per_epoch = math.ceil(num_update_steps_per_dataloader / args.gradient_accumulation_steps) | ||
778 | if args.max_train_steps is None: | 772 | if args.max_train_steps is None: |
779 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch | 773 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch |
780 | overrode_max_train_steps = True | 774 | overrode_max_train_steps = True |
@@ -811,9 +805,10 @@ def main(): | |||
811 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, | 805 | num_training_steps=args.max_train_steps * args.gradient_accumulation_steps, |
812 | ) | 806 | ) |
813 | 807 | ||
814 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler = accelerator.prepare( | 808 | text_encoder, optimizer, val_dataloader, lr_scheduler = accelerator.prepare( |
815 | text_encoder, optimizer, train_dataloader, val_dataloader, lr_scheduler | 809 | text_encoder, optimizer, val_dataloader, lr_scheduler |
816 | ) | 810 | ) |
811 | train_dataloaders = accelerator.prepare(*train_dataloaders) | ||
817 | 812 | ||
818 | # Move vae and unet to device | 813 | # Move vae and unet to device |
819 | vae.to(accelerator.device, dtype=weight_dtype) | 814 | vae.to(accelerator.device, dtype=weight_dtype) |
@@ -831,7 +826,8 @@ def main(): | |||
831 | unet.eval() | 826 | unet.eval() |
832 | 827 | ||
833 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. | 828 | # We need to recalculate our total training steps as the size of the training dataloader may have changed. |
834 | num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps) | 829 | num_update_steps_per_dataloader = sum(len(dataloader) for dataloader in train_dataloaders) |
830 | num_update_steps_per_epoch = math.ceil(num_update_steps_per_dataloader / args.gradient_accumulation_steps) | ||
835 | if overrode_max_train_steps: | 831 | if overrode_max_train_steps: |
836 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch | 832 | args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch |
837 | 833 | ||
@@ -889,7 +885,7 @@ def main(): | |||
889 | accelerator, | 885 | accelerator, |
890 | text_encoder, | 886 | text_encoder, |
891 | optimizer, | 887 | optimizer, |
892 | train_dataloader, | 888 | default_train_dataloader, |
893 | val_dataloader, | 889 | val_dataloader, |
894 | loop, | 890 | loop, |
895 | on_train=on_train, | 891 | on_train=on_train, |
@@ -968,46 +964,48 @@ def main(): | |||
968 | text_encoder.train() | 964 | text_encoder.train() |
969 | 965 | ||
970 | with on_train(): | 966 | with on_train(): |
971 | for step, batch in enumerate(train_dataloader): | 967 | for train_dataloader in train_dataloaders: |
972 | with accelerator.accumulate(text_encoder): | 968 | for step, batch in enumerate(train_dataloader): |
973 | loss, acc, bsz = loop(step, batch) | 969 | with accelerator.accumulate(text_encoder): |
970 | loss, acc, bsz = loop(step, batch) | ||
974 | 971 | ||
975 | accelerator.backward(loss) | 972 | accelerator.backward(loss) |
976 | 973 | ||
977 | optimizer.step() | 974 | optimizer.step() |
978 | if not accelerator.optimizer_step_was_skipped: | 975 | if not accelerator.optimizer_step_was_skipped: |
979 | lr_scheduler.step() | 976 | lr_scheduler.step() |
980 | optimizer.zero_grad(set_to_none=True) | 977 | optimizer.zero_grad(set_to_none=True) |
981 | 978 | ||
982 | avg_loss.update(loss.detach_(), bsz) | 979 | avg_loss.update(loss.detach_(), bsz) |
983 | avg_acc.update(acc.detach_(), bsz) | 980 | avg_acc.update(acc.detach_(), bsz) |
984 | 981 | ||
985 | # Checks if the accelerator has performed an optimization step behind the scenes | 982 | # Checks if the accelerator has performed an optimization step behind the scenes |
986 | if accelerator.sync_gradients: | 983 | if accelerator.sync_gradients: |
987 | if args.use_ema: | 984 | if args.use_ema: |
988 | ema_embeddings.step(text_encoder.text_model.embeddings.temp_token_embedding.parameters()) | 985 | ema_embeddings.step( |
986 | text_encoder.text_model.embeddings.temp_token_embedding.parameters()) | ||
989 | 987 | ||
990 | local_progress_bar.update(1) | 988 | local_progress_bar.update(1) |
991 | global_progress_bar.update(1) | 989 | global_progress_bar.update(1) |
992 | 990 | ||
993 | global_step += 1 | 991 | global_step += 1 |
994 | 992 | ||
995 | logs = { | 993 | logs = { |
996 | "train/loss": avg_loss.avg.item(), | 994 | "train/loss": avg_loss.avg.item(), |
997 | "train/acc": avg_acc.avg.item(), | 995 | "train/acc": avg_acc.avg.item(), |
998 | "train/cur_loss": loss.item(), | 996 | "train/cur_loss": loss.item(), |
999 | "train/cur_acc": acc.item(), | 997 | "train/cur_acc": acc.item(), |
1000 | "lr": lr_scheduler.get_last_lr()[0], | 998 | "lr": lr_scheduler.get_last_lr()[0], |
1001 | } | 999 | } |
1002 | if args.use_ema: | 1000 | if args.use_ema: |
1003 | logs["ema_decay"] = ema_embeddings.decay | 1001 | logs["ema_decay"] = ema_embeddings.decay |
1004 | 1002 | ||
1005 | accelerator.log(logs, step=global_step) | 1003 | accelerator.log(logs, step=global_step) |
1006 | 1004 | ||
1007 | local_progress_bar.set_postfix(**logs) | 1005 | local_progress_bar.set_postfix(**logs) |
1008 | 1006 | ||
1009 | if global_step >= args.max_train_steps: | 1007 | if global_step >= args.max_train_steps: |
1010 | break | 1008 | break |
1011 | 1009 | ||
1012 | accelerator.wait_for_everyone() | 1010 | accelerator.wait_for_everyone() |
1013 | 1011 | ||
diff --git a/training/util.py b/training/util.py index bc466e2..6f42228 100644 --- a/training/util.py +++ b/training/util.py | |||
@@ -58,8 +58,8 @@ class CheckpointerBase: | |||
58 | def save_samples(self, pipeline, step, num_inference_steps, guidance_scale=7.5, eta=0.0): | 58 | def save_samples(self, pipeline, step, num_inference_steps, guidance_scale=7.5, eta=0.0): |
59 | samples_path = Path(self.output_dir).joinpath("samples") | 59 | samples_path = Path(self.output_dir).joinpath("samples") |
60 | 60 | ||
61 | train_data = self.datamodule.train_dataloader() | 61 | train_data = self.datamodule.train_dataloaders[0] |
62 | val_data = self.datamodule.val_dataloader() | 62 | val_data = self.datamodule.val_dataloader |
63 | 63 | ||
64 | generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) | 64 | generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) |
65 | 65 | ||