1use std::sync::Arc;
4
5use crate::agent_builder::AgentBuilder;
6use crate::config::AgentConfig;
7use crate::streaming::AgentStream;
8use crate::traits::context_strategy::ContextStrategy;
9use crate::traits::execution_strategy::ExecutionStrategy;
10use crate::traits::guard::Guard;
11use crate::traits::hint::Hint;
12use crate::traits::memory::Memory;
13use crate::traits::output_processor::OutputProcessor;
14use crate::traits::provider::Provider;
15use crate::traits::tool::ErasedTool;
16use crate::traits::tracker::Tracker;
17use crate::types::message::Message;
18use crate::Result;
19
20#[derive(Debug, Clone, Default)]
22pub struct RunUsage {
23 pub tokens: usize,
25 pub iterations: usize,
27 pub duration: std::time::Duration,
29}
30
31#[derive(Debug, Clone)]
36#[non_exhaustive]
37pub struct AgentOutput {
38 pub content: AgentOutputContent,
40 pub usage: RunUsage,
42}
43
44#[derive(Debug, Clone)]
46#[non_exhaustive]
47pub enum AgentOutputContent {
48 Text(String),
50 Structured(serde_json::Value),
52 Error(String),
54}
55
56impl AgentOutput {
57 #[must_use]
59 pub fn text_with_usage(text: String, usage: RunUsage) -> Self {
60 Self {
61 content: AgentOutputContent::Text(text),
62 usage,
63 }
64 }
65
66 #[must_use]
71 pub fn text(&self) -> &str {
72 match &self.content {
73 AgentOutputContent::Text(t) => t,
74 _ => "",
75 }
76 }
77
78 #[must_use]
80 pub fn error_message(&self) -> Option<&str> {
81 match &self.content {
82 AgentOutputContent::Error(e) => Some(e),
83 _ => None,
84 }
85 }
86
87 #[must_use]
89 pub fn structured(&self) -> Option<&serde_json::Value> {
90 match &self.content {
91 AgentOutputContent::Structured(v) => Some(v),
92 _ => None,
93 }
94 }
95
96 #[must_use]
98 pub fn is_error(&self) -> bool {
99 matches!(&self.content, AgentOutputContent::Error(_))
100 }
101}
102
103impl std::fmt::Display for AgentOutput {
104 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
105 match &self.content {
106 AgentOutputContent::Text(t) => write!(f, "{t}"),
107 AgentOutputContent::Structured(v) => write!(f, "{v}"),
108 AgentOutputContent::Error(e) => write!(f, "Error: {e}"),
109 }
110 }
111}
112
113pub struct Agent {
120 pub(crate) provider: Arc<dyn Provider>,
121 pub(crate) tools: Vec<Arc<dyn ErasedTool>>,
122 pub(crate) memory: Arc<dyn Memory>,
123 pub(crate) guards: Vec<Arc<dyn Guard>>,
124 pub(crate) hints: Vec<Arc<dyn Hint>>,
125 pub(crate) tracker: Arc<dyn Tracker>,
126 pub(crate) context_strategy: Arc<dyn ContextStrategy>,
127 pub(crate) execution_strategy: Arc<dyn ExecutionStrategy>,
128 pub(crate) output_processor: Arc<dyn OutputProcessor>,
129 pub(crate) config: AgentConfig,
130}
131
132impl std::fmt::Debug for Agent {
133 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
134 f.debug_struct("Agent")
135 .field("model", &self.provider.model_info().name)
136 .field("tools", &self.tools.len())
137 .field("guards", &self.guards.len())
138 .field("hints", &self.hints.len())
139 .field("config", &self.config)
140 .finish_non_exhaustive()
141 }
142}
143
144impl Agent {
145 #[must_use]
147 pub fn builder() -> AgentBuilder {
148 AgentBuilder::new()
149 }
150
151 #[allow(clippy::too_many_arguments)]
153 pub(crate) fn new(
154 provider: Arc<dyn Provider>,
155 tools: Vec<Arc<dyn ErasedTool>>,
156 memory: Arc<dyn Memory>,
157 guards: Vec<Arc<dyn Guard>>,
158 hints: Vec<Arc<dyn Hint>>,
159 tracker: Arc<dyn Tracker>,
160 context_strategy: Arc<dyn ContextStrategy>,
161 execution_strategy: Arc<dyn ExecutionStrategy>,
162 output_processor: Arc<dyn OutputProcessor>,
163 config: AgentConfig,
164 ) -> Self {
165 Self {
166 provider,
167 tools,
168 memory,
169 guards,
170 hints,
171 tracker,
172 context_strategy,
173 execution_strategy,
174 output_processor,
175 config,
176 }
177 }
178
179 #[must_use]
193 pub fn session(&self, id: impl Into<String>) -> AgentSession<'_> {
194 AgentSession {
195 agent: self,
196 session_id: id.into(),
197 }
198 }
199
200 #[must_use]
204 pub fn session_auto(&self) -> AgentSession<'_> {
205 AgentSession {
206 agent: self,
207 session_id: uuid::Uuid::new_v4().to_string(),
208 }
209 }
210
211 pub async fn run(&self, input: &str) -> Result<AgentOutput> {
221 crate::runtime::run_agent(self, input, "default").await
222 }
223
224 #[must_use]
233 pub fn stream(&self, input: &str) -> AgentStream {
234 crate::streaming::stream_agent(self, input.to_string(), "default".to_string())
235 }
236
237 pub async fn run_structured<T>(&self, input: &str) -> Result<T>
258 where
259 T: serde::de::DeserializeOwned + schemars::JsonSchema,
260 {
261 let model_info = self.provider.model_info();
262 let schema = schemars::schema_for!(T);
263 let schema_json = serde_json::to_value(&schema)
264 .map_err(|e| crate::Error::Runtime(format!("Failed to serialize schema: {e}")))?;
265
266 let uses_native = model_info.supports_structured;
267
268 let mut messages = vec![];
269 if let Some(ref system_prompt) = self.config.system_prompt {
270 messages.push(Message::system(system_prompt));
271 }
272
273 if !uses_native {
275 let schema_str = serde_json::to_string_pretty(&schema_json)
276 .unwrap_or_else(|_| schema_json.to_string());
277 messages.push(Message::system(format!(
278 "You MUST respond ONLY with valid JSON matching this schema:\n```json\n{schema_str}\n```\nDo NOT include any text before or after the JSON."
279 )));
280 }
281
282 messages.push(Message::user(input));
283
284 let max_retries = 3;
285 let mut last_error = String::new();
286
287 for attempt in 0..=max_retries {
288 if attempt > 0 {
289 messages.push(Message::system(format!(
291 "Your previous response was not valid JSON. Error: {last_error}\n\
292 Please try again. Respond ONLY with valid JSON."
293 )));
294 }
295
296 let response_format = if uses_native {
297 Some(crate::types::completion::ResponseFormat::JsonSchema {
298 json_schema: schema_json.clone(),
299 })
300 } else {
301 None
302 };
303
304 let request = crate::types::completion::CompletionRequest {
305 model: model_info.name.clone(),
306 messages: messages.clone(),
307 tools: vec![],
308 max_tokens: self.config.max_tokens,
309 temperature: self.config.temperature,
310 response_format,
311 stream: false,
312 };
313
314 let response = self.provider.complete(request).await?;
315
316 let text = match response.content {
317 crate::types::completion::ResponseContent::Text(t) => t,
318 crate::types::completion::ResponseContent::ToolCalls(_) => {
319 last_error = "Model returned tool calls instead of JSON".into();
320 messages.push(Message::assistant("[tool calls returned]"));
321 continue;
322 }
323 };
324
325 match serde_json::from_str::<T>(&text) {
326 Ok(value) => return Ok(value),
327 Err(e) => {
328 last_error = format!("{e}");
329 messages.push(Message::assistant(&text));
330 }
331 }
332 }
333
334 Err(crate::Error::Runtime(format!(
335 "Structured output failed after {max_retries} retries. Last error: {last_error}"
336 )))
337 }
338}
339
340pub struct AgentSession<'a> {
347 agent: &'a Agent,
348 session_id: String,
350}
351
352impl AgentSession<'_> {
353 pub async fn say(&self, input: &str) -> Result<AgentOutput> {
362 crate::runtime::run_agent(self.agent, input, &self.session_id).await
363 }
364
365 #[must_use]
369 pub fn stream(&self, input: &str) -> AgentStream {
370 crate::streaming::stream_agent(self.agent, input.to_string(), self.session_id.clone())
371 }
372
373 #[must_use]
375 pub fn id(&self) -> &str {
376 &self.session_id
377 }
378}
379
380#[cfg(test)]
381mod tests {
382 use super::*;
383
384 #[test]
385 fn test_run_usage_default() {
386 let u = RunUsage::default();
387 assert_eq!(u.tokens, 0);
388 assert_eq!(u.iterations, 0);
389 assert_eq!(u.duration, std::time::Duration::ZERO);
390 }
391
392 #[test]
393 fn test_text_output() {
394 let out = AgentOutput::text_with_usage("Hello".into(), RunUsage::default());
395 assert_eq!(out.text(), "Hello");
396 assert!(!out.is_error());
397 assert!(out.structured().is_none());
398 assert!(out.error_message().is_none());
399 }
400
401 #[test]
402 fn test_text_returns_empty_for_structured() {
403 let out = AgentOutput {
404 content: AgentOutputContent::Structured(serde_json::json!({"key": "val"})),
405 usage: RunUsage::default(),
406 };
407 assert_eq!(out.text(), "");
408 assert!(out.structured().is_some());
409 assert_eq!(out.structured().unwrap()["key"], "val");
410 }
411
412 #[test]
413 fn test_text_returns_empty_for_error() {
414 let out = AgentOutput {
415 content: AgentOutputContent::Error("boom".into()),
416 usage: RunUsage::default(),
417 };
418 assert_eq!(out.text(), "");
419 assert!(out.is_error());
420 assert_eq!(out.error_message(), Some("boom"));
421 }
422
423 #[test]
424 fn test_display_text() {
425 let out = AgentOutput::text_with_usage("hi".into(), RunUsage::default());
426 assert_eq!(format!("{out}"), "hi");
427 }
428
429 #[test]
430 fn test_display_structured() {
431 let out = AgentOutput {
432 content: AgentOutputContent::Structured(serde_json::json!(42)),
433 usage: RunUsage::default(),
434 };
435 assert_eq!(format!("{out}"), "42");
436 }
437
438 #[test]
439 fn test_display_error() {
440 let out = AgentOutput {
441 content: AgentOutputContent::Error("fail".into()),
442 usage: RunUsage::default(),
443 };
444 assert_eq!(format!("{out}"), "Error: fail");
445 }
446
447 #[test]
448 fn test_usage_carried_through() {
449 let usage = RunUsage {
450 tokens: 100,
451 iterations: 5,
452 duration: std::time::Duration::from_millis(500),
453 };
454 let out = AgentOutput::text_with_usage("x".into(), usage);
455 assert_eq!(out.usage.tokens, 100);
456 assert_eq!(out.usage.iterations, 5);
457 assert_eq!(out.usage.duration.as_millis(), 500);
458 }
459}