turbomcp_protocol/context/
client.rs1use std::collections::HashMap;
7use std::sync::Arc;
8
9use chrono::{DateTime, Utc};
10use serde::{Deserialize, Serialize};
11
12#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct ClientCapabilities {
15 pub sampling: bool,
17 pub roots: bool,
19 pub elicitation: bool,
21 pub max_concurrent_requests: usize,
23 pub experimental: HashMap<String, bool>,
25}
26
27#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
29pub enum ClientId {
30 Header(String),
32 Token(String),
34 Session(String),
36 QueryParam(String),
38 UserAgent(String),
40 Anonymous,
42}
43
44impl ClientId {
45 #[must_use]
47 pub fn as_str(&self) -> &str {
48 match self {
49 Self::Header(id)
50 | Self::Token(id)
51 | Self::Session(id)
52 | Self::QueryParam(id)
53 | Self::UserAgent(id) => id,
54 Self::Anonymous => "anonymous",
55 }
56 }
57
58 #[must_use]
60 pub const fn is_authenticated(&self) -> bool {
61 matches!(self, Self::Token(_) | Self::Session(_))
62 }
63
64 #[must_use]
66 pub const fn auth_method(&self) -> &'static str {
67 match self {
68 Self::Header(_) => "header",
69 Self::Token(_) => "bearer_token",
70 Self::Session(_) => "session_cookie",
71 Self::QueryParam(_) => "query_param",
72 Self::UserAgent(_) => "user_agent",
73 Self::Anonymous => "anonymous",
74 }
75 }
76}
77
78#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct ClientSession {
81 pub client_id: String,
83 pub client_name: Option<String>,
85 pub connected_at: DateTime<Utc>,
87 pub last_activity: DateTime<Utc>,
89 pub request_count: usize,
91 pub transport_type: String,
93 pub authenticated: bool,
95 pub capabilities: Option<serde_json::Value>,
97 pub metadata: HashMap<String, serde_json::Value>,
99}
100
101impl ClientSession {
102 #[must_use]
104 pub fn new(client_id: String, transport_type: String) -> Self {
105 let now = Utc::now();
106 Self {
107 client_id,
108 client_name: None,
109 connected_at: now,
110 last_activity: now,
111 request_count: 0,
112 transport_type,
113 authenticated: false,
114 capabilities: None,
115 metadata: HashMap::new(),
116 }
117 }
118
119 pub fn update_activity(&mut self) {
121 self.last_activity = Utc::now();
122 self.request_count += 1;
123 }
124
125 pub fn authenticate(&mut self, client_name: Option<String>) {
127 self.authenticated = true;
128 self.client_name = client_name;
129 }
130
131 pub fn set_capabilities(&mut self, capabilities: serde_json::Value) {
133 self.capabilities = Some(capabilities);
134 }
135
136 #[must_use]
138 pub fn session_duration(&self) -> chrono::Duration {
139 self.last_activity - self.connected_at
140 }
141
142 #[must_use]
144 pub fn is_idle(&self, idle_threshold: chrono::Duration) -> bool {
145 Utc::now() - self.last_activity > idle_threshold
146 }
147}
148
149#[derive(Debug)]
151pub struct ClientIdExtractor {
152 auth_tokens: Arc<dashmap::DashMap<String, String>>,
154}
155
156impl ClientIdExtractor {
157 #[must_use]
159 pub fn new() -> Self {
160 Self {
161 auth_tokens: Arc::new(dashmap::DashMap::new()),
162 }
163 }
164
165 pub fn register_token(&self, token: String, client_id: String) {
167 self.auth_tokens.insert(token, client_id);
168 }
169
170 pub fn revoke_token(&self, token: &str) {
172 self.auth_tokens.remove(token);
173 }
174
175 #[must_use]
177 pub fn list_tokens(&self) -> Vec<(String, String)> {
178 self.auth_tokens
179 .iter()
180 .map(|entry| (entry.key().clone(), entry.value().clone()))
181 .collect()
182 }
183
184 #[must_use]
186 #[allow(clippy::significant_drop_tightening)]
187 pub fn extract_from_http_headers(&self, headers: &HashMap<String, String>) -> ClientId {
188 if let Some(client_id) = headers.get("x-client-id") {
190 return ClientId::Header(client_id.clone());
191 }
192
193 if let Some(auth) = headers.get("authorization")
195 && let Some(token) = auth.strip_prefix("Bearer ")
196 {
197 let token_lookup = self.auth_tokens.iter().find(|e| e.key() == token);
199 if let Some(entry) = token_lookup {
200 let client_id = entry.value().clone();
201 drop(entry); return ClientId::Token(client_id);
203 }
204 return ClientId::Token(token.to_string());
206 }
207
208 if let Some(cookie) = headers.get("cookie") {
210 for cookie_part in cookie.split(';') {
211 let parts: Vec<&str> = cookie_part.trim().splitn(2, '=').collect();
212 if parts.len() == 2 && (parts[0] == "session_id" || parts[0] == "sessionid") {
213 return ClientId::Session(parts[1].to_string());
214 }
215 }
216 }
217
218 if let Some(user_agent) = headers.get("user-agent") {
220 use std::collections::hash_map::DefaultHasher;
221 use std::hash::{Hash, Hasher};
222 let mut hasher = DefaultHasher::new();
223 user_agent.hash(&mut hasher);
224 return ClientId::UserAgent(format!("ua_{:x}", hasher.finish()));
225 }
226
227 ClientId::Anonymous
228 }
229
230 #[must_use]
232 pub fn extract_from_query(&self, query_params: &HashMap<String, String>) -> Option<ClientId> {
233 query_params
234 .get("client_id")
235 .map(|client_id| ClientId::QueryParam(client_id.clone()))
236 }
237
238 #[must_use]
240 pub fn extract_client_id(
241 &self,
242 headers: Option<&HashMap<String, String>>,
243 query_params: Option<&HashMap<String, String>>,
244 ) -> ClientId {
245 if let Some(params) = query_params
247 && let Some(client_id) = self.extract_from_query(params)
248 {
249 return client_id;
250 }
251
252 if let Some(headers) = headers {
254 return self.extract_from_http_headers(headers);
255 }
256
257 ClientId::Anonymous
258 }
259}
260
261impl Default for ClientIdExtractor {
262 fn default() -> Self {
263 Self::new()
264 }
265}