Skip to main content

traitclaw_core/
agent.rs

1//! Agent — the main entry point for running AI agents.
2
3use std::sync::Arc;
4
5use crate::agent_builder::AgentBuilder;
6use crate::config::AgentConfig;
7use crate::streaming::AgentStream;
8use crate::traits::context_manager::ContextManager;
9use crate::traits::execution_strategy::ExecutionStrategy;
10use crate::traits::guard::Guard;
11use crate::traits::hint::Hint;
12use crate::traits::hook::AgentHook;
13use crate::traits::memory::Memory;
14use crate::traits::output_transformer::OutputTransformer;
15use crate::traits::provider::Provider;
16use crate::traits::strategy::{AgentRuntime, AgentStrategy};
17use crate::traits::tool::ErasedTool;
18use crate::traits::tool_registry::ToolRegistry;
19use crate::traits::tracker::Tracker;
20use crate::types::message::Message;
21use crate::Result;
22
23/// Usage statistics from an agent run.
24#[derive(Debug, Clone, Default)]
25pub struct RunUsage {
26    /// Total tokens consumed across all LLM calls.
27    pub tokens: usize,
28    /// Number of agent loop iterations.
29    pub iterations: usize,
30    /// Wall-clock duration of the run.
31    pub duration: std::time::Duration,
32}
33
34/// Output from an agent run.
35///
36/// This struct is marked `#[non_exhaustive]` — new fields may be added in
37/// future releases without breaking changes.
38#[derive(Debug, Clone)]
39#[non_exhaustive]
40pub struct AgentOutput {
41    /// The response content.
42    pub content: AgentOutputContent,
43    /// Usage statistics for this run.
44    pub usage: RunUsage,
45}
46
47/// The content type of an agent output.
48#[derive(Debug, Clone)]
49#[non_exhaustive]
50pub enum AgentOutputContent {
51    /// The agent returned a text response.
52    Text(String),
53    /// The agent returned a structured JSON response.
54    Structured(serde_json::Value),
55    /// The agent encountered an error.
56    Error(String),
57}
58
59impl AgentOutput {
60    /// Create a text output with usage.
61    #[must_use]
62    pub fn text_with_usage(text: String, usage: RunUsage) -> Self {
63        Self {
64            content: AgentOutputContent::Text(text),
65            usage,
66        }
67    }
68
69    /// Get the text content if this is a text output.
70    ///
71    /// Returns an empty string for `Structured` and `Error` variants.
72    /// Use [`structured()`](Self::structured) or `Display` for those.
73    #[must_use]
74    pub fn text(&self) -> &str {
75        match &self.content {
76            AgentOutputContent::Text(t) => t,
77            _ => "",
78        }
79    }
80
81    /// Get the error message if this is an error output.
82    #[must_use]
83    pub fn error_message(&self) -> Option<&str> {
84        match &self.content {
85            AgentOutputContent::Error(e) => Some(e),
86            _ => None,
87        }
88    }
89
90    /// Get the structured JSON value if this is a structured output.
91    #[must_use]
92    pub fn structured(&self) -> Option<&serde_json::Value> {
93        match &self.content {
94            AgentOutputContent::Structured(v) => Some(v),
95            _ => None,
96        }
97    }
98
99    /// Check if this is an error output.
100    #[must_use]
101    pub fn is_error(&self) -> bool {
102        matches!(&self.content, AgentOutputContent::Error(_))
103    }
104}
105
106impl std::fmt::Display for AgentOutput {
107    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
108        match &self.content {
109            AgentOutputContent::Text(t) => write!(f, "{t}"),
110            AgentOutputContent::Structured(v) => write!(f, "{v}"),
111            AgentOutputContent::Error(e) => write!(f, "Error: {e}"),
112        }
113    }
114}
115
116/// The main agent struct.
117///
118/// An agent combines an LLM provider, tools, memory, and steering into
119/// a single runtime that processes user input and produces output.
120///
121/// Use [`Agent::builder()`] to construct an agent.
122pub struct Agent {
123    pub(crate) provider: Arc<dyn Provider>,
124    pub(crate) tools: Vec<Arc<dyn ErasedTool>>,
125    pub(crate) memory: Arc<dyn Memory>,
126    pub(crate) guards: Vec<Arc<dyn Guard>>,
127    pub(crate) hints: Vec<Arc<dyn Hint>>,
128    pub(crate) tracker: Arc<dyn Tracker>,
129    pub(crate) context_manager: Arc<dyn ContextManager>,
130    pub(crate) execution_strategy: Arc<dyn ExecutionStrategy>,
131    pub(crate) output_transformer: Arc<dyn OutputTransformer>,
132    pub(crate) tool_registry: Arc<dyn ToolRegistry>,
133    pub(crate) strategy: Box<dyn AgentStrategy>,
134    pub(crate) hooks: Vec<Arc<dyn AgentHook>>,
135    pub(crate) config: AgentConfig,
136}
137
138impl std::fmt::Debug for Agent {
139    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
140        f.debug_struct("Agent")
141            .field("model", &self.provider.model_info().name)
142            .field("tools", &self.tools.len())
143            .field("guards", &self.guards.len())
144            .field("hints", &self.hints.len())
145            .field("hooks", &self.hooks.len())
146            .field("config", &self.config)
147            .finish_non_exhaustive()
148    }
149}
150
151impl Agent {
152    /// Create a builder for constructing an agent.
153    #[must_use]
154    pub fn builder() -> AgentBuilder {
155        AgentBuilder::new()
156    }
157
158    /// Create an agent with just a provider and system prompt.
159    ///
160    /// This is a convenience shorthand equivalent to:
161    /// ```rust,ignore
162    /// Agent::builder()
163    ///     .provider(provider)
164    ///     .system(system)
165    ///     .build()
166    ///     .unwrap()
167    /// ```
168    ///
169    /// All other settings use their defaults (in-memory memory, no tools,
170    /// no guards, etc.). Use [`Agent::builder()`] for full customization.
171    ///
172    /// # Example
173    ///
174    /// ```rust,no_run
175    /// use traitclaw_core::prelude::*;
176    ///
177    /// # fn example(provider: impl traitclaw_core::traits::provider::Provider) {
178    /// let agent = Agent::with_system(provider, "You are a helpful assistant.");
179    /// # }
180    /// ```
181    /// # Panics
182    ///
183    /// This method cannot panic under normal usage — the internal `build()`
184    /// call only fails when no provider is set, and `with_system` always
185    /// provides one.
186    #[must_use]
187    pub fn with_system(provider: impl Provider, system: impl Into<String>) -> Self {
188        Agent::builder()
189            .provider(provider)
190            .system(system)
191            .build()
192            .expect("Agent::with_system is infallible: provider is always set")
193    }
194
195    /// Create an agent directly (prefer using `builder()`).
196    #[allow(clippy::too_many_arguments)]
197    pub(crate) fn new(
198        provider: Arc<dyn Provider>,
199        tools: Vec<Arc<dyn ErasedTool>>,
200        memory: Arc<dyn Memory>,
201        guards: Vec<Arc<dyn Guard>>,
202        hints: Vec<Arc<dyn Hint>>,
203        tracker: Arc<dyn Tracker>,
204        context_manager: Arc<dyn ContextManager>,
205        execution_strategy: Arc<dyn ExecutionStrategy>,
206        output_transformer: Arc<dyn OutputTransformer>,
207        tool_registry: Arc<dyn ToolRegistry>,
208        strategy: Box<dyn AgentStrategy>,
209        hooks: Vec<Arc<dyn AgentHook>>,
210        config: AgentConfig,
211    ) -> Self {
212        Self {
213            provider,
214            tools,
215            memory,
216            guards,
217            hints,
218            tracker,
219            context_manager,
220            execution_strategy,
221            output_transformer,
222            tool_registry,
223            strategy,
224            hooks,
225            config,
226        }
227    }
228
229    /// Build an [`AgentRuntime`] from this agent's components.
230    ///
231    /// The runtime is passed to the strategy for execution.
232    fn to_runtime(&self) -> AgentRuntime {
233        AgentRuntime {
234            provider: Arc::clone(&self.provider),
235            tools: self.tools.clone(),
236            memory: Arc::clone(&self.memory),
237            guards: self.guards.clone(),
238            hints: self.hints.clone(),
239            tracker: Arc::clone(&self.tracker),
240            context_manager: Arc::clone(&self.context_manager),
241            execution_strategy: Arc::clone(&self.execution_strategy),
242            output_transformer: Arc::clone(&self.output_transformer),
243            tool_registry: Arc::clone(&self.tool_registry),
244            hooks: self.hooks.clone(),
245            config: self.config.clone(),
246        }
247    }
248
249    /// Create a session bound to a specific session ID.
250    ///
251    /// The returned [`AgentSession`] routes all memory operations through this
252    /// session ID, providing conversation isolation.
253    ///
254    /// ```rust,no_run
255    /// # use traitclaw_core::prelude::*;
256    /// # async fn example(agent: &Agent) -> traitclaw_core::Result<()> {
257    /// let session = agent.session("user-123");
258    /// let output = session.say("Hello!").await?;
259    /// # Ok(())
260    /// # }
261    /// ```
262    #[must_use]
263    pub fn session(&self, id: impl Into<String>) -> AgentSession<'_> {
264        AgentSession {
265            agent: self,
266            session_id: id.into(),
267        }
268    }
269
270    /// Create a session with an auto-generated UUID v4 session ID.
271    ///
272    /// Useful when you want isolated conversations without managing IDs.
273    #[must_use]
274    pub fn session_auto(&self) -> AgentSession<'_> {
275        AgentSession {
276            agent: self,
277            session_id: uuid::Uuid::new_v4().to_string(),
278        }
279    }
280
281    /// Returns a reference to the agent's memory store.
282    ///
283    /// Useful for inspecting conversation history in tests or building
284    /// custom conversation management on top of the agent.
285    #[must_use]
286    pub fn memory(&self) -> &dyn Memory {
287        &*self.memory
288    }
289
290    /// Run the agent with user input and return the final output.
291    ///
292    /// This uses the `"default"` session for backward compatibility.
293    /// For session isolation, use [`Agent::session()`] or [`Agent::session_auto()`].
294    ///
295    /// # Errors
296    ///
297    /// Returns an error if the provider fails, tool execution fails,
298    /// memory operations fail, or max iterations are reached.
299    pub async fn run(&self, input: &str) -> Result<AgentOutput> {
300        let runtime = self.to_runtime();
301        self.strategy.execute(&runtime, input, "default").await
302    }
303
304    /// Run the agent and return a streaming response.
305    ///
306    /// This uses the `"default"` session for backward compatibility.
307    ///
308    /// Returns an `AgentStream` that yields `StreamEvent`s incrementally,
309    /// providing real-time output from the LLM.
310    #[must_use]
311    pub fn stream(&self, input: &str) -> AgentStream {
312        self.stream_with_session(input, "default")
313    }
314
315    /// Run the agent and return a structured output.
316    ///
317    /// The LLM is instructed to return JSON matching type `T`'s schema.
318    /// If deserialization fails, retries up to 3 times with feedback.
319    ///
320    /// When the provider's model supports native structured-output
321    /// (`model_info.supports_structured == true`), the `response_format`
322    /// is set on the `CompletionRequest` for guaranteed valid JSON.
323    /// Otherwise, schema instructions are injected into the system prompt.
324    ///
325    /// # ⚠️ Stateless Mode
326    ///
327    /// This method calls the provider directly, **bypassing** the agent
328    /// runtime loop. Memory, guards, hints, context strategy, and usage
329    /// tracking are not used. Use `run()` for full agent behavior.
330    ///
331    /// # Errors
332    ///
333    /// Returns an error if the provider fails or deserialization fails
334    /// after retries.
335    pub async fn run_structured<T>(&self, input: &str) -> Result<T>
336    where
337        T: serde::de::DeserializeOwned + schemars::JsonSchema,
338    {
339        let model_info = self.provider.model_info();
340        let schema = schemars::schema_for!(T);
341        let schema_json = serde_json::to_value(&schema)
342            .map_err(|e| crate::Error::Runtime(format!("Failed to serialize schema: {e}")))?;
343
344        let uses_native = model_info.supports_structured;
345
346        let mut messages = vec![];
347        if let Some(ref system_prompt) = self.config.system_prompt {
348            messages.push(Message::system(system_prompt));
349        }
350
351        // If the model doesn't support native structured output, inject schema instructions
352        if !uses_native {
353            let schema_str = serde_json::to_string_pretty(&schema_json)
354                .unwrap_or_else(|_| schema_json.to_string());
355            messages.push(Message::system(format!(
356                "You MUST respond ONLY with valid JSON matching this schema:\n```json\n{schema_str}\n```\nDo NOT include any text before or after the JSON."
357            )));
358        }
359
360        messages.push(Message::user(input));
361
362        let max_retries = 3;
363        let mut last_error = String::new();
364
365        for attempt in 0..=max_retries {
366            if attempt > 0 {
367                // Add retry feedback
368                messages.push(Message::system(format!(
369                    "Your previous response was not valid JSON. Error: {last_error}\n\
370                     Please try again. Respond ONLY with valid JSON."
371                )));
372            }
373
374            let response_format = if uses_native {
375                Some(crate::types::completion::ResponseFormat::JsonSchema {
376                    json_schema: schema_json.clone(),
377                })
378            } else {
379                None
380            };
381
382            let request = crate::types::completion::CompletionRequest {
383                model: model_info.name.clone(),
384                messages: messages.clone(),
385                tools: vec![],
386                max_tokens: self.config.max_tokens,
387                temperature: self.config.temperature,
388                response_format,
389                stream: false,
390            };
391
392            let response = self.provider.complete(request).await?;
393
394            let text = match response.content {
395                crate::types::completion::ResponseContent::Text(t) => t,
396                crate::types::completion::ResponseContent::ToolCalls(_) => {
397                    last_error = "Model returned tool calls instead of JSON".into();
398                    messages.push(Message::assistant("[tool calls returned]"));
399                    continue;
400                }
401            };
402
403            match serde_json::from_str::<T>(&text) {
404                Ok(value) => return Ok(value),
405                Err(e) => {
406                    last_error = format!("{e}");
407                    messages.push(Message::assistant(&text));
408                }
409            }
410        }
411
412        Err(crate::Error::Runtime(format!(
413            "Structured output failed after {max_retries} retries. Last error: {last_error}"
414        )))
415    }
416
417    /// Internal stream implementation supporting custom session IDs.
418    pub(crate) fn stream_with_session(&self, input: &str, session_id: &str) -> AgentStream {
419        let runtime = crate::traits::strategy::AgentRuntime {
420            provider: Arc::clone(&self.provider),
421            tools: self.tools.clone(),
422            memory: Arc::clone(&self.memory),
423            guards: self.guards.clone(),
424            hints: self.hints.clone(),
425            tracker: Arc::clone(&self.tracker),
426            context_manager: Arc::clone(&self.context_manager),
427            execution_strategy: Arc::clone(&self.execution_strategy),
428            output_transformer: Arc::clone(&self.output_transformer),
429            tool_registry: Arc::clone(&self.tool_registry),
430            config: self.config.clone(),
431            hooks: self.hooks.clone(),
432        };
433
434        self.strategy.stream(&runtime, input, session_id)
435    }
436}
437
438/// A session-scoped agent wrapper.
439///
440/// Binds an [`Agent`] to a specific `session_id`, routing all memory
441/// operations through that session for conversation isolation.
442///
443/// Created via [`Agent::session()`] or [`Agent::session_auto()`].
444pub struct AgentSession<'a> {
445    agent: &'a Agent,
446    /// The session ID this session is bound to.
447    session_id: String,
448}
449
450impl AgentSession<'_> {
451    /// Send a message within this session.
452    ///
453    /// Equivalent to [`Agent::run()`] but uses this session's ID for memory.
454    ///
455    /// # Errors
456    ///
457    /// Returns an error if the provider fails, tool execution fails,
458    /// memory operations fail, or max iterations are reached.
459    pub async fn say(&self, input: &str) -> Result<AgentOutput> {
460        let runtime = self.agent.to_runtime();
461        self.agent
462            .strategy
463            .execute(&runtime, input, &self.session_id)
464            .await
465    }
466
467    /// Execute the agent loop with a custom session ID, returning a stream.
468    ///
469    /// Equivalent to [`Agent::stream()`] but uses this session's ID for memory.
470    #[must_use]
471    pub fn stream(&self, input: &str) -> AgentStream {
472        self.agent.stream_with_session(input, &self.session_id)
473    }
474
475    /// Get the session ID.
476    #[must_use]
477    pub fn id(&self) -> &str {
478        &self.session_id
479    }
480}
481
482#[cfg(test)]
483mod tests {
484    use super::*;
485
486    #[test]
487    fn test_run_usage_default() {
488        let u = RunUsage::default();
489        assert_eq!(u.tokens, 0);
490        assert_eq!(u.iterations, 0);
491        assert_eq!(u.duration, std::time::Duration::ZERO);
492    }
493
494    #[test]
495    fn test_text_output() {
496        let out = AgentOutput::text_with_usage("Hello".into(), RunUsage::default());
497        assert_eq!(out.text(), "Hello");
498        assert!(!out.is_error());
499        assert!(out.structured().is_none());
500        assert!(out.error_message().is_none());
501    }
502
503    #[test]
504    fn test_text_returns_empty_for_structured() {
505        let out = AgentOutput {
506            content: AgentOutputContent::Structured(serde_json::json!({"key": "val"})),
507            usage: RunUsage::default(),
508        };
509        assert_eq!(out.text(), "");
510        assert!(out.structured().is_some());
511        assert_eq!(out.structured().unwrap()["key"], "val");
512    }
513
514    #[test]
515    fn test_text_returns_empty_for_error() {
516        let out = AgentOutput {
517            content: AgentOutputContent::Error("boom".into()),
518            usage: RunUsage::default(),
519        };
520        assert_eq!(out.text(), "");
521        assert!(out.is_error());
522        assert_eq!(out.error_message(), Some("boom"));
523    }
524
525    #[test]
526    fn test_display_text() {
527        let out = AgentOutput::text_with_usage("hi".into(), RunUsage::default());
528        assert_eq!(format!("{out}"), "hi");
529    }
530
531    #[test]
532    fn test_display_structured() {
533        let out = AgentOutput {
534            content: AgentOutputContent::Structured(serde_json::json!(42)),
535            usage: RunUsage::default(),
536        };
537        assert_eq!(format!("{out}"), "42");
538    }
539
540    #[test]
541    fn test_display_error() {
542        let out = AgentOutput {
543            content: AgentOutputContent::Error("fail".into()),
544            usage: RunUsage::default(),
545        };
546        assert_eq!(format!("{out}"), "Error: fail");
547    }
548
549    #[test]
550    fn test_usage_carried_through() {
551        let usage = RunUsage {
552            tokens: 100,
553            iterations: 5,
554            duration: std::time::Duration::from_millis(500),
555        };
556        let out = AgentOutput::text_with_usage("x".into(), usage);
557        assert_eq!(out.usage.tokens, 100);
558        assert_eq!(out.usage.iterations, 5);
559        assert_eq!(out.usage.duration.as_millis(), 500);
560    }
561
562    // --- Agent::with_system() tests (Story 1.1) ---
563
564    use crate::types::completion::{CompletionRequest, CompletionResponse, ResponseContent, Usage};
565    use crate::types::model_info::{ModelInfo, ModelTier};
566    use crate::types::stream::CompletionStream;
567    use async_trait::async_trait;
568
569    struct MockProvider {
570        info: ModelInfo,
571    }
572
573    impl MockProvider {
574        fn new() -> Self {
575            Self {
576                info: ModelInfo::new("mock", ModelTier::Small, 4_096, false, false, false),
577            }
578        }
579    }
580
581    #[async_trait]
582    impl crate::traits::provider::Provider for MockProvider {
583        async fn complete(&self, _req: CompletionRequest) -> crate::Result<CompletionResponse> {
584            Ok(CompletionResponse {
585                content: ResponseContent::Text("ok".into()),
586                usage: Usage {
587                    prompt_tokens: 1,
588                    completion_tokens: 1,
589                    total_tokens: 2,
590                },
591            })
592        }
593        async fn stream(&self, _req: CompletionRequest) -> crate::Result<CompletionStream> {
594            unimplemented!()
595        }
596        fn model_info(&self) -> &ModelInfo {
597            &self.info
598        }
599    }
600
601    #[test]
602    fn test_with_system_str_prompt() {
603        // AC #1, #2: with_system accepts &str and creates a valid agent
604        let agent = Agent::with_system(MockProvider::new(), "You are helpful.");
605        assert_eq!(
606            agent.config.system_prompt.as_deref(),
607            Some("You are helpful.")
608        );
609    }
610
611    #[test]
612    fn test_with_system_string_prompt() {
613        // AC #2: with_system accepts String
614        let prompt = String::from("You are a researcher.");
615        let agent = Agent::with_system(MockProvider::new(), prompt);
616        assert_eq!(
617            agent.config.system_prompt.as_deref(),
618            Some("You are a researcher.")
619        );
620    }
621
622    #[test]
623    fn test_with_system_builder_unchanged() {
624        // AC #3: builder API is unchanged (still works)
625        let result = Agent::builder()
626            .provider(MockProvider::new())
627            .system("test")
628            .build();
629        assert!(result.is_ok());
630    }
631
632    #[test]
633    fn test_with_system_provider_configured() {
634        // AC #4: agent has correct provider
635        let agent = Agent::with_system(MockProvider::new(), "test");
636        assert_eq!(agent.provider.model_info().name, "mock");
637    }
638}