1use std::collections::BTreeSet;
2
3use sim_kernel::{CapabilityName, Expr};
4#[cfg(feature = "stream")]
5use sim_lib_stream_core::StreamPacket;
6
7use crate::{McpNativeCard, McpProfile};
8
9pub const DEFAULT_PROTOCOL_VERSION: &str = "2025-03-26";
11
12#[derive(Clone, Copy, Debug, PartialEq, Eq)]
13pub(crate) enum McpBoundaryLimit {
14 Deadline,
15 Rate,
16 ActiveRequests,
17}
18
19#[derive(Clone)]
25pub struct McpSession {
26 pub id: String,
28 pub initialized: bool,
30 pub client_info: Option<Expr>,
32 pub protocol_version: String,
34 pub profile: McpProfile,
36 pub native_cards: Vec<McpNativeCard>,
38 pub granted_capabilities: Vec<CapabilityName>,
40 pub deadline_ms: Option<u64>,
42 pub rate_limit: Option<usize>,
44 pub active_request_limit: Option<usize>,
46 requests_seen: usize,
47 pub active_requests: BTreeSet<String>,
49 #[cfg(feature = "cassette")]
50 pub(crate) cassette: Option<crate::McpCassette>,
51 #[cfg(feature = "stream")]
52 pub(crate) cancelled_requests: BTreeSet<String>,
53 #[cfg(feature = "stream")]
54 stream_packets: Vec<StreamPacket>,
55 pub shutdown_requested: bool,
57}
58
59impl McpSession {
60 pub fn new(id: impl Into<String>, profile: McpProfile) -> Self {
62 Self {
63 id: id.into(),
64 initialized: false,
65 client_info: None,
66 protocol_version: DEFAULT_PROTOCOL_VERSION.to_owned(),
67 profile,
68 native_cards: Vec::new(),
69 granted_capabilities: Vec::new(),
70 deadline_ms: None,
71 rate_limit: None,
72 active_request_limit: None,
73 requests_seen: 0,
74 active_requests: BTreeSet::new(),
75 #[cfg(feature = "cassette")]
76 cassette: None,
77 #[cfg(feature = "stream")]
78 cancelled_requests: BTreeSet::new(),
79 #[cfg(feature = "stream")]
80 stream_packets: Vec::new(),
81 shutdown_requested: false,
82 }
83 }
84
85 pub fn fixture() -> Self {
87 Self::new("fixture", McpProfile::all())
88 }
89
90 pub fn with_native_cards(mut self, cards: Vec<McpNativeCard>) -> Self {
92 self.native_cards = cards;
93 self
94 }
95
96 pub fn with_granted_capability(mut self, capability: CapabilityName) -> Self {
98 self.granted_capabilities.push(capability);
99 self
100 }
101
102 pub fn with_deadline_ms(mut self, deadline_ms: u64) -> Self {
104 self.deadline_ms = Some(deadline_ms);
105 self
106 }
107
108 pub fn with_rate_limit(mut self, limit: usize) -> Self {
110 self.rate_limit = Some(limit);
111 self
112 }
113
114 pub fn with_active_request_limit(mut self, limit: usize) -> Self {
116 self.active_request_limit = Some(limit);
117 self
118 }
119
120 #[cfg(feature = "cassette")]
122 pub fn with_cassette(mut self, cassette: crate::McpCassette) -> Self {
123 self.cassette = Some(cassette);
124 self
125 }
126
127 #[cfg(feature = "cassette")]
129 pub fn cassette(&self) -> Option<&crate::McpCassette> {
130 self.cassette.as_ref()
131 }
132
133 #[cfg(feature = "cassette")]
135 pub fn cassette_mut(&mut self) -> Option<&mut crate::McpCassette> {
136 self.cassette.as_mut()
137 }
138
139 pub(crate) fn admit_request(&mut self, id: &Expr) -> std::result::Result<(), McpBoundaryLimit> {
140 if self.deadline_ms == Some(0) {
141 return Err(McpBoundaryLimit::Deadline);
142 }
143 if self
144 .active_request_limit
145 .is_some_and(|limit| self.active_requests.len() >= limit)
146 {
147 return Err(McpBoundaryLimit::ActiveRequests);
148 }
149 if self
150 .rate_limit
151 .is_some_and(|limit| self.requests_seen >= limit)
152 {
153 return Err(McpBoundaryLimit::Rate);
154 }
155 self.requests_seen += 1;
156 self.begin_request(id);
157 Ok(())
158 }
159
160 pub(crate) fn begin_request(&mut self, id: &Expr) {
161 self.active_requests.insert(request_key(id));
162 }
163
164 pub(crate) fn end_request(&mut self, id: &Expr) {
165 let key = request_key(id);
166 self.active_requests.remove(&key);
167 #[cfg(feature = "stream")]
168 self.cancelled_requests.remove(&key);
169 }
170
171 #[cfg(feature = "stream")]
172 pub(crate) fn request_is_active(&self, id: &Expr) -> bool {
173 self.active_requests.contains(&request_key(id))
174 }
175
176 #[cfg(feature = "stream")]
177 pub(crate) fn mark_request_cancelled(&mut self, id: &Expr) {
178 self.cancelled_requests.insert(request_key(id));
179 }
180
181 #[cfg(feature = "stream")]
183 pub fn request_cancelled(&self, id: &Expr) -> bool {
184 self.cancelled_requests.contains(&request_key(id))
185 }
186
187 #[cfg(feature = "stream")]
188 pub(crate) fn record_stream_packet(&mut self, packet: StreamPacket) {
189 self.stream_packets.push(packet);
190 }
191
192 #[cfg(feature = "stream")]
193 pub(crate) fn record_stream_packets(&mut self, packets: Vec<StreamPacket>) {
194 self.stream_packets.extend(packets);
195 }
196
197 #[cfg(feature = "stream")]
199 pub fn stream_packets(&self) -> &[StreamPacket] {
200 &self.stream_packets
201 }
202}
203
204fn request_key(id: &Expr) -> String {
205 match id {
206 Expr::String(value) => value.clone(),
207 Expr::Number(number) => format!("{}:{}", number.domain, number.canonical),
208 Expr::Nil => "nil".to_owned(),
209 _ => format!("{id:?}"),
210 }
211}