Skip to main content

sensorlm/inference/
retrieval.rs

1//! Cross-modal retrieval: sensor → text and text → sensor.
2//!
3//! Given a database of pre-computed embeddings, retrieval finds the top-k
4//! most similar items from the other modality.
5//!
6//! # Use cases
7//!
8//! * **Sensor → Text**: "Given this 24-hour recording, find the most similar
9//!   textual descriptions."
10//! * **Text → Sensor**: "Given this query text, find the most similar sensor
11//!   recordings."
12//!
13//! Both directions use the same L2-normalised embedding space, so cosine
14//! similarity equals dot product.
15
16use burn::tensor::{backend::Backend, Tensor, Int};
17
18use crate::model::sensorlm::SensorLMModel;
19use crate::loss::recall_at_k;
20
21// ---------------------------------------------------------------------------
22// Result type
23// ---------------------------------------------------------------------------
24
25/// A single retrieval result.
26#[derive(Debug, Clone)]
27pub struct RetrievalResult {
28    /// Index into the database being searched.
29    pub index: usize,
30    /// Cosine similarity score `[-1, 1]`.
31    pub score: f32,
32}
33
34// ---------------------------------------------------------------------------
35// Retrieval engine
36// ---------------------------------------------------------------------------
37
38/// Cross-modal retrieval engine.
39///
40/// Pre-computes embeddings for an indexed database of sensor recordings
41/// and/or text captions and enables fast nearest-neighbour search.
42pub struct RetrievalEngine<B: Backend> {
43    model: SensorLMModel<B>,
44    /// Optional pre-computed sensor embeddings `(N, D)`.
45    sensor_embeddings: Option<Tensor<B, 2>>,
46    /// Optional pre-computed text embeddings `(M, D)`.
47    text_embeddings: Option<Tensor<B, 2>>,
48}
49
50impl<B: Backend> RetrievalEngine<B> {
51    /// Create a new retrieval engine.
52    pub fn new(model: SensorLMModel<B>) -> Self {
53        Self {
54            model,
55            sensor_embeddings: None,
56            text_embeddings: None,
57        }
58    }
59
60    /// Pre-compute and store sensor embeddings from a batch of sensor data.
61    ///
62    /// # Arguments
63    ///
64    /// * `sensor_batches` – Iterator of `(B, T, C)` sensor tensors.
65    pub fn index_sensor(&mut self, sensor_batches: impl IntoIterator<Item = Tensor<B, 3>>) {
66        let embeddings: Vec<Tensor<B, 2>> = sensor_batches
67            .into_iter()
68            .map(|batch| self.model.encode_sensor(batch))
69            .collect();
70        self.sensor_embeddings = Some(Tensor::cat(embeddings, 0));
71    }
72
73    /// Pre-compute and store text embeddings from batches of token sequences.
74    pub fn index_text(
75        &mut self,
76        text_batches: impl IntoIterator<Item = (Tensor<B, 2, Int>, Tensor<B, 2, Int>)>,
77    ) {
78        let embeddings: Vec<Tensor<B, 2>> = text_batches
79            .into_iter()
80            .map(|(ids, mask)| self.model.encode_text(ids, mask))
81            .collect();
82        self.text_embeddings = Some(Tensor::cat(embeddings, 0));
83    }
84
85    /// Retrieve the top-k most similar texts for each sensor query.
86    ///
87    /// # Arguments
88    ///
89    /// * `query_sensor` – `(Q, T, C)` query sensor tensors.
90    /// * `top_k`        – Number of results to return per query.
91    ///
92    /// # Returns
93    ///
94    /// A vector of `Q` lists, each containing `top_k` [`RetrievalResult`]s
95    /// sorted by descending similarity.
96    ///
97    /// # Panics
98    ///
99    /// Panics if no text embeddings have been indexed.
100    pub fn sensor_to_text(
101        &self,
102        query_sensor: Tensor<B, 3>,
103        top_k: usize,
104    ) -> Vec<Vec<RetrievalResult>> {
105        let db = self
106            .text_embeddings
107            .as_ref()
108            .expect("No text embeddings indexed");
109
110        let q_emb = self.model.encode_sensor(query_sensor); // (Q, D)
111        self.top_k_search(q_emb, db.clone(), top_k)
112    }
113
114    /// Retrieve the top-k most similar sensor recordings for each text query.
115    ///
116    /// # Panics
117    ///
118    /// Panics if no sensor embeddings have been indexed.
119    pub fn text_to_sensor(
120        &self,
121        query_ids: Tensor<B, 2, Int>,
122        query_mask: Tensor<B, 2, Int>,
123        top_k: usize,
124    ) -> Vec<Vec<RetrievalResult>> {
125        let db = self
126            .sensor_embeddings
127            .as_ref()
128            .expect("No sensor embeddings indexed");
129
130        let q_emb = self.model.encode_text(query_ids, query_mask); // (Q, D)
131        self.top_k_search(q_emb, db.clone(), top_k)
132    }
133
134    /// Core top-k nearest-neighbour search.
135    ///
136    /// `queries` is `(Q, D)`, `database` is `(N, D)`.
137    /// Returns `Q` lists of the top-k `(index, score)` pairs.
138    fn top_k_search(
139        &self,
140        queries: Tensor<B, 2>,
141        database: Tensor<B, 2>,
142        top_k: usize,
143    ) -> Vec<Vec<RetrievalResult>> {
144        let q = queries.dims()[0];
145        let n = database.dims()[0];
146
147        // (Q, N) similarity matrix.
148        let sim = queries.matmul(database.transpose());
149        let data: Vec<f32> = sim.into_data().to_vec::<f32>().unwrap_or_default();
150
151        (0..q)
152            .map(|qi| {
153                let row = &data[qi * n..(qi + 1) * n];
154                let mut indexed: Vec<(usize, f32)> =
155                    row.iter().copied().enumerate().collect();
156                indexed.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap());
157                indexed
158                    .into_iter()
159                    .take(top_k)
160                    .map(|(i, s)| RetrievalResult { index: i, score: s })
161                    .collect()
162            })
163            .collect()
164    }
165
166    /// Evaluate Recall@k on a paired (sensor, text) evaluation set.
167    ///
168    /// Assumes ground truth is the diagonal: sensor `i` corresponds to text `i`.
169    ///
170    /// # Arguments
171    ///
172    /// * `sensor`  – `(N, T, C)` evaluation sensor data.
173    /// * `ids`     – `(N, L)` token IDs.
174    /// * `mask`    – `(N, L)` attention mask.
175    /// * `k`       – Recall@k.
176    pub fn evaluate_recall(
177        &self,
178        sensor: Tensor<B, 3>,
179        ids: Tensor<B, 2, Int>,
180        mask: Tensor<B, 2, Int>,
181        k: usize,
182    ) -> (f32, f32) {
183        let z_s = self.model.encode_sensor(sensor);  // (N, D)
184        let z_t = self.model.encode_text(ids, mask); // (N, D)
185
186        let logits = z_s.matmul(z_t.clone().transpose()); // (N, N)
187        let r_s2t = recall_at_k(logits.clone(), k);
188        let r_t2s = recall_at_k(logits.transpose(), k);
189
190        (r_s2t, r_t2s)
191    }
192}