Skip to main content

turbomcp_protocol/context/
client.rs

1//! Client-related context types for MCP client session management.
2//!
3//! This module contains types for managing client sessions, capabilities,
4//! and identification across different transport mechanisms.
5
6use std::collections::HashMap;
7use std::sync::Arc;
8
9use chrono::{DateTime, Utc};
10use serde::{Deserialize, Serialize};
11
12/// Client capabilities for server-initiated requests
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct ClientCapabilities {
15    /// Supports sampling/message creation
16    pub sampling: bool,
17    /// Supports roots listing
18    pub roots: bool,
19    /// Supports elicitation
20    pub elicitation: bool,
21    /// Maximum concurrent server requests
22    pub max_concurrent_requests: usize,
23    /// Supported experimental features
24    pub experimental: HashMap<String, bool>,
25}
26
27/// Client identifier types for authentication and tracking
28///
29/// `Debug` is implemented manually to redact secret-bearing variants
30/// (`Token`, `Session`) so bearer tokens / session cookies do not leak
31/// into tracing logs. The plaintext can still be retrieved via
32/// [`ClientId::as_str`].
33#[derive(Clone, PartialEq, Eq, Serialize, Deserialize)]
34pub enum ClientId {
35    /// Explicit client ID from header
36    Header(String),
37    /// Bearer token from Authorization header
38    Token(String),
39    /// Session cookie
40    Session(String),
41    /// Query parameter
42    QueryParam(String),
43    /// Hash of User-Agent (fallback)
44    UserAgent(String),
45    /// Anonymous client
46    Anonymous,
47}
48
49impl std::fmt::Debug for ClientId {
50    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51        match self {
52            Self::Header(id) => f.debug_tuple("Header").field(id).finish(),
53            Self::Token(_) => f.debug_tuple("Token").field(&"<redacted>").finish(),
54            Self::Session(_) => f.debug_tuple("Session").field(&"<redacted>").finish(),
55            Self::QueryParam(id) => f.debug_tuple("QueryParam").field(id).finish(),
56            Self::UserAgent(id) => f.debug_tuple("UserAgent").field(id).finish(),
57            Self::Anonymous => f.write_str("Anonymous"),
58        }
59    }
60}
61
62impl ClientId {
63    /// Get the string representation of the client ID
64    #[must_use]
65    pub fn as_str(&self) -> &str {
66        match self {
67            Self::Header(id)
68            | Self::Token(id)
69            | Self::Session(id)
70            | Self::QueryParam(id)
71            | Self::UserAgent(id) => id,
72            Self::Anonymous => "anonymous",
73        }
74    }
75
76    /// Check if the client is authenticated
77    #[must_use]
78    pub const fn is_authenticated(&self) -> bool {
79        matches!(self, Self::Token(_) | Self::Session(_))
80    }
81
82    /// Get the authentication method
83    #[must_use]
84    pub const fn auth_method(&self) -> &'static str {
85        match self {
86            Self::Header(_) => "header",
87            Self::Token(_) => "bearer_token",
88            Self::Session(_) => "session_cookie",
89            Self::QueryParam(_) => "query_param",
90            Self::UserAgent(_) => "user_agent",
91            Self::Anonymous => "anonymous",
92        }
93    }
94}
95
96/// Client session information for tracking and analytics
97#[derive(Debug, Clone, Serialize, Deserialize)]
98pub struct ClientSession {
99    /// Unique client identifier
100    pub client_id: String,
101    /// Client name (optional, human-readable)
102    pub client_name: Option<String>,
103    /// When the client connected
104    pub connected_at: DateTime<Utc>,
105    /// Last activity timestamp
106    pub last_activity: DateTime<Utc>,
107    /// Number of requests made
108    pub request_count: usize,
109    /// Transport type (stdio, http, websocket, etc.)
110    pub transport_type: String,
111    /// Authentication status
112    pub authenticated: bool,
113    /// Client capabilities (optional)
114    pub capabilities: Option<serde_json::Value>,
115    /// Additional metadata
116    pub metadata: HashMap<String, serde_json::Value>,
117}
118
119impl ClientSession {
120    /// Create a new client session
121    #[must_use]
122    pub fn new(client_id: String, transport_type: String) -> Self {
123        let now = Utc::now();
124        Self {
125            client_id,
126            client_name: None,
127            connected_at: now,
128            last_activity: now,
129            request_count: 0,
130            transport_type,
131            authenticated: false,
132            capabilities: None,
133            metadata: HashMap::new(),
134        }
135    }
136
137    /// Update activity timestamp and increment request count
138    pub fn update_activity(&mut self) {
139        self.last_activity = Utc::now();
140        self.request_count += 1;
141    }
142
143    /// Set authentication status and client info
144    pub fn authenticate(&mut self, client_name: Option<String>) {
145        self.authenticated = true;
146        self.client_name = client_name;
147    }
148
149    /// Set client capabilities
150    pub fn set_capabilities(&mut self, capabilities: serde_json::Value) {
151        self.capabilities = Some(capabilities);
152    }
153
154    /// Get session duration
155    #[must_use]
156    pub fn session_duration(&self) -> chrono::Duration {
157        self.last_activity - self.connected_at
158    }
159
160    /// Check if session is idle (no activity for specified duration)
161    #[must_use]
162    pub fn is_idle(&self, idle_threshold: chrono::Duration) -> bool {
163        Utc::now() - self.last_activity > idle_threshold
164    }
165}
166
167/// Client ID extractor for authentication across different transports
168///
169/// `Debug` is implemented manually so the contents of `auth_tokens` are
170/// never written to logs — only the count is exposed.
171pub struct ClientIdExtractor {
172    /// Authentication tokens mapping token -> `client_id`
173    auth_tokens: Arc<dashmap::DashMap<String, String>>,
174}
175
176impl std::fmt::Debug for ClientIdExtractor {
177    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
178        f.debug_struct("ClientIdExtractor")
179            .field("auth_tokens_count", &self.auth_tokens.len())
180            .finish()
181    }
182}
183
184impl ClientIdExtractor {
185    /// Create a new client ID extractor
186    #[must_use]
187    pub fn new() -> Self {
188        Self {
189            auth_tokens: Arc::new(dashmap::DashMap::new()),
190        }
191    }
192
193    /// Register an authentication token for a client
194    pub fn register_token(&self, token: String, client_id: String) {
195        self.auth_tokens.insert(token, client_id);
196    }
197
198    /// Remove an authentication token
199    pub fn revoke_token(&self, token: &str) {
200        self.auth_tokens.remove(token);
201    }
202
203    /// List all registered tokens (for admin purposes)
204    #[must_use]
205    pub fn list_tokens(&self) -> Vec<(String, String)> {
206        self.auth_tokens
207            .iter()
208            .map(|entry| (entry.key().clone(), entry.value().clone()))
209            .collect()
210    }
211
212    /// Extract client ID from HTTP headers
213    #[must_use]
214    #[allow(clippy::significant_drop_tightening)]
215    pub fn extract_from_http_headers(&self, headers: &HashMap<String, String>) -> ClientId {
216        // 1. Check for explicit client ID header
217        if let Some(client_id) = headers.get("x-client-id") {
218            return ClientId::Header(client_id.clone());
219        }
220
221        // 2. Check for Authorization header with Bearer token
222        if let Some(auth) = headers.get("authorization")
223            && let Some(token) = auth.strip_prefix("Bearer ")
224        {
225            // Look up client ID from token
226            let token_lookup = self.auth_tokens.iter().find(|e| e.key() == token);
227            if let Some(entry) = token_lookup {
228                let client_id = entry.value().clone();
229                drop(entry); // Explicitly drop the lock guard early
230                return ClientId::Token(client_id);
231            }
232            // Token not found - return the token itself as identifier
233            return ClientId::Token(token.to_string());
234        }
235
236        // 3. Check for session cookie
237        if let Some(cookie) = headers.get("cookie") {
238            for cookie_part in cookie.split(';') {
239                let parts: Vec<&str> = cookie_part.trim().splitn(2, '=').collect();
240                if parts.len() == 2 && (parts[0] == "session_id" || parts[0] == "sessionid") {
241                    return ClientId::Session(parts[1].to_string());
242                }
243            }
244        }
245
246        // 4. Use User-Agent hash as fallback
247        if let Some(user_agent) = headers.get("user-agent") {
248            use std::collections::hash_map::DefaultHasher;
249            use std::hash::{Hash, Hasher};
250            let mut hasher = DefaultHasher::new();
251            user_agent.hash(&mut hasher);
252            return ClientId::UserAgent(format!("ua_{:x}", hasher.finish()));
253        }
254
255        ClientId::Anonymous
256    }
257
258    /// Extract client ID from query parameters
259    #[must_use]
260    pub fn extract_from_query(&self, query_params: &HashMap<String, String>) -> Option<ClientId> {
261        query_params
262            .get("client_id")
263            .map(|client_id| ClientId::QueryParam(client_id.clone()))
264    }
265
266    /// Extract client ID from multiple sources (with priority)
267    #[must_use]
268    pub fn extract_client_id(
269        &self,
270        headers: Option<&HashMap<String, String>>,
271        query_params: Option<&HashMap<String, String>>,
272    ) -> ClientId {
273        // Try query parameters first (highest priority)
274        if let Some(params) = query_params
275            && let Some(client_id) = self.extract_from_query(params)
276        {
277            return client_id;
278        }
279
280        // Try HTTP headers
281        if let Some(headers) = headers {
282            return self.extract_from_http_headers(headers);
283        }
284
285        ClientId::Anonymous
286    }
287}
288
289impl Default for ClientIdExtractor {
290    fn default() -> Self {
291        Self::new()
292    }
293}