systemprompt_api/services/middleware/session/
mod.rs1mod lifecycle;
2mod skip;
3
4pub use skip::should_skip_session_tracking;
5
6use axum::extract::Request;
7use axum::http::header;
8use axum::middleware::Next;
9use axum::response::Response;
10use std::sync::Arc;
11use systemprompt_analytics::AnalyticsService;
12use systemprompt_identifiers::{AgentName, ContextId, SessionId, UserId};
13use systemprompt_models::api::ApiError;
14use systemprompt_models::auth::UserType;
15use systemprompt_models::execution::context::RequestContext;
16use systemprompt_oauth::services::SessionCreationService;
17use systemprompt_runtime::AppContext;
18use systemprompt_security::{HeaderExtractor, TokenExtractor, extract_user_context};
19use systemprompt_traits::AnalyticsProvider;
20use systemprompt_users::{UserProviderImpl, UserService};
21use uuid::Uuid;
22
23#[derive(Clone, Debug)]
24pub struct SessionMiddleware {
25 analytics_service: Arc<AnalyticsService>,
26 session_creation_service: Arc<SessionCreationService>,
27}
28
29impl SessionMiddleware {
30 pub fn new(ctx: &AppContext) -> anyhow::Result<Self> {
31 let user_service = UserService::new(ctx.db_pool())?;
32 let concrete = Arc::clone(ctx.analytics_service());
33 let analytics: Arc<dyn AnalyticsProvider> = concrete;
34 let session_creation_service = Arc::new(SessionCreationService::new(
35 analytics,
36 Arc::new(UserProviderImpl::new(user_service)),
37 ));
38
39 Ok(Self {
40 analytics_service: Arc::clone(ctx.analytics_service()),
41 session_creation_service,
42 })
43 }
44
45 pub async fn handle(&self, mut request: Request, next: Next) -> Result<Response, ApiError> {
46 let headers = request.headers();
47 let uri = request.uri().clone();
48 let method = request.method().clone();
49
50 let should_skip = should_skip_session_tracking(uri.path());
51
52 tracing::debug!(
53 path = %uri.path(),
54 should_skip = should_skip,
55 "Session middleware evaluating request"
56 );
57
58 let trace_id = HeaderExtractor::extract_trace_id(headers);
59
60 let (req_ctx, jwt_cookie) = if should_skip {
61 (Self::untracked_context(trace_id), None)
62 } else {
63 self.tracked_context(trace_id, headers, &uri, &method)
64 .await?
65 };
66
67 tracing::debug!(
68 path = %uri.path(),
69 session_id = %req_ctx.session_id(),
70 "Session middleware setting context"
71 );
72
73 request.extensions_mut().insert(req_ctx);
74
75 let mut response = next.run(request).await;
76
77 if let Some(token) = jwt_cookie {
78 let cookie =
79 format!("access_token={token}; HttpOnly; SameSite=Strict; Path=/; Max-Age=604800");
80 if let Ok(cookie_value) = cookie.parse() {
81 response
82 .headers_mut()
83 .insert(header::SET_COOKIE, cookie_value);
84 }
85 }
86
87 Ok(response)
88 }
89
90 fn untracked_context(trace_id: systemprompt_identifiers::TraceId) -> RequestContext {
91 RequestContext::new(
92 SessionId::new(format!("untracked_{}", Uuid::new_v4())),
93 trace_id,
94 ContextId::generate(),
95 AgentName::system(),
96 )
97 .with_actor(systemprompt_identifiers::Actor::anonymous(
98 systemprompt_identifiers::bootstrap::anonymous(),
99 ))
100 .with_user_type(UserType::Anon)
101 .with_tracked(false)
102 }
103
104 fn bot_context(trace_id: systemprompt_identifiers::TraceId) -> RequestContext {
105 RequestContext::new(
106 SessionId::new(format!("bot_{}", Uuid::new_v4())),
107 trace_id,
108 ContextId::generate(),
109 AgentName::system(),
110 )
111 .with_actor(systemprompt_identifiers::Actor::user(
112 systemprompt_identifiers::bootstrap::bot(),
113 ))
114 .with_user_type(UserType::Anon)
115 .with_tracked(false)
116 }
117
118 async fn tracked_context(
119 &self,
120 trace_id: systemprompt_identifiers::TraceId,
121 headers: &http::HeaderMap,
122 uri: &http::Uri,
123 method: &http::Method,
124 ) -> Result<(RequestContext, Option<String>), ApiError> {
125 let analytics = self.analytics_service.extract_analytics(headers, Some(uri));
126 let is_bot = AnalyticsService::is_bot(&analytics);
127
128 tracing::debug!(
129 path = %uri.path(),
130 is_bot = is_bot,
131 user_agent = ?analytics.user_agent,
132 "Session middleware bot check"
133 );
134
135 if is_bot {
136 return Ok((Self::bot_context(trace_id), None));
137 }
138
139 let token_result = TokenExtractor::browser_only().extract(headers).ok();
140
141 let (session_id, user_id, jwt_token, jwt_cookie, fingerprint_hash) = self
142 .resolve_session(token_result, headers, uri, method)
143 .await?;
144
145 let context_id =
146 HeaderExtractor::extract_context_id(headers).unwrap_or_else(ContextId::generate);
147
148 let mut ctx = RequestContext::new(session_id, trace_id, context_id, AgentName::system())
149 .with_actor(systemprompt_identifiers::Actor::user(user_id))
150 .with_auth_token(jwt_token)
151 .with_user_type(UserType::Anon)
152 .with_tracked(true);
153 if let Some(fp) = fingerprint_hash {
154 ctx = ctx.with_fingerprint_hash(fp);
155 }
156 Ok((ctx, jwt_cookie))
157 }
158
159 async fn resolve_session(
160 &self,
161 token_result: Option<String>,
162 headers: &http::HeaderMap,
163 uri: &http::Uri,
164 method: &http::Method,
165 ) -> Result<(SessionId, UserId, String, Option<String>, Option<String>), ApiError> {
166 let Some(token) = token_result else {
167 let (sid, uid, token, is_new, fp) =
168 lifecycle::create_new_session(&self.session_creation_service, headers, uri, method)
169 .await?;
170 let jwt_cookie = if is_new { Some(token.clone()) } else { None };
171 return Ok((sid, uid, token, jwt_cookie, Some(fp)));
172 };
173
174 let Ok(jwt_context) = extract_user_context(&token) else {
175 let (sid, uid, token, is_new, fp) =
176 lifecycle::create_new_session(&self.session_creation_service, headers, uri, method)
177 .await?;
178 let jwt_cookie = if is_new { Some(token.clone()) } else { None };
179 return Ok((sid, uid, token, jwt_cookie, Some(fp)));
180 };
181
182 let session_exists = self
183 .analytics_service
184 .find_active_session_by_id(&jwt_context.session_id)
185 .await
186 .map_err(|e| {
187 tracing::warn!(error = %e, "find_active_session_by_id failed");
188 e
189 })
190 .ok()
191 .flatten()
192 .is_some();
193
194 if session_exists {
195 return Ok((
196 jwt_context.session_id,
197 jwt_context.user_id,
198 token,
199 None,
200 None,
201 ));
202 }
203
204 tracing::info!(
205 old_session_id = %jwt_context.session_id,
206 user_id = %jwt_context.user_id,
207 "JWT valid but session missing, refreshing with new session"
208 );
209 match lifecycle::refresh_session_for_user(
210 &self.session_creation_service,
211 &self.analytics_service,
212 &jwt_context.user_id,
213 headers,
214 uri,
215 )
216 .await
217 {
218 Ok((sid, uid, new_token, _, fp)) => {
219 Ok((sid, uid, new_token.clone(), Some(new_token), Some(fp)))
220 },
221 Err(e) if e.error_key.as_deref() == Some("user_not_found") => {
222 tracing::warn!(
223 user_id = %jwt_context.user_id,
224 "JWT references non-existent user, creating new anonymous session"
225 );
226 let (sid, uid, token, _, fp) = lifecycle::create_new_session(
227 &self.session_creation_service,
228 headers,
229 uri,
230 method,
231 )
232 .await?;
233 Ok((sid, uid, token.clone(), Some(token), Some(fp)))
234 },
235 Err(e) => Err(e),
236 }
237 }
238}