systemprompt_api/services/middleware/
session.rs1use axum::extract::Request;
2use axum::http::header;
3use axum::middleware::Next;
4use axum::response::Response;
5use std::sync::Arc;
6use systemprompt_analytics::AnalyticsService;
7use systemprompt_identifiers::{AgentName, ClientId, ContextId, SessionId, SessionSource, UserId};
8use systemprompt_models::api::ApiError;
9use systemprompt_models::auth::UserType;
10use systemprompt_models::execution::context::RequestContext;
11use systemprompt_models::modules::ApiPaths;
12use systemprompt_oauth::services::{
13 CreateAnonymousSessionInput, SessionCreationError, SessionCreationService,
14};
15use systemprompt_runtime::AppContext;
16use systemprompt_security::{HeaderExtractor, TokenExtractor};
17use systemprompt_traits::AnalyticsProvider;
18use systemprompt_users::{UserProviderImpl, UserService};
19use uuid::Uuid;
20
21use super::jwt::JwtExtractor;
22
23#[derive(Clone, Debug)]
24pub struct SessionMiddleware {
25 jwt_extractor: Arc<JwtExtractor>,
26 analytics_service: Arc<AnalyticsService>,
27 session_creation_service: Arc<SessionCreationService>,
28}
29
30impl SessionMiddleware {
31 pub fn new(ctx: &AppContext) -> anyhow::Result<Self> {
32 let jwt_secret = systemprompt_models::SecretsBootstrap::jwt_secret()?;
33 let jwt_extractor = Arc::new(JwtExtractor::new(jwt_secret));
34 let user_service = UserService::new(ctx.db_pool())?;
35 let session_creation_service = Arc::new(SessionCreationService::new(
36 ctx.analytics_service().clone(),
37 Arc::new(UserProviderImpl::new(user_service)),
38 ));
39
40 Ok(Self {
41 jwt_extractor,
42 analytics_service: ctx.analytics_service().clone(),
43 session_creation_service,
44 })
45 }
46
47 pub async fn handle(&self, mut request: Request, next: Next) -> Result<Response, ApiError> {
48 let headers = request.headers();
49 let uri = request.uri().clone();
50 let method = request.method().clone();
51
52 let should_skip = Self::should_skip_session_tracking(uri.path());
53
54 tracing::debug!(
55 path = %uri.path(),
56 should_skip = should_skip,
57 "Session middleware evaluating request"
58 );
59
60 let trace_id = HeaderExtractor::extract_trace_id(headers);
61
62 let (req_ctx, jwt_cookie) = if should_skip {
63 let ctx = RequestContext::new(
64 SessionId::new(format!("untracked_{}", Uuid::new_v4())),
65 trace_id,
66 ContextId::new(String::new()),
67 AgentName::system(),
68 )
69 .with_user_id(UserId::new("anonymous".to_string()))
70 .with_user_type(UserType::Anon)
71 .with_tracked(false);
72 (ctx, None)
73 } else {
74 let analytics = self
75 .analytics_service
76 .extract_analytics(headers, Some(&uri));
77 let is_bot = AnalyticsService::is_bot(&analytics);
78
79 tracing::debug!(
80 path = %uri.path(),
81 is_bot = is_bot,
82 user_agent = ?analytics.user_agent,
83 "Session middleware bot check"
84 );
85
86 if is_bot {
87 let ctx = RequestContext::new(
88 SessionId::new(format!("bot_{}", Uuid::new_v4())),
89 trace_id,
90 ContextId::new(String::new()),
91 AgentName::system(),
92 )
93 .with_user_id(UserId::new("bot".to_string()))
94 .with_user_type(UserType::Anon)
95 .with_tracked(false);
96 (ctx, None)
97 } else {
98 let token_result = TokenExtractor::browser_only().extract(headers).ok();
99
100 let (session_id, user_id, jwt_token, jwt_cookie, fingerprint_hash) = if let Some(
101 token,
102 ) =
103 token_result
104 {
105 if let Ok(jwt_context) = self.jwt_extractor.extract_user_context(&token) {
106 let session_exists = self
107 .analytics_service
108 .find_session_by_id(&jwt_context.session_id)
109 .await
110 .ok()
111 .flatten()
112 .is_some();
113
114 if session_exists {
115 (
116 jwt_context.session_id,
117 jwt_context.user_id,
118 token,
119 None,
120 None,
121 )
122 } else {
123 tracing::info!(
124 old_session_id = %jwt_context.session_id,
125 user_id = %jwt_context.user_id,
126 "JWT valid but session missing, refreshing with new session"
127 );
128 match self
129 .refresh_session_for_user(&jwt_context.user_id, headers, &uri)
130 .await
131 {
132 Ok((sid, uid, new_token, _, fp)) => {
133 (sid, uid, new_token.clone(), Some(new_token), Some(fp))
134 },
135 Err(e) if e.error_key.as_deref() == Some("user_not_found") => {
136 tracing::warn!(
137 user_id = %jwt_context.user_id,
138 "JWT references non-existent user, creating new anonymous session"
139 );
140 let (sid, uid, token, _, fp) =
141 self.create_new_session(headers, &uri, &method).await?;
142 (sid, uid, token.clone(), Some(token), Some(fp))
143 },
144 Err(e) => return Err(e),
145 }
146 }
147 } else {
148 let (sid, uid, token, is_new, fp) =
149 self.create_new_session(headers, &uri, &method).await?;
150 let jwt_cookie = if is_new { Some(token.clone()) } else { None };
151 (sid, uid, token, jwt_cookie, Some(fp))
152 }
153 } else {
154 let (sid, uid, token, is_new, fp) =
155 self.create_new_session(headers, &uri, &method).await?;
156 let jwt_cookie = if is_new { Some(token.clone()) } else { None };
157 (sid, uid, token, jwt_cookie, Some(fp))
158 };
159
160 let mut ctx = RequestContext::new(
161 session_id,
162 trace_id,
163 ContextId::new(String::new()),
164 AgentName::system(),
165 )
166 .with_user_id(user_id)
167 .with_auth_token(jwt_token)
168 .with_user_type(UserType::Anon)
169 .with_tracked(true);
170 if let Some(fp) = fingerprint_hash {
171 ctx = ctx.with_fingerprint_hash(fp);
172 }
173 (ctx, jwt_cookie)
174 }
175 };
176
177 tracing::debug!(
178 path = %uri.path(),
179 session_id = %req_ctx.session_id(),
180 "Session middleware setting context"
181 );
182
183 request.extensions_mut().insert(req_ctx);
184
185 let mut response = next.run(request).await;
186
187 if let Some(token) = jwt_cookie {
188 let cookie =
189 format!("access_token={token}; HttpOnly; SameSite=Lax; Path=/; Max-Age=604800");
190 if let Ok(cookie_value) = cookie.parse() {
191 response
192 .headers_mut()
193 .insert(header::SET_COOKIE, cookie_value);
194 }
195 }
196
197 Ok(response)
198 }
199
200 async fn create_new_session(
201 &self,
202 headers: &http::HeaderMap,
203 uri: &http::Uri,
204 _method: &http::Method,
205 ) -> Result<(SessionId, UserId, String, bool, String), ApiError> {
206 let client_id = ClientId::new("sp_web".to_string());
207
208 let jwt_secret = systemprompt_models::SecretsBootstrap::jwt_secret().map_err(|e| {
209 tracing::error!(error = %e, "Failed to get JWT secret during session creation");
210 ApiError::internal_error("Failed to initialize session")
211 })?;
212
213 self.session_creation_service
214 .create_anonymous_session(CreateAnonymousSessionInput {
215 headers,
216 uri: Some(uri),
217 client_id: &client_id,
218 jwt_secret,
219 session_source: SessionSource::Web,
220 })
221 .await
222 .map(|session_info| {
223 (
224 session_info.session_id,
225 session_info.user_id,
226 session_info.jwt_token,
227 session_info.is_new,
228 session_info.fingerprint_hash,
229 )
230 })
231 .map_err(|e| {
232 tracing::error!(error = %e, "Failed to create anonymous session");
233 ApiError::internal_error("Service temporarily unavailable")
234 })
235 }
236
237 async fn refresh_session_for_user(
238 &self,
239 user_id: &UserId,
240 headers: &http::HeaderMap,
241 uri: &http::Uri,
242 ) -> Result<(SessionId, UserId, String, bool, String), ApiError> {
243 let session_id = self
244 .session_creation_service
245 .create_authenticated_session(user_id, headers, SessionSource::Web)
246 .await
247 .map_err(|e| match e {
248 SessionCreationError::UserNotFound { ref user_id } => {
249 ApiError::not_found(format!("User not found: {}", user_id))
250 .with_error_key("user_not_found")
251 }
252 SessionCreationError::Internal(ref msg) => {
253 tracing::error!(error = %msg, user_id = %user_id, "Failed to create session for user");
254 ApiError::internal_error("Failed to refresh session")
255 }
256 })?;
257
258 let jwt_secret = systemprompt_models::SecretsBootstrap::jwt_secret().map_err(|e| {
259 tracing::error!(error = %e, "Failed to get JWT secret during session refresh");
260 ApiError::internal_error("Failed to refresh session")
261 })?;
262
263 let config = systemprompt_models::Config::get().map_err(|e| {
264 tracing::error!(error = %e, "Failed to get config during session refresh");
265 ApiError::internal_error("Failed to refresh session")
266 })?;
267
268 let token = systemprompt_oauth::services::generation::generate_anonymous_jwt(
269 user_id.as_str(),
270 session_id.as_str(),
271 &ClientId::new("sp_web".to_string()),
272 &systemprompt_oauth::services::JwtSigningParams {
273 secret: jwt_secret,
274 issuer: &config.jwt_issuer,
275 },
276 )
277 .map_err(|e| {
278 tracing::error!(error = %e, "Failed to generate JWT during session refresh");
279 ApiError::internal_error("Failed to refresh session")
280 })?;
281
282 let analytics = self.analytics_service.extract_analytics(headers, Some(uri));
283 let fingerprint = AnalyticsService::compute_fingerprint(&analytics);
284
285 Ok((session_id, user_id.clone(), token, true, fingerprint))
286 }
287
288 fn should_skip_session_tracking(path: &str) -> bool {
289 if path.starts_with(ApiPaths::TRACK_BASE) {
290 return false;
291 }
292
293 if path.starts_with(ApiPaths::MCP_BASE) {
294 return true;
295 }
296
297 if path.starts_with(ApiPaths::API_BASE) {
298 return true;
299 }
300
301 if path.starts_with(ApiPaths::NEXT_BASE) {
302 return true;
303 }
304
305 if path.starts_with(ApiPaths::STATIC_BASE)
306 || path.starts_with(ApiPaths::ASSETS_BASE)
307 || path.starts_with(ApiPaths::IMAGES_BASE)
308 {
309 return true;
310 }
311
312 if path == "/health" || path == "/ready" || path == "/healthz" {
313 return true;
314 }
315
316 if path == "/favicon.ico"
317 || path == "/robots.txt"
318 || path == "/sitemap.xml"
319 || path == "/manifest.json"
320 {
321 return true;
322 }
323
324 if let Some(last_segment) = path.rsplit('/').next() {
325 if last_segment.contains('.') {
326 let extension = last_segment.rsplit('.').next().unwrap_or("");
327 match extension {
328 "html" | "htm" => {},
329 _ => return true,
330 }
331 }
332 }
333
334 false
335 }
336}