tch_plus/vision/
dinov2.rs

1//! DINOv2: Learning Robust Visual Features without Supervision
2//! https://github.com/facebookresearch/dinov2
3//! The weights can be extracted from pre-trained Python models
4//! using `python src/vision/export_dinov2.py`.
5// TODO: use swiglu.
6use crate::{nn, IndexOp, Kind, Tensor};
7
8const IMG_SIZE: i64 = 518;
9const PATCH_SIZE: i64 = 14;
10const NUM_CLASSES: i64 = 1000;
11
12#[derive(Debug)]
13struct Attention {
14    qkv: nn::Linear,
15    proj: nn::Linear,
16    num_heads: i64,
17    scale: f64,
18}
19
20impl Attention {
21    fn new(vs: nn::Path, dim: i64, num_heads: i64, qkv_bias: bool, proj_bias: bool) -> Self {
22        let qkv_config = nn::LinearConfig { bias: qkv_bias, ..Default::default() };
23        let proj_config = nn::LinearConfig { bias: proj_bias, ..Default::default() };
24        let qkv = nn::linear(&vs / "qkv", dim, dim * 3, qkv_config);
25        let proj = nn::linear(&vs / "proj", dim, dim, proj_config);
26        let scale = 1. / ((dim / num_heads) as f64).sqrt();
27        Self { qkv, proj, num_heads, scale }
28    }
29}
30
31impl nn::Module for Attention {
32    fn forward(&self, xs: &Tensor) -> Tensor {
33        let (b, n, c) = xs.size3().unwrap();
34        let qkv = self
35            .qkv
36            .forward(xs)
37            .reshape([b, n, 3, self.num_heads, c / self.num_heads])
38            .permute([2, 0, 3, 1, 4]);
39        let q = qkv.get(0) * self.scale;
40        let k = qkv.get(1);
41        let v = qkv.get(2);
42        let attn = q.matmul(&k.transpose(-2, -1)).softmax(-1, Kind::Float);
43        attn.matmul(&v).transpose(1, 2).reshape([b, n, c]).apply(&self.proj)
44    }
45}
46
47#[derive(Debug)]
48struct LayerScale {
49    gamma: Tensor,
50}
51
52impl LayerScale {
53    fn new(vs: nn::Path, dim: i64) -> Self {
54        let gamma = vs.var("gamma", &[dim], nn::Init::Const(0.));
55        Self { gamma }
56    }
57}
58
59impl nn::Module for LayerScale {
60    fn forward(&self, xs: &Tensor) -> Tensor {
61        xs * &self.gamma
62    }
63}
64
65#[derive(Debug)]
66struct Mlp {
67    fc1: nn::Linear,
68    fc2: nn::Linear,
69}
70
71impl Mlp {
72    fn new(vs: nn::Path, in_features: i64, hidden_features: i64, bias: bool) -> Self {
73        let out_features = in_features;
74        let config = nn::LinearConfig { bias, ..Default::default() };
75        let fc1 = nn::linear(&vs / "fc1", in_features, hidden_features, config);
76        let fc2 = nn::linear(&vs / "fc2", hidden_features, out_features, config);
77        Self { fc1, fc2 }
78    }
79}
80
81impl nn::Module for Mlp {
82    fn forward(&self, xs: &Tensor) -> Tensor {
83        xs.apply(&self.fc1).gelu("none").apply(&self.fc2)
84    }
85}
86
87#[derive(Debug)]
88struct Block {
89    norm1: nn::LayerNorm,
90    attn: Attention,
91    ls1: LayerScale,
92    norm2: nn::LayerNorm,
93    mlp: Mlp,
94    ls2: LayerScale,
95}
96
97impl Block {
98    fn new(vs: nn::Path, dim: i64, num_heads: i64) -> Self {
99        let norm1 = nn::layer_norm(&vs / "norm1", vec![dim], Default::default());
100        let attn = Attention::new(&vs / "attn", dim, num_heads, true, true);
101        let ls1 = LayerScale::new(&vs / "ls1", dim);
102        let norm2 = nn::layer_norm(&vs / "norm2", vec![dim], Default::default());
103        let mlp = Mlp::new(&vs / "mlp", dim, dim * 4, true);
104        let ls2 = LayerScale::new(&vs / "ls2", dim);
105        Self { norm1, attn, ls1, norm2, mlp, ls2 }
106    }
107}
108
109impl nn::Module for Block {
110    fn forward(&self, xs: &Tensor) -> Tensor {
111        let xs = xs + xs.apply(&self.norm1).apply(&self.attn).apply(&self.ls1);
112        &xs + xs.apply(&self.norm2).apply(&self.mlp).apply(&self.ls2)
113    }
114}
115
116#[derive(Debug)]
117struct PatchEmbed {
118    proj: nn::Conv2D,
119    patch_size: (i64, i64),
120    num_patches: i64,
121}
122
123impl PatchEmbed {
124    fn new(vs: nn::Path, img_size: i64, patch_size: i64, in_chans: i64, embed_dim: i64) -> Self {
125        let config = nn::ConvConfig { stride: patch_size, ..Default::default() };
126        let proj = nn::conv2d(vs / "proj", in_chans, embed_dim, patch_size, config);
127        let num_patches = (img_size / patch_size) * (img_size / patch_size);
128        Self { proj, patch_size: (patch_size, patch_size), num_patches }
129    }
130}
131
132impl nn::Module for PatchEmbed {
133    fn forward(&self, xs: &Tensor) -> Tensor {
134        let (_b, _c, h, w) = xs.size4().unwrap();
135        let (patch_h, patch_w) = self.patch_size;
136        if (h % patch_h) != 0 {
137            panic!("image height {h} is not a multiple of patch height {patch_h}")
138        }
139        if (w % patch_w) != 0 {
140            panic!("image width {w} is not a multiple of patch width {patch_w}")
141        }
142        let xs = xs.apply(&self.proj);
143        let (b, c, h, w) = xs.size4().unwrap();
144        // flatten embeddings.
145        xs.reshape([b, c, h * w]).transpose(1, 2)
146    }
147}
148
149#[derive(Debug)]
150pub struct DinoVisionTransformer {
151    patch_embed: PatchEmbed,
152    cls_token: Tensor,
153    pos_embed: Tensor,
154    blocks: Vec<Block>,
155    norm: nn::LayerNorm,
156    head: nn::Linear,
157}
158
159impl DinoVisionTransformer {
160    pub fn new(vs: &nn::Path, depth: usize, embed_dim: i64, num_heads: i64) -> Self {
161        let patch_embed = PatchEmbed::new(vs / "patch_embed", IMG_SIZE, PATCH_SIZE, 3, embed_dim);
162        let cls_token = vs.var("cls_token", &[1, 1, embed_dim], nn::Init::Const(0.));
163        let num_tokens = 1;
164        let pos_embed = vs.var(
165            "pos_embed",
166            &[1, patch_embed.num_patches + num_tokens, embed_dim],
167            nn::Init::Const(0.),
168        );
169        let head = nn::linear(vs / "head", 2 * embed_dim, NUM_CLASSES, Default::default());
170        let norm = nn::layer_norm(vs / "norm", vec![embed_dim], Default::default());
171        let blocks =
172            (0..depth).map(|i| Block::new(vs / "blocks" / i, embed_dim, num_heads)).collect();
173        Self { patch_embed, cls_token, pos_embed, blocks, norm, head }
174    }
175
176    fn interpolate_pos_encoding(&self, xs: &Tensor, w: i64, h: i64) -> Tensor {
177        let npatch = xs.size()[1] - 1;
178        let n = self.pos_embed.size()[1] - 1;
179        let sqrt_n = (n as f64).sqrt();
180        if npatch == n && w == h {
181            return xs.shallow_clone();
182        }
183        let class_pos_embed = self.pos_embed.i((.., ..1));
184        let patch_pos_embed = self.pos_embed.i((.., 1..));
185        let dim = *xs.size().last().unwrap();
186        let (w0, h0) = ((w / PATCH_SIZE) as f64 + 0.1, (h / PATCH_SIZE) as f64 + 0.1);
187        let patch_pos_embed = patch_pos_embed
188            .reshape([1, sqrt_n as i64, sqrt_n as i64, dim])
189            .permute([0, 3, 1, 2])
190            .upsample_bicubic2d([w0 as i64, h0 as i64], false, w0 / sqrt_n, h0 / sqrt_n);
191        let patch_pos_embed = patch_pos_embed.permute([0, 2, 3, 1]).reshape([1, -1, dim]);
192        Tensor::cat(&[class_pos_embed, patch_pos_embed], 1)
193    }
194
195    fn prepare_tokens_with_mask(&self, xs: &Tensor) -> Tensor {
196        let (b, _nc, w, h) = xs.size4().unwrap();
197        let xs = xs.apply(&self.patch_embed);
198        let xs = Tensor::concat(&[self.cls_token.expand([b, -1, -1], false), xs], 1);
199        &xs + &self.interpolate_pos_encoding(&xs, w, h)
200    }
201}
202
203impl nn::Module for DinoVisionTransformer {
204    fn forward(&self, xs: &Tensor) -> Tensor {
205        let mut xs = self.prepare_tokens_with_mask(xs);
206        for blk in self.blocks.iter() {
207            xs = xs.apply(blk)
208        }
209        let xs = xs.apply(&self.norm);
210        let xs_norm_clstoken = xs.i((.., 0));
211        let xs_norm_patchtokens = xs.i((.., 1..)).mean_dim(1, false, None);
212        let xs = Tensor::concat(&[xs_norm_clstoken, xs_norm_patchtokens], -1);
213        xs.apply(&self.head)
214    }
215}
216
217pub fn vit_small(vs: &nn::Path) -> DinoVisionTransformer {
218    DinoVisionTransformer::new(vs, 12, 384, 6)
219}