systemprompt_api/services/middleware/session/
mod.rs1mod lifecycle;
9mod skip;
10
11pub use skip::should_skip_session_tracking;
12
13use axum::extract::Request;
14use axum::http::header;
15use axum::middleware::Next;
16use axum::response::Response;
17use std::sync::Arc;
18use systemprompt_analytics::AnalyticsService;
19use systemprompt_identifiers::{AgentName, ContextId, SessionId, UserId};
20use systemprompt_models::api::ApiError;
21use systemprompt_models::auth::UserType;
22use systemprompt_models::execution::context::RequestContext;
23use systemprompt_oauth::services::SessionCreationService;
24use systemprompt_runtime::AppContext;
25use systemprompt_security::{
26 CookieExtractor, HeaderExtractor, TokenExtractor, extract_user_context,
27};
28use systemprompt_traits::AnalyticsProvider;
29use systemprompt_users::{UserProviderImpl, UserService};
30use uuid::Uuid;
31
32#[derive(Clone, Debug)]
33pub struct SessionMiddleware {
34 analytics_service: Arc<AnalyticsService>,
35 session_creation_service: Arc<SessionCreationService>,
36}
37
38impl SessionMiddleware {
39 pub fn new(ctx: &AppContext) -> anyhow::Result<Self> {
40 let user_service = UserService::new(ctx.db_pool())?;
41 let concrete = Arc::clone(ctx.analytics_service());
42 let analytics: Arc<dyn AnalyticsProvider> = concrete;
43 let session_creation_service = Arc::new(SessionCreationService::new(
44 analytics,
45 Arc::new(UserProviderImpl::new(user_service)),
46 ));
47
48 Ok(Self {
49 analytics_service: Arc::clone(ctx.analytics_service()),
50 session_creation_service,
51 })
52 }
53
54 pub async fn handle(&self, mut request: Request, next: Next) -> Result<Response, ApiError> {
55 let headers = request.headers();
56 let uri = request.uri().clone();
57 let method = request.method().clone();
58
59 let should_skip = should_skip_session_tracking(uri.path());
60
61 tracing::debug!(
62 path = %uri.path(),
63 should_skip = should_skip,
64 "Session middleware evaluating request"
65 );
66
67 let trace_id = HeaderExtractor::extract_trace_id(headers);
68
69 let (req_ctx, jwt_cookie) = if should_skip {
70 (
71 self.anonymous_context("untracked", trace_id, headers, &uri)
72 .await?,
73 None,
74 )
75 } else {
76 self.tracked_context(trace_id, headers, &uri, &method)
77 .await?
78 };
79
80 tracing::debug!(
81 path = %uri.path(),
82 session_id = %req_ctx.session_id(),
83 "Session middleware setting context"
84 );
85
86 request.extensions_mut().insert(req_ctx);
87
88 let mut response = next.run(request).await;
89
90 if let Some(token) = jwt_cookie {
91 let cookie = format!(
92 "{}={token}; HttpOnly; SameSite=Strict; Path=/; Max-Age=604800",
93 CookieExtractor::DEFAULT_COOKIE_NAME
94 );
95 if let Ok(cookie_value) = cookie.parse() {
96 response
97 .headers_mut()
98 .insert(header::SET_COOKIE, cookie_value);
99 }
100 }
101
102 Ok(response)
103 }
104
105 async fn anonymous_context(
110 &self,
111 session_prefix: &str,
112 trace_id: systemprompt_identifiers::TraceId,
113 headers: &http::HeaderMap,
114 uri: &http::Uri,
115 ) -> Result<RequestContext, ApiError> {
116 let (user_id, fingerprint) = self
117 .session_creation_service
118 .ensure_anonymous_user(headers, Some(uri))
119 .await
120 .map_err(|e| {
121 tracing::error!(error = %e, session_prefix, "Failed to ensure anonymous user");
122 ApiError::internal_error("Service temporarily unavailable")
123 })?;
124
125 Ok(RequestContext::new(
126 SessionId::new(format!("{session_prefix}_{}", Uuid::new_v4())),
127 trace_id,
128 ContextId::generate(),
129 AgentName::system(),
130 )
131 .with_actor(systemprompt_identifiers::Actor::anonymous(user_id))
132 .with_user_type(UserType::Anon)
133 .with_tracked(false)
134 .with_fingerprint_hash(fingerprint))
135 }
136
137 async fn tracked_context(
138 &self,
139 trace_id: systemprompt_identifiers::TraceId,
140 headers: &http::HeaderMap,
141 uri: &http::Uri,
142 method: &http::Method,
143 ) -> Result<(RequestContext, Option<String>), ApiError> {
144 let analytics = self.analytics_service.extract_analytics(headers, Some(uri));
145 let is_bot = AnalyticsService::is_bot(&analytics);
146
147 tracing::debug!(
148 path = %uri.path(),
149 is_bot = is_bot,
150 user_agent = ?analytics.user_agent,
151 "Session middleware bot check"
152 );
153
154 if is_bot {
155 return Ok((
156 self.anonymous_context("bot", trace_id, headers, uri)
157 .await?,
158 None,
159 ));
160 }
161
162 let token_result = TokenExtractor::browser_only().extract(headers).ok();
163
164 let (session_id, user_id, jwt_token, jwt_cookie, fingerprint_hash) = self
165 .resolve_session(token_result, headers, uri, method)
166 .await?;
167
168 let context_id =
169 HeaderExtractor::extract_context_id(headers).unwrap_or_else(ContextId::generate);
170
171 let mut ctx = RequestContext::new(session_id, trace_id, context_id, AgentName::system())
172 .with_actor(systemprompt_identifiers::Actor::user(user_id))
173 .with_auth_token(jwt_token)
174 .with_user_type(UserType::Anon)
175 .with_tracked(true);
176 if let Some(fp) = fingerprint_hash {
177 ctx = ctx.with_fingerprint_hash(fp);
178 }
179 Ok((ctx, jwt_cookie))
180 }
181
182 async fn resolve_session(
183 &self,
184 token_result: Option<String>,
185 headers: &http::HeaderMap,
186 uri: &http::Uri,
187 method: &http::Method,
188 ) -> Result<(SessionId, UserId, String, Option<String>, Option<String>), ApiError> {
189 let Some(token) = token_result else {
190 let (sid, uid, token, is_new, fp) =
191 lifecycle::create_new_session(&self.session_creation_service, headers, uri, method)
192 .await?;
193 let jwt_cookie = if is_new { Some(token.clone()) } else { None };
194 return Ok((sid, uid, token, jwt_cookie, Some(fp)));
195 };
196
197 let Ok(jwt_context) = extract_user_context(&token) else {
198 let (sid, uid, token, is_new, fp) =
199 lifecycle::create_new_session(&self.session_creation_service, headers, uri, method)
200 .await?;
201 let jwt_cookie = if is_new { Some(token.clone()) } else { None };
202 return Ok((sid, uid, token, jwt_cookie, Some(fp)));
203 };
204
205 let session_exists = self
206 .analytics_service
207 .find_active_session_by_id(&jwt_context.session_id)
208 .await
209 .map_err(|e| {
210 tracing::warn!(error = %e, "find_active_session_by_id failed");
211 e
212 })
213 .ok()
214 .flatten()
215 .is_some();
216
217 if session_exists {
218 return Ok((
219 jwt_context.session_id,
220 jwt_context.user_id,
221 token,
222 None,
223 None,
224 ));
225 }
226
227 tracing::info!(
228 old_session_id = %jwt_context.session_id,
229 user_id = %jwt_context.user_id,
230 "JWT valid but session missing, refreshing with new session"
231 );
232 match lifecycle::refresh_session_for_user(
233 &self.session_creation_service,
234 &self.analytics_service,
235 &jwt_context.user_id,
236 headers,
237 uri,
238 )
239 .await
240 {
241 Ok((sid, uid, new_token, _, fp)) => {
242 Ok((sid, uid, new_token.clone(), Some(new_token), Some(fp)))
243 },
244 Err(e) if e.error_key.as_deref() == Some("user_not_found") => {
245 tracing::warn!(
246 user_id = %jwt_context.user_id,
247 "JWT references non-existent user, creating new anonymous session"
248 );
249 let (sid, uid, token, _, fp) = lifecycle::create_new_session(
250 &self.session_creation_service,
251 headers,
252 uri,
253 method,
254 )
255 .await?;
256 Ok((sid, uid, token.clone(), Some(token), Some(fp)))
257 },
258 Err(e) => Err(e),
259 }
260 }
261}