sayr_engine/
agent.rs

1use std::sync::Arc;
2
3use serde::Deserialize;
4use serde_json::Value;
5
6use crate::error::{AgnoError, Result};
7use crate::governance::{AccessController, Action, Principal, Role as GovernanceRole};
8use crate::hooks::{AgentHook, ConfirmationHandler};
9use crate::knowledge::Retriever;
10use crate::llm::{LanguageModel, ModelCompletion};
11use crate::memory::ConversationMemory;
12use crate::message::{Message, Role};
13use crate::metrics::{MetricsTracker, RunGuard};
14use crate::telemetry::{TelemetryCollector, TelemetryLabels};
15use crate::tool::ToolRegistry;
16
17/// Structured instructions the language model should emit.
18#[derive(Debug, Deserialize, PartialEq)]
19#[serde(tag = "action", rename_all = "snake_case")]
20pub enum AgentDirective {
21    Respond { content: String },
22    CallTool { name: String, arguments: Value },
23}
24
25/// An AGNO-style agent that alternates between the LLM and registered tools.
26pub struct Agent<M: LanguageModel> {
27    system_prompt: String,
28    model: Arc<M>,
29    tools: ToolRegistry,
30    memory: ConversationMemory,
31    max_steps: usize,
32    input_schema: Option<serde_json::Value>,
33    output_schema: Option<serde_json::Value>,
34    hooks: Vec<Arc<dyn AgentHook>>,
35    retriever: Option<Arc<dyn Retriever>>,
36    require_tool_confirmation: bool,
37    confirmation_handler: Option<Arc<dyn ConfirmationHandler>>,
38    access_control: Option<Arc<AccessController>>,
39    principal: Principal,
40    metrics: Option<MetricsTracker>,
41    telemetry: Option<TelemetryCollector>,
42    streaming: bool,
43    workflow_label: Option<String>,
44}
45
46impl<M: LanguageModel> Agent<M> {
47    pub fn new(model: Arc<M>) -> Self {
48        Self {
49            system_prompt: "You are a helpful agent.".to_string(),
50            model,
51            tools: ToolRegistry::new(),
52            memory: ConversationMemory::default(),
53            max_steps: 6,
54            input_schema: None,
55            output_schema: None,
56            hooks: Vec::new(),
57            retriever: None,
58            require_tool_confirmation: false,
59            confirmation_handler: None,
60            access_control: None,
61            principal: Principal {
62                id: "anonymous".into(),
63                role: GovernanceRole::User,
64                tenant: None,
65            },
66            metrics: None,
67            telemetry: None,
68            streaming: false,
69            workflow_label: None,
70        }
71    }
72
73    pub fn with_system_prompt(mut self, prompt: impl Into<String>) -> Self {
74        self.system_prompt = prompt.into();
75        self
76    }
77
78    pub fn with_tools(mut self, tools: ToolRegistry) -> Self {
79        self.tools = tools;
80        self
81    }
82
83    pub fn with_memory(mut self, memory: ConversationMemory) -> Self {
84        self.memory = memory;
85        self
86    }
87
88    pub fn with_access_control(mut self, controller: Arc<AccessController>) -> Self {
89        self.access_control = Some(controller);
90        self
91    }
92
93    pub fn with_principal(mut self, principal: Principal) -> Self {
94        self.principal = principal;
95        self
96    }
97
98    pub fn with_metrics(mut self, metrics: MetricsTracker) -> Self {
99        self.metrics = Some(metrics);
100        self
101    }
102
103    pub fn with_telemetry(mut self, telemetry: TelemetryCollector) -> Self {
104        self.telemetry = Some(telemetry);
105        self
106    }
107
108    pub fn with_workflow_label(mut self, workflow: impl Into<String>) -> Self {
109        self.workflow_label = Some(workflow.into());
110        self
111    }
112
113    pub fn with_input_schema(mut self, schema: serde_json::Value) -> Self {
114        self.input_schema = Some(schema);
115        self
116    }
117
118    pub fn with_output_schema(mut self, schema: serde_json::Value) -> Self {
119        self.output_schema = Some(schema);
120        self
121    }
122
123    pub fn with_hook(mut self, hook: Arc<dyn AgentHook>) -> Self {
124        self.hooks.push(hook);
125        self
126    }
127
128    pub fn with_retriever(mut self, retriever: Arc<dyn Retriever>) -> Self {
129        self.retriever = Some(retriever);
130        self
131    }
132
133    pub fn require_tool_confirmation(mut self, handler: Arc<dyn ConfirmationHandler>) -> Self {
134        self.require_tool_confirmation = true;
135        self.confirmation_handler = Some(handler);
136        self
137    }
138
139    pub fn with_max_steps(mut self, max_steps: usize) -> Self {
140        self.max_steps = max_steps.max(1);
141        self
142    }
143
144    pub fn with_streaming(mut self, streaming: bool) -> Self {
145        self.streaming = streaming;
146        self
147    }
148
149    pub fn tools_mut(&mut self) -> &mut ToolRegistry {
150        &mut self.tools
151    }
152
153    pub fn tool_names(&self) -> Vec<String> {
154        self.tools.names()
155    }
156
157    pub fn set_principal(&mut self, principal: Principal) {
158        self.principal = principal;
159    }
160
161    pub fn attach_access_control(&mut self, controller: Arc<AccessController>) {
162        self.access_control = Some(controller);
163    }
164
165    pub fn attach_metrics(&mut self, metrics: MetricsTracker) {
166        self.metrics = Some(metrics);
167    }
168
169    pub fn attach_telemetry(&mut self, telemetry: TelemetryCollector) {
170        self.telemetry = Some(telemetry);
171    }
172
173    pub fn memory(&self) -> &ConversationMemory {
174        &self.memory
175    }
176
177    pub fn sync_memory_from(&mut self, memory: &ConversationMemory) {
178        self.memory = memory.clone();
179    }
180
181    pub fn take_memory_snapshot(&self) -> ConversationMemory {
182        self.memory.clone()
183    }
184
185    /// Run a single exchange with the agent. Returns the final assistant reply.
186    pub async fn respond(&mut self, user_input: impl Into<String>) -> Result<String> {
187        let principal = self.principal.clone();
188        self.respond_for(principal, user_input).await
189    }
190
191    pub async fn respond_for(
192        &mut self,
193        principal: Principal,
194        user_input: impl Into<String>,
195    ) -> Result<String> {
196        if let Some(ctrl) = &self.access_control {
197            if !ctrl.authorize(&principal, &Action::SendMessage) {
198                return Err(AgnoError::Protocol(
199                    "principal not authorized to send messages".into(),
200                ));
201            }
202        }
203
204        let base_labels = TelemetryLabels {
205            tenant: principal.tenant.clone(),
206            tool: None,
207            workflow: self.workflow_label.clone(),
208        };
209        if let Some(telemetry) = &self.telemetry {
210            telemetry.record(
211                "user_message",
212                serde_json::json!({"principal": principal.id.clone(), "tenant": principal.tenant}),
213                base_labels.clone(),
214            );
215        }
216
217        let mut run_guard: Option<RunGuard> = self
218            .metrics
219            .as_ref()
220            .map(|m| m.start_run(base_labels.clone()));
221        self.memory.push(Message::user(user_input));
222
223        for _ in 0..self.max_steps {
224            let contexts = self.retrieve_contexts().await?;
225            let system_prompt = self.build_system_message(&contexts)?;
226            let mut request_messages = vec![Message::system(system_prompt)];
227            request_messages.extend(self.memory.iter().cloned());
228            let snapshot: Vec<Message> = request_messages.clone();
229            for hook in &self.hooks {
230                hook.before_model(snapshot.as_slice()).await?;
231            }
232            let completion = self
233                .model
234                .complete_chat(&request_messages, &self.tools.describe(), self.streaming)
235                .await?;
236            for hook in &self.hooks {
237                let serialized = serde_json::to_string(&completion)
238                    .unwrap_or_else(|_| "<unserializable>".into());
239                hook.after_model(&serialized).await?;
240            }
241
242            if !completion.tool_calls.is_empty() {
243                for mut call in completion.tool_calls {
244                    if call.id.is_none() {
245                        call.id = Some(format!("call-{}", self.memory.len()));
246                    }
247                    if let Some(ctrl) = &self.access_control {
248                        if !ctrl.authorize(&principal, &Action::CallTool(call.name.clone())) {
249                            if let Some(guard) = run_guard.as_mut() {
250                                guard.record_failure(Some(call.name.clone()));
251                            }
252                            return Err(AgnoError::Protocol(format!(
253                                "principal `{}` not allowed to call tool `{}`",
254                                principal.id, call.name
255                            )));
256                        }
257                    }
258                    if self.require_tool_confirmation {
259                        if let Some(handler) = &self.confirmation_handler {
260                            let approved = handler.confirm_tool_call(&call).await?;
261                            if !approved {
262                                self.memory.push(Message::assistant(format!(
263                                    "Tool call `{}` rejected by guardrail",
264                                    call.name
265                                )));
266                                continue;
267                            }
268                        }
269                    }
270                    if let Some(guard) = run_guard.as_mut() {
271                        guard.record_tool_call(call.name.clone());
272                    }
273                    let call_id = call.id.clone();
274                    self.memory.push(Message {
275                        role: Role::Assistant,
276                        content: format!("Calling tool `{}`", call.name),
277                        tool_call: Some(call.clone()),
278                        tool_result: None,
279                        attachments: Vec::new(),
280                    });
281
282                    for hook in &self.hooks {
283                        hook.before_tool_call(
284                            self.memory
285                                .iter()
286                                .last()
287                                .unwrap()
288                                .tool_call
289                                .as_ref()
290                                .unwrap(),
291                        )
292                        .await?;
293                    }
294                    let output = match self.tools.call(&call.name, call.arguments.clone()).await {
295                        Ok(value) => value,
296                        Err(err) => {
297                            if let Some(guard) = run_guard.as_mut() {
298                                guard.record_failure(Some(call.name.clone()));
299                            }
300                            if let Some(telemetry) = &self.telemetry {
301                                telemetry.record_failure(
302                                    format!("tool::{}", call.name),
303                                    format!("{err}"),
304                                    0,
305                                    base_labels.clone().with_tool(call.name.clone()),
306                                );
307                            }
308                            return Err(err);
309                        }
310                    };
311                    let result_message =
312                        Message::tool_with_call(&call.name, output, call_id.clone());
313                    for hook in &self.hooks {
314                        if let Some(result) = result_message.tool_result.as_ref() {
315                            hook.after_tool_result(result).await?;
316                        }
317                    }
318                    self.memory.push(result_message);
319                }
320                continue;
321            }
322
323            match completion {
324                ModelCompletion {
325                    content: Some(content),
326                    tool_calls,
327                } if tool_calls.is_empty() => {
328                    self.memory.push(Message::assistant(&content));
329                    if let Some(guard) = run_guard.take() {
330                        guard.finish(true);
331                    }
332                    return Ok(content);
333                }
334                _ => {
335                    if let Some(guard) = run_guard.as_mut() {
336                        guard.record_failure(None);
337                    }
338                    return Err(AgnoError::Protocol(
339                        "Model response missing content and tool calls".into(),
340                    ));
341                }
342            }
343        }
344
345        if let Some(guard) = run_guard {
346            guard.finish(false);
347        }
348
349        Err(AgnoError::Protocol(
350            "Agent reached the step limit without returning a response".into(),
351        ))
352    }
353
354    async fn retrieve_contexts(&self) -> Result<Vec<String>> {
355        if let Some(retriever) = &self.retriever {
356            return Ok(retriever
357                .retrieve(
358                    self.memory
359                        .iter()
360                        .rev()
361                        .find(|m| m.role == Role::User)
362                        .map(|m| m.content.as_str())
363                        .unwrap_or_default(),
364                    3,
365                )
366                .await
367                .unwrap_or_default());
368        }
369        Ok(Vec::new())
370    }
371
372    fn build_system_message(&self, contexts: &[String]) -> Result<String> {
373        let mut prompt = String::new();
374        prompt.push_str(&self.system_prompt);
375        prompt.push_str("\n\nWhen a tool is relevant, call it with appropriate JSON arguments. Return a direct response when no tool is needed.\n");
376        if let Some(schema) = &self.input_schema {
377            prompt.push_str(&format!(
378                "User input is expected to follow this JSON shape: {}\n\n",
379                schema
380            ));
381        }
382        if let Some(schema) = &self.output_schema {
383            prompt.push_str(&format!(
384                "When responding directly, conform to this output schema: {}\n",
385                schema
386            ));
387        }
388        if self.tools.names().is_empty() {
389            prompt.push_str("No tools are available.\n");
390        } else {
391            prompt.push_str("Available tools:\n");
392            for tool in self.tools.describe() {
393                prompt.push_str(&format!("- {}: {}", tool.name, tool.description));
394                if let Some(params) = &tool.parameters {
395                    prompt.push_str(&format!(" (parameters: {})", params));
396                }
397                prompt.push('\n');
398            }
399        }
400        if !contexts.is_empty() {
401            prompt.push_str("\nContext snippets:\n");
402            for ctx in contexts {
403                prompt.push_str("- ");
404                prompt.push_str(ctx);
405                prompt.push('\n');
406            }
407        }
408
409        Ok(prompt)
410    }
411}
412
413#[cfg(test)]
414mod tests {
415    use super::*;
416    use async_trait::async_trait;
417
418    use crate::tool::Tool;
419    use crate::StubModel;
420
421    struct EchoTool;
422
423    #[async_trait]
424    impl Tool for EchoTool {
425        fn name(&self) -> &str {
426            "echo"
427        }
428
429        fn description(&self) -> &str {
430            "Echoes the `text` field back"
431        }
432
433        async fn call(&self, input: Value) -> Result<Value> {
434            Ok(input)
435        }
436    }
437
438    #[tokio::test]
439    async fn returns_llm_response_without_tools() {
440        let model = StubModel::new(vec![r#"{"action":"respond","content":"Hello!"}"#.into()]);
441        let mut agent = Agent::new(model);
442
443        let reply = agent.respond("hi").await.unwrap();
444
445        assert_eq!(reply, "Hello!");
446        assert_eq!(agent.memory().len(), 2);
447    }
448
449    #[tokio::test]
450    async fn executes_tool_then_replies() {
451        let model = StubModel::new(vec![
452            r#"{"action":"call_tool","name":"echo","arguments":{"text":"ping"}}"#.into(),
453            r#"{"action":"respond","content":"Echoed your request."}"#.into(),
454        ]);
455        let mut tools = ToolRegistry::new();
456        tools.register(EchoTool);
457
458        let mut agent = Agent::new(model).with_tools(tools);
459
460        let reply = agent.respond("say ping").await.unwrap();
461
462        assert_eq!(reply, "Echoed your request.");
463        assert_eq!(agent.memory().len(), 4);
464    }
465
466    #[tokio::test]
467    async fn includes_tool_metadata_in_prompt() {
468        struct DescribingTool;
469
470        #[async_trait]
471        impl Tool for DescribingTool {
472            fn name(&self) -> &str {
473                "describe"
474            }
475
476            fn description(&self) -> &str {
477                "Replies with metadata"
478            }
479
480            fn parameters(&self) -> Option<Value> {
481                Some(serde_json::json!({"type":"object","properties":{"id":{"type":"string"}}}))
482            }
483
484            async fn call(&self, _input: Value) -> Result<Value> {
485                Ok(serde_json::json!({"ok": true}))
486            }
487        }
488
489        let model = StubModel::new(vec![r#"{"action":"respond","content":"done"}"#.into()]);
490        let mut tools = ToolRegistry::new();
491        tools.register(DescribingTool);
492
493        let agent = Agent::new(model).with_tools(tools);
494        let prompt = agent.build_system_message(&[]).unwrap();
495
496        assert!(prompt.contains("Replies with metadata"));
497        assert!(prompt.contains("Available tools"));
498    }
499}