semantic_sift_core/
lib.rs1use regex::Regex;
5use once_cell::sync::Lazy;
6use std::path::Path;
7use ndarray::Array2;
8use ort::session::Session;
9use tokenizers::Tokenizer;
10
11pub struct SemanticEngine {
12 session: Session,
13 tokenizer: Tokenizer,
14}
15
16impl SemanticEngine {
17 pub fn new<P: AsRef<Path>>(model_dir: P) -> Result<Self, String> {
19 let model_dir = model_dir.as_ref();
20 let model_path = model_dir.join("model.onnx");
21 let tokenizer_path = model_dir.join("tokenizer.json");
22
23 if !model_path.exists() {
24 return Err(format!("Model not found at {:?}", model_path));
25 }
26 if !tokenizer_path.exists() {
27 return Err(format!("Tokenizer not found at {:?}", tokenizer_path));
28 }
29
30 let session = Session::builder()
31 .map_err(|e: ort::Error| e.to_string())?
32 .commit_from_file(model_path)
33 .map_err(|e: ort::Error| e.to_string())?;
34
35 let tokenizer = Tokenizer::from_file(tokenizer_path)
36 .map_err(|e: Box<dyn std::error::Error + Send + Sync>| e.to_string())?;
37
38 Ok(Self { session, tokenizer })
39 }
40
41 pub fn compress(&mut self, text: &str, rate: f32) -> Result<String, String> {
43 let encoding = self.tokenizer.encode(text, true)
44 .map_err(|e: Box<dyn std::error::Error + Send + Sync>| e.to_string())?;
45
46 let ids = encoding.get_ids();
47 let attention_mask = encoding.get_attention_mask();
48 let seq_len = ids.len();
49
50 if seq_len == 0 {
51 return Ok(String::new());
52 }
53
54 let input_ids_array = Array2::from_shape_vec((1, seq_len), ids.iter().map(|&x| x as i64).collect())
55 .map_err(|e: ndarray::ShapeError| e.to_string())?;
56 let attention_mask_array = Array2::from_shape_vec((1, seq_len), attention_mask.iter().map(|&x| x as i64).collect())
57 .map_err(|e: ndarray::ShapeError| e.to_string())?;
58 let token_type_ids_array = Array2::from_shape_vec((1, seq_len), vec![0i64; seq_len])
59 .map_err(|e: ndarray::ShapeError| e.to_string())?;
60
61 let input_ids_value = ort::value::Value::from_array(input_ids_array)
63 .map_err(|e: ort::Error| e.to_string())?;
64 let attention_mask_value = ort::value::Value::from_array(attention_mask_array)
65 .map_err(|e: ort::Error| e.to_string())?;
66 let token_type_ids_value = ort::value::Value::from_array(token_type_ids_array)
67 .map_err(|e: ort::Error| e.to_string())?;
68
69 let inputs = ort::inputs![
70 "input_ids" => &input_ids_value,
71 "attention_mask" => &attention_mask_value,
72 "token_type_ids" => &token_type_ids_value,
73 ];
74
75 let outputs = self.session.run(inputs).map_err(|e: ort::Error| e.to_string())?;
76
77 let logits_data = outputs["logits"].try_extract_tensor::<f32>()
79 .map_err(|e: ort::Error| e.to_string())?;
80
81 let (_shape, data) = logits_data;
82
83 let mut scores = Vec::with_capacity(seq_len);
85 for i in 0..seq_len {
86 let offset_discard = i * 2;
87 let offset_preserve = i * 2 + 1;
88
89 let logit_discard = data[offset_discard];
90 let logit_preserve = data[offset_preserve];
91
92 let exp_discard = logit_discard.exp();
93 let exp_preserve = logit_preserve.exp();
94 let prob_preserve = exp_preserve / (exp_discard + exp_preserve);
95 scores.push(prob_preserve);
96 }
97
98 let mut sorted_scores = scores.clone();
100 sorted_scores.sort_by(|a: &f32, b: &f32| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
101
102 let threshold_idx = ((1.0 - rate) * seq_len as f32) as usize;
103 let threshold = if threshold_idx < seq_len {
104 sorted_scores[threshold_idx]
105 } else {
106 0.0
107 };
108
109 let mut result_ids = Vec::new();
111 for (i, &score) in scores.iter().enumerate() {
112 if score >= threshold {
113 result_ids.push(ids[i]);
114 }
115 }
116
117 self.tokenizer.decode(&result_ids, true).map_err(|e: Box<dyn std::error::Error + Send + Sync>| e.to_string())
118 }
119}
120
121pub fn apply_heuristic_sieve(text: &str) -> String {
124 static TIMESTAMP_PATTERN: Lazy<Regex> = Lazy::new(|| {
125 Regex::new(r"(\d{4}-\d{2}-\d{2}[T\s]\d{2}:\d{2}:\d{2}([\.,]\d+)?Z?)|(\d{6}\s\d{6}\s\d+)|(\[\d{2}:\d{2}:\d{2}(\.\d+)?\])").unwrap()
126 });
127 static PROGRESS_PATTERN: Lazy<Regex> = Lazy::new(|| {
128 Regex::new(r"\[\d+/\d+\]|[\.]{3,}|\d+%\s*").unwrap()
129 });
130 static METADATA_PATTERN: Lazy<Regex> = Lazy::new(|| {
131 Regex::new(r"\s*(INFO|DEBUG|WARN|ERROR)\s+dfs\..*?:\s*").unwrap()
132 });
133 static MODULE_PATTERN: Lazy<Regex> = Lazy::new(|| {
134 Regex::new(r"(?i)^\s*[\d\.]+\s+(MB|KB|bytes|B)\s+[\w\-\.\/]+.*$").unwrap()
135 });
136
137 let mut sifted = Vec::new();
138
139 for line in text.lines() {
140 let clean_line = TIMESTAMP_PATTERN.replace_all(line, "").trim().to_string();
141 let clean_line = METADATA_PATTERN.replace_all(&clean_line, "").trim().to_string();
142
143 if clean_line.is_empty()
144 || PROGRESS_PATTERN.is_match(&clean_line)
145 || MODULE_PATTERN.is_match(&clean_line)
146 {
147 continue;
148 }
149
150 sifted.push(clean_line);
151 }
152
153 sifted.join("\n")
154}
155
156#[cfg(test)]
157mod tests {
158 use super::*;
159
160 #[test]
161 fn test_heuristic_sieve_removes_timestamps() {
162 let input = "2026-05-01T12:00:00Z INFO some message";
163 let expected = "INFO some message";
164 assert_eq!(apply_heuristic_sieve(input), expected);
165 }
166
167 #[test]
168 fn test_heuristic_sieve_removes_progress() {
169 let input = "Compiling... [1/42] 25%";
170 let expected = "";
171 assert_eq!(apply_heuristic_sieve(input), expected);
172 }
173
174 #[test]
175 fn test_heuristic_sieve_preserves_errors() {
176 let input = "ERROR: connection refused at line 42";
177 let expected = "ERROR: connection refused at line 42";
178 assert_eq!(apply_heuristic_sieve(input), expected);
179 }
180
181 #[test]
182 fn test_heuristic_sieve_strips_hdfs_metadata() {
183 let input = "2026-05-01T12:00:00Z INFO dfs.DataNode: Receiving block";
184 let expected = "Receiving block";
185 assert_eq!(apply_heuristic_sieve(input), expected);
186 }
187}