Skip to main content

treeship_core/session/
context.rs

1//! Cross-tool and cross-host context propagation.
2//!
3//! Propagates session identity through environment variables, HTTP headers,
4//! and CLI wrappers so that spawned processes and remote agents inherit
5//! session context.
6
7use serde::{Deserialize, Serialize};
8
9use super::event::{generate_span_id, generate_trace_id};
10
11/// Environment variable prefix for Treeship context.
12const ENV_PREFIX: &str = "TREESHIP_";
13
14/// HTTP header prefix for Treeship context.
15const HEADER_PREFIX: &str = "x-treeship-";
16
17/// Field names used for both env vars and headers.
18const FIELD_SESSION_ID: &str = "SESSION_ID";
19const FIELD_TRACE_ID: &str = "TRACE_ID";
20const FIELD_SPAN_ID: &str = "SPAN_ID";
21const FIELD_PARENT_SPAN_ID: &str = "PARENT_SPAN_ID";
22const FIELD_AGENT_ID: &str = "AGENT_ID";
23const FIELD_AGENT_INSTANCE_ID: &str = "AGENT_INSTANCE_ID";
24const FIELD_WORKSPACE_ID: &str = "WORKSPACE_ID";
25const FIELD_MISSION_ID: &str = "MISSION_ID";
26const FIELD_HOST_ID: &str = "HOST_ID";
27const FIELD_TOOL_RUNTIME_ID: &str = "TOOL_RUNTIME_ID";
28
29/// Context propagated across tool and host boundaries.
30#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct PropagationContext {
32    pub session_id: String,
33    pub trace_id: String,
34    pub span_id: String,
35    #[serde(skip_serializing_if = "Option::is_none")]
36    pub parent_span_id: Option<String>,
37    pub agent_id: String,
38    pub agent_instance_id: String,
39    #[serde(skip_serializing_if = "Option::is_none")]
40    pub workspace_id: Option<String>,
41    #[serde(skip_serializing_if = "Option::is_none")]
42    pub mission_id: Option<String>,
43    pub host_id: String,
44    #[serde(skip_serializing_if = "Option::is_none")]
45    pub tool_runtime_id: Option<String>,
46}
47
48impl PropagationContext {
49    /// Read context from environment variables.
50    ///
51    /// Returns `None` if required fields (`TREESHIP_SESSION_ID`, `TREESHIP_TRACE_ID`)
52    /// are not present.
53    pub fn from_env() -> Option<Self> {
54        let session_id = std::env::var(format!("{ENV_PREFIX}{FIELD_SESSION_ID}")).ok()?;
55        let trace_id = std::env::var(format!("{ENV_PREFIX}{FIELD_TRACE_ID}"))
56            .unwrap_or_else(|_| generate_trace_id());
57
58        Some(Self {
59            session_id,
60            trace_id,
61            span_id: std::env::var(format!("{ENV_PREFIX}{FIELD_SPAN_ID}"))
62                .unwrap_or_else(|_| generate_span_id()),
63            parent_span_id: std::env::var(format!("{ENV_PREFIX}{FIELD_PARENT_SPAN_ID}")).ok(),
64            agent_id: std::env::var(format!("{ENV_PREFIX}{FIELD_AGENT_ID}"))
65                .unwrap_or_else(|_| "agent://unknown".into()),
66            agent_instance_id: std::env::var(format!("{ENV_PREFIX}{FIELD_AGENT_INSTANCE_ID}"))
67                .unwrap_or_else(|_| "ai_unknown".into()),
68            workspace_id: std::env::var(format!("{ENV_PREFIX}{FIELD_WORKSPACE_ID}")).ok(),
69            mission_id: std::env::var(format!("{ENV_PREFIX}{FIELD_MISSION_ID}")).ok(),
70            host_id: std::env::var(format!("{ENV_PREFIX}{FIELD_HOST_ID}"))
71                .unwrap_or_else(|_| default_host_id()),
72            tool_runtime_id: std::env::var(format!("{ENV_PREFIX}{FIELD_TOOL_RUNTIME_ID}")).ok(),
73        })
74    }
75
76    /// Inject context as environment variables on a Command builder.
77    pub fn inject_env(&self, cmd: &mut std::process::Command) {
78        cmd.env(format!("{ENV_PREFIX}{FIELD_SESSION_ID}"), &self.session_id);
79        cmd.env(format!("{ENV_PREFIX}{FIELD_TRACE_ID}"), &self.trace_id);
80        cmd.env(format!("{ENV_PREFIX}{FIELD_SPAN_ID}"), &self.span_id);
81        if let Some(ref psid) = self.parent_span_id {
82            cmd.env(format!("{ENV_PREFIX}{FIELD_PARENT_SPAN_ID}"), psid);
83        }
84        cmd.env(format!("{ENV_PREFIX}{FIELD_AGENT_ID}"), &self.agent_id);
85        cmd.env(format!("{ENV_PREFIX}{FIELD_AGENT_INSTANCE_ID}"), &self.agent_instance_id);
86        if let Some(ref wid) = self.workspace_id {
87            cmd.env(format!("{ENV_PREFIX}{FIELD_WORKSPACE_ID}"), wid);
88        }
89        if let Some(ref mid) = self.mission_id {
90            cmd.env(format!("{ENV_PREFIX}{FIELD_MISSION_ID}"), mid);
91        }
92        cmd.env(format!("{ENV_PREFIX}{FIELD_HOST_ID}"), &self.host_id);
93        if let Some(ref trid) = self.tool_runtime_id {
94            cmd.env(format!("{ENV_PREFIX}{FIELD_TOOL_RUNTIME_ID}"), trid);
95        }
96    }
97
98    /// Produce HTTP header pairs for outbound requests.
99    pub fn to_headers(&self) -> Vec<(String, String)> {
100        let mut h = vec![
101            (format!("{HEADER_PREFIX}session-id"), self.session_id.clone()),
102            (format!("{HEADER_PREFIX}trace-id"), self.trace_id.clone()),
103            (format!("{HEADER_PREFIX}span-id"), self.span_id.clone()),
104            (format!("{HEADER_PREFIX}agent-id"), self.agent_id.clone()),
105            (format!("{HEADER_PREFIX}agent-instance-id"), self.agent_instance_id.clone()),
106            (format!("{HEADER_PREFIX}host-id"), self.host_id.clone()),
107        ];
108        if let Some(ref psid) = self.parent_span_id {
109            h.push((format!("{HEADER_PREFIX}parent-span-id"), psid.clone()));
110        }
111        if let Some(ref wid) = self.workspace_id {
112            h.push((format!("{HEADER_PREFIX}workspace-id"), wid.clone()));
113        }
114        if let Some(ref mid) = self.mission_id {
115            h.push((format!("{HEADER_PREFIX}mission-id"), mid.clone()));
116        }
117        if let Some(ref trid) = self.tool_runtime_id {
118            h.push((format!("{HEADER_PREFIX}tool-runtime-id"), trid.clone()));
119        }
120        h
121    }
122
123    /// Parse context from HTTP header pairs.
124    pub fn from_headers(headers: &[(String, String)]) -> Option<Self> {
125        let get = |name: &str| -> Option<String> {
126            let key = format!("{HEADER_PREFIX}{name}");
127            headers.iter()
128                .find(|(k, _)| k.eq_ignore_ascii_case(&key))
129                .map(|(_, v)| v.clone())
130        };
131
132        let session_id = get("session-id")?;
133        let trace_id = get("trace-id").unwrap_or_else(generate_trace_id);
134
135        Some(Self {
136            session_id,
137            trace_id,
138            span_id: get("span-id").unwrap_or_else(generate_span_id),
139            parent_span_id: get("parent-span-id"),
140            agent_id: get("agent-id").unwrap_or_else(|| "agent://unknown".into()),
141            agent_instance_id: get("agent-instance-id").unwrap_or_else(|| "ai_unknown".into()),
142            workspace_id: get("workspace-id"),
143            mission_id: get("mission-id"),
144            host_id: get("host-id").unwrap_or_else(default_host_id),
145            tool_runtime_id: get("tool-runtime-id"),
146        })
147    }
148
149    /// Generate a child span context: new span_id, current span_id becomes parent.
150    pub fn child_span(&self) -> Self {
151        Self {
152            session_id: self.session_id.clone(),
153            trace_id: self.trace_id.clone(),
154            span_id: generate_span_id(),
155            parent_span_id: Some(self.span_id.clone()),
156            agent_id: self.agent_id.clone(),
157            agent_instance_id: self.agent_instance_id.clone(),
158            workspace_id: self.workspace_id.clone(),
159            mission_id: self.mission_id.clone(),
160            host_id: self.host_id.clone(),
161            tool_runtime_id: self.tool_runtime_id.clone(),
162        }
163    }
164
165    /// Generate a W3C traceparent header value.
166    pub fn to_traceparent(&self) -> String {
167        // Pad trace_id to 32 chars, span_id to 16 chars
168        let tid = format!("{:0>32}", &self.trace_id);
169        let sid = format!("{:0>16}", &self.span_id);
170        format!("00-{tid}-{sid}-01")
171    }
172}
173
174/// Default host ID derived from hostname.
175fn default_host_id() -> String {
176    hostname::get()
177        .ok()
178        .and_then(|h| h.into_string().ok())
179        .map(|h| format!("host_{}", h.replace('.', "_")))
180        .unwrap_or_else(|| "host_unknown".into())
181}
182
183#[cfg(test)]
184mod tests {
185    use super::*;
186
187    #[test]
188    fn child_span_preserves_trace() {
189        let ctx = PropagationContext {
190            session_id: "ssn_001".into(),
191            trace_id: "abcd1234abcd1234abcd1234abcd1234".into(),
192            span_id: "1111222233334444".into(),
193            parent_span_id: None,
194            agent_id: "agent://test".into(),
195            agent_instance_id: "ai_1".into(),
196            workspace_id: None,
197            mission_id: None,
198            host_id: "host_local".into(),
199            tool_runtime_id: None,
200        };
201
202        let child = ctx.child_span();
203        assert_eq!(child.trace_id, ctx.trace_id);
204        assert_eq!(child.parent_span_id.as_deref(), Some("1111222233334444"));
205        assert_ne!(child.span_id, ctx.span_id);
206    }
207
208    #[test]
209    fn headers_roundtrip() {
210        let ctx = PropagationContext {
211            session_id: "ssn_002".into(),
212            trace_id: "abcd".into(),
213            span_id: "ef01".into(),
214            parent_span_id: Some("0000".into()),
215            agent_id: "agent://claude".into(),
216            agent_instance_id: "ai_cc_1".into(),
217            workspace_id: Some("ws_1".into()),
218            mission_id: None,
219            host_id: "host_mac".into(),
220            tool_runtime_id: Some("rt_1".into()),
221        };
222
223        let headers = ctx.to_headers();
224        let back = PropagationContext::from_headers(&headers).unwrap();
225        assert_eq!(back.session_id, "ssn_002");
226        assert_eq!(back.parent_span_id.as_deref(), Some("0000"));
227        assert_eq!(back.workspace_id.as_deref(), Some("ws_1"));
228    }
229
230    #[test]
231    fn traceparent_format() {
232        let ctx = PropagationContext {
233            session_id: "ssn_001".into(),
234            trace_id: "abcd1234abcd1234abcd1234abcd1234".into(),
235            span_id: "1111222233334444".into(),
236            parent_span_id: None,
237            agent_id: "agent://test".into(),
238            agent_instance_id: "ai_1".into(),
239            workspace_id: None,
240            mission_id: None,
241            host_id: "host_local".into(),
242            tool_runtime_id: None,
243        };
244        let tp = ctx.to_traceparent();
245        assert_eq!(tp, "00-abcd1234abcd1234abcd1234abcd1234-1111222233334444-01");
246    }
247}