sensorlm/inference/retrieval.rs
1//! Cross-modal retrieval: sensor → text and text → sensor.
2//!
3//! Given a database of pre-computed embeddings, retrieval finds the top-k
4//! most similar items from the other modality.
5//!
6//! # Use cases
7//!
8//! * **Sensor → Text**: "Given this 24-hour recording, find the most similar
9//! textual descriptions."
10//! * **Text → Sensor**: "Given this query text, find the most similar sensor
11//! recordings."
12//!
13//! Both directions use the same L2-normalised embedding space, so cosine
14//! similarity equals dot product.
15
16use burn::tensor::{backend::Backend, Tensor, Int};
17
18use crate::model::sensorlm::SensorLMModel;
19use crate::loss::recall_at_k;
20
21// ---------------------------------------------------------------------------
22// Result type
23// ---------------------------------------------------------------------------
24
25/// A single retrieval result.
26#[derive(Debug, Clone)]
27pub struct RetrievalResult {
28 /// Index into the database being searched.
29 pub index: usize,
30 /// Cosine similarity score `[-1, 1]`.
31 pub score: f32,
32}
33
34// ---------------------------------------------------------------------------
35// Retrieval engine
36// ---------------------------------------------------------------------------
37
38/// Cross-modal retrieval engine.
39///
40/// Pre-computes embeddings for an indexed database of sensor recordings
41/// and/or text captions and enables fast nearest-neighbour search.
42pub struct RetrievalEngine<B: Backend> {
43 model: SensorLMModel<B>,
44 /// Optional pre-computed sensor embeddings `(N, D)`.
45 sensor_embeddings: Option<Tensor<B, 2>>,
46 /// Optional pre-computed text embeddings `(M, D)`.
47 text_embeddings: Option<Tensor<B, 2>>,
48}
49
50impl<B: Backend> RetrievalEngine<B> {
51 /// Create a new retrieval engine.
52 pub fn new(model: SensorLMModel<B>) -> Self {
53 Self {
54 model,
55 sensor_embeddings: None,
56 text_embeddings: None,
57 }
58 }
59
60 /// Pre-compute and store sensor embeddings from a batch of sensor data.
61 ///
62 /// # Arguments
63 ///
64 /// * `sensor_batches` – Iterator of `(B, T, C)` sensor tensors.
65 pub fn index_sensor(&mut self, sensor_batches: impl IntoIterator<Item = Tensor<B, 3>>) {
66 let embeddings: Vec<Tensor<B, 2>> = sensor_batches
67 .into_iter()
68 .map(|batch| self.model.encode_sensor(batch))
69 .collect();
70 self.sensor_embeddings = Some(Tensor::cat(embeddings, 0));
71 }
72
73 /// Pre-compute and store text embeddings from batches of token sequences.
74 pub fn index_text(
75 &mut self,
76 text_batches: impl IntoIterator<Item = (Tensor<B, 2, Int>, Tensor<B, 2, Int>)>,
77 ) {
78 let embeddings: Vec<Tensor<B, 2>> = text_batches
79 .into_iter()
80 .map(|(ids, mask)| self.model.encode_text(ids, mask))
81 .collect();
82 self.text_embeddings = Some(Tensor::cat(embeddings, 0));
83 }
84
85 /// Retrieve the top-k most similar texts for each sensor query.
86 ///
87 /// # Arguments
88 ///
89 /// * `query_sensor` – `(Q, T, C)` query sensor tensors.
90 /// * `top_k` – Number of results to return per query.
91 ///
92 /// # Returns
93 ///
94 /// A vector of `Q` lists, each containing `top_k` [`RetrievalResult`]s
95 /// sorted by descending similarity.
96 ///
97 /// # Panics
98 ///
99 /// Panics if no text embeddings have been indexed.
100 pub fn sensor_to_text(
101 &self,
102 query_sensor: Tensor<B, 3>,
103 top_k: usize,
104 ) -> Vec<Vec<RetrievalResult>> {
105 let db = self
106 .text_embeddings
107 .as_ref()
108 .expect("No text embeddings indexed");
109
110 let q_emb = self.model.encode_sensor(query_sensor); // (Q, D)
111 self.top_k_search(q_emb, db.clone(), top_k)
112 }
113
114 /// Retrieve the top-k most similar sensor recordings for each text query.
115 ///
116 /// # Panics
117 ///
118 /// Panics if no sensor embeddings have been indexed.
119 pub fn text_to_sensor(
120 &self,
121 query_ids: Tensor<B, 2, Int>,
122 query_mask: Tensor<B, 2, Int>,
123 top_k: usize,
124 ) -> Vec<Vec<RetrievalResult>> {
125 let db = self
126 .sensor_embeddings
127 .as_ref()
128 .expect("No sensor embeddings indexed");
129
130 let q_emb = self.model.encode_text(query_ids, query_mask); // (Q, D)
131 self.top_k_search(q_emb, db.clone(), top_k)
132 }
133
134 /// Core top-k nearest-neighbour search.
135 ///
136 /// `queries` is `(Q, D)`, `database` is `(N, D)`.
137 /// Returns `Q` lists of the top-k `(index, score)` pairs.
138 fn top_k_search(
139 &self,
140 queries: Tensor<B, 2>,
141 database: Tensor<B, 2>,
142 top_k: usize,
143 ) -> Vec<Vec<RetrievalResult>> {
144 let q = queries.dims()[0];
145 let n = database.dims()[0];
146
147 // (Q, N) similarity matrix.
148 let sim = queries.matmul(database.transpose());
149 let data: Vec<f32> = sim.into_data().to_vec::<f32>().unwrap_or_default();
150
151 (0..q)
152 .map(|qi| {
153 let row = &data[qi * n..(qi + 1) * n];
154 let mut indexed: Vec<(usize, f32)> =
155 row.iter().copied().enumerate().collect();
156 indexed.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap());
157 indexed
158 .into_iter()
159 .take(top_k)
160 .map(|(i, s)| RetrievalResult { index: i, score: s })
161 .collect()
162 })
163 .collect()
164 }
165
166 /// Evaluate Recall@k on a paired (sensor, text) evaluation set.
167 ///
168 /// Assumes ground truth is the diagonal: sensor `i` corresponds to text `i`.
169 ///
170 /// # Arguments
171 ///
172 /// * `sensor` – `(N, T, C)` evaluation sensor data.
173 /// * `ids` – `(N, L)` token IDs.
174 /// * `mask` – `(N, L)` attention mask.
175 /// * `k` – Recall@k.
176 pub fn evaluate_recall(
177 &self,
178 sensor: Tensor<B, 3>,
179 ids: Tensor<B, 2, Int>,
180 mask: Tensor<B, 2, Int>,
181 k: usize,
182 ) -> (f32, f32) {
183 let z_s = self.model.encode_sensor(sensor); // (N, D)
184 let z_t = self.model.encode_text(ids, mask); // (N, D)
185
186 let logits = z_s.matmul(z_t.clone().transpose()); // (N, N)
187 let r_s2t = recall_at_k(logits.clone(), k);
188 let r_t2s = recall_at_k(logits.transpose(), k);
189
190 (r_s2t, r_t2s)
191 }
192}