summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2022-10-13 21:11:53 +0200
committerVolpeon <git@volpeon.ink>2022-10-13 21:11:53 +0200
commit515f0f1fdc9a76bf63bd746c291dcfec7fc747fb (patch)
treecf4bbf1cae822cf7c7f388b6918154032def0376
parentAdded TI+Dreambooth training (diff)
downloadtextual-inversion-diff-515f0f1fdc9a76bf63bd746c291dcfec7fc747fb.tar.gz
textual-inversion-diff-515f0f1fdc9a76bf63bd746c291dcfec7fc747fb.tar.bz2
textual-inversion-diff-515f0f1fdc9a76bf63bd746c291dcfec7fc747fb.zip
Added support for Aesthetic Gradients
-rw-r--r--.gitignore4
-rw-r--r--aesthetic_gradient.py137
-rw-r--r--data/csv.py2
-rw-r--r--dreambooth.py10
-rw-r--r--dreambooth_plus.py16
-rw-r--r--infer.py75
-rw-r--r--pipelines/stable_diffusion/vlpn_stable_diffusion.py52
-rw-r--r--textual_inversion.py10
8 files changed, 245 insertions, 61 deletions
diff --git a/.gitignore b/.gitignore
index 6b9605f..d84b4dd 100644
--- a/.gitignore
+++ b/.gitignore
@@ -161,5 +161,7 @@ cython_debug/
161 161
162output/ 162output/
163conf/ 163conf/
164embeddings/ 164embeddings_ti/
165embeddings_ag/
165v1-inference.yaml* 166v1-inference.yaml*
167*.old
diff --git a/aesthetic_gradient.py b/aesthetic_gradient.py
new file mode 100644
index 0000000..5386d0f
--- /dev/null
+++ b/aesthetic_gradient.py
@@ -0,0 +1,137 @@
1import argparse
2import datetime
3import logging
4import json
5from pathlib import Path
6
7import torch
8import torch.utils.checkpoint
9from torchvision import transforms
10import pandas as pd
11
12from accelerate.logging import get_logger
13from PIL import Image
14from tqdm import tqdm
15from transformers import CLIPModel
16from slugify import slugify
17
18logger = get_logger(__name__)
19
20
21torch.backends.cuda.matmul.allow_tf32 = True
22
23
24def parse_args():
25 parser = argparse.ArgumentParser(
26 description="Simple example of a training script."
27 )
28 parser.add_argument(
29 "--pretrained_model_name_or_path",
30 type=str,
31 default=None,
32 help="Path to pretrained model or model identifier from huggingface.co/models.",
33 )
34 parser.add_argument(
35 "--train_data_file",
36 type=str,
37 default=None,
38 help="A directory."
39 )
40 parser.add_argument(
41 "--token",
42 type=str,
43 default=None,
44 help="A token to use as a placeholder for the concept.",
45 )
46 parser.add_argument(
47 "--resolution",
48 type=int,
49 default=224,
50 help=(
51 "The resolution for input images, all the images in the train/validation dataset will be resized to this"
52 " resolution"
53 ),
54 )
55 parser.add_argument(
56 "--output_dir",
57 type=str,
58 default="output/aesthetic-gradient",
59 help="The output directory where the model predictions and checkpoints will be written.",
60 )
61 parser.add_argument(
62 "--config",
63 type=str,
64 default=None,
65 help="Path to a JSON configuration file containing arguments for invoking this script. If resume_from is given, its resume.json takes priority over this."
66 )
67
68 args = parser.parse_args()
69 if args.config is not None:
70 with open(args.config, 'rt') as f:
71 args = parser.parse_args(
72 namespace=argparse.Namespace(**json.load(f)["args"]))
73
74 if args.train_data_file is None:
75 raise ValueError("You must specify --train_data_file")
76
77 if args.token is None:
78 raise ValueError("You must specify --token")
79
80 if args.output_dir is None:
81 raise ValueError("You must specify --output_dir")
82
83 return args
84
85
86def main():
87 args = parse_args()
88
89 now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
90 basepath = Path(args.output_dir)
91 basepath.mkdir(parents=True, exist_ok=True)
92 target = basepath.joinpath(f"{slugify(args.token)}-{now}.pt")
93
94 logging.basicConfig(filename=basepath.joinpath("log.txt"), level=logging.DEBUG)
95
96 data_file = Path(args.train_data_file)
97 if not data_file.is_file():
98 raise ValueError("data_file must be a file")
99 data_root = data_file.parent
100 metadata = pd.read_csv(data_file)
101 image_paths = [
102 data_root.joinpath(item.image)
103 for item in metadata.itertuples()
104 if "skip" not in item or item.skip != "x"
105 ]
106
107 model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
108
109 image_transforms = transforms.Compose(
110 [
111 transforms.Resize(args.resolution, interpolation=transforms.InterpolationMode.LANCZOS),
112 transforms.RandomCrop(args.resolution),
113 transforms.ToTensor(),
114 transforms.Normalize([0.5], [0.5]),
115 ]
116 )
117
118 with torch.no_grad():
119 embs = []
120 for path in tqdm(image_paths):
121 image = Image.open(path)
122 if not image.mode == "RGB":
123 image = image.convert("RGB")
124 image = image_transforms(image).unsqueeze(0)
125 emb = model.get_image_features(image)
126 print(f">>>> {emb.shape}")
127 embs.append(emb)
128
129 embs = torch.cat(embs, dim=0).mean(dim=0, keepdim=True)
130
131 print(embs.shape)
132
133 torch.save(embs, target)
134
135
136if __name__ == "__main__":
137 main()
diff --git a/data/csv.py b/data/csv.py
index 253ce9e..aad970c 100644
--- a/data/csv.py
+++ b/data/csv.py
@@ -23,7 +23,7 @@ class CSVDataModule(pl.LightningDataModule):
23 tokenizer, 23 tokenizer,
24 instance_identifier, 24 instance_identifier,
25 class_identifier=None, 25 class_identifier=None,
26 class_subdir="db_cls", 26 class_subdir="cls",
27 num_class_images=100, 27 num_class_images=100,
28 size=512, 28 size=512,
29 repeats=100, 29 repeats=100,
diff --git a/dreambooth.py b/dreambooth.py
index 699313e..072142e 100644
--- a/dreambooth.py
+++ b/dreambooth.py
@@ -216,12 +216,6 @@ def parse_args():
216 ), 216 ),
217 ) 217 )
218 parser.add_argument( 218 parser.add_argument(
219 "--local_rank",
220 type=int,
221 default=-1,
222 help="For distributed training: local_rank"
223 )
224 parser.add_argument(
225 "--sample_frequency", 219 "--sample_frequency",
226 type=int, 220 type=int,
227 default=100, 221 default=100,
@@ -287,10 +281,6 @@ def parse_args():
287 args = parser.parse_args( 281 args = parser.parse_args(
288 namespace=argparse.Namespace(**json.load(f)["args"])) 282 namespace=argparse.Namespace(**json.load(f)["args"]))
289 283
290 env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
291 if env_local_rank != -1 and env_local_rank != args.local_rank:
292 args.local_rank = env_local_rank
293
294 if args.train_data_file is None: 284 if args.train_data_file is None:
295 raise ValueError("You must specify --train_data_file") 285 raise ValueError("You must specify --train_data_file")
296 286
diff --git a/dreambooth_plus.py b/dreambooth_plus.py
index 9e482b3..7996bc2 100644
--- a/dreambooth_plus.py
+++ b/dreambooth_plus.py
@@ -112,7 +112,7 @@ def parse_args():
112 parser.add_argument( 112 parser.add_argument(
113 "--max_train_steps", 113 "--max_train_steps",
114 type=int, 114 type=int,
115 default=3000, 115 default=1600,
116 help="Total number of training steps to perform. If provided, overrides num_train_epochs.", 116 help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
117 ) 117 )
118 parser.add_argument( 118 parser.add_argument(
@@ -129,13 +129,13 @@ def parse_args():
129 parser.add_argument( 129 parser.add_argument(
130 "--learning_rate_unet", 130 "--learning_rate_unet",
131 type=float, 131 type=float,
132 default=1e-5, 132 default=5e-6,
133 help="Initial learning rate (after the potential warmup period) to use.", 133 help="Initial learning rate (after the potential warmup period) to use.",
134 ) 134 )
135 parser.add_argument( 135 parser.add_argument(
136 "--learning_rate_text", 136 "--learning_rate_text",
137 type=float, 137 type=float,
138 default=1e-4, 138 default=5e-4,
139 help="Initial learning rate (after the potential warmup period) to use.", 139 help="Initial learning rate (after the potential warmup period) to use.",
140 ) 140 )
141 parser.add_argument( 141 parser.add_argument(
@@ -222,12 +222,6 @@ def parse_args():
222 ), 222 ),
223 ) 223 )
224 parser.add_argument( 224 parser.add_argument(
225 "--local_rank",
226 type=int,
227 default=-1,
228 help="For distributed training: local_rank"
229 )
230 parser.add_argument(
231 "--sample_frequency", 225 "--sample_frequency",
232 type=int, 226 type=int,
233 default=100, 227 default=100,
@@ -293,10 +287,6 @@ def parse_args():
293 args = parser.parse_args( 287 args = parser.parse_args(
294 namespace=argparse.Namespace(**json.load(f)["args"])) 288 namespace=argparse.Namespace(**json.load(f)["args"]))
295 289
296 env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
297 if env_local_rank != -1 and env_local_rank != args.local_rank:
298 args.local_rank = env_local_rank
299
300 if args.train_data_file is None: 290 if args.train_data_file is None:
301 raise ValueError("You must specify --train_data_file") 291 raise ValueError("You must specify --train_data_file")
302 292
diff --git a/infer.py b/infer.py
index 63b16d8..650c119 100644
--- a/infer.py
+++ b/infer.py
@@ -23,7 +23,8 @@ default_args = {
23 "model": None, 23 "model": None,
24 "scheduler": "euler_a", 24 "scheduler": "euler_a",
25 "precision": "fp32", 25 "precision": "fp32",
26 "embeddings_dir": "embeddings", 26 "ti_embeddings_dir": "embeddings_ti",
27 "ag_embeddings_dir": "embeddings_ag",
27 "output_dir": "output/inference", 28 "output_dir": "output/inference",
28 "config": None, 29 "config": None,
29} 30}
@@ -73,7 +74,11 @@ def create_args_parser():
73 choices=["fp32", "fp16", "bf16"], 74 choices=["fp32", "fp16", "bf16"],
74 ) 75 )
75 parser.add_argument( 76 parser.add_argument(
76 "--embeddings_dir", 77 "--ti_embeddings_dir",
78 type=str,
79 )
80 parser.add_argument(
81 "--ag_embeddings_dir",
77 type=str, 82 type=str,
78 ) 83 )
79 parser.add_argument( 84 parser.add_argument(
@@ -167,42 +172,63 @@ def save_args(basepath, args, extra={}):
167 json.dump(info, f, indent=4) 172 json.dump(info, f, indent=4)
168 173
169 174
170def load_embeddings(tokenizer, text_encoder, embeddings_dir): 175def load_embeddings_ti(tokenizer, text_encoder, embeddings_dir):
176 print(f"Loading Textual Inversion embeddings")
177
171 embeddings_dir = Path(embeddings_dir) 178 embeddings_dir = Path(embeddings_dir)
172 embeddings_dir.mkdir(parents=True, exist_ok=True) 179 embeddings_dir.mkdir(parents=True, exist_ok=True)
173 180
174 for file in embeddings_dir.iterdir(): 181 for file in embeddings_dir.iterdir():
175 placeholder_token = file.stem 182 if file.is_file():
183 placeholder_token = file.stem
176 184
177 num_added_tokens = tokenizer.add_tokens(placeholder_token) 185 num_added_tokens = tokenizer.add_tokens(placeholder_token)
178 if num_added_tokens == 0: 186 if num_added_tokens == 0:
179 raise ValueError( 187 raise ValueError(
180 f"The tokenizer already contains the token {placeholder_token}. Please pass a different" 188 f"The tokenizer already contains the token {placeholder_token}. Please pass a different"
181 " `placeholder_token` that is not already in the tokenizer." 189 " `placeholder_token` that is not already in the tokenizer."
182 ) 190 )
183 191
184 text_encoder.resize_token_embeddings(len(tokenizer)) 192 text_encoder.resize_token_embeddings(len(tokenizer))
185 193
186 token_embeds = text_encoder.get_input_embeddings().weight.data 194 token_embeds = text_encoder.get_input_embeddings().weight.data
187 195
188 for file in embeddings_dir.iterdir(): 196 for file in embeddings_dir.iterdir():
189 placeholder_token = file.stem 197 if file.is_file():
190 placeholder_token_id = tokenizer.convert_tokens_to_ids(placeholder_token) 198 placeholder_token = file.stem
199 placeholder_token_id = tokenizer.convert_tokens_to_ids(placeholder_token)
200
201 data = torch.load(file, map_location="cpu")
202
203 assert len(data.keys()) == 1, 'embedding file has multiple terms in it'
204
205 emb = next(iter(data.values()))
206 if len(emb.shape) == 1:
207 emb = emb.unsqueeze(0)
191 208
192 data = torch.load(file, map_location="cpu") 209 token_embeds[placeholder_token_id] = emb
193 210
194 assert len(data.keys()) == 1, 'embedding file has multiple terms in it' 211 print(f"Loaded {placeholder_token}")
195 212
196 emb = next(iter(data.values()))
197 if len(emb.shape) == 1:
198 emb = emb.unsqueeze(0)
199 213
200 token_embeds[placeholder_token_id] = emb 214def load_embeddings_ag(pipeline, embeddings_dir):
215 print(f"Loading Aesthetic Gradient embeddings")
201 216
202 print(f"Loaded embedding: {placeholder_token}") 217 embeddings_dir = Path(embeddings_dir)
218 embeddings_dir.mkdir(parents=True, exist_ok=True)
219
220 for file in embeddings_dir.iterdir():
221 if file.is_file():
222 placeholder_token = file.stem
203 223
224 data = torch.load(file, map_location="cpu")
204 225
205def create_pipeline(model, scheduler, embeddings_dir, dtype): 226 pipeline.add_aesthetic_gradient_embedding(placeholder_token, data)
227
228 print(f"Loaded {placeholder_token}")
229
230
231def create_pipeline(model, scheduler, ti_embeddings_dir, ag_embeddings_dir, dtype):
206 print("Loading Stable Diffusion pipeline...") 232 print("Loading Stable Diffusion pipeline...")
207 233
208 tokenizer = CLIPTokenizer.from_pretrained(model, subfolder='tokenizer', torch_dtype=dtype) 234 tokenizer = CLIPTokenizer.from_pretrained(model, subfolder='tokenizer', torch_dtype=dtype)
@@ -210,7 +236,7 @@ def create_pipeline(model, scheduler, embeddings_dir, dtype):
210 vae = AutoencoderKL.from_pretrained(model, subfolder='vae', torch_dtype=dtype) 236 vae = AutoencoderKL.from_pretrained(model, subfolder='vae', torch_dtype=dtype)
211 unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype) 237 unet = UNet2DConditionModel.from_pretrained(model, subfolder='unet', torch_dtype=dtype)
212 238
213 load_embeddings(tokenizer, text_encoder, embeddings_dir) 239 load_embeddings_ti(tokenizer, text_encoder, ti_embeddings_dir)
214 240
215 if scheduler == "plms": 241 if scheduler == "plms":
216 scheduler = PNDMScheduler( 242 scheduler = PNDMScheduler(
@@ -236,10 +262,13 @@ def create_pipeline(model, scheduler, embeddings_dir, dtype):
236 tokenizer=tokenizer, 262 tokenizer=tokenizer,
237 scheduler=scheduler, 263 scheduler=scheduler,
238 ) 264 )
265 pipeline.aesthetic_gradient_iters = 30
239 pipeline.to("cuda") 266 pipeline.to("cuda")
240 267
241 print("Pipeline loaded.") 268 print("Pipeline loaded.")
242 269
270 load_embeddings_ag(pipeline, ag_embeddings_dir)
271
243 return pipeline 272 return pipeline
244 273
245 274
@@ -259,7 +288,7 @@ def generate(output_dir, pipeline, args):
259 else: 288 else:
260 init_image = None 289 init_image = None
261 290
262 with torch.autocast("cuda"), torch.inference_mode(): 291 with torch.autocast("cuda"):
263 for i in range(args.batch_num): 292 for i in range(args.batch_num):
264 pipeline.set_progress_bar_config( 293 pipeline.set_progress_bar_config(
265 desc=f"Batch {i + 1} of {args.batch_num}", 294 desc=f"Batch {i + 1} of {args.batch_num}",
@@ -337,7 +366,7 @@ def main():
337 output_dir = Path(args.output_dir) 366 output_dir = Path(args.output_dir)
338 dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.precision] 367 dtype = {"fp32": torch.float32, "fp16": torch.float16, "bf16": torch.bfloat16}[args.precision]
339 368
340 pipeline = create_pipeline(args.model, args.scheduler, args.embeddings_dir, dtype) 369 pipeline = create_pipeline(args.model, args.scheduler, args.ti_embeddings_dir, args.ag_embeddings_dir, dtype)
341 cmd_parser = create_cmd_parser() 370 cmd_parser = create_cmd_parser()
342 cmd_prompt = CmdParse(output_dir, pipeline, cmd_parser) 371 cmd_prompt = CmdParse(output_dir, pipeline, cmd_parser)
343 cmd_prompt.cmdloop() 372 cmd_prompt.cmdloop()
diff --git a/pipelines/stable_diffusion/vlpn_stable_diffusion.py b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
index 8927a78..1a84c8d 100644
--- a/pipelines/stable_diffusion/vlpn_stable_diffusion.py
+++ b/pipelines/stable_diffusion/vlpn_stable_diffusion.py
@@ -4,13 +4,14 @@ from typing import List, Optional, Union
4 4
5import numpy as np 5import numpy as np
6import torch 6import torch
7import torch.optim as optim
7import PIL 8import PIL
8 9
9from diffusers.configuration_utils import FrozenDict 10from diffusers.configuration_utils import FrozenDict
10from diffusers import AutoencoderKL, DiffusionPipeline, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel 11from diffusers import AutoencoderKL, DiffusionPipeline, DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler, UNet2DConditionModel
11from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput 12from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
12from diffusers.utils import logging 13from diffusers.utils import logging
13from transformers import CLIPTextModel, CLIPTokenizer 14from transformers import CLIPTextModel, CLIPTokenizer, CLIPModel
14from schedulers.scheduling_euler_a import EulerAScheduler 15from schedulers.scheduling_euler_a import EulerAScheduler
15 16
16logger = logging.get_logger(__name__) # pylint: disable=invalid-name 17logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -50,6 +51,10 @@ class VlpnStableDiffusion(DiffusionPipeline):
50 new_config["steps_offset"] = 1 51 new_config["steps_offset"] = 1
51 scheduler._internal_dict = FrozenDict(new_config) 52 scheduler._internal_dict = FrozenDict(new_config)
52 53
54 self.aesthetic_gradient_embeddings = {}
55 self.aesthetic_gradient_lr = 1e-4
56 self.aesthetic_gradient_iters = 10
57
53 self.register_modules( 58 self.register_modules(
54 vae=vae, 59 vae=vae,
55 text_encoder=text_encoder, 60 text_encoder=text_encoder,
@@ -58,6 +63,47 @@ class VlpnStableDiffusion(DiffusionPipeline):
58 scheduler=scheduler, 63 scheduler=scheduler,
59 ) 64 )
60 65
66 def add_aesthetic_gradient_embedding(self, keyword: str, tensor: torch.IntTensor):
67 self.aesthetic_gradient_embeddings[keyword] = tensor
68
69 def get_text_embeddings(self, prompt, text_input_ids):
70 prompt = " ".join(prompt)
71
72 embeddings = [
73 embedding
74 for key, embedding in self.aesthetic_gradient_embeddings.items()
75 if key in prompt
76 ]
77
78 if len(embeddings) != 0:
79 with torch.enable_grad():
80 full_clip_model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14")
81 full_clip_model.to(self.device)
82 full_clip_model.text_model.train()
83
84 optimizer = optim.Adam(full_clip_model.text_model.parameters(), lr=self.aesthetic_gradient_lr)
85
86 for embs in embeddings:
87 embs = embs.clone().detach().to(self.device)
88 embs /= embs.norm(dim=-1, keepdim=True)
89
90 for i in range(self.aesthetic_gradient_iters):
91 text_embs = full_clip_model.get_text_features(text_input_ids)
92 text_embs /= text_embs.norm(dim=-1, keepdim=True)
93 sim = text_embs @ embs.T
94 loss = -sim
95 loss = loss.mean()
96
97 loss.backward()
98 optimizer.step()
99 optimizer.zero_grad()
100
101 full_clip_model.text_model.eval()
102
103 return full_clip_model.text_model(text_input_ids)[0]
104 else:
105 return self.text_encoder(text_input_ids)[0]
106
61 def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"): 107 def enable_attention_slicing(self, slice_size: Optional[Union[str, int]] = "auto"):
62 r""" 108 r"""
63 Enable sliced attention computation. 109 Enable sliced attention computation.
@@ -195,7 +241,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
195 ) 241 )
196 print(f"Too many tokens: {removed_text}") 242 print(f"Too many tokens: {removed_text}")
197 text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length] 243 text_input_ids = text_input_ids[:, : self.tokenizer.model_max_length]
198 text_embeddings = self.text_encoder(text_input_ids.to(self.device))[0] 244 text_embeddings = self.get_text_embeddings(prompt, text_input_ids.to(self.device))
199 245
200 # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) 246 # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
201 # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` 247 # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
@@ -207,7 +253,7 @@ class VlpnStableDiffusion(DiffusionPipeline):
207 uncond_input = self.tokenizer( 253 uncond_input = self.tokenizer(
208 negative_prompt, padding="max_length", max_length=max_length, return_tensors="pt" 254 negative_prompt, padding="max_length", max_length=max_length, return_tensors="pt"
209 ) 255 )
210 uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] 256 uncond_embeddings = self.get_text_embeddings(negative_prompt, uncond_input.input_ids.to(self.device))
211 257
212 # For classifier free guidance, we need to do two forward passes. 258 # For classifier free guidance, we need to do two forward passes.
213 # Here we concatenate the unconditional and text embeddings into a single batch 259 # Here we concatenate the unconditional and text embeddings into a single batch
diff --git a/textual_inversion.py b/textual_inversion.py
index 181a318..9d2840d 100644
--- a/textual_inversion.py
+++ b/textual_inversion.py
@@ -193,12 +193,6 @@ def parse_args():
193 ), 193 ),
194 ) 194 )
195 parser.add_argument( 195 parser.add_argument(
196 "--local_rank",
197 type=int,
198 default=-1,
199 help="For distributed training: local_rank"
200 )
201 parser.add_argument(
202 "--checkpoint_frequency", 196 "--checkpoint_frequency",
203 type=int, 197 type=int,
204 default=500, 198 default=500,
@@ -280,10 +274,6 @@ def parse_args():
280 args = parser.parse_args( 274 args = parser.parse_args(
281 namespace=argparse.Namespace(**json.load(f)["args"])) 275 namespace=argparse.Namespace(**json.load(f)["args"]))
282 276
283 env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
284 if env_local_rank != -1 and env_local_rank != args.local_rank:
285 args.local_rank = env_local_rank
286
287 if args.train_data_file is None: 277 if args.train_data_file is None:
288 raise ValueError("You must specify --train_data_file") 278 raise ValueError("You must specify --train_data_file")
289 279