Skip to main content

systemprompt_api/services/middleware/
session.rs

1mod lifecycle;
2mod skip;
3
4pub use skip::should_skip_session_tracking;
5
6use axum::extract::Request;
7use axum::http::header;
8use axum::middleware::Next;
9use axum::response::Response;
10use std::sync::Arc;
11use systemprompt_analytics::AnalyticsService;
12use systemprompt_identifiers::{AgentName, ContextId, SessionId, UserId};
13use systemprompt_models::api::ApiError;
14use systemprompt_models::auth::UserType;
15use systemprompt_models::execution::context::RequestContext;
16use systemprompt_oauth::services::SessionCreationService;
17use systemprompt_runtime::AppContext;
18use systemprompt_security::{HeaderExtractor, TokenExtractor};
19use systemprompt_traits::AnalyticsProvider;
20use systemprompt_users::{UserProviderImpl, UserService};
21use uuid::Uuid;
22
23use super::jwt::JwtExtractor;
24
25#[derive(Clone, Debug)]
26pub struct SessionMiddleware {
27    jwt_extractor: Arc<JwtExtractor>,
28    analytics_service: Arc<AnalyticsService>,
29    session_creation_service: Arc<SessionCreationService>,
30}
31
32impl SessionMiddleware {
33    pub fn new(ctx: &AppContext) -> anyhow::Result<Self> {
34        let jwt_secret = systemprompt_config::SecretsBootstrap::jwt_secret()?;
35        let jwt_extractor = Arc::new(JwtExtractor::new(jwt_secret));
36        let user_service = UserService::new(ctx.db_pool())?;
37        let concrete = Arc::clone(ctx.analytics_service());
38        let analytics: Arc<dyn AnalyticsProvider> = concrete;
39        let session_creation_service = Arc::new(SessionCreationService::new(
40            analytics,
41            Arc::new(UserProviderImpl::new(user_service)),
42        ));
43
44        Ok(Self {
45            jwt_extractor,
46            analytics_service: Arc::clone(ctx.analytics_service()),
47            session_creation_service,
48        })
49    }
50
51    pub async fn handle(&self, mut request: Request, next: Next) -> Result<Response, ApiError> {
52        let headers = request.headers();
53        let uri = request.uri().clone();
54        let method = request.method().clone();
55
56        let should_skip = should_skip_session_tracking(uri.path());
57
58        tracing::debug!(
59            path = %uri.path(),
60            should_skip = should_skip,
61            "Session middleware evaluating request"
62        );
63
64        let trace_id = HeaderExtractor::extract_trace_id(headers);
65
66        let (req_ctx, jwt_cookie) = if should_skip {
67            (Self::untracked_context(trace_id), None)
68        } else {
69            self.tracked_context(trace_id, headers, &uri, &method)
70                .await?
71        };
72
73        tracing::debug!(
74            path = %uri.path(),
75            session_id = %req_ctx.session_id(),
76            "Session middleware setting context"
77        );
78
79        request.extensions_mut().insert(req_ctx);
80
81        let mut response = next.run(request).await;
82
83        if let Some(token) = jwt_cookie {
84            let cookie =
85                format!("access_token={token}; HttpOnly; SameSite=Lax; Path=/; Max-Age=604800");
86            if let Ok(cookie_value) = cookie.parse() {
87                response
88                    .headers_mut()
89                    .insert(header::SET_COOKIE, cookie_value);
90            }
91        }
92
93        Ok(response)
94    }
95
96    fn untracked_context(trace_id: systemprompt_identifiers::TraceId) -> RequestContext {
97        RequestContext::new(
98            SessionId::new(format!("untracked_{}", Uuid::new_v4())),
99            trace_id,
100            ContextId::generate(),
101            AgentName::system(),
102        )
103        .with_user_id(UserId::new("anonymous".to_string()))
104        .with_user_type(UserType::Anon)
105        .with_tracked(false)
106    }
107
108    fn bot_context(trace_id: systemprompt_identifiers::TraceId) -> RequestContext {
109        RequestContext::new(
110            SessionId::new(format!("bot_{}", Uuid::new_v4())),
111            trace_id,
112            ContextId::generate(),
113            AgentName::system(),
114        )
115        .with_user_id(UserId::new("bot".to_string()))
116        .with_user_type(UserType::Anon)
117        .with_tracked(false)
118    }
119
120    async fn tracked_context(
121        &self,
122        trace_id: systemprompt_identifiers::TraceId,
123        headers: &http::HeaderMap,
124        uri: &http::Uri,
125        method: &http::Method,
126    ) -> Result<(RequestContext, Option<String>), ApiError> {
127        let analytics = self.analytics_service.extract_analytics(headers, Some(uri));
128        let is_bot = AnalyticsService::is_bot(&analytics);
129
130        tracing::debug!(
131            path = %uri.path(),
132            is_bot = is_bot,
133            user_agent = ?analytics.user_agent,
134            "Session middleware bot check"
135        );
136
137        if is_bot {
138            return Ok((Self::bot_context(trace_id), None));
139        }
140
141        let token_result = TokenExtractor::browser_only().extract(headers).ok();
142
143        let (session_id, user_id, jwt_token, jwt_cookie, fingerprint_hash) = self
144            .resolve_session(token_result, headers, uri, method)
145            .await?;
146
147        let context_id =
148            HeaderExtractor::extract_context_id(headers).unwrap_or_else(ContextId::generate);
149
150        let mut ctx = RequestContext::new(session_id, trace_id, context_id, AgentName::system())
151            .with_user_id(user_id)
152            .with_auth_token(jwt_token)
153            .with_user_type(UserType::Anon)
154            .with_tracked(true);
155        if let Some(fp) = fingerprint_hash {
156            ctx = ctx.with_fingerprint_hash(fp);
157        }
158        Ok((ctx, jwt_cookie))
159    }
160
161    async fn resolve_session(
162        &self,
163        token_result: Option<String>,
164        headers: &http::HeaderMap,
165        uri: &http::Uri,
166        method: &http::Method,
167    ) -> Result<(SessionId, UserId, String, Option<String>, Option<String>), ApiError> {
168        let Some(token) = token_result else {
169            let (sid, uid, token, is_new, fp) =
170                lifecycle::create_new_session(&self.session_creation_service, headers, uri, method)
171                    .await?;
172            let jwt_cookie = if is_new { Some(token.clone()) } else { None };
173            return Ok((sid, uid, token, jwt_cookie, Some(fp)));
174        };
175
176        let Ok(jwt_context) = self.jwt_extractor.extract_user_context(&token) else {
177            let (sid, uid, token, is_new, fp) =
178                lifecycle::create_new_session(&self.session_creation_service, headers, uri, method)
179                    .await?;
180            let jwt_cookie = if is_new { Some(token.clone()) } else { None };
181            return Ok((sid, uid, token, jwt_cookie, Some(fp)));
182        };
183
184        let session_exists = self
185            .analytics_service
186            .find_session_by_id(&jwt_context.session_id)
187            .await
188            .map_err(|e| {
189                tracing::warn!(error = %e, "find_session_by_id failed");
190                e
191            })
192            .ok()
193            .flatten()
194            .is_some();
195
196        if session_exists {
197            return Ok((
198                jwt_context.session_id,
199                jwt_context.user_id,
200                token,
201                None,
202                None,
203            ));
204        }
205
206        tracing::info!(
207            old_session_id = %jwt_context.session_id,
208            user_id = %jwt_context.user_id,
209            "JWT valid but session missing, refreshing with new session"
210        );
211        match lifecycle::refresh_session_for_user(
212            &self.session_creation_service,
213            &self.analytics_service,
214            &jwt_context.user_id,
215            headers,
216            uri,
217        )
218        .await
219        {
220            Ok((sid, uid, new_token, _, fp)) => {
221                Ok((sid, uid, new_token.clone(), Some(new_token), Some(fp)))
222            },
223            Err(e) if e.error_key.as_deref() == Some("user_not_found") => {
224                tracing::warn!(
225                    user_id = %jwt_context.user_id,
226                    "JWT references non-existent user, creating new anonymous session"
227                );
228                let (sid, uid, token, _, fp) = lifecycle::create_new_session(
229                    &self.session_creation_service,
230                    headers,
231                    uri,
232                    method,
233                )
234                .await?;
235                Ok((sid, uid, token.clone(), Some(token), Some(fp)))
236            },
237            Err(e) => Err(e),
238        }
239    }
240}