1use std::sync::Arc;
4
5use crate::agent_builder::AgentBuilder;
6use crate::config::AgentConfig;
7use crate::streaming::AgentStream;
8use crate::traits::context_manager::ContextManager;
9#[allow(deprecated)]
10use crate::traits::context_strategy::ContextStrategy;
11use crate::traits::execution_strategy::ExecutionStrategy;
12use crate::traits::guard::Guard;
13use crate::traits::hint::Hint;
14use crate::traits::hook::AgentHook;
15use crate::traits::memory::Memory;
16#[allow(deprecated)]
17use crate::traits::output_processor::OutputProcessor;
18use crate::traits::output_transformer::OutputTransformer;
19use crate::traits::provider::Provider;
20use crate::traits::strategy::{AgentRuntime, AgentStrategy};
21use crate::traits::tool::ErasedTool;
22use crate::traits::tool_registry::ToolRegistry;
23use crate::traits::tracker::Tracker;
24use crate::types::message::Message;
25use crate::Result;
26
27#[derive(Debug, Clone, Default)]
29pub struct RunUsage {
30 pub tokens: usize,
32 pub iterations: usize,
34 pub duration: std::time::Duration,
36}
37
38#[derive(Debug, Clone)]
43#[non_exhaustive]
44pub struct AgentOutput {
45 pub content: AgentOutputContent,
47 pub usage: RunUsage,
49}
50
51#[derive(Debug, Clone)]
53#[non_exhaustive]
54pub enum AgentOutputContent {
55 Text(String),
57 Structured(serde_json::Value),
59 Error(String),
61}
62
63impl AgentOutput {
64 #[must_use]
66 pub fn text_with_usage(text: String, usage: RunUsage) -> Self {
67 Self {
68 content: AgentOutputContent::Text(text),
69 usage,
70 }
71 }
72
73 #[must_use]
78 pub fn text(&self) -> &str {
79 match &self.content {
80 AgentOutputContent::Text(t) => t,
81 _ => "",
82 }
83 }
84
85 #[must_use]
87 pub fn error_message(&self) -> Option<&str> {
88 match &self.content {
89 AgentOutputContent::Error(e) => Some(e),
90 _ => None,
91 }
92 }
93
94 #[must_use]
96 pub fn structured(&self) -> Option<&serde_json::Value> {
97 match &self.content {
98 AgentOutputContent::Structured(v) => Some(v),
99 _ => None,
100 }
101 }
102
103 #[must_use]
105 pub fn is_error(&self) -> bool {
106 matches!(&self.content, AgentOutputContent::Error(_))
107 }
108}
109
110impl std::fmt::Display for AgentOutput {
111 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
112 match &self.content {
113 AgentOutputContent::Text(t) => write!(f, "{t}"),
114 AgentOutputContent::Structured(v) => write!(f, "{v}"),
115 AgentOutputContent::Error(e) => write!(f, "Error: {e}"),
116 }
117 }
118}
119
120#[allow(deprecated)]
127pub struct Agent {
128 pub(crate) provider: Arc<dyn Provider>,
129 pub(crate) tools: Vec<Arc<dyn ErasedTool>>,
130 pub(crate) memory: Arc<dyn Memory>,
131 pub(crate) guards: Vec<Arc<dyn Guard>>,
132 pub(crate) hints: Vec<Arc<dyn Hint>>,
133 pub(crate) tracker: Arc<dyn Tracker>,
134 pub(crate) context_manager: Arc<dyn ContextManager>,
135 pub(crate) context_strategy: Arc<dyn ContextStrategy>,
136 pub(crate) execution_strategy: Arc<dyn ExecutionStrategy>,
137 pub(crate) output_transformer: Arc<dyn OutputTransformer>,
138 pub(crate) output_processor: Arc<dyn OutputProcessor>,
139 pub(crate) tool_registry: Arc<dyn ToolRegistry>,
140 pub(crate) strategy: Box<dyn AgentStrategy>,
141 pub(crate) hooks: Vec<Arc<dyn AgentHook>>,
142 pub(crate) config: AgentConfig,
143}
144
145impl std::fmt::Debug for Agent {
146 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
147 f.debug_struct("Agent")
148 .field("model", &self.provider.model_info().name)
149 .field("tools", &self.tools.len())
150 .field("guards", &self.guards.len())
151 .field("hints", &self.hints.len())
152 .field("hooks", &self.hooks.len())
153 .field("config", &self.config)
154 .finish_non_exhaustive()
155 }
156}
157
158#[allow(deprecated)]
159impl Agent {
160 #[must_use]
162 pub fn builder() -> AgentBuilder {
163 AgentBuilder::new()
164 }
165
166 #[must_use]
195 pub fn with_system(provider: impl Provider, system: impl Into<String>) -> Self {
196 Agent::builder()
197 .provider(provider)
198 .system(system)
199 .build()
200 .expect("Agent::with_system is infallible: provider is always set")
201 }
202
203 #[allow(clippy::too_many_arguments)]
205 pub(crate) fn new(
206 provider: Arc<dyn Provider>,
207 tools: Vec<Arc<dyn ErasedTool>>,
208 memory: Arc<dyn Memory>,
209 guards: Vec<Arc<dyn Guard>>,
210 hints: Vec<Arc<dyn Hint>>,
211 tracker: Arc<dyn Tracker>,
212 context_manager: Arc<dyn ContextManager>,
213 context_strategy: Arc<dyn ContextStrategy>,
214 execution_strategy: Arc<dyn ExecutionStrategy>,
215 output_transformer: Arc<dyn OutputTransformer>,
216 output_processor: Arc<dyn OutputProcessor>,
217 tool_registry: Arc<dyn ToolRegistry>,
218 strategy: Box<dyn AgentStrategy>,
219 hooks: Vec<Arc<dyn AgentHook>>,
220 config: AgentConfig,
221 ) -> Self {
222 Self {
223 provider,
224 tools,
225 memory,
226 guards,
227 hints,
228 tracker,
229 context_manager,
230 context_strategy,
231 execution_strategy,
232 output_transformer,
233 output_processor,
234 tool_registry,
235 strategy,
236 hooks,
237 config,
238 }
239 }
240
241 fn to_runtime(&self) -> AgentRuntime {
245 AgentRuntime {
246 provider: Arc::clone(&self.provider),
247 tools: self.tools.clone(),
248 memory: Arc::clone(&self.memory),
249 guards: self.guards.clone(),
250 hints: self.hints.clone(),
251 tracker: Arc::clone(&self.tracker),
252 context_manager: Arc::clone(&self.context_manager),
253 context_strategy: Arc::clone(&self.context_strategy),
254 execution_strategy: Arc::clone(&self.execution_strategy),
255 output_transformer: Arc::clone(&self.output_transformer),
256 output_processor: Arc::clone(&self.output_processor),
257 tool_registry: Arc::clone(&self.tool_registry),
258 hooks: self.hooks.clone(),
259 config: self.config.clone(),
260 }
261 }
262
263 #[must_use]
277 pub fn session(&self, id: impl Into<String>) -> AgentSession<'_> {
278 AgentSession {
279 agent: self,
280 session_id: id.into(),
281 }
282 }
283
284 #[must_use]
288 pub fn session_auto(&self) -> AgentSession<'_> {
289 AgentSession {
290 agent: self,
291 session_id: uuid::Uuid::new_v4().to_string(),
292 }
293 }
294
295 pub async fn run(&self, input: &str) -> Result<AgentOutput> {
305 let runtime = self.to_runtime();
306 self.strategy.execute(&runtime, input, "default").await
307 }
308
309 #[must_use]
319 pub fn stream(&self, input: &str) -> AgentStream {
320 self.stream_with_session(input, "default")
321 }
322
323 pub async fn run_structured<T>(&self, input: &str) -> Result<T>
344 where
345 T: serde::de::DeserializeOwned + schemars::JsonSchema,
346 {
347 let model_info = self.provider.model_info();
348 let schema = schemars::schema_for!(T);
349 let schema_json = serde_json::to_value(&schema)
350 .map_err(|e| crate::Error::Runtime(format!("Failed to serialize schema: {e}")))?;
351
352 let uses_native = model_info.supports_structured;
353
354 let mut messages = vec![];
355 if let Some(ref system_prompt) = self.config.system_prompt {
356 messages.push(Message::system(system_prompt));
357 }
358
359 if !uses_native {
361 let schema_str = serde_json::to_string_pretty(&schema_json)
362 .unwrap_or_else(|_| schema_json.to_string());
363 messages.push(Message::system(format!(
364 "You MUST respond ONLY with valid JSON matching this schema:\n```json\n{schema_str}\n```\nDo NOT include any text before or after the JSON."
365 )));
366 }
367
368 messages.push(Message::user(input));
369
370 let max_retries = 3;
371 let mut last_error = String::new();
372
373 for attempt in 0..=max_retries {
374 if attempt > 0 {
375 messages.push(Message::system(format!(
377 "Your previous response was not valid JSON. Error: {last_error}\n\
378 Please try again. Respond ONLY with valid JSON."
379 )));
380 }
381
382 let response_format = if uses_native {
383 Some(crate::types::completion::ResponseFormat::JsonSchema {
384 json_schema: schema_json.clone(),
385 })
386 } else {
387 None
388 };
389
390 let request = crate::types::completion::CompletionRequest {
391 model: model_info.name.clone(),
392 messages: messages.clone(),
393 tools: vec![],
394 max_tokens: self.config.max_tokens,
395 temperature: self.config.temperature,
396 response_format,
397 stream: false,
398 };
399
400 let response = self.provider.complete(request).await?;
401
402 let text = match response.content {
403 crate::types::completion::ResponseContent::Text(t) => t,
404 crate::types::completion::ResponseContent::ToolCalls(_) => {
405 last_error = "Model returned tool calls instead of JSON".into();
406 messages.push(Message::assistant("[tool calls returned]"));
407 continue;
408 }
409 };
410
411 match serde_json::from_str::<T>(&text) {
412 Ok(value) => return Ok(value),
413 Err(e) => {
414 last_error = format!("{e}");
415 messages.push(Message::assistant(&text));
416 }
417 }
418 }
419
420 Err(crate::Error::Runtime(format!(
421 "Structured output failed after {max_retries} retries. Last error: {last_error}"
422 )))
423 }
424
425 pub(crate) fn stream_with_session(&self, input: &str, session_id: &str) -> AgentStream {
427 let runtime = crate::traits::strategy::AgentRuntime {
428 provider: Arc::clone(&self.provider),
429 tools: self.tools.clone(),
430 memory: Arc::clone(&self.memory),
431 guards: self.guards.clone(),
432 hints: self.hints.clone(),
433 tracker: Arc::clone(&self.tracker),
434 context_manager: Arc::clone(&self.context_manager),
435 context_strategy: Arc::clone(&self.context_strategy),
436 execution_strategy: Arc::clone(&self.execution_strategy),
437 output_transformer: Arc::clone(&self.output_transformer),
438 output_processor: Arc::clone(&self.output_processor),
439 tool_registry: Arc::clone(&self.tool_registry),
440 config: self.config.clone(),
441 hooks: self.hooks.clone(),
442 };
443
444 self.strategy.stream(&runtime, input, session_id)
445 }
446}
447
448pub struct AgentSession<'a> {
455 agent: &'a Agent,
456 session_id: String,
458}
459
460impl AgentSession<'_> {
461 pub async fn say(&self, input: &str) -> Result<AgentOutput> {
470 let runtime = self.agent.to_runtime();
471 self.agent
472 .strategy
473 .execute(&runtime, input, &self.session_id)
474 .await
475 }
476
477 #[must_use]
481 pub fn stream(&self, input: &str) -> AgentStream {
482 self.agent.stream_with_session(input, &self.session_id)
483 }
484
485 #[must_use]
487 pub fn id(&self) -> &str {
488 &self.session_id
489 }
490}
491
492#[cfg(test)]
493mod tests {
494 use super::*;
495
496 #[test]
497 fn test_run_usage_default() {
498 let u = RunUsage::default();
499 assert_eq!(u.tokens, 0);
500 assert_eq!(u.iterations, 0);
501 assert_eq!(u.duration, std::time::Duration::ZERO);
502 }
503
504 #[test]
505 fn test_text_output() {
506 let out = AgentOutput::text_with_usage("Hello".into(), RunUsage::default());
507 assert_eq!(out.text(), "Hello");
508 assert!(!out.is_error());
509 assert!(out.structured().is_none());
510 assert!(out.error_message().is_none());
511 }
512
513 #[test]
514 fn test_text_returns_empty_for_structured() {
515 let out = AgentOutput {
516 content: AgentOutputContent::Structured(serde_json::json!({"key": "val"})),
517 usage: RunUsage::default(),
518 };
519 assert_eq!(out.text(), "");
520 assert!(out.structured().is_some());
521 assert_eq!(out.structured().unwrap()["key"], "val");
522 }
523
524 #[test]
525 fn test_text_returns_empty_for_error() {
526 let out = AgentOutput {
527 content: AgentOutputContent::Error("boom".into()),
528 usage: RunUsage::default(),
529 };
530 assert_eq!(out.text(), "");
531 assert!(out.is_error());
532 assert_eq!(out.error_message(), Some("boom"));
533 }
534
535 #[test]
536 fn test_display_text() {
537 let out = AgentOutput::text_with_usage("hi".into(), RunUsage::default());
538 assert_eq!(format!("{out}"), "hi");
539 }
540
541 #[test]
542 fn test_display_structured() {
543 let out = AgentOutput {
544 content: AgentOutputContent::Structured(serde_json::json!(42)),
545 usage: RunUsage::default(),
546 };
547 assert_eq!(format!("{out}"), "42");
548 }
549
550 #[test]
551 fn test_display_error() {
552 let out = AgentOutput {
553 content: AgentOutputContent::Error("fail".into()),
554 usage: RunUsage::default(),
555 };
556 assert_eq!(format!("{out}"), "Error: fail");
557 }
558
559 #[test]
560 fn test_usage_carried_through() {
561 let usage = RunUsage {
562 tokens: 100,
563 iterations: 5,
564 duration: std::time::Duration::from_millis(500),
565 };
566 let out = AgentOutput::text_with_usage("x".into(), usage);
567 assert_eq!(out.usage.tokens, 100);
568 assert_eq!(out.usage.iterations, 5);
569 assert_eq!(out.usage.duration.as_millis(), 500);
570 }
571
572 use crate::types::completion::{CompletionRequest, CompletionResponse, ResponseContent, Usage};
575 use crate::types::model_info::{ModelInfo, ModelTier};
576 use crate::types::stream::CompletionStream;
577 use async_trait::async_trait;
578
579 struct MockProvider {
580 info: ModelInfo,
581 }
582
583 impl MockProvider {
584 fn new() -> Self {
585 Self {
586 info: ModelInfo::new("mock", ModelTier::Small, 4_096, false, false, false),
587 }
588 }
589 }
590
591 #[async_trait]
592 impl crate::traits::provider::Provider for MockProvider {
593 async fn complete(&self, _req: CompletionRequest) -> crate::Result<CompletionResponse> {
594 Ok(CompletionResponse {
595 content: ResponseContent::Text("ok".into()),
596 usage: Usage {
597 prompt_tokens: 1,
598 completion_tokens: 1,
599 total_tokens: 2,
600 },
601 })
602 }
603 async fn stream(&self, _req: CompletionRequest) -> crate::Result<CompletionStream> {
604 unimplemented!()
605 }
606 fn model_info(&self) -> &ModelInfo {
607 &self.info
608 }
609 }
610
611 #[test]
612 fn test_with_system_str_prompt() {
613 let agent = Agent::with_system(MockProvider::new(), "You are helpful.");
615 assert_eq!(
616 agent.config.system_prompt.as_deref(),
617 Some("You are helpful.")
618 );
619 }
620
621 #[test]
622 fn test_with_system_string_prompt() {
623 let prompt = String::from("You are a researcher.");
625 let agent = Agent::with_system(MockProvider::new(), prompt);
626 assert_eq!(
627 agent.config.system_prompt.as_deref(),
628 Some("You are a researcher.")
629 );
630 }
631
632 #[test]
633 fn test_with_system_builder_unchanged() {
634 let result = Agent::builder()
636 .provider(MockProvider::new())
637 .system("test")
638 .build();
639 assert!(result.is_ok());
640 }
641
642 #[test]
643 fn test_with_system_provider_configured() {
644 let agent = Agent::with_system(MockProvider::new(), "test");
646 assert_eq!(agent.provider.model_info().name, "mock");
647 }
648}