1mod error;
38mod tensor;
39
40pub mod layers;
42pub mod simd;
43
44#[cfg(feature = "backbone")]
46pub mod backbone;
47#[cfg(feature = "backbone")]
48pub mod embedding;
49
50pub mod contrastive;
52
53pub use error::{CnnError, CnnResult};
54pub use tensor::Tensor;
55
56#[cfg(feature = "backbone")]
58pub use backbone::{
59 Backbone, BackboneExt, BackboneType,
60 MobileNetV3, MobileNetV3Config,
61 MobileNetV3Small, MobileNetV3Large, MobileNetConfig,
62 ConvBNActivation, InvertedResidual, SqueezeExcitation,
63 create_backbone, mobilenet_v3_small, mobilenet_v3_large,
64};
65
66#[cfg(feature = "backbone")]
68pub use embedding::{
69 MobileNetEmbedder, EmbeddingExtractorExt,
70 EmbeddingConfig as MobileNetEmbeddingConfig,
71 cosine_similarity, euclidean_distance,
72};
73
74use serde::{Deserialize, Serialize};
79
80#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct EmbeddingConfig {
83 pub input_size: u32,
85 pub embedding_dim: usize,
87 pub normalize: bool,
89 pub quantized: bool,
91}
92
93impl Default for EmbeddingConfig {
94 fn default() -> Self {
95 Self {
96 input_size: 224,
97 embedding_dim: 512,
98 normalize: true,
99 quantized: false,
100 }
101 }
102}
103
104#[derive(Debug, Clone)]
106pub struct CnnEmbedder {
107 config: EmbeddingConfig,
108 weights: EmbedderWeights,
109}
110
111#[derive(Debug, Clone)]
113struct EmbedderWeights {
114 conv_weights: Vec<f32>,
116 bn_gamma: Vec<f32>,
118 bn_beta: Vec<f32>,
119 bn_mean: Vec<f32>,
120 bn_var: Vec<f32>,
121 projection: Vec<f32>,
123}
124
125impl Default for EmbedderWeights {
126 fn default() -> Self {
127 use rand::Rng;
128 let mut rng = rand::thread_rng();
129
130 let conv_size = 3 * 3 * 3 * 16;
131 let bn_size = 16;
132 let proj_size = 16 * 512;
133
134 Self {
135 conv_weights: (0..conv_size).map(|_| rng.gen_range(-0.1..0.1)).collect(),
136 bn_gamma: vec![1.0; bn_size],
137 bn_beta: vec![0.0; bn_size],
138 bn_mean: vec![0.0; bn_size],
139 bn_var: vec![1.0; bn_size],
140 projection: (0..proj_size).map(|_| rng.gen_range(-0.1..0.1)).collect(),
141 }
142 }
143}
144
145impl CnnEmbedder {
146 pub fn new(config: EmbeddingConfig) -> CnnResult<Self> {
148 let weights = EmbedderWeights::default();
149 Ok(Self { config, weights })
150 }
151
152 pub fn new_v3_small() -> CnnResult<Self> {
154 Self::new(EmbeddingConfig {
155 input_size: 224,
156 embedding_dim: 576,
157 normalize: true,
158 quantized: false,
159 })
160 }
161
162 pub fn new_v3_large() -> CnnResult<Self> {
164 Self::new(EmbeddingConfig {
165 input_size: 224,
166 embedding_dim: 960,
167 normalize: true,
168 quantized: false,
169 })
170 }
171
172 pub fn extract(&self, image_data: &[u8], width: u32, height: u32) -> CnnResult<Vec<f32>> {
174 let expected_size = (width * height * 4) as usize;
175 if image_data.len() != expected_size {
176 return Err(CnnError::InvalidInput(format!(
177 "Expected {} bytes for {}x{} RGBA image, got {}",
178 expected_size, width, height, image_data.len()
179 )));
180 }
181
182 let rgb_float = self.preprocess(image_data, width, height)?;
183 let features = self.forward(&rgb_float)?;
184 let pooled = self.global_avg_pool(&features)?;
185 let mut embedding = self.project(&pooled)?;
186
187 if self.config.normalize {
188 self.l2_normalize(&mut embedding);
189 }
190
191 Ok(embedding)
192 }
193
194 pub fn embedding_dim(&self) -> usize {
196 self.config.embedding_dim
197 }
198
199 pub fn input_size(&self) -> u32 {
201 self.config.input_size
202 }
203
204 fn preprocess(&self, image_data: &[u8], width: u32, height: u32) -> CnnResult<Vec<f32>> {
205 let pixels = (width * height) as usize;
206 let mut rgb = Vec::with_capacity(pixels * 3);
207
208 let mean = [0.485, 0.456, 0.406];
209 let std = [0.229, 0.224, 0.225];
210
211 for i in 0..pixels {
212 let offset = i * 4;
213 rgb.push((image_data[offset] as f32 / 255.0 - mean[0]) / std[0]);
214 rgb.push((image_data[offset + 1] as f32 / 255.0 - mean[1]) / std[1]);
215 rgb.push((image_data[offset + 2] as f32 / 255.0 - mean[2]) / std[2]);
216 }
217
218 Ok(rgb)
219 }
220
221 fn forward(&self, input: &[f32]) -> CnnResult<Vec<f32>> {
222 let conv_out = layers::conv2d_3x3(
223 input,
224 &self.weights.conv_weights,
225 3,
226 16,
227 self.config.input_size as usize,
228 self.config.input_size as usize,
229 );
230
231 let bn_out = layers::batch_norm(
232 &conv_out,
233 &self.weights.bn_gamma,
234 &self.weights.bn_beta,
235 &self.weights.bn_mean,
236 &self.weights.bn_var,
237 1e-5,
238 );
239
240 let activated: Vec<f32> = bn_out.iter().map(|&x| x.max(0.0)).collect();
241 Ok(activated)
242 }
243
244 fn global_avg_pool(&self, features: &[f32]) -> CnnResult<Vec<f32>> {
245 let channels = 16;
246 let spatial = features.len() / channels;
247 let mut pooled = vec![0.0f32; channels];
248
249 for i in 0..spatial {
250 for c in 0..channels {
251 pooled[c] += features[i * channels + c];
252 }
253 }
254
255 let inv_spatial = 1.0 / spatial as f32;
256 for p in pooled.iter_mut() {
257 *p *= inv_spatial;
258 }
259
260 Ok(pooled)
261 }
262
263 fn project(&self, features: &[f32]) -> CnnResult<Vec<f32>> {
264 let in_dim = features.len();
265 let out_dim = self.config.embedding_dim;
266 let mut output = vec![0.0f32; out_dim];
267
268 for o in 0..out_dim {
269 let mut sum = 0.0f32;
270 for i in 0..in_dim {
271 sum += features[i] * self.weights.projection[i * out_dim + o];
272 }
273 output[o] = sum;
274 }
275
276 Ok(output)
277 }
278
279 fn l2_normalize(&self, vec: &mut [f32]) {
280 let norm: f32 = vec.iter().map(|x| x * x).sum::<f32>().sqrt();
281 if norm > 1e-10 {
282 for x in vec.iter_mut() {
283 *x /= norm;
284 }
285 }
286 }
287}
288
289pub trait EmbeddingExtractor {
291 fn extract(&self, image_data: &[u8], width: u32, height: u32) -> CnnResult<Vec<f32>>;
292 fn embedding_dim(&self) -> usize;
293}
294
295impl EmbeddingExtractor for CnnEmbedder {
296 fn extract(&self, image_data: &[u8], width: u32, height: u32) -> CnnResult<Vec<f32>> {
297 CnnEmbedder::extract(self, image_data, width, height)
298 }
299
300 fn embedding_dim(&self) -> usize {
301 CnnEmbedder::embedding_dim(self)
302 }
303}
304
305#[cfg(test)]
306mod tests {
307 use super::*;
308
309 #[test]
310 fn test_embedder_creation() {
311 let embedder = CnnEmbedder::new(EmbeddingConfig::default()).unwrap();
312 assert_eq!(embedder.embedding_dim(), 512);
313 }
314
315 #[test]
316 fn test_v3_small() {
317 let embedder = CnnEmbedder::new_v3_small().unwrap();
318 assert_eq!(embedder.embedding_dim(), 576);
319 }
320
321 #[test]
322 fn test_v3_large() {
323 let embedder = CnnEmbedder::new_v3_large().unwrap();
324 assert_eq!(embedder.embedding_dim(), 960);
325 }
326
327 #[test]
328 fn test_extract_embedding() {
329 let embedder = CnnEmbedder::new(EmbeddingConfig {
330 input_size: 4,
331 embedding_dim: 8,
332 normalize: true,
333 quantized: false,
334 }).unwrap();
335
336 let image_data = vec![128u8; 4 * 4 * 4];
337 let embedding = embedder.extract(&image_data, 4, 4).unwrap();
338
339 assert_eq!(embedding.len(), 8);
340 let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
341 assert!((norm - 1.0).abs() < 1e-5 || norm < 1e-10);
342 }
343
344 #[test]
345 fn test_invalid_input() {
346 let embedder = CnnEmbedder::new(EmbeddingConfig::default()).unwrap();
347 let image_data = vec![0u8; 100];
348 let result = embedder.extract(&image_data, 10, 10);
349 assert!(result.is_err());
350 }
351}