turbomcp_protocol/context/
client.rs

1//! Client-related context types for MCP client session management.
2//!
3//! This module contains types for managing client sessions, capabilities,
4//! and identification across different transport mechanisms.
5
6use std::collections::HashMap;
7use std::sync::Arc;
8
9use chrono::{DateTime, Utc};
10use serde::{Deserialize, Serialize};
11
12/// Client capabilities for server-initiated requests
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct ClientCapabilities {
15    /// Supports sampling/message creation
16    pub sampling: bool,
17    /// Supports roots listing
18    pub roots: bool,
19    /// Supports elicitation
20    pub elicitation: bool,
21    /// Maximum concurrent server requests
22    pub max_concurrent_requests: usize,
23    /// Supported experimental features
24    pub experimental: HashMap<String, bool>,
25}
26
27/// Client identifier types for authentication and tracking
28#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
29pub enum ClientId {
30    /// Explicit client ID from header
31    Header(String),
32    /// Bearer token from Authorization header
33    Token(String),
34    /// Session cookie
35    Session(String),
36    /// Query parameter
37    QueryParam(String),
38    /// Hash of User-Agent (fallback)
39    UserAgent(String),
40    /// Anonymous client
41    Anonymous,
42}
43
44impl ClientId {
45    /// Get the string representation of the client ID
46    #[must_use]
47    pub fn as_str(&self) -> &str {
48        match self {
49            Self::Header(id)
50            | Self::Token(id)
51            | Self::Session(id)
52            | Self::QueryParam(id)
53            | Self::UserAgent(id) => id,
54            Self::Anonymous => "anonymous",
55        }
56    }
57
58    /// Check if the client is authenticated
59    #[must_use]
60    pub const fn is_authenticated(&self) -> bool {
61        matches!(self, Self::Token(_) | Self::Session(_))
62    }
63
64    /// Get the authentication method
65    #[must_use]
66    pub const fn auth_method(&self) -> &'static str {
67        match self {
68            Self::Header(_) => "header",
69            Self::Token(_) => "bearer_token",
70            Self::Session(_) => "session_cookie",
71            Self::QueryParam(_) => "query_param",
72            Self::UserAgent(_) => "user_agent",
73            Self::Anonymous => "anonymous",
74        }
75    }
76}
77
78/// Client session information for tracking and analytics
79#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct ClientSession {
81    /// Unique client identifier
82    pub client_id: String,
83    /// Client name (optional, human-readable)
84    pub client_name: Option<String>,
85    /// When the client connected
86    pub connected_at: DateTime<Utc>,
87    /// Last activity timestamp
88    pub last_activity: DateTime<Utc>,
89    /// Number of requests made
90    pub request_count: usize,
91    /// Transport type (stdio, http, websocket, etc.)
92    pub transport_type: String,
93    /// Authentication status
94    pub authenticated: bool,
95    /// Client capabilities (optional)
96    pub capabilities: Option<serde_json::Value>,
97    /// Additional metadata
98    pub metadata: HashMap<String, serde_json::Value>,
99}
100
101impl ClientSession {
102    /// Create a new client session
103    #[must_use]
104    pub fn new(client_id: String, transport_type: String) -> Self {
105        let now = Utc::now();
106        Self {
107            client_id,
108            client_name: None,
109            connected_at: now,
110            last_activity: now,
111            request_count: 0,
112            transport_type,
113            authenticated: false,
114            capabilities: None,
115            metadata: HashMap::new(),
116        }
117    }
118
119    /// Update activity timestamp and increment request count
120    pub fn update_activity(&mut self) {
121        self.last_activity = Utc::now();
122        self.request_count += 1;
123    }
124
125    /// Set authentication status and client info
126    pub fn authenticate(&mut self, client_name: Option<String>) {
127        self.authenticated = true;
128        self.client_name = client_name;
129    }
130
131    /// Set client capabilities
132    pub fn set_capabilities(&mut self, capabilities: serde_json::Value) {
133        self.capabilities = Some(capabilities);
134    }
135
136    /// Get session duration
137    #[must_use]
138    pub fn session_duration(&self) -> chrono::Duration {
139        self.last_activity - self.connected_at
140    }
141
142    /// Check if session is idle (no activity for specified duration)
143    #[must_use]
144    pub fn is_idle(&self, idle_threshold: chrono::Duration) -> bool {
145        Utc::now() - self.last_activity > idle_threshold
146    }
147}
148
149/// Client ID extractor for authentication across different transports
150#[derive(Debug)]
151pub struct ClientIdExtractor {
152    /// Authentication tokens mapping token -> `client_id`
153    auth_tokens: Arc<dashmap::DashMap<String, String>>,
154}
155
156impl ClientIdExtractor {
157    /// Create a new client ID extractor
158    #[must_use]
159    pub fn new() -> Self {
160        Self {
161            auth_tokens: Arc::new(dashmap::DashMap::new()),
162        }
163    }
164
165    /// Register an authentication token for a client
166    pub fn register_token(&self, token: String, client_id: String) {
167        self.auth_tokens.insert(token, client_id);
168    }
169
170    /// Remove an authentication token
171    pub fn revoke_token(&self, token: &str) {
172        self.auth_tokens.remove(token);
173    }
174
175    /// List all registered tokens (for admin purposes)
176    #[must_use]
177    pub fn list_tokens(&self) -> Vec<(String, String)> {
178        self.auth_tokens
179            .iter()
180            .map(|entry| (entry.key().clone(), entry.value().clone()))
181            .collect()
182    }
183
184    /// Extract client ID from HTTP headers
185    #[must_use]
186    #[allow(clippy::significant_drop_tightening)]
187    pub fn extract_from_http_headers(&self, headers: &HashMap<String, String>) -> ClientId {
188        // 1. Check for explicit client ID header
189        if let Some(client_id) = headers.get("x-client-id") {
190            return ClientId::Header(client_id.clone());
191        }
192
193        // 2. Check for Authorization header with Bearer token
194        if let Some(auth) = headers.get("authorization")
195            && let Some(token) = auth.strip_prefix("Bearer ")
196        {
197            // Look up client ID from token
198            let token_lookup = self.auth_tokens.iter().find(|e| e.key() == token);
199            if let Some(entry) = token_lookup {
200                let client_id = entry.value().clone();
201                drop(entry); // Explicitly drop the lock guard early
202                return ClientId::Token(client_id);
203            }
204            // Token not found - return the token itself as identifier
205            return ClientId::Token(token.to_string());
206        }
207
208        // 3. Check for session cookie
209        if let Some(cookie) = headers.get("cookie") {
210            for cookie_part in cookie.split(';') {
211                let parts: Vec<&str> = cookie_part.trim().splitn(2, '=').collect();
212                if parts.len() == 2 && (parts[0] == "session_id" || parts[0] == "sessionid") {
213                    return ClientId::Session(parts[1].to_string());
214                }
215            }
216        }
217
218        // 4. Use User-Agent hash as fallback
219        if let Some(user_agent) = headers.get("user-agent") {
220            use std::collections::hash_map::DefaultHasher;
221            use std::hash::{Hash, Hasher};
222            let mut hasher = DefaultHasher::new();
223            user_agent.hash(&mut hasher);
224            return ClientId::UserAgent(format!("ua_{:x}", hasher.finish()));
225        }
226
227        ClientId::Anonymous
228    }
229
230    /// Extract client ID from query parameters
231    #[must_use]
232    pub fn extract_from_query(&self, query_params: &HashMap<String, String>) -> Option<ClientId> {
233        query_params
234            .get("client_id")
235            .map(|client_id| ClientId::QueryParam(client_id.clone()))
236    }
237
238    /// Extract client ID from multiple sources (with priority)
239    #[must_use]
240    pub fn extract_client_id(
241        &self,
242        headers: Option<&HashMap<String, String>>,
243        query_params: Option<&HashMap<String, String>>,
244    ) -> ClientId {
245        // Try query parameters first (highest priority)
246        if let Some(params) = query_params
247            && let Some(client_id) = self.extract_from_query(params)
248        {
249            return client_id;
250        }
251
252        // Try HTTP headers
253        if let Some(headers) = headers {
254            return self.extract_from_http_headers(headers);
255        }
256
257        ClientId::Anonymous
258    }
259}
260
261impl Default for ClientIdExtractor {
262    fn default() -> Self {
263        Self::new()
264    }
265}