diff options
-rw-r--r-- | main.py | 102 |
1 files changed, 71 insertions, 31 deletions
@@ -9,28 +9,21 @@ from typing import Optional | |||
9 | 9 | ||
10 | import numpy as np | 10 | import numpy as np |
11 | import torch | 11 | import torch |
12 | import torch.nn as nn | ||
13 | import torch.nn.functional as F | 12 | import torch.nn.functional as F |
14 | import torch.utils.checkpoint | 13 | import torch.utils.checkpoint |
15 | from torch.utils.data import Dataset | ||
16 | 14 | ||
17 | import PIL | ||
18 | from accelerate import Accelerator | 15 | from accelerate import Accelerator |
19 | from accelerate.logging import get_logger | 16 | from accelerate.logging import get_logger |
20 | from accelerate.utils import LoggerType, set_seed | 17 | from accelerate.utils import LoggerType, set_seed |
21 | from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, LMSDiscreteScheduler, StableDiffusionPipeline, UNet2DConditionModel | 18 | from diffusers import AutoencoderKL, DDPMScheduler, LMSDiscreteScheduler, StableDiffusionPipeline, UNet2DConditionModel |
22 | from diffusers.optimization import get_scheduler | 19 | from diffusers.optimization import get_scheduler |
23 | from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker | ||
24 | from einops import rearrange | ||
25 | from pipelines.stable_diffusion.no_check import NoCheck | 20 | from pipelines.stable_diffusion.no_check import NoCheck |
26 | from huggingface_hub import HfFolder, Repository, whoami | ||
27 | from PIL import Image | 21 | from PIL import Image |
28 | from tqdm.auto import tqdm | 22 | from tqdm.auto import tqdm |
29 | from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer | 23 | from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer |
30 | from slugify import slugify | 24 | from slugify import slugify |
31 | import json | 25 | import json |
32 | import os | 26 | import os |
33 | import sys | ||
34 | 27 | ||
35 | from data import CSVDataModule | 28 | from data import CSVDataModule |
36 | 29 | ||
@@ -39,7 +32,8 @@ logger = get_logger(__name__) | |||
39 | 32 | ||
40 | def parse_args(): | 33 | def parse_args(): |
41 | parser = argparse.ArgumentParser( | 34 | parser = argparse.ArgumentParser( |
42 | description="Simple example of a training script.") | 35 | description="Simple example of a training script." |
36 | ) | ||
43 | parser.add_argument( | 37 | parser.add_argument( |
44 | "--pretrained_model_name_or_path", | 38 | "--pretrained_model_name_or_path", |
45 | type=str, | 39 | type=str, |
@@ -53,7 +47,10 @@ def parse_args(): | |||
53 | help="Pretrained tokenizer name or path if not the same as model_name", | 47 | help="Pretrained tokenizer name or path if not the same as model_name", |
54 | ) | 48 | ) |
55 | parser.add_argument( | 49 | parser.add_argument( |
56 | "--train_data_dir", type=str, default=None, help="A folder containing the training data." | 50 | "--train_data_dir", |
51 | type=str, | ||
52 | default=None, | ||
53 | help="A folder containing the training data." | ||
57 | ) | 54 | ) |
58 | parser.add_argument( | 55 | parser.add_argument( |
59 | "--placeholder_token", | 56 | "--placeholder_token", |
@@ -62,21 +59,33 @@ def parse_args(): | |||
62 | help="A token to use as a placeholder for the concept.", | 59 | help="A token to use as a placeholder for the concept.", |
63 | ) | 60 | ) |
64 | parser.add_argument( | 61 | parser.add_argument( |
65 | "--initializer_token", type=str, default=None, help="A token to use as initializer word." | 62 | "--initializer_token", |
63 | type=str, | ||
64 | default=None, | ||
65 | help="A token to use as initializer word." | ||
66 | ) | 66 | ) |
67 | parser.add_argument( | 67 | parser.add_argument( |
68 | "--vectors_per_token", type=int, default=1, help="Vectors per token." | 68 | "--vectors_per_token", |
69 | type=int, | ||
70 | default=1, | ||
71 | help="Vectors per token." | ||
69 | ) | 72 | ) |
70 | parser.add_argument("--repeats", type=int, default=100, | 73 | parser.add_argument( |
71 | help="How many times to repeat the training data.") | 74 | "--repeats", |
75 | type=int, | ||
76 | default=100, | ||
77 | help="How many times to repeat the training data.") | ||
72 | parser.add_argument( | 78 | parser.add_argument( |
73 | "--output_dir", | 79 | "--output_dir", |
74 | type=str, | 80 | type=str, |
75 | default="text-inversion-model", | 81 | default="text-inversion-model", |
76 | help="The output directory where the model predictions and checkpoints will be written.", | 82 | help="The output directory where the model predictions and checkpoints will be written.", |
77 | ) | 83 | ) |
78 | parser.add_argument("--seed", type=int, default=None, | 84 | parser.add_argument( |
79 | help="A seed for reproducible training.") | 85 | "--seed", |
86 | type=int, | ||
87 | default=None, | ||
88 | help="A seed for reproducible training.") | ||
80 | parser.add_argument( | 89 | parser.add_argument( |
81 | "--resolution", | 90 | "--resolution", |
82 | type=int, | 91 | type=int, |
@@ -87,12 +96,14 @@ def parse_args(): | |||
87 | ), | 96 | ), |
88 | ) | 97 | ) |
89 | parser.add_argument( | 98 | parser.add_argument( |
90 | "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution" | 99 | "--center_crop", |
100 | action="store_true", | ||
101 | help="Whether to center crop images before resizing to resolution" | ||
91 | ) | 102 | ) |
92 | parser.add_argument( | 103 | parser.add_argument( |
93 | "--train_batch_size", type=int, default=1, help="Batch size (per device) for the training dataloader." | 104 | "--num_train_epochs", |
94 | ) | 105 | type=int, |
95 | parser.add_argument("--num_train_epochs", type=int, default=100) | 106 | default=100) |
96 | parser.add_argument( | 107 | parser.add_argument( |
97 | "--max_train_steps", | 108 | "--max_train_steps", |
98 | type=int, | 109 | type=int, |
@@ -132,16 +143,35 @@ def parse_args(): | |||
132 | ), | 143 | ), |
133 | ) | 144 | ) |
134 | parser.add_argument( | 145 | parser.add_argument( |
135 | "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." | 146 | "--lr_warmup_steps", |
147 | type=int, | ||
148 | default=500, | ||
149 | help="Number of steps for the warmup in the lr scheduler." | ||
150 | ) | ||
151 | parser.add_argument( | ||
152 | "--adam_beta1", | ||
153 | type=float, | ||
154 | default=0.9, | ||
155 | help="The beta1 parameter for the Adam optimizer." | ||
156 | ) | ||
157 | parser.add_argument( | ||
158 | "--adam_beta2", | ||
159 | type=float, | ||
160 | default=0.999, | ||
161 | help="The beta2 parameter for the Adam optimizer." | ||
162 | ) | ||
163 | parser.add_argument( | ||
164 | "--adam_weight_decay", | ||
165 | type=float, | ||
166 | default=1e-2, | ||
167 | help="Weight decay to use." | ||
168 | ) | ||
169 | parser.add_argument( | ||
170 | "--adam_epsilon", | ||
171 | type=float, | ||
172 | default=1e-08, | ||
173 | help="Epsilon value for the Adam optimizer" | ||
136 | ) | 174 | ) |
137 | parser.add_argument("--adam_beta1", type=float, default=0.9, | ||
138 | help="The beta1 parameter for the Adam optimizer.") | ||
139 | parser.add_argument("--adam_beta2", type=float, default=0.999, | ||
140 | help="The beta2 parameter for the Adam optimizer.") | ||
141 | parser.add_argument("--adam_weight_decay", type=float, | ||
142 | default=1e-2, help="Weight decay to use.") | ||
143 | parser.add_argument("--adam_epsilon", type=float, default=1e-08, | ||
144 | help="Epsilon value for the Adam optimizer") | ||
145 | parser.add_argument( | 175 | parser.add_argument( |
146 | "--mixed_precision", | 176 | "--mixed_precision", |
147 | type=str, | 177 | type=str, |
@@ -153,8 +183,12 @@ def parse_args(): | |||
153 | "and an Nvidia Ampere GPU." | 183 | "and an Nvidia Ampere GPU." |
154 | ), | 184 | ), |
155 | ) | 185 | ) |
156 | parser.add_argument("--local_rank", type=int, default=-1, | 186 | parser.add_argument( |
157 | help="For distributed training: local_rank") | 187 | "--local_rank", |
188 | type=int, | ||
189 | default=-1, | ||
190 | help="For distributed training: local_rank" | ||
191 | ) | ||
158 | parser.add_argument( | 192 | parser.add_argument( |
159 | "--checkpoint_frequency", | 193 | "--checkpoint_frequency", |
160 | type=int, | 194 | type=int, |
@@ -186,6 +220,12 @@ def parse_args(): | |||
186 | help="Number of samples to generate per batch", | 220 | help="Number of samples to generate per batch", |
187 | ) | 221 | ) |
188 | parser.add_argument( | 222 | parser.add_argument( |
223 | "--train_batch_size", | ||
224 | type=int, | ||
225 | default=1, | ||
226 | help="Batch size (per device) for the training dataloader." | ||
227 | ) | ||
228 | parser.add_argument( | ||
189 | "--sample_steps", | 229 | "--sample_steps", |
190 | type=int, | 230 | type=int, |
191 | default=50, | 231 | default=50, |