summaryrefslogtreecommitdiffstats
path: root/training/common.py
diff options
context:
space:
mode:
Diffstat (limited to 'training/common.py')
-rw-r--r--training/common.py97
1 files changed, 52 insertions, 45 deletions
diff --git a/training/common.py b/training/common.py
index b6964a3..f5ab326 100644
--- a/training/common.py
+++ b/training/common.py
@@ -45,42 +45,44 @@ def generate_class_images(
45): 45):
46 missing_data = [item for item in data_train if not item.class_image_path.exists()] 46 missing_data = [item for item in data_train if not item.class_image_path.exists()]
47 47
48 if len(missing_data) != 0: 48 if len(missing_data) == 0:
49 batched_data = [ 49 return
50 missing_data[i:i+sample_batch_size] 50
51 for i in range(0, len(missing_data), sample_batch_size) 51 batched_data = [
52 ] 52 missing_data[i:i+sample_batch_size]
53 53 for i in range(0, len(missing_data), sample_batch_size)
54 pipeline = VlpnStableDiffusion( 54 ]
55 text_encoder=text_encoder, 55
56 vae=vae, 56 pipeline = VlpnStableDiffusion(
57 unet=unet, 57 text_encoder=text_encoder,
58 tokenizer=tokenizer, 58 vae=vae,
59 scheduler=scheduler, 59 unet=unet,
60 ).to(accelerator.device) 60 tokenizer=tokenizer,
61 pipeline.set_progress_bar_config(dynamic_ncols=True) 61 scheduler=scheduler,
62 62 ).to(accelerator.device)
63 with torch.inference_mode(): 63 pipeline.set_progress_bar_config(dynamic_ncols=True)
64 for batch in batched_data: 64
65 image_name = [item.class_image_path for item in batch] 65 with torch.inference_mode():
66 prompt = [item.cprompt for item in batch] 66 for batch in batched_data:
67 nprompt = [item.nprompt for item in batch] 67 image_name = [item.class_image_path for item in batch]
68 68 prompt = [item.cprompt for item in batch]
69 images = pipeline( 69 nprompt = [item.nprompt for item in batch]
70 prompt=prompt, 70
71 negative_prompt=nprompt, 71 images = pipeline(
72 height=sample_image_size, 72 prompt=prompt,
73 width=sample_image_size, 73 negative_prompt=nprompt,
74 num_inference_steps=sample_steps 74 height=sample_image_size,
75 ).images 75 width=sample_image_size,
76 76 num_inference_steps=sample_steps
77 for i, image in enumerate(images): 77 ).images
78 image.save(image_name[i]) 78
79 79 for i, image in enumerate(images):
80 del pipeline 80 image.save(image_name[i])
81 81
82 if torch.cuda.is_available(): 82 del pipeline
83 torch.cuda.empty_cache() 83
84 if torch.cuda.is_available():
85 torch.cuda.empty_cache()
84 86
85 87
86def get_models(pretrained_model_name_or_path: str): 88def get_models(pretrained_model_name_or_path: str):
@@ -119,7 +121,7 @@ def add_placeholder_tokens(
119 for (placeholder_token_id, initializer_token_id) in zip(placeholder_token_ids, initializer_token_ids): 121 for (placeholder_token_id, initializer_token_id) in zip(placeholder_token_ids, initializer_token_ids):
120 embeddings.add_embed(placeholder_token_id, initializer_token_id) 122 embeddings.add_embed(placeholder_token_id, initializer_token_id)
121 123
122 return placeholder_token_ids 124 return placeholder_token_ids, initializer_token_ids
123 125
124 126
125def loss_step( 127def loss_step(
@@ -127,7 +129,6 @@ def loss_step(
127 noise_scheduler: DDPMScheduler, 129 noise_scheduler: DDPMScheduler,
128 unet: UNet2DConditionModel, 130 unet: UNet2DConditionModel,
129 text_encoder: CLIPTextModel, 131 text_encoder: CLIPTextModel,
130 with_prior: bool,
131 prior_loss_weight: float, 132 prior_loss_weight: float,
132 seed: int, 133 seed: int,
133 step: int, 134 step: int,
@@ -138,16 +139,23 @@ def loss_step(
138 latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() 139 latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach()
139 latents = latents * 0.18215 140 latents = latents * 0.18215
140 141
142 generator = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None
143
141 # Sample noise that we'll add to the latents 144 # Sample noise that we'll add to the latents
142 noise = torch.randn_like(latents) 145 noise = torch.randn(
146 latents.shape,
147 dtype=latents.dtype,
148 layout=latents.layout,
149 device=latents.device,
150 generator=generator
151 )
143 bsz = latents.shape[0] 152 bsz = latents.shape[0]
144 # Sample a random timestep for each image 153 # Sample a random timestep for each image
145 timesteps_gen = torch.Generator(device=latents.device).manual_seed(seed + step) if eval else None
146 timesteps = torch.randint( 154 timesteps = torch.randint(
147 0, 155 0,
148 noise_scheduler.config.num_train_timesteps, 156 noise_scheduler.config.num_train_timesteps,
149 (bsz,), 157 (bsz,),
150 generator=timesteps_gen, 158 generator=generator,
151 device=latents.device, 159 device=latents.device,
152 ) 160 )
153 timesteps = timesteps.long() 161 timesteps = timesteps.long()
@@ -176,7 +184,7 @@ def loss_step(
176 else: 184 else:
177 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}") 185 raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
178 186
179 if with_prior: 187 if batch["with_prior"]:
180 # Chunk the noise and model_pred into two parts and compute the loss on each part separately. 188 # Chunk the noise and model_pred into two parts and compute the loss on each part separately.
181 model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0) 189 model_pred, model_pred_prior = torch.chunk(model_pred, 2, dim=0)
182 target, target_prior = torch.chunk(target, 2, dim=0) 190 target, target_prior = torch.chunk(target, 2, dim=0)
@@ -207,7 +215,6 @@ def train_loop(
207 val_dataloader: DataLoader, 215 val_dataloader: DataLoader,
208 loss_step: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]], 216 loss_step: Union[Callable[[int, Any], Tuple[Any, Any, int]], Callable[[int, Any, bool], Tuple[Any, Any, int]]],
209 sample_frequency: int = 10, 217 sample_frequency: int = 10,
210 sample_steps: int = 20,
211 checkpoint_frequency: int = 50, 218 checkpoint_frequency: int = 50,
212 global_step_offset: int = 0, 219 global_step_offset: int = 0,
213 num_epochs: int = 100, 220 num_epochs: int = 100,
@@ -251,7 +258,7 @@ def train_loop(
251 for epoch in range(num_epochs): 258 for epoch in range(num_epochs):
252 if accelerator.is_main_process: 259 if accelerator.is_main_process:
253 if epoch % sample_frequency == 0: 260 if epoch % sample_frequency == 0:
254 checkpointer.save_samples(global_step + global_step_offset, sample_steps) 261 checkpointer.save_samples(global_step + global_step_offset)
255 262
256 if epoch % checkpoint_frequency == 0 and epoch != 0: 263 if epoch % checkpoint_frequency == 0 and epoch != 0:
257 checkpointer.checkpoint(global_step + global_step_offset, "training") 264 checkpointer.checkpoint(global_step + global_step_offset, "training")
@@ -353,7 +360,7 @@ def train_loop(
353 if accelerator.is_main_process: 360 if accelerator.is_main_process:
354 print("Finished!") 361 print("Finished!")
355 checkpointer.checkpoint(global_step + global_step_offset, "end") 362 checkpointer.checkpoint(global_step + global_step_offset, "end")
356 checkpointer.save_samples(global_step + global_step_offset, sample_steps) 363 checkpointer.save_samples(global_step + global_step_offset)
357 accelerator.end_training() 364 accelerator.end_training()
358 365
359 except KeyboardInterrupt: 366 except KeyboardInterrupt: