summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--aesthetic_gradient.py137
-rw-r--r--dreambooth.py10
-rw-r--r--dreambooth_plus.py59
-rw-r--r--infer.py32
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py50
-rw-r--r--textual_inversion.py32
6 files changed, 77 insertions, 243 deletions
diff --git a/aesthetic_gradient.py b/aesthetic_gradient.py
deleted file mode 100644
index 5386d0f..0000000
--- a/aesthetic_gradient.py
+++ /dev/null
@@ -1,137 +0,0 @@
1import argparse
2import datetime
3import logging
4import json
5from pathlib import Path
6
7import torch
8import torch.utils.checkpoint
9from torchvision import transforms
10import pandas as pd
11
12from accelerate.logging import get_logger
13from PIL import Image
14from tqdm import tqdm
15from transformers import CLIPModel
16from slugify import slugify
17
18logger = get_logger(__name__)
19
20
21torch.backends.cuda.matmul.allow_tf32 = True
22
23
24def parse_args():
25 parser = argparse.ArgumentParser(
26 description="Simple example of a training script."
27 )
28 parser.add_argument(
29 "--pretrained_model_name_or_path",
30 type=str,
31 default=None,
32 help="Path to pretrained model or model identifier from huggingface.co/models.",
33 )
34 parser.add_argument(
35 "--train_data_file",
36 type=str,
37 default=None,
38 help="A directory."
39 )
40 parser.add_argument(
41 "--token",
42 type=str,
43 default=None,
44 help="A token to use as a placeholder for the concept.",
45 )
46 parser.add_argument(
47 "--resolution",
48 type=int,
49 default=224,
50 help=(
51 "The resolution for input images, all the images in the train/validation dataset will be resized to this"
52 " resolution"
53 ),
54 )
55 parser.add_argument(
56 "--output_dir",
57 type=str,
58 default="output/aesthetic-gradient",
59 help="The output directory where the model predictions and checkpoints will be written.",
60 )
61 parser.add_argument(
62 "--config",
63 type=str,
64 default=None,
65 help="Path to a JSON configuration file containing arguments for invoking this script. If resume_from is given, its resume.json takes priority over this."
66 )
67
68 args = parser.parse_args()
69 if args.config is not None:
70 with open(args.config, 'rt') as f:
71 args = parser.parse_args(
72 namespace=argparse.Namespace(**json.load(f)["args"]))
73
74 if args.train_data_file is None:
75 raise ValueError("You must specify --train_data_file")
76
77 if args.token is None:
78 raise ValueError("You must specify --token")
79
80 if args.output_dir is None:
81 raise ValueError("You must specify --output_dir")
82
83 return args
84
85
86def main():
87 args = parse_args()
88
89 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
90 basepath = Path(args.output_dir)
91 basepath.mkdir(parents=True, exist_ok=True)
92 target = basepath.joinpath(f"{slugify(args.token)}-{now}.pt")
93
94 logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG)
95
96 data_file = Path(args.train_data_file)
97 if not data_file.is_file():
98 raise ValueError("data_file must be a file")
99 data_root = data_file.parent
100 metadata = pd.read_csv(data_file)
101 image_paths = [
102 data_root.joinpath(item.image)
103 for item in metadata.itertuples()
104 if "skip" not in item or item.skip != "x"
105 ]
106
107 model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
108
109 image_transforms = transforms.Compose(
110 [
111 transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.LANCZOS),
112 transforms.RandomCrop(args.resolution),
113 transforms.ToTensor(),
114 transforms.Normalize([0.5], [0.5]),
115 ]
116 )
117
118 with torch.no_grad():
119 embs = []
120 for path in tqdm(image_paths):
121 image = Image.open(path)
122 if not image.mode == "RGB":
123 image = image.convert("RGB")
124 image = image_transforms(image).unsqueeze(0)
125 emb = model.get_image_features(image)
126 print(f">>>> {emb.shape}")
127 embs.append(emb)
128
129 embs = torch.cat(embs, dim=0).mean(dim=0, keepdim=True)
130
131 print(embs.shape)
132
133 torch.save(embs, target)
134
135
136if __name__ == "__main__":
137 main()
diff --git a/dreambooth.py b/dreambooth.py
index 072142e..1ba8dc0 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -70,7 +70,7 @@ def parse_args():
70 "--num_class_images", 70 "--num_class_images",
71 type=int, 71 type=int,
72 default=400, 72 default=400,
73 help="How many class images to generate per training image." 73 help="How many class images to generate."
74 ) 74 )
75 parser.add_argument( 75 parser.add_argument(
76 "--repeats", 76 "--repeats",
@@ -112,7 +112,7 @@ def parse_args():
112 parser.add_argument( 112 parser.add_argument(
113 "--max_train_steps", 113 "--max_train_steps",
114 type=int, 114 type=int,
115 default=3000, 115 default=2000,
116 help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 116 help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
117 ) 117 )
118 parser.add_argument( 118 parser.add_argument(
@@ -341,7 +341,7 @@ class Checkpointer:
341 self.sample_batch_size = sample_batch_size 341 self.sample_batch_size = sample_batch_size
342 342
343 @torch.no_grad() 343 @torch.no_grad()
344 def checkpoint(self): 344 def save_model(self):
345 print("Saving model...") 345 print("Saving model...")
346 346
347 unwrapped = self.accelerator.unwrap_model( 347 unwrapped = self.accelerator.unwrap_model(
@@ -839,14 +839,14 @@ def main():
839 # Create the pipeline using using the trained modules and save it. 839 # Create the pipeline using using the trained modules and save it.
840 if accelerator.is_main_process: 840 if accelerator.is_main_process:
841 print("Finished! Saving final checkpoint and resume state.") 841 print("Finished! Saving final checkpoint and resume state.")
842 checkpointer.checkpoint() 842 checkpointer.save_model()
843 843
844 accelerator.end_training() 844 accelerator.end_training()
845 845
846 except KeyboardInterrupt: 846 except KeyboardInterrupt:
847 if accelerator.is_main_process: 847 if accelerator.is_main_process:
848 print("Interrupted, saving checkpoint and resume state...") 848 print("Interrupted, saving checkpoint and resume state...")
849 checkpointer.checkpoint() 849 checkpointer.save_model()
850 accelerator.end_training() 850 accelerator.end_training()
851 quit() 851 quit()
852 852
diff --git a/dreambooth_plus.py b/dreambooth_plus.py
index 7996bc2..b5ec2fc 100644
--- a/dreambooth_plus.py
+++ b/dreambooth_plus.py
@@ -58,6 +58,12 @@ def parse_args():
58 parser.add_argument( 58 parser.add_argument(
59 "--placeholder_token", 59 "--placeholder_token",
60 type=str, 60 type=str,
61 default="<*>",
62 help="A token to use as a placeholder for the concept.",
63 )
64 parser.add_argument(
65 "--class_identifier",
66 type=str,
61 default=None, 67 default=None,
62 help="A token to use as a placeholder for the concept.", 68 help="A token to use as a placeholder for the concept.",
63 ) 69 )
@@ -71,7 +77,7 @@ def parse_args():
71 "--num_class_images", 77 "--num_class_images",
72 type=int, 78 type=int,
73 default=400, 79 default=400,
74 help="How many class images to generate per training image." 80 help="How many class images to generate."
75 ) 81 )
76 parser.add_argument( 82 parser.add_argument(
77 "--repeats", 83 "--repeats",
@@ -112,7 +118,7 @@ def parse_args():
112 parser.add_argument( 118 parser.add_argument(
113 "--max_train_steps", 119 "--max_train_steps",
114 type=int, 120 type=int,
115 default=1600, 121 default=2300,
116 help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 122 help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
117 ) 123 )
118 parser.add_argument( 124 parser.add_argument(
@@ -135,7 +141,7 @@ def parse_args():
135 parser.add_argument( 141 parser.add_argument(
136 "--learning_rate_text", 142 "--learning_rate_text",
137 type=float, 143 type=float,
138 default=5e-4, 144 default=5e-6,
139 help="Initial learning rate (after the potential warmup period) to use.", 145 help="Initial learning rate (after the potential warmup period) to use.",
140 ) 146 )
141 parser.add_argument( 147 parser.add_argument(
@@ -222,6 +228,12 @@ def parse_args():
222 ), 228 ),
223 ) 229 )
224 parser.add_argument( 230 parser.add_argument(
231 "--checkpoint_frequency",
232 type=int,
233 default=500,
234 help="How often to save a checkpoint and sample image",
235 )
236 parser.add_argument(
225 "--sample_frequency", 237 "--sample_frequency",
226 type=int, 238 type=int,
227 default=100, 239 default=100,
@@ -352,7 +364,26 @@ class Checkpointer:
352 self.sample_batch_size = sample_batch_size 364 self.sample_batch_size = sample_batch_size
353 365
354 @torch.no_grad() 366 @torch.no_grad()
355 def checkpoint(self): 367 def checkpoint(self, step, postfix):
368 print("Saving checkpoint for step %d..." % step)
369
370 checkpoints_path = self.output_dir.joinpath("checkpoints")
371 checkpoints_path.mkdir(parents=True, exist_ok=True)
372
373 unwrapped = self.accelerator.unwrap_model(self.text_encoder)
374
375 # Save a checkpoint
376 learned_embeds = unwrapped.get_input_embeddings().weight[self.placeholder_token_id]
377 learned_embeds_dict = {self.placeholder_token: learned_embeds.detach().cpu()}
378
379 filename = f"%s_%d_%s.bin" % (slugify(self.placeholder_token), step, postfix)
380 torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename))
381
382 del unwrapped
383 del learned_embeds
384
385 @torch.no_grad()
386 def save_model(self):
356 print("Saving model...") 387 print("Saving model...")
357 388
358 unwrapped_unet = self.accelerator.unwrap_model( 389 unwrapped_unet = self.accelerator.unwrap_model(
@@ -612,7 +643,7 @@ def main():
612 batch_size=args.train_batch_size, 643 batch_size=args.train_batch_size,
613 tokenizer=tokenizer, 644 tokenizer=tokenizer,
614 instance_identifier=args.placeholder_token, 645 instance_identifier=args.placeholder_token,
615 class_identifier=args.initializer_token, 646 class_identifier=args.class_identifier,
616 class_subdir="cls", 647 class_subdir="cls",
617 num_class_images=args.num_class_images, 648 num_class_images=args.num_class_images,
618 size=args.resolution, 649 size=args.resolution,
@@ -648,7 +679,7 @@ def main():
648 with torch.inference_mode(): 679 with torch.inference_mode():
649 for batch in batched_data: 680 for batch in batched_data:
650 image_name = [p.class_image_path for p in batch] 681 image_name = [p.class_image_path for p in batch]
651 prompt = [p.prompt.format(args.initializer_token) for p in batch] 682 prompt = [p.prompt.format(args.class_identifier) for p in batch]
652 nprompt = [p.nprompt for p in batch] 683 nprompt = [p.nprompt for p in batch]
653 684
654 images = pipeline( 685 images = pipeline(
@@ -842,6 +873,12 @@ def main():
842 if global_step % args.sample_frequency == 0: 873 if global_step % args.sample_frequency == 0:
843 sample_checkpoint = True 874 sample_checkpoint = True
844 875
876 if global_step % args.checkpoint_frequency == 0 and global_step > 0 and accelerator.is_main_process:
877 local_progress_bar.clear()
878 global_progress_bar.clear()
879
880 checkpointer.checkpoint(global_step + global_step_offset, "training")
881
845 logs = { 882 logs = {
846 "train/loss": loss, 883 "train/loss": loss,
847 "lr/unet": lr_scheduler.get_last_lr()[0], 884 "lr/unet": lr_scheduler.get_last_lr()[0],
@@ -903,6 +940,9 @@ def main():
903 global_progress_bar.clear() 940 global_progress_bar.clear()
904 941
905 if min_val_loss > val_loss: 942 if min_val_loss > val_loss:
943 accelerator.print(
944 f"Global step {global_step}: Validation loss reached new minimum: {min_val_loss:.2e} -> {val_loss:.2e}")
945 checkpointer.checkpoint(global_step + global_step_offset, "milestone")
906 min_val_loss = val_loss 946 min_val_loss = val_loss
907 947
908 if sample_checkpoint and accelerator.is_main_process: 948 if sample_checkpoint and accelerator.is_main_process:
@@ -913,14 +953,15 @@ def main():
913 # Create the pipeline using using the trained modules and save it. 953 # Create the pipeline using using the trained modules and save it.
914 if accelerator.is_main_process: 954 if accelerator.is_main_process:
915 print("Finished! Saving final checkpoint and resume state.") 955 print("Finished! Saving final checkpoint and resume state.")
916 checkpointer.checkpoint() 956 checkpointer.checkpoint(global_step + global_step_offset, "end")
917 957 checkpointer.save_model()
918 accelerator.end_training() 958 accelerator.end_training()
919 959
920 except KeyboardInterrupt: 960 except KeyboardInterrupt:
921 if accelerator.is_main_process: 961 if accelerator.is_main_process:
922 print("Interrupted, saving checkpoint and resume state...") 962 print("Interrupted, saving checkpoint and resume state...")
923 checkpointer.checkpoint() 963 checkpointer.checkpoint(global_step + global_step_offset, "end")
964 checkpointer.save_model()
924 accelerator.end_training() 965 accelerator.end_training()
925 quit() 966 quit()
926 967
diff --git a/infer.py b/infer.py
index 650c119..1a0baf5 100644
--- a/infer.py
+++ b/infer.py
@@ -24,7 +24,6 @@ default_args = {
24 "scheduler": "euler_a", 24 "scheduler": "euler_a",
25 "precision": "fp32", 25 "precision": "fp32",
26 "ti_embeddings_dir": "embeddings_ti", 26 "ti_embeddings_dir": "embeddings_ti",
27 "ag_embeddings_dir": "embeddings_ag",
28 "output_dir": "output/inference", 27 "output_dir": "output/inference",
29 "config": None, 28 "config": None,
30} 29}
@@ -78,10 +77,6 @@ def create_args_parser():
78 type=str, 77 type=str,
79 ) 78 )
80 parser.add_argument( 79 parser.add_argument(
81 "--ag_embeddings_dir",
82 type=str,
83 )
84 parser.add_argument(
85 "--output_dir", 80 "--output_dir",
86 type=str, 81 type=str,
87 ) 82 )
@@ -211,24 +206,7 @@ def load_embeddings_ti(tokenizer, text_encoder, embeddings_dir):
211 print(f"Loaded {placeholder_token}") 206 print(f"Loaded {placeholder_token}")
212 207
213 208
214def load_embeddings_ag(pipeline, embeddings_dir): 209def create_pipeline(model, scheduler, ti_embeddings_dir, dtype):
215 print(f"Loading Aesthetic Gradient embeddings")
216
217 embeddings_dir = Path(embeddings_dir)
218 embeddings_dir.mkdir(parents=True, exist_ok=True)
219
220 for file in embeddings_dir.iterdir():
221 if file.is_file():
222 placeholder_token = file.stem
223
224 data = torch.load(file, map_location="cpu")
225
226 pipeline.add_aesthetic_gradient_embedding(placeholder_token, data)
227
228 print(f"Loaded {placeholder_token}")
229
230
231def create_pipeline(model, scheduler, ti_embeddings_dir, ag_embeddings_dir, dtype):
232 print("Loading Stable Diffusion pipeline...") 210 print("Loading Stable Diffusion pipeline...")
233 211
234 tokenizer = CLIPTokenizer.from_pretrained(model, subfolder='tokenizer', torch_dtype=dtype) 212 tokenizer = CLIPTokenizer.from_pretrained(model, subfolder='tokenizer', torch_dtype=dtype)
@@ -262,13 +240,11 @@ def create_pipeline(model, scheduler, ti_embeddings_dir, ag_embeddings_dir, dtyp
262 tokenizer=tokenizer, 240 tokenizer=tokenizer,
263 scheduler=scheduler, 241 scheduler=scheduler,
264 ) 242 )
265 pipeline.aesthetic_gradient_iters = 30 243 pipeline.aesthetic_gradient_iters = 20
266 pipeline.to("cuda") 244 pipeline.to("cuda")
267 245
268 print("Pipeline loaded.") 246 print("Pipeline loaded.")
269 247
270 load_embeddings_ag(pipeline, ag_embeddings_dir)
271
272 return pipeline 248 return pipeline
273 249
274 250
@@ -288,7 +264,7 @@ def generate(output_dir, pipeline, args):
288 else: 264 else:
289 init_image = None 265 init_image = None
290 266
291 with torch.autocast("cuda"): 267 with torch.autocast("cuda"), torch.inference_mode():
292 for i in range(args.batch_num): 268 for i in range(args.batch_num):
293 pipeline.set_progress_bar_config( 269 pipeline.set_progress_bar_config(
294 desc=f"Batch {i + 1} of {args.batch_num}", 270 desc=f"Batch {i + 1} of {args.batch_num}",
@@ -366,7 +342,7 @@ def main():
366 output_dir = Path(args.output_dir) 342 output_dir = Path(args.output_dir)
367 dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.precision] 343 dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.precision]
368 344
369 pipeline = create_pipeline(args.model, args.scheduler, args.ti_embeddings_dir, args.ag_embeddings_dir, dtype) 345 pipeline = create_pipeline(args.model, args.scheduler, args.ti_embeddings_dir, dtype)
370 cmd_parser = create_cmd_parser() 346 cmd_parser = create_cmd_parser()
371 cmd_prompt = CmdParse(output_dir, pipeline, cmd_parser) 347 cmd_prompt = CmdParse(output_dir, pipeline, cmd_parser)
372 cmd_prompt.cmdloop() 348 cmd_prompt.cmdloop()
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index 1a84c8d..3e41f86 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -51,10 +51,6 @@ class VlpnStableDiffusion(DiffusionPipeline):
51 new_config["steps_offset"] = 1 51 new_config["steps_offset"] = 1
52 scheduler._internal_dict = FrozenDict(new_config) 52 scheduler._internal_dict = FrozenDict(new_config)
53 53
54 self.aesthetic_gradient_embeddings = {}
55 self.aesthetic_gradient_lr = 1e-4
56 self.aesthetic_gradient_iters = 10
57
58 self.register_modules( 54 self.register_modules(
59 vae=vae, 55 vae=vae,
60 text_encoder=text_encoder, 56 text_encoder=text_encoder,
@@ -63,46 +59,8 @@ class VlpnStableDiffusion(DiffusionPipeline):
63 scheduler=scheduler, 59 scheduler=scheduler,
64 ) 60 )
65 61
66 def add_aesthetic_gradient_embedding(self, keyword: str, tensor: torch.IntTensor): 62 def get_text_embeddings(self, text_input_ids):
67 self.aesthetic_gradient_embeddings[keyword] = tensor 63 return self.text_encoder(text_input_ids)[0]
68
69 def get_text_embeddings(self, prompt, text_input_ids):
70 prompt = " ".join(prompt)
71
72 embeddings = [
73 embedding
74 for key, embedding in self.aesthetic_gradient_embeddings.items()
75 if key in prompt
76 ]
77
78 if len(embeddings) != 0:
79 with torch.enable_grad():
80 full_clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
81 full_clip_model.to(self.device)
82 full_clip_model.text_model.train()
83
84 optimizer = optim.Adam(full_clip_model.text_model.parameters(), lr=self.aesthetic_gradient_lr)
85
86 for embs in embeddings:
87 embs = embs.clone().detach().to(self.device)
88 embs /= embs.norm(dim=-1, keepdim=True)
89
90 for i in range(self.aesthetic_gradient_iters):
91 text_embs = full_clip_model.get_text_features(text_input_ids)
92 text_embs /= text_embs.norm(dim=-1, keepdim=True)
93 sim = text_embs @ embs.T
94 loss = -sim
95 loss = loss.mean()
96
97 loss.backward()
98 optimizer.step()
99 optimizer.zero_grad()
100
101 full_clip_model.text_model.eval()
102
103 return full_clip_model.text_model(text_input_ids)[0]
104 else:
105 return self.text_encoder(text_input_ids)[0]
106 64
107 def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): 65 def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
108 r""" 66 r"""
@@ -241,7 +199,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
241 ) 199 )
242 print(f"Too many tokens: {removed_text}") 200 print(f"Too many tokens: {removed_text}")
243 text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] 201 text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
244 text_embeddings = self.get_text_embeddings(prompt, text_input_ids.to(self.device)) 202 text_embeddings = self.get_text_embeddings(text_input_ids.to(self.device))
245 203
246 # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 204 # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
247 # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 205 # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
@@ -253,7 +211,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
253 uncond_input = self.tokenizer( 211 uncond_input = self.tokenizer(
254 negative_prompt, padding="max_length", max_length=max_length, return_tensors="pt" 212 negative_prompt, padding="max_length", max_length=max_length, return_tensors="pt"
255 ) 213 )
256 uncond_embeddings = self.get_text_embeddings(negative_prompt, uncond_input.input_ids.to(self.device)) 214 uncond_embeddings = self.get_text_embeddings(uncond_input.input_ids.to(self.device))
257 215
258 # For classifier free guidance, we need to do two forward passes. 216 # For classifier free guidance, we need to do two forward passes.
259 # Here we concatenate the unconditional and text embeddings into a single batch 217 # Here we concatenate the unconditional and text embeddings into a single batch
diff --git a/textual_inversion.py b/textual_inversion.py
index 9d2840d..6627f1f 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -57,6 +57,12 @@ def parse_args():
57 parser.add_argument( 57 parser.add_argument(
58 "--placeholder_token", 58 "--placeholder_token",
59 type=str, 59 type=str,
60 default="<*>",
61 help="A token to use as a placeholder for the concept.",
62 )
63 parser.add_argument(
64 "--class_identifier",
65 type=str,
60 default=None, 66 default=None,
61 help="A token to use as a placeholder for the concept.", 67 help="A token to use as a placeholder for the concept.",
62 ) 68 )
@@ -70,7 +76,7 @@ def parse_args():
70 "--num_class_images", 76 "--num_class_images",
71 type=int, 77 type=int,
72 default=400, 78 default=400,
73 help="How many class images to generate per training image." 79 help="How many class images to generate."
74 ) 80 )
75 parser.add_argument( 81 parser.add_argument(
76 "--repeats", 82 "--repeats",
@@ -344,12 +350,11 @@ class Checkpointer:
344 self.sample_batch_size = sample_batch_size 350 self.sample_batch_size = sample_batch_size
345 351
346 @torch.no_grad() 352 @torch.no_grad()
347 def checkpoint(self, step, postfix, path=None): 353 def checkpoint(self, step, postfix):
348 print("Saving checkpoint for step %d..." % step) 354 print("Saving checkpoint for step %d..." % step)
349 355
350 if path is None: 356 checkpoints_path = self.output_dir.joinpath("checkpoints")
351 checkpoints_path = self.output_dir.joinpath("checkpoints") 357 checkpoints_path.mkdir(parents=True, exist_ok=True)
352 checkpoints_path.mkdir(parents=True, exist_ok=True)
353 358
354 unwrapped = self.accelerator.unwrap_model(self.text_encoder) 359 unwrapped = self.accelerator.unwrap_model(self.text_encoder)
355 360
@@ -358,10 +363,7 @@ class Checkpointer:
358 learned_embeds_dict = {self.placeholder_token: learned_embeds.detach().cpu()} 363 learned_embeds_dict = {self.placeholder_token: learned_embeds.detach().cpu()}
359 364
360 filename = f"%s_%d_%s.bin" % (slugify(self.placeholder_token), step, postfix) 365 filename = f"%s_%d_%s.bin" % (slugify(self.placeholder_token), step, postfix)
361 if path is not None: 366 torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename))
362 torch.save(learned_embeds_dict, path)
363 else:
364 torch.save(learned_embeds_dict, checkpoints_path.joinpath(filename))
365 367
366 del unwrapped 368 del unwrapped
367 del learned_embeds 369 del learned_embeds
@@ -595,7 +597,7 @@ def main():
595 batch_size=args.train_batch_size, 597 batch_size=args.train_batch_size,
596 tokenizer=tokenizer, 598 tokenizer=tokenizer,
597 instance_identifier=args.placeholder_token, 599 instance_identifier=args.placeholder_token,
598 class_identifier=args.initializer_token, 600 class_identifier=args.class_identifier,
599 class_subdir="cls", 601 class_subdir="cls",
600 num_class_images=args.num_class_images, 602 num_class_images=args.num_class_images,
601 size=args.resolution, 603 size=args.resolution,
@@ -631,7 +633,7 @@ def main():
631 with torch.inference_mode(): 633 with torch.inference_mode():
632 for batch in batched_data: 634 for batch in batched_data:
633 image_name = [p.class_image_path for p in batch] 635 image_name = [p.class_image_path for p in batch]
634 prompt = [p.prompt.format(args.initializer_token) for p in batch] 636 prompt = [p.prompt.format(args.class_identifier) for p in batch]
635 nprompt = [p.nprompt for p in batch] 637 nprompt = [p.nprompt for p in batch]
636 638
637 images = pipeline( 639 images = pipeline(
@@ -898,17 +900,11 @@ def main():
898 # Create the pipeline using using the trained modules and save it. 900 # Create the pipeline using using the trained modules and save it.
899 if accelerator.is_main_process: 901 if accelerator.is_main_process:
900 print("Finished! Saving final checkpoint and resume state.") 902 print("Finished! Saving final checkpoint and resume state.")
901 checkpointer.checkpoint( 903 checkpointer.checkpoint(global_step + global_step_offset, "end")
902 global_step + global_step_offset,
903 "end",
904 path=f"{basepath}/learned_embeds.bin"
905 )
906
907 save_resume_file(basepath, args, { 904 save_resume_file(basepath, args, {
908 "global_step": global_step + global_step_offset, 905 "global_step": global_step + global_step_offset,
909 "resume_checkpoint": f"{basepath}/checkpoints/last.bin" 906 "resume_checkpoint": f"{basepath}/checkpoints/last.bin"
910 }) 907 })
911
912 accelerator.end_training() 908 accelerator.end_training()
913 909
914 except KeyboardInterrupt: 910 except KeyboardInterrupt: