diff options
Diffstat (limited to 'trainer_old/ti.py')
| -rw-r--r-- | trainer_old/ti.py | 168 |
1 files changed, 168 insertions, 0 deletions
diff --git a/trainer_old/ti.py b/trainer_old/ti.py new file mode 100644 index 0000000..66393af --- /dev/null +++ b/trainer_old/ti.py | |||
| @@ -0,0 +1,168 @@ | |||
| 1 | from contextlib import contextmanager, nullcontext | ||
| 2 | |||
| 3 | import torch | ||
| 4 | |||
| 5 | from slugify import slugify | ||
| 6 | |||
| 7 | from diffusers import UNet2DConditionModel | ||
| 8 | from transformers import CLIPTextModel | ||
| 9 | |||
| 10 | from trainer.base import TrainingStrategy, Checkpointer | ||
| 11 | from training.util import EMAModel | ||
| 12 | |||
| 13 | |||
| 14 | class TextualInversionCheckpointer(Checkpointer): | ||
| 15 | def __init__( | ||
| 16 | self, | ||
| 17 | ema_embeddings: EMAModel, | ||
| 18 | placeholder_tokens: list[str], | ||
| 19 | placeholder_token_ids: list[list[int]], | ||
| 20 | *args, | ||
| 21 | **kwargs, | ||
| 22 | ): | ||
| 23 | super().__init__(*args, **kwargs) | ||
| 24 | |||
| 25 | self.ema_embeddings = ema_embeddings | ||
| 26 | self.placeholder_tokens = placeholder_tokens | ||
| 27 | self.placeholder_token_ids = placeholder_token_ids | ||
| 28 | |||
| 29 | @torch.no_grad() | ||
| 30 | def checkpoint(self, step, postfix): | ||
| 31 | print(f"Saving checkpoint for step {step}...") | ||
| 32 | |||
| 33 | checkpoints_path = self.output_dir.joinpath("checkpoints") | ||
| 34 | checkpoints_path.mkdir(parents=True, exist_ok=True) | ||
| 35 | |||
| 36 | text_encoder = self.accelerator.unwrap_model(self.text_encoder) | ||
| 37 | |||
| 38 | ema_context = self.ema_embeddings.apply_temporary( | ||
| 39 | text_encoder.text_model.embeddings.temp_token_embedding.parameters() | ||
| 40 | ) if self.ema_embeddings is not None else nullcontext() | ||
| 41 | |||
| 42 | with ema_context: | ||
| 43 | for (token, ids) in zip(self.placeholder_tokens, self.placeholder_token_ids): | ||
| 44 | text_encoder.text_model.embeddings.save_embed( | ||
| 45 | ids, | ||
| 46 | checkpoints_path.joinpath(f"{slugify(token)}_{step}_{postfix}.bin") | ||
| 47 | ) | ||
| 48 | |||
| 49 | @torch.no_grad() | ||
| 50 | def save_samples(self, step): | ||
| 51 | ema_context = self.ema_embeddings.apply_temporary( | ||
| 52 | self.text_encoder.text_model.embeddings.temp_token_embedding.parameters() | ||
| 53 | ) if self.ema_embeddings is not None else nullcontext() | ||
| 54 | |||
| 55 | with ema_context: | ||
| 56 | super().save_samples(step) | ||
| 57 | |||
| 58 | |||
| 59 | class TextualInversionTrainingStrategy(TrainingStrategy): | ||
| 60 | def __init__( | ||
| 61 | self, | ||
| 62 | unet: UNet2DConditionModel, | ||
| 63 | text_encoder: CLIPTextModel, | ||
| 64 | placeholder_tokens: list[str], | ||
| 65 | placeholder_token_ids: list[list[int]], | ||
| 66 | learning_rate: float, | ||
| 67 | gradient_checkpointing: bool = False, | ||
| 68 | use_emb_decay: bool = False, | ||
| 69 | emb_decay_target: float = 0.4, | ||
| 70 | emb_decay_factor: float = 1, | ||
| 71 | emb_decay_start: float = 1e-4, | ||
| 72 | use_ema: bool = False, | ||
| 73 | ema_inv_gamma: float = 1.0, | ||
| 74 | ema_power: int = 1, | ||
| 75 | ema_max_decay: float = 0.9999, | ||
| 76 | *args, | ||
| 77 | **kwargs, | ||
| 78 | ): | ||
| 79 | super().__init__( | ||
| 80 | unet=unet, | ||
| 81 | text_encoder=text_encoder, | ||
| 82 | *args, | ||
| 83 | **kwargs | ||
| 84 | ) | ||
| 85 | |||
| 86 | self.text_encoder = text_encoder | ||
| 87 | self.unet = unet | ||
| 88 | |||
| 89 | self.placeholder_tokens = placeholder_tokens | ||
| 90 | self.placeholder_token_ids = placeholder_token_ids | ||
| 91 | |||
| 92 | self.gradient_checkpointing = gradient_checkpointing | ||
| 93 | |||
| 94 | self.learning_rate = learning_rate | ||
| 95 | self.use_emb_decay = use_emb_decay | ||
| 96 | self.emb_decay_target = emb_decay_target | ||
| 97 | self.emb_decay_factor = emb_decay_factor | ||
| 98 | self.emb_decay_start = emb_decay_start | ||
| 99 | |||
| 100 | self.text_encoder.text_model.embeddings.temp_token_embedding.requires_grad_(True) | ||
| 101 | |||
| 102 | self.ema_embeddings = None | ||
| 103 | |||
| 104 | if use_ema: | ||
| 105 | self.ema_embeddings = EMAModel( | ||
| 106 | self.text_encoder.text_model.embeddings.temp_token_embedding.parameters(), | ||
| 107 | inv_gamma=ema_inv_gamma, | ||
| 108 | power=ema_power, | ||
| 109 | max_value=ema_max_decay, | ||
| 110 | ) | ||
| 111 | |||
| 112 | self.checkpointer = TextualInversionCheckpointer( | ||
| 113 | unet=unet, | ||
| 114 | text_encoder=text_encoder, | ||
| 115 | ema_embeddings=self.ema_embeddings, | ||
| 116 | *args, | ||
| 117 | **kwargs | ||
| 118 | ) | ||
| 119 | |||
| 120 | @property | ||
| 121 | def main_model(self): | ||
| 122 | return self.text_encoder | ||
| 123 | |||
| 124 | @contextmanager | ||
| 125 | def on_train(self, epoch: int): | ||
| 126 | try: | ||
| 127 | if self.gradient_checkpointing: | ||
| 128 | self.unet.train() | ||
| 129 | |||
| 130 | with super().on_eval(): | ||
| 131 | yield | ||
| 132 | finally: | ||
| 133 | pass | ||
| 134 | |||
| 135 | @contextmanager | ||
| 136 | def on_eval(self): | ||
| 137 | try: | ||
| 138 | if self.gradient_checkpointing: | ||
| 139 | self.unet.eval() | ||
| 140 | |||
| 141 | ema_context = self.ema_embeddings.apply_temporary( | ||
| 142 | self.text_encoder.text_model.embeddings.temp_token_embedding.parameters() | ||
| 143 | ) if self.ema_embeddings is not None else nullcontext() | ||
| 144 | |||
| 145 | with ema_context, super().on_eval(): | ||
| 146 | yield | ||
| 147 | finally: | ||
| 148 | pass | ||
| 149 | |||
| 150 | @torch.no_grad() | ||
| 151 | def on_after_optimize(self, lr: float): | ||
| 152 | if self.use_emb_decay: | ||
| 153 | self.text_encoder.text_model.embeddings.normalize( | ||
| 154 | self.emb_decay_target, | ||
| 155 | min(1.0, max(0.0, self.emb_decay_factor * ((lr - self.emb_decay_start) / (self.learning_rate - self.emb_decay_start)))) | ||
| 156 | ) | ||
| 157 | |||
| 158 | if self.ema_embeddings is not None: | ||
| 159 | self.ema_embeddings.step(self.text_encoder.text_model.embeddings.temp_token_embedding.parameters()) | ||
| 160 | |||
| 161 | def on_log(self): | ||
| 162 | log = super().on_log() | ||
| 163 | added = {} | ||
| 164 | |||
| 165 | if self.ema_embeddings is not None: | ||
| 166 | added = {"ema_decay": self.ema_embeddings.decay} | ||
| 167 | |||
| 168 | return log.update(added) | ||
