Skip to main content

systemprompt_api/services/middleware/session/
mod.rs

1mod lifecycle;
2mod skip;
3
4pub use skip::should_skip_session_tracking;
5
6use axum::extract::Request;
7use axum::http::header;
8use axum::middleware::Next;
9use axum::response::Response;
10use std::sync::Arc;
11use systemprompt_analytics::AnalyticsService;
12use systemprompt_identifiers::{AgentName, ContextId, SessionId, UserId};
13use systemprompt_models::api::ApiError;
14use systemprompt_models::auth::UserType;
15use systemprompt_models::execution::context::RequestContext;
16use systemprompt_oauth::services::SessionCreationService;
17use systemprompt_runtime::AppContext;
18use systemprompt_security::{HeaderExtractor, TokenExtractor, extract_user_context};
19use systemprompt_traits::AnalyticsProvider;
20use systemprompt_users::{UserProviderImpl, UserService};
21use uuid::Uuid;
22
23#[derive(Clone, Debug)]
24pub struct SessionMiddleware {
25    analytics_service: Arc<AnalyticsService>,
26    session_creation_service: Arc<SessionCreationService>,
27}
28
29impl SessionMiddleware {
30    pub fn new(ctx: &AppContext) -> anyhow::Result<Self> {
31        let user_service = UserService::new(ctx.db_pool())?;
32        let concrete = Arc::clone(ctx.analytics_service());
33        let analytics: Arc<dyn AnalyticsProvider> = concrete;
34        let session_creation_service = Arc::new(SessionCreationService::new(
35            analytics,
36            Arc::new(UserProviderImpl::new(user_service)),
37        ));
38
39        Ok(Self {
40            analytics_service: Arc::clone(ctx.analytics_service()),
41            session_creation_service,
42        })
43    }
44
45    pub async fn handle(&self, mut request: Request, next: Next) -> Result<Response, ApiError> {
46        let headers = request.headers();
47        let uri = request.uri().clone();
48        let method = request.method().clone();
49
50        let should_skip = should_skip_session_tracking(uri.path());
51
52        tracing::debug!(
53            path = %uri.path(),
54            should_skip = should_skip,
55            "Session middleware evaluating request"
56        );
57
58        let trace_id = HeaderExtractor::extract_trace_id(headers);
59
60        let (req_ctx, jwt_cookie) = if should_skip {
61            (Self::untracked_context(trace_id), None)
62        } else {
63            self.tracked_context(trace_id, headers, &uri, &method)
64                .await?
65        };
66
67        tracing::debug!(
68            path = %uri.path(),
69            session_id = %req_ctx.session_id(),
70            "Session middleware setting context"
71        );
72
73        request.extensions_mut().insert(req_ctx);
74
75        let mut response = next.run(request).await;
76
77        if let Some(token) = jwt_cookie {
78            let cookie =
79                format!("access_token={token}; HttpOnly; SameSite=Strict; Path=/; Max-Age=604800");
80            if let Ok(cookie_value) = cookie.parse() {
81                response
82                    .headers_mut()
83                    .insert(header::SET_COOKIE, cookie_value);
84            }
85        }
86
87        Ok(response)
88    }
89
90    fn untracked_context(trace_id: systemprompt_identifiers::TraceId) -> RequestContext {
91        RequestContext::new(
92            SessionId::new(format!("untracked_{}", Uuid::new_v4())),
93            trace_id,
94            ContextId::generate(),
95            AgentName::system(),
96        )
97        .with_actor(systemprompt_identifiers::Actor::anonymous(
98            systemprompt_identifiers::bootstrap::anonymous(),
99        ))
100        .with_user_type(UserType::Anon)
101        .with_tracked(false)
102    }
103
104    fn bot_context(trace_id: systemprompt_identifiers::TraceId) -> RequestContext {
105        RequestContext::new(
106            SessionId::new(format!("bot_{}", Uuid::new_v4())),
107            trace_id,
108            ContextId::generate(),
109            AgentName::system(),
110        )
111        .with_actor(systemprompt_identifiers::Actor::user(
112            systemprompt_identifiers::bootstrap::bot(),
113        ))
114        .with_user_type(UserType::Anon)
115        .with_tracked(false)
116    }
117
118    async fn tracked_context(
119        &self,
120        trace_id: systemprompt_identifiers::TraceId,
121        headers: &http::HeaderMap,
122        uri: &http::Uri,
123        method: &http::Method,
124    ) -> Result<(RequestContext, Option<String>), ApiError> {
125        let analytics = self.analytics_service.extract_analytics(headers, Some(uri));
126        let is_bot = AnalyticsService::is_bot(&analytics);
127
128        tracing::debug!(
129            path = %uri.path(),
130            is_bot = is_bot,
131            user_agent = ?analytics.user_agent,
132            "Session middleware bot check"
133        );
134
135        if is_bot {
136            return Ok((Self::bot_context(trace_id), None));
137        }
138
139        let token_result = TokenExtractor::browser_only().extract(headers).ok();
140
141        let (session_id, user_id, jwt_token, jwt_cookie, fingerprint_hash) = self
142            .resolve_session(token_result, headers, uri, method)
143            .await?;
144
145        let context_id =
146            HeaderExtractor::extract_context_id(headers).unwrap_or_else(ContextId::generate);
147
148        let mut ctx = RequestContext::new(session_id, trace_id, context_id, AgentName::system())
149            .with_actor(systemprompt_identifiers::Actor::user(user_id))
150            .with_auth_token(jwt_token)
151            .with_user_type(UserType::Anon)
152            .with_tracked(true);
153        if let Some(fp) = fingerprint_hash {
154            ctx = ctx.with_fingerprint_hash(fp);
155        }
156        Ok((ctx, jwt_cookie))
157    }
158
159    async fn resolve_session(
160        &self,
161        token_result: Option<String>,
162        headers: &http::HeaderMap,
163        uri: &http::Uri,
164        method: &http::Method,
165    ) -> Result<(SessionId, UserId, String, Option<String>, Option<String>), ApiError> {
166        let Some(token) = token_result else {
167            let (sid, uid, token, is_new, fp) =
168                lifecycle::create_new_session(&self.session_creation_service, headers, uri, method)
169                    .await?;
170            let jwt_cookie = if is_new { Some(token.clone()) } else { None };
171            return Ok((sid, uid, token, jwt_cookie, Some(fp)));
172        };
173
174        let Ok(jwt_context) = extract_user_context(&token) else {
175            let (sid, uid, token, is_new, fp) =
176                lifecycle::create_new_session(&self.session_creation_service, headers, uri, method)
177                    .await?;
178            let jwt_cookie = if is_new { Some(token.clone()) } else { None };
179            return Ok((sid, uid, token, jwt_cookie, Some(fp)));
180        };
181
182        let session_exists = self
183            .analytics_service
184            .find_active_session_by_id(&jwt_context.session_id)
185            .await
186            .map_err(|e| {
187                tracing::warn!(error = %e, "find_active_session_by_id failed");
188                e
189            })
190            .ok()
191            .flatten()
192            .is_some();
193
194        if session_exists {
195            return Ok((
196                jwt_context.session_id,
197                jwt_context.user_id,
198                token,
199                None,
200                None,
201            ));
202        }
203
204        tracing::info!(
205            old_session_id = %jwt_context.session_id,
206            user_id = %jwt_context.user_id,
207            "JWT valid but session missing, refreshing with new session"
208        );
209        match lifecycle::refresh_session_for_user(
210            &self.session_creation_service,
211            &self.analytics_service,
212            &jwt_context.user_id,
213            headers,
214            uri,
215        )
216        .await
217        {
218            Ok((sid, uid, new_token, _, fp)) => {
219                Ok((sid, uid, new_token.clone(), Some(new_token), Some(fp)))
220            },
221            Err(e) if e.error_key.as_deref() == Some("user_not_found") => {
222                tracing::warn!(
223                    user_id = %jwt_context.user_id,
224                    "JWT references non-existent user, creating new anonymous session"
225                );
226                let (sid, uid, token, _, fp) = lifecycle::create_new_session(
227                    &self.session_creation_service,
228                    headers,
229                    uri,
230                    method,
231                )
232                .await?;
233                Ok((sid, uid, token.clone(), Some(token), Some(fp)))
234            },
235            Err(e) => Err(e),
236        }
237    }
238}