Skip to main content

sim_lib_mcp/
session.rs

1use std::collections::BTreeSet;
2
3use sim_kernel::{CapabilityName, Expr};
4#[cfg(feature = "stream")]
5use sim_lib_stream_core::StreamPacket;
6
7use crate::{McpNativeCard, McpProfile};
8
9/// MCP protocol version advertised by a freshly created [`McpSession`].
10pub const DEFAULT_PROTOCOL_VERSION: &str = "2025-03-26";
11
12#[derive(Clone, Copy, Debug, PartialEq, Eq)]
13pub(crate) enum McpBoundaryLimit {
14    Deadline,
15    Rate,
16    ActiveRequests,
17}
18
19/// Mutable per-connection MCP session state.
20///
21/// Tracks the handshake, the visibility profile, granted capabilities, and the
22/// in-flight request bookkeeping used to enforce deadline, rate, and
23/// concurrency boundaries.
24#[derive(Clone)]
25pub struct McpSession {
26    /// Stable session identifier.
27    pub id: String,
28    /// Whether the `initialize` handshake has completed.
29    pub initialized: bool,
30    /// Client info reported during initialization, if any.
31    pub client_info: Option<Expr>,
32    /// Negotiated MCP protocol version.
33    pub protocol_version: String,
34    /// Visibility profile filtering the surface for this session.
35    pub profile: McpProfile,
36    /// Native cards exposed through this session.
37    pub native_cards: Vec<McpNativeCard>,
38    /// Capabilities granted to this session.
39    pub granted_capabilities: Vec<CapabilityName>,
40    /// Optional per-request deadline budget, in milliseconds.
41    pub deadline_ms: Option<u64>,
42    /// Optional cap on the total number of requests admitted.
43    pub rate_limit: Option<usize>,
44    /// Optional cap on concurrently active requests.
45    pub active_request_limit: Option<usize>,
46    requests_seen: usize,
47    /// Identifiers of currently in-flight requests.
48    pub active_requests: BTreeSet<String>,
49    #[cfg(feature = "cassette")]
50    pub(crate) cassette: Option<crate::McpCassette>,
51    #[cfg(feature = "stream")]
52    pub(crate) cancelled_requests: BTreeSet<String>,
53    #[cfg(feature = "stream")]
54    stream_packets: Vec<StreamPacket>,
55    /// Whether the peer has requested shutdown.
56    pub shutdown_requested: bool,
57}
58
59impl McpSession {
60    /// Creates a session with the given `id` and visibility `profile`.
61    pub fn new(id: impl Into<String>, profile: McpProfile) -> Self {
62        Self {
63            id: id.into(),
64            initialized: false,
65            client_info: None,
66            protocol_version: DEFAULT_PROTOCOL_VERSION.to_owned(),
67            profile,
68            native_cards: Vec::new(),
69            granted_capabilities: Vec::new(),
70            deadline_ms: None,
71            rate_limit: None,
72            active_request_limit: None,
73            requests_seen: 0,
74            active_requests: BTreeSet::new(),
75            #[cfg(feature = "cassette")]
76            cassette: None,
77            #[cfg(feature = "stream")]
78            cancelled_requests: BTreeSet::new(),
79            #[cfg(feature = "stream")]
80            stream_packets: Vec::new(),
81            shutdown_requested: false,
82        }
83    }
84
85    /// Creates a permissive session for tests and fixtures.
86    pub fn fixture() -> Self {
87        Self::new("fixture", McpProfile::all())
88    }
89
90    /// Returns the session with its native cards replaced by `cards`.
91    pub fn with_native_cards(mut self, cards: Vec<McpNativeCard>) -> Self {
92        self.native_cards = cards;
93        self
94    }
95
96    /// Returns the session with `capability` added to the granted set.
97    pub fn with_granted_capability(mut self, capability: CapabilityName) -> Self {
98        self.granted_capabilities.push(capability);
99        self
100    }
101
102    /// Returns the session with a per-request `deadline_ms` budget.
103    pub fn with_deadline_ms(mut self, deadline_ms: u64) -> Self {
104        self.deadline_ms = Some(deadline_ms);
105        self
106    }
107
108    /// Returns the session with a total request `limit`.
109    pub fn with_rate_limit(mut self, limit: usize) -> Self {
110        self.rate_limit = Some(limit);
111        self
112    }
113
114    /// Returns the session with a concurrent active-request `limit`.
115    pub fn with_active_request_limit(mut self, limit: usize) -> Self {
116        self.active_request_limit = Some(limit);
117        self
118    }
119
120    /// Returns the session with a recording/replay `cassette` attached.
121    #[cfg(feature = "cassette")]
122    pub fn with_cassette(mut self, cassette: crate::McpCassette) -> Self {
123        self.cassette = Some(cassette);
124        self
125    }
126
127    /// Returns the attached cassette, if any.
128    #[cfg(feature = "cassette")]
129    pub fn cassette(&self) -> Option<&crate::McpCassette> {
130        self.cassette.as_ref()
131    }
132
133    /// Returns a mutable reference to the attached cassette, if any.
134    #[cfg(feature = "cassette")]
135    pub fn cassette_mut(&mut self) -> Option<&mut crate::McpCassette> {
136        self.cassette.as_mut()
137    }
138
139    pub(crate) fn admit_request(&mut self, id: &Expr) -> std::result::Result<(), McpBoundaryLimit> {
140        if self.deadline_ms == Some(0) {
141            return Err(McpBoundaryLimit::Deadline);
142        }
143        if self
144            .active_request_limit
145            .is_some_and(|limit| self.active_requests.len() >= limit)
146        {
147            return Err(McpBoundaryLimit::ActiveRequests);
148        }
149        if self
150            .rate_limit
151            .is_some_and(|limit| self.requests_seen >= limit)
152        {
153            return Err(McpBoundaryLimit::Rate);
154        }
155        self.requests_seen += 1;
156        self.begin_request(id);
157        Ok(())
158    }
159
160    pub(crate) fn begin_request(&mut self, id: &Expr) {
161        self.active_requests.insert(request_key(id));
162    }
163
164    pub(crate) fn end_request(&mut self, id: &Expr) {
165        let key = request_key(id);
166        self.active_requests.remove(&key);
167        #[cfg(feature = "stream")]
168        self.cancelled_requests.remove(&key);
169    }
170
171    #[cfg(feature = "stream")]
172    pub(crate) fn request_is_active(&self, id: &Expr) -> bool {
173        self.active_requests.contains(&request_key(id))
174    }
175
176    #[cfg(feature = "stream")]
177    pub(crate) fn mark_request_cancelled(&mut self, id: &Expr) {
178        self.cancelled_requests.insert(request_key(id));
179    }
180
181    /// Reports whether the request identified by `id` has been cancelled.
182    #[cfg(feature = "stream")]
183    pub fn request_cancelled(&self, id: &Expr) -> bool {
184        self.cancelled_requests.contains(&request_key(id))
185    }
186
187    #[cfg(feature = "stream")]
188    pub(crate) fn record_stream_packet(&mut self, packet: StreamPacket) {
189        self.stream_packets.push(packet);
190    }
191
192    #[cfg(feature = "stream")]
193    pub(crate) fn record_stream_packets(&mut self, packets: Vec<StreamPacket>) {
194        self.stream_packets.extend(packets);
195    }
196
197    /// Returns the stream packets recorded during this session.
198    #[cfg(feature = "stream")]
199    pub fn stream_packets(&self) -> &[StreamPacket] {
200        &self.stream_packets
201    }
202}
203
204fn request_key(id: &Expr) -> String {
205    match id {
206        Expr::String(value) => value.clone(),
207        Expr::Number(number) => format!("{}:{}", number.domain, number.canonical),
208        Expr::Nil => "nil".to_owned(),
209        _ => format!("{id:?}"),
210    }
211}