Skip to main content

sensorlm/inference/
zero_shot.rs

1//! Zero-shot classification for wearable sensor data.
2//!
3//! Zero-shot recognition works by encoding a set of candidate class-name
4//! prompts with the text encoder and computing the cosine similarity to each
5//! sensor embedding.  The class with the highest similarity is predicted.
6//!
7//! # Example
8//!
9//! ```rust,no_run
10//! use sensorlm::inference::zero_shot::{ZeroShotClassifier, ClassifierConfig};
11//!
12//! let cfg = ClassifierConfig {
13//!     class_names: vec!["walking".into(), "running".into(), "sleeping".into()],
14//!     prompt_template: "The person is {label}.".into(),
15//! };
16//! // let clf = ZeroShotClassifier::new(model, tokenizer, cfg, device);
17//! // let predictions = clf.predict(sensor_batch);
18//! ```
19
20use burn::tensor::{backend::Backend, Tensor, Int};
21
22use crate::model::sensorlm::SensorLMModel;
23
24// ---------------------------------------------------------------------------
25// Configuration
26// ---------------------------------------------------------------------------
27
28/// Configuration for zero-shot classification.
29#[derive(Debug, Clone)]
30pub struct ClassifierConfig {
31    /// Human-readable class labels.
32    pub class_names: Vec<String>,
33    /// Prompt template.  The substring `{label}` is replaced with each class
34    /// name before tokenisation.
35    pub prompt_template: String,
36}
37
38impl ClassifierConfig {
39    /// Build the filled-in prompt for one class.
40    pub fn prompt_for(&self, label: &str) -> String {
41        self.prompt_template.replace("{label}", label)
42    }
43}
44
45impl Default for ClassifierConfig {
46    fn default() -> Self {
47        Self {
48            class_names: vec![
49                "walking".to_string(),
50                "running".to_string(),
51                "cycling".to_string(),
52                "sleeping".to_string(),
53                "sedentary".to_string(),
54            ],
55            prompt_template: "The person is {label}.".to_string(),
56        }
57    }
58}
59
60// ---------------------------------------------------------------------------
61// Classifier
62// ---------------------------------------------------------------------------
63
64/// Zero-shot classifier backed by a SensorLM model.
65pub struct ZeroShotClassifier<B: Backend> {
66    model: SensorLMModel<B>,
67    /// Pre-computed text embeddings for all class prompts, shape `(K, D)`.
68    class_embeddings: Tensor<B, 2>,
69    /// Class names in the same order as `class_embeddings`.
70    class_names: Vec<String>,
71}
72
73impl<B: Backend> ZeroShotClassifier<B> {
74    /// Construct the classifier and pre-compute class embeddings.
75    ///
76    /// # Arguments
77    ///
78    /// * `model`     – A trained SensorLM model.
79    /// * `cfg`       – Classifier configuration.
80    /// * `tokenize`  – A closure that converts a prompt string into
81    ///   `(token_ids, attention_mask)` tensors of shape `(1, L)`.
82    pub fn new<F>(model: SensorLMModel<B>, cfg: &ClassifierConfig, tokenize: F) -> Self
83    where
84        F: Fn(&str) -> (Tensor<B, 2, Int>, Tensor<B, 2, Int>),
85    {
86        let embeddings: Vec<Tensor<B, 2>> = cfg
87            .class_names
88            .iter()
89            .map(|name| {
90                let prompt = cfg.prompt_for(name);
91                let (ids, mask) = tokenize(&prompt);
92                model.encode_text(ids, mask) // (1, D)
93            })
94            .collect();
95
96        // Stack into (K, D).
97        let class_embeddings = Tensor::cat(embeddings, 0);
98
99        Self {
100            model,
101            class_embeddings,
102            class_names: cfg.class_names.clone(),
103        }
104    }
105
106    /// Predict the class for each sensor sample in the batch.
107    ///
108    /// # Arguments
109    ///
110    /// * `sensor` – `(B, T, C)` normalised sensor data.
111    ///
112    /// # Returns
113    ///
114    /// A vector of `B` `(class_index, class_name, similarity_score)` tuples.
115    pub fn predict(
116        &self,
117        sensor: Tensor<B, 3>,
118    ) -> Vec<(usize, String, f32)> {
119        let b = sensor.dims()[0];
120        let z_sensor = self.model.encode_sensor(sensor); // (B, D)
121
122        // Similarity matrix: (B, K) = z_sensor @ class_embeddings.T
123        let sim = z_sensor.matmul(self.class_embeddings.clone().transpose()); // (B, K)
124
125        let data: Vec<f32> = sim.into_data().to_vec::<f32>().unwrap_or_default();
126        let k = self.class_names.len();
127
128        (0..b)
129            .map(|i| {
130                let row = &data[i * k..(i + 1) * k];
131                let (best_idx, &best_score) = row
132                    .iter()
133                    .enumerate()
134                    .max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
135                    .unwrap();
136                (best_idx, self.class_names[best_idx].clone(), best_score)
137            })
138            .collect()
139    }
140
141    /// Predict and return the top-k predictions per sample.
142    pub fn predict_topk(
143        &self,
144        sensor: Tensor<B, 3>,
145        k: usize,
146    ) -> Vec<Vec<(usize, String, f32)>> {
147        let b = sensor.dims()[0];
148        let z_sensor = self.model.encode_sensor(sensor);
149        let sim = z_sensor.matmul(self.class_embeddings.clone().transpose());
150        let data: Vec<f32> = sim.into_data().to_vec::<f32>().unwrap_or_default();
151        let num_classes = self.class_names.len();
152
153        (0..b)
154            .map(|i| {
155                let row = &data[i * num_classes..(i + 1) * num_classes];
156                let mut indexed: Vec<(usize, f32)> =
157                    row.iter().copied().enumerate().collect();
158                indexed.sort_by(|(_, a), (_, b)| b.partial_cmp(a).unwrap());
159                indexed
160                    .into_iter()
161                    .take(k)
162                    .map(|(idx, score)| (idx, self.class_names[idx].clone(), score))
163                    .collect()
164            })
165            .collect()
166    }
167}