Skip to main content

systemprompt_api/services/middleware/
session.rs

1mod lifecycle;
2mod skip;
3
4pub use skip::should_skip_session_tracking;
5
6use axum::extract::Request;
7use axum::http::header;
8use axum::middleware::Next;
9use axum::response::Response;
10use std::sync::Arc;
11use systemprompt_analytics::AnalyticsService;
12use systemprompt_identifiers::{AgentName, ContextId, SessionId, UserId};
13use systemprompt_models::api::ApiError;
14use systemprompt_models::auth::UserType;
15use systemprompt_models::execution::context::RequestContext;
16use systemprompt_oauth::services::SessionCreationService;
17use systemprompt_runtime::AppContext;
18use systemprompt_security::{HeaderExtractor, TokenExtractor};
19use systemprompt_traits::AnalyticsProvider;
20use systemprompt_users::{UserProviderImpl, UserService};
21use uuid::Uuid;
22
23use super::jwt::JwtExtractor;
24
25#[derive(Clone, Debug)]
26pub struct SessionMiddleware {
27    jwt_extractor: Arc<JwtExtractor>,
28    analytics_service: Arc<AnalyticsService>,
29    session_creation_service: Arc<SessionCreationService>,
30}
31
32impl SessionMiddleware {
33    pub fn new(ctx: &AppContext) -> anyhow::Result<Self> {
34        let jwt_extractor = Arc::new(JwtExtractor::new());
35        let user_service = UserService::new(ctx.db_pool())?;
36        let concrete = Arc::clone(ctx.analytics_service());
37        let analytics: Arc<dyn AnalyticsProvider> = concrete;
38        let session_creation_service = Arc::new(SessionCreationService::new(
39            analytics,
40            Arc::new(UserProviderImpl::new(user_service)),
41        ));
42
43        Ok(Self {
44            jwt_extractor,
45            analytics_service: Arc::clone(ctx.analytics_service()),
46            session_creation_service,
47        })
48    }
49
50    pub async fn handle(&self, mut request: Request, next: Next) -> Result<Response, ApiError> {
51        let headers = request.headers();
52        let uri = request.uri().clone();
53        let method = request.method().clone();
54
55        let should_skip = should_skip_session_tracking(uri.path());
56
57        tracing::debug!(
58            path = %uri.path(),
59            should_skip = should_skip,
60            "Session middleware evaluating request"
61        );
62
63        let trace_id = HeaderExtractor::extract_trace_id(headers);
64
65        let (req_ctx, jwt_cookie) = if should_skip {
66            (Self::untracked_context(trace_id), None)
67        } else {
68            self.tracked_context(trace_id, headers, &uri, &method)
69                .await?
70        };
71
72        tracing::debug!(
73            path = %uri.path(),
74            session_id = %req_ctx.session_id(),
75            "Session middleware setting context"
76        );
77
78        request.extensions_mut().insert(req_ctx);
79
80        let mut response = next.run(request).await;
81
82        if let Some(token) = jwt_cookie {
83            let cookie =
84                format!("access_token={token}; HttpOnly; SameSite=Strict; Path=/; Max-Age=604800");
85            if let Ok(cookie_value) = cookie.parse() {
86                response
87                    .headers_mut()
88                    .insert(header::SET_COOKIE, cookie_value);
89            }
90        }
91
92        Ok(response)
93    }
94
95    fn untracked_context(trace_id: systemprompt_identifiers::TraceId) -> RequestContext {
96        RequestContext::new(
97            SessionId::new(format!("untracked_{}", Uuid::new_v4())),
98            trace_id,
99            ContextId::generate(),
100            AgentName::system(),
101        )
102        .with_actor(systemprompt_identifiers::Actor::anonymous(
103            systemprompt_identifiers::bootstrap::anonymous(),
104        ))
105        .with_user_type(UserType::Anon)
106        .with_tracked(false)
107    }
108
109    fn bot_context(trace_id: systemprompt_identifiers::TraceId) -> RequestContext {
110        RequestContext::new(
111            SessionId::new(format!("bot_{}", Uuid::new_v4())),
112            trace_id,
113            ContextId::generate(),
114            AgentName::system(),
115        )
116        .with_actor(systemprompt_identifiers::Actor::user(
117            systemprompt_identifiers::bootstrap::bot(),
118        ))
119        .with_user_type(UserType::Anon)
120        .with_tracked(false)
121    }
122
123    async fn tracked_context(
124        &self,
125        trace_id: systemprompt_identifiers::TraceId,
126        headers: &http::HeaderMap,
127        uri: &http::Uri,
128        method: &http::Method,
129    ) -> Result<(RequestContext, Option<String>), ApiError> {
130        let analytics = self.analytics_service.extract_analytics(headers, Some(uri));
131        let is_bot = AnalyticsService::is_bot(&analytics);
132
133        tracing::debug!(
134            path = %uri.path(),
135            is_bot = is_bot,
136            user_agent = ?analytics.user_agent,
137            "Session middleware bot check"
138        );
139
140        if is_bot {
141            return Ok((Self::bot_context(trace_id), None));
142        }
143
144        let token_result = TokenExtractor::browser_only().extract(headers).ok();
145
146        let (session_id, user_id, jwt_token, jwt_cookie, fingerprint_hash) = self
147            .resolve_session(token_result, headers, uri, method)
148            .await?;
149
150        let context_id =
151            HeaderExtractor::extract_context_id(headers).unwrap_or_else(ContextId::generate);
152
153        let mut ctx = RequestContext::new(session_id, trace_id, context_id, AgentName::system())
154            .with_actor(systemprompt_identifiers::Actor::user(user_id))
155            .with_auth_token(jwt_token)
156            .with_user_type(UserType::Anon)
157            .with_tracked(true);
158        if let Some(fp) = fingerprint_hash {
159            ctx = ctx.with_fingerprint_hash(fp);
160        }
161        Ok((ctx, jwt_cookie))
162    }
163
164    async fn resolve_session(
165        &self,
166        token_result: Option<String>,
167        headers: &http::HeaderMap,
168        uri: &http::Uri,
169        method: &http::Method,
170    ) -> Result<(SessionId, UserId, String, Option<String>, Option<String>), ApiError> {
171        let Some(token) = token_result else {
172            let (sid, uid, token, is_new, fp) =
173                lifecycle::create_new_session(&self.session_creation_service, headers, uri, method)
174                    .await?;
175            let jwt_cookie = if is_new { Some(token.clone()) } else { None };
176            return Ok((sid, uid, token, jwt_cookie, Some(fp)));
177        };
178
179        let Ok(jwt_context) = self.jwt_extractor.extract_user_context(&token) else {
180            let (sid, uid, token, is_new, fp) =
181                lifecycle::create_new_session(&self.session_creation_service, headers, uri, method)
182                    .await?;
183            let jwt_cookie = if is_new { Some(token.clone()) } else { None };
184            return Ok((sid, uid, token, jwt_cookie, Some(fp)));
185        };
186
187        let session_exists = self
188            .analytics_service
189            .find_active_session_by_id(&jwt_context.session_id)
190            .await
191            .map_err(|e| {
192                tracing::warn!(error = %e, "find_active_session_by_id failed");
193                e
194            })
195            .ok()
196            .flatten()
197            .is_some();
198
199        if session_exists {
200            return Ok((
201                jwt_context.session_id,
202                jwt_context.user_id,
203                token,
204                None,
205                None,
206            ));
207        }
208
209        tracing::info!(
210            old_session_id = %jwt_context.session_id,
211            user_id = %jwt_context.user_id,
212            "JWT valid but session missing, refreshing with new session"
213        );
214        match lifecycle::refresh_session_for_user(
215            &self.session_creation_service,
216            &self.analytics_service,
217            &jwt_context.user_id,
218            headers,
219            uri,
220        )
221        .await
222        {
223            Ok((sid, uid, new_token, _, fp)) => {
224                Ok((sid, uid, new_token.clone(), Some(new_token), Some(fp)))
225            },
226            Err(e) if e.error_key.as_deref() == Some("user_not_found") => {
227                tracing::warn!(
228                    user_id = %jwt_context.user_id,
229                    "JWT references non-existent user, creating new anonymous session"
230                );
231                let (sid, uid, token, _, fp) = lifecycle::create_new_session(
232                    &self.session_creation_service,
233                    headers,
234                    uri,
235                    method,
236                )
237                .await?;
238                Ok((sid, uid, token.clone(), Some(token), Some(fp)))
239            },
240            Err(e) => Err(e),
241        }
242    }
243}