turbomcp_protocol/context/
client.rs1use std::collections::HashMap;
7use std::sync::Arc;
8
9use chrono::{DateTime, Utc};
10use serde::{Deserialize, Serialize};
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct ClientCapabilities {
15 pub sampling: bool,
17 pub roots: bool,
19 pub elicitation: bool,
21 pub max_concurrent_requests: usize,
23 pub experimental: HashMap<String, bool>,
25}
26
27#[derive(Clone, PartialEq, Eq, Serialize, Deserialize)]
34pub enum ClientId {
35 Header(String),
37 Token(String),
39 Session(String),
41 QueryParam(String),
43 UserAgent(String),
45 Anonymous,
47}
48
49impl std::fmt::Debug for ClientId {
50 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
51 match self {
52 Self::Header(id) => f.debug_tuple("Header").field(id).finish(),
53 Self::Token(_) => f.debug_tuple("Token").field(&"<redacted>").finish(),
54 Self::Session(_) => f.debug_tuple("Session").field(&"<redacted>").finish(),
55 Self::QueryParam(id) => f.debug_tuple("QueryParam").field(id).finish(),
56 Self::UserAgent(id) => f.debug_tuple("UserAgent").field(id).finish(),
57 Self::Anonymous => f.write_str("Anonymous"),
58 }
59 }
60}
61
62impl ClientId {
63 #[must_use]
65 pub fn as_str(&self) -> &str {
66 match self {
67 Self::Header(id)
68 | Self::Token(id)
69 | Self::Session(id)
70 | Self::QueryParam(id)
71 | Self::UserAgent(id) => id,
72 Self::Anonymous => "anonymous",
73 }
74 }
75
76 #[must_use]
78 pub const fn is_authenticated(&self) -> bool {
79 matches!(self, Self::Token(_) | Self::Session(_))
80 }
81
82 #[must_use]
84 pub const fn auth_method(&self) -> &'static str {
85 match self {
86 Self::Header(_) => "header",
87 Self::Token(_) => "bearer_token",
88 Self::Session(_) => "session_cookie",
89 Self::QueryParam(_) => "query_param",
90 Self::UserAgent(_) => "user_agent",
91 Self::Anonymous => "anonymous",
92 }
93 }
94}
95
96#[derive(Debug, Clone, Serialize, Deserialize)]
98pub struct ClientSession {
99 pub client_id: String,
101 pub client_name: Option<String>,
103 pub connected_at: DateTime<Utc>,
105 pub last_activity: DateTime<Utc>,
107 pub request_count: usize,
109 pub transport_type: String,
111 pub authenticated: bool,
113 pub capabilities: Option<serde_json::Value>,
115 pub metadata: HashMap<String, serde_json::Value>,
117}
118
119impl ClientSession {
120 #[must_use]
122 pub fn new(client_id: String, transport_type: String) -> Self {
123 let now = Utc::now();
124 Self {
125 client_id,
126 client_name: None,
127 connected_at: now,
128 last_activity: now,
129 request_count: 0,
130 transport_type,
131 authenticated: false,
132 capabilities: None,
133 metadata: HashMap::new(),
134 }
135 }
136
137 pub fn update_activity(&mut self) {
139 self.last_activity = Utc::now();
140 self.request_count += 1;
141 }
142
143 pub fn authenticate(&mut self, client_name: Option<String>) {
145 self.authenticated = true;
146 self.client_name = client_name;
147 }
148
149 pub fn set_capabilities(&mut self, capabilities: serde_json::Value) {
151 self.capabilities = Some(capabilities);
152 }
153
154 #[must_use]
156 pub fn session_duration(&self) -> chrono::Duration {
157 self.last_activity - self.connected_at
158 }
159
160 #[must_use]
162 pub fn is_idle(&self, idle_threshold: chrono::Duration) -> bool {
163 Utc::now() - self.last_activity > idle_threshold
164 }
165}
166
167pub struct ClientIdExtractor {
172 auth_tokens: Arc<dashmap::DashMap<String, String>>,
174}
175
176impl std::fmt::Debug for ClientIdExtractor {
177 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
178 f.debug_struct("ClientIdExtractor")
179 .field("auth_tokens_count", &self.auth_tokens.len())
180 .finish()
181 }
182}
183
184impl ClientIdExtractor {
185 #[must_use]
187 pub fn new() -> Self {
188 Self {
189 auth_tokens: Arc::new(dashmap::DashMap::new()),
190 }
191 }
192
193 pub fn register_token(&self, token: String, client_id: String) {
195 self.auth_tokens.insert(token, client_id);
196 }
197
198 pub fn revoke_token(&self, token: &str) {
200 self.auth_tokens.remove(token);
201 }
202
203 #[must_use]
205 pub fn list_tokens(&self) -> Vec<(String, String)> {
206 self.auth_tokens
207 .iter()
208 .map(|entry| (entry.key().clone(), entry.value().clone()))
209 .collect()
210 }
211
212 #[must_use]
214 #[allow(clippy::significant_drop_tightening)]
215 pub fn extract_from_http_headers(&self, headers: &HashMap<String, String>) -> ClientId {
216 if let Some(client_id) = headers.get("x-client-id") {
218 return ClientId::Header(client_id.clone());
219 }
220
221 if let Some(auth) = headers.get("authorization")
223 && let Some(token) = auth.strip_prefix("Bearer ")
224 {
225 let token_lookup = self.auth_tokens.iter().find(|e| e.key() == token);
227 if let Some(entry) = token_lookup {
228 let client_id = entry.value().clone();
229 drop(entry); return ClientId::Token(client_id);
231 }
232 return ClientId::Token(token.to_string());
234 }
235
236 if let Some(cookie) = headers.get("cookie") {
238 for cookie_part in cookie.split(';') {
239 let parts: Vec<&str> = cookie_part.trim().splitn(2, '=').collect();
240 if parts.len() == 2 && (parts[0] == "session_id" || parts[0] == "sessionid") {
241 return ClientId::Session(parts[1].to_string());
242 }
243 }
244 }
245
246 if let Some(user_agent) = headers.get("user-agent") {
248 use std::collections::hash_map::DefaultHasher;
249 use std::hash::{Hash, Hasher};
250 let mut hasher = DefaultHasher::new();
251 user_agent.hash(&mut hasher);
252 return ClientId::UserAgent(format!("ua_{:x}", hasher.finish()));
253 }
254
255 ClientId::Anonymous
256 }
257
258 #[must_use]
260 pub fn extract_from_query(&self, query_params: &HashMap<String, String>) -> Option<ClientId> {
261 query_params
262 .get("client_id")
263 .map(|client_id| ClientId::QueryParam(client_id.clone()))
264 }
265
266 #[must_use]
268 pub fn extract_client_id(
269 &self,
270 headers: Option<&HashMap<String, String>>,
271 query_params: Option<&HashMap<String, String>>,
272 ) -> ClientId {
273 if let Some(params) = query_params
275 && let Some(client_id) = self.extract_from_query(params)
276 {
277 return client_id;
278 }
279
280 if let Some(headers) = headers {
282 return self.extract_from_http_headers(headers);
283 }
284
285 ClientId::Anonymous
286 }
287}
288
289impl Default for ClientIdExtractor {
290 fn default() -> Self {
291 Self::new()
292 }
293}