systemprompt_api/services/middleware/
session.rs1mod lifecycle;
2mod skip;
3
4pub use skip::should_skip_session_tracking;
5
6use axum::extract::Request;
7use axum::http::header;
8use axum::middleware::Next;
9use axum::response::Response;
10use std::sync::Arc;
11use systemprompt_analytics::AnalyticsService;
12use systemprompt_identifiers::{AgentName, ContextId, SessionId, UserId};
13use systemprompt_models::api::ApiError;
14use systemprompt_models::auth::UserType;
15use systemprompt_models::execution::context::RequestContext;
16use systemprompt_oauth::services::SessionCreationService;
17use systemprompt_runtime::AppContext;
18use systemprompt_security::{HeaderExtractor, TokenExtractor};
19use systemprompt_traits::AnalyticsProvider;
20use systemprompt_users::{UserProviderImpl, UserService};
21use uuid::Uuid;
22
23use super::jwt::JwtExtractor;
24
25#[derive(Clone, Debug)]
26pub struct SessionMiddleware {
27 jwt_extractor: Arc<JwtExtractor>,
28 analytics_service: Arc<AnalyticsService>,
29 session_creation_service: Arc<SessionCreationService>,
30}
31
32impl SessionMiddleware {
33 pub fn new(ctx: &AppContext) -> anyhow::Result<Self> {
34 let jwt_secret = systemprompt_config::SecretsBootstrap::jwt_secret()?;
35 let jwt_extractor = Arc::new(JwtExtractor::new(jwt_secret));
36 let user_service = UserService::new(ctx.db_pool())?;
37 let concrete = Arc::clone(ctx.analytics_service());
38 let analytics: Arc<dyn AnalyticsProvider> = concrete;
39 let session_creation_service = Arc::new(SessionCreationService::new(
40 analytics,
41 Arc::new(UserProviderImpl::new(user_service)),
42 ));
43
44 Ok(Self {
45 jwt_extractor,
46 analytics_service: Arc::clone(ctx.analytics_service()),
47 session_creation_service,
48 })
49 }
50
51 pub async fn handle(&self, mut request: Request, next: Next) -> Result<Response, ApiError> {
52 let headers = request.headers();
53 let uri = request.uri().clone();
54 let method = request.method().clone();
55
56 let should_skip = should_skip_session_tracking(uri.path());
57
58 tracing::debug!(
59 path = %uri.path(),
60 should_skip = should_skip,
61 "Session middleware evaluating request"
62 );
63
64 let trace_id = HeaderExtractor::extract_trace_id(headers);
65
66 let (req_ctx, jwt_cookie) = if should_skip {
67 (Self::untracked_context(trace_id), None)
68 } else {
69 self.tracked_context(trace_id, headers, &uri, &method)
70 .await?
71 };
72
73 tracing::debug!(
74 path = %uri.path(),
75 session_id = %req_ctx.session_id(),
76 "Session middleware setting context"
77 );
78
79 request.extensions_mut().insert(req_ctx);
80
81 let mut response = next.run(request).await;
82
83 if let Some(token) = jwt_cookie {
84 let cookie =
85 format!("access_token={token}; HttpOnly; SameSite=Lax; Path=/; Max-Age=604800");
86 if let Ok(cookie_value) = cookie.parse() {
87 response
88 .headers_mut()
89 .insert(header::SET_COOKIE, cookie_value);
90 }
91 }
92
93 Ok(response)
94 }
95
96 fn untracked_context(trace_id: systemprompt_identifiers::TraceId) -> RequestContext {
97 RequestContext::new(
98 SessionId::new(format!("untracked_{}", Uuid::new_v4())),
99 trace_id,
100 ContextId::generate(),
101 AgentName::system(),
102 )
103 .with_user_id(UserId::new("anonymous".to_string()))
104 .with_user_type(UserType::Anon)
105 .with_tracked(false)
106 }
107
108 fn bot_context(trace_id: systemprompt_identifiers::TraceId) -> RequestContext {
109 RequestContext::new(
110 SessionId::new(format!("bot_{}", Uuid::new_v4())),
111 trace_id,
112 ContextId::generate(),
113 AgentName::system(),
114 )
115 .with_user_id(UserId::new("bot".to_string()))
116 .with_user_type(UserType::Anon)
117 .with_tracked(false)
118 }
119
120 async fn tracked_context(
121 &self,
122 trace_id: systemprompt_identifiers::TraceId,
123 headers: &http::HeaderMap,
124 uri: &http::Uri,
125 method: &http::Method,
126 ) -> Result<(RequestContext, Option<String>), ApiError> {
127 let analytics = self.analytics_service.extract_analytics(headers, Some(uri));
128 let is_bot = AnalyticsService::is_bot(&analytics);
129
130 tracing::debug!(
131 path = %uri.path(),
132 is_bot = is_bot,
133 user_agent = ?analytics.user_agent,
134 "Session middleware bot check"
135 );
136
137 if is_bot {
138 return Ok((Self::bot_context(trace_id), None));
139 }
140
141 let token_result = TokenExtractor::browser_only().extract(headers).ok();
142
143 let (session_id, user_id, jwt_token, jwt_cookie, fingerprint_hash) = self
144 .resolve_session(token_result, headers, uri, method)
145 .await?;
146
147 let context_id =
148 HeaderExtractor::extract_context_id(headers).unwrap_or_else(ContextId::generate);
149
150 let mut ctx = RequestContext::new(session_id, trace_id, context_id, AgentName::system())
151 .with_user_id(user_id)
152 .with_auth_token(jwt_token)
153 .with_user_type(UserType::Anon)
154 .with_tracked(true);
155 if let Some(fp) = fingerprint_hash {
156 ctx = ctx.with_fingerprint_hash(fp);
157 }
158 Ok((ctx, jwt_cookie))
159 }
160
161 async fn resolve_session(
162 &self,
163 token_result: Option<String>,
164 headers: &http::HeaderMap,
165 uri: &http::Uri,
166 method: &http::Method,
167 ) -> Result<(SessionId, UserId, String, Option<String>, Option<String>), ApiError> {
168 let Some(token) = token_result else {
169 let (sid, uid, token, is_new, fp) =
170 lifecycle::create_new_session(&self.session_creation_service, headers, uri, method)
171 .await?;
172 let jwt_cookie = if is_new { Some(token.clone()) } else { None };
173 return Ok((sid, uid, token, jwt_cookie, Some(fp)));
174 };
175
176 let Ok(jwt_context) = self.jwt_extractor.extract_user_context(&token) else {
177 let (sid, uid, token, is_new, fp) =
178 lifecycle::create_new_session(&self.session_creation_service, headers, uri, method)
179 .await?;
180 let jwt_cookie = if is_new { Some(token.clone()) } else { None };
181 return Ok((sid, uid, token, jwt_cookie, Some(fp)));
182 };
183
184 let session_exists = self
185 .analytics_service
186 .find_session_by_id(&jwt_context.session_id)
187 .await
188 .map_err(|e| {
189 tracing::warn!(error = %e, "find_session_by_id failed");
190 e
191 })
192 .ok()
193 .flatten()
194 .is_some();
195
196 if session_exists {
197 return Ok((
198 jwt_context.session_id,
199 jwt_context.user_id,
200 token,
201 None,
202 None,
203 ));
204 }
205
206 tracing::info!(
207 old_session_id = %jwt_context.session_id,
208 user_id = %jwt_context.user_id,
209 "JWT valid but session missing, refreshing with new session"
210 );
211 match lifecycle::refresh_session_for_user(
212 &self.session_creation_service,
213 &self.analytics_service,
214 &jwt_context.user_id,
215 headers,
216 uri,
217 )
218 .await
219 {
220 Ok((sid, uid, new_token, _, fp)) => {
221 Ok((sid, uid, new_token.clone(), Some(new_token), Some(fp)))
222 },
223 Err(e) if e.error_key.as_deref() == Some("user_not_found") => {
224 tracing::warn!(
225 user_id = %jwt_context.user_id,
226 "JWT references non-existent user, creating new anonymous session"
227 );
228 let (sid, uid, token, _, fp) = lifecycle::create_new_session(
229 &self.session_creation_service,
230 headers,
231 uri,
232 method,
233 )
234 .await?;
235 Ok((sid, uid, token.clone(), Some(token), Some(fp)))
236 },
237 Err(e) => Err(e),
238 }
239 }
240}