summaryrefslogtreecommitdiffstats
path: root/training
diff options
context:
space:
mode:
authorVolpeon <git@volpeon.ink>2023-01-14 09:25:13 +0100
committerVolpeon <git@volpeon.ink>2023-01-14 09:25:13 +0100
commite2d3a62bce63fcde940395a1c5618c4eb43385a9 (patch)
tree574f7a794feab13e1cf0ed18522a66d4737b6db3 /training
parentUnified training script structure (diff)
downloadtextual-inversion-diff-e2d3a62bce63fcde940395a1c5618c4eb43385a9.tar.gz
textual-inversion-diff-e2d3a62bce63fcde940395a1c5618c4eb43385a9.tar.bz2
textual-inversion-diff-e2d3a62bce63fcde940395a1c5618c4eb43385a9.zip
Cleanup
Diffstat (limited to 'training')
-rw-r--r--training/common.py97
-rw-r--r--training/util.py26
2 files changed, 63 insertions, 60 deletions
diff --git a/training/common.py b/training/common.py
index b6964a3..f5ab326 100644
--- a/training/common.py
+++ b/training/common.py
@@ -45,42 +45,44 @@ def generate_class_images(
45): 45):
46 missing_data = [item for item in data_train if not item.class_image_path.exists()] 46 missing_data = [item for item in data_train if not item.class_image_path.exists()]
47 47
48 if len(missing_data) != 0: 48 if len(missing_data) == 0:
49 batched_data = [ 49 return
50 missing_data[i:i+sample_batch_size] 50
51 for i in range(0, len(missing_data), sample_batch_size) 51 batched_data = [
52 ] 52 missing_data[i:i+sample_batch_size]
53 53 for i in range(0, len(missing_data), sample_batch_size)
54 pipeline = VlpnStableDiffusion( 54 ]
55 text_encoder=text_encoder, 55
56 vae=vae, 56 pipeline = VlpnStableDiffusion(
57 unet=unet, 57 text_encoder=text_encoder,
58 tokenizer=tokenizer, 58 vae=vae,
59 scheduler=scheduler, 59 unet=unet,
60 ).to(accelerator.device) 60 tokenizer=tokenizer,
61 pipeline.set_progress_bar_config(dynamic_ncols=True) 61 scheduler=scheduler,
62 62 ).to(accelerator.device)
63 with torch.inference_mode(): 63 pipeline.set_progress_bar_config(dynamic_ncols=True)
64 for batch in batched_data: 64
65 image_name = [item.class_image_path for item in batch] 65 with torch.inference_mode():
66 prompt = [item.cprompt for item in batch] 66 for batch in batched_data:
67 nprompt = [item.nprompt for item in batch] 67 image_name = [item.class_image_path for item in batch]
68 68 prompt = [item.cprompt for item in batch]
69 images = pipeline( 69 nprompt = [item.nprompt for item in batch]
70 prompt=prompt, 70
71 negative_prompt=nprompt, 71 images = pipeline(
72 height=sample_image_size, 72 prompt=prompt,
73 width=sample_image_size, 73 negative_prompt=nprompt,
74 num_inference_steps=sample_steps 74 height=sample_image_size,
75 ).images 75 width=sample_image_size,
76 76 num_inference_steps=sample_steps
77 for i, image in enumerate(images): 77 ).images
78 image.save(image_name[i]) 78
79 79 for i, image in enumerate(images):
80 del pipeline 80 image.save(image_name[i])
81 81
82 if torch.cuda.is_available(): 82 del pipeline
83 torch.cuda.empty_cache() 83
84 if torch.cuda.is_available():
85 torch.cuda.empty_cache()
84 86
85 87
86def get_models(pretrained_model_name_or_path: str): 88def get_models(pretrained_model_name_or_path: str):
@@ -119,7 +121,7 @@ def add_placeholder_tokens(
119 for (placeholder_token_id, initializer_token_id) in zip(placeholder_token_ids, initializer_token_ids): 121 for (placeholder_token_id, initializer_token_id) in zip(placeholder_token_ids, initializer_token_ids):
120 embeddings.add_embed(placeholder_token_id, initializer_token_id) 122 embeddings.add_embed(placeholder_token_id, initializer_token_id)
121 123
122 return placeholder_token_ids 124 return placeholder_token_ids, initializer_token_ids
123 125
124 126
125def loss_step( 127def loss_step(
@@ -127,7 +129,6 @@ def loss_step(
127 noise_scheduler: DDPMScheduler, 129 noise_scheduler: DDPMScheduler,
128 unet: UNet2DConditionModel, 130 unet: UNet2DConditionModel,
129 text_encoder: CLIPTextModel, 131 text_encoder: CLIPTextModel,
130 with_prior: bool,
131 prior_loss_weight: float, 132 prior_loss_weight: float,
132 seed: int, 133 seed: int,
133 step: int, 134 step: int,
@@ -138,16 +139,23 @@ def loss_step(
138 latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() 139 latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach()
139 latents = latents * 0.18215 140 latents = latents * 0.18215
140 141
142 generator = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None
143
141 # Sample noise that we'll add to the latents 144 # Sample noise that we'll add to the latents
142 noise = torch.randn_like(latents) 145 noise = torch.randn(
146 latents.shape,
147 dtype=latents.dtype,
148 layout=latents.layout,
149 device=latents.device,
150 generator=generator
151 )
143 bsz = latents.shape[0] 152 bsz = latents.shape[0]
144 # Sample a random timestep for each image 153 # Sample a random timestep for each image
145 timesteps_gen = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None
146 timesteps = torch.randint( 154 timesteps = torch.randint(
147 0, 155 0,
148 noise_scheduler.config.num_train_timesteps, 156 noise_scheduler.config.num_train_timesteps,
149 (bsz,), 157 (bsz,),
150 generator=timesteps_gen, 158 generator=generator,
151 device=latents.device, 159 device=latents.device,
152 ) 160 )
153 timesteps = timesteps.long() 161 timesteps = timesteps.long()
@@ -176,7 +184,7 @@ def loss_step(
176 else: 184 else:
177 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") 185 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
178 186
179 if with_prior: 187 if batch["with_prior"]:
180 # Chunk the noise and model_pred into two parts and compute the loss on each part separately. 188 # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
181 model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) 189 model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
182 target, target_prior = torch.chunk(target, 2, dim=0) 190 target, target_prior = torch.chunk(target, 2, dim=0)
@@ -207,7 +215,6 @@ def train_loop(
207 val_dataloader: DataLoader, 215 val_dataloader: DataLoader,
208 loss_step: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], 216 loss_step: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]],
209 sample_frequency: int = 10, 217 sample_frequency: int = 10,
210 sample_steps: int = 20,
211 checkpoint_frequency: int = 50, 218 checkpoint_frequency: int = 50,
212 global_step_offset: int = 0, 219 global_step_offset: int = 0,
213 num_epochs: int = 100, 220 num_epochs: int = 100,
@@ -251,7 +258,7 @@ def train_loop(
251 for epoch in range(num_epochs): 258 for epoch in range(num_epochs):
252 if accelerator.is_main_process: 259 if accelerator.is_main_process:
253 if epoch % sample_frequency == 0: 260 if epoch % sample_frequency == 0:
254 checkpointer.save_samples(global_step + global_step_offset, sample_steps) 261 checkpointer.save_samples(global_step + global_step_offset)
255 262
256 if epoch % checkpoint_frequency == 0 and epoch != 0: 263 if epoch % checkpoint_frequency == 0 and epoch != 0:
257 checkpointer.checkpoint(global_step + global_step_offset, "training") 264 checkpointer.checkpoint(global_step + global_step_offset, "training")
@@ -353,7 +360,7 @@ def train_loop(
353 if accelerator.is_main_process: 360 if accelerator.is_main_process:
354 print("Finished!") 361 print("Finished!")
355 checkpointer.checkpoint(global_step + global_step_offset, "end") 362 checkpointer.checkpoint(global_step + global_step_offset, "end")
356 checkpointer.save_samples(global_step + global_step_offset, sample_steps) 363 checkpointer.save_samples(global_step + global_step_offset)
357 accelerator.end_training() 364 accelerator.end_training()
358 365
359 except KeyboardInterrupt: 366 except KeyboardInterrupt:
diff --git a/training/util.py b/training/util.py
index cc4cdee..1008021 100644
--- a/training/util.py
+++ b/training/util.py
@@ -44,32 +44,29 @@ class CheckpointerBase:
44 train_dataloader, 44 train_dataloader,
45 val_dataloader, 45 val_dataloader,
46 output_dir: Path, 46 output_dir: Path,
47 sample_image_size: int, 47 sample_steps: int = 20,
48 sample_batches: int, 48 sample_guidance_scale: float = 7.5,
49 sample_batch_size: int, 49 sample_image_size: int = 768,
50 sample_batches: int = 1,
51 sample_batch_size: int = 1,
50 seed: Optional[int] = None 52 seed: Optional[int] = None
51 ): 53 ):
52 self.train_dataloader = train_dataloader 54 self.train_dataloader = train_dataloader
53 self.val_dataloader = val_dataloader 55 self.val_dataloader = val_dataloader
54 self.output_dir = output_dir 56 self.output_dir = output_dir
55 self.sample_image_size = sample_image_size 57 self.sample_image_size = sample_image_size
56 self.seed = seed if seed is not None else torch.random.seed() 58 self.sample_steps = sample_steps
59 self.sample_guidance_scale = sample_guidance_scale
57 self.sample_batches = sample_batches 60 self.sample_batches = sample_batches
58 self.sample_batch_size = sample_batch_size 61 self.sample_batch_size = sample_batch_size
62 self.seed = seed if seed is not None else torch.random.seed()
59 63
60 @torch.no_grad() 64 @torch.no_grad()
61 def checkpoint(self, step: int, postfix: str): 65 def checkpoint(self, step: int, postfix: str):
62 pass 66 pass
63 67
64 @torch.inference_mode() 68 @torch.inference_mode()
65 def save_samples( 69 def save_samples(self, pipeline, step: int):
66 self,
67 pipeline,
68 step: int,
69 num_inference_steps: int,
70 guidance_scale: float = 7.5,
71 eta: float = 0.0
72 ):
73 samples_path = Path(self.output_dir).joinpath("samples") 70 samples_path = Path(self.output_dir).joinpath("samples")
74 71
75 generator = torch.Generator(device=pipeline.device).manual_seed(self.seed) 72 generator = torch.Generator(device=pipeline.device).manual_seed(self.seed)
@@ -110,9 +107,8 @@ class CheckpointerBase:
110 height=self.sample_image_size, 107 height=self.sample_image_size,
111 width=self.sample_image_size, 108 width=self.sample_image_size,
112 generator=gen, 109 generator=gen,
113 guidance_scale=guidance_scale, 110 guidance_scale=self.sample_guidance_scale,
114 eta=eta, 111 num_inference_steps=self.sample_steps,
115 num_inference_steps=num_inference_steps,
116 output_type='pil' 112 output_type='pil'
117 ).images 113 ).images
118 114