Skip to main content

systemprompt_api/services/middleware/context/extractors/
header_extractor.rs

1use async_trait::async_trait;
2use axum::http::HeaderMap;
3use systemprompt_identifiers::{AgentName, ContextId, SessionId, TaskId, TraceId, UserId};
4use systemprompt_models::auth::UserType;
5use systemprompt_models::execution::{ContextExtractionError, RequestContext};
6
7use super::traits::ContextExtractor;
8
9#[derive(Debug, Clone, Copy)]
10pub struct HeaderContextExtractor;
11
12impl HeaderContextExtractor {
13    pub const fn new() -> Self {
14        Self
15    }
16
17    fn extract_required_header(
18        headers: &HeaderMap,
19        name: &str,
20    ) -> Result<String, ContextExtractionError> {
21        headers
22            .get(name)
23            .ok_or_else(|| ContextExtractionError::MissingHeader(name.to_string()))?
24            .to_str()
25            .map(ToString::to_string)
26            .map_err(|e| ContextExtractionError::InvalidHeaderValue {
27                header: name.to_string(),
28                reason: e.to_string(),
29            })
30    }
31
32    fn extract_optional_header(headers: &HeaderMap, name: &str) -> Option<String> {
33        headers
34            .get(name)
35            .and_then(|v| v.to_str().ok())
36            .map(ToString::to_string)
37    }
38}
39
40impl Default for HeaderContextExtractor {
41    fn default() -> Self {
42        Self::new()
43    }
44}
45
46#[async_trait]
47impl ContextExtractor for HeaderContextExtractor {
48    async fn extract_from_headers(
49        &self,
50        headers: &HeaderMap,
51    ) -> Result<RequestContext, ContextExtractionError> {
52        let session_id_str = Self::extract_required_header(headers, "x-session-id")?;
53        let trace_id_str = Self::extract_required_header(headers, "x-trace-id")?;
54        let user_id_str = Self::extract_required_header(headers, "x-user-id")?;
55        let context_id_str = Self::extract_required_header(headers, "x-context-id")?;
56        let agent_name_str = Self::extract_required_header(headers, "x-agent-name")?;
57
58        let mut context = RequestContext::new(
59            SessionId::new(session_id_str),
60            TraceId::new(trace_id_str),
61            ContextId::new(context_id_str),
62            AgentName::new(agent_name_str),
63        )
64        .with_user_id(UserId::new(user_id_str));
65
66        if let Some(task_id_str) = Self::extract_optional_header(headers, "x-task-id") {
67            context = context.with_task_id(TaskId::new(task_id_str));
68        }
69
70        Ok(context)
71    }
72
73    async fn extract_user_only(
74        &self,
75        headers: &HeaderMap,
76    ) -> Result<RequestContext, ContextExtractionError> {
77        let session_id_str = Self::extract_required_header(headers, "x-session-id")?;
78        let trace_id_str = Self::extract_required_header(headers, "x-trace-id")?;
79        let user_id_str = Self::extract_required_header(headers, "x-user-id")?;
80        let agent_name_str = Self::extract_required_header(headers, "x-agent-name")?;
81
82        let context = RequestContext::new(
83            SessionId::new(session_id_str),
84            TraceId::new(trace_id_str),
85            ContextId::new(String::new()),
86            AgentName::new(agent_name_str),
87        )
88        .with_user_id(UserId::new(user_id_str))
89        .with_user_type(UserType::User);
90
91        Ok(context)
92    }
93}