Skip to main content

systemprompt_oauth/services/session/
mod.rs

1mod creation;
2mod lookup;
3
4use anyhow::Result;
5use http::{HeaderMap, Uri};
6use std::collections::HashMap;
7use std::sync::Arc;
8use tokio::sync::RwLock;
9use uuid::Uuid;
10
11use systemprompt_identifiers::{ClientId, SessionId, SessionSource, UserId};
12use systemprompt_traits::{
13    AnalyticsProvider, CreateSessionInput, FingerprintProvider, SessionAnalytics, UserEvent,
14    UserEventPublisher, UserProvider,
15};
16
17const MAX_SESSION_AGE_SECONDS: i64 = 7 * 24 * 60 * 60;
18
19#[derive(Debug, thiserror::Error)]
20pub enum SessionCreationError {
21    #[error("User not found: {user_id}")]
22    UserNotFound { user_id: String },
23
24    #[error("Session creation failed: {0}")]
25    Internal(String),
26}
27
28struct SessionCreationParams<'a> {
29    analytics: SessionAnalytics,
30    is_bot: bool,
31    fingerprint: String,
32    client_id: &'a ClientId,
33    jwt_secret: &'a str,
34    session_source: SessionSource,
35}
36
37#[derive(Debug, Clone)]
38pub struct AnonymousSessionInfo {
39    pub session_id: SessionId,
40    pub user_id: UserId,
41    pub is_new: bool,
42    pub jwt_token: String,
43    pub fingerprint_hash: String,
44}
45
46#[derive(Debug, Clone)]
47pub struct AuthenticatedSessionInfo {
48    pub session_id: SessionId,
49}
50
51#[derive(Debug)]
52pub struct CreateAnonymousSessionInput<'a> {
53    pub headers: &'a HeaderMap,
54    pub uri: Option<&'a Uri>,
55    pub client_id: &'a ClientId,
56    pub jwt_secret: &'a str,
57    pub session_source: SessionSource,
58}
59
60#[derive(Clone)]
61pub struct SessionCreationService {
62    analytics_provider: Arc<dyn AnalyticsProvider>,
63    user_provider: Arc<dyn UserProvider>,
64    fingerprint_locks: Arc<RwLock<HashMap<String, Arc<tokio::sync::Mutex<()>>>>>,
65    event_publisher: Option<Arc<dyn UserEventPublisher>>,
66    fingerprint_provider: Option<Arc<dyn FingerprintProvider>>,
67}
68
69impl std::fmt::Debug for SessionCreationService {
70    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71        f.debug_struct("SessionCreationService")
72            .field("analytics_provider", &"<provider>")
73            .field(
74                "event_publisher",
75                &self.event_publisher.as_ref().map(|_| "<publisher>"),
76            )
77            .finish_non_exhaustive()
78    }
79}
80
81impl SessionCreationService {
82    pub fn new(
83        analytics_provider: Arc<dyn AnalyticsProvider>,
84        user_provider: Arc<dyn UserProvider>,
85    ) -> Self {
86        Self {
87            analytics_provider,
88            user_provider,
89            fingerprint_locks: Arc::new(RwLock::new(HashMap::new())),
90            event_publisher: None,
91            fingerprint_provider: None,
92        }
93    }
94
95    pub fn with_event_publisher(mut self, publisher: Arc<dyn UserEventPublisher>) -> Self {
96        self.event_publisher = Some(publisher);
97        self
98    }
99
100    pub fn with_fingerprint_provider(mut self, provider: Arc<dyn FingerprintProvider>) -> Self {
101        self.fingerprint_provider = Some(provider);
102        self
103    }
104
105    fn publish_event(&self, event: UserEvent) {
106        if let Some(ref publisher) = self.event_publisher {
107            publisher.publish_user_event(event);
108        }
109    }
110
111    pub async fn create_anonymous_session(
112        &self,
113        input: CreateAnonymousSessionInput<'_>,
114    ) -> Result<AnonymousSessionInfo> {
115        let analytics = self
116            .analytics_provider
117            .extract_analytics(input.headers, input.uri);
118        let is_bot = analytics.is_bot();
119        let fingerprint = analytics.compute_fingerprint();
120
121        let params = SessionCreationParams {
122            analytics,
123            is_bot,
124            fingerprint,
125            client_id: input.client_id,
126            jwt_secret: input.jwt_secret,
127            session_source: input.session_source,
128        };
129        self.create_session_internal(params).await
130    }
131
132    pub async fn create_authenticated_session(
133        &self,
134        user_id: &UserId,
135        headers: &HeaderMap,
136        session_source: SessionSource,
137    ) -> Result<SessionId, SessionCreationError> {
138        let user = self
139            .user_provider
140            .find_by_id(user_id.as_str())
141            .await
142            .map_err(|e| SessionCreationError::Internal(e.to_string()))?;
143
144        if user.is_none() {
145            return Err(SessionCreationError::UserNotFound {
146                user_id: user_id.to_string(),
147            });
148        }
149
150        let session_id = SessionId::new(format!("sess_{}", Uuid::new_v4()));
151        let analytics = self.analytics_provider.extract_analytics(headers, None);
152        let is_bot = analytics.is_bot();
153
154        let global_config = systemprompt_models::Config::get()
155            .map_err(|e| SessionCreationError::Internal(e.to_string()))?;
156        let expires_at = chrono::Utc::now()
157            + chrono::Duration::seconds(global_config.jwt_access_token_expiration);
158
159        self.analytics_provider
160            .create_session(CreateSessionInput {
161                session_id: &session_id,
162                user_id: Some(user_id),
163                analytics: &analytics,
164                session_source,
165                is_bot,
166                expires_at,
167            })
168            .await
169            .map_err(|e| SessionCreationError::Internal(e.to_string()))?;
170
171        self.publish_event(UserEvent::SessionCreated {
172            user_id: user_id.to_string(),
173            session_id: session_id.to_string(),
174        });
175
176        Ok(session_id)
177    }
178
179    async fn create_session_internal(
180        &self,
181        params: SessionCreationParams<'_>,
182    ) -> Result<AnonymousSessionInfo> {
183        let _guard = self.acquire_fingerprint_lock(&params.fingerprint).await;
184
185        self.update_fingerprint_if_available(&params.fingerprint, &params.analytics)
186            .await;
187
188        if let Some(session) = self
189            .try_reuse_session_at_limit(&params.fingerprint, params.client_id, params.jwt_secret)
190            .await
191        {
192            return Ok(session);
193        }
194
195        if let Some(session) = self
196            .try_find_existing_session(&params.fingerprint, params.client_id, params.jwt_secret)
197            .await
198        {
199            return Ok(session);
200        }
201
202        self.create_new_session(params).await
203    }
204
205    async fn acquire_fingerprint_lock(
206        &self,
207        fingerprint: &str,
208    ) -> tokio::sync::OwnedMutexGuard<()> {
209        let lock = {
210            let mut locks = self.fingerprint_locks.write().await;
211            Arc::clone(
212                locks
213                    .entry(fingerprint.to_string())
214                    .or_insert_with(|| Arc::new(tokio::sync::Mutex::new(()))),
215            )
216        };
217        lock.lock_owned().await
218    }
219
220    async fn update_fingerprint_if_available(
221        &self,
222        fingerprint: &str,
223        analytics: &SessionAnalytics,
224    ) {
225        let Some(ref fp_provider) = self.fingerprint_provider else {
226            return;
227        };
228
229        if let Err(e) = fp_provider
230            .upsert_fingerprint(
231                fingerprint,
232                analytics.ip_address.as_deref(),
233                analytics.user_agent.as_deref(),
234                None,
235            )
236            .await
237        {
238            tracing::warn!(error = %e, fingerprint = %fingerprint, "Failed to upsert fingerprint");
239        }
240    }
241}