Skip to main content

systemprompt_security/extraction/
header.rs

1use axum::http::{HeaderMap, HeaderValue};
2use std::error::Error;
3use std::fmt;
4use systemprompt_identifiers::{
5    AgentName, ContextId, GatewayConversationId, ProviderRequestId, SessionId, TaskId, TraceId,
6    UserId, headers,
7};
8use systemprompt_models::execution::context::RequestContext;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub struct HeaderInjectionError;
12
13impl fmt::Display for HeaderInjectionError {
14    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
15        write!(f, "Header value contains invalid characters")
16    }
17}
18
19impl Error for HeaderInjectionError {}
20
21#[derive(Debug, Clone, Copy)]
22pub struct HeaderExtractor;
23
24impl HeaderExtractor {
25    pub fn extract_trace_id(headers: &HeaderMap) -> TraceId {
26        Self::extract_header(headers, headers::TRACE_ID)
27            .map_or_else(TraceId::generate, TraceId::new)
28    }
29
30    pub fn extract_context_id(headers: &HeaderMap) -> Option<ContextId> {
31        Self::extract_header(headers, headers::CONTEXT_ID)
32            .filter(|s| !s.is_empty())
33            .and_then(|s| {
34                ContextId::try_new(s)
35                    .map_err(|e| {
36                        tracing::warn!(error = %e, "Invalid context_id header value, ignoring");
37                        e
38                    })
39                    .ok()
40            })
41    }
42
43    pub fn extract_gateway_conversation_id(headers: &HeaderMap) -> Option<GatewayConversationId> {
44        Self::extract_header(headers, headers::GATEWAY_CONVERSATION_ID)
45            .filter(|s| !s.is_empty())
46            .and_then(|s| {
47                GatewayConversationId::try_new(s)
48                    .map_err(|e| {
49                        tracing::warn!(error = %e, "Invalid gateway_conversation_id header value, ignoring");
50                        e
51                    })
52                    .ok()
53            })
54    }
55
56    pub fn extract_provider_request_id(headers: &HeaderMap) -> Option<ProviderRequestId> {
57        Self::extract_header(headers, headers::PROVIDER_REQUEST_ID)
58            .filter(|s| !s.is_empty())
59            .and_then(|s| {
60                ProviderRequestId::try_new(s)
61                    .map_err(|e| {
62                        tracing::warn!(error = %e, "Invalid provider_request_id header value, ignoring");
63                        e
64                    })
65                    .ok()
66            })
67    }
68
69    pub fn extract_task_id(headers: &HeaderMap) -> Option<TaskId> {
70        Self::extract_header(headers, headers::TASK_ID).map(TaskId::new)
71    }
72
73    pub fn extract_agent_name(headers: &HeaderMap) -> AgentName {
74        Self::extract_header(headers, headers::AGENT_NAME)
75            .map_or_else(AgentName::system, AgentName::new)
76    }
77
78    fn extract_header(headers: &HeaderMap, name: &str) -> Option<String> {
79        headers
80            .get(name)
81            .and_then(|v| {
82                v.to_str()
83                    .map_err(|e| {
84                        tracing::debug!(error = %e, header = %name, "Header contains non-ASCII characters");
85                        e
86                    })
87                    .ok()
88            })
89            .map(str::to_owned)
90    }
91}
92
93#[derive(Debug, Clone, Copy)]
94pub struct HeaderInjector;
95
96impl HeaderInjector {
97    pub fn inject_session_id(
98        headers: &mut HeaderMap,
99        session_id: &SessionId,
100    ) -> Result<(), HeaderInjectionError> {
101        Self::inject_header(headers, headers::SESSION_ID, session_id.as_str())
102    }
103
104    pub fn inject_user_id(
105        headers: &mut HeaderMap,
106        user_id: &UserId,
107    ) -> Result<(), HeaderInjectionError> {
108        Self::inject_header(headers, headers::USER_ID, user_id.as_str())
109    }
110
111    pub fn inject_trace_id(
112        headers: &mut HeaderMap,
113        trace_id: &TraceId,
114    ) -> Result<(), HeaderInjectionError> {
115        Self::inject_header(headers, headers::TRACE_ID, trace_id.as_str())
116    }
117
118    pub fn inject_context_id(
119        headers: &mut HeaderMap,
120        context_id: &ContextId,
121    ) -> Result<(), HeaderInjectionError> {
122        Self::inject_header(headers, headers::CONTEXT_ID, context_id.as_str())
123    }
124
125    pub fn inject_gateway_conversation_id(
126        headers: &mut HeaderMap,
127        id: &GatewayConversationId,
128    ) -> Result<(), HeaderInjectionError> {
129        Self::inject_header(headers, headers::GATEWAY_CONVERSATION_ID, id.as_str())
130    }
131
132    pub fn inject_provider_request_id(
133        headers: &mut HeaderMap,
134        id: &ProviderRequestId,
135    ) -> Result<(), HeaderInjectionError> {
136        Self::inject_header(headers, headers::PROVIDER_REQUEST_ID, id.as_str())
137    }
138
139    pub fn inject_task_id(
140        headers: &mut HeaderMap,
141        task_id: &TaskId,
142    ) -> Result<(), HeaderInjectionError> {
143        Self::inject_header(headers, headers::TASK_ID, task_id.as_str())
144    }
145
146    pub fn inject_agent_name(
147        headers: &mut HeaderMap,
148        agent_name: &str,
149    ) -> Result<(), HeaderInjectionError> {
150        Self::inject_header(headers, headers::AGENT_NAME, agent_name)
151    }
152
153    pub fn inject_from_request_context(
154        headers: &mut HeaderMap,
155        ctx: &RequestContext,
156    ) -> Result<(), HeaderInjectionError> {
157        Self::inject_session_id(headers, &ctx.request.session_id)?;
158        Self::inject_user_id(headers, &ctx.auth.actor.user_id)?;
159        Self::inject_trace_id(headers, &ctx.execution.trace_id)?;
160        Self::inject_context_id(headers, &ctx.execution.context_id)?;
161        Self::inject_agent_name(headers, ctx.execution.agent_name.as_str())?;
162        Ok(())
163    }
164
165    fn inject_header(
166        headers: &mut HeaderMap,
167        name: &'static str,
168        value: &str,
169    ) -> Result<(), HeaderInjectionError> {
170        HeaderValue::from_str(value).map_or(Err(HeaderInjectionError), |header_value| {
171            headers.insert(name, header_value);
172            Ok(())
173        })
174    }
175}