Skip to main content

serdes_ai_tools/
context.rs

1//! Run context for tool execution.
2//!
3//! This module provides the `RunContext` type which carries contextual information
4//! to tools during execution, including dependencies, model info, and usage tracking.
5
6use chrono::{DateTime, Utc};
7use serdes_ai_core::{identifier::generate_run_id, ModelSettings, RunUsage};
8use std::sync::Arc;
9
10/// Context passed to tools during execution.
11///
12/// The `RunContext` provides tools with access to:
13/// - User-provided dependencies (database connections, API clients, etc.)
14/// - Run metadata (ID, start time, model name)
15/// - Retry information
16/// - Current usage statistics
17///
18/// # Type Parameters
19///
20/// - `Deps`: The type of dependencies available to tools. Defaults to `()`.
21///
22/// # Example
23///
24/// ```rust
25/// use serdes_ai_tools::RunContext;
26/// use std::sync::Arc;
27///
28/// struct MyDeps {
29///     api_key: String,
30/// }
31///
32/// let deps = MyDeps { api_key: "secret".into() };
33/// let ctx = RunContext::new(deps, "gpt-4");
34///
35/// // Access deps in a tool
36/// assert_eq!(ctx.deps.api_key, "secret");
37/// ```
38#[derive(Debug, Clone)]
39pub struct RunContext<Deps = ()> {
40    /// User-provided dependencies.
41    pub deps: Arc<Deps>,
42
43    /// Unique identifier for this run.
44    pub run_id: String,
45
46    /// When this run started.
47    pub start_time: DateTime<Utc>,
48
49    /// Current retry count for tool call.
50    pub retry_count: u32,
51
52    /// Maximum retries allowed.
53    pub max_retries: u32,
54
55    /// Name of the tool being called (if in a tool call).
56    pub tool_name: Option<String>,
57
58    /// Tool call ID (if in a tool call).
59    pub tool_call_id: Option<String>,
60
61    /// Name of the model being used.
62    pub model_name: String,
63
64    /// Current model settings.
65    pub model_settings: ModelSettings,
66
67    /// Usage statistics so far.
68    pub usage: RunUsage,
69
70    /// Custom metadata.
71    pub metadata: Option<serde_json::Value>,
72
73    /// Whether partial output is being generated.
74    pub partial_output: bool,
75}
76
77impl<Deps> RunContext<Deps> {
78    /// Create a new run context.
79    #[must_use]
80    pub fn new(deps: Deps, model_name: impl Into<String>) -> Self {
81        Self {
82            deps: Arc::new(deps),
83            run_id: generate_run_id(),
84            start_time: Utc::now(),
85            retry_count: 0,
86            max_retries: 3,
87            tool_name: None,
88            tool_call_id: None,
89            model_name: model_name.into(),
90            model_settings: ModelSettings::default(),
91            usage: RunUsage::default(),
92            metadata: None,
93            partial_output: false,
94        }
95    }
96
97    /// Create a context from existing Arc'd deps.
98    #[must_use]
99    pub fn from_arc(deps: Arc<Deps>, model_name: impl Into<String>) -> Self {
100        Self {
101            deps,
102            run_id: generate_run_id(),
103            start_time: Utc::now(),
104            retry_count: 0,
105            max_retries: 3,
106            tool_name: None,
107            tool_call_id: None,
108            model_name: model_name.into(),
109            model_settings: ModelSettings::default(),
110            usage: RunUsage::default(),
111            metadata: None,
112            partial_output: false,
113        }
114    }
115
116    /// Set the run ID.
117    #[must_use]
118    pub fn with_run_id(mut self, run_id: impl Into<String>) -> Self {
119        self.run_id = run_id.into();
120        self
121    }
122
123    /// Set max retries.
124    #[must_use]
125    pub fn with_max_retries(mut self, max_retries: u32) -> Self {
126        self.max_retries = max_retries;
127        self
128    }
129
130    /// Set model settings.
131    #[must_use]
132    pub fn with_model_settings(mut self, settings: ModelSettings) -> Self {
133        self.model_settings = settings;
134        self
135    }
136
137    /// Set metadata.
138    #[must_use]
139    pub fn with_metadata(mut self, metadata: serde_json::Value) -> Self {
140        self.metadata = Some(metadata);
141        self
142    }
143
144    /// Set the tool context for a tool call.
145    #[must_use]
146    pub fn with_tool_context(
147        mut self,
148        tool_name: impl Into<String>,
149        tool_call_id: Option<String>,
150    ) -> Self {
151        self.tool_name = Some(tool_name.into());
152        self.tool_call_id = tool_call_id;
153        self
154    }
155
156    /// Set partial output mode.
157    #[must_use]
158    pub fn with_partial_output(mut self, partial: bool) -> Self {
159        self.partial_output = partial;
160        self
161    }
162
163    /// Increment the retry count.
164    pub fn increment_retry(&mut self) {
165        self.retry_count += 1;
166    }
167
168    /// Check if we can retry.
169    #[must_use]
170    pub fn can_retry(&self) -> bool {
171        self.retry_count < self.max_retries
172    }
173
174    /// Get elapsed time since start.
175    #[must_use]
176    pub fn elapsed(&self) -> chrono::Duration {
177        Utc::now() - self.start_time
178    }
179
180    /// Get elapsed time in seconds.
181    #[must_use]
182    pub fn elapsed_secs(&self) -> f64 {
183        self.elapsed().num_milliseconds() as f64 / 1000.0
184    }
185
186    /// Check if we're currently in a tool call.
187    #[must_use]
188    pub fn in_tool_call(&self) -> bool {
189        self.tool_name.is_some()
190    }
191
192    /// Create a child context for a tool call.
193    #[must_use]
194    pub fn for_tool(&self, tool_name: impl Into<String>, tool_call_id: Option<String>) -> Self {
195        Self {
196            deps: Arc::clone(&self.deps),
197            run_id: self.run_id.clone(),
198            start_time: self.start_time,
199            retry_count: 0,
200            max_retries: self.max_retries,
201            tool_name: Some(tool_name.into()),
202            tool_call_id,
203            model_name: self.model_name.clone(),
204            model_settings: self.model_settings.clone(),
205            usage: self.usage.clone(),
206            metadata: self.metadata.clone(),
207            partial_output: self.partial_output,
208        }
209    }
210
211    /// Create a copy with updated usage.
212    #[must_use]
213    pub fn with_usage(mut self, usage: RunUsage) -> Self {
214        self.usage = usage;
215        self
216    }
217
218    /// Replace dependencies (for testing).
219    #[must_use]
220    pub fn with_deps<NewDeps>(self, new_deps: NewDeps) -> RunContext<NewDeps> {
221        RunContext {
222            deps: Arc::new(new_deps),
223            run_id: self.run_id,
224            start_time: self.start_time,
225            retry_count: self.retry_count,
226            max_retries: self.max_retries,
227            tool_name: self.tool_name,
228            tool_call_id: self.tool_call_id,
229            model_name: self.model_name,
230            model_settings: self.model_settings,
231            usage: self.usage,
232            metadata: self.metadata,
233            partial_output: self.partial_output,
234        }
235    }
236}
237
238impl<Deps: Default> Default for RunContext<Deps> {
239    fn default() -> Self {
240        Self::new(Deps::default(), "default")
241    }
242}
243
244impl RunContext<()> {
245    /// Create a minimal context without dependencies.
246    #[must_use]
247    pub fn minimal(model_name: impl Into<String>) -> Self {
248        Self::new((), model_name)
249    }
250}
251
252#[cfg(test)]
253mod tests {
254    use super::*;
255
256    #[derive(Debug, Clone, Default)]
257    struct TestDeps {
258        value: i32,
259    }
260
261    #[test]
262    fn test_run_context_new() {
263        let ctx = RunContext::new(TestDeps { value: 42 }, "gpt-4");
264        assert_eq!(ctx.deps.value, 42);
265        assert_eq!(ctx.model_name, "gpt-4");
266        assert!(ctx.run_id.starts_with("run_"));
267        assert_eq!(ctx.retry_count, 0);
268    }
269
270    #[test]
271    fn test_run_context_minimal() {
272        let ctx = RunContext::minimal("claude-3");
273        assert_eq!(ctx.model_name, "claude-3");
274    }
275
276    #[test]
277    fn test_run_context_with_tool_context() {
278        let ctx =
279            RunContext::minimal("gpt-4").with_tool_context("my_tool", Some("call_123".to_string()));
280        assert_eq!(ctx.tool_name, Some("my_tool".to_string()));
281        assert_eq!(ctx.tool_call_id, Some("call_123".to_string()));
282        assert!(ctx.in_tool_call());
283    }
284
285    #[test]
286    fn test_increment_retry() {
287        let mut ctx = RunContext::minimal("gpt-4").with_max_retries(3);
288        assert!(ctx.can_retry());
289        ctx.increment_retry();
290        ctx.increment_retry();
291        ctx.increment_retry();
292        assert!(!ctx.can_retry());
293    }
294
295    #[test]
296    fn test_for_tool() {
297        let ctx = RunContext::new(TestDeps { value: 10 }, "gpt-4");
298        let tool_ctx = ctx.for_tool("test_tool", Some("id1".to_string()));
299
300        // Same deps
301        assert_eq!(tool_ctx.deps.value, 10);
302        // Same run ID
303        assert_eq!(tool_ctx.run_id, ctx.run_id);
304        // Tool info set
305        assert_eq!(tool_ctx.tool_name, Some("test_tool".to_string()));
306        // Reset retry count
307        assert_eq!(tool_ctx.retry_count, 0);
308    }
309
310    #[test]
311    fn test_with_deps() {
312        let ctx = RunContext::minimal("gpt-4");
313        let new_ctx = ctx.with_deps(TestDeps { value: 99 });
314        assert_eq!(new_ctx.deps.value, 99);
315    }
316
317    #[test]
318    fn test_elapsed() {
319        let ctx = RunContext::minimal("gpt-4");
320        std::thread::sleep(std::time::Duration::from_millis(10));
321        let elapsed = ctx.elapsed_secs();
322        assert!(elapsed >= 0.01);
323    }
324
325    #[test]
326    fn test_default() {
327        let ctx: RunContext<TestDeps> = RunContext::default();
328        assert_eq!(ctx.deps.value, 0);
329        assert_eq!(ctx.model_name, "default");
330    }
331}