Skip to main content

systemprompt_api/services/middleware/
session.rs

1use axum::extract::Request;
2use axum::http::header;
3use axum::middleware::Next;
4use axum::response::Response;
5use std::sync::Arc;
6use systemprompt_analytics::AnalyticsService;
7use systemprompt_identifiers::{AgentName, ClientId, ContextId, SessionId, SessionSource, UserId};
8use systemprompt_models::api::ApiError;
9use systemprompt_models::auth::UserType;
10use systemprompt_models::execution::context::RequestContext;
11use systemprompt_models::modules::ApiPaths;
12use systemprompt_oauth::services::{
13    CreateAnonymousSessionInput, SessionCreationError, SessionCreationService,
14};
15use systemprompt_runtime::AppContext;
16use systemprompt_security::{HeaderExtractor, TokenExtractor};
17use systemprompt_traits::AnalyticsProvider;
18use systemprompt_users::{UserProviderImpl, UserService};
19use uuid::Uuid;
20
21use super::jwt::JwtExtractor;
22
23#[derive(Clone, Debug)]
24pub struct SessionMiddleware {
25    jwt_extractor: Arc<JwtExtractor>,
26    analytics_service: Arc<AnalyticsService>,
27    session_creation_service: Arc<SessionCreationService>,
28}
29
30impl SessionMiddleware {
31    pub fn new(ctx: &AppContext) -> anyhow::Result<Self> {
32        let jwt_secret = systemprompt_models::SecretsBootstrap::jwt_secret()?;
33        let jwt_extractor = Arc::new(JwtExtractor::new(jwt_secret));
34        let user_service = UserService::new(ctx.db_pool())?;
35        let session_creation_service = Arc::new(SessionCreationService::new(
36            ctx.analytics_service().clone(),
37            Arc::new(UserProviderImpl::new(user_service)),
38        ));
39
40        Ok(Self {
41            jwt_extractor,
42            analytics_service: ctx.analytics_service().clone(),
43            session_creation_service,
44        })
45    }
46
47    pub async fn handle(&self, mut request: Request, next: Next) -> Result<Response, ApiError> {
48        let headers = request.headers();
49        let uri = request.uri().clone();
50        let method = request.method().clone();
51
52        let should_skip = Self::should_skip_session_tracking(uri.path());
53
54        tracing::debug!(
55            path = %uri.path(),
56            should_skip = should_skip,
57            "Session middleware evaluating request"
58        );
59
60        let trace_id = HeaderExtractor::extract_trace_id(headers);
61
62        let (req_ctx, jwt_cookie) = if should_skip {
63            let ctx = RequestContext::new(
64                SessionId::new(format!("untracked_{}", Uuid::new_v4())),
65                trace_id,
66                ContextId::new(String::new()),
67                AgentName::system(),
68            )
69            .with_user_id(UserId::new("anonymous".to_string()))
70            .with_user_type(UserType::Anon)
71            .with_tracked(false);
72            (ctx, None)
73        } else {
74            let analytics = self
75                .analytics_service
76                .extract_analytics(headers, Some(&uri));
77            let is_bot = AnalyticsService::is_bot(&analytics);
78
79            tracing::debug!(
80                path = %uri.path(),
81                is_bot = is_bot,
82                user_agent = ?analytics.user_agent,
83                "Session middleware bot check"
84            );
85
86            if is_bot {
87                let ctx = RequestContext::new(
88                    SessionId::new(format!("bot_{}", Uuid::new_v4())),
89                    trace_id,
90                    ContextId::new(String::new()),
91                    AgentName::system(),
92                )
93                .with_user_id(UserId::new("bot".to_string()))
94                .with_user_type(UserType::Anon)
95                .with_tracked(false);
96                (ctx, None)
97            } else {
98                let token_result = TokenExtractor::browser_only().extract(headers).ok();
99
100                let (session_id, user_id, jwt_token, jwt_cookie, fingerprint_hash) = if let Some(
101                    token,
102                ) =
103                    token_result
104                {
105                    if let Ok(jwt_context) = self.jwt_extractor.extract_user_context(&token) {
106                        let session_exists = self
107                            .analytics_service
108                            .find_session_by_id(&jwt_context.session_id)
109                            .await
110                            .ok()
111                            .flatten()
112                            .is_some();
113
114                        if session_exists {
115                            (
116                                jwt_context.session_id,
117                                jwt_context.user_id,
118                                token,
119                                None,
120                                None,
121                            )
122                        } else {
123                            tracing::info!(
124                                old_session_id = %jwt_context.session_id,
125                                user_id = %jwt_context.user_id,
126                                "JWT valid but session missing, refreshing with new session"
127                            );
128                            match self
129                                .refresh_session_for_user(&jwt_context.user_id, headers, &uri)
130                                .await
131                            {
132                                Ok((sid, uid, new_token, _, fp)) => {
133                                    (sid, uid, new_token.clone(), Some(new_token), Some(fp))
134                                },
135                                Err(e) if e.error_key.as_deref() == Some("user_not_found") => {
136                                    tracing::warn!(
137                                        user_id = %jwt_context.user_id,
138                                        "JWT references non-existent user, creating new anonymous session"
139                                    );
140                                    let (sid, uid, token, _, fp) =
141                                        self.create_new_session(headers, &uri, &method).await?;
142                                    (sid, uid, token.clone(), Some(token), Some(fp))
143                                },
144                                Err(e) => return Err(e),
145                            }
146                        }
147                    } else {
148                        let (sid, uid, token, is_new, fp) =
149                            self.create_new_session(headers, &uri, &method).await?;
150                        let jwt_cookie = if is_new { Some(token.clone()) } else { None };
151                        (sid, uid, token, jwt_cookie, Some(fp))
152                    }
153                } else {
154                    let (sid, uid, token, is_new, fp) =
155                        self.create_new_session(headers, &uri, &method).await?;
156                    let jwt_cookie = if is_new { Some(token.clone()) } else { None };
157                    (sid, uid, token, jwt_cookie, Some(fp))
158                };
159
160                let mut ctx = RequestContext::new(
161                    session_id,
162                    trace_id,
163                    ContextId::new(String::new()),
164                    AgentName::system(),
165                )
166                .with_user_id(user_id)
167                .with_auth_token(jwt_token)
168                .with_user_type(UserType::Anon)
169                .with_tracked(true);
170                if let Some(fp) = fingerprint_hash {
171                    ctx = ctx.with_fingerprint_hash(fp);
172                }
173                (ctx, jwt_cookie)
174            }
175        };
176
177        tracing::debug!(
178            path = %uri.path(),
179            session_id = %req_ctx.session_id(),
180            "Session middleware setting context"
181        );
182
183        request.extensions_mut().insert(req_ctx);
184
185        let mut response = next.run(request).await;
186
187        if let Some(token) = jwt_cookie {
188            let cookie =
189                format!("access_token={token}; HttpOnly; SameSite=Lax; Path=/; Max-Age=604800");
190            if let Ok(cookie_value) = cookie.parse() {
191                response
192                    .headers_mut()
193                    .insert(header::SET_COOKIE, cookie_value);
194            }
195        }
196
197        Ok(response)
198    }
199
200    async fn create_new_session(
201        &self,
202        headers: &http::HeaderMap,
203        uri: &http::Uri,
204        _method: &http::Method,
205    ) -> Result<(SessionId, UserId, String, bool, String), ApiError> {
206        let client_id = ClientId::new("sp_web".to_string());
207
208        let jwt_secret = systemprompt_models::SecretsBootstrap::jwt_secret().map_err(|e| {
209            tracing::error!(error = %e, "Failed to get JWT secret during session creation");
210            ApiError::internal_error("Failed to initialize session")
211        })?;
212
213        self.session_creation_service
214            .create_anonymous_session(CreateAnonymousSessionInput {
215                headers,
216                uri: Some(uri),
217                client_id: &client_id,
218                jwt_secret,
219                session_source: SessionSource::Web,
220            })
221            .await
222            .map(|session_info| {
223                (
224                    session_info.session_id,
225                    session_info.user_id,
226                    session_info.jwt_token,
227                    session_info.is_new,
228                    session_info.fingerprint_hash,
229                )
230            })
231            .map_err(|e| {
232                tracing::error!(error = %e, "Failed to create anonymous session");
233                ApiError::internal_error("Service temporarily unavailable")
234            })
235    }
236
237    async fn refresh_session_for_user(
238        &self,
239        user_id: &UserId,
240        headers: &http::HeaderMap,
241        uri: &http::Uri,
242    ) -> Result<(SessionId, UserId, String, bool, String), ApiError> {
243        let session_id = self
244            .session_creation_service
245            .create_authenticated_session(user_id, headers, SessionSource::Web)
246            .await
247            .map_err(|e| match e {
248                SessionCreationError::UserNotFound { ref user_id } => {
249                    ApiError::not_found(format!("User not found: {}", user_id))
250                        .with_error_key("user_not_found")
251                }
252                SessionCreationError::Internal(ref msg) => {
253                    tracing::error!(error = %msg, user_id = %user_id, "Failed to create session for user");
254                    ApiError::internal_error("Failed to refresh session")
255                }
256            })?;
257
258        let jwt_secret = systemprompt_models::SecretsBootstrap::jwt_secret().map_err(|e| {
259            tracing::error!(error = %e, "Failed to get JWT secret during session refresh");
260            ApiError::internal_error("Failed to refresh session")
261        })?;
262
263        let config = systemprompt_models::Config::get().map_err(|e| {
264            tracing::error!(error = %e, "Failed to get config during session refresh");
265            ApiError::internal_error("Failed to refresh session")
266        })?;
267
268        let token = systemprompt_oauth::services::generation::generate_anonymous_jwt(
269            user_id.as_str(),
270            session_id.as_str(),
271            &ClientId::new("sp_web".to_string()),
272            &systemprompt_oauth::services::JwtSigningParams {
273                secret: jwt_secret,
274                issuer: &config.jwt_issuer,
275            },
276        )
277        .map_err(|e| {
278            tracing::error!(error = %e, "Failed to generate JWT during session refresh");
279            ApiError::internal_error("Failed to refresh session")
280        })?;
281
282        let analytics = self.analytics_service.extract_analytics(headers, Some(uri));
283        let fingerprint = AnalyticsService::compute_fingerprint(&analytics);
284
285        Ok((session_id, user_id.clone(), token, true, fingerprint))
286    }
287
288    fn should_skip_session_tracking(path: &str) -> bool {
289        if path.starts_with(ApiPaths::TRACK_BASE) {
290            return false;
291        }
292
293        if path.starts_with(ApiPaths::MCP_BASE) {
294            return true;
295        }
296
297        if path.starts_with(ApiPaths::API_BASE) {
298            return true;
299        }
300
301        if path.starts_with(ApiPaths::NEXT_BASE) {
302            return true;
303        }
304
305        if path.starts_with(ApiPaths::STATIC_BASE)
306            || path.starts_with(ApiPaths::ASSETS_BASE)
307            || path.starts_with(ApiPaths::IMAGES_BASE)
308        {
309            return true;
310        }
311
312        if path == "/health" || path == "/ready" || path == "/healthz" {
313            return true;
314        }
315
316        if path == "/favicon.ico"
317            || path == "/robots.txt"
318            || path == "/sitemap.xml"
319            || path == "/manifest.json"
320        {
321            return true;
322        }
323
324        if let Some(last_segment) = path.rsplit('/').next() {
325            if last_segment.contains('.') {
326                let extension = last_segment.rsplit('.').next().unwrap_or("");
327                match extension {
328                    "html" | "htm" => {},
329                    _ => return true,
330                }
331            }
332        }
333
334        false
335    }
336}