1use crate::{nn, IndexOp, Kind, Tensor};
7
8const IMG_SIZE: i64 = 518;
9const PATCH_SIZE: i64 = 14;
10const NUM_CLASSES: i64 = 1000;
11
12#[derive(Debug)]
13struct Attention {
14 qkv: nn::Linear,
15 proj: nn::Linear,
16 num_heads: i64,
17 scale: f64,
18}
19
20impl Attention {
21 fn new(vs: nn::Path, dim: i64, num_heads: i64, qkv_bias: bool, proj_bias: bool) -> Self {
22 let qkv_config = nn::LinearConfig { bias: qkv_bias, ..Default::default() };
23 let proj_config = nn::LinearConfig { bias: proj_bias, ..Default::default() };
24 let qkv = nn::linear(&vs / "qkv", dim, dim * 3, qkv_config);
25 let proj = nn::linear(&vs / "proj", dim, dim, proj_config);
26 let scale = 1. / ((dim / num_heads) as f64).sqrt();
27 Self { qkv, proj, num_heads, scale }
28 }
29}
30
31impl nn::Module for Attention {
32 fn forward(&self, xs: &Tensor) -> Tensor {
33 let (b, n, c) = xs.size3().unwrap();
34 let qkv = self
35 .qkv
36 .forward(xs)
37 .reshape([b, n, 3, self.num_heads, c / self.num_heads])
38 .permute([2, 0, 3, 1, 4]);
39 let q = qkv.get(0) * self.scale;
40 let k = qkv.get(1);
41 let v = qkv.get(2);
42 let attn = q.matmul(&k.transpose(-2, -1)).softmax(-1, Kind::Float);
43 attn.matmul(&v).transpose(1, 2).reshape([b, n, c]).apply(&self.proj)
44 }
45}
46
47#[derive(Debug)]
48struct LayerScale {
49 gamma: Tensor,
50}
51
52impl LayerScale {
53 fn new(vs: nn::Path, dim: i64) -> Self {
54 let gamma = vs.var("gamma", &[dim], nn::Init::Const(0.));
55 Self { gamma }
56 }
57}
58
59impl nn::Module for LayerScale {
60 fn forward(&self, xs: &Tensor) -> Tensor {
61 xs * &self.gamma
62 }
63}
64
65#[derive(Debug)]
66struct Mlp {
67 fc1: nn::Linear,
68 fc2: nn::Linear,
69}
70
71impl Mlp {
72 fn new(vs: nn::Path, in_features: i64, hidden_features: i64, bias: bool) -> Self {
73 let out_features = in_features;
74 let config = nn::LinearConfig { bias, ..Default::default() };
75 let fc1 = nn::linear(&vs / "fc1", in_features, hidden_features, config);
76 let fc2 = nn::linear(&vs / "fc2", hidden_features, out_features, config);
77 Self { fc1, fc2 }
78 }
79}
80
81impl nn::Module for Mlp {
82 fn forward(&self, xs: &Tensor) -> Tensor {
83 xs.apply(&self.fc1).gelu("none").apply(&self.fc2)
84 }
85}
86
87#[derive(Debug)]
88struct Block {
89 norm1: nn::LayerNorm,
90 attn: Attention,
91 ls1: LayerScale,
92 norm2: nn::LayerNorm,
93 mlp: Mlp,
94 ls2: LayerScale,
95}
96
97impl Block {
98 fn new(vs: nn::Path, dim: i64, num_heads: i64) -> Self {
99 let norm1 = nn::layer_norm(&vs / "norm1", vec![dim], Default::default());
100 let attn = Attention::new(&vs / "attn", dim, num_heads, true, true);
101 let ls1 = LayerScale::new(&vs / "ls1", dim);
102 let norm2 = nn::layer_norm(&vs / "norm2", vec![dim], Default::default());
103 let mlp = Mlp::new(&vs / "mlp", dim, dim * 4, true);
104 let ls2 = LayerScale::new(&vs / "ls2", dim);
105 Self { norm1, attn, ls1, norm2, mlp, ls2 }
106 }
107}
108
109impl nn::Module for Block {
110 fn forward(&self, xs: &Tensor) -> Tensor {
111 let xs = xs + xs.apply(&self.norm1).apply(&self.attn).apply(&self.ls1);
112 &xs + xs.apply(&self.norm2).apply(&self.mlp).apply(&self.ls2)
113 }
114}
115
116#[derive(Debug)]
117struct PatchEmbed {
118 proj: nn::Conv2D,
119 patch_size: (i64, i64),
120 num_patches: i64,
121}
122
123impl PatchEmbed {
124 fn new(vs: nn::Path, img_size: i64, patch_size: i64, in_chans: i64, embed_dim: i64) -> Self {
125 let config = nn::ConvConfig { stride: patch_size, ..Default::default() };
126 let proj = nn::conv2d(vs / "proj", in_chans, embed_dim, patch_size, config);
127 let num_patches = (img_size / patch_size) * (img_size / patch_size);
128 Self { proj, patch_size: (patch_size, patch_size), num_patches }
129 }
130}
131
132impl nn::Module for PatchEmbed {
133 fn forward(&self, xs: &Tensor) -> Tensor {
134 let (_b, _c, h, w) = xs.size4().unwrap();
135 let (patch_h, patch_w) = self.patch_size;
136 if (h % patch_h) != 0 {
137 panic!("image height {h} is not a multiple of patch height {patch_h}")
138 }
139 if (w % patch_w) != 0 {
140 panic!("image width {w} is not a multiple of patch width {patch_w}")
141 }
142 let xs = xs.apply(&self.proj);
143 let (b, c, h, w) = xs.size4().unwrap();
144 xs.reshape([b, c, h * w]).transpose(1, 2)
146 }
147}
148
149#[derive(Debug)]
150pub struct DinoVisionTransformer {
151 patch_embed: PatchEmbed,
152 cls_token: Tensor,
153 pos_embed: Tensor,
154 blocks: Vec<Block>,
155 norm: nn::LayerNorm,
156 head: nn::Linear,
157}
158
159impl DinoVisionTransformer {
160 pub fn new(vs: &nn::Path, depth: usize, embed_dim: i64, num_heads: i64) -> Self {
161 let patch_embed = PatchEmbed::new(vs / "patch_embed", IMG_SIZE, PATCH_SIZE, 3, embed_dim);
162 let cls_token = vs.var("cls_token", &[1, 1, embed_dim], nn::Init::Const(0.));
163 let num_tokens = 1;
164 let pos_embed = vs.var(
165 "pos_embed",
166 &[1, patch_embed.num_patches + num_tokens, embed_dim],
167 nn::Init::Const(0.),
168 );
169 let head = nn::linear(vs / "head", 2 * embed_dim, NUM_CLASSES, Default::default());
170 let norm = nn::layer_norm(vs / "norm", vec![embed_dim], Default::default());
171 let blocks =
172 (0..depth).map(|i| Block::new(vs / "blocks" / i, embed_dim, num_heads)).collect();
173 Self { patch_embed, cls_token, pos_embed, blocks, norm, head }
174 }
175
176 fn interpolate_pos_encoding(&self, xs: &Tensor, w: i64, h: i64) -> Tensor {
177 let npatch = xs.size()[1] - 1;
178 let n = self.pos_embed.size()[1] - 1;
179 let sqrt_n = (n as f64).sqrt();
180 if npatch == n && w == h {
181 return xs.shallow_clone();
182 }
183 let class_pos_embed = self.pos_embed.i((.., ..1));
184 let patch_pos_embed = self.pos_embed.i((.., 1..));
185 let dim = *xs.size().last().unwrap();
186 let (w0, h0) = ((w / PATCH_SIZE) as f64 + 0.1, (h / PATCH_SIZE) as f64 + 0.1);
187 let patch_pos_embed = patch_pos_embed
188 .reshape([1, sqrt_n as i64, sqrt_n as i64, dim])
189 .permute([0, 3, 1, 2])
190 .upsample_bicubic2d([w0 as i64, h0 as i64], false, w0 / sqrt_n, h0 / sqrt_n);
191 let patch_pos_embed = patch_pos_embed.permute([0, 2, 3, 1]).reshape([1, -1, dim]);
192 Tensor::cat(&[class_pos_embed, patch_pos_embed], 1)
193 }
194
195 fn prepare_tokens_with_mask(&self, xs: &Tensor) -> Tensor {
196 let (b, _nc, w, h) = xs.size4().unwrap();
197 let xs = xs.apply(&self.patch_embed);
198 let xs = Tensor::concat(&[self.cls_token.expand([b, -1, -1], false), xs], 1);
199 &xs + &self.interpolate_pos_encoding(&xs, w, h)
200 }
201}
202
203impl nn::Module for DinoVisionTransformer {
204 fn forward(&self, xs: &Tensor) -> Tensor {
205 let mut xs = self.prepare_tokens_with_mask(xs);
206 for blk in self.blocks.iter() {
207 xs = xs.apply(blk)
208 }
209 let xs = xs.apply(&self.norm);
210 let xs_norm_clstoken = xs.i((.., 0));
211 let xs_norm_patchtokens = xs.i((.., 1..)).mean_dim(1, false, None);
212 let xs = Tensor::concat(&[xs_norm_clstoken, xs_norm_patchtokens], -1);
213 xs.apply(&self.head)
214 }
215}
216
217pub fn vit_small(vs: &nn::Path) -> DinoVisionTransformer {
218 DinoVisionTransformer::new(vs, 12, 384, 6)
219}