Skip to main content

systemprompt_api/services/middleware/
session.rs

1use axum::extract::Request;
2use axum::http::header;
3use axum::middleware::Next;
4use axum::response::Response;
5use std::sync::Arc;
6use systemprompt_analytics::AnalyticsService;
7use systemprompt_identifiers::{AgentName, ClientId, ContextId, SessionId, SessionSource, UserId};
8use systemprompt_models::api::ApiError;
9use systemprompt_models::auth::UserType;
10use systemprompt_models::execution::context::RequestContext;
11use systemprompt_models::modules::ApiPaths;
12use systemprompt_oauth::services::{
13    CreateAnonymousSessionInput, SessionCreationError, SessionCreationService,
14};
15use systemprompt_runtime::AppContext;
16use systemprompt_security::{HeaderExtractor, TokenExtractor};
17use systemprompt_traits::AnalyticsProvider;
18use systemprompt_users::{UserProviderImpl, UserService};
19use uuid::Uuid;
20
21use super::jwt::JwtExtractor;
22
23#[derive(Clone, Debug)]
24pub struct SessionMiddleware {
25    jwt_extractor: Arc<JwtExtractor>,
26    analytics_service: Arc<AnalyticsService>,
27    session_creation_service: Arc<SessionCreationService>,
28}
29
30impl SessionMiddleware {
31    pub fn new(ctx: &AppContext) -> anyhow::Result<Self> {
32        let jwt_secret = systemprompt_models::SecretsBootstrap::jwt_secret()?;
33        let jwt_extractor = Arc::new(JwtExtractor::new(jwt_secret));
34        let user_service = UserService::new(ctx.db_pool())?;
35        let concrete = Arc::clone(ctx.analytics_service());
36        let analytics: Arc<dyn AnalyticsProvider> = concrete;
37        let session_creation_service = Arc::new(SessionCreationService::new(
38            analytics,
39            Arc::new(UserProviderImpl::new(user_service)),
40        ));
41
42        Ok(Self {
43            jwt_extractor,
44            analytics_service: Arc::clone(ctx.analytics_service()),
45            session_creation_service,
46        })
47    }
48
49    pub async fn handle(&self, mut request: Request, next: Next) -> Result<Response, ApiError> {
50        let headers = request.headers();
51        let uri = request.uri().clone();
52        let method = request.method().clone();
53
54        let should_skip = Self::should_skip_session_tracking(uri.path());
55
56        tracing::debug!(
57            path = %uri.path(),
58            should_skip = should_skip,
59            "Session middleware evaluating request"
60        );
61
62        let trace_id = HeaderExtractor::extract_trace_id(headers);
63
64        let (req_ctx, jwt_cookie) = if should_skip {
65            let ctx = RequestContext::new(
66                SessionId::new(format!("untracked_{}", Uuid::new_v4())),
67                trace_id,
68                ContextId::new(String::new()),
69                AgentName::system(),
70            )
71            .with_user_id(UserId::new("anonymous".to_string()))
72            .with_user_type(UserType::Anon)
73            .with_tracked(false);
74            (ctx, None)
75        } else {
76            let analytics = self
77                .analytics_service
78                .extract_analytics(headers, Some(&uri));
79            let is_bot = AnalyticsService::is_bot(&analytics);
80
81            tracing::debug!(
82                path = %uri.path(),
83                is_bot = is_bot,
84                user_agent = ?analytics.user_agent,
85                "Session middleware bot check"
86            );
87
88            if is_bot {
89                let ctx = RequestContext::new(
90                    SessionId::new(format!("bot_{}", Uuid::new_v4())),
91                    trace_id,
92                    ContextId::new(String::new()),
93                    AgentName::system(),
94                )
95                .with_user_id(UserId::new("bot".to_string()))
96                .with_user_type(UserType::Anon)
97                .with_tracked(false);
98                (ctx, None)
99            } else {
100                let token_result = TokenExtractor::browser_only().extract(headers).ok();
101
102                let (session_id, user_id, jwt_token, jwt_cookie, fingerprint_hash) = if let Some(
103                    token,
104                ) =
105                    token_result
106                {
107                    if let Ok(jwt_context) = self.jwt_extractor.extract_user_context(&token) {
108                        let session_exists = self
109                            .analytics_service
110                            .find_session_by_id(&jwt_context.session_id)
111                            .await
112                            .ok()
113                            .flatten()
114                            .is_some();
115
116                        if session_exists {
117                            (
118                                jwt_context.session_id,
119                                jwt_context.user_id,
120                                token,
121                                None,
122                                None,
123                            )
124                        } else {
125                            tracing::info!(
126                                old_session_id = %jwt_context.session_id,
127                                user_id = %jwt_context.user_id,
128                                "JWT valid but session missing, refreshing with new session"
129                            );
130                            match self
131                                .refresh_session_for_user(&jwt_context.user_id, headers, &uri)
132                                .await
133                            {
134                                Ok((sid, uid, new_token, _, fp)) => {
135                                    (sid, uid, new_token.clone(), Some(new_token), Some(fp))
136                                },
137                                Err(e) if e.error_key.as_deref() == Some("user_not_found") => {
138                                    tracing::warn!(
139                                        user_id = %jwt_context.user_id,
140                                        "JWT references non-existent user, creating new anonymous session"
141                                    );
142                                    let (sid, uid, token, _, fp) =
143                                        self.create_new_session(headers, &uri, &method).await?;
144                                    (sid, uid, token.clone(), Some(token), Some(fp))
145                                },
146                                Err(e) => return Err(e),
147                            }
148                        }
149                    } else {
150                        let (sid, uid, token, is_new, fp) =
151                            self.create_new_session(headers, &uri, &method).await?;
152                        let jwt_cookie = if is_new { Some(token.clone()) } else { None };
153                        (sid, uid, token, jwt_cookie, Some(fp))
154                    }
155                } else {
156                    let (sid, uid, token, is_new, fp) =
157                        self.create_new_session(headers, &uri, &method).await?;
158                    let jwt_cookie = if is_new { Some(token.clone()) } else { None };
159                    (sid, uid, token, jwt_cookie, Some(fp))
160                };
161
162                let mut ctx = RequestContext::new(
163                    session_id,
164                    trace_id,
165                    ContextId::new(String::new()),
166                    AgentName::system(),
167                )
168                .with_user_id(user_id)
169                .with_auth_token(jwt_token)
170                .with_user_type(UserType::Anon)
171                .with_tracked(true);
172                if let Some(fp) = fingerprint_hash {
173                    ctx = ctx.with_fingerprint_hash(fp);
174                }
175                (ctx, jwt_cookie)
176            }
177        };
178
179        tracing::debug!(
180            path = %uri.path(),
181            session_id = %req_ctx.session_id(),
182            "Session middleware setting context"
183        );
184
185        request.extensions_mut().insert(req_ctx);
186
187        let mut response = next.run(request).await;
188
189        if let Some(token) = jwt_cookie {
190            let cookie =
191                format!("access_token={token}; HttpOnly; SameSite=Lax; Path=/; Max-Age=604800");
192            if let Ok(cookie_value) = cookie.parse() {
193                response
194                    .headers_mut()
195                    .insert(header::SET_COOKIE, cookie_value);
196            }
197        }
198
199        Ok(response)
200    }
201
202    async fn create_new_session(
203        &self,
204        headers: &http::HeaderMap,
205        uri: &http::Uri,
206        _method: &http::Method,
207    ) -> Result<(SessionId, UserId, String, bool, String), ApiError> {
208        let client_id = ClientId::new("sp_web".to_string());
209
210        let jwt_secret = systemprompt_models::SecretsBootstrap::jwt_secret().map_err(|e| {
211            tracing::error!(error = %e, "Failed to get JWT secret during session creation");
212            ApiError::internal_error("Failed to initialize session")
213        })?;
214
215        self.session_creation_service
216            .create_anonymous_session(CreateAnonymousSessionInput {
217                headers,
218                uri: Some(uri),
219                client_id: &client_id,
220                jwt_secret,
221                session_source: SessionSource::Web,
222            })
223            .await
224            .map(|session_info| {
225                (
226                    session_info.session_id,
227                    session_info.user_id,
228                    session_info.jwt_token,
229                    session_info.is_new,
230                    session_info.fingerprint_hash,
231                )
232            })
233            .map_err(|e| {
234                tracing::error!(error = %e, "Failed to create anonymous session");
235                ApiError::internal_error("Service temporarily unavailable")
236            })
237    }
238
239    async fn refresh_session_for_user(
240        &self,
241        user_id: &UserId,
242        headers: &http::HeaderMap,
243        uri: &http::Uri,
244    ) -> Result<(SessionId, UserId, String, bool, String), ApiError> {
245        let session_id = self
246            .session_creation_service
247            .create_authenticated_session(user_id, headers, SessionSource::Web)
248            .await
249            .map_err(|e| match e {
250                SessionCreationError::UserNotFound { ref user_id } => {
251                    ApiError::not_found(format!("User not found: {}", user_id))
252                        .with_error_key("user_not_found")
253                }
254                SessionCreationError::Internal(ref msg) => {
255                    tracing::error!(error = %msg, user_id = %user_id, "Failed to create session for user");
256                    ApiError::internal_error("Failed to refresh session")
257                }
258            })?;
259
260        let jwt_secret = systemprompt_models::SecretsBootstrap::jwt_secret().map_err(|e| {
261            tracing::error!(error = %e, "Failed to get JWT secret during session refresh");
262            ApiError::internal_error("Failed to refresh session")
263        })?;
264
265        let config = systemprompt_models::Config::get().map_err(|e| {
266            tracing::error!(error = %e, "Failed to get config during session refresh");
267            ApiError::internal_error("Failed to refresh session")
268        })?;
269
270        let token = systemprompt_oauth::services::generation::generate_anonymous_jwt(
271            user_id,
272            &session_id,
273            &ClientId::new("sp_web".to_string()),
274            &systemprompt_oauth::services::JwtSigningParams {
275                secret: jwt_secret,
276                issuer: &config.jwt_issuer,
277            },
278        )
279        .map_err(|e| {
280            tracing::error!(error = %e, "Failed to generate JWT during session refresh");
281            ApiError::internal_error("Failed to refresh session")
282        })?;
283
284        let analytics = self.analytics_service.extract_analytics(headers, Some(uri));
285        let fingerprint = AnalyticsService::compute_fingerprint(&analytics);
286
287        Ok((session_id, user_id.clone(), token, true, fingerprint))
288    }
289
290    pub fn should_skip_session_tracking(path: &str) -> bool {
291        if path.starts_with(ApiPaths::TRACK_BASE) {
292            return false;
293        }
294
295        if path.starts_with(ApiPaths::MCP_BASE) {
296            return true;
297        }
298
299        if path.starts_with(ApiPaths::API_BASE) {
300            return true;
301        }
302
303        if path.starts_with(ApiPaths::NEXT_BASE) {
304            return true;
305        }
306
307        if path.starts_with(ApiPaths::STATIC_BASE)
308            || path.starts_with(ApiPaths::ASSETS_BASE)
309            || path.starts_with(ApiPaths::IMAGES_BASE)
310        {
311            return true;
312        }
313
314        if path == "/health" || path == "/ready" || path == "/healthz" {
315            return true;
316        }
317
318        if path == "/favicon.ico"
319            || path == "/robots.txt"
320            || path == "/sitemap.xml"
321            || path == "/manifest.json"
322        {
323            return true;
324        }
325
326        if let Some(last_segment) = path.rsplit('/').next() {
327            if last_segment.contains('.') {
328                let extension = last_segment.rsplit('.').next().unwrap_or("");
329                match extension {
330                    "html" | "htm" => {},
331                    _ => return true,
332                }
333            }
334        }
335
336        false
337    }
338}