systemprompt_api/services/middleware/
session.rs1use axum::extract::Request;
2use axum::http::header;
3use axum::middleware::Next;
4use axum::response::Response;
5use std::sync::Arc;
6use systemprompt_analytics::AnalyticsService;
7use systemprompt_identifiers::{AgentName, ClientId, ContextId, SessionId, SessionSource, UserId};
8use systemprompt_models::api::ApiError;
9use systemprompt_models::auth::UserType;
10use systemprompt_models::execution::context::RequestContext;
11use systemprompt_models::modules::ApiPaths;
12use systemprompt_oauth::services::{
13 CreateAnonymousSessionInput, SessionCreationError, SessionCreationService,
14};
15use systemprompt_runtime::AppContext;
16use systemprompt_security::{HeaderExtractor, TokenExtractor};
17use systemprompt_traits::AnalyticsProvider;
18use systemprompt_users::{UserProviderImpl, UserService};
19use uuid::Uuid;
20
21use super::jwt::JwtExtractor;
22
23#[derive(Clone, Debug)]
24pub struct SessionMiddleware {
25 jwt_extractor: Arc<JwtExtractor>,
26 analytics_service: Arc<AnalyticsService>,
27 session_creation_service: Arc<SessionCreationService>,
28}
29
30impl SessionMiddleware {
31 pub fn new(ctx: &AppContext) -> anyhow::Result<Self> {
32 let jwt_secret = systemprompt_models::SecretsBootstrap::jwt_secret()?;
33 let jwt_extractor = Arc::new(JwtExtractor::new(jwt_secret));
34 let user_service = UserService::new(ctx.db_pool())?;
35 let concrete = Arc::clone(ctx.analytics_service());
36 let analytics: Arc<dyn AnalyticsProvider> = concrete;
37 let session_creation_service = Arc::new(SessionCreationService::new(
38 analytics,
39 Arc::new(UserProviderImpl::new(user_service)),
40 ));
41
42 Ok(Self {
43 jwt_extractor,
44 analytics_service: Arc::clone(ctx.analytics_service()),
45 session_creation_service,
46 })
47 }
48
49 pub async fn handle(&self, mut request: Request, next: Next) -> Result<Response, ApiError> {
50 let headers = request.headers();
51 let uri = request.uri().clone();
52 let method = request.method().clone();
53
54 let should_skip = Self::should_skip_session_tracking(uri.path());
55
56 tracing::debug!(
57 path = %uri.path(),
58 should_skip = should_skip,
59 "Session middleware evaluating request"
60 );
61
62 let trace_id = HeaderExtractor::extract_trace_id(headers);
63
64 let (req_ctx, jwt_cookie) = if should_skip {
65 let ctx = RequestContext::new(
66 SessionId::new(format!("untracked_{}", Uuid::new_v4())),
67 trace_id,
68 ContextId::new(String::new()),
69 AgentName::system(),
70 )
71 .with_user_id(UserId::new("anonymous".to_string()))
72 .with_user_type(UserType::Anon)
73 .with_tracked(false);
74 (ctx, None)
75 } else {
76 let analytics = self
77 .analytics_service
78 .extract_analytics(headers, Some(&uri));
79 let is_bot = AnalyticsService::is_bot(&analytics);
80
81 tracing::debug!(
82 path = %uri.path(),
83 is_bot = is_bot,
84 user_agent = ?analytics.user_agent,
85 "Session middleware bot check"
86 );
87
88 if is_bot {
89 let ctx = RequestContext::new(
90 SessionId::new(format!("bot_{}", Uuid::new_v4())),
91 trace_id,
92 ContextId::new(String::new()),
93 AgentName::system(),
94 )
95 .with_user_id(UserId::new("bot".to_string()))
96 .with_user_type(UserType::Anon)
97 .with_tracked(false);
98 (ctx, None)
99 } else {
100 let token_result = TokenExtractor::browser_only().extract(headers).ok();
101
102 let (session_id, user_id, jwt_token, jwt_cookie, fingerprint_hash) = if let Some(
103 token,
104 ) =
105 token_result
106 {
107 if let Ok(jwt_context) = self.jwt_extractor.extract_user_context(&token) {
108 let session_exists = self
109 .analytics_service
110 .find_session_by_id(&jwt_context.session_id)
111 .await
112 .ok()
113 .flatten()
114 .is_some();
115
116 if session_exists {
117 (
118 jwt_context.session_id,
119 jwt_context.user_id,
120 token,
121 None,
122 None,
123 )
124 } else {
125 tracing::info!(
126 old_session_id = %jwt_context.session_id,
127 user_id = %jwt_context.user_id,
128 "JWT valid but session missing, refreshing with new session"
129 );
130 match self
131 .refresh_session_for_user(&jwt_context.user_id, headers, &uri)
132 .await
133 {
134 Ok((sid, uid, new_token, _, fp)) => {
135 (sid, uid, new_token.clone(), Some(new_token), Some(fp))
136 },
137 Err(e) if e.error_key.as_deref() == Some("user_not_found") => {
138 tracing::warn!(
139 user_id = %jwt_context.user_id,
140 "JWT references non-existent user, creating new anonymous session"
141 );
142 let (sid, uid, token, _, fp) =
143 self.create_new_session(headers, &uri, &method).await?;
144 (sid, uid, token.clone(), Some(token), Some(fp))
145 },
146 Err(e) => return Err(e),
147 }
148 }
149 } else {
150 let (sid, uid, token, is_new, fp) =
151 self.create_new_session(headers, &uri, &method).await?;
152 let jwt_cookie = if is_new { Some(token.clone()) } else { None };
153 (sid, uid, token, jwt_cookie, Some(fp))
154 }
155 } else {
156 let (sid, uid, token, is_new, fp) =
157 self.create_new_session(headers, &uri, &method).await?;
158 let jwt_cookie = if is_new { Some(token.clone()) } else { None };
159 (sid, uid, token, jwt_cookie, Some(fp))
160 };
161
162 let mut ctx = RequestContext::new(
163 session_id,
164 trace_id,
165 ContextId::new(String::new()),
166 AgentName::system(),
167 )
168 .with_user_id(user_id)
169 .with_auth_token(jwt_token)
170 .with_user_type(UserType::Anon)
171 .with_tracked(true);
172 if let Some(fp) = fingerprint_hash {
173 ctx = ctx.with_fingerprint_hash(fp);
174 }
175 (ctx, jwt_cookie)
176 }
177 };
178
179 tracing::debug!(
180 path = %uri.path(),
181 session_id = %req_ctx.session_id(),
182 "Session middleware setting context"
183 );
184
185 request.extensions_mut().insert(req_ctx);
186
187 let mut response = next.run(request).await;
188
189 if let Some(token) = jwt_cookie {
190 let cookie =
191 format!("access_token={token}; HttpOnly; SameSite=Lax; Path=/; Max-Age=604800");
192 if let Ok(cookie_value) = cookie.parse() {
193 response
194 .headers_mut()
195 .insert(header::SET_COOKIE, cookie_value);
196 }
197 }
198
199 Ok(response)
200 }
201
202 async fn create_new_session(
203 &self,
204 headers: &http::HeaderMap,
205 uri: &http::Uri,
206 _method: &http::Method,
207 ) -> Result<(SessionId, UserId, String, bool, String), ApiError> {
208 let client_id = ClientId::new("sp_web".to_string());
209
210 let jwt_secret = systemprompt_models::SecretsBootstrap::jwt_secret().map_err(|e| {
211 tracing::error!(error = %e, "Failed to get JWT secret during session creation");
212 ApiError::internal_error("Failed to initialize session")
213 })?;
214
215 self.session_creation_service
216 .create_anonymous_session(CreateAnonymousSessionInput {
217 headers,
218 uri: Some(uri),
219 client_id: &client_id,
220 jwt_secret,
221 session_source: SessionSource::Web,
222 })
223 .await
224 .map(|session_info| {
225 (
226 session_info.session_id,
227 session_info.user_id,
228 session_info.jwt_token,
229 session_info.is_new,
230 session_info.fingerprint_hash,
231 )
232 })
233 .map_err(|e| {
234 tracing::error!(error = %e, "Failed to create anonymous session");
235 ApiError::internal_error("Service temporarily unavailable")
236 })
237 }
238
239 async fn refresh_session_for_user(
240 &self,
241 user_id: &UserId,
242 headers: &http::HeaderMap,
243 uri: &http::Uri,
244 ) -> Result<(SessionId, UserId, String, bool, String), ApiError> {
245 let session_id = self
246 .session_creation_service
247 .create_authenticated_session(user_id, headers, SessionSource::Web)
248 .await
249 .map_err(|e| match e {
250 SessionCreationError::UserNotFound { ref user_id } => {
251 ApiError::not_found(format!("User not found: {}", user_id))
252 .with_error_key("user_not_found")
253 }
254 SessionCreationError::Internal(ref msg) => {
255 tracing::error!(error = %msg, user_id = %user_id, "Failed to create session for user");
256 ApiError::internal_error("Failed to refresh session")
257 }
258 })?;
259
260 let jwt_secret = systemprompt_models::SecretsBootstrap::jwt_secret().map_err(|e| {
261 tracing::error!(error = %e, "Failed to get JWT secret during session refresh");
262 ApiError::internal_error("Failed to refresh session")
263 })?;
264
265 let config = systemprompt_models::Config::get().map_err(|e| {
266 tracing::error!(error = %e, "Failed to get config during session refresh");
267 ApiError::internal_error("Failed to refresh session")
268 })?;
269
270 let token = systemprompt_oauth::services::generation::generate_anonymous_jwt(
271 user_id,
272 &session_id,
273 &ClientId::new("sp_web".to_string()),
274 &systemprompt_oauth::services::JwtSigningParams {
275 secret: jwt_secret,
276 issuer: &config.jwt_issuer,
277 },
278 )
279 .map_err(|e| {
280 tracing::error!(error = %e, "Failed to generate JWT during session refresh");
281 ApiError::internal_error("Failed to refresh session")
282 })?;
283
284 let analytics = self.analytics_service.extract_analytics(headers, Some(uri));
285 let fingerprint = AnalyticsService::compute_fingerprint(&analytics);
286
287 Ok((session_id, user_id.clone(), token, true, fingerprint))
288 }
289
290 pub fn should_skip_session_tracking(path: &str) -> bool {
291 if path.starts_with(ApiPaths::TRACK_BASE) {
292 return false;
293 }
294
295 if path.starts_with(ApiPaths::MCP_BASE) {
296 return true;
297 }
298
299 if path.starts_with(ApiPaths::API_BASE) {
300 return true;
301 }
302
303 if path.starts_with(ApiPaths::NEXT_BASE) {
304 return true;
305 }
306
307 if path.starts_with(ApiPaths::STATIC_BASE)
308 || path.starts_with(ApiPaths::ASSETS_BASE)
309 || path.starts_with(ApiPaths::IMAGES_BASE)
310 {
311 return true;
312 }
313
314 if path == "/health" || path == "/ready" || path == "/healthz" {
315 return true;
316 }
317
318 if path == "/favicon.ico"
319 || path == "/robots.txt"
320 || path == "/sitemap.xml"
321 || path == "/manifest.json"
322 {
323 return true;
324 }
325
326 if let Some(last_segment) = path.rsplit('/').next() {
327 if last_segment.contains('.') {
328 let extension = last_segment.rsplit('.').next().unwrap_or("");
329 match extension {
330 "html" | "htm" => {},
331 _ => return true,
332 }
333 }
334 }
335
336 false
337 }
338}