1use std::collections::HashMap;
16use std::sync::atomic::Ordering;
17use std::sync::{Arc, Mutex as StdMutex};
18use std::time::{Duration, Instant};
19
20use async_trait::async_trait;
21use bytes::Bytes;
22use serde_json;
23use tokio::sync::{Mutex as TokioMutex, mpsc};
24use tracing::{debug, error, info, trace, warn};
25use uuid::Uuid;
26
27use crate::core::{
28 AtomicMetrics, Transport, TransportCapabilities, TransportError, TransportEventEmitter,
29 TransportMessage, TransportMetrics, TransportResult, TransportState, TransportType,
30};
31use turbomcp_protocol::MessageId;
32
33pub type SessionId = String;
35
36#[derive(Debug, Clone)]
38pub struct SessionInfo {
39 pub id: SessionId,
41
42 pub created_at: Instant,
44
45 pub last_activity: Instant,
47
48 pub remote_addr: Option<String>,
50
51 pub user_agent: Option<String>,
53
54 pub metadata: HashMap<String, String>,
56}
57
58impl SessionInfo {
59 pub fn new() -> Self {
61 let now = Instant::now();
62 Self {
63 id: Uuid::new_v4().to_string(),
64 created_at: now,
65 last_activity: now,
66 remote_addr: None,
67 user_agent: None,
68 metadata: HashMap::new(),
69 }
70 }
71}
72
73impl Default for SessionInfo {
74 fn default() -> Self {
75 Self::new()
76 }
77}
78
79impl SessionInfo {
80 pub fn touch(&mut self) {
82 self.last_activity = Instant::now();
83 }
84
85 pub fn is_expired(&self, timeout: Duration) -> bool {
87 self.last_activity.elapsed() > timeout
88 }
89
90 pub fn duration(&self) -> Duration {
92 self.created_at.elapsed()
93 }
94}
95
96#[derive(Debug, Clone)]
101pub struct SessionManager {
102 sessions: Arc<StdMutex<HashMap<SessionId, SessionInfo>>>,
104
105 session_timeout: Duration,
107
108 max_sessions: usize,
110}
111
112impl SessionManager {
113 pub fn new() -> Self {
115 Self {
116 sessions: Arc::new(StdMutex::new(HashMap::new())),
117 session_timeout: Duration::from_secs(300), max_sessions: 1000, }
120 }
121
122 pub fn with_config(session_timeout: Duration, max_sessions: usize) -> Self {
124 Self {
125 sessions: Arc::new(StdMutex::new(HashMap::new())),
126 session_timeout,
127 max_sessions,
128 }
129 }
130
131 pub async fn create_session(&self) -> TransportResult<SessionInfo> {
133 let mut sessions = self.sessions.lock().expect("sessions mutex poisoned");
134
135 if sessions.len() >= self.max_sessions {
137 self.cleanup_expired_sessions_locked(&mut sessions);
139
140 if sessions.len() >= self.max_sessions {
142 return Err(TransportError::RateLimitExceeded);
143 }
144 }
145
146 let session = SessionInfo::new();
147 let session_id = session.id.clone();
148 sessions.insert(session_id, session.clone());
149
150 debug!("Created new session: {}", session.id);
151 Ok(session)
152 }
153
154 pub fn get_session(&self, session_id: &str) -> Option<SessionInfo> {
156 let mut sessions = self.sessions.lock().expect("sessions mutex poisoned");
157
158 if let Some(session) = sessions.get_mut(session_id) {
159 session.touch();
161 Some(session.clone())
162 } else {
163 None
164 }
165 }
166
167 pub fn update_session_metadata(&self, session_id: &str, key: String, value: String) {
169 let mut sessions = self.sessions.lock().expect("sessions mutex poisoned");
170
171 if let Some(session) = sessions.get_mut(session_id) {
172 session.metadata.insert(key, value);
173 session.touch();
174 }
175 }
176
177 pub fn remove_session(&self, session_id: &str) -> bool {
179 let mut sessions = self.sessions.lock().expect("sessions mutex poisoned");
180 let removed = sessions.remove(session_id).is_some();
181
182 if removed {
183 debug!("Removed session: {}", session_id);
184 }
185
186 removed
187 }
188
189 pub async fn active_session_count(&self) -> usize {
191 self.sessions.lock().expect("sessions mutex poisoned").len()
192 }
193
194 pub async fn cleanup_expired_sessions(&self) -> usize {
196 let mut sessions = self.sessions.lock().expect("sessions mutex poisoned");
197 self.cleanup_expired_sessions_locked(&mut sessions)
198 }
199
200 fn cleanup_expired_sessions_locked(
201 &self,
202 sessions: &mut HashMap<SessionId, SessionInfo>,
203 ) -> usize {
204 let initial_count = sessions.len();
205
206 sessions.retain(|_id, session| !session.is_expired(self.session_timeout));
207
208 let removed = initial_count - sessions.len();
209
210 if removed > 0 {
211 debug!("Cleaned up {} expired sessions", removed);
212 }
213
214 removed
215 }
216
217 pub async fn list_sessions(&self) -> Vec<SessionInfo> {
219 self.sessions
220 .lock()
221 .expect("sessions mutex poisoned")
222 .values()
223 .cloned()
224 .collect()
225 }
226}
227
228impl Default for SessionManager {
229 fn default() -> Self {
230 Self::new()
231 }
232}
233
234#[derive(Debug)]
247pub struct TowerTransportAdapter {
248 capabilities: TransportCapabilities,
250
251 state: Arc<StdMutex<TransportState>>,
253
254 metrics: Arc<AtomicMetrics>,
256
257 event_emitter: TransportEventEmitter,
259
260 session_manager: SessionManager,
262
263 receiver: Arc<TokioMutex<Option<mpsc::Receiver<TransportMessage>>>>,
265
266 sender: Arc<TokioMutex<Option<mpsc::Sender<TransportMessage>>>>,
268
269 _cleanup_task: Arc<TokioMutex<Option<tokio::task::JoinHandle<()>>>>,
271}
272
273impl TowerTransportAdapter {
274 pub fn new() -> Self {
276 let (event_emitter, _) = TransportEventEmitter::new();
277
278 Self {
279 capabilities: TransportCapabilities {
280 max_message_size: Some(16 * 1024 * 1024), supports_compression: true,
282 supports_streaming: true,
283 supports_bidirectional: true,
284 supports_multiplexing: true,
285 compression_algorithms: vec![
286 "gzip".to_string(),
287 "deflate".to_string(),
288 "br".to_string(),
289 ],
290 custom: HashMap::new(),
291 },
292 state: Arc::new(StdMutex::new(TransportState::Disconnected)),
293 metrics: Arc::new(AtomicMetrics::default()),
294 event_emitter,
295 session_manager: SessionManager::new(),
296 receiver: Arc::new(TokioMutex::new(None)),
297 sender: Arc::new(TokioMutex::new(None)),
298 _cleanup_task: Arc::new(TokioMutex::new(None)),
299 }
300 }
301}
302
303impl Default for TowerTransportAdapter {
304 fn default() -> Self {
305 Self::new()
306 }
307}
308
309impl TowerTransportAdapter {
310 pub fn with_session_manager(session_manager: SessionManager) -> Self {
312 let mut adapter = Self::new();
313 adapter.session_manager = session_manager;
314 adapter
315 }
316
317 pub async fn initialize(&self) -> McpResult<()> {
319 let (tx, rx) = mpsc::channel(1000); *self.sender.lock().await = Some(tx);
321 *self.receiver.lock().await = Some(rx);
322
323 let session_manager = self.session_manager.clone();
325 let cleanup_task = tokio::spawn(async move {
326 let mut interval = tokio::time::interval(Duration::from_secs(60)); loop {
329 interval.tick().await;
330 let cleaned = session_manager.cleanup_expired_sessions().await;
331
332 if cleaned > 0 {
333 trace!("Session cleanup: removed {} expired sessions", cleaned);
334 }
335 }
336 });
337
338 *self._cleanup_task.lock().await = Some(cleanup_task);
339 self.set_state(TransportState::Connected);
340
341 info!("Tower transport adapter initialized");
342 Ok(())
343 }
344
345 pub fn session_manager(&self) -> &SessionManager {
347 &self.session_manager
348 }
349
350 pub async fn process_message(
352 &self,
353 message: TransportMessage,
354 session_info: &SessionInfo,
355 ) -> TransportResult<Option<TransportMessage>> {
356 let start_time = Instant::now();
357
358 self.metrics
360 .messages_received
361 .fetch_add(1, Ordering::Relaxed);
362 self.metrics
363 .bytes_received
364 .fetch_add(message.size() as u64, Ordering::Relaxed);
365
366 self.event_emitter
368 .emit_message_received(message.id.clone(), message.size());
369
370 if message.size() > self.capabilities.max_message_size.unwrap_or(usize::MAX) {
372 let error = TransportError::ProtocolError("Message too large".to_string());
373 self.event_emitter
374 .emit_error(error.clone(), Some("message validation".to_string()));
375 return Err(error);
376 }
377
378 let json_value: serde_json::Value = serde_json::from_slice(&message.payload)
380 .map_err(|e| TransportError::SerializationFailed(e.to_string()))?;
381
382 trace!(
383 "Processing message from session {}: {:?}",
384 session_info.id, json_value
385 );
386
387 let response_payload = serde_json::json!({
390 "jsonrpc": "2.0",
391 "id": json_value.get("id").unwrap_or(&serde_json::Value::Null),
392 "result": {
393 "echo": json_value,
394 "session": session_info.id,
395 "processed_at": chrono::Utc::now().to_rfc3339()
396 }
397 });
398
399 let response_bytes = Bytes::from(
400 serde_json::to_vec(&response_payload)
401 .map_err(|e| TransportError::SerializationFailed(e.to_string()))?,
402 );
403
404 let response_message =
405 TransportMessage::new(MessageId::from(Uuid::new_v4()), response_bytes);
406
407 let processing_time = start_time.elapsed();
409 self.metrics.messages_sent.fetch_add(1, Ordering::Relaxed);
410 self.metrics
411 .bytes_sent
412 .fetch_add(response_message.size() as u64, Ordering::Relaxed);
413
414 self.metrics
416 .update_latency_us(processing_time.as_micros() as u64);
417
418 self.event_emitter
420 .emit_message_sent(response_message.id.clone(), response_message.size());
421
422 Ok(Some(response_message))
423 }
424
425 fn set_state(&self, new_state: TransportState) {
427 let mut state = self.state.lock().expect("state mutex poisoned");
429 if *state != new_state {
430 trace!("Tower transport state: {:?} -> {:?}", *state, new_state);
431 *state = new_state.clone();
432
433 match new_state {
435 TransportState::Connected => {
436 self.event_emitter
437 .emit_connected(TransportType::Http, "tower://adapter".to_string());
438 }
439 TransportState::Disconnected => {
440 self.event_emitter.emit_disconnected(
441 TransportType::Http,
442 "tower://adapter".to_string(),
443 None,
444 );
445 }
446 TransportState::Failed { reason } => {
447 self.event_emitter.emit_disconnected(
448 TransportType::Http,
449 "tower://adapter".to_string(),
450 Some(reason),
451 );
452 }
453 _ => {}
454 }
455 }
456 }
457}
458
459#[async_trait]
460impl Transport for TowerTransportAdapter {
461 fn transport_type(&self) -> TransportType {
462 TransportType::Http
463 }
464
465 fn capabilities(&self) -> &TransportCapabilities {
466 &self.capabilities
467 }
468
469 async fn state(&self) -> TransportState {
470 self.state.lock().expect("state mutex poisoned").clone()
472 }
473
474 async fn connect(&self) -> TransportResult<()> {
475 if matches!(self.state().await, TransportState::Connected) {
476 return Ok(());
477 }
478
479 self.set_state(TransportState::Connecting);
480
481 match self.initialize().await {
482 Ok(()) => {
483 self.metrics.connections.fetch_add(1, Ordering::Relaxed);
485 info!("Tower transport adapter connected");
486 Ok(())
487 }
488 Err(e) => {
489 self.metrics
491 .failed_connections
492 .fetch_add(1, Ordering::Relaxed);
493 self.set_state(TransportState::Failed {
494 reason: e.to_string(),
495 });
496 error!("Failed to connect Tower transport adapter: {}", e);
497 Err(TransportError::ConnectionFailed(e.to_string()))
498 }
499 }
500 }
501
502 async fn disconnect(&self) -> TransportResult<()> {
503 if matches!(self.state().await, TransportState::Disconnected) {
504 return Ok(());
505 }
506
507 self.set_state(TransportState::Disconnecting);
508
509 *self.sender.lock().await = None;
511 *self.receiver.lock().await = None;
512
513 if let Some(handle) = self._cleanup_task.lock().await.take() {
515 handle.abort();
516 }
517
518 self.set_state(TransportState::Disconnected);
519 info!("Tower transport adapter disconnected");
520 Ok(())
521 }
522
523 async fn send(&self, message: TransportMessage) -> TransportResult<()> {
524 let state = self.state().await;
525 if !matches!(state, TransportState::Connected) {
526 return Err(TransportError::ConnectionFailed(format!(
527 "Tower transport not connected: {state}"
528 )));
529 }
530
531 let sender_guard = self.sender.lock().await;
532 if let Some(sender) = sender_guard.as_ref() {
533 let message_id = message.id.clone();
534 let message_size = message.size();
535
536 match sender.try_send(message) {
538 Ok(()) => {}
539 Err(mpsc::error::TrySendError::Full(_)) => {
540 return Err(TransportError::SendFailed(
541 "Transport channel full, applying backpressure".to_string(),
542 ));
543 }
544 Err(mpsc::error::TrySendError::Closed(_)) => {
545 return Err(TransportError::SendFailed(
546 "Transport channel closed".to_string(),
547 ));
548 }
549 }
550
551 self.metrics.messages_sent.fetch_add(1, Ordering::Relaxed);
553 self.metrics
554 .bytes_sent
555 .fetch_add(message_size as u64, Ordering::Relaxed);
556
557 self.event_emitter
559 .emit_message_sent(message_id, message_size);
560
561 trace!("Sent message via Tower transport: {} bytes", message_size);
562 Ok(())
563 } else {
564 Err(TransportError::SendFailed(
565 "Sender not available".to_string(),
566 ))
567 }
568 }
569
570 async fn receive(&self) -> TransportResult<Option<TransportMessage>> {
571 let state = self.state().await;
572 if !matches!(state, TransportState::Connected) {
573 return Err(TransportError::ConnectionFailed(format!(
574 "Tower transport not connected: {state}"
575 )));
576 }
577
578 let mut receiver_guard = self.receiver.lock().await;
579 if let Some(ref mut receiver) = receiver_guard.as_mut() {
580 match receiver.recv().await {
581 Some(message) => {
582 trace!(
583 "Received message via Tower transport: {} bytes",
584 message.size()
585 );
586 Ok(Some(message))
587 }
588 None => {
589 warn!("Tower transport receiver disconnected");
590 self.set_state(TransportState::Failed {
591 reason: "Receiver channel disconnected".to_string(),
592 });
593 Err(TransportError::ReceiveFailed(
594 "Channel disconnected".to_string(),
595 ))
596 }
597 }
598 } else {
599 Err(TransportError::ReceiveFailed(
600 "Receiver not available".to_string(),
601 ))
602 }
603 }
604
605 async fn metrics(&self) -> TransportMetrics {
606 let mut metrics = self.metrics.snapshot();
608
609 metrics.active_connections = self.session_manager.active_session_count().await as u64;
611
612 metrics
613 }
614
615 fn endpoint(&self) -> Option<String> {
616 Some("tower://adapter".to_string())
617 }
618}
619
620use turbomcp_protocol::Result as McpResult;
622
623#[cfg(test)]
624mod tests {
625 use super::*;
626 use pretty_assertions::assert_eq;
627
628 #[test]
629 fn test_session_info_creation() {
630 let session = SessionInfo::new();
631
632 assert!(!session.id.is_empty());
633 assert!(session.duration() < Duration::from_millis(100)); assert!(!session.is_expired(Duration::from_secs(1)));
635 }
636
637 #[tokio::test]
638 async fn test_session_manager_creation() {
639 let manager = SessionManager::new();
640 assert_eq!(manager.active_session_count().await, 0);
641 }
642
643 #[tokio::test]
644 async fn test_session_lifecycle() {
645 let manager = SessionManager::new();
646
647 let session = manager.create_session().await.unwrap();
649 assert_eq!(manager.active_session_count().await, 1);
650
651 let retrieved = manager.get_session(&session.id).unwrap();
653 assert_eq!(retrieved.id, session.id);
654
655 let removed = manager.remove_session(&session.id);
657 assert!(removed);
658 assert_eq!(manager.active_session_count().await, 0);
659 }
660
661 #[tokio::test]
662 async fn test_tower_transport_adapter_creation() {
663 let adapter = TowerTransportAdapter::new();
664
665 assert_eq!(adapter.transport_type(), TransportType::Http);
666 assert!(adapter.capabilities().supports_bidirectional);
667 assert!(adapter.capabilities().supports_streaming);
668 assert!(adapter.capabilities().supports_multiplexing);
669 }
670
671 #[tokio::test]
672 async fn test_tower_transport_connection_lifecycle() {
673 let adapter = TowerTransportAdapter::new();
674
675 assert_eq!(adapter.state().await, TransportState::Disconnected);
677
678 let result = adapter.connect().await;
680 assert!(result.is_ok(), "Failed to connect: {result:?}");
681 assert_eq!(adapter.state().await, TransportState::Connected);
682
683 let result = adapter.disconnect().await;
685 assert!(result.is_ok(), "Failed to disconnect: {result:?}");
686 assert_eq!(adapter.state().await, TransportState::Disconnected);
687 }
688
689 #[tokio::test]
690 async fn test_tower_transport_message_processing() {
691 let adapter = TowerTransportAdapter::new();
692 let session = SessionInfo::new();
693
694 let test_payload = serde_json::json!({
696 "jsonrpc": "2.0",
697 "id": "test-123",
698 "method": "ping",
699 "params": {}
700 });
701
702 let payload_bytes = Bytes::from(serde_json::to_vec(&test_payload).unwrap());
703 let message = TransportMessage::new(MessageId::from("test-123"), payload_bytes);
704
705 let result = adapter.process_message(message, &session).await;
707 assert!(result.is_ok(), "Failed to process message: {result:?}");
708
709 let response = result.unwrap().unwrap();
710 assert!(!response.payload.is_empty());
711
712 let response_json: serde_json::Value = serde_json::from_slice(&response.payload).unwrap();
714 assert_eq!(response_json["jsonrpc"], "2.0");
715 assert!(response_json["result"].is_object());
716 }
717}