1use chrono::{DateTime, Utc};
7use serdes_ai_core::{identifier::generate_run_id, ModelSettings, RunUsage};
8use std::sync::Arc;
9
10#[derive(Debug, Clone)]
39pub struct RunContext<Deps = ()> {
40 pub deps: Arc<Deps>,
42
43 pub run_id: String,
45
46 pub start_time: DateTime<Utc>,
48
49 pub retry_count: u32,
51
52 pub max_retries: u32,
54
55 pub tool_name: Option<String>,
57
58 pub tool_call_id: Option<String>,
60
61 pub model_name: String,
63
64 pub model_settings: ModelSettings,
66
67 pub usage: RunUsage,
69
70 pub metadata: Option<serde_json::Value>,
72
73 pub partial_output: bool,
75}
76
77impl<Deps> RunContext<Deps> {
78 #[must_use]
80 pub fn new(deps: Deps, model_name: impl Into<String>) -> Self {
81 Self {
82 deps: Arc::new(deps),
83 run_id: generate_run_id(),
84 start_time: Utc::now(),
85 retry_count: 0,
86 max_retries: 3,
87 tool_name: None,
88 tool_call_id: None,
89 model_name: model_name.into(),
90 model_settings: ModelSettings::default(),
91 usage: RunUsage::default(),
92 metadata: None,
93 partial_output: false,
94 }
95 }
96
97 #[must_use]
99 pub fn from_arc(deps: Arc<Deps>, model_name: impl Into<String>) -> Self {
100 Self {
101 deps,
102 run_id: generate_run_id(),
103 start_time: Utc::now(),
104 retry_count: 0,
105 max_retries: 3,
106 tool_name: None,
107 tool_call_id: None,
108 model_name: model_name.into(),
109 model_settings: ModelSettings::default(),
110 usage: RunUsage::default(),
111 metadata: None,
112 partial_output: false,
113 }
114 }
115
116 #[must_use]
118 pub fn with_run_id(mut self, run_id: impl Into<String>) -> Self {
119 self.run_id = run_id.into();
120 self
121 }
122
123 #[must_use]
125 pub fn with_max_retries(mut self, max_retries: u32) -> Self {
126 self.max_retries = max_retries;
127 self
128 }
129
130 #[must_use]
132 pub fn with_model_settings(mut self, settings: ModelSettings) -> Self {
133 self.model_settings = settings;
134 self
135 }
136
137 #[must_use]
139 pub fn with_metadata(mut self, metadata: serde_json::Value) -> Self {
140 self.metadata = Some(metadata);
141 self
142 }
143
144 #[must_use]
146 pub fn with_tool_context(
147 mut self,
148 tool_name: impl Into<String>,
149 tool_call_id: Option<String>,
150 ) -> Self {
151 self.tool_name = Some(tool_name.into());
152 self.tool_call_id = tool_call_id;
153 self
154 }
155
156 #[must_use]
158 pub fn with_partial_output(mut self, partial: bool) -> Self {
159 self.partial_output = partial;
160 self
161 }
162
163 pub fn increment_retry(&mut self) {
165 self.retry_count += 1;
166 }
167
168 #[must_use]
170 pub fn can_retry(&self) -> bool {
171 self.retry_count < self.max_retries
172 }
173
174 #[must_use]
176 pub fn elapsed(&self) -> chrono::Duration {
177 Utc::now() - self.start_time
178 }
179
180 #[must_use]
182 pub fn elapsed_secs(&self) -> f64 {
183 self.elapsed().num_milliseconds() as f64 / 1000.0
184 }
185
186 #[must_use]
188 pub fn in_tool_call(&self) -> bool {
189 self.tool_name.is_some()
190 }
191
192 #[must_use]
194 pub fn for_tool(&self, tool_name: impl Into<String>, tool_call_id: Option<String>) -> Self {
195 Self {
196 deps: Arc::clone(&self.deps),
197 run_id: self.run_id.clone(),
198 start_time: self.start_time,
199 retry_count: 0,
200 max_retries: self.max_retries,
201 tool_name: Some(tool_name.into()),
202 tool_call_id,
203 model_name: self.model_name.clone(),
204 model_settings: self.model_settings.clone(),
205 usage: self.usage.clone(),
206 metadata: self.metadata.clone(),
207 partial_output: self.partial_output,
208 }
209 }
210
211 #[must_use]
213 pub fn with_usage(mut self, usage: RunUsage) -> Self {
214 self.usage = usage;
215 self
216 }
217
218 #[must_use]
220 pub fn with_deps<NewDeps>(self, new_deps: NewDeps) -> RunContext<NewDeps> {
221 RunContext {
222 deps: Arc::new(new_deps),
223 run_id: self.run_id,
224 start_time: self.start_time,
225 retry_count: self.retry_count,
226 max_retries: self.max_retries,
227 tool_name: self.tool_name,
228 tool_call_id: self.tool_call_id,
229 model_name: self.model_name,
230 model_settings: self.model_settings,
231 usage: self.usage,
232 metadata: self.metadata,
233 partial_output: self.partial_output,
234 }
235 }
236}
237
238impl<Deps: Default> Default for RunContext<Deps> {
239 fn default() -> Self {
240 Self::new(Deps::default(), "default")
241 }
242}
243
244impl RunContext<()> {
245 #[must_use]
247 pub fn minimal(model_name: impl Into<String>) -> Self {
248 Self::new((), model_name)
249 }
250}
251
252#[cfg(test)]
253mod tests {
254 use super::*;
255
256 #[derive(Debug, Clone, Default)]
257 struct TestDeps {
258 value: i32,
259 }
260
261 #[test]
262 fn test_run_context_new() {
263 let ctx = RunContext::new(TestDeps { value: 42 }, "gpt-4");
264 assert_eq!(ctx.deps.value, 42);
265 assert_eq!(ctx.model_name, "gpt-4");
266 assert!(ctx.run_id.starts_with("run_"));
267 assert_eq!(ctx.retry_count, 0);
268 }
269
270 #[test]
271 fn test_run_context_minimal() {
272 let ctx = RunContext::minimal("claude-3");
273 assert_eq!(ctx.model_name, "claude-3");
274 }
275
276 #[test]
277 fn test_run_context_with_tool_context() {
278 let ctx =
279 RunContext::minimal("gpt-4").with_tool_context("my_tool", Some("call_123".to_string()));
280 assert_eq!(ctx.tool_name, Some("my_tool".to_string()));
281 assert_eq!(ctx.tool_call_id, Some("call_123".to_string()));
282 assert!(ctx.in_tool_call());
283 }
284
285 #[test]
286 fn test_increment_retry() {
287 let mut ctx = RunContext::minimal("gpt-4").with_max_retries(3);
288 assert!(ctx.can_retry());
289 ctx.increment_retry();
290 ctx.increment_retry();
291 ctx.increment_retry();
292 assert!(!ctx.can_retry());
293 }
294
295 #[test]
296 fn test_for_tool() {
297 let ctx = RunContext::new(TestDeps { value: 10 }, "gpt-4");
298 let tool_ctx = ctx.for_tool("test_tool", Some("id1".to_string()));
299
300 assert_eq!(tool_ctx.deps.value, 10);
302 assert_eq!(tool_ctx.run_id, ctx.run_id);
304 assert_eq!(tool_ctx.tool_name, Some("test_tool".to_string()));
306 assert_eq!(tool_ctx.retry_count, 0);
308 }
309
310 #[test]
311 fn test_with_deps() {
312 let ctx = RunContext::minimal("gpt-4");
313 let new_ctx = ctx.with_deps(TestDeps { value: 99 });
314 assert_eq!(new_ctx.deps.value, 99);
315 }
316
317 #[test]
318 fn test_elapsed() {
319 let ctx = RunContext::minimal("gpt-4");
320 std::thread::sleep(std::time::Duration::from_millis(10));
321 let elapsed = ctx.elapsed_secs();
322 assert!(elapsed >= 0.01);
323 }
324
325 #[test]
326 fn test_default() {
327 let ctx: RunContext<TestDeps> = RunContext::default();
328 assert_eq!(ctx.deps.value, 0);
329 assert_eq!(ctx.model_name, "default");
330 }
331}