summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--main.py102
1 files changed, 71 insertions, 31 deletions
diff --git a/main.py b/main.py
index 8be79e5..51b64c1 100644
--- a/main.py
+++ b/main.py
@@ -9,28 +9,21 @@ from typing import Optional
9 9
10import numpy as np 10import numpy as np
11import torch 11import torch
12import torch.nn as nn
13import torch.nn.functional as F 12import torch.nn.functional as F
14import torch.utils.checkpoint 13import torch.utils.checkpoint
15from torch.utils.data import Dataset
16 14
17import PIL
18from accelerate import Accelerator 15from accelerate import Accelerator
19from accelerate.logging import get_logger 16from accelerate.logging import get_logger
20from accelerate.utils import LoggerType, set_seed 17from accelerate.utils import LoggerType, set_seed
21from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, LMSDiscreteScheduler, StableDiffusionPipeline, UNet2DConditionModel 18from diffusers import AutoencoderKL, DDPMScheduler, LMSDiscreteScheduler, StableDiffusionPipeline, UNet2DConditionModel
22from diffusers.optimization import get_scheduler 19from diffusers.optimization import get_scheduler
23from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
24from einops import rearrange
25from pipelines.stable_diffusion.no_check import NoCheck 20from pipelines.stable_diffusion.no_check import NoCheck
26from huggingface_hub import HfFolder, Repository, whoami
27from PIL import Image 21from PIL import Image
28from tqdm.auto import tqdm 22from tqdm.auto import tqdm
29from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer 23from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
30from slugify import slugify 24from slugify import slugify
31import json 25import json
32import os 26import os
33import sys
34 27
35from data import CSVDataModule 28from data import CSVDataModule
36 29
@@ -39,7 +32,8 @@ logger = get_logger(__name__)
39 32
40def parse_args(): 33def parse_args():
41 parser = argparse.ArgumentParser( 34 parser = argparse.ArgumentParser(
42 description="Simple example of a training script.") 35 description="Simple example of a training script."
36 )
43 parser.add_argument( 37 parser.add_argument(
44 "--pretrained_model_name_or_path", 38 "--pretrained_model_name_or_path",
45 type=str, 39 type=str,
@@ -53,7 +47,10 @@ def parse_args():
53 help="Pretrained tokenizer name or path if not the same as model_name", 47 help="Pretrained tokenizer name or path if not the same as model_name",
54 ) 48 )
55 parser.add_argument( 49 parser.add_argument(
56 "--train_data_dir", type=str, default=None, help="A folder containing the training data." 50 "--train_data_dir",
51 type=str,
52 default=None,
53 help="A folder containing the training data."
57 ) 54 )
58 parser.add_argument( 55 parser.add_argument(
59 "--placeholder_token", 56 "--placeholder_token",
@@ -62,21 +59,33 @@ def parse_args():
62 help="A token to use as a placeholder for the concept.", 59 help="A token to use as a placeholder for the concept.",
63 ) 60 )
64 parser.add_argument( 61 parser.add_argument(
65 "--initializer_token", type=str, default=None, help="A token to use as initializer word." 62 "--initializer_token",
63 type=str,
64 default=None,
65 help="A token to use as initializer word."
66 ) 66 )
67 parser.add_argument( 67 parser.add_argument(
68 "--vectors_per_token", type=int, default=1, help="Vectors per token." 68 "--vectors_per_token",
69 type=int,
70 default=1,
71 help="Vectors per token."
69 ) 72 )
70 parser.add_argument("--repeats", type=int, default=100, 73 parser.add_argument(
71 help="How many times to repeat the training data.") 74 "--repeats",
75 type=int,
76 default=100,
77 help="How many times to repeat the training data.")
72 parser.add_argument( 78 parser.add_argument(
73 "--output_dir", 79 "--output_dir",
74 type=str, 80 type=str,
75 default="text-inversion-model", 81 default="text-inversion-model",
76 help="The output directory where the model predictions and checkpoints will be written.", 82 help="The output directory where the model predictions and checkpoints will be written.",
77 ) 83 )
78 parser.add_argument("--seed", type=int, default=None, 84 parser.add_argument(
79 help="A seed for reproducible training.") 85 "--seed",
86 type=int,
87 default=None,
88 help="A seed for reproducible training.")
80 parser.add_argument( 89 parser.add_argument(
81 "--resolution", 90 "--resolution",
82 type=int, 91 type=int,
@@ -87,12 +96,14 @@ def parse_args():
87 ), 96 ),
88 ) 97 )
89 parser.add_argument( 98 parser.add_argument(
90 "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution" 99 "--center_crop",
100 action="store_true",
101 help="Whether to center crop images before resizing to resolution"
91 ) 102 )
92 parser.add_argument( 103 parser.add_argument(
93 "--train_batch_size", type=int, default=1, help="Batch size (per device) for the training dataloader." 104 "--num_train_epochs",
94 ) 105 type=int,
95 parser.add_argument("--num_train_epochs", type=int, default=100) 106 default=100)
96 parser.add_argument( 107 parser.add_argument(
97 "--max_train_steps", 108 "--max_train_steps",
98 type=int, 109 type=int,
@@ -132,16 +143,35 @@ def parse_args():
132 ), 143 ),
133 ) 144 )
134 parser.add_argument( 145 parser.add_argument(
135 "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." 146 "--lr_warmup_steps",
147 type=int,
148 default=500,
149 help="Number of steps for the warmup in the lr scheduler."
150 )
151 parser.add_argument(
152 "--adam_beta1",
153 type=float,
154 default=0.9,
155 help="The beta1 parameter for the Adam optimizer."
156 )
157 parser.add_argument(
158 "--adam_beta2",
159 type=float,
160 default=0.999,
161 help="The beta2 parameter for the Adam optimizer."
162 )
163 parser.add_argument(
164 "--adam_weight_decay",
165 type=float,
166 default=1e-2,
167 help="Weight decay to use."
168 )
169 parser.add_argument(
170 "--adam_epsilon",
171 type=float,
172 default=1e-08,
173 help="Epsilon value for the Adam optimizer"
136 ) 174 )
137 parser.add_argument("--adam_beta1", type=float, default=0.9,
138 help="The beta1 parameter for the Adam optimizer.")
139 parser.add_argument("--adam_beta2", type=float, default=0.999,
140 help="The beta2 parameter for the Adam optimizer.")
141 parser.add_argument("--adam_weight_decay", type=float,
142 default=1e-2, help="Weight decay to use.")
143 parser.add_argument("--adam_epsilon", type=float, default=1e-08,
144 help="Epsilon value for the Adam optimizer")
145 parser.add_argument( 175 parser.add_argument(
146 "--mixed_precision", 176 "--mixed_precision",
147 type=str, 177 type=str,
@@ -153,8 +183,12 @@ def parse_args():
153 "and an Nvidia Ampere GPU." 183 "and an Nvidia Ampere GPU."
154 ), 184 ),
155 ) 185 )
156 parser.add_argument("--local_rank", type=int, default=-1, 186 parser.add_argument(
157 help="For distributed training: local_rank") 187 "--local_rank",
188 type=int,
189 default=-1,
190 help="For distributed training: local_rank"
191 )
158 parser.add_argument( 192 parser.add_argument(
159 "--checkpoint_frequency", 193 "--checkpoint_frequency",
160 type=int, 194 type=int,
@@ -186,6 +220,12 @@ def parse_args():
186 help="Number of samples to generate per batch", 220 help="Number of samples to generate per batch",
187 ) 221 )
188 parser.add_argument( 222 parser.add_argument(
223 "--train_batch_size",
224 type=int,
225 default=1,
226 help="Batch size (per device) for the training dataloader."
227 )
228 parser.add_argument(
189 "--sample_steps", 229 "--sample_steps",
190 type=int, 230 type=int,
191 default=50, 231 default=50,