systemprompt_api/services/middleware/
session.rs1use axum::extract::Request;
2use axum::http::header;
3use axum::middleware::Next;
4use axum::response::Response;
5use std::sync::Arc;
6use systemprompt_analytics::AnalyticsService;
7use systemprompt_identifiers::{AgentName, ClientId, ContextId, SessionId, SessionSource, UserId};
8use systemprompt_models::api::ApiError;
9use systemprompt_models::auth::UserType;
10use systemprompt_models::execution::context::RequestContext;
11use systemprompt_models::modules::ApiPaths;
12use systemprompt_oauth::services::{CreateAnonymousSessionInput, SessionCreationService};
13use systemprompt_runtime::AppContext;
14use systemprompt_security::{HeaderExtractor, TokenExtractor};
15use systemprompt_traits::AnalyticsProvider;
16use systemprompt_users::{UserProviderImpl, UserService};
17use uuid::Uuid;
18
19use super::jwt::JwtExtractor;
20
21#[derive(Clone, Debug)]
22pub struct SessionMiddleware {
23 jwt_extractor: Arc<JwtExtractor>,
24 analytics_service: Arc<AnalyticsService>,
25 session_creation_service: Arc<SessionCreationService>,
26}
27
28impl SessionMiddleware {
29 pub fn new(ctx: &AppContext) -> anyhow::Result<Self> {
30 let jwt_secret = systemprompt_models::SecretsBootstrap::jwt_secret()?;
31 let jwt_extractor = Arc::new(JwtExtractor::new(jwt_secret));
32 let user_service = UserService::new(ctx.db_pool())?;
33 let session_creation_service = Arc::new(SessionCreationService::new(
34 ctx.analytics_service().clone(),
35 Arc::new(UserProviderImpl::new(user_service)),
36 ));
37
38 Ok(Self {
39 jwt_extractor,
40 analytics_service: ctx.analytics_service().clone(),
41 session_creation_service,
42 })
43 }
44
45 pub async fn handle(&self, mut request: Request, next: Next) -> Result<Response, ApiError> {
46 let headers = request.headers();
47 let uri = request.uri().clone();
48 let method = request.method().clone();
49
50 let should_skip = Self::should_skip_session_tracking(uri.path());
51
52 let trace_id = HeaderExtractor::extract_trace_id(headers);
53
54 let (req_ctx, jwt_cookie) = if should_skip {
55 let ctx = RequestContext::new(
56 SessionId::new(format!("untracked_{}", Uuid::new_v4())),
57 trace_id,
58 ContextId::new(String::new()),
59 AgentName::system(),
60 )
61 .with_user_id(UserId::new("anonymous".to_string()))
62 .with_user_type(UserType::Anon)
63 .with_tracked(false);
64 (ctx, None)
65 } else {
66 let analytics = self
67 .analytics_service
68 .extract_analytics(headers, Some(&uri));
69 let is_bot = AnalyticsService::is_bot(&analytics);
70
71 if is_bot {
72 let ctx = RequestContext::new(
73 SessionId::new(format!("bot_{}", Uuid::new_v4())),
74 trace_id,
75 ContextId::new(String::new()),
76 AgentName::system(),
77 )
78 .with_user_id(UserId::new("bot".to_string()))
79 .with_user_type(UserType::Anon)
80 .with_tracked(false);
81 (ctx, None)
82 } else {
83 let token_result = TokenExtractor::browser_only().extract(headers).ok();
84
85 let (session_id, user_id, jwt_token, jwt_cookie, fingerprint_hash) =
86 if let Some(token) = token_result {
87 if let Ok(jwt_context) = self.jwt_extractor.extract_user_context(&token) {
88 let session_exists = self
89 .analytics_service
90 .find_session_by_id(&jwt_context.session_id)
91 .await
92 .ok()
93 .flatten()
94 .is_some();
95
96 if session_exists {
97 (
98 jwt_context.session_id,
99 jwt_context.user_id,
100 token,
101 None,
102 None,
103 )
104 } else {
105 tracing::info!(
106 old_session_id = %jwt_context.session_id,
107 user_id = %jwt_context.user_id,
108 "JWT valid but session missing, refreshing with new session"
109 );
110 let (sid, uid, new_token, _, fp) = self
111 .refresh_session_for_user(&jwt_context.user_id, headers, &uri)
112 .await?;
113 (sid, uid, new_token.clone(), Some(new_token), Some(fp))
114 }
115 } else {
116 let (sid, uid, token, is_new, fp) =
117 self.create_new_session(headers, &uri, &method).await?;
118 let jwt_cookie = if is_new { Some(token.clone()) } else { None };
119 (sid, uid, token, jwt_cookie, Some(fp))
120 }
121 } else {
122 let (sid, uid, token, is_new, fp) =
123 self.create_new_session(headers, &uri, &method).await?;
124 let jwt_cookie = if is_new { Some(token.clone()) } else { None };
125 (sid, uid, token, jwt_cookie, Some(fp))
126 };
127
128 let mut ctx = RequestContext::new(
129 session_id,
130 trace_id,
131 ContextId::new(String::new()),
132 AgentName::system(),
133 )
134 .with_user_id(user_id)
135 .with_auth_token(jwt_token)
136 .with_user_type(UserType::Anon)
137 .with_tracked(true);
138 if let Some(fp) = fingerprint_hash {
139 ctx = ctx.with_fingerprint_hash(fp);
140 }
141 (ctx, jwt_cookie)
142 }
143 };
144
145 request.extensions_mut().insert(req_ctx);
146
147 let mut response = next.run(request).await;
148
149 if let Some(token) = jwt_cookie {
150 let cookie =
151 format!("access_token={token}; HttpOnly; SameSite=Lax; Path=/; Max-Age=604800");
152 if let Ok(cookie_value) = cookie.parse() {
153 response
154 .headers_mut()
155 .insert(header::SET_COOKIE, cookie_value);
156 }
157 }
158
159 Ok(response)
160 }
161
162 async fn create_new_session(
163 &self,
164 headers: &http::HeaderMap,
165 uri: &http::Uri,
166 _method: &http::Method,
167 ) -> Result<(SessionId, UserId, String, bool, String), ApiError> {
168 let client_id = ClientId::new("sp_web".to_string());
169
170 let jwt_secret = systemprompt_models::SecretsBootstrap::jwt_secret().map_err(|e| {
171 tracing::error!(error = %e, "Failed to get JWT secret during session creation");
172 ApiError::internal_error("Failed to initialize session")
173 })?;
174
175 self.session_creation_service
176 .create_anonymous_session(CreateAnonymousSessionInput {
177 headers,
178 uri: Some(uri),
179 client_id: &client_id,
180 jwt_secret,
181 session_source: SessionSource::Web,
182 })
183 .await
184 .map(|session_info| {
185 (
186 session_info.session_id,
187 session_info.user_id,
188 session_info.jwt_token,
189 session_info.is_new,
190 session_info.fingerprint_hash,
191 )
192 })
193 .map_err(|e| {
194 tracing::error!(error = %e, "Failed to create anonymous session");
195 ApiError::internal_error("Service temporarily unavailable")
196 })
197 }
198
199 async fn refresh_session_for_user(
200 &self,
201 user_id: &UserId,
202 headers: &http::HeaderMap,
203 uri: &http::Uri,
204 ) -> Result<(SessionId, UserId, String, bool, String), ApiError> {
205 let session_id = self
206 .session_creation_service
207 .create_authenticated_session(user_id, headers, SessionSource::Web)
208 .await
209 .map_err(|e| {
210 tracing::error!(error = %e, user_id = %user_id, "Failed to create session for user");
211 ApiError::internal_error("Failed to refresh session")
212 })?;
213
214 let jwt_secret = systemprompt_models::SecretsBootstrap::jwt_secret().map_err(|e| {
215 tracing::error!(error = %e, "Failed to get JWT secret during session refresh");
216 ApiError::internal_error("Failed to refresh session")
217 })?;
218
219 let config = systemprompt_models::Config::get().map_err(|e| {
220 tracing::error!(error = %e, "Failed to get config during session refresh");
221 ApiError::internal_error("Failed to refresh session")
222 })?;
223
224 let token = systemprompt_oauth::services::generation::generate_anonymous_jwt(
225 user_id.as_str(),
226 session_id.as_str(),
227 &ClientId::new("sp_web".to_string()),
228 &systemprompt_oauth::services::JwtSigningParams {
229 secret: jwt_secret,
230 issuer: &config.jwt_issuer,
231 },
232 )
233 .map_err(|e| {
234 tracing::error!(error = %e, "Failed to generate JWT during session refresh");
235 ApiError::internal_error("Failed to refresh session")
236 })?;
237
238 let analytics = self.analytics_service.extract_analytics(headers, Some(uri));
239 let fingerprint = AnalyticsService::compute_fingerprint(&analytics);
240
241 Ok((session_id, user_id.clone(), token, true, fingerprint))
242 }
243
244 fn should_skip_session_tracking(path: &str) -> bool {
245 if path.starts_with(ApiPaths::API_BASE) {
246 return true;
247 }
248
249 if path.starts_with(ApiPaths::NEXT_BASE) {
250 return true;
251 }
252
253 if path.starts_with(ApiPaths::STATIC_BASE)
254 || path.starts_with(ApiPaths::ASSETS_BASE)
255 || path.starts_with(ApiPaths::IMAGES_BASE)
256 {
257 return true;
258 }
259
260 if path == "/health" || path == "/ready" || path == "/healthz" {
261 return true;
262 }
263
264 if path == "/favicon.ico"
265 || path == "/robots.txt"
266 || path == "/sitemap.xml"
267 || path == "/manifest.json"
268 {
269 return true;
270 }
271
272 if let Some(last_segment) = path.rsplit('/').next() {
273 if last_segment.contains('.') {
274 let extension = last_segment.rsplit('.').next().unwrap_or("");
275 match extension {
276 "html" | "htm" => {},
277 _ => return true,
278 }
279 }
280 }
281
282 false
283 }
284}