systemprompt_api/services/middleware/context/extractors/
header_extractor.rs1use async_trait::async_trait;
2use axum::http::HeaderMap;
3use systemprompt_identifiers::{AgentName, ContextId, SessionId, TaskId, TraceId, UserId};
4use systemprompt_models::auth::UserType;
5use systemprompt_models::execution::{ContextExtractionError, RequestContext};
6
7use super::traits::ContextExtractor;
8
9#[derive(Debug, Clone, Copy)]
10pub struct HeaderContextExtractor;
11
12impl HeaderContextExtractor {
13 pub const fn new() -> Self {
14 Self
15 }
16
17 fn extract_required_header(
18 headers: &HeaderMap,
19 name: &str,
20 ) -> Result<String, ContextExtractionError> {
21 headers
22 .get(name)
23 .ok_or_else(|| ContextExtractionError::MissingHeader(name.to_string()))?
24 .to_str()
25 .map(ToString::to_string)
26 .map_err(|e| ContextExtractionError::InvalidHeaderValue {
27 header: name.to_string(),
28 reason: e.to_string(),
29 })
30 }
31
32 fn extract_optional_header(headers: &HeaderMap, name: &str) -> Option<String> {
33 headers
34 .get(name)
35 .and_then(|v| v.to_str().ok())
36 .map(ToString::to_string)
37 }
38}
39
40impl Default for HeaderContextExtractor {
41 fn default() -> Self {
42 Self::new()
43 }
44}
45
46#[async_trait]
47impl ContextExtractor for HeaderContextExtractor {
48 async fn extract_from_headers(
49 &self,
50 headers: &HeaderMap,
51 ) -> Result<RequestContext, ContextExtractionError> {
52 let session_id_str = Self::extract_required_header(headers, "x-session-id")?;
53 let trace_id_str = Self::extract_required_header(headers, "x-trace-id")?;
54 let user_id_str = Self::extract_required_header(headers, "x-user-id")?;
55 let context_id_str = Self::extract_required_header(headers, "x-context-id")?;
56 let agent_name_str = Self::extract_required_header(headers, "x-agent-name")?;
57
58 let mut context = RequestContext::new(
59 SessionId::new(session_id_str),
60 TraceId::new(trace_id_str),
61 ContextId::new(context_id_str),
62 AgentName::new(agent_name_str),
63 )
64 .with_user_id(UserId::new(user_id_str));
65
66 if let Some(task_id_str) = Self::extract_optional_header(headers, "x-task-id") {
67 context = context.with_task_id(TaskId::new(task_id_str));
68 }
69
70 Ok(context)
71 }
72
73 async fn extract_user_only(
74 &self,
75 headers: &HeaderMap,
76 ) -> Result<RequestContext, ContextExtractionError> {
77 let session_id_str = Self::extract_required_header(headers, "x-session-id")?;
78 let trace_id_str = Self::extract_required_header(headers, "x-trace-id")?;
79 let user_id_str = Self::extract_required_header(headers, "x-user-id")?;
80 let agent_name_str = Self::extract_required_header(headers, "x-agent-name")?;
81
82 let context = RequestContext::new(
83 SessionId::new(session_id_str),
84 TraceId::new(trace_id_str),
85 ContextId::new(String::new()),
86 AgentName::new(agent_name_str),
87 )
88 .with_user_id(UserId::new(user_id_str))
89 .with_user_type(UserType::User);
90
91 Ok(context)
92 }
93}