From faab86b979e8aad5ff3bb4712399e977a59a2a98 Mon Sep 17 00:00:00 2001 From: Volpeon Date: Tue, 27 Sep 2022 08:53:10 +0200 Subject: Cleanup --- main.py | 102 ++++++++++++++++++++++++++++++++++++++++++++-------------------- 1 file changed, 71 insertions(+), 31 deletions(-) diff --git a/main.py b/main.py index 8be79e5..51b64c1 100644 --- a/main.py +++ b/main.py @@ -9,28 +9,21 @@ from typing import Optional import numpy as np import torch -import torch.nn as nn import torch.nn.functional as F import torch.utils.checkpoint -from torch.utils.data import Dataset -import PIL from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import LoggerType, set_seed -from diffusers import AutoencoderKL, DDPMScheduler, PNDMScheduler, LMSDiscreteScheduler, StableDiffusionPipeline, UNet2DConditionModel +from diffusers import AutoencoderKL, DDPMScheduler, LMSDiscreteScheduler, StableDiffusionPipeline, UNet2DConditionModel from diffusers.optimization import get_scheduler -from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker -from einops import rearrange from pipelines.stable_diffusion.no_check import NoCheck -from huggingface_hub import HfFolder, Repository, whoami from PIL import Image from tqdm.auto import tqdm from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from slugify import slugify import json import os -import sys from data import CSVDataModule @@ -39,7 +32,8 @@ logger = get_logger(__name__) def parse_args(): parser = argparse.ArgumentParser( - description="Simple example of a training script.") + description="Simple example of a training script." + ) parser.add_argument( "--pretrained_model_name_or_path", type=str, @@ -53,7 +47,10 @@ def parse_args(): help="Pretrained tokenizer name or path if not the same as model_name", ) parser.add_argument( - "--train_data_dir", type=str, default=None, help="A folder containing the training data." + "--train_data_dir", + type=str, + default=None, + help="A folder containing the training data." ) parser.add_argument( "--placeholder_token", @@ -62,21 +59,33 @@ def parse_args(): help="A token to use as a placeholder for the concept.", ) parser.add_argument( - "--initializer_token", type=str, default=None, help="A token to use as initializer word." + "--initializer_token", + type=str, + default=None, + help="A token to use as initializer word." ) parser.add_argument( - "--vectors_per_token", type=int, default=1, help="Vectors per token." + "--vectors_per_token", + type=int, + default=1, + help="Vectors per token." ) - parser.add_argument("--repeats", type=int, default=100, - help="How many times to repeat the training data.") + parser.add_argument( + "--repeats", + type=int, + default=100, + help="How many times to repeat the training data.") parser.add_argument( "--output_dir", type=str, default="text-inversion-model", help="The output directory where the model predictions and checkpoints will be written.", ) - parser.add_argument("--seed", type=int, default=None, - help="A seed for reproducible training.") + parser.add_argument( + "--seed", + type=int, + default=None, + help="A seed for reproducible training.") parser.add_argument( "--resolution", type=int, @@ -87,12 +96,14 @@ def parse_args(): ), ) parser.add_argument( - "--center_crop", action="store_true", help="Whether to center crop images before resizing to resolution" + "--center_crop", + action="store_true", + help="Whether to center crop images before resizing to resolution" ) parser.add_argument( - "--train_batch_size", type=int, default=1, help="Batch size (per device) for the training dataloader." - ) - parser.add_argument("--num_train_epochs", type=int, default=100) + "--num_train_epochs", + type=int, + default=100) parser.add_argument( "--max_train_steps", type=int, @@ -132,16 +143,35 @@ def parse_args(): ), ) parser.add_argument( - "--lr_warmup_steps", type=int, default=500, help="Number of steps for the warmup in the lr scheduler." + "--lr_warmup_steps", + type=int, + default=500, + help="Number of steps for the warmup in the lr scheduler." + ) + parser.add_argument( + "--adam_beta1", + type=float, + default=0.9, + help="The beta1 parameter for the Adam optimizer." + ) + parser.add_argument( + "--adam_beta2", + type=float, + default=0.999, + help="The beta2 parameter for the Adam optimizer." + ) + parser.add_argument( + "--adam_weight_decay", + type=float, + default=1e-2, + help="Weight decay to use." + ) + parser.add_argument( + "--adam_epsilon", + type=float, + default=1e-08, + help="Epsilon value for the Adam optimizer" ) - parser.add_argument("--adam_beta1", type=float, default=0.9, - help="The beta1 parameter for the Adam optimizer.") - parser.add_argument("--adam_beta2", type=float, default=0.999, - help="The beta2 parameter for the Adam optimizer.") - parser.add_argument("--adam_weight_decay", type=float, - default=1e-2, help="Weight decay to use.") - parser.add_argument("--adam_epsilon", type=float, default=1e-08, - help="Epsilon value for the Adam optimizer") parser.add_argument( "--mixed_precision", type=str, @@ -153,8 +183,12 @@ def parse_args(): "and an Nvidia Ampere GPU." ), ) - parser.add_argument("--local_rank", type=int, default=-1, - help="For distributed training: local_rank") + parser.add_argument( + "--local_rank", + type=int, + default=-1, + help="For distributed training: local_rank" + ) parser.add_argument( "--checkpoint_frequency", type=int, @@ -185,6 +219,12 @@ def parse_args(): default=1, help="Number of samples to generate per batch", ) + parser.add_argument( + "--train_batch_size", + type=int, + default=1, + help="Batch size (per device) for the training dataloader." + ) parser.add_argument( "--sample_steps", type=int, -- cgit v1.2.3-70-g09d2