Skip to main content

semantic_sift_core/
lib.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2026 Luis Kobayashi. All rights reserved.
3
4use regex::Regex;
5use once_cell::sync::Lazy;
6use std::path::Path;
7use ndarray::Array2;
8use ort::session::Session;
9use tokenizers::Tokenizer;
10
11pub struct SemanticEngine {
12    session: Session,
13    tokenizer: Tokenizer,
14}
15
16impl SemanticEngine {
17    /// Initialize the semantic engine with the ONNX model and tokenizer.
18    pub fn new<P: AsRef<Path>>(model_dir: P) -> Result<Self, String> {
19        let model_dir = model_dir.as_ref();
20        let model_path = model_dir.join("model.onnx");
21        let tokenizer_path = model_dir.join("tokenizer.json");
22
23        if !model_path.exists() {
24            return Err(format!("Model not found at {:?}", model_path));
25        }
26        if !tokenizer_path.exists() {
27            return Err(format!("Tokenizer not found at {:?}", tokenizer_path));
28        }
29
30        let session = Session::builder()
31            .map_err(|e: ort::Error| e.to_string())?
32            .commit_from_file(model_path)
33            .map_err(|e: ort::Error| e.to_string())?;
34
35        let tokenizer = Tokenizer::from_file(tokenizer_path)
36            .map_err(|e: Box<dyn std::error::Error + Send + Sync>| e.to_string())?;
37
38        Ok(Self { session, tokenizer })
39    }
40
41    /// Perform semantic distillation on the given text at the specified rate.
42    pub fn compress(&mut self, text: &str, rate: f32) -> Result<String, String> {
43        let encoding = self.tokenizer.encode(text, true)
44            .map_err(|e: Box<dyn std::error::Error + Send + Sync>| e.to_string())?;
45        
46        let ids = encoding.get_ids();
47        let attention_mask = encoding.get_attention_mask();
48        let seq_len = ids.len();
49
50        if seq_len == 0 {
51            return Ok(String::new());
52        }
53
54        let input_ids_array = Array2::from_shape_vec((1, seq_len), ids.iter().map(|&x| x as i64).collect())
55            .map_err(|e: ndarray::ShapeError| e.to_string())?;
56        let attention_mask_array = Array2::from_shape_vec((1, seq_len), attention_mask.iter().map(|&x| x as i64).collect())
57            .map_err(|e: ndarray::ShapeError| e.to_string())?;
58        let token_type_ids_array = Array2::from_shape_vec((1, seq_len), vec![0i64; seq_len])
59            .map_err(|e: ndarray::ShapeError| e.to_string())?;
60
61        // In ort 2.x, we must create Value objects from arrays
62        let input_ids_value = ort::value::Value::from_array(input_ids_array)
63            .map_err(|e: ort::Error| e.to_string())?;
64        let attention_mask_value = ort::value::Value::from_array(attention_mask_array)
65            .map_err(|e: ort::Error| e.to_string())?;
66        let token_type_ids_value = ort::value::Value::from_array(token_type_ids_array)
67            .map_err(|e: ort::Error| e.to_string())?;
68
69        let inputs = ort::inputs![
70            "input_ids" => &input_ids_value,
71            "attention_mask" => &attention_mask_value,
72            "token_type_ids" => &token_type_ids_value,
73        ];
74
75        let outputs = self.session.run(inputs).map_err(|e: ort::Error| e.to_string())?;
76        
77        // Extract raw data and shape
78        let logits_data = outputs["logits"].try_extract_tensor::<f32>()
79            .map_err(|e: ort::Error| e.to_string())?;
80        
81        let (_shape, data) = logits_data;
82        
83        // Extract probabilities for the "preserve" class (index 1)
84        let mut scores = Vec::with_capacity(seq_len);
85        for i in 0..seq_len {
86            let offset_discard = i * 2;
87            let offset_preserve = i * 2 + 1;
88            
89            let logit_discard = data[offset_discard];
90            let logit_preserve = data[offset_preserve];
91            
92            let exp_discard = logit_discard.exp();
93            let exp_preserve = logit_preserve.exp();
94            let prob_preserve = exp_preserve / (exp_discard + exp_preserve);
95            scores.push(prob_preserve);
96        }
97
98        // Determine threshold based on rate
99        let mut sorted_scores = scores.clone();
100        sorted_scores.sort_by(|a: &f32, b: &f32| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
101        
102        let threshold_idx = ((1.0 - rate) * seq_len as f32) as usize;
103        let threshold = if threshold_idx < seq_len {
104            sorted_scores[threshold_idx]
105        } else {
106            0.0
107        };
108
109        // Reconstruct text with preserved tokens
110        let mut result_ids = Vec::new();
111        for (i, &score) in scores.iter().enumerate() {
112            if score >= threshold {
113                result_ids.push(ids[i]);
114            }
115        }
116
117        self.tokenizer.decode(&result_ids, true).map_err(|e: Box<dyn std::error::Error + Send + Sync>| e.to_string())
118    }
119}
120
121/// Heuristic Sieve: Ported from Python sift_kernel.py
122/// Sifts through raw technical logs to remove noise.
123pub fn apply_heuristic_sieve(text: &str) -> String {
124    static TIMESTAMP_PATTERN: Lazy<Regex> = Lazy::new(|| {
125        Regex::new(r"(\d{4}-\d{2}-\d{2}[T\s]\d{2}:\d{2}:\d{2}([\.,]\d+)?Z?)|(\d{6}\s\d{6}\s\d+)|(\[\d{2}:\d{2}:\d{2}(\.\d+)?\])").unwrap()
126    });
127    static PROGRESS_PATTERN: Lazy<Regex> = Lazy::new(|| {
128        Regex::new(r"\[\d+/\d+\]|[\.]{3,}|\d+%\s*").unwrap()
129    });
130    static METADATA_PATTERN: Lazy<Regex> = Lazy::new(|| {
131        Regex::new(r"\s*(INFO|DEBUG|WARN|ERROR)\s+dfs\..*?:\s*").unwrap()
132    });
133    static MODULE_PATTERN: Lazy<Regex> = Lazy::new(|| {
134        Regex::new(r"(?i)^\s*[\d\.]+\s+(MB|KB|bytes|B)\s+[\w\-\.\/]+.*$").unwrap()
135    });
136
137    let mut sifted = Vec::new();
138    
139    for line in text.lines() {
140        let clean_line = TIMESTAMP_PATTERN.replace_all(line, "").trim().to_string();
141        let clean_line = METADATA_PATTERN.replace_all(&clean_line, "").trim().to_string();
142        
143        if clean_line.is_empty() 
144            || PROGRESS_PATTERN.is_match(&clean_line) 
145            || MODULE_PATTERN.is_match(&clean_line) 
146        {
147            continue;
148        }
149        
150        sifted.push(clean_line);
151    }
152    
153    sifted.join("\n")
154}
155
156#[cfg(test)]
157mod tests {
158    use super::*;
159
160    #[test]
161    fn test_heuristic_sieve_removes_timestamps() {
162        let input = "2026-05-01T12:00:00Z INFO some message";
163        let expected = "INFO some message";
164        assert_eq!(apply_heuristic_sieve(input), expected);
165    }
166
167    #[test]
168    fn test_heuristic_sieve_removes_progress() {
169        let input = "Compiling... [1/42] 25%";
170        let expected = "";
171        assert_eq!(apply_heuristic_sieve(input), expected);
172    }
173
174    #[test]
175    fn test_heuristic_sieve_preserves_errors() {
176        let input = "ERROR: connection refused at line 42";
177        let expected = "ERROR: connection refused at line 42";
178        assert_eq!(apply_heuristic_sieve(input), expected);
179    }
180
181    #[test]
182    fn test_heuristic_sieve_strips_hdfs_metadata() {
183        let input = "2026-05-01T12:00:00Z INFO dfs.DataNode: Receiving block";
184        let expected = "Receiving block";
185        assert_eq!(apply_heuristic_sieve(input), expected);
186    }
187}