Skip to main content

traitclaw_core/
agent.rs

1//! Agent — the main entry point for running AI agents.
2
3use std::sync::Arc;
4
5use crate::agent_builder::AgentBuilder;
6use crate::config::AgentConfig;
7use crate::streaming::AgentStream;
8use crate::traits::context_manager::ContextManager;
9#[allow(deprecated)]
10use crate::traits::context_strategy::ContextStrategy;
11use crate::traits::execution_strategy::ExecutionStrategy;
12use crate::traits::guard::Guard;
13use crate::traits::hint::Hint;
14use crate::traits::hook::AgentHook;
15use crate::traits::memory::Memory;
16#[allow(deprecated)]
17use crate::traits::output_processor::OutputProcessor;
18use crate::traits::output_transformer::OutputTransformer;
19use crate::traits::provider::Provider;
20use crate::traits::strategy::{AgentRuntime, AgentStrategy};
21use crate::traits::tool::ErasedTool;
22use crate::traits::tool_registry::ToolRegistry;
23use crate::traits::tracker::Tracker;
24use crate::types::message::Message;
25use crate::Result;
26
27/// Usage statistics from an agent run.
28#[derive(Debug, Clone, Default)]
29pub struct RunUsage {
30    /// Total tokens consumed across all LLM calls.
31    pub tokens: usize,
32    /// Number of agent loop iterations.
33    pub iterations: usize,
34    /// Wall-clock duration of the run.
35    pub duration: std::time::Duration,
36}
37
38/// Output from an agent run.
39///
40/// This struct is marked `#[non_exhaustive]` — new fields may be added in
41/// future releases without breaking changes.
42#[derive(Debug, Clone)]
43#[non_exhaustive]
44pub struct AgentOutput {
45    /// The response content.
46    pub content: AgentOutputContent,
47    /// Usage statistics for this run.
48    pub usage: RunUsage,
49}
50
51/// The content type of an agent output.
52#[derive(Debug, Clone)]
53#[non_exhaustive]
54pub enum AgentOutputContent {
55    /// The agent returned a text response.
56    Text(String),
57    /// The agent returned a structured JSON response.
58    Structured(serde_json::Value),
59    /// The agent encountered an error.
60    Error(String),
61}
62
63impl AgentOutput {
64    /// Create a text output with usage.
65    #[must_use]
66    pub fn text_with_usage(text: String, usage: RunUsage) -> Self {
67        Self {
68            content: AgentOutputContent::Text(text),
69            usage,
70        }
71    }
72
73    /// Get the text content if this is a text output.
74    ///
75    /// Returns an empty string for `Structured` and `Error` variants.
76    /// Use [`structured()`](Self::structured) or [`Display`] for those.
77    #[must_use]
78    pub fn text(&self) -> &str {
79        match &self.content {
80            AgentOutputContent::Text(t) => t,
81            _ => "",
82        }
83    }
84
85    /// Get the error message if this is an error output.
86    #[must_use]
87    pub fn error_message(&self) -> Option<&str> {
88        match &self.content {
89            AgentOutputContent::Error(e) => Some(e),
90            _ => None,
91        }
92    }
93
94    /// Get the structured JSON value if this is a structured output.
95    #[must_use]
96    pub fn structured(&self) -> Option<&serde_json::Value> {
97        match &self.content {
98            AgentOutputContent::Structured(v) => Some(v),
99            _ => None,
100        }
101    }
102
103    /// Check if this is an error output.
104    #[must_use]
105    pub fn is_error(&self) -> bool {
106        matches!(&self.content, AgentOutputContent::Error(_))
107    }
108}
109
110impl std::fmt::Display for AgentOutput {
111    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
112        match &self.content {
113            AgentOutputContent::Text(t) => write!(f, "{t}"),
114            AgentOutputContent::Structured(v) => write!(f, "{v}"),
115            AgentOutputContent::Error(e) => write!(f, "Error: {e}"),
116        }
117    }
118}
119
120/// The main agent struct.
121///
122/// An agent combines an LLM provider, tools, memory, and steering into
123/// a single runtime that processes user input and produces output.
124///
125/// Use [`Agent::builder()`] to construct an agent.
126#[allow(deprecated)]
127pub struct Agent {
128    pub(crate) provider: Arc<dyn Provider>,
129    pub(crate) tools: Vec<Arc<dyn ErasedTool>>,
130    pub(crate) memory: Arc<dyn Memory>,
131    pub(crate) guards: Vec<Arc<dyn Guard>>,
132    pub(crate) hints: Vec<Arc<dyn Hint>>,
133    pub(crate) tracker: Arc<dyn Tracker>,
134    pub(crate) context_manager: Arc<dyn ContextManager>,
135    pub(crate) context_strategy: Arc<dyn ContextStrategy>,
136    pub(crate) execution_strategy: Arc<dyn ExecutionStrategy>,
137    pub(crate) output_transformer: Arc<dyn OutputTransformer>,
138    pub(crate) output_processor: Arc<dyn OutputProcessor>,
139    pub(crate) tool_registry: Arc<dyn ToolRegistry>,
140    pub(crate) strategy: Box<dyn AgentStrategy>,
141    pub(crate) hooks: Vec<Arc<dyn AgentHook>>,
142    pub(crate) config: AgentConfig,
143}
144
145impl std::fmt::Debug for Agent {
146    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
147        f.debug_struct("Agent")
148            .field("model", &self.provider.model_info().name)
149            .field("tools", &self.tools.len())
150            .field("guards", &self.guards.len())
151            .field("hints", &self.hints.len())
152            .field("hooks", &self.hooks.len())
153            .field("config", &self.config)
154            .finish_non_exhaustive()
155    }
156}
157
158#[allow(deprecated)]
159impl Agent {
160    /// Create a builder for constructing an agent.
161    #[must_use]
162    pub fn builder() -> AgentBuilder {
163        AgentBuilder::new()
164    }
165
166    /// Create an agent with just a provider and system prompt.
167    ///
168    /// This is a convenience shorthand equivalent to:
169    /// ```rust,ignore
170    /// Agent::builder()
171    ///     .provider(provider)
172    ///     .system(system)
173    ///     .build()
174    ///     .unwrap()
175    /// ```
176    ///
177    /// All other settings use their defaults (in-memory memory, no tools,
178    /// no guards, etc.). Use [`Agent::builder()`] for full customization.
179    ///
180    /// # Example
181    ///
182    /// ```rust,no_run
183    /// use traitclaw_core::prelude::*;
184    ///
185    /// # fn example(provider: impl traitclaw_core::traits::provider::Provider) {
186    /// let agent = Agent::with_system(provider, "You are a helpful assistant.");
187    /// # }
188    /// ```
189    /// # Panics
190    ///
191    /// This method cannot panic under normal usage — the internal `build()`
192    /// call only fails when no provider is set, and `with_system` always
193    /// provides one.
194    #[must_use]
195    pub fn with_system(provider: impl Provider, system: impl Into<String>) -> Self {
196        Agent::builder()
197            .provider(provider)
198            .system(system)
199            .build()
200            .expect("Agent::with_system is infallible: provider is always set")
201    }
202
203    /// Create an agent directly (prefer using `builder()`).
204    #[allow(clippy::too_many_arguments)]
205    pub(crate) fn new(
206        provider: Arc<dyn Provider>,
207        tools: Vec<Arc<dyn ErasedTool>>,
208        memory: Arc<dyn Memory>,
209        guards: Vec<Arc<dyn Guard>>,
210        hints: Vec<Arc<dyn Hint>>,
211        tracker: Arc<dyn Tracker>,
212        context_manager: Arc<dyn ContextManager>,
213        context_strategy: Arc<dyn ContextStrategy>,
214        execution_strategy: Arc<dyn ExecutionStrategy>,
215        output_transformer: Arc<dyn OutputTransformer>,
216        output_processor: Arc<dyn OutputProcessor>,
217        tool_registry: Arc<dyn ToolRegistry>,
218        strategy: Box<dyn AgentStrategy>,
219        hooks: Vec<Arc<dyn AgentHook>>,
220        config: AgentConfig,
221    ) -> Self {
222        Self {
223            provider,
224            tools,
225            memory,
226            guards,
227            hints,
228            tracker,
229            context_manager,
230            context_strategy,
231            execution_strategy,
232            output_transformer,
233            output_processor,
234            tool_registry,
235            strategy,
236            hooks,
237            config,
238        }
239    }
240
241    /// Build an [`AgentRuntime`] from this agent's components.
242    ///
243    /// The runtime is passed to the strategy for execution.
244    fn to_runtime(&self) -> AgentRuntime {
245        AgentRuntime {
246            provider: Arc::clone(&self.provider),
247            tools: self.tools.clone(),
248            memory: Arc::clone(&self.memory),
249            guards: self.guards.clone(),
250            hints: self.hints.clone(),
251            tracker: Arc::clone(&self.tracker),
252            context_manager: Arc::clone(&self.context_manager),
253            context_strategy: Arc::clone(&self.context_strategy),
254            execution_strategy: Arc::clone(&self.execution_strategy),
255            output_transformer: Arc::clone(&self.output_transformer),
256            output_processor: Arc::clone(&self.output_processor),
257            tool_registry: Arc::clone(&self.tool_registry),
258            hooks: self.hooks.clone(),
259            config: self.config.clone(),
260        }
261    }
262
263    /// Create a session bound to a specific session ID.
264    ///
265    /// The returned [`AgentSession`] routes all memory operations through this
266    /// session ID, providing conversation isolation.
267    ///
268    /// ```rust,no_run
269    /// # use traitclaw_core::prelude::*;
270    /// # async fn example(agent: &Agent) -> traitclaw_core::Result<()> {
271    /// let session = agent.session("user-123");
272    /// let output = session.say("Hello!").await?;
273    /// # Ok(())
274    /// # }
275    /// ```
276    #[must_use]
277    pub fn session(&self, id: impl Into<String>) -> AgentSession<'_> {
278        AgentSession {
279            agent: self,
280            session_id: id.into(),
281        }
282    }
283
284    /// Create a session with an auto-generated UUID v4 session ID.
285    ///
286    /// Useful when you want isolated conversations without managing IDs.
287    #[must_use]
288    pub fn session_auto(&self) -> AgentSession<'_> {
289        AgentSession {
290            agent: self,
291            session_id: uuid::Uuid::new_v4().to_string(),
292        }
293    }
294
295    /// Run the agent with user input and return the final output.
296    ///
297    /// This uses the `"default"` session for backward compatibility.
298    /// For session isolation, use [`Agent::session()`] or [`Agent::session_auto()`].
299    ///
300    /// # Errors
301    ///
302    /// Returns an error if the provider fails, tool execution fails,
303    /// memory operations fail, or max iterations are reached.
304    pub async fn run(&self, input: &str) -> Result<AgentOutput> {
305        let runtime = self.to_runtime();
306        self.strategy.execute(&runtime, input, "default").await
307    }
308
309    /// Run the agent and return a streaming response.
310    ///
311    /// This uses the `"default"` session for backward compatibility.
312    ///
313    /// Returns an [`AgentStream`] that yields [`StreamEvent`]s incrementally,
314    /// providing real-time output from the LLM.
315    ///
316    /// [`AgentStream`]: crate::streaming::AgentStream
317    /// [`StreamEvent`]: crate::types::stream::StreamEvent
318    #[must_use]
319    pub fn stream(&self, input: &str) -> AgentStream {
320        self.stream_with_session(input, "default")
321    }
322
323    /// Run the agent and return a structured output.
324    ///
325    /// The LLM is instructed to return JSON matching type `T`'s schema.
326    /// If deserialization fails, retries up to 3 times with feedback.
327    ///
328    /// When the provider's model supports native structured-output
329    /// (`model_info.supports_structured == true`), the `response_format`
330    /// is set on the `CompletionRequest` for guaranteed valid JSON.
331    /// Otherwise, schema instructions are injected into the system prompt.
332    ///
333    /// # ⚠️ Stateless Mode
334    ///
335    /// This method calls the provider directly, **bypassing** the agent
336    /// runtime loop. Memory, guards, hints, context strategy, and usage
337    /// tracking are not used. Use `run()` for full agent behavior.
338    ///
339    /// # Errors
340    ///
341    /// Returns an error if the provider fails or deserialization fails
342    /// after retries.
343    pub async fn run_structured<T>(&self, input: &str) -> Result<T>
344    where
345        T: serde::de::DeserializeOwned + schemars::JsonSchema,
346    {
347        let model_info = self.provider.model_info();
348        let schema = schemars::schema_for!(T);
349        let schema_json = serde_json::to_value(&schema)
350            .map_err(|e| crate::Error::Runtime(format!("Failed to serialize schema: {e}")))?;
351
352        let uses_native = model_info.supports_structured;
353
354        let mut messages = vec![];
355        if let Some(ref system_prompt) = self.config.system_prompt {
356            messages.push(Message::system(system_prompt));
357        }
358
359        // If the model doesn't support native structured output, inject schema instructions
360        if !uses_native {
361            let schema_str = serde_json::to_string_pretty(&schema_json)
362                .unwrap_or_else(|_| schema_json.to_string());
363            messages.push(Message::system(format!(
364                "You MUST respond ONLY with valid JSON matching this schema:\n```json\n{schema_str}\n```\nDo NOT include any text before or after the JSON."
365            )));
366        }
367
368        messages.push(Message::user(input));
369
370        let max_retries = 3;
371        let mut last_error = String::new();
372
373        for attempt in 0..=max_retries {
374            if attempt > 0 {
375                // Add retry feedback
376                messages.push(Message::system(format!(
377                    "Your previous response was not valid JSON. Error: {last_error}\n\
378                     Please try again. Respond ONLY with valid JSON."
379                )));
380            }
381
382            let response_format = if uses_native {
383                Some(crate::types::completion::ResponseFormat::JsonSchema {
384                    json_schema: schema_json.clone(),
385                })
386            } else {
387                None
388            };
389
390            let request = crate::types::completion::CompletionRequest {
391                model: model_info.name.clone(),
392                messages: messages.clone(),
393                tools: vec![],
394                max_tokens: self.config.max_tokens,
395                temperature: self.config.temperature,
396                response_format,
397                stream: false,
398            };
399
400            let response = self.provider.complete(request).await?;
401
402            let text = match response.content {
403                crate::types::completion::ResponseContent::Text(t) => t,
404                crate::types::completion::ResponseContent::ToolCalls(_) => {
405                    last_error = "Model returned tool calls instead of JSON".into();
406                    messages.push(Message::assistant("[tool calls returned]"));
407                    continue;
408                }
409            };
410
411            match serde_json::from_str::<T>(&text) {
412                Ok(value) => return Ok(value),
413                Err(e) => {
414                    last_error = format!("{e}");
415                    messages.push(Message::assistant(&text));
416                }
417            }
418        }
419
420        Err(crate::Error::Runtime(format!(
421            "Structured output failed after {max_retries} retries. Last error: {last_error}"
422        )))
423    }
424
425    /// Internal stream implementation supporting custom session IDs.
426    pub(crate) fn stream_with_session(&self, input: &str, session_id: &str) -> AgentStream {
427        let runtime = crate::traits::strategy::AgentRuntime {
428            provider: Arc::clone(&self.provider),
429            tools: self.tools.clone(),
430            memory: Arc::clone(&self.memory),
431            guards: self.guards.clone(),
432            hints: self.hints.clone(),
433            tracker: Arc::clone(&self.tracker),
434            context_manager: Arc::clone(&self.context_manager),
435            context_strategy: Arc::clone(&self.context_strategy),
436            execution_strategy: Arc::clone(&self.execution_strategy),
437            output_transformer: Arc::clone(&self.output_transformer),
438            output_processor: Arc::clone(&self.output_processor),
439            tool_registry: Arc::clone(&self.tool_registry),
440            config: self.config.clone(),
441            hooks: self.hooks.clone(),
442        };
443
444        self.strategy.stream(&runtime, input, session_id)
445    }
446}
447
448/// A session-scoped agent wrapper.
449///
450/// Binds an [`Agent`] to a specific `session_id`, routing all memory
451/// operations through that session for conversation isolation.
452///
453/// Created via [`Agent::session()`] or [`Agent::session_auto()`].
454pub struct AgentSession<'a> {
455    agent: &'a Agent,
456    /// The session ID this session is bound to.
457    session_id: String,
458}
459
460impl AgentSession<'_> {
461    /// Send a message within this session.
462    ///
463    /// Equivalent to [`Agent::run()`] but uses this session's ID for memory.
464    ///
465    /// # Errors
466    ///
467    /// Returns an error if the provider fails, tool execution fails,
468    /// memory operations fail, or max iterations are reached.
469    pub async fn say(&self, input: &str) -> Result<AgentOutput> {
470        let runtime = self.agent.to_runtime();
471        self.agent
472            .strategy
473            .execute(&runtime, input, &self.session_id)
474            .await
475    }
476
477    /// Execute the agent loop with a custom session ID, returning a stream.
478    ///
479    /// Equivalent to [`Agent::stream()`] but uses this session's ID for memory.
480    #[must_use]
481    pub fn stream(&self, input: &str) -> AgentStream {
482        self.agent.stream_with_session(input, &self.session_id)
483    }
484
485    /// Get the session ID.
486    #[must_use]
487    pub fn id(&self) -> &str {
488        &self.session_id
489    }
490}
491
492#[cfg(test)]
493mod tests {
494    use super::*;
495
496    #[test]
497    fn test_run_usage_default() {
498        let u = RunUsage::default();
499        assert_eq!(u.tokens, 0);
500        assert_eq!(u.iterations, 0);
501        assert_eq!(u.duration, std::time::Duration::ZERO);
502    }
503
504    #[test]
505    fn test_text_output() {
506        let out = AgentOutput::text_with_usage("Hello".into(), RunUsage::default());
507        assert_eq!(out.text(), "Hello");
508        assert!(!out.is_error());
509        assert!(out.structured().is_none());
510        assert!(out.error_message().is_none());
511    }
512
513    #[test]
514    fn test_text_returns_empty_for_structured() {
515        let out = AgentOutput {
516            content: AgentOutputContent::Structured(serde_json::json!({"key": "val"})),
517            usage: RunUsage::default(),
518        };
519        assert_eq!(out.text(), "");
520        assert!(out.structured().is_some());
521        assert_eq!(out.structured().unwrap()["key"], "val");
522    }
523
524    #[test]
525    fn test_text_returns_empty_for_error() {
526        let out = AgentOutput {
527            content: AgentOutputContent::Error("boom".into()),
528            usage: RunUsage::default(),
529        };
530        assert_eq!(out.text(), "");
531        assert!(out.is_error());
532        assert_eq!(out.error_message(), Some("boom"));
533    }
534
535    #[test]
536    fn test_display_text() {
537        let out = AgentOutput::text_with_usage("hi".into(), RunUsage::default());
538        assert_eq!(format!("{out}"), "hi");
539    }
540
541    #[test]
542    fn test_display_structured() {
543        let out = AgentOutput {
544            content: AgentOutputContent::Structured(serde_json::json!(42)),
545            usage: RunUsage::default(),
546        };
547        assert_eq!(format!("{out}"), "42");
548    }
549
550    #[test]
551    fn test_display_error() {
552        let out = AgentOutput {
553            content: AgentOutputContent::Error("fail".into()),
554            usage: RunUsage::default(),
555        };
556        assert_eq!(format!("{out}"), "Error: fail");
557    }
558
559    #[test]
560    fn test_usage_carried_through() {
561        let usage = RunUsage {
562            tokens: 100,
563            iterations: 5,
564            duration: std::time::Duration::from_millis(500),
565        };
566        let out = AgentOutput::text_with_usage("x".into(), usage);
567        assert_eq!(out.usage.tokens, 100);
568        assert_eq!(out.usage.iterations, 5);
569        assert_eq!(out.usage.duration.as_millis(), 500);
570    }
571
572    // --- Agent::with_system() tests (Story 1.1) ---
573
574    use crate::types::completion::{CompletionRequest, CompletionResponse, ResponseContent, Usage};
575    use crate::types::model_info::{ModelInfo, ModelTier};
576    use crate::types::stream::CompletionStream;
577    use async_trait::async_trait;
578
579    struct MockProvider {
580        info: ModelInfo,
581    }
582
583    impl MockProvider {
584        fn new() -> Self {
585            Self {
586                info: ModelInfo::new("mock", ModelTier::Small, 4_096, false, false, false),
587            }
588        }
589    }
590
591    #[async_trait]
592    impl crate::traits::provider::Provider for MockProvider {
593        async fn complete(&self, _req: CompletionRequest) -> crate::Result<CompletionResponse> {
594            Ok(CompletionResponse {
595                content: ResponseContent::Text("ok".into()),
596                usage: Usage {
597                    prompt_tokens: 1,
598                    completion_tokens: 1,
599                    total_tokens: 2,
600                },
601            })
602        }
603        async fn stream(&self, _req: CompletionRequest) -> crate::Result<CompletionStream> {
604            unimplemented!()
605        }
606        fn model_info(&self) -> &ModelInfo {
607            &self.info
608        }
609    }
610
611    #[test]
612    fn test_with_system_str_prompt() {
613        // AC #1, #2: with_system accepts &str and creates a valid agent
614        let agent = Agent::with_system(MockProvider::new(), "You are helpful.");
615        assert_eq!(
616            agent.config.system_prompt.as_deref(),
617            Some("You are helpful.")
618        );
619    }
620
621    #[test]
622    fn test_with_system_string_prompt() {
623        // AC #2: with_system accepts String
624        let prompt = String::from("You are a researcher.");
625        let agent = Agent::with_system(MockProvider::new(), prompt);
626        assert_eq!(
627            agent.config.system_prompt.as_deref(),
628            Some("You are a researcher.")
629        );
630    }
631
632    #[test]
633    fn test_with_system_builder_unchanged() {
634        // AC #3: builder API is unchanged (still works)
635        let result = Agent::builder()
636            .provider(MockProvider::new())
637            .system("test")
638            .build();
639        assert!(result.is_ok());
640    }
641
642    #[test]
643    fn test_with_system_provider_configured() {
644        // AC #4: agent has correct provider
645        let agent = Agent::with_system(MockProvider::new(), "test");
646        assert_eq!(agent.provider.model_info().name, "mock");
647    }
648}