Skip to main content

traitclaw_core/
agent.rs

1//! Agent — the main entry point for running AI agents.
2
3use std::sync::Arc;
4
5use crate::agent_builder::AgentBuilder;
6use crate::config::AgentConfig;
7use crate::streaming::AgentStream;
8use crate::traits::context_strategy::ContextStrategy;
9use crate::traits::execution_strategy::ExecutionStrategy;
10use crate::traits::guard::Guard;
11use crate::traits::hint::Hint;
12use crate::traits::memory::Memory;
13use crate::traits::output_processor::OutputProcessor;
14use crate::traits::provider::Provider;
15use crate::traits::tool::ErasedTool;
16use crate::traits::tracker::Tracker;
17use crate::types::message::Message;
18use crate::Result;
19
20/// Usage statistics from an agent run.
21#[derive(Debug, Clone, Default)]
22pub struct RunUsage {
23    /// Total tokens consumed across all LLM calls.
24    pub tokens: usize,
25    /// Number of agent loop iterations.
26    pub iterations: usize,
27    /// Wall-clock duration of the run.
28    pub duration: std::time::Duration,
29}
30
31/// Output from an agent run.
32///
33/// This struct is marked `#[non_exhaustive]` — new fields may be added in
34/// future releases without breaking changes.
35#[derive(Debug, Clone)]
36#[non_exhaustive]
37pub struct AgentOutput {
38    /// The response content.
39    pub content: AgentOutputContent,
40    /// Usage statistics for this run.
41    pub usage: RunUsage,
42}
43
44/// The content type of an agent output.
45#[derive(Debug, Clone)]
46#[non_exhaustive]
47pub enum AgentOutputContent {
48    /// The agent returned a text response.
49    Text(String),
50    /// The agent returned a structured JSON response.
51    Structured(serde_json::Value),
52    /// The agent encountered an error.
53    Error(String),
54}
55
56impl AgentOutput {
57    /// Create a text output with usage.
58    #[must_use]
59    pub fn text_with_usage(text: String, usage: RunUsage) -> Self {
60        Self {
61            content: AgentOutputContent::Text(text),
62            usage,
63        }
64    }
65
66    /// Get the text content if this is a text output.
67    ///
68    /// Returns an empty string for `Structured` and `Error` variants.
69    /// Use [`structured()`](Self::structured) or [`Display`] for those.
70    #[must_use]
71    pub fn text(&self) -> &str {
72        match &self.content {
73            AgentOutputContent::Text(t) => t,
74            _ => "",
75        }
76    }
77
78    /// Get the error message if this is an error output.
79    #[must_use]
80    pub fn error_message(&self) -> Option<&str> {
81        match &self.content {
82            AgentOutputContent::Error(e) => Some(e),
83            _ => None,
84        }
85    }
86
87    /// Get the structured JSON value if this is a structured output.
88    #[must_use]
89    pub fn structured(&self) -> Option<&serde_json::Value> {
90        match &self.content {
91            AgentOutputContent::Structured(v) => Some(v),
92            _ => None,
93        }
94    }
95
96    /// Check if this is an error output.
97    #[must_use]
98    pub fn is_error(&self) -> bool {
99        matches!(&self.content, AgentOutputContent::Error(_))
100    }
101}
102
103impl std::fmt::Display for AgentOutput {
104    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
105        match &self.content {
106            AgentOutputContent::Text(t) => write!(f, "{t}"),
107            AgentOutputContent::Structured(v) => write!(f, "{v}"),
108            AgentOutputContent::Error(e) => write!(f, "Error: {e}"),
109        }
110    }
111}
112
113/// The main agent struct.
114///
115/// An agent combines an LLM provider, tools, memory, and steering into
116/// a single runtime that processes user input and produces output.
117///
118/// Use [`Agent::builder()`] to construct an agent.
119pub struct Agent {
120    pub(crate) provider: Arc<dyn Provider>,
121    pub(crate) tools: Vec<Arc<dyn ErasedTool>>,
122    pub(crate) memory: Arc<dyn Memory>,
123    pub(crate) guards: Vec<Arc<dyn Guard>>,
124    pub(crate) hints: Vec<Arc<dyn Hint>>,
125    pub(crate) tracker: Arc<dyn Tracker>,
126    pub(crate) context_strategy: Arc<dyn ContextStrategy>,
127    pub(crate) execution_strategy: Arc<dyn ExecutionStrategy>,
128    pub(crate) output_processor: Arc<dyn OutputProcessor>,
129    pub(crate) config: AgentConfig,
130}
131
132impl std::fmt::Debug for Agent {
133    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
134        f.debug_struct("Agent")
135            .field("model", &self.provider.model_info().name)
136            .field("tools", &self.tools.len())
137            .field("guards", &self.guards.len())
138            .field("hints", &self.hints.len())
139            .field("config", &self.config)
140            .finish_non_exhaustive()
141    }
142}
143
144impl Agent {
145    /// Create a builder for constructing an agent.
146    #[must_use]
147    pub fn builder() -> AgentBuilder {
148        AgentBuilder::new()
149    }
150
151    /// Create an agent directly (prefer using `builder()`).
152    #[allow(clippy::too_many_arguments)]
153    pub(crate) fn new(
154        provider: Arc<dyn Provider>,
155        tools: Vec<Arc<dyn ErasedTool>>,
156        memory: Arc<dyn Memory>,
157        guards: Vec<Arc<dyn Guard>>,
158        hints: Vec<Arc<dyn Hint>>,
159        tracker: Arc<dyn Tracker>,
160        context_strategy: Arc<dyn ContextStrategy>,
161        execution_strategy: Arc<dyn ExecutionStrategy>,
162        output_processor: Arc<dyn OutputProcessor>,
163        config: AgentConfig,
164    ) -> Self {
165        Self {
166            provider,
167            tools,
168            memory,
169            guards,
170            hints,
171            tracker,
172            context_strategy,
173            execution_strategy,
174            output_processor,
175            config,
176        }
177    }
178
179    /// Create a session bound to a specific session ID.
180    ///
181    /// The returned [`AgentSession`] routes all memory operations through this
182    /// session ID, providing conversation isolation.
183    ///
184    /// ```rust,no_run
185    /// # use traitclaw_core::prelude::*;
186    /// # async fn example(agent: &Agent) -> traitclaw_core::Result<()> {
187    /// let session = agent.session("user-123");
188    /// let output = session.say("Hello!").await?;
189    /// # Ok(())
190    /// # }
191    /// ```
192    #[must_use]
193    pub fn session(&self, id: impl Into<String>) -> AgentSession<'_> {
194        AgentSession {
195            agent: self,
196            session_id: id.into(),
197        }
198    }
199
200    /// Create a session with an auto-generated UUID v4 session ID.
201    ///
202    /// Useful when you want isolated conversations without managing IDs.
203    #[must_use]
204    pub fn session_auto(&self) -> AgentSession<'_> {
205        AgentSession {
206            agent: self,
207            session_id: uuid::Uuid::new_v4().to_string(),
208        }
209    }
210
211    /// Run the agent with user input and return the final output.
212    ///
213    /// This uses the `"default"` session for backward compatibility.
214    /// For session isolation, use [`Agent::session()`] or [`Agent::session_auto()`].
215    ///
216    /// # Errors
217    ///
218    /// Returns an error if the provider fails, tool execution fails,
219    /// memory operations fail, or max iterations are reached.
220    pub async fn run(&self, input: &str) -> Result<AgentOutput> {
221        crate::runtime::run_agent(self, input, "default").await
222    }
223
224    /// Run the agent and return a streaming response.
225    ///
226    /// This uses the `"default"` session for backward compatibility.
227    ///
228    /// Returns an [`AgentStream`] that yields [`StreamEvent`]s incrementally,
229    /// allowing you to display text as it is generated.
230    ///
231    /// [`StreamEvent`]: crate::types::stream::StreamEvent
232    #[must_use]
233    pub fn stream(&self, input: &str) -> AgentStream {
234        crate::streaming::stream_agent(self, input.to_string(), "default".to_string())
235    }
236
237    /// Run the agent and return a structured output.
238    ///
239    /// The LLM is instructed to return JSON matching type `T`'s schema.
240    /// If deserialization fails, retries up to 3 times with feedback.
241    ///
242    /// When the provider's model supports native structured-output
243    /// (`model_info.supports_structured == true`), the `response_format`
244    /// is set on the `CompletionRequest` for guaranteed valid JSON.
245    /// Otherwise, schema instructions are injected into the system prompt.
246    ///
247    /// # ⚠️ Stateless Mode
248    ///
249    /// This method calls the provider directly, **bypassing** the agent
250    /// runtime loop. Memory, guards, hints, context strategy, and usage
251    /// tracking are not used. Use `run()` for full agent behavior.
252    ///
253    /// # Errors
254    ///
255    /// Returns an error if the provider fails or deserialization fails
256    /// after retries.
257    pub async fn run_structured<T>(&self, input: &str) -> Result<T>
258    where
259        T: serde::de::DeserializeOwned + schemars::JsonSchema,
260    {
261        let model_info = self.provider.model_info();
262        let schema = schemars::schema_for!(T);
263        let schema_json = serde_json::to_value(&schema)
264            .map_err(|e| crate::Error::Runtime(format!("Failed to serialize schema: {e}")))?;
265
266        let uses_native = model_info.supports_structured;
267
268        let mut messages = vec![];
269        if let Some(ref system_prompt) = self.config.system_prompt {
270            messages.push(Message::system(system_prompt));
271        }
272
273        // If the model doesn't support native structured output, inject schema instructions
274        if !uses_native {
275            let schema_str = serde_json::to_string_pretty(&schema_json)
276                .unwrap_or_else(|_| schema_json.to_string());
277            messages.push(Message::system(format!(
278                "You MUST respond ONLY with valid JSON matching this schema:\n```json\n{schema_str}\n```\nDo NOT include any text before or after the JSON."
279            )));
280        }
281
282        messages.push(Message::user(input));
283
284        let max_retries = 3;
285        let mut last_error = String::new();
286
287        for attempt in 0..=max_retries {
288            if attempt > 0 {
289                // Add retry feedback
290                messages.push(Message::system(format!(
291                    "Your previous response was not valid JSON. Error: {last_error}\n\
292                     Please try again. Respond ONLY with valid JSON."
293                )));
294            }
295
296            let response_format = if uses_native {
297                Some(crate::types::completion::ResponseFormat::JsonSchema {
298                    json_schema: schema_json.clone(),
299                })
300            } else {
301                None
302            };
303
304            let request = crate::types::completion::CompletionRequest {
305                model: model_info.name.clone(),
306                messages: messages.clone(),
307                tools: vec![],
308                max_tokens: self.config.max_tokens,
309                temperature: self.config.temperature,
310                response_format,
311                stream: false,
312            };
313
314            let response = self.provider.complete(request).await?;
315
316            let text = match response.content {
317                crate::types::completion::ResponseContent::Text(t) => t,
318                crate::types::completion::ResponseContent::ToolCalls(_) => {
319                    last_error = "Model returned tool calls instead of JSON".into();
320                    messages.push(Message::assistant("[tool calls returned]"));
321                    continue;
322                }
323            };
324
325            match serde_json::from_str::<T>(&text) {
326                Ok(value) => return Ok(value),
327                Err(e) => {
328                    last_error = format!("{e}");
329                    messages.push(Message::assistant(&text));
330                }
331            }
332        }
333
334        Err(crate::Error::Runtime(format!(
335            "Structured output failed after {max_retries} retries. Last error: {last_error}"
336        )))
337    }
338}
339
340/// A session-scoped agent wrapper.
341///
342/// Binds an [`Agent`] to a specific `session_id`, routing all memory
343/// operations through that session for conversation isolation.
344///
345/// Created via [`Agent::session()`] or [`Agent::session_auto()`].
346pub struct AgentSession<'a> {
347    agent: &'a Agent,
348    /// The session ID this session is bound to.
349    session_id: String,
350}
351
352impl AgentSession<'_> {
353    /// Send a message within this session.
354    ///
355    /// Equivalent to [`Agent::run()`] but uses this session's ID for memory.
356    ///
357    /// # Errors
358    ///
359    /// Returns an error if the provider fails, tool execution fails,
360    /// memory operations fail, or max iterations are reached.
361    pub async fn say(&self, input: &str) -> Result<AgentOutput> {
362        crate::runtime::run_agent(self.agent, input, &self.session_id).await
363    }
364
365    /// Stream a response within this session.
366    ///
367    /// Equivalent to [`Agent::stream()`] but uses this session's ID for memory.
368    #[must_use]
369    pub fn stream(&self, input: &str) -> AgentStream {
370        crate::streaming::stream_agent(self.agent, input.to_string(), self.session_id.clone())
371    }
372
373    /// Get the session ID.
374    #[must_use]
375    pub fn id(&self) -> &str {
376        &self.session_id
377    }
378}
379
380#[cfg(test)]
381mod tests {
382    use super::*;
383
384    #[test]
385    fn test_run_usage_default() {
386        let u = RunUsage::default();
387        assert_eq!(u.tokens, 0);
388        assert_eq!(u.iterations, 0);
389        assert_eq!(u.duration, std::time::Duration::ZERO);
390    }
391
392    #[test]
393    fn test_text_output() {
394        let out = AgentOutput::text_with_usage("Hello".into(), RunUsage::default());
395        assert_eq!(out.text(), "Hello");
396        assert!(!out.is_error());
397        assert!(out.structured().is_none());
398        assert!(out.error_message().is_none());
399    }
400
401    #[test]
402    fn test_text_returns_empty_for_structured() {
403        let out = AgentOutput {
404            content: AgentOutputContent::Structured(serde_json::json!({"key": "val"})),
405            usage: RunUsage::default(),
406        };
407        assert_eq!(out.text(), "");
408        assert!(out.structured().is_some());
409        assert_eq!(out.structured().unwrap()["key"], "val");
410    }
411
412    #[test]
413    fn test_text_returns_empty_for_error() {
414        let out = AgentOutput {
415            content: AgentOutputContent::Error("boom".into()),
416            usage: RunUsage::default(),
417        };
418        assert_eq!(out.text(), "");
419        assert!(out.is_error());
420        assert_eq!(out.error_message(), Some("boom"));
421    }
422
423    #[test]
424    fn test_display_text() {
425        let out = AgentOutput::text_with_usage("hi".into(), RunUsage::default());
426        assert_eq!(format!("{out}"), "hi");
427    }
428
429    #[test]
430    fn test_display_structured() {
431        let out = AgentOutput {
432            content: AgentOutputContent::Structured(serde_json::json!(42)),
433            usage: RunUsage::default(),
434        };
435        assert_eq!(format!("{out}"), "42");
436    }
437
438    #[test]
439    fn test_display_error() {
440        let out = AgentOutput {
441            content: AgentOutputContent::Error("fail".into()),
442            usage: RunUsage::default(),
443        };
444        assert_eq!(format!("{out}"), "Error: fail");
445    }
446
447    #[test]
448    fn test_usage_carried_through() {
449        let usage = RunUsage {
450            tokens: 100,
451            iterations: 5,
452            duration: std::time::Duration::from_millis(500),
453        };
454        let out = AgentOutput::text_with_usage("x".into(), usage);
455        assert_eq!(out.usage.tokens, 100);
456        assert_eq!(out.usage.iterations, 5);
457        assert_eq!(out.usage.duration.as_millis(), 500);
458    }
459}