systemprompt_api/services/middleware/
session.rs1use axum::extract::Request;
2use axum::http::header;
3use axum::middleware::Next;
4use axum::response::Response;
5use std::sync::Arc;
6use systemprompt_analytics::AnalyticsService;
7use systemprompt_identifiers::{AgentName, ClientId, ContextId, SessionId, SessionSource, UserId};
8use systemprompt_models::api::ApiError;
9use systemprompt_models::auth::UserType;
10use systemprompt_models::execution::context::RequestContext;
11use systemprompt_models::modules::ApiPaths;
12use systemprompt_oauth::services::{CreateAnonymousSessionInput, SessionCreationService};
13use systemprompt_runtime::AppContext;
14use systemprompt_security::{HeaderExtractor, TokenExtractor};
15use systemprompt_traits::AnalyticsProvider;
16use systemprompt_users::{UserProviderImpl, UserService};
17use uuid::Uuid;
18
19use super::jwt::JwtExtractor;
20
21#[derive(Clone, Debug)]
22pub struct SessionMiddleware {
23 jwt_extractor: Arc<JwtExtractor>,
24 analytics_service: Arc<AnalyticsService>,
25 session_creation_service: Arc<SessionCreationService>,
26}
27
28impl SessionMiddleware {
29 pub fn new(ctx: &AppContext) -> anyhow::Result<Self> {
30 let jwt_secret = systemprompt_models::SecretsBootstrap::jwt_secret()?;
31 let jwt_extractor = Arc::new(JwtExtractor::new(jwt_secret));
32 let user_service = UserService::new(ctx.db_pool())?;
33 let session_creation_service = Arc::new(SessionCreationService::new(
34 ctx.analytics_service().clone(),
35 Arc::new(UserProviderImpl::new(user_service)),
36 ));
37
38 Ok(Self {
39 jwt_extractor,
40 analytics_service: ctx.analytics_service().clone(),
41 session_creation_service,
42 })
43 }
44
45 pub async fn handle(&self, mut request: Request, next: Next) -> Result<Response, ApiError> {
46 let headers = request.headers();
47 let uri = request.uri().clone();
48 let method = request.method().clone();
49
50 let should_skip = Self::should_skip_session_tracking(uri.path());
51
52 tracing::debug!(
53 path = %uri.path(),
54 should_skip = should_skip,
55 "Session middleware evaluating request"
56 );
57
58 let trace_id = HeaderExtractor::extract_trace_id(headers);
59
60 let (req_ctx, jwt_cookie) = if should_skip {
61 let ctx = RequestContext::new(
62 SessionId::new(format!("untracked_{}", Uuid::new_v4())),
63 trace_id,
64 ContextId::new(String::new()),
65 AgentName::system(),
66 )
67 .with_user_id(UserId::new("anonymous".to_string()))
68 .with_user_type(UserType::Anon)
69 .with_tracked(false);
70 (ctx, None)
71 } else {
72 let analytics = self
73 .analytics_service
74 .extract_analytics(headers, Some(&uri));
75 let is_bot = AnalyticsService::is_bot(&analytics);
76
77 tracing::debug!(
78 path = %uri.path(),
79 is_bot = is_bot,
80 user_agent = ?analytics.user_agent,
81 "Session middleware bot check"
82 );
83
84 if is_bot {
85 let ctx = RequestContext::new(
86 SessionId::new(format!("bot_{}", Uuid::new_v4())),
87 trace_id,
88 ContextId::new(String::new()),
89 AgentName::system(),
90 )
91 .with_user_id(UserId::new("bot".to_string()))
92 .with_user_type(UserType::Anon)
93 .with_tracked(false);
94 (ctx, None)
95 } else {
96 let token_result = TokenExtractor::browser_only().extract(headers).ok();
97
98 let (session_id, user_id, jwt_token, jwt_cookie, fingerprint_hash) =
99 if let Some(token) = token_result {
100 if let Ok(jwt_context) = self.jwt_extractor.extract_user_context(&token) {
101 let session_exists = self
102 .analytics_service
103 .find_session_by_id(&jwt_context.session_id)
104 .await
105 .ok()
106 .flatten()
107 .is_some();
108
109 if session_exists {
110 (
111 jwt_context.session_id,
112 jwt_context.user_id,
113 token,
114 None,
115 None,
116 )
117 } else {
118 tracing::info!(
119 old_session_id = %jwt_context.session_id,
120 user_id = %jwt_context.user_id,
121 "JWT valid but session missing, refreshing with new session"
122 );
123 let (sid, uid, new_token, _, fp) = self
124 .refresh_session_for_user(&jwt_context.user_id, headers, &uri)
125 .await?;
126 (sid, uid, new_token.clone(), Some(new_token), Some(fp))
127 }
128 } else {
129 let (sid, uid, token, is_new, fp) =
130 self.create_new_session(headers, &uri, &method).await?;
131 let jwt_cookie = if is_new { Some(token.clone()) } else { None };
132 (sid, uid, token, jwt_cookie, Some(fp))
133 }
134 } else {
135 let (sid, uid, token, is_new, fp) =
136 self.create_new_session(headers, &uri, &method).await?;
137 let jwt_cookie = if is_new { Some(token.clone()) } else { None };
138 (sid, uid, token, jwt_cookie, Some(fp))
139 };
140
141 let mut ctx = RequestContext::new(
142 session_id,
143 trace_id,
144 ContextId::new(String::new()),
145 AgentName::system(),
146 )
147 .with_user_id(user_id)
148 .with_auth_token(jwt_token)
149 .with_user_type(UserType::Anon)
150 .with_tracked(true);
151 if let Some(fp) = fingerprint_hash {
152 ctx = ctx.with_fingerprint_hash(fp);
153 }
154 (ctx, jwt_cookie)
155 }
156 };
157
158 tracing::debug!(
159 path = %uri.path(),
160 session_id = %req_ctx.session_id(),
161 "Session middleware setting context"
162 );
163
164 request.extensions_mut().insert(req_ctx);
165
166 let mut response = next.run(request).await;
167
168 if let Some(token) = jwt_cookie {
169 let cookie =
170 format!("access_token={token}; HttpOnly; SameSite=Lax; Path=/; Max-Age=604800");
171 if let Ok(cookie_value) = cookie.parse() {
172 response
173 .headers_mut()
174 .insert(header::SET_COOKIE, cookie_value);
175 }
176 }
177
178 Ok(response)
179 }
180
181 async fn create_new_session(
182 &self,
183 headers: &http::HeaderMap,
184 uri: &http::Uri,
185 _method: &http::Method,
186 ) -> Result<(SessionId, UserId, String, bool, String), ApiError> {
187 let client_id = ClientId::new("sp_web".to_string());
188
189 let jwt_secret = systemprompt_models::SecretsBootstrap::jwt_secret().map_err(|e| {
190 tracing::error!(error = %e, "Failed to get JWT secret during session creation");
191 ApiError::internal_error("Failed to initialize session")
192 })?;
193
194 self.session_creation_service
195 .create_anonymous_session(CreateAnonymousSessionInput {
196 headers,
197 uri: Some(uri),
198 client_id: &client_id,
199 jwt_secret,
200 session_source: SessionSource::Web,
201 })
202 .await
203 .map(|session_info| {
204 (
205 session_info.session_id,
206 session_info.user_id,
207 session_info.jwt_token,
208 session_info.is_new,
209 session_info.fingerprint_hash,
210 )
211 })
212 .map_err(|e| {
213 tracing::error!(error = %e, "Failed to create anonymous session");
214 ApiError::internal_error("Service temporarily unavailable")
215 })
216 }
217
218 async fn refresh_session_for_user(
219 &self,
220 user_id: &UserId,
221 headers: &http::HeaderMap,
222 uri: &http::Uri,
223 ) -> Result<(SessionId, UserId, String, bool, String), ApiError> {
224 let session_id = self
225 .session_creation_service
226 .create_authenticated_session(user_id, headers, SessionSource::Web)
227 .await
228 .map_err(|e| {
229 tracing::error!(error = %e, user_id = %user_id, "Failed to create session for user");
230 ApiError::internal_error("Failed to refresh session")
231 })?;
232
233 let jwt_secret = systemprompt_models::SecretsBootstrap::jwt_secret().map_err(|e| {
234 tracing::error!(error = %e, "Failed to get JWT secret during session refresh");
235 ApiError::internal_error("Failed to refresh session")
236 })?;
237
238 let config = systemprompt_models::Config::get().map_err(|e| {
239 tracing::error!(error = %e, "Failed to get config during session refresh");
240 ApiError::internal_error("Failed to refresh session")
241 })?;
242
243 let token = systemprompt_oauth::services::generation::generate_anonymous_jwt(
244 user_id.as_str(),
245 session_id.as_str(),
246 &ClientId::new("sp_web".to_string()),
247 &systemprompt_oauth::services::JwtSigningParams {
248 secret: jwt_secret,
249 issuer: &config.jwt_issuer,
250 },
251 )
252 .map_err(|e| {
253 tracing::error!(error = %e, "Failed to generate JWT during session refresh");
254 ApiError::internal_error("Failed to refresh session")
255 })?;
256
257 let analytics = self.analytics_service.extract_analytics(headers, Some(uri));
258 let fingerprint = AnalyticsService::compute_fingerprint(&analytics);
259
260 Ok((session_id, user_id.clone(), token, true, fingerprint))
261 }
262
263 fn should_skip_session_tracking(path: &str) -> bool {
264 if path.starts_with(ApiPaths::TRACK_BASE) {
265 return false;
266 }
267
268 if path.starts_with(ApiPaths::API_BASE) {
269 return true;
270 }
271
272 if path.starts_with(ApiPaths::NEXT_BASE) {
273 return true;
274 }
275
276 if path.starts_with(ApiPaths::STATIC_BASE)
277 || path.starts_with(ApiPaths::ASSETS_BASE)
278 || path.starts_with(ApiPaths::IMAGES_BASE)
279 {
280 return true;
281 }
282
283 if path == "/health" || path == "/ready" || path == "/healthz" {
284 return true;
285 }
286
287 if path == "/favicon.ico"
288 || path == "/robots.txt"
289 || path == "/sitemap.xml"
290 || path == "/manifest.json"
291 {
292 return true;
293 }
294
295 if let Some(last_segment) = path.rsplit('/').next() {
296 if last_segment.contains('.') {
297 let extension = last_segment.rsplit('.').next().unwrap_or("");
298 match extension {
299 "html" | "htm" => {},
300 _ => return true,
301 }
302 }
303 }
304
305 false
306 }
307}