1use std::{
9 collections::HashMap,
10 fmt,
11 sync::{Arc, Mutex},
12 time::{Duration, SystemTime},
13};
14
15use axum::http::{HeaderMap, HeaderName};
16use serde_json::Value;
17use thiserror::Error;
18use uuid::Uuid;
19
20use crate::config::{SessionConfig, SessionFallbackScope};
21
22#[derive(Debug, Clone, Copy, PartialEq, Eq)]
24pub enum SessionScope {
25 Agent,
27 Request,
29}
30
31impl SessionScope {
32 pub fn as_str(self) -> &'static str {
34 match self {
35 Self::Agent => "agent",
36 Self::Request => "request",
37 }
38 }
39}
40
41impl fmt::Display for SessionScope {
42 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
44 f.write_str(self.as_str())
45 }
46}
47
48#[derive(Debug, Clone, Copy, PartialEq, Eq)]
50pub enum SessionExpirationReason {
51 IdleTtl,
52 MaxTtl,
53 MaxRequests,
54}
55
56#[derive(Debug, Clone, PartialEq, Eq)]
58pub struct SessionContext {
59 pub session_key: String,
60 pub model_id: String,
61 pub agent_session_id: String,
62 pub scope: SessionScope,
63 pub created_at: SystemTime,
64 pub last_used_at: SystemTime,
65 pub expires_at: SystemTime,
66 pub request_count: u64,
67 pub attested_model_public_key: Option<String>,
68 pub attestation_report: Option<Value>,
69 pub verified_at: Option<SystemTime>,
70}
71
72#[derive(Debug, Clone, PartialEq, Eq)]
74pub struct SessionResolution {
75 pub session: SessionContext,
76 pub created: bool,
77 pub replaced_expired: Option<SessionExpirationReason>,
78}
79
80#[derive(Debug, Clone, PartialEq, Eq)]
82pub struct AttestedModelState {
83 pub model_public_key: String,
84 pub attestation_report: Value,
85 pub verified_at: SystemTime,
86}
87
88#[derive(Debug, Clone, Copy)]
90pub struct SessionRequest<'a> {
91 pub model_id: &'a str,
92 pub headers: &'a HeaderMap,
93 pub body: Option<&'a Value>,
94}
95
96impl<'a> SessionRequest<'a> {
97 pub fn new(model_id: &'a str, headers: &'a HeaderMap) -> Self {
99 Self {
100 model_id,
101 headers,
102 body: None,
103 }
104 }
105
106 pub fn with_body(mut self, body: &'a Value) -> Self {
108 self.body = Some(body);
109 self
110 }
111}
112
113#[derive(Debug, Clone)]
115pub struct SessionManager {
116 config: SessionConfig,
117 sessions: Arc<Mutex<HashMap<String, SessionContext>>>,
118 agent_fallback_session_id: Arc<str>,
119}
120
121impl SessionManager {
122 pub fn new(config: SessionConfig) -> Self {
124 Self {
125 config,
126 sessions: Arc::new(Mutex::new(HashMap::new())),
127 agent_fallback_session_id: Arc::from(Uuid::new_v4().to_string()),
128 }
129 }
130
131 pub fn get_or_create(
134 &self,
135 request: SessionRequest<'_>,
136 ) -> Result<SessionResolution, SessionError> {
137 self.get_or_create_at(request, SystemTime::now())
138 }
139
140 pub fn get_or_create_at(
142 &self,
143 request: SessionRequest<'_>,
144 now: SystemTime,
145 ) -> Result<SessionResolution, SessionError> {
146 if request.model_id.trim().is_empty() {
147 return Err(SessionError::InvalidModelId);
148 }
149
150 let resolved = self.resolve_identifier(request)?;
151 let session_key = session_key(request.model_id, &resolved.agent_session_id);
152 let mut sessions = self.lock_sessions();
153 let replaced_expired = match sessions.get(&session_key) {
154 Some(existing) => self.expiration_reason(existing, now),
155 None => None,
156 };
157
158 if replaced_expired.is_some() {
159 sessions.remove(&session_key);
160 }
161
162 if let Some(existing) = sessions.get_mut(&session_key) {
163 existing.request_count += 1;
164 existing.last_used_at = now;
165 return Ok(SessionResolution {
166 session: existing.clone(),
167 created: false,
168 replaced_expired: None,
169 });
170 }
171
172 let context = SessionContext::new(
173 request.model_id,
174 resolved.agent_session_id,
175 resolved.scope,
176 now,
177 &self.config,
178 );
179 sessions.insert(session_key, context.clone());
180
181 Ok(SessionResolution {
182 session: context,
183 created: true,
184 replaced_expired,
185 })
186 }
187
188 pub fn set_attested_model_state(
190 &self,
191 session_key: &str,
192 state: AttestedModelState,
193 ) -> Result<SessionContext, SessionError> {
194 self.set_attested_model_state_at(session_key, state, SystemTime::now())
195 }
196
197 pub fn set_attested_model_state_at(
199 &self,
200 session_key: &str,
201 state: AttestedModelState,
202 now: SystemTime,
203 ) -> Result<SessionContext, SessionError> {
204 let mut sessions = self.lock_sessions();
205 let expired = sessions
206 .get(session_key)
207 .and_then(|session| self.expiration_reason(session, now));
208
209 if let Some(reason) = expired {
210 sessions.remove(session_key);
211 return Err(SessionError::SessionExpired { reason });
212 }
213
214 let session =
215 sessions
216 .get_mut(session_key)
217 .ok_or_else(|| SessionError::SessionNotFound {
218 session_key: session_key.to_owned(),
219 })?;
220 session.attested_model_public_key = Some(state.model_public_key);
221 session.attestation_report = Some(state.attestation_report);
222 session.verified_at = Some(state.verified_at);
223
224 Ok(session.clone())
225 }
226
227 pub fn cleanup_expired(&self) -> usize {
229 self.cleanup_expired_at(SystemTime::now())
230 }
231
232 pub fn cleanup_expired_at(&self, now: SystemTime) -> usize {
234 let mut sessions = self.lock_sessions();
235 let before = sessions.len();
236 sessions.retain(|_, session| self.expiration_reason(session, now).is_none());
237 before - sessions.len()
238 }
239
240 pub fn len(&self) -> usize {
242 self.lock_sessions().len()
243 }
244
245 pub fn is_empty(&self) -> bool {
247 self.len() == 0
248 }
249
250 fn resolve_identifier(
252 &self,
253 request: SessionRequest<'_>,
254 ) -> Result<ResolvedSessionIdentifier, SessionError> {
255 if let Some(value) = self.explicit_identifier(&request)? {
256 return Ok(ResolvedSessionIdentifier::agent(value));
257 }
258
259 match self.config.fallback_scope {
260 SessionFallbackScope::Agent => Ok(ResolvedSessionIdentifier::agent(
261 self.agent_fallback_session_id.to_string(),
262 )),
263 SessionFallbackScope::Request => Ok(ResolvedSessionIdentifier {
264 agent_session_id: Uuid::new_v4().to_string(),
265 scope: SessionScope::Request,
266 }),
267 SessionFallbackScope::Disabled => Err(SessionError::MissingSessionIdentifier),
268 }
269 }
270
271 fn explicit_identifier(
273 &self,
274 request: &SessionRequest<'_>,
275 ) -> Result<Option<String>, SessionError> {
276 if let Some(value) =
277 header_identifier(request.headers, &self.config.headers.incoming_session_id)?
278 {
279 return Ok(Some(value));
280 }
281
282 Ok(metadata_identifier(request.body, "session_id"))
283 }
284
285 fn expiration_reason(
287 &self,
288 session: &SessionContext,
289 now: SystemTime,
290 ) -> Option<SessionExpirationReason> {
291 if session.request_count >= self.config.max_requests {
292 return Some(SessionExpirationReason::MaxRequests);
293 }
294
295 if now >= session.expires_at {
296 return Some(SessionExpirationReason::MaxTtl);
297 }
298
299 if elapsed_since(session.last_used_at, now) >= self.config.idle_ttl {
300 return Some(SessionExpirationReason::IdleTtl);
301 }
302
303 None
304 }
305
306 fn lock_sessions(&self) -> std::sync::MutexGuard<'_, HashMap<String, SessionContext>> {
308 self.sessions
309 .lock()
310 .unwrap_or_else(std::sync::PoisonError::into_inner)
311 }
312}
313
314#[derive(Debug, Clone)]
316struct ResolvedSessionIdentifier {
317 agent_session_id: String,
318 scope: SessionScope,
319}
320
321impl ResolvedSessionIdentifier {
322 fn agent(agent_session_id: String) -> Self {
324 Self {
325 agent_session_id,
326 scope: SessionScope::Agent,
327 }
328 }
329}
330
331impl SessionContext {
332 fn new(
334 model_id: &str,
335 agent_session_id: String,
336 scope: SessionScope,
337 now: SystemTime,
338 config: &SessionConfig,
339 ) -> Self {
340 let session_key = session_key(model_id, &agent_session_id);
341 Self {
342 session_key,
343 model_id: model_id.to_owned(),
344 agent_session_id,
345 scope,
346 created_at: now,
347 last_used_at: now,
348 expires_at: now + config.max_ttl,
349 request_count: 1,
350 attested_model_public_key: None,
351 attestation_report: None,
352 verified_at: None,
353 }
354 }
355}
356
357#[derive(Debug, Error, PartialEq, Eq)]
359pub enum SessionError {
360 #[error("request model id must not be empty")]
361 InvalidModelId,
362 #[error("request does not include a session identifier and session fallback is disabled")]
363 MissingSessionIdentifier,
364 #[error("configured session header name {header:?} is invalid")]
365 InvalidHeaderName { header: String },
366 #[error("session header {header} contains non-UTF-8 data")]
367 InvalidHeaderValue { header: String },
368 #[error("session {session_key} was not found")]
369 SessionNotFound { session_key: String },
370 #[error("session expired before attestation state could be stored: {reason:?}")]
371 SessionExpired { reason: SessionExpirationReason },
372}
373
374fn header_identifier(
376 headers: &HeaderMap,
377 configured_name: &str,
378) -> Result<Option<String>, SessionError> {
379 let name = HeaderName::from_bytes(configured_name.as_bytes()).map_err(|_| {
380 SessionError::InvalidHeaderName {
381 header: configured_name.to_owned(),
382 }
383 })?;
384
385 let Some(value) = headers.get(&name) else {
386 return Ok(None);
387 };
388 let value = value
389 .to_str()
390 .map_err(|_| SessionError::InvalidHeaderValue {
391 header: configured_name.to_owned(),
392 })?;
393 Ok(non_empty_string(value))
394}
395
396fn metadata_identifier(body: Option<&Value>, key: &str) -> Option<String> {
398 body.and_then(|body| body.get("metadata"))
399 .and_then(|metadata| metadata.get(key))
400 .and_then(Value::as_str)
401 .and_then(non_empty_string)
402}
403
404fn non_empty_string(value: &str) -> Option<String> {
406 let trimmed = value.trim();
407 (!trimmed.is_empty()).then(|| trimmed.to_owned())
408}
409
410fn session_key(model_id: &str, agent_session_id: &str) -> String {
412 format!("{model_id}:{agent_session_id}")
413}
414
415fn elapsed_since(start: SystemTime, now: SystemTime) -> Duration {
417 now.duration_since(start).unwrap_or(Duration::ZERO)
418}
419
420#[cfg(test)]
421mod tests {
422 use super::*;
423 use axum::http::HeaderValue;
424 use serde_json::json;
425
426 fn test_config() -> SessionConfig {
427 SessionConfig {
428 idle_ttl: Duration::from_secs(10),
429 max_ttl: Duration::from_secs(30),
430 max_requests: 3,
431 fallback_scope: SessionFallbackScope::Request,
432 headers: Default::default(),
433 }
434 }
435
436 fn manager() -> SessionManager {
437 SessionManager::new(test_config())
438 }
439
440 fn now(seconds: u64) -> SystemTime {
441 SystemTime::UNIX_EPOCH + Duration::from_secs(seconds)
442 }
443
444 fn request<'a>(model_id: &'a str, headers: &'a HeaderMap) -> SessionRequest<'a> {
445 SessionRequest::new(model_id, headers)
446 }
447
448 #[test]
449 fn creates_new_agent_session_from_incoming_session_id_header() {
450 let manager = manager();
451 let mut headers = HeaderMap::new();
452 headers.insert(
453 "X-Venice-Proxy-Session-Id",
454 HeaderValue::from_static("chat-1"),
455 );
456
457 let resolved = manager
458 .get_or_create_at(request("model-a", &headers), now(0))
459 .expect("session should resolve");
460
461 assert!(resolved.created);
462 assert_eq!(resolved.replaced_expired, None);
463 assert_eq!(resolved.session.session_key, "model-a:chat-1");
464 assert_eq!(resolved.session.model_id, "model-a");
465 assert_eq!(resolved.session.agent_session_id, "chat-1");
466 assert_eq!(resolved.session.scope, SessionScope::Agent);
467 assert_eq!(resolved.session.request_count, 1);
468 }
469
470 #[test]
471 fn reuses_existing_session_from_configured_header() {
472 let mut config = test_config();
473 config.headers.incoming_session_id = "X-Custom-Session-Id".to_owned();
474 let manager = SessionManager::new(config);
475 let mut headers = HeaderMap::new();
476 headers.insert(
477 "X-Custom-Session-Id",
478 HeaderValue::from_static("configured-chat"),
479 );
480
481 let first = manager
482 .get_or_create_at(request("model-a", &headers), now(0))
483 .expect("first request should create");
484 let second = manager
485 .get_or_create_at(request("model-a", &headers), now(5))
486 .expect("second request should reuse");
487
488 assert!(first.created);
489 assert!(!second.created);
490 assert_eq!(second.session.session_key, first.session.session_key);
491 assert_eq!(second.session.request_count, 2);
492 assert_eq!(second.session.last_used_at, now(5));
493 assert_eq!(manager.len(), 1);
494 }
495
496 #[test]
497 fn configured_header_wins_over_metadata() {
498 let manager = manager();
499 let mut headers = HeaderMap::new();
500 headers.insert(
501 "X-Venice-Proxy-Session-Id",
502 HeaderValue::from_static("header-session"),
503 );
504 let body = json!({ "metadata": { "session_id": "body-session" } });
505
506 let resolved = manager
507 .get_or_create_at(
508 SessionRequest::new("model-a", &headers).with_body(&body),
509 now(0),
510 )
511 .expect("session should resolve");
512
513 assert_eq!(resolved.session.session_key, "model-a:header-session");
514 }
515
516 #[test]
517 fn metadata_session_id_is_used_when_headers_are_missing() {
518 let manager = manager();
519 let headers = HeaderMap::new();
520 let body = json!({ "metadata": { "session_id": "metadata-session" } });
521
522 let resolved = manager
523 .get_or_create_at(
524 SessionRequest::new("model-a", &headers).with_body(&body),
525 now(0),
526 )
527 .expect("session should resolve");
528
529 assert_eq!(resolved.session.session_key, "model-a:metadata-session");
530 assert_eq!(resolved.session.scope, SessionScope::Agent);
531 }
532
533 #[test]
534 fn idle_ttl_expiration_discards_old_session_and_creates_fresh_one() {
535 let manager = manager();
536 let mut headers = HeaderMap::new();
537 headers.insert(
538 "X-Venice-Proxy-Session-Id",
539 HeaderValue::from_static("chat-1"),
540 );
541
542 let first = manager
543 .get_or_create_at(request("model-a", &headers), now(0))
544 .expect("first request should create");
545 let second = manager
546 .get_or_create_at(request("model-a", &headers), now(10))
547 .expect("idle-expired request should recreate");
548
549 assert!(second.created);
550 assert_eq!(
551 second.replaced_expired,
552 Some(SessionExpirationReason::IdleTtl)
553 );
554 assert_eq!(second.session.session_key, first.session.session_key);
555 assert_eq!(second.session.request_count, 1);
556 assert_eq!(second.session.created_at, now(10));
557 }
558
559 #[test]
560 fn max_ttl_expiration_discards_old_session_and_creates_fresh_one() {
561 let mut config = test_config();
562 config.idle_ttl = Duration::from_secs(20);
563 config.max_ttl = Duration::from_secs(30);
564 let manager = SessionManager::new(config);
565 let mut headers = HeaderMap::new();
566 headers.insert(
567 "X-Venice-Proxy-Session-Id",
568 HeaderValue::from_static("chat-1"),
569 );
570
571 let first = manager
572 .get_or_create_at(request("model-a", &headers), now(0))
573 .expect("first request should create");
574 manager
575 .get_or_create_at(request("model-a", &headers), now(15))
576 .expect("within idle ttl should reuse");
577 let third = manager
578 .get_or_create_at(request("model-a", &headers), now(30))
579 .expect("max-ttl-expired request should recreate");
580
581 assert!(third.created);
582 assert_eq!(
583 third.replaced_expired,
584 Some(SessionExpirationReason::MaxTtl)
585 );
586 assert_eq!(third.session.session_key, first.session.session_key);
587 assert_eq!(third.session.request_count, 1);
588 assert_eq!(third.session.created_at, now(30));
589 }
590
591 #[test]
592 fn max_request_expiration_discards_old_session_and_creates_fresh_one() {
593 let manager = manager();
594 let mut headers = HeaderMap::new();
595 headers.insert(
596 "X-Venice-Proxy-Session-Id",
597 HeaderValue::from_static("chat-1"),
598 );
599
600 manager
601 .get_or_create_at(request("model-a", &headers), now(0))
602 .expect("first request should create");
603 manager
604 .get_or_create_at(request("model-a", &headers), now(1))
605 .expect("second request should reuse");
606 let third = manager
607 .get_or_create_at(request("model-a", &headers), now(2))
608 .expect("third request should reuse and reach max");
609 let fourth = manager
610 .get_or_create_at(request("model-a", &headers), now(3))
611 .expect("fourth request should recreate");
612
613 assert!(!third.created);
614 assert_eq!(third.session.request_count, 3);
615 assert!(fourth.created);
616 assert_eq!(
617 fourth.replaced_expired,
618 Some(SessionExpirationReason::MaxRequests)
619 );
620 assert_eq!(fourth.session.request_count, 1);
621 }
622
623 #[test]
624 fn request_fallback_creates_distinct_request_scoped_sessions() {
625 let manager = manager();
626 let headers = HeaderMap::new();
627
628 let first = manager
629 .get_or_create_at(request("model-a", &headers), now(0))
630 .expect("fallback should create");
631 let second = manager
632 .get_or_create_at(request("model-a", &headers), now(1))
633 .expect("fallback should create again");
634
635 assert!(first.created);
636 assert!(second.created);
637 assert_eq!(first.session.scope, SessionScope::Request);
638 assert_eq!(second.session.scope, SessionScope::Request);
639 assert_ne!(
640 first.session.agent_session_id,
641 second.session.agent_session_id
642 );
643 assert_eq!(manager.len(), 2);
644 }
645
646 #[test]
647 fn agent_fallback_reuses_generated_agent_scoped_session() {
648 let mut config = test_config();
649 config.fallback_scope = SessionFallbackScope::Agent;
650 let manager = SessionManager::new(config);
651 let headers = HeaderMap::new();
652
653 let first = manager
654 .get_or_create_at(request("model-a", &headers), now(0))
655 .expect("fallback should create");
656 let second = manager
657 .get_or_create_at(request("model-a", &headers), now(1))
658 .expect("fallback should reuse");
659
660 assert!(first.created);
661 assert!(!second.created);
662 assert_eq!(first.session.scope, SessionScope::Agent);
663 assert_eq!(
664 first.session.agent_session_id,
665 second.session.agent_session_id
666 );
667 assert_eq!(second.session.request_count, 2);
668 }
669
670 #[test]
671 fn disabled_fallback_returns_clear_error_without_creating_session() {
672 let mut config = test_config();
673 config.fallback_scope = SessionFallbackScope::Disabled;
674 let manager = SessionManager::new(config);
675 let headers = HeaderMap::new();
676
677 let error = manager
678 .get_or_create_at(request("model-a", &headers), now(0))
679 .expect_err("missing session id should fail when fallback is disabled");
680
681 assert_eq!(error, SessionError::MissingSessionIdentifier);
682 assert_eq!(
683 error.to_string(),
684 "request does not include a session identifier and session fallback is disabled"
685 );
686 assert!(manager.is_empty());
687 }
688
689 #[test]
690 fn cleanup_removes_expired_sessions_and_keeps_valid_sessions() {
691 let manager = manager();
692 let mut headers_a = HeaderMap::new();
693 headers_a.insert(
694 "X-Venice-Proxy-Session-Id",
695 HeaderValue::from_static("chat-a"),
696 );
697 let mut headers_b = HeaderMap::new();
698 headers_b.insert(
699 "X-Venice-Proxy-Session-Id",
700 HeaderValue::from_static("chat-b"),
701 );
702
703 manager
704 .get_or_create_at(request("model-a", &headers_a), now(0))
705 .expect("session a should create");
706 manager
707 .get_or_create_at(request("model-a", &headers_b), now(15))
708 .expect("session b should create");
709
710 let removed = manager.cleanup_expired_at(now(20));
711
712 assert_eq!(removed, 1);
713 assert_eq!(manager.len(), 1);
714 let reused_b = manager
715 .get_or_create_at(request("model-a", &headers_b), now(21))
716 .expect("session b should remain valid");
717 assert!(!reused_b.created);
718 }
719
720 #[test]
721 fn stores_attested_model_state_on_existing_unexpired_session() {
722 let manager = manager();
723 let mut headers = HeaderMap::new();
724 headers.insert(
725 "X-Venice-Proxy-Session-Id",
726 HeaderValue::from_static("chat-1"),
727 );
728 let session = manager
729 .get_or_create_at(request("model-a", &headers), now(0))
730 .expect("session should create")
731 .session;
732
733 let updated = manager
734 .set_attested_model_state_at(
735 &session.session_key,
736 AttestedModelState {
737 model_public_key: "model-public-key".to_owned(),
738 attestation_report: json!({ "verified": true }),
739 verified_at: now(1),
740 },
741 now(1),
742 )
743 .expect("attestation state should update");
744
745 assert_eq!(
746 updated.attested_model_public_key.as_deref(),
747 Some("model-public-key")
748 );
749 assert_eq!(
750 updated.attestation_report,
751 Some(json!({ "verified": true }))
752 );
753 assert_eq!(updated.verified_at, Some(now(1)));
754 }
755}