systemprompt_security/extraction/
header.rs1use axum::http::{HeaderMap, HeaderValue};
2use std::error::Error;
3use std::fmt;
4use systemprompt_identifiers::{
5 AgentName, ContextId, GatewayConversationId, ProviderRequestId, SessionId, TaskId, TraceId,
6 UserId, headers,
7};
8use systemprompt_models::execution::context::RequestContext;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub struct HeaderInjectionError;
12
13impl fmt::Display for HeaderInjectionError {
14 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
15 write!(f, "Header value contains invalid characters")
16 }
17}
18
19impl Error for HeaderInjectionError {}
20
21#[derive(Debug, Clone, Copy)]
22pub struct HeaderExtractor;
23
24impl HeaderExtractor {
25 pub fn extract_trace_id(headers: &HeaderMap) -> TraceId {
26 Self::extract_header(headers, headers::TRACE_ID)
27 .map_or_else(TraceId::generate, TraceId::new)
28 }
29
30 pub fn extract_context_id(headers: &HeaderMap) -> Option<ContextId> {
31 Self::extract_header(headers, headers::CONTEXT_ID)
32 .filter(|s| !s.is_empty())
33 .and_then(|s| {
34 ContextId::try_new(s)
35 .map_err(|e| {
36 tracing::warn!(error = %e, "Invalid context_id header value, ignoring");
37 e
38 })
39 .ok()
40 })
41 }
42
43 pub fn extract_gateway_conversation_id(headers: &HeaderMap) -> Option<GatewayConversationId> {
44 Self::extract_header(headers, headers::GATEWAY_CONVERSATION_ID)
45 .filter(|s| !s.is_empty())
46 .and_then(|s| {
47 GatewayConversationId::try_new(s)
48 .map_err(|e| {
49 tracing::warn!(error = %e, "Invalid gateway_conversation_id header value, ignoring");
50 e
51 })
52 .ok()
53 })
54 }
55
56 pub fn extract_provider_request_id(headers: &HeaderMap) -> Option<ProviderRequestId> {
57 Self::extract_header(headers, headers::PROVIDER_REQUEST_ID)
58 .filter(|s| !s.is_empty())
59 .and_then(|s| {
60 ProviderRequestId::try_new(s)
61 .map_err(|e| {
62 tracing::warn!(error = %e, "Invalid provider_request_id header value, ignoring");
63 e
64 })
65 .ok()
66 })
67 }
68
69 pub fn extract_task_id(headers: &HeaderMap) -> Option<TaskId> {
70 Self::extract_header(headers, headers::TASK_ID).map(TaskId::new)
71 }
72
73 pub fn extract_agent_name(headers: &HeaderMap) -> AgentName {
74 Self::extract_header(headers, headers::AGENT_NAME)
75 .map_or_else(AgentName::system, AgentName::new)
76 }
77
78 fn extract_header(headers: &HeaderMap, name: &str) -> Option<String> {
79 headers
80 .get(name)
81 .and_then(|v| {
82 v.to_str()
83 .map_err(|e| {
84 tracing::debug!(error = %e, header = %name, "Header contains non-ASCII characters");
85 e
86 })
87 .ok()
88 })
89 .map(str::to_owned)
90 }
91}
92
93#[derive(Debug, Clone, Copy)]
94pub struct HeaderInjector;
95
96impl HeaderInjector {
97 pub fn inject_session_id(
98 headers: &mut HeaderMap,
99 session_id: &SessionId,
100 ) -> Result<(), HeaderInjectionError> {
101 Self::inject_header(headers, headers::SESSION_ID, session_id.as_str())
102 }
103
104 pub fn inject_user_id(
105 headers: &mut HeaderMap,
106 user_id: &UserId,
107 ) -> Result<(), HeaderInjectionError> {
108 Self::inject_header(headers, headers::USER_ID, user_id.as_str())
109 }
110
111 pub fn inject_trace_id(
112 headers: &mut HeaderMap,
113 trace_id: &TraceId,
114 ) -> Result<(), HeaderInjectionError> {
115 Self::inject_header(headers, headers::TRACE_ID, trace_id.as_str())
116 }
117
118 pub fn inject_context_id(
119 headers: &mut HeaderMap,
120 context_id: &ContextId,
121 ) -> Result<(), HeaderInjectionError> {
122 Self::inject_header(headers, headers::CONTEXT_ID, context_id.as_str())
123 }
124
125 pub fn inject_gateway_conversation_id(
126 headers: &mut HeaderMap,
127 id: &GatewayConversationId,
128 ) -> Result<(), HeaderInjectionError> {
129 Self::inject_header(headers, headers::GATEWAY_CONVERSATION_ID, id.as_str())
130 }
131
132 pub fn inject_provider_request_id(
133 headers: &mut HeaderMap,
134 id: &ProviderRequestId,
135 ) -> Result<(), HeaderInjectionError> {
136 Self::inject_header(headers, headers::PROVIDER_REQUEST_ID, id.as_str())
137 }
138
139 pub fn inject_task_id(
140 headers: &mut HeaderMap,
141 task_id: &TaskId,
142 ) -> Result<(), HeaderInjectionError> {
143 Self::inject_header(headers, headers::TASK_ID, task_id.as_str())
144 }
145
146 pub fn inject_agent_name(
147 headers: &mut HeaderMap,
148 agent_name: &str,
149 ) -> Result<(), HeaderInjectionError> {
150 Self::inject_header(headers, headers::AGENT_NAME, agent_name)
151 }
152
153 pub fn inject_from_request_context(
154 headers: &mut HeaderMap,
155 ctx: &RequestContext,
156 ) -> Result<(), HeaderInjectionError> {
157 Self::inject_session_id(headers, &ctx.request.session_id)?;
158 Self::inject_user_id(headers, &ctx.auth.actor.user_id)?;
159 Self::inject_trace_id(headers, &ctx.execution.trace_id)?;
160 Self::inject_context_id(headers, &ctx.execution.context_id)?;
161 Self::inject_agent_name(headers, ctx.execution.agent_name.as_str())?;
162 Ok(())
163 }
164
165 fn inject_header(
166 headers: &mut HeaderMap,
167 name: &'static str,
168 value: &str,
169 ) -> Result<(), HeaderInjectionError> {
170 HeaderValue::from_str(value).map_or(Err(HeaderInjectionError), |header_value| {
171 headers.insert(name, header_value);
172 Ok(())
173 })
174 }
175}