1use std::{
9 collections::HashMap,
10 fmt,
11 sync::{Arc, Mutex},
12 time::{Duration, SystemTime},
13};
14
15use axum::http::{HeaderMap, HeaderName};
16use serde_json::Value;
17use thiserror::Error;
18use uuid::Uuid;
19
20use crate::config::{SessionConfig, SessionFallbackScope};
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub enum SessionScope {
25 Agent,
27 Request,
29}
30
31impl SessionScope {
32 pub fn as_str(self) -> &'static str {
34 match self {
35 Self::Agent => "agent",
36 Self::Request => "request",
37 }
38 }
39}
40
41impl fmt::Display for SessionScope {
42 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
44 f.write_str(self.as_str())
45 }
46}
47
48#[derive(Debug, Clone, Copy, PartialEq, Eq)]
50pub enum SessionExpirationReason {
51 IdleTtl,
52 MaxTtl,
53 MaxRequests,
54}
55
56#[derive(Debug, Clone, PartialEq, Eq)]
58pub struct SessionContext {
59 pub session_key: String,
60 pub model_id: String,
61 pub agent_session_id: String,
62 pub scope: SessionScope,
63 pub created_at: SystemTime,
64 pub last_used_at: SystemTime,
65 pub expires_at: SystemTime,
66 pub request_count: u64,
67 pub attested_model_public_key: Option<String>,
68 pub attestation_report: Option<Value>,
69 pub verified_at: Option<SystemTime>,
70}
71
72#[derive(Debug, Clone, PartialEq, Eq)]
74pub struct SessionResolution {
75 pub session: SessionContext,
76 pub created: bool,
77 pub replaced_expired: Option<SessionExpirationReason>,
78}
79
80#[derive(Debug, Clone, PartialEq, Eq)]
82pub struct AttestedModelState {
83 pub model_public_key: String,
84 pub attestation_report: Value,
85 pub verified_at: SystemTime,
86}
87
88#[derive(Debug, Clone, Copy)]
90pub struct SessionRequest<'a> {
91 pub model_id: &'a str,
92 pub headers: &'a HeaderMap,
93 pub body: Option<&'a Value>,
94}
95
96impl<'a> SessionRequest<'a> {
97 pub fn new(model_id: &'a str, headers: &'a HeaderMap) -> Self {
99 Self {
100 model_id,
101 headers,
102 body: None,
103 }
104 }
105
106 pub fn with_body(mut self, body: &'a Value) -> Self {
108 self.body = Some(body);
109 self
110 }
111}
112
113#[derive(Debug, Clone)]
115pub struct SessionManager {
116 config: SessionConfig,
117 sessions: Arc<Mutex<HashMap<String, SessionContext>>>,
118 agent_fallback_session_id: Arc<str>,
119}
120
121impl SessionManager {
122 pub fn new(config: SessionConfig) -> Self {
124 Self {
125 config,
126 sessions: Arc::new(Mutex::new(HashMap::new())),
127 agent_fallback_session_id: Arc::from(Uuid::new_v4().to_string()),
128 }
129 }
130
131 pub fn get_or_create(
134 &self,
135 request: SessionRequest<'_>,
136 ) -> Result<SessionResolution, SessionError> {
137 self.get_or_create_at(request, SystemTime::now())
138 }
139
140 pub fn get_or_create_at(
142 &self,
143 request: SessionRequest<'_>,
144 now: SystemTime,
145 ) -> Result<SessionResolution, SessionError> {
146 if request.model_id.trim().is_empty() {
147 return Err(SessionError::InvalidModelId);
148 }
149
150 let resolved = self.resolve_identifier(request)?;
151 let session_key = session_key(request.model_id, &resolved.agent_session_id);
152 let mut sessions = self.lock_sessions();
153 let replaced_expired = match sessions.get(&session_key) {
154 Some(existing) => self.expiration_reason(existing, now),
155 None => None,
156 };
157
158 if replaced_expired.is_some() {
159 sessions.remove(&session_key);
160 }
161
162 if let Some(existing) = sessions.get_mut(&session_key) {
163 existing.request_count += 1;
164 existing.last_used_at = now;
165 return Ok(SessionResolution {
166 session: existing.clone(),
167 created: false,
168 replaced_expired: None,
169 });
170 }
171
172 let context = SessionContext::new(
173 request.model_id,
174 resolved.agent_session_id,
175 resolved.scope,
176 now,
177 &self.config,
178 );
179 sessions.insert(session_key, context.clone());
180
181 Ok(SessionResolution {
182 session: context,
183 created: true,
184 replaced_expired,
185 })
186 }
187
188 pub fn set_attested_model_state(
190 &self,
191 session_key: &str,
192 state: AttestedModelState,
193 ) -> Result<SessionContext, SessionError> {
194 self.set_attested_model_state_at(session_key, state, SystemTime::now())
195 }
196
197 pub fn set_attested_model_state_at(
199 &self,
200 session_key: &str,
201 state: AttestedModelState,
202 now: SystemTime,
203 ) -> Result<SessionContext, SessionError> {
204 let mut sessions = self.lock_sessions();
205 let expired = sessions
206 .get(session_key)
207 .and_then(|session| self.expiration_reason(session, now));
208
209 if let Some(reason) = expired {
210 sessions.remove(session_key);
211 return Err(SessionError::SessionExpired { reason });
212 }
213
214 let session =
215 sessions
216 .get_mut(session_key)
217 .ok_or_else(|| SessionError::SessionNotFound {
218 session_key: session_key.to_owned(),
219 })?;
220 session.attested_model_public_key = Some(state.model_public_key);
221 session.attestation_report = Some(state.attestation_report);
222 session.verified_at = Some(state.verified_at);
223
224 Ok(session.clone())
225 }
226
227 pub fn cleanup_expired(&self) -> usize {
229 self.cleanup_expired_at(SystemTime::now())
230 }
231
232 pub fn cleanup_expired_at(&self, now: SystemTime) -> usize {
234 let mut sessions = self.lock_sessions();
235 let before = sessions.len();
236 sessions.retain(|_, session| self.expiration_reason(session, now).is_none());
237 before - sessions.len()
238 }
239
240 pub fn len(&self) -> usize {
242 self.lock_sessions().len()
243 }
244
245 pub fn is_empty(&self) -> bool {
247 self.len() == 0
248 }
249
250 fn resolve_identifier(
252 &self,
253 request: SessionRequest<'_>,
254 ) -> Result<ResolvedSessionIdentifier, SessionError> {
255 if let Some(value) = self.explicit_identifier(&request)? {
256 return Ok(ResolvedSessionIdentifier::agent(value));
257 }
258
259 match self.config.fallback_scope {
260 SessionFallbackScope::Agent => Ok(ResolvedSessionIdentifier::agent(
261 self.agent_fallback_session_id.to_string(),
262 )),
263 SessionFallbackScope::Request => Ok(ResolvedSessionIdentifier {
264 agent_session_id: Uuid::new_v4().to_string(),
265 scope: SessionScope::Request,
266 }),
267 SessionFallbackScope::Disabled => Err(SessionError::MissingSessionIdentifier),
268 }
269 }
270
271 fn explicit_identifier(
273 &self,
274 request: &SessionRequest<'_>,
275 ) -> Result<Option<String>, SessionError> {
276 if let Some(value) = header_identifier(request.headers, &self.config.headers.preferred)? {
277 return Ok(Some(value));
278 }
279
280 if let Some(value) = header_identifier(request.headers, &self.config.headers.open_webui)? {
281 return Ok(Some(value));
282 }
283
284 Ok(metadata_identifier(request.body, "session_id")
285 .or_else(|| metadata_identifier(request.body, "chat_id")))
286 }
287
288 fn expiration_reason(
290 &self,
291 session: &SessionContext,
292 now: SystemTime,
293 ) -> Option<SessionExpirationReason> {
294 if session.request_count >= self.config.max_requests {
295 return Some(SessionExpirationReason::MaxRequests);
296 }
297
298 if now >= session.expires_at {
299 return Some(SessionExpirationReason::MaxTtl);
300 }
301
302 if elapsed_since(session.last_used_at, now) >= self.config.idle_ttl {
303 return Some(SessionExpirationReason::IdleTtl);
304 }
305
306 None
307 }
308
309 fn lock_sessions(&self) -> std::sync::MutexGuard<'_, HashMap<String, SessionContext>> {
311 self.sessions
312 .lock()
313 .unwrap_or_else(std::sync::PoisonError::into_inner)
314 }
315}
316
317#[derive(Debug, Clone)]
319struct ResolvedSessionIdentifier {
320 agent_session_id: String,
321 scope: SessionScope,
322}
323
324impl ResolvedSessionIdentifier {
325 fn agent(agent_session_id: String) -> Self {
327 Self {
328 agent_session_id,
329 scope: SessionScope::Agent,
330 }
331 }
332}
333
334impl SessionContext {
335 fn new(
337 model_id: &str,
338 agent_session_id: String,
339 scope: SessionScope,
340 now: SystemTime,
341 config: &SessionConfig,
342 ) -> Self {
343 let session_key = session_key(model_id, &agent_session_id);
344 Self {
345 session_key,
346 model_id: model_id.to_owned(),
347 agent_session_id,
348 scope,
349 created_at: now,
350 last_used_at: now,
351 expires_at: now + config.max_ttl,
352 request_count: 1,
353 attested_model_public_key: None,
354 attestation_report: None,
355 verified_at: None,
356 }
357 }
358}
359
360#[derive(Debug, Error, PartialEq, Eq)]
362pub enum SessionError {
363 #[error("request model id must not be empty")]
364 InvalidModelId,
365 #[error("request does not include a session identifier and session fallback is disabled")]
366 MissingSessionIdentifier,
367 #[error("configured session header name {header:?} is invalid")]
368 InvalidHeaderName { header: String },
369 #[error("session header {header} contains non-UTF-8 data")]
370 InvalidHeaderValue { header: String },
371 #[error("session {session_key} was not found")]
372 SessionNotFound { session_key: String },
373 #[error("session expired before attestation state could be stored: {reason:?}")]
374 SessionExpired { reason: SessionExpirationReason },
375}
376
377fn header_identifier(
379 headers: &HeaderMap,
380 configured_name: &str,
381) -> Result<Option<String>, SessionError> {
382 let name = HeaderName::from_bytes(configured_name.as_bytes()).map_err(|_| {
383 SessionError::InvalidHeaderName {
384 header: configured_name.to_owned(),
385 }
386 })?;
387
388 let Some(value) = headers.get(&name) else {
389 return Ok(None);
390 };
391 let value = value
392 .to_str()
393 .map_err(|_| SessionError::InvalidHeaderValue {
394 header: configured_name.to_owned(),
395 })?;
396 Ok(non_empty_string(value))
397}
398
399fn metadata_identifier(body: Option<&Value>, key: &str) -> Option<String> {
401 body.and_then(|body| body.get("metadata"))
402 .and_then(|metadata| metadata.get(key))
403 .and_then(Value::as_str)
404 .and_then(non_empty_string)
405}
406
407fn non_empty_string(value: &str) -> Option<String> {
409 let trimmed = value.trim();
410 (!trimmed.is_empty()).then(|| trimmed.to_owned())
411}
412
413fn session_key(model_id: &str, agent_session_id: &str) -> String {
415 format!("{model_id}:{agent_session_id}")
416}
417
418fn elapsed_since(start: SystemTime, now: SystemTime) -> Duration {
420 now.duration_since(start).unwrap_or(Duration::ZERO)
421}
422
423#[cfg(test)]
424mod tests {
425 use super::*;
426 use axum::http::HeaderValue;
427 use serde_json::json;
428
429 fn test_config() -> SessionConfig {
430 SessionConfig {
431 idle_ttl: Duration::from_secs(10),
432 max_ttl: Duration::from_secs(30),
433 max_requests: 3,
434 fallback_scope: SessionFallbackScope::Request,
435 headers: Default::default(),
436 }
437 }
438
439 fn manager() -> SessionManager {
440 SessionManager::new(test_config())
441 }
442
443 fn now(seconds: u64) -> SystemTime {
444 SystemTime::UNIX_EPOCH + Duration::from_secs(seconds)
445 }
446
447 fn request<'a>(model_id: &'a str, headers: &'a HeaderMap) -> SessionRequest<'a> {
448 SessionRequest::new(model_id, headers)
449 }
450
451 #[test]
452 fn creates_new_agent_session_from_preferred_header() {
453 let manager = manager();
454 let mut headers = HeaderMap::new();
455 headers.insert(
456 "X-Venice-Proxy-Session-Id",
457 HeaderValue::from_static("chat-1"),
458 );
459
460 let resolved = manager
461 .get_or_create_at(request("model-a", &headers), now(0))
462 .expect("session should resolve");
463
464 assert!(resolved.created);
465 assert_eq!(resolved.replaced_expired, None);
466 assert_eq!(resolved.session.session_key, "model-a:chat-1");
467 assert_eq!(resolved.session.model_id, "model-a");
468 assert_eq!(resolved.session.agent_session_id, "chat-1");
469 assert_eq!(resolved.session.scope, SessionScope::Agent);
470 assert_eq!(resolved.session.request_count, 1);
471 }
472
473 #[test]
474 fn reuses_existing_session_from_configured_headers() {
475 let manager = manager();
476 let mut headers = HeaderMap::new();
477 headers.insert(
478 "X-OpenWebUI-Chat-Id",
479 HeaderValue::from_static("open-webui-chat"),
480 );
481
482 let first = manager
483 .get_or_create_at(request("model-a", &headers), now(0))
484 .expect("first request should create");
485 let second = manager
486 .get_or_create_at(request("model-a", &headers), now(5))
487 .expect("second request should reuse");
488
489 assert!(first.created);
490 assert!(!second.created);
491 assert_eq!(second.session.session_key, first.session.session_key);
492 assert_eq!(second.session.request_count, 2);
493 assert_eq!(second.session.last_used_at, now(5));
494 assert_eq!(manager.len(), 1);
495 }
496
497 #[test]
498 fn preferred_header_wins_over_open_webui_and_metadata() {
499 let manager = manager();
500 let mut headers = HeaderMap::new();
501 headers.insert(
502 "X-Venice-Proxy-Session-Id",
503 HeaderValue::from_static("preferred"),
504 );
505 headers.insert(
506 "X-OpenWebUI-Chat-Id",
507 HeaderValue::from_static("open-webui"),
508 );
509 let body = json!({ "metadata": { "session_id": "body-session", "chat_id": "body-chat" } });
510
511 let resolved = manager
512 .get_or_create_at(
513 SessionRequest::new("model-a", &headers).with_body(&body),
514 now(0),
515 )
516 .expect("session should resolve");
517
518 assert_eq!(resolved.session.session_key, "model-a:preferred");
519 }
520
521 #[test]
522 fn metadata_session_id_is_used_when_headers_are_missing() {
523 let manager = manager();
524 let headers = HeaderMap::new();
525 let body = json!({ "metadata": { "session_id": "metadata-session" } });
526
527 let resolved = manager
528 .get_or_create_at(
529 SessionRequest::new("model-a", &headers).with_body(&body),
530 now(0),
531 )
532 .expect("session should resolve");
533
534 assert_eq!(resolved.session.session_key, "model-a:metadata-session");
535 assert_eq!(resolved.session.scope, SessionScope::Agent);
536 }
537
538 #[test]
539 fn idle_ttl_expiration_discards_old_session_and_creates_fresh_one() {
540 let manager = manager();
541 let mut headers = HeaderMap::new();
542 headers.insert(
543 "X-Venice-Proxy-Session-Id",
544 HeaderValue::from_static("chat-1"),
545 );
546
547 let first = manager
548 .get_or_create_at(request("model-a", &headers), now(0))
549 .expect("first request should create");
550 let second = manager
551 .get_or_create_at(request("model-a", &headers), now(10))
552 .expect("idle-expired request should recreate");
553
554 assert!(second.created);
555 assert_eq!(
556 second.replaced_expired,
557 Some(SessionExpirationReason::IdleTtl)
558 );
559 assert_eq!(second.session.session_key, first.session.session_key);
560 assert_eq!(second.session.request_count, 1);
561 assert_eq!(second.session.created_at, now(10));
562 }
563
564 #[test]
565 fn max_ttl_expiration_discards_old_session_and_creates_fresh_one() {
566 let mut config = test_config();
567 config.idle_ttl = Duration::from_secs(20);
568 config.max_ttl = Duration::from_secs(30);
569 let manager = SessionManager::new(config);
570 let mut headers = HeaderMap::new();
571 headers.insert(
572 "X-Venice-Proxy-Session-Id",
573 HeaderValue::from_static("chat-1"),
574 );
575
576 let first = manager
577 .get_or_create_at(request("model-a", &headers), now(0))
578 .expect("first request should create");
579 manager
580 .get_or_create_at(request("model-a", &headers), now(15))
581 .expect("within idle ttl should reuse");
582 let third = manager
583 .get_or_create_at(request("model-a", &headers), now(30))
584 .expect("max-ttl-expired request should recreate");
585
586 assert!(third.created);
587 assert_eq!(
588 third.replaced_expired,
589 Some(SessionExpirationReason::MaxTtl)
590 );
591 assert_eq!(third.session.session_key, first.session.session_key);
592 assert_eq!(third.session.request_count, 1);
593 assert_eq!(third.session.created_at, now(30));
594 }
595
596 #[test]
597 fn max_request_expiration_discards_old_session_and_creates_fresh_one() {
598 let manager = manager();
599 let mut headers = HeaderMap::new();
600 headers.insert(
601 "X-Venice-Proxy-Session-Id",
602 HeaderValue::from_static("chat-1"),
603 );
604
605 manager
606 .get_or_create_at(request("model-a", &headers), now(0))
607 .expect("first request should create");
608 manager
609 .get_or_create_at(request("model-a", &headers), now(1))
610 .expect("second request should reuse");
611 let third = manager
612 .get_or_create_at(request("model-a", &headers), now(2))
613 .expect("third request should reuse and reach max");
614 let fourth = manager
615 .get_or_create_at(request("model-a", &headers), now(3))
616 .expect("fourth request should recreate");
617
618 assert!(!third.created);
619 assert_eq!(third.session.request_count, 3);
620 assert!(fourth.created);
621 assert_eq!(
622 fourth.replaced_expired,
623 Some(SessionExpirationReason::MaxRequests)
624 );
625 assert_eq!(fourth.session.request_count, 1);
626 }
627
628 #[test]
629 fn request_fallback_creates_distinct_request_scoped_sessions() {
630 let manager = manager();
631 let headers = HeaderMap::new();
632
633 let first = manager
634 .get_or_create_at(request("model-a", &headers), now(0))
635 .expect("fallback should create");
636 let second = manager
637 .get_or_create_at(request("model-a", &headers), now(1))
638 .expect("fallback should create again");
639
640 assert!(first.created);
641 assert!(second.created);
642 assert_eq!(first.session.scope, SessionScope::Request);
643 assert_eq!(second.session.scope, SessionScope::Request);
644 assert_ne!(
645 first.session.agent_session_id,
646 second.session.agent_session_id
647 );
648 assert_eq!(manager.len(), 2);
649 }
650
651 #[test]
652 fn agent_fallback_reuses_generated_agent_scoped_session() {
653 let mut config = test_config();
654 config.fallback_scope = SessionFallbackScope::Agent;
655 let manager = SessionManager::new(config);
656 let headers = HeaderMap::new();
657
658 let first = manager
659 .get_or_create_at(request("model-a", &headers), now(0))
660 .expect("fallback should create");
661 let second = manager
662 .get_or_create_at(request("model-a", &headers), now(1))
663 .expect("fallback should reuse");
664
665 assert!(first.created);
666 assert!(!second.created);
667 assert_eq!(first.session.scope, SessionScope::Agent);
668 assert_eq!(
669 first.session.agent_session_id,
670 second.session.agent_session_id
671 );
672 assert_eq!(second.session.request_count, 2);
673 }
674
675 #[test]
676 fn disabled_fallback_returns_clear_error_without_creating_session() {
677 let mut config = test_config();
678 config.fallback_scope = SessionFallbackScope::Disabled;
679 let manager = SessionManager::new(config);
680 let headers = HeaderMap::new();
681
682 let error = manager
683 .get_or_create_at(request("model-a", &headers), now(0))
684 .expect_err("missing session id should fail when fallback is disabled");
685
686 assert_eq!(error, SessionError::MissingSessionIdentifier);
687 assert_eq!(
688 error.to_string(),
689 "request does not include a session identifier and session fallback is disabled"
690 );
691 assert!(manager.is_empty());
692 }
693
694 #[test]
695 fn cleanup_removes_expired_sessions_and_keeps_valid_sessions() {
696 let manager = manager();
697 let mut headers_a = HeaderMap::new();
698 headers_a.insert(
699 "X-Venice-Proxy-Session-Id",
700 HeaderValue::from_static("chat-a"),
701 );
702 let mut headers_b = HeaderMap::new();
703 headers_b.insert(
704 "X-Venice-Proxy-Session-Id",
705 HeaderValue::from_static("chat-b"),
706 );
707
708 manager
709 .get_or_create_at(request("model-a", &headers_a), now(0))
710 .expect("session a should create");
711 manager
712 .get_or_create_at(request("model-a", &headers_b), now(15))
713 .expect("session b should create");
714
715 let removed = manager.cleanup_expired_at(now(20));
716
717 assert_eq!(removed, 1);
718 assert_eq!(manager.len(), 1);
719 let reused_b = manager
720 .get_or_create_at(request("model-a", &headers_b), now(21))
721 .expect("session b should remain valid");
722 assert!(!reused_b.created);
723 }
724
725 #[test]
726 fn stores_attested_model_state_on_existing_unexpired_session() {
727 let manager = manager();
728 let mut headers = HeaderMap::new();
729 headers.insert(
730 "X-Venice-Proxy-Session-Id",
731 HeaderValue::from_static("chat-1"),
732 );
733 let session = manager
734 .get_or_create_at(request("model-a", &headers), now(0))
735 .expect("session should create")
736 .session;
737
738 let updated = manager
739 .set_attested_model_state_at(
740 &session.session_key,
741 AttestedModelState {
742 model_public_key: "model-public-key".to_owned(),
743 attestation_report: json!({ "verified": true }),
744 verified_at: now(1),
745 },
746 now(1),
747 )
748 .expect("attestation state should update");
749
750 assert_eq!(
751 updated.attested_model_public_key.as_deref(),
752 Some("model-public-key")
753 );
754 assert_eq!(
755 updated.attestation_report,
756 Some(json!({ "verified": true }))
757 );
758 assert_eq!(updated.verified_at, Some(now(1)));
759 }
760}