1use crate::tool::ToolDef;
7use crate::types::{Message, Role, SgrError, ToolCall};
8use serde_json::Value;
9
10#[async_trait::async_trait]
12pub trait LlmClient: Send + Sync {
13 async fn structured_call(
16 &self,
17 messages: &[Message],
18 schema: &Value,
19 ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError>;
20
21 async fn tools_call(
24 &self,
25 messages: &[Message],
26 tools: &[ToolDef],
27 ) -> Result<Vec<ToolCall>, SgrError>;
28
29 async fn tools_call_stateful(
33 &self,
34 messages: &[Message],
35 tools: &[ToolDef],
36 _previous_response_id: Option<&str>,
37 ) -> Result<(Vec<ToolCall>, Option<String>), SgrError> {
38 let calls = self.tools_call(messages, tools).await?;
40 Ok((calls, None))
41 }
42
43 async fn tools_call_with_text(
47 &self,
48 messages: &[Message],
49 tools: &[ToolDef],
50 ) -> Result<(Vec<ToolCall>, String), SgrError> {
51 let calls = self.tools_call(messages, tools).await?;
52 Ok((calls, String::new()))
53 }
54
55 async fn complete(&self, messages: &[Message]) -> Result<String, SgrError>;
57}
58
59pub fn synthesize_finish_if_empty(calls: &mut Vec<ToolCall>, content: &str) {
63 if calls.is_empty() {
64 let text = content.trim();
65 if !text.is_empty() {
66 calls.push(ToolCall {
67 id: "synth_finish".into(),
68 name: "finish".into(),
69 arguments: serde_json::json!({"summary": text}),
70 });
71 }
72 }
73}
74
75fn inject_schema(messages: &[Message], schema: &Value) -> Vec<Message> {
77 let schema_hint = format!(
78 "\n\nRespond with valid JSON matching this schema:\n{}\n\nDo NOT wrap in markdown code blocks. Output raw JSON only.",
79 serde_json::to_string_pretty(schema).unwrap_or_default()
80 );
81
82 let mut msgs = Vec::with_capacity(messages.len() + 1);
83 let mut injected = false;
84
85 for msg in messages {
86 if msg.role == Role::System && !injected {
87 msgs.push(Message::system(format!("{}{}", msg.content, schema_hint)));
89 injected = true;
90 } else {
91 msgs.push(msg.clone());
92 }
93 }
94
95 if !injected {
96 msgs.insert(0, Message::system(schema_hint));
98 }
99
100 msgs
101}
102
103#[cfg(feature = "gemini")]
104mod gemini_impl {
105 use super::*;
106 use crate::gemini::GeminiClient;
107
108 #[async_trait::async_trait]
109 impl LlmClient for GeminiClient {
110 async fn structured_call(
111 &self,
112 messages: &[Message],
113 schema: &Value,
114 ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError> {
115 let msgs = inject_schema(messages, schema);
116 let resp = self.flexible::<Value>(&msgs).await?;
117 Ok((resp.output, resp.tool_calls, resp.raw_text))
118 }
119
120 async fn tools_call(
121 &self,
122 messages: &[Message],
123 tools: &[ToolDef],
124 ) -> Result<Vec<ToolCall>, SgrError> {
125 self.tools_call(messages, tools).await
126 }
127
128 async fn complete(&self, messages: &[Message]) -> Result<String, SgrError> {
129 let resp = self.flexible::<Value>(messages).await?;
130 Ok(resp.raw_text)
131 }
132 }
133}
134
135#[cfg(feature = "openai")]
136mod openai_impl {
137 use super::*;
138 use crate::openai::OpenAIClient;
139
140 #[async_trait::async_trait]
141 impl LlmClient for OpenAIClient {
142 async fn structured_call(
143 &self,
144 messages: &[Message],
145 schema: &Value,
146 ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError> {
147 let msgs = inject_schema(messages, schema);
148 let resp = self.flexible::<Value>(&msgs).await?;
149 Ok((resp.output, resp.tool_calls, resp.raw_text))
150 }
151
152 async fn tools_call(
153 &self,
154 messages: &[Message],
155 tools: &[ToolDef],
156 ) -> Result<Vec<ToolCall>, SgrError> {
157 self.tools_call(messages, tools).await
158 }
159
160 async fn complete(&self, messages: &[Message]) -> Result<String, SgrError> {
161 let resp = self.flexible::<Value>(messages).await?;
162 Ok(resp.raw_text)
163 }
164 }
165}
166
167#[cfg(test)]
168mod tests {
169 use super::*;
170 use crate::tool::ToolDef;
171
172 struct MockStatelessClient;
175
176 #[async_trait::async_trait]
177 impl LlmClient for MockStatelessClient {
178 async fn structured_call(
179 &self,
180 _: &[Message],
181 _: &Value,
182 ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError> {
183 Ok((None, vec![], String::new()))
184 }
185 async fn tools_call(
186 &self,
187 _: &[Message],
188 _: &[ToolDef],
189 ) -> Result<Vec<ToolCall>, SgrError> {
190 Ok(vec![ToolCall {
191 id: "tc1".into(),
192 name: "test_tool".into(),
193 arguments: serde_json::json!({"x": 1}),
194 }])
195 }
196 async fn complete(&self, _: &[Message]) -> Result<String, SgrError> {
197 Ok(String::new())
198 }
199 }
200
201 #[tokio::test]
202 async fn tools_call_stateful_default_delegates() {
203 let client = MockStatelessClient;
204 let msgs = vec![Message::user("hi")];
205 let tools = vec![ToolDef {
206 name: "test_tool".into(),
207 description: "test".into(),
208 parameters: serde_json::json!({"type": "object"}),
209 }];
210
211 let (calls, response_id) = client
213 .tools_call_stateful(&msgs, &tools, None)
214 .await
215 .unwrap();
216 assert_eq!(calls.len(), 1);
217 assert_eq!(calls[0].name, "test_tool");
218 assert!(response_id.is_none(), "default impl returns no response_id");
219
220 let (calls, response_id) = client
222 .tools_call_stateful(&msgs, &tools, Some("resp_abc"))
223 .await
224 .unwrap();
225 assert_eq!(calls.len(), 1);
226 assert!(response_id.is_none());
227 }
228
229 #[test]
230 fn inject_schema_appends_to_existing_system() {
231 let msgs = vec![
232 Message::system("You are a coding agent."),
233 Message::user("hello"),
234 ];
235 let schema = serde_json::json!({"type": "object"});
236 let result = inject_schema(&msgs, &schema);
237
238 assert_eq!(result.len(), 2);
239 assert!(result[0].content.contains("You are a coding agent."));
240 assert!(result[0].content.contains("Respond with valid JSON"));
241 assert_eq!(result[0].role, Role::System);
242 }
243
244 #[test]
245 fn inject_schema_prepends_when_no_system() {
246 let msgs = vec![Message::user("hello")];
247 let schema = serde_json::json!({"type": "object"});
248 let result = inject_schema(&msgs, &schema);
249
250 assert_eq!(result.len(), 2);
251 assert_eq!(result[0].role, Role::System);
252 assert!(result[0].content.contains("Respond with valid JSON"));
253 assert_eq!(result[1].role, Role::User);
254 }
255
256 #[test]
257 fn inject_schema_only_first_system_message() {
258 let msgs = vec![
259 Message::system("System 1"),
260 Message::user("msg"),
261 Message::system("System 2"),
262 ];
263 let schema = serde_json::json!({"type": "object"});
264 let result = inject_schema(&msgs, &schema);
265
266 assert_eq!(result.len(), 3);
267 assert!(result[0].content.contains("Respond with valid JSON"));
269 assert_eq!(result[2].content, "System 2");
271 }
272}