1use std::sync::Arc;
4
5use crate::agent_builder::AgentBuilder;
6use crate::config::AgentConfig;
7use crate::streaming::AgentStream;
8use crate::traits::context_manager::ContextManager;
9use crate::traits::execution_strategy::ExecutionStrategy;
10use crate::traits::guard::Guard;
11use crate::traits::hint::Hint;
12use crate::traits::hook::AgentHook;
13use crate::traits::memory::Memory;
14use crate::traits::output_transformer::OutputTransformer;
15use crate::traits::provider::Provider;
16use crate::traits::strategy::{AgentRuntime, AgentStrategy};
17use crate::traits::tool::ErasedTool;
18use crate::traits::tool_registry::ToolRegistry;
19use crate::traits::tracker::Tracker;
20use crate::types::message::Message;
21use crate::Result;
22
23#[derive(Debug, Clone, Default)]
25pub struct RunUsage {
26 pub tokens: usize,
28 pub iterations: usize,
30 pub duration: std::time::Duration,
32}
33
34#[derive(Debug, Clone)]
39#[non_exhaustive]
40pub struct AgentOutput {
41 pub content: AgentOutputContent,
43 pub usage: RunUsage,
45}
46
47#[derive(Debug, Clone)]
49#[non_exhaustive]
50pub enum AgentOutputContent {
51 Text(String),
53 Structured(serde_json::Value),
55 Error(String),
57}
58
59impl AgentOutput {
60 #[must_use]
62 pub fn text_with_usage(text: String, usage: RunUsage) -> Self {
63 Self {
64 content: AgentOutputContent::Text(text),
65 usage,
66 }
67 }
68
69 #[must_use]
74 pub fn text(&self) -> &str {
75 match &self.content {
76 AgentOutputContent::Text(t) => t,
77 _ => "",
78 }
79 }
80
81 #[must_use]
83 pub fn error_message(&self) -> Option<&str> {
84 match &self.content {
85 AgentOutputContent::Error(e) => Some(e),
86 _ => None,
87 }
88 }
89
90 #[must_use]
92 pub fn structured(&self) -> Option<&serde_json::Value> {
93 match &self.content {
94 AgentOutputContent::Structured(v) => Some(v),
95 _ => None,
96 }
97 }
98
99 #[must_use]
101 pub fn is_error(&self) -> bool {
102 matches!(&self.content, AgentOutputContent::Error(_))
103 }
104}
105
106impl std::fmt::Display for AgentOutput {
107 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
108 match &self.content {
109 AgentOutputContent::Text(t) => write!(f, "{t}"),
110 AgentOutputContent::Structured(v) => write!(f, "{v}"),
111 AgentOutputContent::Error(e) => write!(f, "Error: {e}"),
112 }
113 }
114}
115
116pub struct Agent {
123 pub(crate) provider: Arc<dyn Provider>,
124 pub(crate) tools: Vec<Arc<dyn ErasedTool>>,
125 pub(crate) memory: Arc<dyn Memory>,
126 pub(crate) guards: Vec<Arc<dyn Guard>>,
127 pub(crate) hints: Vec<Arc<dyn Hint>>,
128 pub(crate) tracker: Arc<dyn Tracker>,
129 pub(crate) context_manager: Arc<dyn ContextManager>,
130 pub(crate) execution_strategy: Arc<dyn ExecutionStrategy>,
131 pub(crate) output_transformer: Arc<dyn OutputTransformer>,
132 pub(crate) tool_registry: Arc<dyn ToolRegistry>,
133 pub(crate) strategy: Box<dyn AgentStrategy>,
134 pub(crate) hooks: Vec<Arc<dyn AgentHook>>,
135 pub(crate) config: AgentConfig,
136}
137
138impl std::fmt::Debug for Agent {
139 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
140 f.debug_struct("Agent")
141 .field("model", &self.provider.model_info().name)
142 .field("tools", &self.tools.len())
143 .field("guards", &self.guards.len())
144 .field("hints", &self.hints.len())
145 .field("hooks", &self.hooks.len())
146 .field("config", &self.config)
147 .finish_non_exhaustive()
148 }
149}
150
151impl Agent {
152 #[must_use]
154 pub fn builder() -> AgentBuilder {
155 AgentBuilder::new()
156 }
157
158 #[must_use]
187 pub fn with_system(provider: impl Provider, system: impl Into<String>) -> Self {
188 Agent::builder()
189 .provider(provider)
190 .system(system)
191 .build()
192 .expect("Agent::with_system is infallible: provider is always set")
193 }
194
195 #[allow(clippy::too_many_arguments)]
197 pub(crate) fn new(
198 provider: Arc<dyn Provider>,
199 tools: Vec<Arc<dyn ErasedTool>>,
200 memory: Arc<dyn Memory>,
201 guards: Vec<Arc<dyn Guard>>,
202 hints: Vec<Arc<dyn Hint>>,
203 tracker: Arc<dyn Tracker>,
204 context_manager: Arc<dyn ContextManager>,
205 execution_strategy: Arc<dyn ExecutionStrategy>,
206 output_transformer: Arc<dyn OutputTransformer>,
207 tool_registry: Arc<dyn ToolRegistry>,
208 strategy: Box<dyn AgentStrategy>,
209 hooks: Vec<Arc<dyn AgentHook>>,
210 config: AgentConfig,
211 ) -> Self {
212 Self {
213 provider,
214 tools,
215 memory,
216 guards,
217 hints,
218 tracker,
219 context_manager,
220 execution_strategy,
221 output_transformer,
222 tool_registry,
223 strategy,
224 hooks,
225 config,
226 }
227 }
228
229 fn to_runtime(&self) -> AgentRuntime {
233 AgentRuntime {
234 provider: Arc::clone(&self.provider),
235 tools: self.tools.clone(),
236 memory: Arc::clone(&self.memory),
237 guards: self.guards.clone(),
238 hints: self.hints.clone(),
239 tracker: Arc::clone(&self.tracker),
240 context_manager: Arc::clone(&self.context_manager),
241 execution_strategy: Arc::clone(&self.execution_strategy),
242 output_transformer: Arc::clone(&self.output_transformer),
243 tool_registry: Arc::clone(&self.tool_registry),
244 hooks: self.hooks.clone(),
245 config: self.config.clone(),
246 }
247 }
248
249 #[must_use]
263 pub fn session(&self, id: impl Into<String>) -> AgentSession<'_> {
264 AgentSession {
265 agent: self,
266 session_id: id.into(),
267 }
268 }
269
270 #[must_use]
274 pub fn session_auto(&self) -> AgentSession<'_> {
275 AgentSession {
276 agent: self,
277 session_id: uuid::Uuid::new_v4().to_string(),
278 }
279 }
280
281 #[must_use]
286 pub fn memory(&self) -> &dyn Memory {
287 &*self.memory
288 }
289
290 pub async fn run(&self, input: &str) -> Result<AgentOutput> {
300 let runtime = self.to_runtime();
301 self.strategy.execute(&runtime, input, "default").await
302 }
303
304 #[must_use]
311 pub fn stream(&self, input: &str) -> AgentStream {
312 self.stream_with_session(input, "default")
313 }
314
315 pub async fn run_structured<T>(&self, input: &str) -> Result<T>
336 where
337 T: serde::de::DeserializeOwned + schemars::JsonSchema,
338 {
339 let model_info = self.provider.model_info();
340 let schema = schemars::schema_for!(T);
341 let schema_json = serde_json::to_value(&schema)
342 .map_err(|e| crate::Error::Runtime(format!("Failed to serialize schema: {e}")))?;
343
344 let uses_native = model_info.supports_structured;
345
346 let mut messages = vec![];
347 if let Some(ref system_prompt) = self.config.system_prompt {
348 messages.push(Message::system(system_prompt));
349 }
350
351 if !uses_native {
353 let schema_str = serde_json::to_string_pretty(&schema_json)
354 .unwrap_or_else(|_| schema_json.to_string());
355 messages.push(Message::system(format!(
356 "You MUST respond ONLY with valid JSON matching this schema:\n```json\n{schema_str}\n```\nDo NOT include any text before or after the JSON."
357 )));
358 }
359
360 messages.push(Message::user(input));
361
362 let max_retries = 3;
363 let mut last_error = String::new();
364
365 for attempt in 0..=max_retries {
366 if attempt > 0 {
367 messages.push(Message::system(format!(
369 "Your previous response was not valid JSON. Error: {last_error}\n\
370 Please try again. Respond ONLY with valid JSON."
371 )));
372 }
373
374 let response_format = if uses_native {
375 Some(crate::types::completion::ResponseFormat::JsonSchema {
376 json_schema: schema_json.clone(),
377 })
378 } else {
379 None
380 };
381
382 let request = crate::types::completion::CompletionRequest {
383 model: model_info.name.clone(),
384 messages: messages.clone(),
385 tools: vec![],
386 max_tokens: self.config.max_tokens,
387 temperature: self.config.temperature,
388 response_format,
389 stream: false,
390 };
391
392 let response = self.provider.complete(request).await?;
393
394 let text = match response.content {
395 crate::types::completion::ResponseContent::Text(t) => t,
396 crate::types::completion::ResponseContent::ToolCalls(_) => {
397 last_error = "Model returned tool calls instead of JSON".into();
398 messages.push(Message::assistant("[tool calls returned]"));
399 continue;
400 }
401 };
402
403 match serde_json::from_str::<T>(&text) {
404 Ok(value) => return Ok(value),
405 Err(e) => {
406 last_error = format!("{e}");
407 messages.push(Message::assistant(&text));
408 }
409 }
410 }
411
412 Err(crate::Error::Runtime(format!(
413 "Structured output failed after {max_retries} retries. Last error: {last_error}"
414 )))
415 }
416
417 pub(crate) fn stream_with_session(&self, input: &str, session_id: &str) -> AgentStream {
419 let runtime = crate::traits::strategy::AgentRuntime {
420 provider: Arc::clone(&self.provider),
421 tools: self.tools.clone(),
422 memory: Arc::clone(&self.memory),
423 guards: self.guards.clone(),
424 hints: self.hints.clone(),
425 tracker: Arc::clone(&self.tracker),
426 context_manager: Arc::clone(&self.context_manager),
427 execution_strategy: Arc::clone(&self.execution_strategy),
428 output_transformer: Arc::clone(&self.output_transformer),
429 tool_registry: Arc::clone(&self.tool_registry),
430 config: self.config.clone(),
431 hooks: self.hooks.clone(),
432 };
433
434 self.strategy.stream(&runtime, input, session_id)
435 }
436}
437
438pub struct AgentSession<'a> {
445 agent: &'a Agent,
446 session_id: String,
448}
449
450impl AgentSession<'_> {
451 pub async fn say(&self, input: &str) -> Result<AgentOutput> {
460 let runtime = self.agent.to_runtime();
461 self.agent
462 .strategy
463 .execute(&runtime, input, &self.session_id)
464 .await
465 }
466
467 #[must_use]
471 pub fn stream(&self, input: &str) -> AgentStream {
472 self.agent.stream_with_session(input, &self.session_id)
473 }
474
475 #[must_use]
477 pub fn id(&self) -> &str {
478 &self.session_id
479 }
480}
481
482#[cfg(test)]
483mod tests {
484 use super::*;
485
486 #[test]
487 fn test_run_usage_default() {
488 let u = RunUsage::default();
489 assert_eq!(u.tokens, 0);
490 assert_eq!(u.iterations, 0);
491 assert_eq!(u.duration, std::time::Duration::ZERO);
492 }
493
494 #[test]
495 fn test_text_output() {
496 let out = AgentOutput::text_with_usage("Hello".into(), RunUsage::default());
497 assert_eq!(out.text(), "Hello");
498 assert!(!out.is_error());
499 assert!(out.structured().is_none());
500 assert!(out.error_message().is_none());
501 }
502
503 #[test]
504 fn test_text_returns_empty_for_structured() {
505 let out = AgentOutput {
506 content: AgentOutputContent::Structured(serde_json::json!({"key": "val"})),
507 usage: RunUsage::default(),
508 };
509 assert_eq!(out.text(), "");
510 assert!(out.structured().is_some());
511 assert_eq!(out.structured().unwrap()["key"], "val");
512 }
513
514 #[test]
515 fn test_text_returns_empty_for_error() {
516 let out = AgentOutput {
517 content: AgentOutputContent::Error("boom".into()),
518 usage: RunUsage::default(),
519 };
520 assert_eq!(out.text(), "");
521 assert!(out.is_error());
522 assert_eq!(out.error_message(), Some("boom"));
523 }
524
525 #[test]
526 fn test_display_text() {
527 let out = AgentOutput::text_with_usage("hi".into(), RunUsage::default());
528 assert_eq!(format!("{out}"), "hi");
529 }
530
531 #[test]
532 fn test_display_structured() {
533 let out = AgentOutput {
534 content: AgentOutputContent::Structured(serde_json::json!(42)),
535 usage: RunUsage::default(),
536 };
537 assert_eq!(format!("{out}"), "42");
538 }
539
540 #[test]
541 fn test_display_error() {
542 let out = AgentOutput {
543 content: AgentOutputContent::Error("fail".into()),
544 usage: RunUsage::default(),
545 };
546 assert_eq!(format!("{out}"), "Error: fail");
547 }
548
549 #[test]
550 fn test_usage_carried_through() {
551 let usage = RunUsage {
552 tokens: 100,
553 iterations: 5,
554 duration: std::time::Duration::from_millis(500),
555 };
556 let out = AgentOutput::text_with_usage("x".into(), usage);
557 assert_eq!(out.usage.tokens, 100);
558 assert_eq!(out.usage.iterations, 5);
559 assert_eq!(out.usage.duration.as_millis(), 500);
560 }
561
562 use crate::types::completion::{CompletionRequest, CompletionResponse, ResponseContent, Usage};
565 use crate::types::model_info::{ModelInfo, ModelTier};
566 use crate::types::stream::CompletionStream;
567 use async_trait::async_trait;
568
569 struct MockProvider {
570 info: ModelInfo,
571 }
572
573 impl MockProvider {
574 fn new() -> Self {
575 Self {
576 info: ModelInfo::new("mock", ModelTier::Small, 4_096, false, false, false),
577 }
578 }
579 }
580
581 #[async_trait]
582 impl crate::traits::provider::Provider for MockProvider {
583 async fn complete(&self, _req: CompletionRequest) -> crate::Result<CompletionResponse> {
584 Ok(CompletionResponse {
585 content: ResponseContent::Text("ok".into()),
586 usage: Usage {
587 prompt_tokens: 1,
588 completion_tokens: 1,
589 total_tokens: 2,
590 },
591 })
592 }
593 async fn stream(&self, _req: CompletionRequest) -> crate::Result<CompletionStream> {
594 unimplemented!()
595 }
596 fn model_info(&self) -> &ModelInfo {
597 &self.info
598 }
599 }
600
601 #[test]
602 fn test_with_system_str_prompt() {
603 let agent = Agent::with_system(MockProvider::new(), "You are helpful.");
605 assert_eq!(
606 agent.config.system_prompt.as_deref(),
607 Some("You are helpful.")
608 );
609 }
610
611 #[test]
612 fn test_with_system_string_prompt() {
613 let prompt = String::from("You are a researcher.");
615 let agent = Agent::with_system(MockProvider::new(), prompt);
616 assert_eq!(
617 agent.config.system_prompt.as_deref(),
618 Some("You are a researcher.")
619 );
620 }
621
622 #[test]
623 fn test_with_system_builder_unchanged() {
624 let result = Agent::builder()
626 .provider(MockProvider::new())
627 .system("test")
628 .build();
629 assert!(result.is_ok());
630 }
631
632 #[test]
633 fn test_with_system_provider_configured() {
634 let agent = Agent::with_system(MockProvider::new(), "test");
636 assert_eq!(agent.provider.model_info().name, "mock");
637 }
638}