Skip to main content

st/mcp/
session.rs

1//! MCP Session-Aware Context Negotiation
2//!
3//! Smart compression negotiation that adapts to AI preferences
4//! No more redundant compression hints - negotiate once, compress always!
5
6// use anyhow::Result; // TODO: Use when needed
7use serde::{Deserialize, Serialize};
8use std::path::PathBuf;
9use std::sync::Arc;
10use tokio::sync::RwLock;
11
12/// Compression modes supported by Smart Tree
13#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
14#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
15pub enum CompressionMode {
16    /// No compression - raw output
17    None,
18    /// Light compression - readable with some optimization
19    Light,
20    /// Standard compression - balanced
21    Standard,
22    /// Quantum compression - maximum token reduction
23    Quantum,
24    /// Quantum-semantic - ultimate compression with meaning
25    QuantumSemantic,
26    /// Auto-detect based on context size
27    Auto,
28}
29
30impl CompressionMode {
31    /// Get mode from environment variable or default
32    pub fn from_env() -> Self {
33        std::env::var("ST_COMPRESSION")
34            .ok()
35            .and_then(|s| match s.to_lowercase().as_str() {
36                "none" | "raw" => Some(Self::None),
37                "light" => Some(Self::Light),
38                "standard" | "normal" => Some(Self::Standard),
39                "quantum" => Some(Self::Quantum),
40                "quantum-semantic" | "max" => Some(Self::QuantumSemantic),
41                "auto" => Some(Self::Auto),
42                _ => None,
43            })
44            .unwrap_or(Self::Auto)
45    }
46
47    /// Select optimal mode based on file count
48    pub fn auto_select(file_count: usize) -> Self {
49        match file_count {
50            0..=50 => Self::None,        // Small projects: raw is fine
51            51..=200 => Self::Light,     // Medium: light compression
52            201..=500 => Self::Standard, // Large: standard compression
53            501..=1000 => Self::Quantum, // Huge: quantum compression
54            _ => Self::QuantumSemantic,  // Massive: maximum compression
55        }
56    }
57
58    /// Convert to Smart Tree output mode
59    pub fn to_output_mode(&self) -> &'static str {
60        match self {
61            Self::None => "classic",
62            Self::Light => "ai",
63            Self::Standard => "summary-ai",
64            Self::Quantum => "quantum",
65            Self::QuantumSemantic => "quantum-semantic",
66            Self::Auto => "auto",
67        }
68    }
69}
70
71/// Session preferences from the AI client
72#[derive(Debug, Clone, Serialize, Deserialize)]
73pub struct SessionPreferences {
74    /// Preferred compression format
75    pub format: CompressionMode,
76    /// Traversal depth preference
77    pub depth: DepthMode,
78    /// Which tools to advertise
79    pub tools: ToolAdvertisement,
80    /// Project context path
81    pub project_path: Option<PathBuf>,
82}
83
84impl Default for SessionPreferences {
85    fn default() -> Self {
86        Self {
87            format: CompressionMode::Auto,
88            depth: DepthMode::Adaptive,
89            tools: ToolAdvertisement::Lazy,
90            project_path: None,
91        }
92    }
93}
94
95/// Depth traversal modes
96#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
97#[serde(rename_all = "snake_case")]
98pub enum DepthMode {
99    /// Shallow - 1-2 levels
100    Shallow,
101    /// Standard - 3-4 levels
102    Standard,
103    /// Deep - 5+ levels
104    Deep,
105    /// Adaptive based on directory size
106    Adaptive,
107}
108
109impl DepthMode {
110    pub fn to_depth(&self, dir_count: usize) -> usize {
111        match self {
112            Self::Shallow => 2,
113            Self::Standard => 4,
114            Self::Deep => 10,
115            Self::Adaptive => {
116                // Smart depth based on directory count
117                match dir_count {
118                    0..=10 => 10,  // Small: show everything
119                    11..=50 => 5,  // Medium: reasonable depth
120                    51..=100 => 4, // Large: moderate depth
121                    _ => 3,        // Huge: shallow to avoid overwhelm
122                }
123            }
124        }
125    }
126}
127
128/// Tool advertisement strategy
129#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
130#[serde(rename_all = "snake_case")]
131pub enum ToolAdvertisement {
132    /// Advertise all tools immediately
133    All,
134    /// Only advertise core tools, reveal others on demand
135    Lazy,
136    /// Advertise based on project type
137    ContextAware,
138    /// Minimal - only essential tools
139    Minimal,
140}
141
142/// MCP Session Context
143#[derive(Debug, Clone)]
144pub struct McpSession {
145    /// Unique session ID
146    pub id: String,
147    /// Negotiated preferences
148    pub preferences: SessionPreferences,
149    /// Project context path (inferred or explicit)
150    pub project_path: PathBuf,
151    /// Whether negotiation is complete
152    pub negotiated: bool,
153    /// Session start time
154    pub started_at: std::time::SystemTime,
155}
156
157impl Default for McpSession {
158    fn default() -> Self {
159        Self::new()
160    }
161}
162
163impl McpSession {
164    /// Create new session with defaults
165    pub fn new() -> Self {
166        Self {
167            id: format!("STX-{:x}", rand::random::<u32>()),
168            preferences: SessionPreferences::default(),
169            project_path: std::env::current_dir().unwrap_or_else(|_| PathBuf::from(".")),
170            negotiated: false,
171            started_at: std::time::SystemTime::now(),
172        }
173    }
174
175    /// Create session from initial context
176    pub fn from_context(initial_path: Option<PathBuf>) -> Self {
177        let mut session = Self::new();
178
179        // Try to infer project path
180        if let Some(path) = initial_path {
181            session.project_path = path;
182        } else if let Ok(cwd) = std::env::current_dir() {
183            // Look for project markers
184            if cwd.join("Cargo.toml").exists()
185                || cwd.join("package.json").exists()
186                || cwd.join("pyproject.toml").exists()
187                || cwd.join(".git").exists()
188            {
189                session.project_path = cwd;
190            }
191        }
192
193        // Check environment for preferences
194        session.preferences.format = CompressionMode::from_env();
195
196        session
197    }
198
199    /// Negotiate compression with client
200    pub fn negotiate(&mut self, client_prefs: Option<SessionPreferences>) -> NegotiationResponse {
201        if let Some(prefs) = client_prefs {
202            self.preferences = prefs;
203            self.negotiated = true;
204
205            NegotiationResponse {
206                session_id: self.id.clone(),
207                accepted: true,
208                format: self.preferences.format,
209                project_path: self.project_path.clone(),
210                tools_available: self.get_available_tools(),
211            }
212        } else {
213            // Client didn't provide preferences, suggest defaults
214            NegotiationResponse {
215                session_id: self.id.clone(),
216                accepted: false,
217                format: self.preferences.format,
218                project_path: self.project_path.clone(),
219                tools_available: vec!["overview".to_string(), "find".to_string()],
220            }
221        }
222    }
223
224    /// Get tools to advertise based on preferences
225    pub fn get_available_tools(&self) -> Vec<String> {
226        match self.preferences.tools {
227            ToolAdvertisement::All => {
228                // All 30+ tools
229                vec![
230                    "overview",
231                    "find",
232                    "search",
233                    "analyze",
234                    "edit",
235                    "history",
236                    "context",
237                    "memory",
238                    "compare",
239                    "feedback",
240                    "server_info",
241                    "verify_permissions",
242                    "sse",
243                    // ... all tools
244                ]
245                .into_iter()
246                .map(String::from)
247                .collect()
248            }
249            ToolAdvertisement::Lazy => {
250                // Start with essentials
251                vec!["overview", "find", "search"]
252                    .into_iter()
253                    .map(String::from)
254                    .collect()
255            }
256            ToolAdvertisement::ContextAware => {
257                // Based on project type
258                let mut tools = vec!["overview", "find", "search"];
259
260                // Add project-specific tools
261                if self.project_path.join("Cargo.toml").exists() {
262                    tools.push("analyze"); // Code analysis for Rust
263                }
264                if self.project_path.join(".git").exists() {
265                    tools.push("history"); // Git history
266                }
267
268                tools.into_iter().map(String::from).collect()
269            }
270            ToolAdvertisement::Minimal => {
271                // Absolute minimum
272                vec!["overview"].into_iter().map(String::from).collect()
273            }
274        }
275    }
276
277    /// Apply session context to a tool call
278    pub fn apply_context(&self, tool_name: &str, params: &mut serde_json::Value) {
279        // Auto-inject project path if not specified
280        if let Some(obj) = params.as_object_mut() {
281            if !obj.contains_key("path") {
282                obj.insert(
283                    "path".to_string(),
284                    serde_json::Value::String(self.project_path.to_string_lossy().to_string()),
285                );
286            }
287
288            // Apply compression preference
289            if tool_name == "overview" && !obj.contains_key("mode") {
290                obj.insert(
291                    "mode".to_string(),
292                    serde_json::Value::String(self.preferences.format.to_output_mode().to_string()),
293                );
294            }
295        }
296    }
297}
298
299/// Response to negotiation request
300#[derive(Debug, Serialize, Deserialize)]
301pub struct NegotiationResponse {
302    pub session_id: String,
303    pub accepted: bool,
304    pub format: CompressionMode,
305    pub project_path: PathBuf,
306    pub tools_available: Vec<String>,
307}
308
309/// Negotiation request from client
310#[derive(Debug, Serialize, Deserialize)]
311pub struct NegotiationRequest {
312    pub session_prefs: Option<SessionPreferences>,
313    pub capabilities: Vec<String>,
314}
315
316/// Session manager for multiple concurrent sessions
317pub struct SessionManager {
318    sessions: Arc<RwLock<std::collections::HashMap<String, McpSession>>>,
319}
320
321impl Default for SessionManager {
322    fn default() -> Self {
323        Self::new()
324    }
325}
326
327impl SessionManager {
328    pub fn new() -> Self {
329        Self {
330            sessions: Arc::new(RwLock::new(std::collections::HashMap::new())),
331        }
332    }
333
334    /// Create or get session
335    pub async fn get_or_create(&self, session_id: Option<String>) -> McpSession {
336        let mut sessions = self.sessions.write().await;
337
338        if let Some(id) = session_id {
339            if let Some(session) = sessions.get(&id) {
340                return session.clone();
341            }
342        }
343
344        // Create new session
345        let session = McpSession::from_context(None);
346        sessions.insert(session.id.clone(), session.clone());
347        session
348    }
349
350    /// Update session after negotiation
351    pub async fn update(&self, session: McpSession) {
352        let mut sessions = self.sessions.write().await;
353        sessions.insert(session.id.clone(), session);
354    }
355
356    /// Clean up old sessions (>1 hour)
357    pub async fn cleanup(&self) {
358        let mut sessions = self.sessions.write().await;
359        let now = std::time::SystemTime::now();
360
361        sessions.retain(|_, session| {
362            if let Ok(duration) = now.duration_since(session.started_at) {
363                duration.as_secs() < 3600 // Keep sessions less than 1 hour old
364            } else {
365                true
366            }
367        });
368    }
369}
370
371// Add rand for session IDs (already in dependencies)
372use rand;
373
374#[cfg(test)]
375mod tests {
376    use super::*;
377
378    #[test]
379    fn test_compression_auto_select() {
380        assert_eq!(CompressionMode::auto_select(10), CompressionMode::None);
381        assert_eq!(CompressionMode::auto_select(100), CompressionMode::Light);
382        assert_eq!(CompressionMode::auto_select(300), CompressionMode::Standard);
383        assert_eq!(CompressionMode::auto_select(700), CompressionMode::Quantum);
384        assert_eq!(
385            CompressionMode::auto_select(2000),
386            CompressionMode::QuantumSemantic
387        );
388    }
389
390    #[test]
391    fn test_depth_adaptive() {
392        let depth = DepthMode::Adaptive;
393        assert_eq!(depth.to_depth(5), 10); // Small: deep
394        assert_eq!(depth.to_depth(30), 5); // Medium: moderate
395        assert_eq!(depth.to_depth(200), 3); // Huge: shallow
396    }
397}