turbomcp_protocol/context/
request.rs1use std::collections::HashMap;
7use std::fmt;
8use std::sync::Arc;
9use std::time::Instant;
10
11use chrono::{DateTime, Utc};
12use serde::{Deserialize, Serialize};
13use tokio_util::sync::CancellationToken;
14use uuid::Uuid;
15
16use super::capabilities::ServerToClientRequests;
17use crate::types::Timestamp;
18
19#[derive(Clone)]
25pub struct RequestContext {
26 pub request_id: String,
28
29 pub user_id: Option<String>,
31
32 pub session_id: Option<String>,
34
35 pub client_id: Option<String>,
37
38 pub timestamp: Timestamp,
40
41 pub start_time: Instant,
43
44 pub metadata: Arc<HashMap<String, serde_json::Value>>,
46
47 #[cfg(feature = "tracing")]
49 pub span: Option<tracing::Span>,
50
51 pub cancellation_token: Option<Arc<CancellationToken>>,
53
54 #[doc(hidden)]
57 pub(crate) server_to_client: Option<Arc<dyn ServerToClientRequests>>,
58}
59
60impl fmt::Debug for RequestContext {
61 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
62 f.debug_struct("RequestContext")
63 .field("request_id", &self.request_id)
64 .field("user_id", &self.user_id)
65 .field("session_id", &self.session_id)
66 .field("client_id", &self.client_id)
67 .field("timestamp", &self.timestamp)
68 .field("metadata", &self.metadata)
69 .field("server_to_client", &self.server_to_client.is_some())
70 .finish()
71 }
72}
73
74#[derive(Debug, Clone)]
76pub struct ResponseContext {
77 pub request_id: String,
79
80 pub timestamp: Timestamp,
82
83 pub duration: std::time::Duration,
85
86 pub status: ResponseStatus,
88
89 pub metadata: Arc<HashMap<String, serde_json::Value>>,
91}
92
93#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
95pub enum ResponseStatus {
96 Success,
98 Error {
100 code: i32,
102 message: String,
104 },
105 Partial,
107 Cancelled,
109}
110
111#[derive(Debug, Clone, Serialize, Deserialize)]
113pub struct RequestInfo {
114 pub timestamp: DateTime<Utc>,
116 pub client_id: String,
118 pub method_name: String,
120 pub parameters: serde_json::Value,
122 pub response_time_ms: Option<u64>,
124 pub success: bool,
126 pub error_message: Option<String>,
128 pub status_code: Option<u16>,
130 pub metadata: HashMap<String, serde_json::Value>,
132}
133
134impl RequestContext {
135 #[must_use]
137 pub fn new() -> Self {
138 Self {
139 request_id: Uuid::new_v4().to_string(),
140 user_id: None,
141 session_id: None,
142 client_id: None,
143 timestamp: Timestamp::now(),
144 start_time: Instant::now(),
145 metadata: Arc::new(HashMap::new()),
146 #[cfg(feature = "tracing")]
147 span: None,
148 cancellation_token: None,
149 server_to_client: None,
150 }
151 }
152
153 pub fn with_id(id: impl Into<String>) -> Self {
155 Self {
156 request_id: id.into(),
157 ..Self::new()
158 }
159 }
160
161 #[must_use]
170 pub fn with_user_id(mut self, user_id: impl Into<String>) -> Self {
171 self.user_id = Some(user_id.into());
172 self
173 }
174
175 #[must_use]
177 pub fn with_session_id(mut self, session_id: impl Into<String>) -> Self {
178 self.session_id = Some(session_id.into());
179 self
180 }
181
182 #[must_use]
184 pub fn with_client_id(mut self, client_id: impl Into<String>) -> Self {
185 self.client_id = Some(client_id.into());
186 self
187 }
188
189 #[must_use]
199 pub fn with_metadata(
200 mut self,
201 key: impl Into<String>,
202 value: impl Into<serde_json::Value>,
203 ) -> Self {
204 Arc::make_mut(&mut self.metadata).insert(key.into(), value.into());
205 self
206 }
207
208 #[must_use]
210 pub fn get_metadata(&self, key: &str) -> Option<&serde_json::Value> {
211 self.metadata.get(key)
212 }
213
214 #[must_use]
216 pub fn elapsed(&self) -> std::time::Duration {
217 self.start_time.elapsed()
218 }
219
220 #[must_use]
222 pub fn is_cancelled(&self) -> bool {
223 self.cancellation_token
224 .as_ref()
225 .is_some_and(|token| token.is_cancelled())
226 }
227
228 #[must_use]
234 pub fn with_server_to_client(mut self, capabilities: Arc<dyn ServerToClientRequests>) -> Self {
235 self.server_to_client = Some(capabilities);
236 self
237 }
238
239 #[must_use]
242 pub fn with_cancellation_token(mut self, token: Arc<CancellationToken>) -> Self {
243 self.cancellation_token = Some(token);
244 self
245 }
246
247 #[must_use]
249 pub fn user(&self) -> Option<&str> {
250 self.user_id.as_deref()
251 }
252
253 #[must_use]
256 pub fn is_authenticated(&self) -> bool {
257 self.get_metadata("client_authenticated")
258 .and_then(|v| v.as_bool())
259 .unwrap_or(false)
260 }
261
262 #[must_use]
265 pub fn roles(&self) -> Vec<String> {
266 self.get_metadata("auth")
267 .and_then(|auth| auth.get("roles"))
268 .and_then(|roles| roles.as_array())
269 .map(|roles| {
270 roles
271 .iter()
272 .filter_map(|role| role.as_str().map(ToString::to_string))
273 .collect()
274 })
275 .unwrap_or_default()
276 }
277
278 pub fn has_any_role<S: AsRef<str>>(&self, required: &[S]) -> bool {
281 if required.is_empty() {
282 return true; }
284
285 let user_roles = self.roles();
286 required
287 .iter()
288 .any(|required_role| user_roles.contains(&required_role.as_ref().to_string()))
289 }
290
291 #[doc(hidden)]
296 pub fn server_to_client(&self) -> Option<&Arc<dyn ServerToClientRequests>> {
297 self.server_to_client.as_ref()
298 }
299}
300
301impl Default for RequestContext {
302 fn default() -> Self {
303 Self::new()
304 }
305}
306
307impl ResponseContext {
308 pub fn success(request_id: impl Into<String>, duration: std::time::Duration) -> Self {
310 Self {
311 request_id: request_id.into(),
312 timestamp: Timestamp::now(),
313 duration,
314 status: ResponseStatus::Success,
315 metadata: Arc::new(HashMap::new()),
316 }
317 }
318
319 pub fn error(
321 request_id: impl Into<String>,
322 duration: std::time::Duration,
323 code: i32,
324 message: impl Into<String>,
325 ) -> Self {
326 Self {
327 request_id: request_id.into(),
328 timestamp: Timestamp::now(),
329 duration,
330 status: ResponseStatus::Error {
331 code,
332 message: message.into(),
333 },
334 metadata: Arc::new(HashMap::new()),
335 }
336 }
337}
338
339impl RequestInfo {
340 #[must_use]
342 pub fn new(client_id: String, method_name: String, parameters: serde_json::Value) -> Self {
343 Self {
344 timestamp: Utc::now(),
345 client_id,
346 method_name,
347 parameters,
348 response_time_ms: None,
349 success: false,
350 error_message: None,
351 status_code: None,
352 metadata: HashMap::new(),
353 }
354 }
355
356 #[must_use]
358 pub const fn complete_success(mut self, response_time_ms: u64) -> Self {
359 self.response_time_ms = Some(response_time_ms);
360 self.success = true;
361 self.status_code = Some(200);
362 self
363 }
364
365 #[must_use]
367 pub fn complete_error(mut self, response_time_ms: u64, error: String) -> Self {
368 self.response_time_ms = Some(response_time_ms);
369 self.success = false;
370 self.error_message = Some(error);
371 self.status_code = Some(500);
372 self
373 }
374
375 #[must_use]
377 pub const fn with_status_code(mut self, code: u16) -> Self {
378 self.status_code = Some(code);
379 self
380 }
381
382 #[must_use]
384 pub fn with_metadata(mut self, key: String, value: serde_json::Value) -> Self {
385 self.metadata.insert(key, value);
386 self
387 }
388}
389
390pub trait RequestContextExt {
392 #[must_use]
394 fn with_enhanced_client_id(self, client_id: super::client::ClientId) -> Self;
395
396 #[must_use]
398 fn extract_client_id(
399 self,
400 extractor: &super::client::ClientIdExtractor,
401 headers: Option<&HashMap<String, String>>,
402 query_params: Option<&HashMap<String, String>>,
403 ) -> Self;
404
405 fn get_enhanced_client_id(&self) -> Option<super::client::ClientId>;
407}
408
409impl RequestContextExt for RequestContext {
410 fn with_enhanced_client_id(self, client_id: super::client::ClientId) -> Self {
411 self.with_client_id(client_id.as_str())
412 .with_metadata(
413 "client_id_method".to_string(),
414 serde_json::Value::String(client_id.auth_method().to_string()),
415 )
416 .with_metadata(
417 "client_authenticated".to_string(),
418 serde_json::Value::Bool(client_id.is_authenticated()),
419 )
420 }
421
422 fn extract_client_id(
423 self,
424 extractor: &super::client::ClientIdExtractor,
425 headers: Option<&HashMap<String, String>>,
426 query_params: Option<&HashMap<String, String>>,
427 ) -> Self {
428 let client_id = extractor.extract_client_id(headers, query_params);
429 self.with_enhanced_client_id(client_id)
430 }
431
432 fn get_enhanced_client_id(&self) -> Option<super::client::ClientId> {
433 self.client_id.as_ref().map(|id| {
434 let method = self
435 .get_metadata("client_id_method")
436 .and_then(|v| v.as_str())
437 .unwrap_or("header");
438
439 match method {
440 "bearer_token" => super::client::ClientId::Token(id.clone()),
441 "session_cookie" => super::client::ClientId::Session(id.clone()),
442 "query_param" => super::client::ClientId::QueryParam(id.clone()),
443 "user_agent" => super::client::ClientId::UserAgent(id.clone()),
444 "anonymous" => super::client::ClientId::Anonymous,
445 _ => super::client::ClientId::Header(id.clone()), }
447 })
448 }
449}