systemprompt_oauth/services/session/
mod.rs1mod creation;
2mod lookup;
3
4use anyhow::Result;
5use http::{HeaderMap, Uri};
6use std::collections::HashMap;
7use std::sync::Arc;
8use tokio::sync::RwLock;
9use uuid::Uuid;
10
11use systemprompt_identifiers::{ClientId, SessionId, SessionSource, UserId};
12use systemprompt_traits::{
13 AnalyticsProvider, CreateSessionInput, FingerprintProvider, SessionAnalytics, UserEvent,
14 UserEventPublisher, UserProvider,
15};
16
17const MAX_SESSION_AGE_SECONDS: i64 = 7 * 24 * 60 * 60;
18
19#[derive(Debug, thiserror::Error)]
20pub enum SessionCreationError {
21 #[error("User not found: {user_id}")]
22 UserNotFound { user_id: String },
23
24 #[error("Session creation failed: {0}")]
25 Internal(String),
26}
27
28struct SessionCreationParams<'a> {
29 analytics: SessionAnalytics,
30 is_bot: bool,
31 fingerprint: String,
32 client_id: &'a ClientId,
33 jwt_secret: &'a str,
34 session_source: SessionSource,
35}
36
37#[derive(Debug, Clone)]
38pub struct AnonymousSessionInfo {
39 pub session_id: SessionId,
40 pub user_id: UserId,
41 pub is_new: bool,
42 pub jwt_token: String,
43 pub fingerprint_hash: String,
44}
45
46#[derive(Debug, Clone)]
47pub struct AuthenticatedSessionInfo {
48 pub session_id: SessionId,
49}
50
51#[derive(Debug)]
52pub struct CreateAnonymousSessionInput<'a> {
53 pub headers: &'a HeaderMap,
54 pub uri: Option<&'a Uri>,
55 pub client_id: &'a ClientId,
56 pub jwt_secret: &'a str,
57 pub session_source: SessionSource,
58}
59
60#[derive(Clone)]
61pub struct SessionCreationService {
62 analytics_provider: Arc<dyn AnalyticsProvider>,
63 user_provider: Arc<dyn UserProvider>,
64 fingerprint_locks: Arc<RwLock<HashMap<String, Arc<tokio::sync::Mutex<()>>>>>,
65 event_publisher: Option<Arc<dyn UserEventPublisher>>,
66 fingerprint_provider: Option<Arc<dyn FingerprintProvider>>,
67}
68
69impl std::fmt::Debug for SessionCreationService {
70 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
71 f.debug_struct("SessionCreationService")
72 .field("analytics_provider", &"<provider>")
73 .field(
74 "event_publisher",
75 &self.event_publisher.as_ref().map(|_| "<publisher>"),
76 )
77 .finish_non_exhaustive()
78 }
79}
80
81impl SessionCreationService {
82 pub fn new(
83 analytics_provider: Arc<dyn AnalyticsProvider>,
84 user_provider: Arc<dyn UserProvider>,
85 ) -> Self {
86 Self {
87 analytics_provider,
88 user_provider,
89 fingerprint_locks: Arc::new(RwLock::new(HashMap::new())),
90 event_publisher: None,
91 fingerprint_provider: None,
92 }
93 }
94
95 pub fn with_event_publisher(mut self, publisher: Arc<dyn UserEventPublisher>) -> Self {
96 self.event_publisher = Some(publisher);
97 self
98 }
99
100 pub fn with_fingerprint_provider(mut self, provider: Arc<dyn FingerprintProvider>) -> Self {
101 self.fingerprint_provider = Some(provider);
102 self
103 }
104
105 fn publish_event(&self, event: UserEvent) {
106 if let Some(ref publisher) = self.event_publisher {
107 publisher.publish_user_event(event);
108 }
109 }
110
111 pub async fn create_anonymous_session(
112 &self,
113 input: CreateAnonymousSessionInput<'_>,
114 ) -> Result<AnonymousSessionInfo> {
115 let analytics = self
116 .analytics_provider
117 .extract_analytics(input.headers, input.uri);
118 let is_bot = analytics.is_bot();
119 let fingerprint = analytics.compute_fingerprint();
120
121 let params = SessionCreationParams {
122 analytics,
123 is_bot,
124 fingerprint,
125 client_id: input.client_id,
126 jwt_secret: input.jwt_secret,
127 session_source: input.session_source,
128 };
129 self.create_session_internal(params).await
130 }
131
132 pub async fn create_authenticated_session(
133 &self,
134 user_id: &UserId,
135 headers: &HeaderMap,
136 session_source: SessionSource,
137 ) -> Result<SessionId, SessionCreationError> {
138 let user = self
139 .user_provider
140 .find_by_id(user_id.as_str())
141 .await
142 .map_err(|e| SessionCreationError::Internal(e.to_string()))?;
143
144 if user.is_none() {
145 return Err(SessionCreationError::UserNotFound {
146 user_id: user_id.to_string(),
147 });
148 }
149
150 let session_id = SessionId::new(format!("sess_{}", Uuid::new_v4()));
151 let analytics = self.analytics_provider.extract_analytics(headers, None);
152 let is_bot = analytics.is_bot();
153
154 let global_config = systemprompt_models::Config::get()
155 .map_err(|e| SessionCreationError::Internal(e.to_string()))?;
156 let expires_at = chrono::Utc::now()
157 + chrono::Duration::seconds(global_config.jwt_access_token_expiration);
158
159 self.analytics_provider
160 .create_session(CreateSessionInput {
161 session_id: &session_id,
162 user_id: Some(user_id),
163 analytics: &analytics,
164 session_source,
165 is_bot,
166 expires_at,
167 })
168 .await
169 .map_err(|e| SessionCreationError::Internal(e.to_string()))?;
170
171 self.publish_event(UserEvent::SessionCreated {
172 user_id: user_id.to_string(),
173 session_id: session_id.to_string(),
174 });
175
176 Ok(session_id)
177 }
178
179 async fn create_session_internal(
180 &self,
181 params: SessionCreationParams<'_>,
182 ) -> Result<AnonymousSessionInfo> {
183 let _guard = self.acquire_fingerprint_lock(¶ms.fingerprint).await;
184
185 self.update_fingerprint_if_available(¶ms.fingerprint, ¶ms.analytics)
186 .await;
187
188 if let Some(session) = self
189 .try_reuse_session_at_limit(¶ms.fingerprint, params.client_id, params.jwt_secret)
190 .await
191 {
192 return Ok(session);
193 }
194
195 if let Some(session) = self
196 .try_find_existing_session(¶ms.fingerprint, params.client_id, params.jwt_secret)
197 .await
198 {
199 return Ok(session);
200 }
201
202 self.create_new_session(params).await
203 }
204
205 async fn acquire_fingerprint_lock(
206 &self,
207 fingerprint: &str,
208 ) -> tokio::sync::OwnedMutexGuard<()> {
209 let lock = {
210 let mut locks = self.fingerprint_locks.write().await;
211 Arc::clone(
212 locks
213 .entry(fingerprint.to_string())
214 .or_insert_with(|| Arc::new(tokio::sync::Mutex::new(()))),
215 )
216 };
217 lock.lock_owned().await
218 }
219
220 async fn update_fingerprint_if_available(
221 &self,
222 fingerprint: &str,
223 analytics: &SessionAnalytics,
224 ) {
225 let Some(ref fp_provider) = self.fingerprint_provider else {
226 return;
227 };
228
229 if let Err(e) = fp_provider
230 .upsert_fingerprint(
231 fingerprint,
232 analytics.ip_address.as_deref(),
233 analytics.user_agent.as_deref(),
234 None,
235 )
236 .await
237 {
238 tracing::warn!(error = %e, fingerprint = %fingerprint, "Failed to upsert fingerprint");
239 }
240 }
241}