systemprompt_api/services/middleware/
session.rs1mod lifecycle;
2mod skip;
3
4pub use skip::should_skip_session_tracking;
5
6use axum::extract::Request;
7use axum::http::header;
8use axum::middleware::Next;
9use axum::response::Response;
10use std::sync::Arc;
11use systemprompt_analytics::AnalyticsService;
12use systemprompt_identifiers::{AgentName, ContextId, SessionId, UserId};
13use systemprompt_models::api::ApiError;
14use systemprompt_models::auth::UserType;
15use systemprompt_models::execution::context::RequestContext;
16use systemprompt_oauth::services::SessionCreationService;
17use systemprompt_runtime::AppContext;
18use systemprompt_security::{HeaderExtractor, TokenExtractor};
19use systemprompt_traits::AnalyticsProvider;
20use systemprompt_users::{UserProviderImpl, UserService};
21use uuid::Uuid;
22
23use super::jwt::JwtExtractor;
24
25#[derive(Clone, Debug)]
26pub struct SessionMiddleware {
27 jwt_extractor: Arc<JwtExtractor>,
28 analytics_service: Arc<AnalyticsService>,
29 session_creation_service: Arc<SessionCreationService>,
30}
31
32impl SessionMiddleware {
33 pub fn new(ctx: &AppContext) -> anyhow::Result<Self> {
34 let jwt_extractor = Arc::new(JwtExtractor::new());
35 let user_service = UserService::new(ctx.db_pool())?;
36 let concrete = Arc::clone(ctx.analytics_service());
37 let analytics: Arc<dyn AnalyticsProvider> = concrete;
38 let session_creation_service = Arc::new(SessionCreationService::new(
39 analytics,
40 Arc::new(UserProviderImpl::new(user_service)),
41 ));
42
43 Ok(Self {
44 jwt_extractor,
45 analytics_service: Arc::clone(ctx.analytics_service()),
46 session_creation_service,
47 })
48 }
49
50 pub async fn handle(&self, mut request: Request, next: Next) -> Result<Response, ApiError> {
51 let headers = request.headers();
52 let uri = request.uri().clone();
53 let method = request.method().clone();
54
55 let should_skip = should_skip_session_tracking(uri.path());
56
57 tracing::debug!(
58 path = %uri.path(),
59 should_skip = should_skip,
60 "Session middleware evaluating request"
61 );
62
63 let trace_id = HeaderExtractor::extract_trace_id(headers);
64
65 let (req_ctx, jwt_cookie) = if should_skip {
66 (Self::untracked_context(trace_id), None)
67 } else {
68 self.tracked_context(trace_id, headers, &uri, &method)
69 .await?
70 };
71
72 tracing::debug!(
73 path = %uri.path(),
74 session_id = %req_ctx.session_id(),
75 "Session middleware setting context"
76 );
77
78 request.extensions_mut().insert(req_ctx);
79
80 let mut response = next.run(request).await;
81
82 if let Some(token) = jwt_cookie {
83 let cookie =
84 format!("access_token={token}; HttpOnly; SameSite=Strict; Path=/; Max-Age=604800");
85 if let Ok(cookie_value) = cookie.parse() {
86 response
87 .headers_mut()
88 .insert(header::SET_COOKIE, cookie_value);
89 }
90 }
91
92 Ok(response)
93 }
94
95 fn untracked_context(trace_id: systemprompt_identifiers::TraceId) -> RequestContext {
96 RequestContext::new(
97 SessionId::new(format!("untracked_{}", Uuid::new_v4())),
98 trace_id,
99 ContextId::generate(),
100 AgentName::system(),
101 )
102 .with_actor(systemprompt_identifiers::Actor::anonymous(
103 systemprompt_identifiers::bootstrap::anonymous(),
104 ))
105 .with_user_type(UserType::Anon)
106 .with_tracked(false)
107 }
108
109 fn bot_context(trace_id: systemprompt_identifiers::TraceId) -> RequestContext {
110 RequestContext::new(
111 SessionId::new(format!("bot_{}", Uuid::new_v4())),
112 trace_id,
113 ContextId::generate(),
114 AgentName::system(),
115 )
116 .with_actor(systemprompt_identifiers::Actor::user(
117 systemprompt_identifiers::bootstrap::bot(),
118 ))
119 .with_user_type(UserType::Anon)
120 .with_tracked(false)
121 }
122
123 async fn tracked_context(
124 &self,
125 trace_id: systemprompt_identifiers::TraceId,
126 headers: &http::HeaderMap,
127 uri: &http::Uri,
128 method: &http::Method,
129 ) -> Result<(RequestContext, Option<String>), ApiError> {
130 let analytics = self.analytics_service.extract_analytics(headers, Some(uri));
131 let is_bot = AnalyticsService::is_bot(&analytics);
132
133 tracing::debug!(
134 path = %uri.path(),
135 is_bot = is_bot,
136 user_agent = ?analytics.user_agent,
137 "Session middleware bot check"
138 );
139
140 if is_bot {
141 return Ok((Self::bot_context(trace_id), None));
142 }
143
144 let token_result = TokenExtractor::browser_only().extract(headers).ok();
145
146 let (session_id, user_id, jwt_token, jwt_cookie, fingerprint_hash) = self
147 .resolve_session(token_result, headers, uri, method)
148 .await?;
149
150 let context_id =
151 HeaderExtractor::extract_context_id(headers).unwrap_or_else(ContextId::generate);
152
153 let mut ctx = RequestContext::new(session_id, trace_id, context_id, AgentName::system())
154 .with_actor(systemprompt_identifiers::Actor::user(user_id))
155 .with_auth_token(jwt_token)
156 .with_user_type(UserType::Anon)
157 .with_tracked(true);
158 if let Some(fp) = fingerprint_hash {
159 ctx = ctx.with_fingerprint_hash(fp);
160 }
161 Ok((ctx, jwt_cookie))
162 }
163
164 async fn resolve_session(
165 &self,
166 token_result: Option<String>,
167 headers: &http::HeaderMap,
168 uri: &http::Uri,
169 method: &http::Method,
170 ) -> Result<(SessionId, UserId, String, Option<String>, Option<String>), ApiError> {
171 let Some(token) = token_result else {
172 let (sid, uid, token, is_new, fp) =
173 lifecycle::create_new_session(&self.session_creation_service, headers, uri, method)
174 .await?;
175 let jwt_cookie = if is_new { Some(token.clone()) } else { None };
176 return Ok((sid, uid, token, jwt_cookie, Some(fp)));
177 };
178
179 let Ok(jwt_context) = self.jwt_extractor.extract_user_context(&token) else {
180 let (sid, uid, token, is_new, fp) =
181 lifecycle::create_new_session(&self.session_creation_service, headers, uri, method)
182 .await?;
183 let jwt_cookie = if is_new { Some(token.clone()) } else { None };
184 return Ok((sid, uid, token, jwt_cookie, Some(fp)));
185 };
186
187 let session_exists = self
188 .analytics_service
189 .find_active_session_by_id(&jwt_context.session_id)
190 .await
191 .map_err(|e| {
192 tracing::warn!(error = %e, "find_active_session_by_id failed");
193 e
194 })
195 .ok()
196 .flatten()
197 .is_some();
198
199 if session_exists {
200 return Ok((
201 jwt_context.session_id,
202 jwt_context.user_id,
203 token,
204 None,
205 None,
206 ));
207 }
208
209 tracing::info!(
210 old_session_id = %jwt_context.session_id,
211 user_id = %jwt_context.user_id,
212 "JWT valid but session missing, refreshing with new session"
213 );
214 match lifecycle::refresh_session_for_user(
215 &self.session_creation_service,
216 &self.analytics_service,
217 &jwt_context.user_id,
218 headers,
219 uri,
220 )
221 .await
222 {
223 Ok((sid, uid, new_token, _, fp)) => {
224 Ok((sid, uid, new_token.clone(), Some(new_token), Some(fp)))
225 },
226 Err(e) if e.error_key.as_deref() == Some("user_not_found") => {
227 tracing::warn!(
228 user_id = %jwt_context.user_id,
229 "JWT references non-existent user, creating new anonymous session"
230 );
231 let (sid, uid, token, _, fp) = lifecycle::create_new_session(
232 &self.session_creation_service,
233 headers,
234 uri,
235 method,
236 )
237 .await?;
238 Ok((sid, uid, token.clone(), Some(token), Some(fp)))
239 },
240 Err(e) => Err(e),
241 }
242 }
243}