Skip to main content

st/
compression_manager.rs

1// Smart Compression Manager - Token-aware compression for all outputs
2// "Compression so smart, it knows when to squeeze!" - Aye
3
4use anyhow::Result;
5use once_cell::sync::Lazy;
6use serde_json::Value;
7use std::sync::RwLock;
8
9/// Global compression state
10static COMPRESSION_STATE: Lazy<RwLock<CompressionState>> =
11    Lazy::new(|| RwLock::new(CompressionState::default()));
12
13#[derive(Debug, Clone)]
14pub struct CompressionState {
15    /// Whether the client supports compressed content
16    pub client_supports_compression: Option<bool>,
17
18    /// Maximum tokens before auto-compression kicks in
19    pub max_tokens: usize,
20
21    /// Whether to always compress (override)
22    pub force_compression: bool,
23
24    /// Whether to never compress (override)
25    pub disable_compression: bool,
26
27    /// Statistics for debugging
28    pub stats: CompressionStats,
29}
30
31#[derive(Debug, Clone, Default)]
32pub struct CompressionStats {
33    pub total_compressions: usize,
34    pub bytes_saved: usize,
35    pub tokens_saved: usize,
36    pub failed_decompressions: usize,
37}
38
39impl Default for CompressionState {
40    fn default() -> Self {
41        Self {
42            client_supports_compression: None, // Unknown until tested
43            max_tokens: 20000,                 // Safe limit (MCP allows 25k)
44            force_compression: false,
45            disable_compression: false,
46            stats: CompressionStats::default(),
47        }
48    }
49}
50
51/// Test if client supports compression by including a small compressed hint
52pub fn create_compression_test() -> Value {
53    // Create a small compressed message that won't break non-supporting clients
54    let test_message = "COMPRESSION_SUPPORTED";
55    let compressed = compress_string(test_message).unwrap_or_default();
56
57    serde_json::json!({
58        "_compression_test": compressed,
59        "_compression_hint": "This server supports compressed responses. If you can read the _compression_test field after decompressing, reply with 'compression:ok' in your next request."
60    })
61}
62
63/// Check if a client response indicates compression support
64pub fn check_client_compression_support(request: &Value) -> bool {
65    // Check for explicit compression acknowledgment
66    if let Some(params) = request.get("params") {
67        if let Some(compression) = params.get("compression") {
68            if compression.as_str() == Some("ok") {
69                set_compression_support(true);
70                return true;
71            }
72        }
73
74        // Check for compression capability in client info
75        if let Some(capabilities) = params.get("capabilities") {
76            if let Some(compression) = capabilities.get("compression") {
77                let supported = compression.as_bool().unwrap_or(false);
78                set_compression_support(supported);
79                return supported;
80            }
81        }
82    }
83
84    false
85}
86
87/// Set global compression support status
88pub fn set_compression_support(supported: bool) {
89    if let Ok(mut state) = COMPRESSION_STATE.write() {
90        state.client_supports_compression = Some(supported);
91        eprintln!(
92            "🗜️ Client compression support: {}",
93            if supported { "YES" } else { "NO" }
94        );
95    }
96}
97
98/// Check if we should compress a response based on its size
99pub fn should_compress_response(content: &str) -> bool {
100    let state = COMPRESSION_STATE.read().unwrap();
101
102    // Check overrides
103    if state.disable_compression {
104        return false;
105    }
106    if state.force_compression {
107        return true;
108    }
109
110    // If we don't know if client supports compression, don't compress
111    if state.client_supports_compression != Some(true) {
112        return false;
113    }
114
115    // Estimate tokens (rough: 1 token ≈ 4 characters)
116    let estimated_tokens = content.len() / 4;
117
118    estimated_tokens > state.max_tokens
119}
120
121/// Compress a string using zlib
122pub fn compress_string(content: &str) -> Result<String> {
123    use flate2::write::ZlibEncoder;
124    use flate2::Compression;
125    use std::io::Write;
126
127    let mut encoder = ZlibEncoder::new(Vec::new(), Compression::default());
128    encoder.write_all(content.as_bytes())?;
129    let compressed = encoder.finish()?;
130
131    // Update stats
132    if let Ok(mut state) = COMPRESSION_STATE.write() {
133        state.stats.total_compressions += 1;
134        state.stats.bytes_saved += content.len().saturating_sub(compressed.len());
135        state.stats.tokens_saved += (content.len() / 4).saturating_sub(compressed.len() / 4);
136    }
137
138    Ok(format!("COMPRESSED_V1:{}", hex::encode(&compressed)))
139}
140
141/// Smart compress any MCP response content
142pub fn smart_compress_mcp_response(response: &mut Value) -> Result<()> {
143    // Look for content in the response
144    if let Some(content) = response.get_mut("content") {
145        if let Some(content_array) = content.as_array_mut() {
146            for item in content_array {
147                if let Some(text) = item.get_mut("text") {
148                    if let Some(text_str) = text.as_str() {
149                        // Check if we should compress
150                        if should_compress_response(text_str) {
151                            let compressed = compress_string(text_str)?;
152
153                            // Calculate compression stats
154                            let original_size = text_str.len();
155                            let compressed_size = compressed.len();
156                            let ratio =
157                                100.0 - (compressed_size as f64 / original_size as f64 * 100.0);
158
159                            eprintln!(
160                                "🗜️ Auto-compressed response: {} → {} bytes ({:.1}% reduction)",
161                                original_size, compressed_size, ratio
162                            );
163                            eprintln!(
164                                "💡 Estimated tokens saved: {}",
165                                (original_size / 4).saturating_sub(compressed_size / 4)
166                            );
167
168                            *text = Value::String(compressed);
169
170                            // Add compression metadata
171                            item["_compressed"] = serde_json::json!(true);
172                            item["_original_size"] = serde_json::json!(original_size);
173                            item["_compression_ratio"] = serde_json::json!(ratio);
174                        }
175                    }
176                }
177            }
178        }
179    }
180
181    // Also check result field for tool responses
182    if let Some(result) = response.get_mut("result") {
183        if let Some(content) = result.get_mut("content") {
184            if let Some(content_array) = content.as_array_mut() {
185                for item in content_array {
186                    if let Some(text) = item.get_mut("text") {
187                        if let Some(text_str) = text.as_str() {
188                            if should_compress_response(text_str) {
189                                let compressed = compress_string(text_str)?;
190
191                                let original_size = text_str.len();
192                                let compressed_size = compressed.len();
193                                let ratio =
194                                    100.0 - (compressed_size as f64 / original_size as f64 * 100.0);
195
196                                eprintln!(
197                                    "🗜️ Auto-compressed result: {} → {} bytes ({:.1}% reduction)",
198                                    original_size, compressed_size, ratio
199                                );
200
201                                *text = Value::String(compressed);
202
203                                item["_compressed"] = serde_json::json!(true);
204                                item["_original_size"] = serde_json::json!(original_size);
205                            }
206                        }
207                    }
208                }
209            }
210        }
211    }
212
213    Ok(())
214}
215
216/// Get compression statistics
217pub fn get_compression_stats() -> CompressionStats {
218    COMPRESSION_STATE.read().unwrap().stats.clone()
219}
220
221/// Configure compression settings
222pub fn configure_compression(
223    max_tokens: Option<usize>,
224    force: Option<bool>,
225    disable: Option<bool>,
226) {
227    if let Ok(mut state) = COMPRESSION_STATE.write() {
228        if let Some(max) = max_tokens {
229            state.max_tokens = max;
230        }
231        if let Some(f) = force {
232            state.force_compression = f;
233        }
234        if let Some(d) = disable {
235            state.disable_compression = d;
236        }
237    }
238}
239
240#[cfg(test)]
241mod tests {
242    use super::*;
243
244    #[test]
245    fn test_compression() {
246        let content = "Hello World!".repeat(1000);
247        let compressed = compress_string(&content).unwrap();
248        assert!(compressed.starts_with("COMPRESSED_V1:"));
249        assert!(compressed.len() < content.len());
250    }
251
252    #[test]
253    fn test_should_compress() {
254        set_compression_support(true);
255
256        let small_content = "small";
257        assert!(!should_compress_response(small_content));
258
259        let large_content = "x".repeat(100000);
260        assert!(should_compress_response(&large_content));
261    }
262}