Skip to main content

systemprompt_api/services/middleware/
session.rs

1use axum::extract::Request;
2use axum::http::header;
3use axum::middleware::Next;
4use axum::response::Response;
5use std::sync::Arc;
6use systemprompt_analytics::AnalyticsService;
7use systemprompt_identifiers::{AgentName, ClientId, ContextId, SessionId, SessionSource, UserId};
8use systemprompt_models::api::ApiError;
9use systemprompt_models::auth::UserType;
10use systemprompt_models::execution::context::RequestContext;
11use systemprompt_models::modules::ApiPaths;
12use systemprompt_oauth::services::{CreateAnonymousSessionInput, SessionCreationService};
13use systemprompt_runtime::AppContext;
14use systemprompt_security::{HeaderExtractor, TokenExtractor};
15use systemprompt_traits::AnalyticsProvider;
16use systemprompt_users::{UserProviderImpl, UserService};
17use uuid::Uuid;
18
19use super::jwt::JwtExtractor;
20
21#[derive(Clone, Debug)]
22pub struct SessionMiddleware {
23    jwt_extractor: Arc<JwtExtractor>,
24    analytics_service: Arc<AnalyticsService>,
25    session_creation_service: Arc<SessionCreationService>,
26}
27
28impl SessionMiddleware {
29    pub fn new(ctx: &AppContext) -> anyhow::Result<Self> {
30        let jwt_secret = systemprompt_models::SecretsBootstrap::jwt_secret()?;
31        let jwt_extractor = Arc::new(JwtExtractor::new(jwt_secret));
32        let user_service = UserService::new(ctx.db_pool())?;
33        let session_creation_service = Arc::new(SessionCreationService::new(
34            ctx.analytics_service().clone(),
35            Arc::new(UserProviderImpl::new(user_service)),
36        ));
37
38        Ok(Self {
39            jwt_extractor,
40            analytics_service: ctx.analytics_service().clone(),
41            session_creation_service,
42        })
43    }
44
45    pub async fn handle(&self, mut request: Request, next: Next) -> Result<Response, ApiError> {
46        let headers = request.headers();
47        let uri = request.uri().clone();
48        let method = request.method().clone();
49
50        let should_skip = Self::should_skip_session_tracking(uri.path());
51
52        tracing::debug!(
53            path = %uri.path(),
54            should_skip = should_skip,
55            "Session middleware evaluating request"
56        );
57
58        let trace_id = HeaderExtractor::extract_trace_id(headers);
59
60        let (req_ctx, jwt_cookie) = if should_skip {
61            let ctx = RequestContext::new(
62                SessionId::new(format!("untracked_{}", Uuid::new_v4())),
63                trace_id,
64                ContextId::new(String::new()),
65                AgentName::system(),
66            )
67            .with_user_id(UserId::new("anonymous".to_string()))
68            .with_user_type(UserType::Anon)
69            .with_tracked(false);
70            (ctx, None)
71        } else {
72            let analytics = self
73                .analytics_service
74                .extract_analytics(headers, Some(&uri));
75            let is_bot = AnalyticsService::is_bot(&analytics);
76
77            tracing::debug!(
78                path = %uri.path(),
79                is_bot = is_bot,
80                user_agent = ?analytics.user_agent,
81                "Session middleware bot check"
82            );
83
84            if is_bot {
85                let ctx = RequestContext::new(
86                    SessionId::new(format!("bot_{}", Uuid::new_v4())),
87                    trace_id,
88                    ContextId::new(String::new()),
89                    AgentName::system(),
90                )
91                .with_user_id(UserId::new("bot".to_string()))
92                .with_user_type(UserType::Anon)
93                .with_tracked(false);
94                (ctx, None)
95            } else {
96                let token_result = TokenExtractor::browser_only().extract(headers).ok();
97
98                let (session_id, user_id, jwt_token, jwt_cookie, fingerprint_hash) =
99                    if let Some(token) = token_result {
100                        if let Ok(jwt_context) = self.jwt_extractor.extract_user_context(&token) {
101                            let session_exists = self
102                                .analytics_service
103                                .find_session_by_id(&jwt_context.session_id)
104                                .await
105                                .ok()
106                                .flatten()
107                                .is_some();
108
109                            if session_exists {
110                                (
111                                    jwt_context.session_id,
112                                    jwt_context.user_id,
113                                    token,
114                                    None,
115                                    None,
116                                )
117                            } else {
118                                tracing::info!(
119                                    old_session_id = %jwt_context.session_id,
120                                    user_id = %jwt_context.user_id,
121                                    "JWT valid but session missing, refreshing with new session"
122                                );
123                                let (sid, uid, new_token, _, fp) = self
124                                    .refresh_session_for_user(&jwt_context.user_id, headers, &uri)
125                                    .await?;
126                                (sid, uid, new_token.clone(), Some(new_token), Some(fp))
127                            }
128                        } else {
129                            let (sid, uid, token, is_new, fp) =
130                                self.create_new_session(headers, &uri, &method).await?;
131                            let jwt_cookie = if is_new { Some(token.clone()) } else { None };
132                            (sid, uid, token, jwt_cookie, Some(fp))
133                        }
134                    } else {
135                        let (sid, uid, token, is_new, fp) =
136                            self.create_new_session(headers, &uri, &method).await?;
137                        let jwt_cookie = if is_new { Some(token.clone()) } else { None };
138                        (sid, uid, token, jwt_cookie, Some(fp))
139                    };
140
141                let mut ctx = RequestContext::new(
142                    session_id,
143                    trace_id,
144                    ContextId::new(String::new()),
145                    AgentName::system(),
146                )
147                .with_user_id(user_id)
148                .with_auth_token(jwt_token)
149                .with_user_type(UserType::Anon)
150                .with_tracked(true);
151                if let Some(fp) = fingerprint_hash {
152                    ctx = ctx.with_fingerprint_hash(fp);
153                }
154                (ctx, jwt_cookie)
155            }
156        };
157
158        tracing::debug!(
159            path = %uri.path(),
160            session_id = %req_ctx.session_id(),
161            "Session middleware setting context"
162        );
163
164        request.extensions_mut().insert(req_ctx);
165
166        let mut response = next.run(request).await;
167
168        if let Some(token) = jwt_cookie {
169            let cookie =
170                format!("access_token={token}; HttpOnly; SameSite=Lax; Path=/; Max-Age=604800");
171            if let Ok(cookie_value) = cookie.parse() {
172                response
173                    .headers_mut()
174                    .insert(header::SET_COOKIE, cookie_value);
175            }
176        }
177
178        Ok(response)
179    }
180
181    async fn create_new_session(
182        &self,
183        headers: &http::HeaderMap,
184        uri: &http::Uri,
185        _method: &http::Method,
186    ) -> Result<(SessionId, UserId, String, bool, String), ApiError> {
187        let client_id = ClientId::new("sp_web".to_string());
188
189        let jwt_secret = systemprompt_models::SecretsBootstrap::jwt_secret().map_err(|e| {
190            tracing::error!(error = %e, "Failed to get JWT secret during session creation");
191            ApiError::internal_error("Failed to initialize session")
192        })?;
193
194        self.session_creation_service
195            .create_anonymous_session(CreateAnonymousSessionInput {
196                headers,
197                uri: Some(uri),
198                client_id: &client_id,
199                jwt_secret,
200                session_source: SessionSource::Web,
201            })
202            .await
203            .map(|session_info| {
204                (
205                    session_info.session_id,
206                    session_info.user_id,
207                    session_info.jwt_token,
208                    session_info.is_new,
209                    session_info.fingerprint_hash,
210                )
211            })
212            .map_err(|e| {
213                tracing::error!(error = %e, "Failed to create anonymous session");
214                ApiError::internal_error("Service temporarily unavailable")
215            })
216    }
217
218    async fn refresh_session_for_user(
219        &self,
220        user_id: &UserId,
221        headers: &http::HeaderMap,
222        uri: &http::Uri,
223    ) -> Result<(SessionId, UserId, String, bool, String), ApiError> {
224        let session_id = self
225            .session_creation_service
226            .create_authenticated_session(user_id, headers, SessionSource::Web)
227            .await
228            .map_err(|e| {
229                tracing::error!(error = %e, user_id = %user_id, "Failed to create session for user");
230                ApiError::internal_error("Failed to refresh session")
231            })?;
232
233        let jwt_secret = systemprompt_models::SecretsBootstrap::jwt_secret().map_err(|e| {
234            tracing::error!(error = %e, "Failed to get JWT secret during session refresh");
235            ApiError::internal_error("Failed to refresh session")
236        })?;
237
238        let config = systemprompt_models::Config::get().map_err(|e| {
239            tracing::error!(error = %e, "Failed to get config during session refresh");
240            ApiError::internal_error("Failed to refresh session")
241        })?;
242
243        let token = systemprompt_oauth::services::generation::generate_anonymous_jwt(
244            user_id.as_str(),
245            session_id.as_str(),
246            &ClientId::new("sp_web".to_string()),
247            &systemprompt_oauth::services::JwtSigningParams {
248                secret: jwt_secret,
249                issuer: &config.jwt_issuer,
250            },
251        )
252        .map_err(|e| {
253            tracing::error!(error = %e, "Failed to generate JWT during session refresh");
254            ApiError::internal_error("Failed to refresh session")
255        })?;
256
257        let analytics = self.analytics_service.extract_analytics(headers, Some(uri));
258        let fingerprint = AnalyticsService::compute_fingerprint(&analytics);
259
260        Ok((session_id, user_id.clone(), token, true, fingerprint))
261    }
262
263    fn should_skip_session_tracking(path: &str) -> bool {
264        if path.starts_with(ApiPaths::TRACK_BASE) {
265            return false;
266        }
267
268        if path.starts_with(ApiPaths::API_BASE) {
269            return true;
270        }
271
272        if path.starts_with(ApiPaths::NEXT_BASE) {
273            return true;
274        }
275
276        if path.starts_with(ApiPaths::STATIC_BASE)
277            || path.starts_with(ApiPaths::ASSETS_BASE)
278            || path.starts_with(ApiPaths::IMAGES_BASE)
279        {
280            return true;
281        }
282
283        if path == "/health" || path == "/ready" || path == "/healthz" {
284            return true;
285        }
286
287        if path == "/favicon.ico"
288            || path == "/robots.txt"
289            || path == "/sitemap.xml"
290            || path == "/manifest.json"
291        {
292            return true;
293        }
294
295        if let Some(last_segment) = path.rsplit('/').next() {
296            if last_segment.contains('.') {
297                let extension = last_segment.rsplit('.').next().unwrap_or("");
298                match extension {
299                    "html" | "htm" => {},
300                    _ => return true,
301                }
302            }
303        }
304
305        false
306    }
307}