systemprompt_api/services/middleware/context/extractors/
header_extractor.rs1use async_trait::async_trait;
2use axum::http::HeaderMap;
3use systemprompt_identifiers::{AgentName, ContextId, SessionId, TaskId, TraceId, UserId};
4use systemprompt_models::auth::UserType;
5use systemprompt_models::execution::{ContextExtractionError, RequestContext};
6
7use super::traits::ContextExtractor;
8
9#[derive(Debug, Clone, Copy)]
10pub struct HeaderContextExtractor;
11
12impl HeaderContextExtractor {
13 pub const fn new() -> Self {
14 Self
15 }
16
17 fn extract_required_header(
18 headers: &HeaderMap,
19 name: &str,
20 ) -> Result<String, ContextExtractionError> {
21 headers
22 .get(name)
23 .ok_or_else(|| ContextExtractionError::MissingHeader(name.to_string()))?
24 .to_str()
25 .map(ToString::to_string)
26 .map_err(|e| ContextExtractionError::InvalidHeaderValue {
27 header: name.to_string(),
28 reason: e.to_string(),
29 })
30 }
31
32 fn extract_optional_header(headers: &HeaderMap, name: &str) -> Option<String> {
33 headers
34 .get(name)
35 .and_then(|v| v.to_str().ok())
36 .map(ToString::to_string)
37 }
38}
39
40impl Default for HeaderContextExtractor {
41 fn default() -> Self {
42 Self::new()
43 }
44}
45
46#[async_trait]
47impl ContextExtractor for HeaderContextExtractor {
48 async fn extract_from_headers(
49 &self,
50 headers: &HeaderMap,
51 ) -> Result<RequestContext, ContextExtractionError> {
52 let session_id_str = Self::extract_required_header(headers, "x-session-id")?;
53 let trace_id_str = Self::extract_required_header(headers, "x-trace-id")?;
54 let user_id_str = Self::extract_required_header(headers, "x-user-id")?;
55 let context_id_str = Self::extract_required_header(headers, "x-context-id")?;
56 let agent_name_str = Self::extract_required_header(headers, "x-agent-name")?;
57
58 let context_id = ContextId::try_new(context_id_str).map_err(|e| {
59 ContextExtractionError::InvalidHeaderValue {
60 header: "x-context-id".to_string(),
61 reason: e.to_string(),
62 }
63 })?;
64
65 let mut context = RequestContext::new(
66 SessionId::new(session_id_str),
67 TraceId::new(trace_id_str),
68 context_id,
69 AgentName::new(agent_name_str),
70 )
71 .with_user_id(UserId::new(user_id_str));
72
73 if let Some(task_id_str) = Self::extract_optional_header(headers, "x-task-id") {
74 context = context.with_task_id(TaskId::new(task_id_str));
75 }
76
77 Ok(context)
78 }
79
80 async fn extract_user_only(
81 &self,
82 headers: &HeaderMap,
83 ) -> Result<RequestContext, ContextExtractionError> {
84 let session_id_str = Self::extract_required_header(headers, "x-session-id")?;
85 let trace_id_str = Self::extract_required_header(headers, "x-trace-id")?;
86 let user_id_str = Self::extract_required_header(headers, "x-user-id")?;
87 let agent_name_str = Self::extract_required_header(headers, "x-agent-name")?;
88
89 let context = RequestContext::new(
90 SessionId::new(session_id_str),
91 TraceId::new(trace_id_str),
92 ContextId::generate(),
93 AgentName::new(agent_name_str),
94 )
95 .with_user_id(UserId::new(user_id_str))
96 .with_user_type(UserType::User);
97
98 Ok(context)
99 }
100}