Skip to main content

systemprompt_api/services/middleware/session/
mod.rs

1//! Session-establishment middleware.
2//!
3//! [`SessionMiddleware`] resolves or mints the per-request session: it skips
4//! untracked paths, short-circuits detected bots into anonymous contexts,
5//! validates an existing JWT session, and refreshes or recreates the session
6//! when the token is stale, issuing a `Set-Cookie` for newly minted tokens.
7
8mod lifecycle;
9mod skip;
10
11pub use skip::should_skip_session_tracking;
12
13use axum::extract::Request;
14use axum::http::header;
15use axum::middleware::Next;
16use axum::response::Response;
17use std::sync::Arc;
18use systemprompt_analytics::AnalyticsService;
19use systemprompt_identifiers::{AgentName, ContextId, SessionId, UserId};
20use systemprompt_models::api::ApiError;
21use systemprompt_models::auth::UserType;
22use systemprompt_models::execution::context::RequestContext;
23use systemprompt_oauth::services::SessionCreationService;
24use systemprompt_runtime::AppContext;
25use systemprompt_security::{
26    CookieExtractor, HeaderExtractor, TokenExtractor, extract_user_context,
27};
28use systemprompt_traits::AnalyticsProvider;
29use systemprompt_users::{UserProviderImpl, UserService};
30use uuid::Uuid;
31
32#[derive(Clone, Debug)]
33pub struct SessionMiddleware {
34    analytics_service: Arc<AnalyticsService>,
35    session_creation_service: Arc<SessionCreationService>,
36}
37
38impl SessionMiddleware {
39    pub fn new(ctx: &AppContext) -> anyhow::Result<Self> {
40        let user_service = UserService::new(ctx.db_pool())?;
41        let concrete = Arc::clone(ctx.analytics_service());
42        let analytics: Arc<dyn AnalyticsProvider> = concrete;
43        let session_creation_service = Arc::new(SessionCreationService::new(
44            analytics,
45            Arc::new(UserProviderImpl::new(user_service)),
46        ));
47
48        Ok(Self {
49            analytics_service: Arc::clone(ctx.analytics_service()),
50            session_creation_service,
51        })
52    }
53
54    pub async fn handle(&self, mut request: Request, next: Next) -> Result<Response, ApiError> {
55        let headers = request.headers();
56        let uri = request.uri().clone();
57        let method = request.method().clone();
58
59        let should_skip = should_skip_session_tracking(uri.path());
60
61        tracing::debug!(
62            path = %uri.path(),
63            should_skip = should_skip,
64            "Session middleware evaluating request"
65        );
66
67        let trace_id = HeaderExtractor::extract_trace_id(headers);
68
69        let (req_ctx, jwt_cookie) = if should_skip {
70            (
71                self.anonymous_context("untracked", trace_id, headers, &uri)
72                    .await?,
73                None,
74            )
75        } else {
76            self.tracked_context(trace_id, headers, &uri, &method)
77                .await?
78        };
79
80        tracing::debug!(
81            path = %uri.path(),
82            session_id = %req_ctx.session_id(),
83            "Session middleware setting context"
84        );
85
86        request.extensions_mut().insert(req_ctx);
87
88        let mut response = next.run(request).await;
89
90        if let Some(token) = jwt_cookie {
91            let cookie = format!(
92                "{}={token}; HttpOnly; SameSite=Strict; Path=/; Max-Age=604800",
93                CookieExtractor::DEFAULT_COOKIE_NAME
94            );
95            if let Ok(cookie_value) = cookie.parse() {
96                response
97                    .headers_mut()
98                    .insert(header::SET_COOKIE, cookie_value);
99            }
100        }
101
102        Ok(response)
103    }
104
105    /// Builds an untracked anonymous context. `session_prefix` distinguishes
106    /// the synthetic session id (`untracked_*` for skip-tracking paths,
107    /// `bot_*` for detected crawlers) so the two cases stay legible in logs
108    /// and analytics without two near-identical constructors.
109    async fn anonymous_context(
110        &self,
111        session_prefix: &str,
112        trace_id: systemprompt_identifiers::TraceId,
113        headers: &http::HeaderMap,
114        uri: &http::Uri,
115    ) -> Result<RequestContext, ApiError> {
116        let (user_id, fingerprint) = self
117            .session_creation_service
118            .ensure_anonymous_user(headers, Some(uri))
119            .await
120            .map_err(|e| {
121                tracing::error!(error = %e, session_prefix, "Failed to ensure anonymous user");
122                ApiError::internal_error("Service temporarily unavailable")
123            })?;
124
125        Ok(RequestContext::new(
126            SessionId::new(format!("{session_prefix}_{}", Uuid::new_v4())),
127            trace_id,
128            ContextId::generate(),
129            AgentName::system(),
130        )
131        .with_actor(systemprompt_identifiers::Actor::anonymous(user_id))
132        .with_user_type(UserType::Anon)
133        .with_tracked(false)
134        .with_fingerprint_hash(fingerprint))
135    }
136
137    async fn tracked_context(
138        &self,
139        trace_id: systemprompt_identifiers::TraceId,
140        headers: &http::HeaderMap,
141        uri: &http::Uri,
142        method: &http::Method,
143    ) -> Result<(RequestContext, Option<String>), ApiError> {
144        let analytics = self.analytics_service.extract_analytics(headers, Some(uri));
145        let is_bot = AnalyticsService::is_bot(&analytics);
146
147        tracing::debug!(
148            path = %uri.path(),
149            is_bot = is_bot,
150            user_agent = ?analytics.user_agent,
151            "Session middleware bot check"
152        );
153
154        if is_bot {
155            return Ok((
156                self.anonymous_context("bot", trace_id, headers, uri)
157                    .await?,
158                None,
159            ));
160        }
161
162        let token_result = TokenExtractor::browser_only().extract(headers).ok();
163
164        let (session_id, user_id, jwt_token, jwt_cookie, fingerprint_hash) = self
165            .resolve_session(token_result, headers, uri, method)
166            .await?;
167
168        let context_id =
169            HeaderExtractor::extract_context_id(headers).unwrap_or_else(ContextId::generate);
170
171        let mut ctx = RequestContext::new(session_id, trace_id, context_id, AgentName::system())
172            .with_actor(systemprompt_identifiers::Actor::user(user_id))
173            .with_auth_token(jwt_token)
174            .with_user_type(UserType::Anon)
175            .with_tracked(true);
176        if let Some(fp) = fingerprint_hash {
177            ctx = ctx.with_fingerprint_hash(fp);
178        }
179        Ok((ctx, jwt_cookie))
180    }
181
182    async fn resolve_session(
183        &self,
184        token_result: Option<String>,
185        headers: &http::HeaderMap,
186        uri: &http::Uri,
187        method: &http::Method,
188    ) -> Result<(SessionId, UserId, String, Option<String>, Option<String>), ApiError> {
189        let Some(token) = token_result else {
190            let (sid, uid, token, is_new, fp) =
191                lifecycle::create_new_session(&self.session_creation_service, headers, uri, method)
192                    .await?;
193            let jwt_cookie = if is_new { Some(token.clone()) } else { None };
194            return Ok((sid, uid, token, jwt_cookie, Some(fp)));
195        };
196
197        let Ok(jwt_context) = extract_user_context(&token) else {
198            let (sid, uid, token, is_new, fp) =
199                lifecycle::create_new_session(&self.session_creation_service, headers, uri, method)
200                    .await?;
201            let jwt_cookie = if is_new { Some(token.clone()) } else { None };
202            return Ok((sid, uid, token, jwt_cookie, Some(fp)));
203        };
204
205        let session_exists = self
206            .analytics_service
207            .find_active_session_by_id(&jwt_context.session_id)
208            .await
209            .map_err(|e| {
210                tracing::warn!(error = %e, "find_active_session_by_id failed");
211                e
212            })
213            .ok()
214            .flatten()
215            .is_some();
216
217        if session_exists {
218            return Ok((
219                jwt_context.session_id,
220                jwt_context.user_id,
221                token,
222                None,
223                None,
224            ));
225        }
226
227        tracing::info!(
228            old_session_id = %jwt_context.session_id,
229            user_id = %jwt_context.user_id,
230            "JWT valid but session missing, refreshing with new session"
231        );
232        match lifecycle::refresh_session_for_user(
233            &self.session_creation_service,
234            &self.analytics_service,
235            &jwt_context.user_id,
236            headers,
237            uri,
238        )
239        .await
240        {
241            Ok((sid, uid, new_token, _, fp)) => {
242                Ok((sid, uid, new_token.clone(), Some(new_token), Some(fp)))
243            },
244            Err(e) if e.error_key.as_deref() == Some("user_not_found") => {
245                tracing::warn!(
246                    user_id = %jwt_context.user_id,
247                    "JWT references non-existent user, creating new anonymous session"
248                );
249                let (sid, uid, token, _, fp) = lifecycle::create_new_session(
250                    &self.session_creation_service,
251                    headers,
252                    uri,
253                    method,
254                )
255                .await?;
256                Ok((sid, uid, token.clone(), Some(token), Some(fp)))
257            },
258            Err(e) => Err(e),
259        }
260    }
261}