Skip to main content

systemprompt_models/execution/context/
mod.rs

1//! Request context for execution tracking.
2
3mod call_source;
4mod context_error;
5mod context_types;
6
7pub use call_source::CallSource;
8pub use context_error::{ContextExtractionError, ContextIdSource, TASK_BASED_CONTEXT_MARKER};
9pub use context_types::{
10    AuthContext, ExecutionContext, ExecutionSettings, RequestMetadata, UserInteractionMode,
11};
12
13use crate::ai::ToolModelConfig;
14use crate::auth::{AuthenticatedUser, RateLimitTier, UserType};
15use anyhow::anyhow;
16use http::{HeaderMap, HeaderValue};
17use serde::{Deserialize, Serialize};
18use std::str::FromStr;
19use std::time::{Duration, Instant};
20use systemprompt_identifiers::{
21    headers, AgentName, AiToolCallId, ClientId, ContextId, JwtToken, McpExecutionId, SessionId,
22    TaskId, TraceId, UserId,
23};
24use systemprompt_traits::{ContextPropagation, InjectContextHeaders};
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
27pub struct RequestContext {
28    pub auth: AuthContext,
29    pub request: RequestMetadata,
30    pub execution: ExecutionContext,
31    pub settings: ExecutionSettings,
32
33    #[serde(skip)]
34    pub user: Option<AuthenticatedUser>,
35
36    #[serde(skip, default = "Instant::now")]
37    pub start_time: Instant,
38}
39
40impl RequestContext {
41    /// Creates a new `RequestContext` - the ONLY way to construct a context.
42    ///
43    /// This is the single constructor for `RequestContext`. All contexts must
44    /// be created through this method, ensuring consistent initialization.
45    ///
46    /// # Required Fields
47    /// - `session_id`: Identifies the user session
48    /// - `trace_id`: For distributed tracing
49    /// - `context_id`: Conversation/execution context (empty string for
50    ///   user-level contexts)
51    /// - `agent_name`: The agent handling this request (use
52    ///   `AgentName::system()` for system operations)
53    ///
54    /// # Optional Fields
55    /// Use builder methods to set optional fields:
56    /// - `.with_user_id()` - Set the authenticated user
57    /// - `.with_auth_token()` - Set the JWT token
58    /// - `.with_user_type()` - Set user type (Admin, Standard, Anon)
59    /// - `.with_task_id()` - Set task ID for AI operations
60    /// - `.with_client_id()` - Set client ID
61    /// - `.with_call_source()` - Set call source (Agentic, Direct, Ephemeral)
62    ///
63    /// # Example
64    /// ```
65    /// # use systemprompt_models::execution::context::RequestContext;
66    /// # use systemprompt_identifiers::{SessionId, TraceId, ContextId, AgentName, UserId};
67    /// # use systemprompt_models::auth::UserType;
68    /// let ctx = RequestContext::new(
69    ///     SessionId::new("sess_123".to_string()),
70    ///     TraceId::new("trace_456".to_string()),
71    ///     ContextId::new("ctx_789".to_string()),
72    ///     AgentName::new("my-agent".to_string()),
73    /// )
74    /// .with_user_id(UserId::new("user_123".to_string()))
75    /// .with_auth_token("jwt_token_here")
76    /// .with_user_type(UserType::User);
77    /// ```
78    pub fn new(
79        session_id: SessionId,
80        trace_id: TraceId,
81        context_id: ContextId,
82        agent_name: AgentName,
83    ) -> Self {
84        Self {
85            auth: AuthContext {
86                auth_token: JwtToken::new(""),
87                user_id: UserId::anonymous(),
88                user_type: UserType::Anon,
89            },
90            request: RequestMetadata {
91                session_id,
92                timestamp: Instant::now(),
93                client_id: None,
94                is_tracked: true,
95            },
96            execution: ExecutionContext {
97                trace_id,
98                context_id,
99                task_id: None,
100                ai_tool_call_id: None,
101                mcp_execution_id: None,
102                call_source: None,
103                agent_name,
104                tool_model_config: None,
105            },
106            settings: ExecutionSettings::default(),
107            user: None,
108            start_time: Instant::now(),
109        }
110    }
111
112    pub fn with_user(mut self, user: AuthenticatedUser) -> Self {
113        self.auth.user_id = UserId::new(user.id.to_string());
114        self.user = Some(user);
115        self
116    }
117
118    pub fn with_user_id(mut self, user_id: UserId) -> Self {
119        self.auth.user_id = user_id;
120        self
121    }
122
123    pub fn with_agent_name(mut self, agent_name: AgentName) -> Self {
124        self.execution.agent_name = agent_name;
125        self
126    }
127
128    pub fn with_context_id(mut self, context_id: ContextId) -> Self {
129        self.execution.context_id = context_id;
130        self
131    }
132
133    pub fn with_task_id(mut self, task_id: TaskId) -> Self {
134        self.execution.task_id = Some(task_id);
135        self
136    }
137
138    pub fn with_task(mut self, task_id: TaskId, call_source: CallSource) -> Self {
139        self.execution.task_id = Some(task_id);
140        self.execution.call_source = Some(call_source);
141        self
142    }
143
144    pub fn with_ai_tool_call_id(mut self, ai_tool_call_id: AiToolCallId) -> Self {
145        self.execution.ai_tool_call_id = Some(ai_tool_call_id);
146        self
147    }
148
149    pub fn with_mcp_execution_id(mut self, mcp_execution_id: McpExecutionId) -> Self {
150        self.execution.mcp_execution_id = Some(mcp_execution_id);
151        self
152    }
153
154    pub fn with_client_id(mut self, client_id: ClientId) -> Self {
155        self.request.client_id = Some(client_id);
156        self
157    }
158
159    pub const fn with_user_type(mut self, user_type: UserType) -> Self {
160        self.auth.user_type = user_type;
161        self
162    }
163
164    pub fn with_auth_token(mut self, token: impl Into<String>) -> Self {
165        self.auth.auth_token = JwtToken::new(token.into());
166        self
167    }
168
169    pub const fn with_call_source(mut self, call_source: CallSource) -> Self {
170        self.execution.call_source = Some(call_source);
171        self
172    }
173
174    pub const fn with_budget(mut self, cents: i32) -> Self {
175        self.settings.max_budget_cents = Some(cents);
176        self
177    }
178
179    pub const fn with_interaction_mode(mut self, mode: UserInteractionMode) -> Self {
180        self.settings.user_interaction_mode = Some(mode);
181        self
182    }
183
184    pub const fn with_tracked(mut self, is_tracked: bool) -> Self {
185        self.request.is_tracked = is_tracked;
186        self
187    }
188
189    pub fn with_tool_model_config(mut self, config: ToolModelConfig) -> Self {
190        self.execution.tool_model_config = Some(config);
191        self
192    }
193
194    pub const fn tool_model_config(&self) -> Option<&ToolModelConfig> {
195        self.execution.tool_model_config.as_ref()
196    }
197
198    pub const fn session_id(&self) -> &SessionId {
199        &self.request.session_id
200    }
201
202    pub const fn user_id(&self) -> &UserId {
203        &self.auth.user_id
204    }
205
206    pub const fn trace_id(&self) -> &TraceId {
207        &self.execution.trace_id
208    }
209
210    pub const fn context_id(&self) -> &ContextId {
211        &self.execution.context_id
212    }
213
214    pub const fn agent_name(&self) -> &AgentName {
215        &self.execution.agent_name
216    }
217
218    pub const fn auth_token(&self) -> &JwtToken {
219        &self.auth.auth_token
220    }
221
222    pub const fn user_type(&self) -> UserType {
223        self.auth.user_type
224    }
225
226    pub const fn rate_limit_tier(&self) -> RateLimitTier {
227        self.auth.user_type.rate_tier()
228    }
229
230    pub const fn task_id(&self) -> Option<&TaskId> {
231        self.execution.task_id.as_ref()
232    }
233
234    pub const fn client_id(&self) -> Option<&ClientId> {
235        self.request.client_id.as_ref()
236    }
237
238    pub const fn ai_tool_call_id(&self) -> Option<&AiToolCallId> {
239        self.execution.ai_tool_call_id.as_ref()
240    }
241
242    pub const fn mcp_execution_id(&self) -> Option<&McpExecutionId> {
243        self.execution.mcp_execution_id.as_ref()
244    }
245
246    pub const fn call_source(&self) -> Option<CallSource> {
247        self.execution.call_source
248    }
249
250    pub const fn is_authenticated(&self) -> bool {
251        self.user.is_some()
252    }
253
254    pub fn is_system(&self) -> bool {
255        self.auth.user_id.is_system() && self.execution.context_id.is_system()
256    }
257
258    pub fn elapsed(&self) -> Duration {
259        self.start_time.elapsed()
260    }
261
262    pub fn validate_task_execution(&self) -> Result<(), String> {
263        if self.execution.task_id.is_none() {
264            return Err("Missing task_id for task execution".to_string());
265        }
266        if self.execution.context_id.as_str().is_empty() {
267            return Err("Missing context_id for task execution".to_string());
268        }
269        Ok(())
270    }
271
272    pub fn validate_authenticated(&self) -> Result<(), String> {
273        if self.auth.auth_token.as_str().is_empty() {
274            return Err("Missing authentication token".to_string());
275        }
276        if self.auth.user_id.is_anonymous() {
277            return Err("User is not authenticated".to_string());
278        }
279        Ok(())
280    }
281}
282
283fn insert_header(headers: &mut HeaderMap, name: &'static str, value: &str) {
284    if let Ok(val) = HeaderValue::from_str(value) {
285        headers.insert(name, val);
286    }
287}
288
289fn insert_header_if_present(headers: &mut HeaderMap, name: &'static str, value: Option<&str>) {
290    if let Some(v) = value {
291        insert_header(headers, name, v);
292    }
293}
294
295impl InjectContextHeaders for RequestContext {
296    fn inject_headers(&self, hdrs: &mut HeaderMap) {
297        insert_header(hdrs, headers::SESSION_ID, self.request.session_id.as_str());
298        insert_header(hdrs, headers::TRACE_ID, self.execution.trace_id.as_str());
299        insert_header(hdrs, headers::USER_ID, self.auth.user_id.as_str());
300        insert_header(hdrs, headers::USER_TYPE, self.auth.user_type.as_str());
301        insert_header(
302            hdrs,
303            headers::AGENT_NAME,
304            self.execution.agent_name.as_str(),
305        );
306
307        let context_id = self.execution.context_id.as_str();
308        if !context_id.is_empty() {
309            insert_header(hdrs, headers::CONTEXT_ID, context_id);
310        }
311
312        insert_header_if_present(
313            hdrs,
314            headers::TASK_ID,
315            self.execution.task_id.as_ref().map(TaskId::as_str),
316        );
317        insert_header_if_present(
318            hdrs,
319            headers::AI_TOOL_CALL_ID,
320            self.execution.ai_tool_call_id.as_ref().map(AsRef::as_ref),
321        );
322        insert_header_if_present(
323            hdrs,
324            headers::CALL_SOURCE,
325            self.execution.call_source.as_ref().map(CallSource::as_str),
326        );
327        insert_header_if_present(
328            hdrs,
329            headers::CLIENT_ID,
330            self.request.client_id.as_ref().map(ClientId::as_str),
331        );
332
333        let auth_token = self.auth.auth_token.as_str();
334        if auth_token.is_empty() {
335            tracing::trace!(user_id = %self.auth.user_id, "No auth_token to inject - Authorization header not added");
336        } else {
337            let auth_value = format!("Bearer {}", auth_token);
338            insert_header(hdrs, headers::AUTHORIZATION, &auth_value);
339            tracing::trace!(user_id = %self.auth.user_id, "Injected Authorization header for proxy");
340        }
341    }
342}
343
344impl ContextPropagation for RequestContext {
345    fn from_headers(hdrs: &HeaderMap) -> anyhow::Result<Self> {
346        let session_id = hdrs
347            .get(headers::SESSION_ID)
348            .and_then(|v| v.to_str().ok())
349            .ok_or_else(|| anyhow!("Missing {} header", headers::SESSION_ID))?;
350
351        let trace_id = hdrs
352            .get(headers::TRACE_ID)
353            .and_then(|v| v.to_str().ok())
354            .ok_or_else(|| anyhow!("Missing {} header", headers::TRACE_ID))?;
355
356        let user_id = hdrs
357            .get(headers::USER_ID)
358            .and_then(|v| v.to_str().ok())
359            .ok_or_else(|| anyhow!("Missing {} header", headers::USER_ID))?;
360
361        let context_id = hdrs
362            .get(headers::CONTEXT_ID)
363            .and_then(|v| v.to_str().ok())
364            .map_or_else(
365                || ContextId::new(String::new()),
366                |s| ContextId::new(s.to_string()),
367            );
368
369        let agent_name = hdrs
370            .get(headers::AGENT_NAME)
371            .and_then(|v| v.to_str().ok())
372            .ok_or_else(|| {
373                anyhow!(
374                    "Missing {} header - all requests must have agent context",
375                    headers::AGENT_NAME
376                )
377            })?;
378
379        let task_id = hdrs
380            .get(headers::TASK_ID)
381            .and_then(|v| v.to_str().ok())
382            .map(|s| TaskId::new(s.to_string()));
383
384        let ai_tool_call_id = hdrs
385            .get(headers::AI_TOOL_CALL_ID)
386            .and_then(|v| v.to_str().ok())
387            .map(|s| AiToolCallId::from(s.to_string()));
388
389        let call_source = hdrs
390            .get(headers::CALL_SOURCE)
391            .and_then(|v| v.to_str().ok())
392            .and_then(|s| CallSource::from_str(s).ok());
393
394        let client_id = hdrs
395            .get(headers::CLIENT_ID)
396            .and_then(|v| v.to_str().ok())
397            .map(|s| ClientId::new(s.to_string()));
398
399        let mut ctx = Self::new(
400            SessionId::new(session_id.to_string()),
401            TraceId::new(trace_id.to_string()),
402            context_id,
403            AgentName::new(agent_name.to_string()),
404        )
405        .with_user_id(UserId::new(user_id.to_string()));
406
407        if let Some(tid) = task_id {
408            ctx = ctx.with_task_id(tid);
409        }
410
411        if let Some(ai_id) = ai_tool_call_id {
412            ctx = ctx.with_ai_tool_call_id(ai_id);
413        }
414
415        if let Some(cs) = call_source {
416            ctx = ctx.with_call_source(cs);
417        }
418
419        if let Some(cid) = client_id {
420            ctx = ctx.with_client_id(cid);
421        }
422
423        Ok(ctx)
424    }
425
426    fn to_headers(&self) -> HeaderMap {
427        let mut headers = HeaderMap::new();
428        self.inject_headers(&mut headers);
429        headers
430    }
431}