1use std::collections::HashMap;
16
17use crate::error::{Result, SqzError};
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub enum ApiFormat {
22 OpenAi,
24 Anthropic,
26 Google,
28}
29
30impl ApiFormat {
31 pub fn from_path(path: &str) -> Option<Self> {
33 if path.contains("/chat/completions") {
34 Some(ApiFormat::OpenAi)
35 } else if path.contains("/messages") {
36 Some(ApiFormat::Anthropic)
37 } else if path.contains("/generateContent") {
38 Some(ApiFormat::Google)
39 } else {
40 None
41 }
42 }
43
44 pub fn default_upstream(&self) -> &'static str {
46 match self {
47 ApiFormat::OpenAi => "https://api.openai.com",
48 ApiFormat::Anthropic => "https://api.anthropic.com",
49 ApiFormat::Google => "https://generativelanguage.googleapis.com",
50 }
51 }
52}
53
54#[derive(Debug, Clone)]
56pub struct ProxyConfig {
57 pub port: u16,
59 pub upstreams: HashMap<String, String>,
61 pub keep_recent_messages: usize,
64 pub compress_system: bool,
66 pub compress_tool_results: bool,
68 pub summarize_history: bool,
70}
71
72impl Default for ProxyConfig {
73 fn default() -> Self {
74 Self {
75 port: 8080,
76 upstreams: HashMap::new(),
77 keep_recent_messages: 10,
78 compress_system: true,
79 compress_tool_results: true,
80 summarize_history: true,
81 }
82 }
83}
84
85#[derive(Debug, Clone, Default)]
87pub struct ProxyStats {
88 pub tokens_original: u32,
90 pub tokens_compressed: u32,
92 pub messages_compressed: u32,
94 pub messages_summarized: u32,
96 pub system_tokens_saved: u32,
98 pub tool_result_tokens_saved: u32,
100}
101
102impl ProxyStats {
103 pub fn tokens_saved(&self) -> u32 {
104 self.tokens_original.saturating_sub(self.tokens_compressed)
105 }
106
107 pub fn reduction_pct(&self) -> f64 {
108 if self.tokens_original == 0 {
109 0.0
110 } else {
111 (1.0 - self.tokens_compressed as f64 / self.tokens_original as f64) * 100.0
112 }
113 }
114}
115
116pub fn compress_request(
121 body: &str,
122 format: ApiFormat,
123 config: &ProxyConfig,
124 engine: &crate::engine::SqzEngine,
125) -> Result<(String, ProxyStats)> {
126 let mut parsed: serde_json::Value = serde_json::from_str(body)
127 .map_err(|e| SqzError::Other(format!("proxy: invalid JSON body: {e}")))?;
128
129 let mut stats = ProxyStats::default();
130
131 match format {
132 ApiFormat::OpenAi => compress_openai(&mut parsed, config, engine, &mut stats)?,
133 ApiFormat::Anthropic => compress_anthropic(&mut parsed, config, engine, &mut stats)?,
134 ApiFormat::Google => compress_google(&mut parsed, config, engine, &mut stats)?,
135 }
136
137 let compressed_body = serde_json::to_string(&parsed)
138 .map_err(|e| SqzError::Other(format!("proxy: JSON serialize error: {e}")))?;
139
140 Ok((compressed_body, stats))
141}
142
143fn compress_openai(
146 body: &mut serde_json::Value,
147 config: &ProxyConfig,
148 engine: &crate::engine::SqzEngine,
149 stats: &mut ProxyStats,
150) -> Result<()> {
151 let messages = match body.get_mut("messages") {
152 Some(serde_json::Value::Array(arr)) => arr,
153 _ => return Ok(()), };
155
156 let total = messages.len();
157 let keep_recent = config.keep_recent_messages.min(total);
158
159 for (i, msg) in messages.iter_mut().enumerate() {
160 let role = msg.get("role").and_then(|v| v.as_str()).unwrap_or("").to_string();
161 let is_recent = i >= total.saturating_sub(keep_recent);
162
163 match role.as_str() {
164 "system" if config.compress_system => {
165 compress_message_content(msg, engine, stats, "system")?;
166 }
167 "tool" if config.compress_tool_results => {
168 compress_message_content(msg, engine, stats, "tool")?;
169 }
170 "assistant" | "user" if !is_recent && config.summarize_history => {
171 summarize_message(msg, stats)?;
172 }
173 _ => {}
174 }
175 }
176
177 Ok(())
178}
179
180fn compress_anthropic(
183 body: &mut serde_json::Value,
184 config: &ProxyConfig,
185 engine: &crate::engine::SqzEngine,
186 stats: &mut ProxyStats,
187) -> Result<()> {
188 if config.compress_system {
190 if let Some(system) = body.get_mut("system") {
191 if let Some(text) = system.as_str() {
192 let original_tokens = estimate_tokens(text);
193 let compressed = engine.compress_or_passthrough(text);
194 if compressed.tokens_compressed < original_tokens {
195 *system = serde_json::Value::String(compressed.data);
196 stats.system_tokens_saved += original_tokens - compressed.tokens_compressed;
197 stats.tokens_original += original_tokens;
198 stats.tokens_compressed += compressed.tokens_compressed;
199 }
200 }
201 }
202 }
203
204 let messages = match body.get_mut("messages") {
206 Some(serde_json::Value::Array(arr)) => arr,
207 _ => return Ok(()),
208 };
209
210 let total = messages.len();
211 let keep_recent = config.keep_recent_messages.min(total);
212
213 for (i, msg) in messages.iter_mut().enumerate() {
214 let role = msg.get("role").and_then(|v| v.as_str()).unwrap_or("").to_string();
215 let is_recent = i >= total.saturating_sub(keep_recent);
216
217 if role == "user" && config.compress_tool_results {
219 if let Some(content) = msg.get_mut("content") {
220 if let Some(arr) = content.as_array_mut() {
221 for block in arr.iter_mut() {
222 if block.get("type").and_then(|v| v.as_str()) == Some("tool_result") {
223 compress_content_block(block, engine, stats)?;
224 }
225 }
226 }
227 }
228 }
229
230 if !is_recent && config.summarize_history && (role == "user" || role == "assistant") {
232 summarize_message(msg, stats)?;
233 }
234 }
235
236 Ok(())
237}
238
239fn compress_google(
242 body: &mut serde_json::Value,
243 config: &ProxyConfig,
244 engine: &crate::engine::SqzEngine,
245 stats: &mut ProxyStats,
246) -> Result<()> {
247 if config.compress_system {
249 if let Some(si) = body.get_mut("system_instruction") {
250 if let Some(parts) = si.get_mut("parts") {
251 if let Some(arr) = parts.as_array_mut() {
252 for part in arr.iter_mut() {
253 if let Some(text) = part.get_mut("text") {
254 if let Some(s) = text.as_str() {
255 let original_tokens = estimate_tokens(s);
256 let compressed = engine.compress_or_passthrough(s);
257 if compressed.tokens_compressed < original_tokens {
258 *text = serde_json::Value::String(compressed.data);
259 stats.system_tokens_saved +=
260 original_tokens - compressed.tokens_compressed;
261 stats.tokens_original += original_tokens;
262 stats.tokens_compressed += compressed.tokens_compressed;
263 }
264 }
265 }
266 }
267 }
268 }
269 }
270 }
271
272 let contents = match body.get_mut("contents") {
274 Some(serde_json::Value::Array(arr)) => arr,
275 _ => return Ok(()),
276 };
277
278 let total = contents.len();
279 let keep_recent = config.keep_recent_messages.min(total);
280
281 for (i, content) in contents.iter_mut().enumerate() {
282 let is_recent = i >= total.saturating_sub(keep_recent);
283
284 if let Some(parts) = content.get_mut("parts") {
285 if let Some(arr) = parts.as_array_mut() {
286 for part in arr.iter_mut() {
287 if let Some(text) = part.get_mut("text") {
288 if let Some(s) = text.as_str() {
289 if !is_recent && config.summarize_history && s.len() > 200 {
290 let summary = summarize_text(s);
291 let original_tokens = estimate_tokens(s);
292 let summary_tokens = estimate_tokens(&summary);
293 *text = serde_json::Value::String(summary);
294 stats.messages_summarized += 1;
295 stats.tokens_original += original_tokens;
296 stats.tokens_compressed += summary_tokens;
297 } else if config.compress_tool_results {
298 let original_tokens = estimate_tokens(s);
299 let compressed = engine.compress_or_passthrough(s);
300 if compressed.tokens_compressed < original_tokens {
301 *text = serde_json::Value::String(compressed.data);
302 stats.tokens_original += original_tokens;
303 stats.tokens_compressed += compressed.tokens_compressed;
304 stats.messages_compressed += 1;
305 }
306 }
307 }
308 }
309 }
310 }
311 }
312 }
313
314 Ok(())
315}
316
317fn compress_message_content(
321 msg: &mut serde_json::Value,
322 engine: &crate::engine::SqzEngine,
323 stats: &mut ProxyStats,
324 role: &str,
325) -> Result<()> {
326 let content = match msg.get_mut("content") {
327 Some(v) => v,
328 None => return Ok(()),
329 };
330
331 if let Some(text) = content.as_str() {
332 let original_tokens = estimate_tokens(text);
333 let compressed = engine.compress_or_passthrough(text);
334 if compressed.tokens_compressed < original_tokens {
335 *content = serde_json::Value::String(compressed.data);
336 stats.messages_compressed += 1;
337 let saved = original_tokens - compressed.tokens_compressed;
338 match role {
339 "system" => stats.system_tokens_saved += saved,
340 "tool" => stats.tool_result_tokens_saved += saved,
341 _ => {}
342 }
343 stats.tokens_original += original_tokens;
344 stats.tokens_compressed += compressed.tokens_compressed;
345 }
346 }
347
348 Ok(())
349}
350
351fn compress_content_block(
353 block: &mut serde_json::Value,
354 engine: &crate::engine::SqzEngine,
355 stats: &mut ProxyStats,
356) -> Result<()> {
357 if let Some(content) = block.get_mut("content") {
358 if let Some(text) = content.as_str() {
359 let original_tokens = estimate_tokens(text);
360 let compressed = engine.compress_or_passthrough(text);
361 if compressed.tokens_compressed < original_tokens {
362 *content = serde_json::Value::String(compressed.data);
363 stats.tool_result_tokens_saved +=
364 original_tokens - compressed.tokens_compressed;
365 stats.tokens_original += original_tokens;
366 stats.tokens_compressed += compressed.tokens_compressed;
367 stats.messages_compressed += 1;
368 }
369 }
370 }
371 Ok(())
372}
373
374fn summarize_message(msg: &mut serde_json::Value, stats: &mut ProxyStats) -> Result<()> {
376 let content = match msg.get_mut("content") {
377 Some(v) => v,
378 None => return Ok(()),
379 };
380
381 if let Some(text) = content.as_str() {
382 if text.len() < 200 {
383 return Ok(()); }
385 let original_tokens = estimate_tokens(text);
386 let summary = summarize_text(text);
387 let summary_tokens = estimate_tokens(&summary);
388
389 if summary_tokens < original_tokens {
390 *content = serde_json::Value::String(summary);
391 stats.messages_summarized += 1;
392 stats.tokens_original += original_tokens;
393 stats.tokens_compressed += summary_tokens;
394 }
395 }
396
397 Ok(())
398}
399
400fn summarize_text(text: &str) -> String {
403 let lines: Vec<&str> = text.lines().collect();
404 if lines.len() <= 3 {
405 return text.to_string();
406 }
407
408 let first_line = lines[0];
409 let last_line = lines[lines.len() - 1];
410 let total_lines = lines.len();
411
412 format!(
413 "{first_line}\n[... {total_lines} lines, {} chars ...]\n{last_line}",
414 text.len()
415 )
416}
417
418fn estimate_tokens(text: &str) -> u32 {
420 ((text.len() as f64) / 4.0).ceil() as u32
421}
422
423pub fn parse_http_request(raw: &[u8]) -> Result<(String, String, HashMap<String, String>, String)> {
426 let text = String::from_utf8_lossy(raw);
427 let mut lines = text.lines();
428
429 let request_line = lines.next().ok_or_else(|| SqzError::Other("empty request".into()))?;
431 let parts: Vec<&str> = request_line.split_whitespace().collect();
432 if parts.len() < 2 {
433 return Err(SqzError::Other("malformed request line".into()));
434 }
435 let method = parts[0].to_string();
436 let path = parts[1].to_string();
437
438 let mut headers = HashMap::new();
440 for line in lines {
441 if line.is_empty() {
442 break;
443 }
444 if let Some((key, value)) = line.split_once(':') {
445 headers.insert(
446 key.trim().to_lowercase(),
447 value.trim().to_string(),
448 );
449 }
450 }
451
452 let body_start = text.find("\r\n\r\n").map(|p| p + 4)
454 .or_else(|| text.find("\n\n").map(|p| p + 2))
455 .unwrap_or(text.len());
456 let body = text[body_start..].to_string();
457
458 Ok((method, path, headers, body))
459}
460
461pub fn build_http_response(status: u16, status_text: &str, headers: &[(&str, &str)], body: &str) -> Vec<u8> {
463 let mut response = format!("HTTP/1.1 {status} {status_text}\r\n");
464 for (key, value) in headers {
465 response.push_str(&format!("{key}: {value}\r\n"));
466 }
467 response.push_str(&format!("content-length: {}\r\n", body.len()));
468 response.push_str("\r\n");
469 response.push_str(body);
470 response.into_bytes()
471}
472
473#[cfg(test)]
476mod tests {
477 use super::*;
478
479 #[test]
480 fn test_api_format_from_path() {
481 assert_eq!(ApiFormat::from_path("/v1/chat/completions"), Some(ApiFormat::OpenAi));
482 assert_eq!(ApiFormat::from_path("/v1/messages"), Some(ApiFormat::Anthropic));
483 assert_eq!(ApiFormat::from_path("/v1/models/gemini/generateContent"), Some(ApiFormat::Google));
484 assert_eq!(ApiFormat::from_path("/unknown"), None);
485 }
486
487 #[test]
488 fn test_compress_openai_request() {
489 let engine = crate::engine::SqzEngine::new().unwrap();
490 let config = ProxyConfig::default();
491
492 let body = serde_json::json!({
493 "model": "gpt-4",
494 "messages": [
495 {"role": "system", "content": "You are a helpful assistant. You help with coding tasks. You follow best practices. You write clean code. You test everything."},
496 {"role": "user", "content": "Hello"},
497 {"role": "assistant", "content": "Hi there! How can I help?"},
498 {"role": "user", "content": "Write a function"}
499 ]
500 });
501
502 let (compressed, stats) = compress_request(
503 &serde_json::to_string(&body).unwrap(),
504 ApiFormat::OpenAi,
505 &config,
506 &engine,
507 ).unwrap();
508
509 let parsed: serde_json::Value = serde_json::from_str(&compressed).unwrap();
511 assert!(parsed.get("messages").is_some());
512 assert!(parsed.get("model").is_some());
513 assert_eq!(parsed["model"].as_str().unwrap(), "gpt-4");
515 }
516
517 #[test]
518 fn test_compress_anthropic_request() {
519 let engine = crate::engine::SqzEngine::new().unwrap();
520 let config = ProxyConfig::default();
521
522 let body = serde_json::json!({
523 "model": "claude-sonnet-4-20250514",
524 "max_tokens": 1024,
525 "system": "You are a helpful coding assistant with extensive knowledge of Rust, Python, and TypeScript.",
526 "messages": [
527 {"role": "user", "content": "Hello"},
528 {"role": "assistant", "content": "Hi! How can I help you today?"}
529 ]
530 });
531
532 let (compressed, stats) = compress_request(
533 &serde_json::to_string(&body).unwrap(),
534 ApiFormat::Anthropic,
535 &config,
536 &engine,
537 ).unwrap();
538
539 let parsed: serde_json::Value = serde_json::from_str(&compressed).unwrap();
540 assert!(parsed.get("system").is_some());
541 assert!(parsed.get("messages").is_some());
542 assert_eq!(parsed["model"].as_str().unwrap(), "claude-sonnet-4-20250514");
543 }
544
545 #[test]
546 fn test_compress_tool_results() {
547 let engine = crate::engine::SqzEngine::new().unwrap();
548 let config = ProxyConfig::default();
549
550 let body = serde_json::json!({
552 "model": "gpt-4",
553 "messages": [
554 {"role": "user", "content": "Get the data"},
555 {"role": "tool", "content": "{\"id\":1,\"name\":\"Alice\",\"debug\":null,\"trace\":null,\"internal_id\":null,\"metadata\":{\"plan\":\"pro\",\"seats\":10,\"billing_cycle\":\"monthly\",\"internal_id\":null}}"}
556 ]
557 });
558
559 let (compressed, stats) = compress_request(
560 &serde_json::to_string(&body).unwrap(),
561 ApiFormat::OpenAi,
562 &config,
563 &engine,
564 ).unwrap();
565
566 let parsed: serde_json::Value = serde_json::from_str(&compressed).unwrap();
568 let tool_content = parsed["messages"][1]["content"].as_str().unwrap();
569 assert!(
571 !tool_content.contains("\"debug\":null") || tool_content.starts_with("TOON:"),
572 "tool result should be compressed: {tool_content}"
573 );
574 }
575
576 #[test]
577 fn test_summarize_old_history() {
578 let engine = crate::engine::SqzEngine::new().unwrap();
579 let config = ProxyConfig {
580 keep_recent_messages: 2,
581 ..Default::default()
582 };
583
584 let long_content = "This is a very long message that contains a lot of detail about the implementation.\nIt spans multiple lines and discusses various aspects of the code.\nThe architecture is modular with clear separation of concerns.\nEach component handles a specific responsibility.\nThe database layer manages persistence and caching.\nThe API layer handles routing and validation.\nError handling is centralized for consistency.\nLogging is structured and searchable.\nThe deployment pipeline is fully automated.\nTests run on every commit to ensure quality.\nDocumentation is kept up to date with the code.\nPerformance is monitored in production.\nSecurity reviews happen before each release.\nThe team follows agile practices with two-week sprints.\nCode reviews are required for all changes.\nThe CI pipeline runs in under five minutes.\nStaging environments mirror production exactly.\nFeature flags control gradual rollouts.\nMetrics are collected for all user-facing operations.\nAlerts fire when error rates exceed thresholds.";
585 let body = serde_json::json!({
586 "model": "gpt-4",
587 "messages": [
588 {"role": "system", "content": "You are helpful."},
589 {"role": "user", "content": long_content},
590 {"role": "assistant", "content": long_content},
591 {"role": "user", "content": "Recent message 1"},
592 {"role": "assistant", "content": "Recent response 1"}
593 ]
594 });
595
596 let (compressed, stats) = compress_request(
597 &serde_json::to_string(&body).unwrap(),
598 ApiFormat::OpenAi,
599 &config,
600 &engine,
601 ).unwrap();
602
603 let parsed: serde_json::Value = serde_json::from_str(&compressed).unwrap();
604 let messages = parsed["messages"].as_array().unwrap();
605
606 assert_eq!(messages[3]["content"].as_str().unwrap(), "Recent message 1");
608 assert_eq!(messages[4]["content"].as_str().unwrap(), "Recent response 1");
609
610 let old_msg = messages[1]["content"].as_str().unwrap();
612 assert!(old_msg.len() < long_content.len(),
613 "old message should be summarized: {} vs {}", old_msg.len(), long_content.len());
614 }
615
616 #[test]
617 fn test_summarize_text() {
618 let text = "First line of content.\nSecond line.\nThird line.\nFourth line.\nLast line.";
619 let summary = summarize_text(text);
620 assert!(summary.contains("First line"));
621 assert!(summary.contains("Last line"));
622 assert!(summary.contains("5 lines"));
623 }
624
625 #[test]
626 fn test_summarize_text_short() {
627 let text = "Short text.\nOnly two lines.";
628 let summary = summarize_text(text);
629 assert_eq!(summary, text, "short text should not be summarized");
630 }
631
632 #[test]
633 fn test_estimate_tokens() {
634 assert_eq!(estimate_tokens(""), 0);
635 assert_eq!(estimate_tokens("abcd"), 1);
636 assert_eq!(estimate_tokens("abcde"), 2);
637 }
638
639 #[test]
640 fn test_proxy_stats() {
641 let stats = ProxyStats {
642 tokens_original: 1000,
643 tokens_compressed: 600,
644 ..Default::default()
645 };
646 assert_eq!(stats.tokens_saved(), 400);
647 assert!((stats.reduction_pct() - 40.0).abs() < 0.1);
648 }
649
650 #[test]
651 fn test_proxy_stats_zero() {
652 let stats = ProxyStats::default();
653 assert_eq!(stats.tokens_saved(), 0);
654 assert_eq!(stats.reduction_pct(), 0.0);
655 }
656
657 #[test]
658 fn test_parse_http_request() {
659 let raw = b"POST /v1/messages HTTP/1.1\r\nContent-Type: application/json\r\nAuthorization: Bearer sk-test\r\n\r\n{\"model\":\"claude\"}";
660 let (method, path, headers, body) = parse_http_request(raw).unwrap();
661 assert_eq!(method, "POST");
662 assert_eq!(path, "/v1/messages");
663 assert_eq!(headers.get("content-type").unwrap(), "application/json");
664 assert_eq!(headers.get("authorization").unwrap(), "Bearer sk-test");
665 assert!(body.contains("claude"));
666 }
667
668 #[test]
669 fn test_build_http_response() {
670 let resp = build_http_response(200, "OK", &[("content-type", "application/json")], "{\"ok\":true}");
671 let text = String::from_utf8(resp).unwrap();
672 assert!(text.starts_with("HTTP/1.1 200 OK"));
673 assert!(text.contains("content-type: application/json"));
674 assert!(text.ends_with("{\"ok\":true}"));
675 }
676
677 #[test]
678 fn test_config_defaults() {
679 let config = ProxyConfig::default();
680 assert_eq!(config.port, 8080);
681 assert_eq!(config.keep_recent_messages, 10);
682 assert!(config.compress_system);
683 assert!(config.compress_tool_results);
684 assert!(config.summarize_history);
685 }
686
687 #[test]
688 fn test_compress_preserves_model_field() {
689 let engine = crate::engine::SqzEngine::new().unwrap();
690 let config = ProxyConfig::default();
691
692 for format in [ApiFormat::OpenAi, ApiFormat::Anthropic] {
693 let body = match format {
694 ApiFormat::OpenAi => serde_json::json!({
695 "model": "gpt-4-turbo",
696 "messages": [{"role": "user", "content": "hi"}]
697 }),
698 ApiFormat::Anthropic => serde_json::json!({
699 "model": "claude-sonnet-4-20250514",
700 "max_tokens": 1024,
701 "messages": [{"role": "user", "content": "hi"}]
702 }),
703 _ => continue,
704 };
705
706 let (compressed, _) = compress_request(
707 &serde_json::to_string(&body).unwrap(),
708 format,
709 &config,
710 &engine,
711 ).unwrap();
712
713 let parsed: serde_json::Value = serde_json::from_str(&compressed).unwrap();
714 assert!(parsed.get("model").is_some(), "model field must be preserved for {:?}", format);
715 }
716 }
717
718 use proptest::prelude::*;
719
720 proptest! {
721 #[test]
723 fn prop_compressed_output_is_valid_json(
724 content in "[a-z ]{10,200}",
725 ) {
726 let engine = crate::engine::SqzEngine::new().unwrap();
727 let config = ProxyConfig::default();
728 let body = serde_json::json!({
729 "model": "test",
730 "messages": [{"role": "user", "content": content}]
731 });
732
733 let (compressed, _) = compress_request(
734 &serde_json::to_string(&body).unwrap(),
735 ApiFormat::OpenAi,
736 &config,
737 &engine,
738 ).unwrap();
739
740 let parsed: std::result::Result<serde_json::Value, _> = serde_json::from_str(&compressed);
742 prop_assert!(parsed.is_ok(), "compressed output must be valid JSON");
743 }
744 }
745}