diff options
-rw-r--r-- | data/csv.py | 10 | ||||
-rw-r--r-- | train_dreambooth.py | 24 | ||||
-rw-r--r-- | train_ti.py | 24 | ||||
-rw-r--r-- | training/lora.py | 92 | ||||
-rw-r--r-- | training/lr.py | 6 |
5 files changed, 113 insertions, 43 deletions
diff --git a/data/csv.py b/data/csv.py index af36d9e..e901ab4 100644 --- a/data/csv.py +++ b/data/csv.py | |||
@@ -59,7 +59,7 @@ class CSVDataModule(): | |||
59 | center_crop: bool = False, | 59 | center_crop: bool = False, |
60 | template_key: str = "template", | 60 | template_key: str = "template", |
61 | valid_set_size: Optional[int] = None, | 61 | valid_set_size: Optional[int] = None, |
62 | generator: Optional[torch.Generator] = None, | 62 | seed: Optional[int] = None, |
63 | filter: Optional[Callable[[CSVDataItem], bool]] = None, | 63 | filter: Optional[Callable[[CSVDataItem], bool]] = None, |
64 | collate_fn=None, | 64 | collate_fn=None, |
65 | num_workers: int = 0 | 65 | num_workers: int = 0 |
@@ -84,7 +84,7 @@ class CSVDataModule(): | |||
84 | self.template_key = template_key | 84 | self.template_key = template_key |
85 | self.interpolation = interpolation | 85 | self.interpolation = interpolation |
86 | self.valid_set_size = valid_set_size | 86 | self.valid_set_size = valid_set_size |
87 | self.generator = generator | 87 | self.seed = seed |
88 | self.filter = filter | 88 | self.filter = filter |
89 | self.collate_fn = collate_fn | 89 | self.collate_fn = collate_fn |
90 | self.num_workers = num_workers | 90 | self.num_workers = num_workers |
@@ -155,7 +155,11 @@ class CSVDataModule(): | |||
155 | valid_set_size = max(valid_set_size, 1) | 155 | valid_set_size = max(valid_set_size, 1) |
156 | train_set_size = num_images - valid_set_size | 156 | train_set_size = num_images - valid_set_size |
157 | 157 | ||
158 | data_train, data_val = random_split(items, [train_set_size, valid_set_size], self.generator) | 158 | generator = torch.Generator(device="cpu") |
159 | if self.seed is not None: | ||
160 | generator = generator.manual_seed(self.seed) | ||
161 | |||
162 | data_train, data_val = random_split(items, [train_set_size, valid_set_size], generator=generator) | ||
159 | 163 | ||
160 | self.data_train = self.pad_items(data_train, self.num_class_images) | 164 | self.data_train = self.pad_items(data_train, self.num_class_images) |
161 | self.data_val = self.pad_items(data_val) | 165 | self.data_val = self.pad_items(data_val) |
diff --git a/train_dreambooth.py b/train_dreambooth.py index df8b54c..6d9bae8 100644 --- a/train_dreambooth.py +++ b/train_dreambooth.py | |||
@@ -320,6 +320,12 @@ def parse_args(): | |||
320 | help="Epsilon value for the Adam optimizer" | 320 | help="Epsilon value for the Adam optimizer" |
321 | ) | 321 | ) |
322 | parser.add_argument( | 322 | parser.add_argument( |
323 | "--adam_amsgrad", | ||
324 | type=bool, | ||
325 | default=False, | ||
326 | help="Amsgrad value for the Adam optimizer" | ||
327 | ) | ||
328 | parser.add_argument( | ||
323 | "--mixed_precision", | 329 | "--mixed_precision", |
324 | type=str, | 330 | type=str, |
325 | default="no", | 331 | default="no", |
@@ -642,7 +648,7 @@ def main(): | |||
642 | ) | 648 | ) |
643 | 649 | ||
644 | if args.find_lr: | 650 | if args.find_lr: |
645 | args.learning_rate = 1e-4 | 651 | args.learning_rate = 1e-6 |
646 | 652 | ||
647 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs | 653 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs |
648 | if args.use_8bit_adam: | 654 | if args.use_8bit_adam: |
@@ -674,6 +680,7 @@ def main(): | |||
674 | betas=(args.adam_beta1, args.adam_beta2), | 680 | betas=(args.adam_beta1, args.adam_beta2), |
675 | weight_decay=args.adam_weight_decay, | 681 | weight_decay=args.adam_weight_decay, |
676 | eps=args.adam_epsilon, | 682 | eps=args.adam_epsilon, |
683 | amsgrad=args.adam_amsgrad, | ||
677 | ) | 684 | ) |
678 | 685 | ||
679 | weight_dtype = torch.float32 | 686 | weight_dtype = torch.float32 |
@@ -730,6 +737,7 @@ def main(): | |||
730 | template_key=args.train_data_template, | 737 | template_key=args.train_data_template, |
731 | valid_set_size=args.valid_set_size, | 738 | valid_set_size=args.valid_set_size, |
732 | num_workers=args.dataloader_num_workers, | 739 | num_workers=args.dataloader_num_workers, |
740 | seed=args.seed, | ||
733 | filter=keyword_filter, | 741 | filter=keyword_filter, |
734 | collate_fn=collate_fn | 742 | collate_fn=collate_fn |
735 | ) | 743 | ) |
@@ -840,7 +848,7 @@ def main(): | |||
840 | def on_eval(): | 848 | def on_eval(): |
841 | tokenizer.eval() | 849 | tokenizer.eval() |
842 | 850 | ||
843 | def loop(batch): | 851 | def loop(batch, eval: bool = False): |
844 | # Convert images to latent space | 852 | # Convert images to latent space |
845 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() | 853 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample() |
846 | latents = latents * 0.18215 | 854 | latents = latents * 0.18215 |
@@ -849,8 +857,14 @@ def main(): | |||
849 | noise = torch.randn_like(latents) | 857 | noise = torch.randn_like(latents) |
850 | bsz = latents.shape[0] | 858 | bsz = latents.shape[0] |
851 | # Sample a random timestep for each image | 859 | # Sample a random timestep for each image |
852 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, | 860 | timesteps_gen = torch.Generator(device=latents.device).manual_seed(args.seed) if eval else None |
853 | (bsz,), device=latents.device) | 861 | timesteps = torch.randint( |
862 | 0, | ||
863 | noise_scheduler.config.num_train_timesteps, | ||
864 | (bsz,), | ||
865 | generator=timesteps_gen, | ||
866 | device=latents.device, | ||
867 | ) | ||
854 | timesteps = timesteps.long() | 868 | timesteps = timesteps.long() |
855 | 869 | ||
856 | # Add noise to the latents according to the noise magnitude at each timestep | 870 | # Add noise to the latents according to the noise magnitude at each timestep |
@@ -1051,7 +1065,7 @@ def main(): | |||
1051 | 1065 | ||
1052 | with torch.inference_mode(): | 1066 | with torch.inference_mode(): |
1053 | for step, batch in enumerate(val_dataloader): | 1067 | for step, batch in enumerate(val_dataloader): |
1054 | loss, acc, bsz = loop(batch) | 1068 | loss, acc, bsz = loop(batch, True) |
1055 | 1069 | ||
1056 | loss = loss.detach_() | 1070 | loss = loss.detach_() |
1057 | acc = acc.detach_() | 1071 | acc = acc.detach_() |
diff --git a/train_ti.py b/train_ti.py index 1685dc4..5d6eafc 100644 --- a/train_ti.py +++ b/train_ti.py | |||
@@ -289,6 +289,12 @@ def parse_args(): | |||
289 | help="Epsilon value for the Adam optimizer" | 289 | help="Epsilon value for the Adam optimizer" |
290 | ) | 290 | ) |
291 | parser.add_argument( | 291 | parser.add_argument( |
292 | "--adam_amsgrad", | ||
293 | type=bool, | ||
294 | default=False, | ||
295 | help="Amsgrad value for the Adam optimizer" | ||
296 | ) | ||
297 | parser.add_argument( | ||
292 | "--mixed_precision", | 298 | "--mixed_precision", |
293 | type=str, | 299 | type=str, |
294 | default="no", | 300 | default="no", |
@@ -592,7 +598,7 @@ def main(): | |||
592 | ) | 598 | ) |
593 | 599 | ||
594 | if args.find_lr: | 600 | if args.find_lr: |
595 | args.learning_rate = 1e-4 | 601 | args.learning_rate = 1e-6 |
596 | 602 | ||
597 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs | 603 | # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs |
598 | if args.use_8bit_adam: | 604 | if args.use_8bit_adam: |
@@ -612,6 +618,7 @@ def main(): | |||
612 | betas=(args.adam_beta1, args.adam_beta2), | 618 | betas=(args.adam_beta1, args.adam_beta2), |
613 | weight_decay=args.adam_weight_decay, | 619 | weight_decay=args.adam_weight_decay, |
614 | eps=args.adam_epsilon, | 620 | eps=args.adam_epsilon, |
621 | amsgrad=args.adam_amsgrad, | ||
615 | ) | 622 | ) |
616 | 623 | ||
617 | weight_dtype = torch.float32 | 624 | weight_dtype = torch.float32 |
@@ -673,6 +680,7 @@ def main(): | |||
673 | template_key=args.train_data_template, | 680 | template_key=args.train_data_template, |
674 | valid_set_size=args.valid_set_size, | 681 | valid_set_size=args.valid_set_size, |
675 | num_workers=args.dataloader_num_workers, | 682 | num_workers=args.dataloader_num_workers, |
683 | seed=args.seed, | ||
676 | filter=keyword_filter, | 684 | filter=keyword_filter, |
677 | collate_fn=collate_fn | 685 | collate_fn=collate_fn |
678 | ) | 686 | ) |
@@ -791,7 +799,7 @@ def main(): | |||
791 | def on_eval(): | 799 | def on_eval(): |
792 | tokenizer.eval() | 800 | tokenizer.eval() |
793 | 801 | ||
794 | def loop(batch): | 802 | def loop(batch, eval: bool = False): |
795 | # Convert images to latent space | 803 | # Convert images to latent space |
796 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() | 804 | latents = vae.encode(batch["pixel_values"]).latent_dist.sample().detach() |
797 | latents = latents * 0.18215 | 805 | latents = latents * 0.18215 |
@@ -800,8 +808,14 @@ def main(): | |||
800 | noise = torch.randn_like(latents) | 808 | noise = torch.randn_like(latents) |
801 | bsz = latents.shape[0] | 809 | bsz = latents.shape[0] |
802 | # Sample a random timestep for each image | 810 | # Sample a random timestep for each image |
803 | timesteps = torch.randint(0, noise_scheduler.config.num_train_timesteps, | 811 | timesteps_gen = torch.Generator(device=latents.device).manual_seed(args.seed) if eval else None |
804 | (bsz,), device=latents.device) | 812 | timesteps = torch.randint( |
813 | 0, | ||
814 | noise_scheduler.config.num_train_timesteps, | ||
815 | (bsz,), | ||
816 | generator=timesteps_gen, | ||
817 | device=latents.device, | ||
818 | ) | ||
805 | timesteps = timesteps.long() | 819 | timesteps = timesteps.long() |
806 | 820 | ||
807 | # Add noise to the latents according to the noise magnitude at each timestep | 821 | # Add noise to the latents according to the noise magnitude at each timestep |
@@ -984,7 +998,7 @@ def main(): | |||
984 | 998 | ||
985 | with torch.inference_mode(): | 999 | with torch.inference_mode(): |
986 | for step, batch in enumerate(val_dataloader): | 1000 | for step, batch in enumerate(val_dataloader): |
987 | loss, acc, bsz = loop(batch) | 1001 | loss, acc, bsz = loop(batch, True) |
988 | 1002 | ||
989 | loss = loss.detach_() | 1003 | loss = loss.detach_() |
990 | acc = acc.detach_() | 1004 | acc = acc.detach_() |
diff --git a/training/lora.py b/training/lora.py index e1c0971..3857d78 100644 --- a/training/lora.py +++ b/training/lora.py | |||
@@ -1,3 +1,4 @@ | |||
1 | import torch | ||
1 | import torch.nn as nn | 2 | import torch.nn as nn |
2 | 3 | ||
3 | from diffusers import ModelMixin, ConfigMixin | 4 | from diffusers import ModelMixin, ConfigMixin |
@@ -13,56 +14,93 @@ else: | |||
13 | xformers = None | 14 | xformers = None |
14 | 15 | ||
15 | 16 | ||
16 | class LoraAttnProcessor(ModelMixin, ConfigMixin): | 17 | class LoRALinearLayer(nn.Module): |
17 | @register_to_config | 18 | def __init__(self, in_features, out_features, rank=4): |
18 | def __init__( | ||
19 | self, | ||
20 | cross_attention_dim, | ||
21 | inner_dim, | ||
22 | r: int = 4 | ||
23 | ): | ||
24 | super().__init__() | 19 | super().__init__() |
25 | 20 | ||
26 | if r > min(cross_attention_dim, inner_dim): | 21 | if rank > min(in_features, out_features): |
27 | raise ValueError( | 22 | raise ValueError( |
28 | f"LoRA rank {r} must be less or equal than {min(cross_attention_dim, inner_dim)}" | 23 | f"LoRA rank {rank} must be less or equal than {min(in_features, out_features)}" |
29 | ) | 24 | ) |
30 | 25 | ||
31 | self.lora_k_down = nn.Linear(cross_attention_dim, r, bias=False) | 26 | self.lora_down = nn.Linear(in_features, rank, bias=False) |
32 | self.lora_k_up = nn.Linear(r, inner_dim, bias=False) | 27 | self.lora_up = nn.Linear(rank, out_features, bias=False) |
28 | self.scale = 1.0 | ||
33 | 29 | ||
34 | self.lora_v_down = nn.Linear(cross_attention_dim, r, bias=False) | 30 | nn.init.normal_(self.lora_down.weight, std=1 / rank) |
35 | self.lora_v_up = nn.Linear(r, inner_dim, bias=False) | 31 | nn.init.zeros_(self.lora_up.weight) |
36 | 32 | ||
37 | self.scale = 1.0 | 33 | def forward(self, hidden_states): |
34 | down_hidden_states = self.lora_down(hidden_states) | ||
35 | up_hidden_states = self.lora_up(down_hidden_states) | ||
38 | 36 | ||
39 | nn.init.normal_(self.lora_k_down.weight, std=1 / r**2) | 37 | return up_hidden_states |
40 | nn.init.zeros_(self.lora_k_up.weight) | ||
41 | 38 | ||
42 | nn.init.normal_(self.lora_v_down.weight, std=1 / r**2) | ||
43 | nn.init.zeros_(self.lora_v_up.weight) | ||
44 | 39 | ||
45 | def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): | 40 | class LoRACrossAttnProcessor(nn.Module): |
46 | batch_size, sequence_length, _ = hidden_states.shape | 41 | def __init__(self, hidden_size, cross_attention_dim=None, rank=4): |
42 | super().__init__() | ||
47 | 43 | ||
44 | self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size) | ||
45 | self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size) | ||
46 | self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size) | ||
47 | self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size) | ||
48 | |||
49 | def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0): | ||
50 | batch_size, sequence_length, _ = hidden_states.shape | ||
48 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) | 51 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) |
49 | 52 | ||
50 | query = attn.to_q(hidden_states) | 53 | query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) |
54 | query = attn.head_to_batch_dim(query) | ||
51 | 55 | ||
52 | encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states | 56 | encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states |
53 | key = attn.to_k(encoder_hidden_states) + self.lora_k_up(self.lora_k_down(encoder_hidden_states)) * self.scale | ||
54 | value = attn.to_v(encoder_hidden_states) + self.lora_v_up(self.lora_v_down(encoder_hidden_states)) * self.scale | ||
55 | 57 | ||
58 | key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states) | ||
59 | value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states) | ||
60 | |||
61 | key = attn.head_to_batch_dim(key) | ||
62 | value = attn.head_to_batch_dim(value) | ||
63 | |||
64 | attention_probs = attn.get_attention_scores(query, key, attention_mask) | ||
65 | hidden_states = torch.bmm(attention_probs, value) | ||
66 | hidden_states = attn.batch_to_head_dim(hidden_states) | ||
67 | |||
68 | # linear proj | ||
69 | hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states) | ||
70 | # dropout | ||
71 | hidden_states = attn.to_out[1](hidden_states) | ||
72 | |||
73 | return hidden_states | ||
74 | |||
75 | |||
76 | class LoRAXFormersCrossAttnProcessor(nn.Module): | ||
77 | def __init__(self, hidden_size, cross_attention_dim, rank=4): | ||
78 | super().__init__() | ||
79 | |||
80 | self.to_q_lora = LoRALinearLayer(hidden_size, hidden_size) | ||
81 | self.to_k_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size) | ||
82 | self.to_v_lora = LoRALinearLayer(cross_attention_dim or hidden_size, hidden_size) | ||
83 | self.to_out_lora = LoRALinearLayer(hidden_size, hidden_size) | ||
84 | |||
85 | def __call__(self, attn: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None, scale=1.0): | ||
86 | batch_size, sequence_length, _ = hidden_states.shape | ||
87 | attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length) | ||
88 | |||
89 | query = attn.to_q(hidden_states) + scale * self.to_q_lora(hidden_states) | ||
56 | query = attn.head_to_batch_dim(query).contiguous() | 90 | query = attn.head_to_batch_dim(query).contiguous() |
91 | |||
92 | encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states | ||
93 | |||
94 | key = attn.to_k(encoder_hidden_states) + scale * self.to_k_lora(encoder_hidden_states) | ||
95 | value = attn.to_v(encoder_hidden_states) + scale * self.to_v_lora(encoder_hidden_states) | ||
96 | |||
57 | key = attn.head_to_batch_dim(key).contiguous() | 97 | key = attn.head_to_batch_dim(key).contiguous() |
58 | value = attn.head_to_batch_dim(value).contiguous() | 98 | value = attn.head_to_batch_dim(value).contiguous() |
59 | 99 | ||
60 | hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) | 100 | hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) |
61 | hidden_states = hidden_states.to(query.dtype) | ||
62 | hidden_states = attn.batch_to_head_dim(hidden_states) | ||
63 | 101 | ||
64 | # linear proj | 102 | # linear proj |
65 | hidden_states = attn.to_out[0](hidden_states) | 103 | hidden_states = attn.to_out[0](hidden_states) + scale * self.to_out_lora(hidden_states) |
66 | # dropout | 104 | # dropout |
67 | hidden_states = attn.to_out[1](hidden_states) | 105 | hidden_states = attn.to_out[1](hidden_states) |
68 | 106 | ||
diff --git a/training/lr.py b/training/lr.py index 37588b6..a3144ba 100644 --- a/training/lr.py +++ b/training/lr.py | |||
@@ -1,6 +1,6 @@ | |||
1 | import math | 1 | import math |
2 | import copy | 2 | import copy |
3 | from typing import Callable | 3 | from typing import Callable, Any, Tuple, Union |
4 | from functools import partial | 4 | from functools import partial |
5 | 5 | ||
6 | import matplotlib.pyplot as plt | 6 | import matplotlib.pyplot as plt |
@@ -24,7 +24,7 @@ class LRFinder(): | |||
24 | optimizer, | 24 | optimizer, |
25 | train_dataloader, | 25 | train_dataloader, |
26 | val_dataloader, | 26 | val_dataloader, |
27 | loss_fn, | 27 | loss_fn: Union[Callable[[Any], Tuple[Any, Any, int]], Callable[[Any, bool], Tuple[Any, Any, int]]], |
28 | on_train: Callable[[], None] = noop, | 28 | on_train: Callable[[], None] = noop, |
29 | on_eval: Callable[[], None] = noop | 29 | on_eval: Callable[[], None] = noop |
30 | ): | 30 | ): |
@@ -108,7 +108,7 @@ class LRFinder(): | |||
108 | if step >= num_val_batches: | 108 | if step >= num_val_batches: |
109 | break | 109 | break |
110 | 110 | ||
111 | loss, acc, bsz = self.loss_fn(batch) | 111 | loss, acc, bsz = self.loss_fn(batch, True) |
112 | avg_loss.update(loss.detach_(), bsz) | 112 | avg_loss.update(loss.detach_(), bsz) |
113 | avg_acc.update(acc.detach_(), bsz) | 113 | avg_acc.update(acc.detach_(), bsz) |
114 | 114 | ||