Skip to main content

solo_api/
mcp_session.rs

1// SPDX-License-Identifier: Apache-2.0
2
3//! v0.11.0 P1 — MCP `Mcp-Session-Id` session store + middleware.
4//!
5//! v0.10.2 shipped `/mcp` as a one-shot request/response surface — every
6//! POST opened a fresh dispatcher with no cross-request state. v0.11.0
7//! lifts that to the full MCP Streamable HTTP spec, and this module is
8//! the foundation: a `DashMap`-backed [`SessionStore`] of
9//! [`SessionState`] entries keyed by [`SessionId`], plus an
10//! [`mcp_session_middleware`] Axum middleware that validates the
11//! `Mcp-Session-Id` request header against the store.
12//!
13//! ## Locked design (plan §3)
14//!
15//! - **Decision A — In-memory storage.** A
16//!   `DashMap<SessionId, SessionState>` gives lock-free per-session
17//!   reads on the dispatch hot path. Solo runs as a single-process
18//!   daemon today; cross-process session persistence is deferred to a
19//!   future release (clients re-`initialize` on daemon restart).
20//! - **Decision D — TTL.** 30 min inactivity + 4 hr absolute cap.
21//!   Background cleanup task runs every 60 s and removes expired
22//!   sessions; a lazy expiry check on every `get` is the safety net
23//!   for the window between sweeps.
24//! - **Expired session → 404.** The middleware returns 404 Not Found
25//!   with a body that includes a `re-initialize` instruction so
26//!   clients can distinguish "session expired" from "server down".
27//!
28//! ## Dispatcher integration (Option B — session-agnostic)
29//!
30//! The brief leaves "does the dispatcher learn about sessions?" open.
31//! P1 picks **Option B**: [`crate::mcp_dispatch::McpDispatcher`] stays
32//! session-agnostic. Sessions are purely an HTTP-transport concern;
33//! the dispatcher receives the resolved tenant + audit principal and
34//! has no knowledge of `SessionId`. The stdio path (which has no
35//! sessions) keeps working unchanged. v0.11.0 P3 will route per-tool
36//! progress events through the session's notification channel by
37//! reading the session out of the request extension before building
38//! the per-request `ProgressEmitter` — without baking sessions into
39//! the dispatcher itself.
40//!
41//! ## v0.11.0 P2 — event buffer + publish API
42//!
43//! P2 grows [`SessionState`] with the two fields the resumable GET
44//! stream rides on top of:
45//!
46//!   - `event_tx: broadcast::Sender<McpStreamEvent>` (capacity
47//!     [`MCP_SESSION_EVENT_BUFFER_CAPACITY`] = 256 per Decision E).
48//!   - `next_event_id: AtomicU64` (monotonic per-session event id).
49//!
50//! A new [`SessionState::publish_event`] helper allocates the next id,
51//! constructs an [`McpStreamEvent`], and fans it out to every live
52//! subscriber. P3 (per-tool progress) and P4 (notifications/message
53//! bridge) call this method on the same session record the HTTP POST
54//! handler resolved; the GET handler in `http.rs` consumes via
55//! [`SessionState::subscribe_events`].
56//!
57//! ## What this module does NOT do
58//!
59//! - **No tenant/principal binding check.** Cross-request tenant
60//!   mismatch (`409 Conflict`) and cross-principal access
61//!   (`401 Unauthorized`) still TBD — P2 wires the broadcast channel
62//!   but leaves the auth-binding policy decisions to a follow-up
63//!   priority (Plan §9 Q4).
64//! - **No audit emission on session open/close.** Plan §9 Q3 still
65//!   open — P2 keeps the store as pure in-memory plumbing.
66
67use std::collections::VecDeque;
68use std::sync::Arc;
69use std::sync::atomic::AtomicU64;
70use std::time::Duration;
71
72use axum::extract::{Request, State};
73use axum::http::{HeaderMap, HeaderName, HeaderValue, StatusCode};
74use axum::middleware::Next;
75use axum::response::{IntoResponse, Response};
76use dashmap::DashMap;
77use serde::{Deserialize, Serialize};
78use solo_core::TenantId;
79use tokio::sync::broadcast;
80use uuid::Uuid;
81
82use crate::auth::AuthenticatedPrincipal;
83
84/// HTTP header name carrying the `Mcp-Session-Id` per the MCP
85/// Streamable HTTP transport spec. Lowercase because HTTP headers are
86/// case-insensitive on the wire and axum stores them lowercased.
87pub const MCP_SESSION_ID_HEADER: &str = "mcp-session-id";
88
89/// Inactivity TTL (milliseconds). A session whose `last_accessed_at_ms`
90/// is more than this old is considered expired (Decision D).
91pub const MCP_SESSION_INACTIVITY_TTL_MS: u64 = 30 * 60 * 1000;
92
93/// Absolute TTL (milliseconds). A session is unconditionally expired
94/// this long after its `created_at_ms`, regardless of activity
95/// (Decision D — bounds worst-case memory growth for orphaned
96/// sessions).
97pub const MCP_SESSION_ABSOLUTE_TTL_MS: u64 = 4 * 60 * 60 * 1000;
98
99/// Cadence (seconds) for the background sweep task that removes
100/// expired sessions. Lazy expiry on every `get` is the primary safety
101/// net; the sweep keeps total memory bounded between accesses for
102/// idle sessions.
103pub const MCP_SESSION_SWEEP_INTERVAL_SECS: u64 = 60;
104
105/// HTTP header name carrying the `Last-Event-ID` per the SSE
106/// specification. v0.11.0 P2 reads this on `GET /mcp` to resume an
107/// interrupted stream from a known event id (Decision E).
108pub const MCP_LAST_EVENT_ID_HEADER: &str = "last-event-id";
109
110/// Capacity of the per-session `tokio::sync::broadcast` channel that
111/// carries server-initiated SSE events (init, message, progress,
112/// heartbeat, lagged). Per plan §3 Decision E. A subscriber that
113/// drifts further behind than this sees a `RecvError::Lagged(n)` on
114/// its next `recv`, at which point the GET handler emits one
115/// `event: lagged` and resumes from the current cursor.
116pub const MCP_SESSION_EVENT_BUFFER_CAPACITY: usize = 256;
117
118/// Opaque session id assigned by the server. v7 UUID — time-ordered
119/// for sortability per Solo's `memory_id` discipline; printed as a
120/// regular hyphenated UUID string on the wire.
121///
122/// Server-assigned (not client-proposed) per Decision A — keeps
123/// correctness on the server.
124#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
125pub struct SessionId(String);
126
127impl SessionId {
128    /// Generate a fresh server-assigned id.
129    pub fn new() -> Self {
130        Self(Uuid::now_v7().to_string())
131    }
132
133    /// Parse a wire-format id. Returns `None` for empty / non-UUID
134    /// strings; the middleware treats `None` as "unknown session" →
135    /// 404 (rather than 400) so clients see a single re-init code
136    /// path regardless of header malformation.
137    pub fn parse(raw: &str) -> Option<Self> {
138        // Reject empty / whitespace strings. We don't reject
139        // arbitrary UUID strings here because the `DashMap` lookup
140        // does the real validation: an id we never assigned simply
141        // isn't in the store.
142        let s = raw.trim();
143        if s.is_empty() {
144            return None;
145        }
146        Some(Self(s.to_string()))
147    }
148
149    /// String representation suitable for the `Mcp-Session-Id` response
150    /// header value.
151    pub fn as_str(&self) -> &str {
152        &self.0
153    }
154}
155
156impl Default for SessionId {
157    fn default() -> Self {
158        Self::new()
159    }
160}
161
162impl std::fmt::Display for SessionId {
163    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
164        f.write_str(&self.0)
165    }
166}
167
168/// Discriminator for the per-session SSE event stream. The wire `event:`
169/// field is rendered from the kebab-case spelling of each variant — kept
170/// in lock-step with the constants below so handlers + clients agree
171/// on the literal string.
172///
173/// Variants:
174///
175///   - [`Init`] — emitted once when a subscriber connects (`event: init`).
176///     The payload includes the session id + tenant + connect ts.
177///   - [`Message`] — JSON-RPC `notifications/message` from the P4 bridge
178///     (`event: message`).
179///   - [`Progress`] — JSON-RPC `notifications/progress` from P3 long-running
180///     tool handlers (`event: progress`).
181///   - [`Lagged`] — synthetic event emitted by the GET handler when a
182///     subscriber falls past the broadcast buffer's capacity, OR when
183///     a `Last-Event-ID` is older than the buffer's oldest retained
184///     event (`event: lagged`). Carries `{dropped: <count>}` so clients
185///     know whether to resync state.
186///   - [`Heartbeat`] — synthetic event emitted by the heartbeat tick to
187///     keep proxies + clients aware the stream is alive
188///     (`event: heartbeat`).
189///
190/// [`Init`]: McpEventKind::Init
191/// [`Message`]: McpEventKind::Message
192/// [`Progress`]: McpEventKind::Progress
193/// [`Lagged`]: McpEventKind::Lagged
194/// [`Heartbeat`]: McpEventKind::Heartbeat
195#[derive(Debug, Clone, Copy, PartialEq, Eq)]
196pub enum McpEventKind {
197    /// `event: init` — subscriber-connect handshake.
198    Init,
199    /// `event: message` — JSON-RPC `notifications/message`.
200    Message,
201    /// `event: progress` — JSON-RPC `notifications/progress`.
202    Progress,
203    /// `event: lagged` — subscriber drifted past the buffer.
204    Lagged,
205    /// `event: heartbeat` — periodic liveness ping.
206    Heartbeat,
207}
208
209/// SSE event name emitted on session-connect (`event: init`).
210pub const MCP_STREAM_EVENT_INIT_NAME: &str = "init";
211/// SSE event name carrying a JSON-RPC `notifications/message` payload.
212pub const MCP_STREAM_EVENT_MESSAGE_NAME: &str = "message";
213/// SSE event name carrying a JSON-RPC `notifications/progress` payload.
214pub const MCP_STREAM_EVENT_PROGRESS_NAME: &str = "progress";
215/// SSE event name emitted when a subscriber lags past the buffer.
216pub const MCP_STREAM_EVENT_LAGGED_NAME: &str = "lagged";
217/// SSE event name emitted on the periodic heartbeat tick.
218pub const MCP_STREAM_EVENT_HEARTBEAT_NAME: &str = "heartbeat";
219
220impl McpEventKind {
221    /// Wire-format `event:` string. Kept in lock-step with the
222    /// constants above so the GET handler + clients agree on the
223    /// literal.
224    pub fn as_str(&self) -> &'static str {
225        match self {
226            McpEventKind::Init => MCP_STREAM_EVENT_INIT_NAME,
227            McpEventKind::Message => MCP_STREAM_EVENT_MESSAGE_NAME,
228            McpEventKind::Progress => MCP_STREAM_EVENT_PROGRESS_NAME,
229            McpEventKind::Lagged => MCP_STREAM_EVENT_LAGGED_NAME,
230            McpEventKind::Heartbeat => MCP_STREAM_EVENT_HEARTBEAT_NAME,
231        }
232    }
233}
234
235/// One event on a session's SSE stream. Cloneable so the broadcast
236/// channel can fan out one event to N concurrent subscribers.
237///
238/// The `id` is monotonic per session — clients carry the last seen id
239/// in a `Last-Event-ID` header on reconnect to request a replay of
240/// missed events. Heartbeats CARRY ids too (no second id space) so a
241/// reconnecting client never sees a gap.
242#[derive(Debug, Clone)]
243pub struct McpStreamEvent {
244    /// Monotonic per-session event id. First event is `1` (the init
245    /// event); `0` is reserved as the sentinel "I have never seen
246    /// anything" value clients send on first connect.
247    pub id: u64,
248    /// Wire-event discriminator (`init` / `message` / `progress` /
249    /// `lagged` / `heartbeat`).
250    pub event: McpEventKind,
251    /// JSON payload. For `message` / `progress` this is the full
252    /// JSON-RPC envelope minus the transport `id`. For `init` /
253    /// `lagged` / `heartbeat` it's an event-specific Solo-shaped
254    /// object documented at each call site.
255    pub data: serde_json::Value,
256}
257
258/// One session's state. v0.11.0 P2 grows this from P1's minimal
259/// "tenant + timestamps" record by adding the broadcast event channel
260/// + monotonic event-id counter the resumable GET stream rides on.
261///
262/// **Not `Clone`** — atomics + `broadcast::Sender` make `Clone` a
263/// surprising contract. The store hands out `Arc<SessionState>` so
264/// concurrent requests observe each other's `touch()` calls + share
265/// one event channel. Callers that want a snapshot of the timestamps
266/// can read them via the public fields directly.
267#[derive(Debug)]
268pub struct SessionState {
269    /// Tenant the session is bound to. Set on session create from the
270    /// extractor-resolved tenant; a future priority will refuse to
271    /// reuse a session under a different tenant.
272    pub tenant_id: TenantId,
273    /// Authenticated principal at session create time. `None` for
274    /// unauthenticated loopback deployments (the daemon default).
275    /// A future cross-principal access check uses this to refuse a
276    /// session presented with a different bearer / OIDC subject.
277    pub principal: Option<AuthenticatedPrincipal>,
278    /// Wall-clock millis at session create. Compared against
279    /// `MCP_SESSION_ABSOLUTE_TTL_MS`.
280    pub created_at_ms: i64,
281    /// Wall-clock millis updated on every successful `SessionStore::get`.
282    /// Compared against `MCP_SESSION_INACTIVITY_TTL_MS`. Stored as
283    /// `AtomicI64` so reads via `Arc<SessionState>` can refresh without
284    /// re-inserting into the DashMap shard.
285    pub last_accessed_at_ms: std::sync::atomic::AtomicI64,
286    /// v0.11.0 P2: broadcast channel fed by `publish_event`. The GET
287    /// handler subscribes to this on connect; P3 (progress) and P4
288    /// (notifications/message) publish into it. Capacity bounded by
289    /// [`MCP_SESSION_EVENT_BUFFER_CAPACITY`] per Decision E.
290    ///
291    /// Note: `broadcast::channel` does NOT backfill freshly-subscribed
292    /// receivers with previously-sent events. To support the
293    /// `Last-Event-ID` resume contract we also keep a ring buffer
294    /// (`event_replay_buffer`) which the GET handler reads on connect
295    /// before tailing this channel for live events.
296    pub event_tx: broadcast::Sender<McpStreamEvent>,
297    /// v0.11.0 P2: monotonic per-session event id counter. Allocated
298    /// via `fetch_add(1, SeqCst)` from `publish_event`; first event
299    /// has id `1` (id `0` is the "never seen" sentinel clients send on
300    /// the first `Last-Event-ID` header).
301    pub next_event_id: AtomicU64,
302    /// v0.11.0 P2: bounded ring buffer of recent events for
303    /// `Last-Event-ID` replay. Capacity matches the broadcast channel
304    /// ([`MCP_SESSION_EVENT_BUFFER_CAPACITY`]); oldest entry evicted
305    /// on insert past the cap. `std::sync::Mutex` rather than
306    /// `tokio::sync::Mutex` because the critical sections are tiny
307    /// (push one event / clone a Vec out) — no `await` inside the
308    /// lock. Wrapping in `Arc<Mutex<...>>` keeps `SessionState` cheap
309    /// to share across the broadcast subscribers + the publisher.
310    pub event_replay_buffer: Arc<std::sync::Mutex<VecDeque<McpStreamEvent>>>,
311}
312
313impl SessionState {
314    /// Build a fresh session-state record. Used by [`SessionStore::insert`]
315    /// and the session-extractor path. Allocates a fresh broadcast
316    /// channel (capacity [`MCP_SESSION_EVENT_BUFFER_CAPACITY`]) for the
317    /// session's SSE stream and a matching-capacity replay ring buffer.
318    pub fn new(tenant_id: TenantId, principal: Option<AuthenticatedPrincipal>) -> Self {
319        let now_ms = now_ms();
320        let (event_tx, _) = broadcast::channel(MCP_SESSION_EVENT_BUFFER_CAPACITY);
321        let event_replay_buffer = Arc::new(std::sync::Mutex::new(VecDeque::with_capacity(
322            MCP_SESSION_EVENT_BUFFER_CAPACITY,
323        )));
324        Self {
325            tenant_id,
326            principal,
327            created_at_ms: now_ms,
328            last_accessed_at_ms: std::sync::atomic::AtomicI64::new(now_ms),
329            event_tx,
330            // Start at 1 so the first allocated id is `1` — `0` is
331            // reserved for "client has never seen an event" on the
332            // `Last-Event-ID` header.
333            next_event_id: AtomicU64::new(1),
334            event_replay_buffer,
335        }
336    }
337
338    /// True iff this session is past either TTL.
339    fn is_expired(&self, now_ms: i64) -> bool {
340        let absolute_deadline = self
341            .created_at_ms
342            .saturating_add(MCP_SESSION_ABSOLUTE_TTL_MS as i64);
343        if now_ms >= absolute_deadline {
344            return true;
345        }
346        let last = self
347            .last_accessed_at_ms
348            .load(std::sync::atomic::Ordering::Relaxed);
349        let inactivity_deadline = last.saturating_add(MCP_SESSION_INACTIVITY_TTL_MS as i64);
350        now_ms >= inactivity_deadline
351    }
352
353    /// Bump `last_accessed_at_ms` to "now". Called on every successful
354    /// `SessionStore::get`.
355    fn touch(&self) {
356        self.last_accessed_at_ms
357            .store(now_ms(), std::sync::atomic::Ordering::Relaxed);
358    }
359
360    /// Allocate the next event id, construct an [`McpStreamEvent`],
361    /// and (a) push it onto the replay ring buffer + (b) broadcast it
362    /// to every live subscriber. Returns the assigned id so callers
363    /// (P3/P4) can correlate their write with the resulting stream
364    /// entry.
365    ///
366    /// Lossy on the broadcast side by design: if there are no live
367    /// receivers (or every receiver has been dropped)
368    /// `broadcast::Sender::send` returns `Err(SendError)` — the event
369    /// is silently dropped from the live channel but STILL appended
370    /// to the replay buffer so a future subscriber's `Last-Event-ID`
371    /// replay observes it.
372    ///
373    /// Replay buffer is bounded at [`MCP_SESSION_EVENT_BUFFER_CAPACITY`]
374    /// entries; pushing past the cap evicts the oldest entry. This
375    /// matches the broadcast channel's capacity so the two stay in
376    /// lock-step — a subscriber that subscribed before any events
377    /// were published and then lags past 256 events sees the same
378    /// "buffer overrun" semantics whether it observes them via the
379    /// broadcast lagged-error path or via a `Last-Event-ID` resume.
380    pub fn publish_event(&self, kind: McpEventKind, data: serde_json::Value) -> u64 {
381        let id = self
382            .next_event_id
383            .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
384        let event = McpStreamEvent {
385            id,
386            event: kind,
387            data,
388        };
389        // Push to the replay buffer first — guarantees a subscriber
390        // that races a publish-then-subscribe sequence either sees
391        // the event via the buffer snapshot or via the broadcast
392        // channel (never neither).
393        if let Ok(mut buf) = self.event_replay_buffer.lock() {
394            if buf.len() >= MCP_SESSION_EVENT_BUFFER_CAPACITY {
395                buf.pop_front();
396            }
397            buf.push_back(event.clone());
398        }
399        // Ignore the result: `SendError` means "no live receivers",
400        // which is fine — sessions can exist without an open GET
401        // stream, and the replay buffer above carries forward.
402        let _ = self.event_tx.send(event);
403        id
404    }
405
406    /// Subscribe to the session's event stream. Returns a fresh
407    /// `broadcast::Receiver` that observes every event published from
408    /// this call forward. Combined with [`SessionState::snapshot_replay_buffer`]
409    /// + `Last-Event-ID` replay logic in the GET handler, this gives
410    /// the spec's resume-from-missed-event semantics.
411    pub fn subscribe_events(&self) -> broadcast::Receiver<McpStreamEvent> {
412        self.event_tx.subscribe()
413    }
414
415    /// Snapshot the current replay buffer. Returns a `Vec<McpStreamEvent>`
416    /// in monotonically increasing id order. The GET handler calls
417    /// this once on connect AFTER calling [`Self::subscribe_events`]
418    /// (so any event published during the snapshot lands in the live
419    /// receiver — the handler dedupes the overlap by id).
420    pub fn snapshot_replay_buffer(&self) -> Vec<McpStreamEvent> {
421        match self.event_replay_buffer.lock() {
422            Ok(buf) => buf.iter().cloned().collect(),
423            // Poisoned lock: return empty rather than panicking. The
424            // GET handler treats this as "no buffered events" which
425            // is harmless — the subscriber falls through to the live
426            // broadcast receiver.
427            Err(poisoned) => poisoned.into_inner().iter().cloned().collect(),
428        }
429    }
430}
431
432/// In-memory, lock-free session store keyed by [`SessionId`]. The
433/// `Arc<SessionState>` value lets the middleware hand a cheap clone to
434/// each request without holding a `DashMap` shard lock for the whole
435/// dispatch.
436///
437/// Cloning a `SessionStore` is cheap — internally it's an
438/// `Arc<Inner>`. Pass a clone to the per-process `SoloHttpState`; the
439/// background sweep task holds another clone via `Arc::downgrade`.
440#[derive(Clone)]
441pub struct SessionStore {
442    inner: Arc<SessionStoreInner>,
443}
444
445struct SessionStoreInner {
446    sessions: DashMap<SessionId, Arc<SessionState>>,
447    /// Sweep task handle. Held so it can be aborted when the last
448    /// store reference drops. Wrapped in `Mutex` so `new` and `Drop`
449    /// can both touch it; never contended on the hot path.
450    sweep_task: std::sync::Mutex<Option<tokio::task::JoinHandle<()>>>,
451}
452
453impl SessionStore {
454    /// Build a fresh store and spawn the background sweep task on the
455    /// current tokio runtime. Panics if called outside a tokio runtime
456    /// context — callers should construct this from inside an `async`
457    /// context (e.g. the daemon's startup, or
458    /// `tokio::runtime::Handle::current()`).
459    pub fn new() -> Self {
460        let inner = Arc::new(SessionStoreInner {
461            sessions: DashMap::new(),
462            sweep_task: std::sync::Mutex::new(None),
463        });
464        let weak = Arc::downgrade(&inner);
465        let sweep = tokio::spawn(async move {
466            let mut tick =
467                tokio::time::interval(Duration::from_secs(MCP_SESSION_SWEEP_INTERVAL_SECS));
468            // Skip the immediate first tick — interval::tick fires once
469            // immediately by default which would sweep a freshly-empty
470            // store on startup (harmless but noisy in tests).
471            tick.tick().await;
472            loop {
473                tick.tick().await;
474                let Some(inner) = weak.upgrade() else {
475                    // Store dropped; exit the loop. The `Drop` impl
476                    // aborts us anyway; this is the cooperative path
477                    // for graceful shutdown.
478                    return;
479                };
480                sweep_once(&inner.sessions);
481            }
482        });
483        // Survive a poisoned mutex during startup: extract the inner guard
484        // rather than panicking. The mutex protects an Option<JoinHandle>
485        // that we only ever assign in `new` and read in `Drop`; poisoning
486        // would only happen if a prior `Drop` panicked, which is
487        // recoverable here.
488        *inner
489            .sweep_task
490            .lock()
491            .unwrap_or_else(|p| p.into_inner()) = Some(sweep);
492        Self { inner }
493    }
494
495    /// Build a store WITHOUT a background sweep task. Used by tests that
496    /// run outside a tokio runtime context or want to drive sweep
497    /// manually via [`Self::sweep_now`].
498    #[cfg(test)]
499    pub(crate) fn new_for_tests_no_sweep() -> Self {
500        let inner = Arc::new(SessionStoreInner {
501            sessions: DashMap::new(),
502            sweep_task: std::sync::Mutex::new(None),
503        });
504        Self { inner }
505    }
506
507    /// Insert a new session and return its assigned id. Each call
508    /// produces a fresh server-assigned [`SessionId`] (UUID v7).
509    pub fn insert(&self, state: SessionState) -> SessionId {
510        let id = SessionId::new();
511        self.inner.sessions.insert(id.clone(), Arc::new(state));
512        id
513    }
514
515    /// Look up a session by id. Returns `None` if absent OR expired —
516    /// expired entries are removed lazily here so the store doesn't
517    /// hand out stale state between sweeps. The returned
518    /// `Arc<SessionState>`'s `last_accessed_at_ms` is bumped to "now"
519    /// on a successful hit.
520    pub fn get(&self, id: &SessionId) -> Option<Arc<SessionState>> {
521        let now = now_ms();
522        // Fast path: clone the Arc out of the shard, then check expiry
523        // outside the shard lock. If expired, remove and return None.
524        let cloned = self.inner.sessions.get(id).map(|r| r.clone());
525        let state = cloned?;
526        if state.is_expired(now) {
527            self.inner.sessions.remove(id);
528            return None;
529        }
530        state.touch();
531        Some(state)
532    }
533
534    /// Drop a session by id. Returns `true` if it was present.
535    pub fn delete(&self, id: &SessionId) -> bool {
536        self.inner.sessions.remove(id).is_some()
537    }
538
539    /// Current count of stored sessions. Used by tests + future
540    /// /v1/health-style readiness probes.
541    pub fn len(&self) -> usize {
542        self.inner.sessions.len()
543    }
544
545    /// True if the store has no sessions.
546    pub fn is_empty(&self) -> bool {
547        self.inner.sessions.is_empty()
548    }
549
550    /// Force one immediate sweep. Used by the background task and by
551    /// tests that want deterministic sweep behaviour without waiting
552    /// 60s for the next tick.
553    pub fn sweep_now(&self) {
554        sweep_once(&self.inner.sessions);
555    }
556}
557
558impl Default for SessionStore {
559    fn default() -> Self {
560        Self::new()
561    }
562}
563
564impl Drop for SessionStoreInner {
565    fn drop(&mut self) {
566        if let Ok(mut guard) = self.sweep_task.lock()
567            && let Some(handle) = guard.take()
568        {
569            handle.abort();
570        }
571    }
572}
573
574/// Walk the map once and remove every expired entry. Held outside
575/// `SessionStore::sweep_now` so the background task can call it
576/// against `Arc<SessionStoreInner>` directly without re-borrowing
577/// the cloneable `SessionStore` wrapper.
578fn sweep_once(sessions: &DashMap<SessionId, Arc<SessionState>>) {
579    let now = now_ms();
580    // Collect first to avoid holding shard guards while issuing
581    // removes (DashMap supports `retain` but `retain` blocks readers
582    // shard-by-shard; collect-then-remove keeps the hot path
583    // contention-free).
584    let expired: Vec<SessionId> = sessions
585        .iter()
586        .filter(|entry| entry.value().is_expired(now))
587        .map(|entry| entry.key().clone())
588        .collect();
589    for id in expired {
590        sessions.remove(&id);
591    }
592}
593
594/// Current wall-clock time in milliseconds. Centralised so tests can
595/// hook in later (none of the v0.11.0 P1 tests need to). Returns
596/// `chrono::Utc::now().timestamp_millis()` to match the rest of the
597/// crate's timestamps.
598fn now_ms() -> i64 {
599    chrono::Utc::now().timestamp_millis()
600}
601
602/// Body of the 404 response the middleware returns when a request
603/// presents an unknown / expired `Mcp-Session-Id`. The
604/// `re-initialize` field is the contract for client retry logic:
605/// drop the stale id, POST `/mcp` without the header, capture the
606/// `Mcp-Session-Id` from the response.
607pub const MCP_SESSION_EXPIRED_ERROR: &str = "session_expired";
608
609/// Axum middleware that enforces the `Mcp-Session-Id` contract.
610///
611/// Behaviour:
612///   - **No `Mcp-Session-Id` header** → pass through. The downstream
613///     POST handler treats this as a session-init request and emits
614///     the assigned id in the response header.
615///   - **`Mcp-Session-Id` header present + session in store + not
616///     expired** → attach `Arc<SessionState>` + `SessionId` to the
617///     request extensions and pass through.
618///   - **`Mcp-Session-Id` header present + session unknown OR
619///     expired** → 404 with body `{"error": "session_expired", ...,
620///     "retry": "re-initialize"}`.
621pub async fn mcp_session_middleware(
622    State(store): State<SessionStore>,
623    mut req: Request,
624    next: Next,
625) -> Response {
626    let header_value = req
627        .headers()
628        .get(MCP_SESSION_ID_HEADER)
629        .and_then(|h| h.to_str().ok())
630        .map(|s| s.to_string());
631
632    if let Some(raw) = header_value {
633        let id = match SessionId::parse(&raw) {
634            Some(id) => id,
635            None => return session_expired_response(&raw),
636        };
637        match store.get(&id) {
638            Some(state) => {
639                req.extensions_mut().insert(id);
640                req.extensions_mut().insert(state);
641            }
642            None => return session_expired_response(&raw),
643        }
644    }
645    next.run(req).await
646}
647
648/// 404 + structured body. Browser MCP clients (Anthropic AI SDK's
649/// `experimental_createMCPClient`) read the `error` discriminator to
650/// route into the re-initialize path automatically.
651fn session_expired_response(presented_id: &str) -> Response {
652    let body = axum::Json(serde_json::json!({
653        "error": MCP_SESSION_EXPIRED_ERROR,
654        "status": 404,
655        "message": format!(
656            "Mcp-Session-Id `{presented_id}` is unknown or expired; \
657             re-initialize via POST /mcp without Mcp-Session-Id"
658        ),
659        "retry": "re-initialize",
660    }));
661    (StatusCode::NOT_FOUND, body).into_response()
662}
663
664/// Insert a `Mcp-Session-Id` response header so the client can echo
665/// it back on subsequent requests. Used by the POST handler when it
666/// freshly creates a session on a request that arrived without the
667/// header.
668pub fn set_session_id_header(headers: &mut HeaderMap, id: &SessionId) {
669    // SessionId::new produces a UUID-string which is always ASCII;
670    // `HeaderValue::from_str` is safe. The `expect` documents the
671    // invariant for future maintainers — we'd rather panic in CI than
672    // silently drop the header.
673    let value =
674        HeaderValue::from_str(id.as_str()).expect("SessionId is ASCII-safe (UUID) for HeaderValue");
675    headers.insert(HeaderName::from_static(MCP_SESSION_ID_HEADER), value);
676}
677
678#[cfg(test)]
679mod tests {
680    use super::*;
681    use std::sync::atomic::Ordering;
682
683    fn fake_tenant() -> TenantId {
684        TenantId::default_tenant()
685    }
686
687    fn fresh_state() -> SessionState {
688        SessionState::new(fake_tenant(), None)
689    }
690
691    #[test]
692    fn session_store_insert_returns_unique_id() {
693        let store = SessionStore::new_for_tests_no_sweep();
694        let id_a = store.insert(fresh_state());
695        let id_b = store.insert(fresh_state());
696        assert_ne!(id_a, id_b, "two inserts must produce distinct ids");
697        assert_eq!(store.len(), 2);
698    }
699
700    #[test]
701    fn session_store_get_returns_state_when_present() {
702        let store = SessionStore::new_for_tests_no_sweep();
703        let id = store.insert(fresh_state());
704        let got = store.get(&id);
705        assert!(got.is_some(), "get must return Some for a just-inserted id");
706        assert_eq!(got.unwrap().tenant_id, fake_tenant());
707    }
708
709    /// Build a state whose `created_at_ms` + `last_accessed_at_ms` are
710    /// both shifted backwards by `delta_ms`. Used by the TTL tests to
711    /// simulate an inactive / aged session without driving the wall
712    /// clock. Mirrors `SessionState::new` for the v0.11.0 P2 broadcast
713    /// channel + replay buffer + event id counter — those fields are
714    /// independent of the timestamp shift.
715    fn aged_state(
716        tenant_id: TenantId,
717        principal: Option<AuthenticatedPrincipal>,
718        delta_ms: i64,
719    ) -> SessionState {
720        let now = now_ms();
721        let shifted = now.saturating_sub(delta_ms);
722        let mut state = SessionState::new(tenant_id, principal);
723        state.created_at_ms = shifted;
724        state.last_accessed_at_ms.store(shifted, Ordering::Relaxed);
725        state
726    }
727
728    #[test]
729    fn session_store_get_returns_none_when_expired_by_inactivity() {
730        let store = SessionStore::new_for_tests_no_sweep();
731        // Hand-build a state whose `last_accessed_at_ms` is older than
732        // the inactivity TTL.
733        let stale_delta = MCP_SESSION_INACTIVITY_TTL_MS as i64 + 1;
734        let stale = Arc::new(aged_state(fake_tenant(), None, stale_delta));
735        let id = SessionId::new();
736        store.inner.sessions.insert(id.clone(), stale);
737        assert!(
738            store.get(&id).is_none(),
739            "session inactive past TTL must read as expired"
740        );
741        // Lazy expiry also evicts the entry.
742        assert!(
743            store.inner.sessions.get(&id).is_none(),
744            "expired entry must be removed from the underlying map"
745        );
746    }
747
748    #[test]
749    fn session_store_get_returns_none_when_expired_by_absolute_ttl() {
750        let store = SessionStore::new_for_tests_no_sweep();
751        // Created past the absolute TTL but recently touched —
752        // absolute-TTL still wins.
753        let absolute_delta = MCP_SESSION_ABSOLUTE_TTL_MS as i64 + 1;
754        let state = aged_state(fake_tenant(), None, absolute_delta);
755        // Touch back to "now" so only the absolute deadline trips.
756        state.last_accessed_at_ms.store(now_ms(), Ordering::Relaxed);
757        let aged = Arc::new(state);
758        let id = SessionId::new();
759        store.inner.sessions.insert(id.clone(), aged);
760        assert!(
761            store.get(&id).is_none(),
762            "session past absolute TTL must read as expired even when recently touched"
763        );
764    }
765
766    #[test]
767    fn session_store_get_refreshes_last_accessed_on_hit() {
768        let store = SessionStore::new_for_tests_no_sweep();
769        let id = store.insert(fresh_state());
770        let before = store
771            .inner
772            .sessions
773            .get(&id)
774            .unwrap()
775            .last_accessed_at_ms
776            .load(Ordering::Relaxed);
777        // Yield long enough that the millis clock advances.
778        std::thread::sleep(std::time::Duration::from_millis(5));
779        let _ = store.get(&id).expect("session must still be present");
780        let after = store
781            .inner
782            .sessions
783            .get(&id)
784            .unwrap()
785            .last_accessed_at_ms
786            .load(Ordering::Relaxed);
787        assert!(
788            after > before,
789            "get must bump last_accessed_at_ms (before={before}, after={after})"
790        );
791    }
792
793    #[test]
794    fn session_store_delete_returns_true_when_present() {
795        let store = SessionStore::new_for_tests_no_sweep();
796        let id = store.insert(fresh_state());
797        assert!(store.delete(&id));
798        assert!(store.get(&id).is_none(), "deleted session must not read");
799    }
800
801    #[test]
802    fn session_store_delete_returns_false_when_absent() {
803        let store = SessionStore::new_for_tests_no_sweep();
804        assert!(!store.delete(&SessionId::new()));
805    }
806
807    #[test]
808    fn session_store_sweep_now_removes_expired() {
809        let store = SessionStore::new_for_tests_no_sweep();
810        // One healthy, one stale.
811        let healthy_id = store.insert(fresh_state());
812        let stale_delta = MCP_SESSION_INACTIVITY_TTL_MS as i64 + 1;
813        let stale = Arc::new(aged_state(fake_tenant(), None, stale_delta));
814        let stale_id = SessionId::new();
815        store.inner.sessions.insert(stale_id.clone(), stale);
816        assert_eq!(store.len(), 2);
817        store.sweep_now();
818        assert_eq!(store.len(), 1, "sweep must drop the expired session");
819        assert!(
820            store.get(&healthy_id).is_some(),
821            "sweep must preserve the healthy session"
822        );
823        assert!(
824            store.inner.sessions.get(&stale_id).is_none(),
825            "stale id must be gone from the map after sweep"
826        );
827    }
828
829    #[tokio::test]
830    async fn session_store_background_sweep_removes_expired() {
831        // Spawn a store with a real sweep task on the current rt.
832        let store = SessionStore::new();
833        // Seed a stale entry directly into the inner map.
834        let stale_delta = MCP_SESSION_INACTIVITY_TTL_MS as i64 + 1;
835        let stale = Arc::new(aged_state(fake_tenant(), None, stale_delta));
836        let stale_id = SessionId::new();
837        store.inner.sessions.insert(stale_id.clone(), stale);
838        // Don't wait 60s; just call sweep_now to prove the same code
839        // path the background task drives works. The 60s-cadence
840        // background task itself is exercised by Drop semantics +
841        // the explicit `sweep_once` unit test above.
842        store.sweep_now();
843        assert!(store.inner.sessions.get(&stale_id).is_none());
844    }
845
846    #[test]
847    fn session_id_round_trips_through_string() {
848        let id = SessionId::new();
849        let s = id.as_str().to_string();
850        let parsed = SessionId::parse(&s).expect("ASCII round-trip");
851        assert_eq!(id, parsed);
852    }
853
854    #[test]
855    fn session_id_parse_rejects_empty_string() {
856        assert!(SessionId::parse("").is_none());
857        assert!(SessionId::parse("   ").is_none());
858    }
859
860    // ----------------------------------------------------------------
861    // v0.11.0 P2 — event buffer + publish_event unit tests
862    // ----------------------------------------------------------------
863
864    /// First three `publish_event` calls allocate ids 1, 2, 3 in order.
865    /// Pins the "start-at-1, monotonic increment" contract; clients
866    /// rely on a 0-sentinel for `Last-Event-ID` ("never seen anything")
867    /// so the first allocated id MUST be ≥ 1.
868    #[test]
869    fn session_state_publish_event_returns_monotonic_ids() {
870        let state = fresh_state();
871        let id1 = state.publish_event(McpEventKind::Init, serde_json::json!({"connected": true}));
872        let id2 = state.publish_event(McpEventKind::Message, serde_json::json!({"hello": 1}));
873        let id3 = state.publish_event(McpEventKind::Progress, serde_json::json!({"progress": 5}));
874        assert_eq!(
875            id1, 1,
876            "first event must allocate id 1 (id 0 reserved for client sentinel)"
877        );
878        assert_eq!(id2, 2);
879        assert_eq!(id3, 3);
880    }
881
882    /// A subscriber that called `subscribe_events()` BEFORE the publish
883    /// observes the published event on its receiver. Pins the
884    /// broadcast wiring end-to-end.
885    #[tokio::test]
886    async fn session_state_publish_event_broadcasts_to_subscribers() {
887        let state = fresh_state();
888        let mut rx = state.subscribe_events();
889        let id = state.publish_event(
890            McpEventKind::Message,
891            serde_json::json!({"jsonrpc": "2.0", "method": "notifications/message"}),
892        );
893        let received = rx
894            .recv()
895            .await
896            .expect("subscriber must observe the broadcast event");
897        assert_eq!(received.id, id);
898        assert_eq!(received.event, McpEventKind::Message);
899        assert_eq!(received.data["method"], "notifications/message");
900    }
901
902    /// Publishing past the broadcast channel's capacity (256) and the
903    /// matching ring buffer's capacity:
904    ///
905    ///   - The replay buffer retains only the last 256 events (oldest
906    ///     evicted on every push past the cap).
907    ///   - A receiver that subscribed before the publishes observes
908    ///     either every event or a `RecvError::Lagged` followed by the
909    ///     tail of the events. We assert the buffer-only side here
910    ///     (deterministic) — the receiver behaviour is covered by the
911    ///     GET-handler integration test.
912    #[test]
913    fn session_state_event_buffer_capacity_256() {
914        let state = fresh_state();
915        let total = (MCP_SESSION_EVENT_BUFFER_CAPACITY + 50) as u64; // 306
916        for _ in 0..total {
917            state.publish_event(McpEventKind::Message, serde_json::json!({}));
918        }
919        let snapshot = state.snapshot_replay_buffer();
920        assert_eq!(
921            snapshot.len(),
922            MCP_SESSION_EVENT_BUFFER_CAPACITY,
923            "replay buffer must retain exactly {} entries after overflow",
924            MCP_SESSION_EVENT_BUFFER_CAPACITY,
925        );
926        // Oldest retained id = total - capacity + 1 (since ids start at 1
927        // and we published `total` events).
928        let expected_first_id = total - MCP_SESSION_EVENT_BUFFER_CAPACITY as u64 + 1;
929        let expected_last_id = total; // last allocated id
930        assert_eq!(
931            snapshot.first().unwrap().id,
932            expected_first_id,
933            "oldest retained event id must be {expected_first_id}",
934        );
935        assert_eq!(
936            snapshot.last().unwrap().id,
937            expected_last_id,
938            "newest retained event id must be {expected_last_id}",
939        );
940        // Buffer is contiguous (each id is previous + 1).
941        for win in snapshot.windows(2) {
942            assert_eq!(
943                win[1].id,
944                win[0].id + 1,
945                "replay buffer must be contiguous (no gaps)",
946            );
947        }
948    }
949
950    /// `publish_event` with no live receivers does NOT error — the
951    /// event still lands in the replay buffer so a future subscriber
952    /// observes it via the `Last-Event-ID` replay path.
953    #[test]
954    fn session_state_publish_event_no_subscribers_is_lossless_to_buffer() {
955        let state = fresh_state();
956        let id = state.publish_event(McpEventKind::Init, serde_json::json!({"hi": true}));
957        let snapshot = state.snapshot_replay_buffer();
958        assert_eq!(snapshot.len(), 1);
959        assert_eq!(snapshot[0].id, id);
960    }
961}