Skip to main content

systemprompt_api/services/middleware/context/extractors/
header_extractor.rs

1use async_trait::async_trait;
2use axum::http::HeaderMap;
3use systemprompt_identifiers::{AgentName, ContextId, SessionId, TaskId, TraceId, UserId};
4use systemprompt_models::auth::UserType;
5use systemprompt_models::execution::{ContextExtractionError, RequestContext};
6
7use super::traits::ContextExtractor;
8
9#[derive(Debug, Clone, Copy)]
10pub struct HeaderContextExtractor;
11
12impl HeaderContextExtractor {
13    pub const fn new() -> Self {
14        Self
15    }
16
17    fn extract_required_header(
18        headers: &HeaderMap,
19        name: &str,
20    ) -> Result<String, ContextExtractionError> {
21        headers
22            .get(name)
23            .ok_or_else(|| ContextExtractionError::MissingHeader(name.to_string()))?
24            .to_str()
25            .map(ToString::to_string)
26            .map_err(|e| ContextExtractionError::InvalidHeaderValue {
27                header: name.to_string(),
28                reason: e.to_string(),
29            })
30    }
31
32    fn extract_optional_header(headers: &HeaderMap, name: &str) -> Option<String> {
33        headers
34            .get(name)
35            .and_then(|v| v.to_str().ok())
36            .map(ToString::to_string)
37    }
38}
39
40impl Default for HeaderContextExtractor {
41    fn default() -> Self {
42        Self::new()
43    }
44}
45
46#[async_trait]
47impl ContextExtractor for HeaderContextExtractor {
48    async fn extract_from_headers(
49        &self,
50        headers: &HeaderMap,
51    ) -> Result<RequestContext, ContextExtractionError> {
52        let session_id_str = Self::extract_required_header(headers, "x-session-id")?;
53        let trace_id_str = Self::extract_required_header(headers, "x-trace-id")?;
54        let user_id_str = Self::extract_required_header(headers, "x-user-id")?;
55        let context_id_str = Self::extract_required_header(headers, "x-context-id")?;
56        let agent_name_str = Self::extract_required_header(headers, "x-agent-name")?;
57
58        let context_id = ContextId::try_new(context_id_str).map_err(|e| {
59            ContextExtractionError::InvalidHeaderValue {
60                header: "x-context-id".to_string(),
61                reason: e.to_string(),
62            }
63        })?;
64
65        let mut context = RequestContext::new(
66            SessionId::new(session_id_str),
67            TraceId::new(trace_id_str),
68            context_id,
69            AgentName::new(agent_name_str),
70        )
71        .with_user_id(UserId::new(user_id_str));
72
73        if let Some(task_id_str) = Self::extract_optional_header(headers, "x-task-id") {
74            context = context.with_task_id(TaskId::new(task_id_str));
75        }
76
77        Ok(context)
78    }
79
80    async fn extract_user_only(
81        &self,
82        headers: &HeaderMap,
83    ) -> Result<RequestContext, ContextExtractionError> {
84        let session_id_str = Self::extract_required_header(headers, "x-session-id")?;
85        let trace_id_str = Self::extract_required_header(headers, "x-trace-id")?;
86        let user_id_str = Self::extract_required_header(headers, "x-user-id")?;
87        let agent_name_str = Self::extract_required_header(headers, "x-agent-name")?;
88
89        let context = RequestContext::new(
90            SessionId::new(session_id_str),
91            TraceId::new(trace_id_str),
92            ContextId::generate(),
93            AgentName::new(agent_name_str),
94        )
95        .with_user_id(UserId::new(user_id_str))
96        .with_user_type(UserType::User);
97
98        Ok(context)
99    }
100}