sensorlm/inference/zero_shot.rs
1//! Zero-shot classification for wearable sensor data.
2//!
3//! Zero-shot recognition works by encoding a set of candidate class-name
4//! prompts with the text encoder and computing the cosine similarity to each
5//! sensor embedding. The class with the highest similarity is predicted.
6//!
7//! # Example
8//!
9//! ```rust,no_run
10//! use sensorlm::inference::zero_shot::{ZeroShotClassifier, ClassifierConfig};
11//!
12//! let cfg = ClassifierConfig {
13//! class_names: vec!["walking".into(), "running".into(), "sleeping".into()],
14//! prompt_template: "The person is {label}.".into(),
15//! };
16//! // let clf = ZeroShotClassifier::new(model, tokenizer, cfg, device);
17//! // let predictions = clf.predict(sensor_batch);
18//! ```
19
20use burn::tensor::{backend::Backend, Tensor, Int};
21
22use crate::model::sensorlm::SensorLMModel;
23
24// ---------------------------------------------------------------------------
25// Configuration
26// ---------------------------------------------------------------------------
27
28/// Configuration for zero-shot classification.
29#[derive(Debug, Clone)]
30pub struct ClassifierConfig {
31 /// Human-readable class labels.
32 pub class_names: Vec<String>,
33 /// Prompt template. The substring `{label}` is replaced with each class
34 /// name before tokenisation.
35 pub prompt_template: String,
36}
37
38impl ClassifierConfig {
39 /// Build the filled-in prompt for one class.
40 pub fn prompt_for(&self, label: &str) -> String {
41 self.prompt_template.replace("{label}", label)
42 }
43}
44
45impl Default for ClassifierConfig {
46 fn default() -> Self {
47 Self {
48 class_names: vec![
49 "walking".to_string(),
50 "running".to_string(),
51 "cycling".to_string(),
52 "sleeping".to_string(),
53 "sedentary".to_string(),
54 ],
55 prompt_template: "The person is {label}.".to_string(),
56 }
57 }
58}
59
60// ---------------------------------------------------------------------------
61// Classifier
62// ---------------------------------------------------------------------------
63
64/// Zero-shot classifier backed by a SensorLM model.
65pub struct ZeroShotClassifier<B: Backend> {
66 model: SensorLMModel<B>,
67 /// Pre-computed text embeddings for all class prompts, shape `(K, D)`.
68 class_embeddings: Tensor<B, 2>,
69 /// Class names in the same order as `class_embeddings`.
70 class_names: Vec<String>,
71}
72
73impl<B: Backend> ZeroShotClassifier<B> {
74 /// Construct the classifier and pre-compute class embeddings.
75 ///
76 /// # Arguments
77 ///
78 /// * `model` – A trained SensorLM model.
79 /// * `cfg` – Classifier configuration.
80 /// * `tokenize` – A closure that converts a prompt string into
81 /// `(token_ids, attention_mask)` tensors of shape `(1, L)`.
82 pub fn new<F>(model: SensorLMModel<B>, cfg: &ClassifierConfig, tokenize: F) -> Self
83 where
84 F: Fn(&str) -> (Tensor<B, 2, Int>, Tensor<B, 2, Int>),
85 {
86 let embeddings: Vec<Tensor<B, 2>> = cfg
87 .class_names
88 .iter()
89 .map(|name| {
90 let prompt = cfg.prompt_for(name);
91 let (ids, mask) = tokenize(&prompt);
92 model.encode_text(ids, mask) // (1, D)
93 })
94 .collect();
95
96 // Stack into (K, D).
97 let class_embeddings = Tensor::cat(embeddings, 0);
98
99 Self {
100 model,
101 class_embeddings,
102 class_names: cfg.class_names.clone(),
103 }
104 }
105
106 /// Predict the class for each sensor sample in the batch.
107 ///
108 /// # Arguments
109 ///
110 /// * `sensor` – `(B, T, C)` normalised sensor data.
111 ///
112 /// # Returns
113 ///
114 /// A vector of `B` `(class_index, class_name, similarity_score)` tuples.
115 pub fn predict(
116 &self,
117 sensor: Tensor<B, 3>,
118 ) -> Vec<(usize, String, f32)> {
119 let b = sensor.dims()[0];
120 let z_sensor = self.model.encode_sensor(sensor); // (B, D)
121
122 // Similarity matrix: (B, K) = z_sensor @ class_embeddings.T
123 let sim = z_sensor.matmul(self.class_embeddings.clone().transpose()); // (B, K)
124
125 let data: Vec<f32> = sim.into_data().to_vec::<f32>().unwrap_or_default();
126 let k = self.class_names.len();
127
128 (0..b)
129 .map(|i| {
130 let row = &data[i * k..(i + 1) * k];
131 let (best_idx, &best_score) = row
132 .iter()
133 .enumerate()
134 .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
135 .unwrap();
136 (best_idx, self.class_names[best_idx].clone(), best_score)
137 })
138 .collect()
139 }
140
141 /// Predict and return the top-k predictions per sample.
142 pub fn predict_topk(
143 &self,
144 sensor: Tensor<B, 3>,
145 k: usize,
146 ) -> Vec<Vec<(usize, String, f32)>> {
147 let b = sensor.dims()[0];
148 let z_sensor = self.model.encode_sensor(sensor);
149 let sim = z_sensor.matmul(self.class_embeddings.clone().transpose());
150 let data: Vec<f32> = sim.into_data().to_vec::<f32>().unwrap_or_default();
151 let num_classes = self.class_names.len();
152
153 (0..b)
154 .map(|i| {
155 let row = &data[i * num_classes..(i + 1) * num_classes];
156 let mut indexed: Vec<(usize, f32)> =
157 row.iter().copied().enumerate().collect();
158 indexed.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap());
159 indexed
160 .into_iter()
161 .take(k)
162 .map(|(idx, score)| (idx, self.class_names[idx].clone(), score))
163 .collect()
164 })
165 .collect()
166 }
167}