systemprompt_security/extraction/
header.rs1use axum::http::{HeaderMap, HeaderValue};
2use std::error::Error;
3use std::fmt;
4use systemprompt_identifiers::{headers, AgentName, ContextId, SessionId, TaskId, TraceId, UserId};
5use systemprompt_models::execution::context::RequestContext;
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
8pub struct HeaderInjectionError;
9
10impl fmt::Display for HeaderInjectionError {
11 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
12 write!(f, "Header value contains invalid characters")
13 }
14}
15
16impl Error for HeaderInjectionError {}
17
18#[derive(Debug, Clone, Copy)]
19pub struct HeaderExtractor;
20
21impl HeaderExtractor {
22 pub fn extract_trace_id(headers: &HeaderMap) -> TraceId {
23 Self::extract_header(headers, headers::TRACE_ID)
24 .map_or_else(TraceId::generate, TraceId::new)
25 }
26
27 pub fn extract_context_id(headers: &HeaderMap) -> ContextId {
28 Self::extract_header(headers, headers::CONTEXT_ID)
29 .filter(|s| !s.is_empty())
30 .map_or_else(ContextId::empty, ContextId::new)
31 }
32
33 pub fn extract_task_id(headers: &HeaderMap) -> Option<TaskId> {
34 Self::extract_header(headers, headers::TASK_ID).map(TaskId::new)
35 }
36
37 pub fn extract_agent_name(headers: &HeaderMap) -> AgentName {
38 Self::extract_header(headers, headers::AGENT_NAME)
39 .map_or_else(AgentName::system, AgentName::new)
40 }
41
42 fn extract_header(headers: &HeaderMap, name: &str) -> Option<String> {
43 headers
44 .get(name)
45 .and_then(|v| {
46 v.to_str()
47 .map_err(|e| {
48 tracing::debug!(error = %e, header = %name, "Header contains non-ASCII characters");
49 e
50 })
51 .ok()
52 })
53 .map(ToString::to_string)
54 }
55}
56
57#[derive(Debug, Clone, Copy)]
58pub struct HeaderInjector;
59
60impl HeaderInjector {
61 pub fn inject_session_id(
62 headers: &mut HeaderMap,
63 session_id: &SessionId,
64 ) -> Result<(), HeaderInjectionError> {
65 Self::inject_header(headers, headers::SESSION_ID, session_id.as_str())
66 }
67
68 pub fn inject_user_id(
69 headers: &mut HeaderMap,
70 user_id: &UserId,
71 ) -> Result<(), HeaderInjectionError> {
72 Self::inject_header(headers, headers::USER_ID, user_id.as_str())
73 }
74
75 pub fn inject_trace_id(
76 headers: &mut HeaderMap,
77 trace_id: &TraceId,
78 ) -> Result<(), HeaderInjectionError> {
79 Self::inject_header(headers, headers::TRACE_ID, trace_id.as_str())
80 }
81
82 pub fn inject_context_id(
83 headers: &mut HeaderMap,
84 context_id: &ContextId,
85 ) -> Result<(), HeaderInjectionError> {
86 if context_id.as_str().is_empty() {
87 return Ok(());
88 }
89 Self::inject_header(headers, headers::CONTEXT_ID, context_id.as_str())
90 }
91
92 pub fn inject_task_id(
93 headers: &mut HeaderMap,
94 task_id: &TaskId,
95 ) -> Result<(), HeaderInjectionError> {
96 Self::inject_header(headers, headers::TASK_ID, task_id.as_str())
97 }
98
99 pub fn inject_agent_name(
100 headers: &mut HeaderMap,
101 agent_name: &str,
102 ) -> Result<(), HeaderInjectionError> {
103 Self::inject_header(headers, headers::AGENT_NAME, agent_name)
104 }
105
106 pub fn inject_from_request_context(
107 headers: &mut HeaderMap,
108 ctx: &RequestContext,
109 ) -> Result<(), HeaderInjectionError> {
110 Self::inject_session_id(headers, &ctx.request.session_id)?;
111 Self::inject_user_id(headers, &ctx.auth.user_id)?;
112 Self::inject_trace_id(headers, &ctx.execution.trace_id)?;
113 Self::inject_context_id(headers, &ctx.execution.context_id)?;
114 Self::inject_agent_name(headers, ctx.execution.agent_name.as_str())?;
115 Ok(())
116 }
117
118 fn inject_header(
119 headers: &mut HeaderMap,
120 name: &'static str,
121 value: &str,
122 ) -> Result<(), HeaderInjectionError> {
123 HeaderValue::from_str(value).map_or(Err(HeaderInjectionError), |header_value| {
124 headers.insert(name, header_value);
125 Ok(())
126 })
127 }
128}