Skip to main content

tokenless_semantic/
lib.rs

1#![forbid(unsafe_code)]
2#![warn(missing_docs)]
3#![warn(missing_debug_implementations)]
4#![allow(
5    clippy::pedantic,
6    clippy::missing_errors_doc,
7    reason = "error types documented at enum level"
8)]
9
10//! Semantic-aware JSON field compression.
11//!
12//! # Levels
13//!
14//! **Level 1** (default, zero deps): keyword-based context detection with
15//! compiled-in TOML profiles.  No model download required.
16//!
17//! **Level 2** (`onnx` feature): ONNX embedding model (`all-MiniLM-L6-v2`,
18//! ~15 MB) that computes cosine similarity between field names and the
19//! user's task context.  Model files are auto-downloaded on first use from
20//! GitHub Releases and cached in `~/.tokenfleet-ai/tokenless/models/`.  Falls back to
21//! Level 1 automatically when the model is unavailable.
22
23mod rules;
24
25#[cfg(feature = "onnx")]
26mod embedder;
27
28#[cfg(feature = "onnx")]
29use std::cell::RefCell;
30use std::fmt;
31
32use rules::{FieldAction, classify_field, detect_category};
33use serde_json::Value;
34
35/// Errors that can occur during semantic compression.
36#[derive(Debug, thiserror::Error)]
37pub enum EmbedderError {
38    /// I/O error (filesystem or network).
39    #[error("I/O error: {0}")]
40    Io(#[from] std::io::Error),
41
42    /// Model file not found at the expected path.
43    #[error("Model not found at {0}")]
44    ModelNotFound(std::path::PathBuf),
45
46    /// Tokenizer file not found at the expected path.
47    #[error("Tokenizer not found at {0}")]
48    TokenizerNotFound(std::path::PathBuf),
49
50    /// Failed to load the tokenizer.
51    #[error("Tokenizer load error: {0}")]
52    TokenizerLoad(String),
53
54    /// Failed to tokenize input text.
55    #[error("Tokenization error: {0}")]
56    Tokenize(String),
57
58    /// ONNX Runtime error.
59    #[cfg(feature = "onnx")]
60    #[error("ONNX error: {0}")]
61    Ort(String),
62
63    /// Network download error.
64    #[error("Download error: {0}")]
65    Download(String),
66}
67
68/// Semantic-aware JSON response compressor.
69///
70/// Accepts a user task context string (e.g. `"今天天气怎么样"`) and
71/// compresses the JSON response by dropping or truncating fields based
72/// on their relevance.
73pub struct SemanticCompressor {
74    /// Similarity threshold for Level 2 (cosine similarity in [0, 1]).
75    /// Fields below this threshold are dropped.  Default: 0.3.
76    threshold: f32,
77    /// Level 2 embedder. `None` until [`load_onnx`] succeeds.
78    /// Wrapped in `RefCell` because ONNX `session.run()` requires `&mut self`.
79    #[cfg(feature = "onnx")]
80    embedder: Option<RefCell<embedder::Embedder>>,
81}
82
83impl fmt::Debug for SemanticCompressor {
84    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
85        f.debug_struct("SemanticCompressor")
86            .field("threshold", &self.threshold)
87            .finish_non_exhaustive()
88    }
89}
90
91impl Default for SemanticCompressor {
92    fn default() -> Self {
93        Self {
94            threshold: 0.3,
95            #[cfg(feature = "onnx")]
96            embedder: None,
97        }
98    }
99}
100
101impl SemanticCompressor {
102    /// Create a new compressor with default settings.
103    ///
104    /// Level 2 (ONNX) is NOT loaded — call [`load_onnx`] to enable it.
105    #[must_use]
106    pub fn new() -> Self {
107        Self::default()
108    }
109
110    /// Attempt to load the ONNX embedding model.
111    ///
112    /// Downloads model files on first use if they are not cached in
113    /// `~/.tokenfleet-ai/tokenless/models/`.  Returns `Ok(true)` if Level 2 is ready,
114    /// `Ok(false)` if the model is unavailable (Level 1 fallback).
115    ///
116    /// When the `onnx` feature is disabled at compile time, always returns
117    /// `Ok(false)` (Level 1 only).
118    pub fn load_onnx(&mut self) -> Result<bool, EmbedderError> {
119        #[cfg(feature = "onnx")]
120        {
121            let model_dir = model_dir();
122            embedder::ensure_models(&model_dir)?;
123
124            match embedder::Embedder::load(&model_dir) {
125                Ok(e) => {
126                    self.embedder = Some(RefCell::new(e));
127                    tracing::info!("ONNX embedder loaded (Level 2 enabled)");
128                    Ok(true)
129                }
130                Err(e) => {
131                    tracing::warn!("Failed to load ONNX model, falling back to Level 1: {e}");
132                    Ok(false)
133                }
134            }
135        }
136        #[cfg(not(feature = "onnx"))]
137        {
138            let _ = self;
139            Ok(false)
140        }
141    }
142
143    /// Compress a JSON value using semantic rules.
144    ///
145    /// With the `onnx` feature and a loaded embedder: uses cosine similarity
146    /// between field names and the task context.  Fields below `threshold`
147    /// are dropped.
148    ///
149    /// Without ONNX: uses keyword-based TOML rules (Level 1).
150    #[must_use]
151    pub fn compress(&self, value: &Value, context: &str) -> Value {
152        #[cfg(feature = "onnx")]
153        if let Some(ref embedder) = self.embedder
154            && let Ok(ctx_embedding) = embedder.borrow_mut().embed(context)
155        {
156            return self.compress_with_embedding(value, &ctx_embedding, embedder);
157        }
158
159        // Level 1 fallback: rule-based classification.
160        let category = detect_category(context);
161        self.compress_with_rules(value, category, context)
162    }
163
164    /// Check whether a field name should be kept regardless of truncation
165    /// limits, based on the user's task context.
166    #[must_use]
167    pub fn is_field_kept(&self, field_name: &str, context: &str) -> bool {
168        let category = detect_category(context);
169        matches!(classify_field(field_name, category), FieldAction::Keep)
170    }
171
172    /// Return the context category detected from the given text.
173    #[must_use]
174    pub fn detect_category(&self, context: &str) -> &'static str {
175        detect_category(context)
176    }
177
178    // ── Level 1 internals ────────────────────────────────────────────────
179
180    #[allow(
181        clippy::only_used_in_recursion,
182        reason = "parameters needed for recursive calls"
183    )]
184    fn compress_with_rules(&self, value: &Value, category: &str, context: &str) -> Value {
185        match value {
186            Value::Object(obj) => {
187                let mut result = serde_json::Map::new();
188                for (key, val) in obj {
189                    match classify_field(key, category) {
190                        FieldAction::Drop => {}
191                        FieldAction::Keep | FieldAction::Truncate => {
192                            let compressed_val = self.compress_with_rules(val, category, context);
193                            result.insert(key.clone(), compressed_val);
194                        }
195                    }
196                }
197                Value::Object(result)
198            }
199            Value::Array(arr) => {
200                let compressed: Vec<Value> = arr
201                    .iter()
202                    .map(|v| self.compress_with_rules(v, category, context))
203                    .collect();
204                Value::Array(compressed)
205            }
206            other => other.clone(),
207        }
208    }
209
210    // ── Level 2 internals ────────────────────────────────────────────────
211
212    #[cfg(feature = "onnx")]
213    fn compress_with_embedding(
214        &self,
215        value: &Value,
216        ctx_embedding: &[f32],
217        embedder: &RefCell<embedder::Embedder>,
218    ) -> Value {
219        match value {
220            Value::Object(obj) => {
221                let mut result = serde_json::Map::new();
222                for (key, val) in obj {
223                    if let Ok(field_emb) = embedder.borrow_mut().embed(key) {
224                        let sim = embedder::Embedder::cosine_similarity(ctx_embedding, &field_emb);
225                        if sim < self.threshold {
226                            continue; // irrelevant field → drop
227                        }
228                    }
229                    let compressed_val = self.compress_with_embedding(val, ctx_embedding, embedder);
230                    result.insert(key.clone(), compressed_val);
231                }
232                Value::Object(result)
233            }
234            Value::Array(arr) => {
235                let compressed: Vec<Value> = arr
236                    .iter()
237                    .map(|v| self.compress_with_embedding(v, ctx_embedding, embedder))
238                    .collect();
239                Value::Array(compressed)
240            }
241            other => other.clone(),
242        }
243    }
244}
245
246/// Path to the model cache directory.
247#[cfg(feature = "onnx")]
248fn model_dir() -> std::path::PathBuf {
249    dirs::home_dir()
250        .unwrap_or_else(|| std::path::PathBuf::from("."))
251        .join(".tokenfleet-ai")
252        .join("tokenless")
253        .join("models")
254}
255
256// ── Tests ─────────────────────────────────────────────────────────────────
257
258#[cfg(test)]
259mod tests {
260    #![allow(clippy::unwrap_used, clippy::expect_used)]
261
262    use serde_json::json;
263
264    use super::*;
265
266    #[test]
267    fn test_compress_weather_drops_station_id() {
268        let compressor = SemanticCompressor::new();
269        let value = json!({
270            "temperature": 22.5,
271            "wind_speed": 12.0,
272            "station_id": "WX-001",
273            "sensor_version": "3.1.0",
274        });
275        let result = compressor.compress(&value, "今天天气怎么样");
276        assert!(result.get("temperature").is_some());
277        assert!(result.get("wind_speed").is_some());
278        assert!(result.get("station_id").is_none());
279        assert!(result.get("sensor_version").is_none());
280    }
281
282    #[test]
283    fn test_compress_devops_drops_uid() {
284        let compressor = SemanticCompressor::new();
285        let value = json!({
286            "pod_status": "Running",
287            "cpu_usage": 0.45,
288            "uid": "abc-123-def",
289            "self_link": "/api/v1/...",
290        });
291        let result = compressor.compress(&value, "deploy to kubernetes");
292        assert!(result.get("pod_status").is_some());
293        assert!(result.get("cpu_usage").is_some());
294        assert!(result.get("uid").is_none());
295        assert!(result.get("self_link").is_none());
296    }
297
298    #[test]
299    fn test_compress_default_drops_debug() {
300        let compressor = SemanticCompressor::new();
301        let value = json!({
302            "name": "Alice",
303            "age": 30,
304            "debug": "some debug info",
305            "trace": "trace data",
306        });
307        let result = compressor.compress(&value, "hello");
308        assert!(result.get("name").is_some());
309        assert!(result.get("age").is_some());
310        assert!(result.get("debug").is_none());
311        assert!(result.get("trace").is_none());
312    }
313
314    #[test]
315    fn test_compress_nested_object() {
316        let compressor = SemanticCompressor::new();
317        let value = json!({
318            "data": {
319                "temperature": 22.5,
320                "station_id": "WX-001",
321                "nested": {
322                    "wind_speed": 12.0,
323                    "calibration_date": "2025-01-01",
324                }
325            }
326        });
327        let result = compressor.compress(&value, "天气");
328        let data = &result["data"];
329        assert!(data["temperature"].is_f64());
330        assert!(data.get("station_id").is_none());
331        let nested = &data["nested"];
332        assert!(nested["wind_speed"].is_f64());
333        assert!(nested.get("calibration_date").is_none());
334    }
335
336    #[test]
337    fn test_compress_array_of_objects() {
338        let compressor = SemanticCompressor::new();
339        let value = json!([
340            {"temperature": 22.5, "station_id": "A"},
341            {"temperature": 18.0, "station_id": "B"},
342        ]);
343        let result = compressor.compress(&value, "天气");
344        let arr = result.as_array().unwrap();
345        assert_eq!(arr.len(), 2);
346        assert!(arr[0].get("station_id").is_none());
347        assert!(arr[1].get("station_id").is_none());
348    }
349
350    #[test]
351    fn test_is_field_kept() {
352        let compressor = SemanticCompressor::new();
353        assert!(compressor.is_field_kept("temperature", "天气怎么样"));
354        assert!(!compressor.is_field_kept("station_id", "天气怎么样"));
355    }
356
357    #[test]
358    fn test_detect_category_public() {
359        let compressor = SemanticCompressor::new();
360        assert_eq!(compressor.detect_category("天气"), "weather");
361        assert_eq!(compressor.detect_category("unknown"), "default");
362    }
363}