Skip to main content

systemprompt_security/extraction/
header.rs

1use axum::http::{HeaderMap, HeaderValue};
2use std::error::Error;
3use std::fmt;
4use systemprompt_identifiers::{
5    AgentName, ContextId, GatewayConversationId, ProviderRequestId, SessionId, TaskId, TraceId,
6    UserId, headers,
7};
8use systemprompt_models::execution::context::RequestContext;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub struct HeaderInjectionError;
12
13impl fmt::Display for HeaderInjectionError {
14    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
15        write!(f, "Header value contains invalid characters")
16    }
17}
18
19impl Error for HeaderInjectionError {}
20
21#[derive(Debug, Clone, Copy)]
22pub struct HeaderExtractor;
23
24impl HeaderExtractor {
25    pub fn extract_trace_id(headers: &HeaderMap) -> TraceId {
26        Self::extract_header(headers, headers::TRACE_ID)
27            .map_or_else(TraceId::generate, TraceId::new)
28    }
29
30    pub fn extract_context_id(headers: &HeaderMap) -> Option<ContextId> {
31        Self::extract_header(headers, headers::CONTEXT_ID)
32            .filter(|s| !s.is_empty())
33            .and_then(|s| {
34                ContextId::try_new(s)
35                    .map_err(|e| {
36                        tracing::warn!(error = %e, "Invalid context_id header value, ignoring");
37                        e
38                    })
39                    .ok()
40            })
41    }
42
43    pub fn extract_gateway_conversation_id(headers: &HeaderMap) -> Option<GatewayConversationId> {
44        Self::extract_header(headers, headers::GATEWAY_CONVERSATION_ID)
45            .filter(|s| !s.is_empty())
46            .and_then(|s| GatewayConversationId::try_new(s).ok())
47    }
48
49    pub fn extract_provider_request_id(headers: &HeaderMap) -> Option<ProviderRequestId> {
50        Self::extract_header(headers, headers::PROVIDER_REQUEST_ID)
51            .filter(|s| !s.is_empty())
52            .and_then(|s| ProviderRequestId::try_new(s).ok())
53    }
54
55    pub fn extract_task_id(headers: &HeaderMap) -> Option<TaskId> {
56        Self::extract_header(headers, headers::TASK_ID).map(TaskId::new)
57    }
58
59    pub fn extract_agent_name(headers: &HeaderMap) -> AgentName {
60        Self::extract_header(headers, headers::AGENT_NAME)
61            .map_or_else(AgentName::system, AgentName::new)
62    }
63
64    fn extract_header(headers: &HeaderMap, name: &str) -> Option<String> {
65        headers
66            .get(name)
67            .and_then(|v| {
68                v.to_str()
69                    .map_err(|e| {
70                        tracing::debug!(error = %e, header = %name, "Header contains non-ASCII characters");
71                        e
72                    })
73                    .ok()
74            })
75            .map(ToString::to_string)
76    }
77}
78
79#[derive(Debug, Clone, Copy)]
80pub struct HeaderInjector;
81
82impl HeaderInjector {
83    pub fn inject_session_id(
84        headers: &mut HeaderMap,
85        session_id: &SessionId,
86    ) -> Result<(), HeaderInjectionError> {
87        Self::inject_header(headers, headers::SESSION_ID, session_id.as_str())
88    }
89
90    pub fn inject_user_id(
91        headers: &mut HeaderMap,
92        user_id: &UserId,
93    ) -> Result<(), HeaderInjectionError> {
94        Self::inject_header(headers, headers::USER_ID, user_id.as_str())
95    }
96
97    pub fn inject_trace_id(
98        headers: &mut HeaderMap,
99        trace_id: &TraceId,
100    ) -> Result<(), HeaderInjectionError> {
101        Self::inject_header(headers, headers::TRACE_ID, trace_id.as_str())
102    }
103
104    pub fn inject_context_id(
105        headers: &mut HeaderMap,
106        context_id: &ContextId,
107    ) -> Result<(), HeaderInjectionError> {
108        Self::inject_header(headers, headers::CONTEXT_ID, context_id.as_str())
109    }
110
111    pub fn inject_gateway_conversation_id(
112        headers: &mut HeaderMap,
113        id: &GatewayConversationId,
114    ) -> Result<(), HeaderInjectionError> {
115        Self::inject_header(headers, headers::GATEWAY_CONVERSATION_ID, id.as_str())
116    }
117
118    pub fn inject_provider_request_id(
119        headers: &mut HeaderMap,
120        id: &ProviderRequestId,
121    ) -> Result<(), HeaderInjectionError> {
122        Self::inject_header(headers, headers::PROVIDER_REQUEST_ID, id.as_str())
123    }
124
125    pub fn inject_task_id(
126        headers: &mut HeaderMap,
127        task_id: &TaskId,
128    ) -> Result<(), HeaderInjectionError> {
129        Self::inject_header(headers, headers::TASK_ID, task_id.as_str())
130    }
131
132    pub fn inject_agent_name(
133        headers: &mut HeaderMap,
134        agent_name: &str,
135    ) -> Result<(), HeaderInjectionError> {
136        Self::inject_header(headers, headers::AGENT_NAME, agent_name)
137    }
138
139    pub fn inject_from_request_context(
140        headers: &mut HeaderMap,
141        ctx: &RequestContext,
142    ) -> Result<(), HeaderInjectionError> {
143        Self::inject_session_id(headers, &ctx.request.session_id)?;
144        Self::inject_user_id(headers, &ctx.auth.user_id)?;
145        Self::inject_trace_id(headers, &ctx.execution.trace_id)?;
146        Self::inject_context_id(headers, &ctx.execution.context_id)?;
147        Self::inject_agent_name(headers, ctx.execution.agent_name.as_str())?;
148        Ok(())
149    }
150
151    fn inject_header(
152        headers: &mut HeaderMap,
153        name: &'static str,
154        value: &str,
155    ) -> Result<(), HeaderInjectionError> {
156        HeaderValue::from_str(value).map_or(Err(HeaderInjectionError), |header_value| {
157            headers.insert(name, header_value);
158            Ok(())
159        })
160    }
161}