Skip to main content

sgr_agent/
client.rs

1//! LlmClient trait — abstract LLM backend for agent use.
2//!
3//! Implementations wrap `GeminiClient` / `OpenAIClient` existing methods.
4//! `structured_call` injects the schema into the system prompt for flexible parsing.
5
6use crate::tool::ToolDef;
7use crate::types::{Message, Role, SgrError, ToolCall};
8use serde_json::Value;
9
10/// Abstract LLM client for agent framework.
11#[async_trait::async_trait]
12pub trait LlmClient: Send + Sync {
13    /// Structured call: send messages with schema injected into system prompt.
14    /// Returns (parsed_output, native_tool_calls, raw_text).
15    async fn structured_call(
16        &self,
17        messages: &[Message],
18        schema: &Value,
19    ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError>;
20
21    /// Native function calling: send messages + tool defs, get tool calls.
22    /// This is STATELESS — no side effects on shared state.
23    async fn tools_call(
24        &self,
25        messages: &[Message],
26        tools: &[ToolDef],
27    ) -> Result<Vec<ToolCall>, SgrError>;
28
29    /// Stateful function calling with explicit response_id for chaining.
30    /// Returns (tool_calls, new_response_id).
31    /// When previous_response_id is Some, only delta messages are needed.
32    async fn tools_call_stateful(
33        &self,
34        messages: &[Message],
35        tools: &[ToolDef],
36        _previous_response_id: Option<&str>,
37    ) -> Result<(Vec<ToolCall>, Option<String>), SgrError> {
38        // Default: delegate to stateless tools_call, no chaining
39        let calls = self.tools_call(messages, tools).await?;
40        Ok((calls, None))
41    }
42
43    /// Plain text completion (no schema, no tools).
44    async fn complete(&self, messages: &[Message]) -> Result<String, SgrError>;
45}
46
47/// When a model responds with text content instead of tool calls,
48/// synthesize a "finish" tool call so the agent loop gets the answer.
49/// Call this in `tools_call` implementations after extracting tool calls.
50pub fn synthesize_finish_if_empty(calls: &mut Vec<ToolCall>, content: &str) {
51    if calls.is_empty() {
52        let text = content.trim();
53        if !text.is_empty() {
54            calls.push(ToolCall {
55                id: "synth_finish".into(),
56                name: "finish".into(),
57                arguments: serde_json::json!({"summary": text}),
58            });
59        }
60    }
61}
62
63/// Inject schema into messages: append to existing system message or prepend a new one.
64fn inject_schema(messages: &[Message], schema: &Value) -> Vec<Message> {
65    let schema_hint = format!(
66        "\n\nRespond with valid JSON matching this schema:\n{}\n\nDo NOT wrap in markdown code blocks. Output raw JSON only.",
67        serde_json::to_string_pretty(schema).unwrap_or_default()
68    );
69
70    let mut msgs = Vec::with_capacity(messages.len() + 1);
71    let mut injected = false;
72
73    for msg in messages {
74        if msg.role == Role::System && !injected {
75            // Append schema to existing system message
76            msgs.push(Message::system(format!("{}{}", msg.content, schema_hint)));
77            injected = true;
78        } else {
79            msgs.push(msg.clone());
80        }
81    }
82
83    if !injected {
84        // No system message found — prepend one
85        msgs.insert(0, Message::system(schema_hint));
86    }
87
88    msgs
89}
90
91#[cfg(feature = "gemini")]
92mod gemini_impl {
93    use super::*;
94    use crate::gemini::GeminiClient;
95
96    #[async_trait::async_trait]
97    impl LlmClient for GeminiClient {
98        async fn structured_call(
99            &self,
100            messages: &[Message],
101            schema: &Value,
102        ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError> {
103            let msgs = inject_schema(messages, schema);
104            let resp = self.flexible::<Value>(&msgs).await?;
105            Ok((resp.output, resp.tool_calls, resp.raw_text))
106        }
107
108        async fn tools_call(
109            &self,
110            messages: &[Message],
111            tools: &[ToolDef],
112        ) -> Result<Vec<ToolCall>, SgrError> {
113            self.tools_call(messages, tools).await
114        }
115
116        async fn complete(&self, messages: &[Message]) -> Result<String, SgrError> {
117            let resp = self.flexible::<Value>(messages).await?;
118            Ok(resp.raw_text)
119        }
120    }
121}
122
123#[cfg(feature = "openai")]
124mod openai_impl {
125    use super::*;
126    use crate::openai::OpenAIClient;
127
128    #[async_trait::async_trait]
129    impl LlmClient for OpenAIClient {
130        async fn structured_call(
131            &self,
132            messages: &[Message],
133            schema: &Value,
134        ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError> {
135            let msgs = inject_schema(messages, schema);
136            let resp = self.flexible::<Value>(&msgs).await?;
137            Ok((resp.output, resp.tool_calls, resp.raw_text))
138        }
139
140        async fn tools_call(
141            &self,
142            messages: &[Message],
143            tools: &[ToolDef],
144        ) -> Result<Vec<ToolCall>, SgrError> {
145            self.tools_call(messages, tools).await
146        }
147
148        async fn complete(&self, messages: &[Message]) -> Result<String, SgrError> {
149            let resp = self.flexible::<Value>(messages).await?;
150            Ok(resp.raw_text)
151        }
152    }
153}
154
155#[cfg(test)]
156mod tests {
157    use super::*;
158    use crate::tool::ToolDef;
159
160    /// Mock client that only implements the required trait methods.
161    /// tools_call_stateful uses the default impl (delegates to tools_call).
162    struct MockStatelessClient;
163
164    #[async_trait::async_trait]
165    impl LlmClient for MockStatelessClient {
166        async fn structured_call(
167            &self,
168            _: &[Message],
169            _: &Value,
170        ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError> {
171            Ok((None, vec![], String::new()))
172        }
173        async fn tools_call(
174            &self,
175            _: &[Message],
176            _: &[ToolDef],
177        ) -> Result<Vec<ToolCall>, SgrError> {
178            Ok(vec![ToolCall {
179                id: "tc1".into(),
180                name: "test_tool".into(),
181                arguments: serde_json::json!({"x": 1}),
182            }])
183        }
184        async fn complete(&self, _: &[Message]) -> Result<String, SgrError> {
185            Ok(String::new())
186        }
187    }
188
189    #[tokio::test]
190    async fn tools_call_stateful_default_delegates() {
191        let client = MockStatelessClient;
192        let msgs = vec![Message::user("hi")];
193        let tools = vec![ToolDef {
194            name: "test_tool".into(),
195            description: "test".into(),
196            parameters: serde_json::json!({"type": "object"}),
197        }];
198
199        // Default impl delegates to tools_call, returns None for response_id
200        let (calls, response_id) = client
201            .tools_call_stateful(&msgs, &tools, None)
202            .await
203            .unwrap();
204        assert_eq!(calls.len(), 1);
205        assert_eq!(calls[0].name, "test_tool");
206        assert!(response_id.is_none(), "default impl returns no response_id");
207
208        // With previous_response_id — still delegates to stateless, ignores it
209        let (calls, response_id) = client
210            .tools_call_stateful(&msgs, &tools, Some("resp_abc"))
211            .await
212            .unwrap();
213        assert_eq!(calls.len(), 1);
214        assert!(response_id.is_none());
215    }
216
217    #[test]
218    fn inject_schema_appends_to_existing_system() {
219        let msgs = vec![
220            Message::system("You are a coding agent."),
221            Message::user("hello"),
222        ];
223        let schema = serde_json::json!({"type": "object"});
224        let result = inject_schema(&msgs, &schema);
225
226        assert_eq!(result.len(), 2);
227        assert!(result[0].content.contains("You are a coding agent."));
228        assert!(result[0].content.contains("Respond with valid JSON"));
229        assert_eq!(result[0].role, Role::System);
230    }
231
232    #[test]
233    fn inject_schema_prepends_when_no_system() {
234        let msgs = vec![Message::user("hello")];
235        let schema = serde_json::json!({"type": "object"});
236        let result = inject_schema(&msgs, &schema);
237
238        assert_eq!(result.len(), 2);
239        assert_eq!(result[0].role, Role::System);
240        assert!(result[0].content.contains("Respond with valid JSON"));
241        assert_eq!(result[1].role, Role::User);
242    }
243
244    #[test]
245    fn inject_schema_only_first_system_message() {
246        let msgs = vec![
247            Message::system("System 1"),
248            Message::user("msg"),
249            Message::system("System 2"),
250        ];
251        let schema = serde_json::json!({"type": "object"});
252        let result = inject_schema(&msgs, &schema);
253
254        assert_eq!(result.len(), 3);
255        // First system gets schema
256        assert!(result[0].content.contains("Respond with valid JSON"));
257        // Second system unchanged
258        assert_eq!(result[2].content, "System 2");
259    }
260}