1use crate::tool::ToolDef;
7use crate::types::{Message, Role, SgrError, ToolCall};
8use serde_json::Value;
9
10#[async_trait::async_trait]
12pub trait LlmClient: Send + Sync {
13 async fn structured_call(
16 &self,
17 messages: &[Message],
18 schema: &Value,
19 ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError>;
20
21 async fn tools_call(
23 &self,
24 messages: &[Message],
25 tools: &[ToolDef],
26 ) -> Result<Vec<ToolCall>, SgrError>;
27
28 async fn complete(&self, messages: &[Message]) -> Result<String, SgrError>;
30}
31
32pub fn synthesize_finish_if_empty(calls: &mut Vec<ToolCall>, content: &str) {
36 if calls.is_empty() {
37 let text = content.trim();
38 if !text.is_empty() {
39 calls.push(ToolCall {
40 id: "synth_finish".into(),
41 name: "finish".into(),
42 arguments: serde_json::json!({"summary": text}),
43 });
44 }
45 }
46}
47
48fn inject_schema(messages: &[Message], schema: &Value) -> Vec<Message> {
50 let schema_hint = format!(
51 "\n\nRespond with valid JSON matching this schema:\n{}\n\nDo NOT wrap in markdown code blocks. Output raw JSON only.",
52 serde_json::to_string_pretty(schema).unwrap_or_default()
53 );
54
55 let mut msgs = Vec::with_capacity(messages.len() + 1);
56 let mut injected = false;
57
58 for msg in messages {
59 if msg.role == Role::System && !injected {
60 msgs.push(Message::system(format!("{}{}", msg.content, schema_hint)));
62 injected = true;
63 } else {
64 msgs.push(msg.clone());
65 }
66 }
67
68 if !injected {
69 msgs.insert(0, Message::system(schema_hint));
71 }
72
73 msgs
74}
75
76#[cfg(feature = "gemini")]
77mod gemini_impl {
78 use super::*;
79 use crate::gemini::GeminiClient;
80
81 #[async_trait::async_trait]
82 impl LlmClient for GeminiClient {
83 async fn structured_call(
84 &self,
85 messages: &[Message],
86 schema: &Value,
87 ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError> {
88 let msgs = inject_schema(messages, schema);
89 let resp = self.flexible::<Value>(&msgs).await?;
90 Ok((resp.output, resp.tool_calls, resp.raw_text))
91 }
92
93 async fn tools_call(
94 &self,
95 messages: &[Message],
96 tools: &[ToolDef],
97 ) -> Result<Vec<ToolCall>, SgrError> {
98 self.tools_call(messages, tools).await
99 }
100
101 async fn complete(&self, messages: &[Message]) -> Result<String, SgrError> {
102 let resp = self.flexible::<Value>(messages).await?;
103 Ok(resp.raw_text)
104 }
105 }
106}
107
108#[cfg(feature = "openai")]
109mod openai_impl {
110 use super::*;
111 use crate::openai::OpenAIClient;
112
113 #[async_trait::async_trait]
114 impl LlmClient for OpenAIClient {
115 async fn structured_call(
116 &self,
117 messages: &[Message],
118 schema: &Value,
119 ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError> {
120 let msgs = inject_schema(messages, schema);
121 let resp = self.flexible::<Value>(&msgs).await?;
122 Ok((resp.output, resp.tool_calls, resp.raw_text))
123 }
124
125 async fn tools_call(
126 &self,
127 messages: &[Message],
128 tools: &[ToolDef],
129 ) -> Result<Vec<ToolCall>, SgrError> {
130 self.tools_call(messages, tools).await
131 }
132
133 async fn complete(&self, messages: &[Message]) -> Result<String, SgrError> {
134 let resp = self.flexible::<Value>(messages).await?;
135 Ok(resp.raw_text)
136 }
137 }
138}
139
140#[cfg(test)]
141mod tests {
142 use super::*;
143
144 #[test]
145 fn inject_schema_appends_to_existing_system() {
146 let msgs = vec![
147 Message::system("You are a coding agent."),
148 Message::user("hello"),
149 ];
150 let schema = serde_json::json!({"type": "object"});
151 let result = inject_schema(&msgs, &schema);
152
153 assert_eq!(result.len(), 2);
154 assert!(result[0].content.contains("You are a coding agent."));
155 assert!(result[0].content.contains("Respond with valid JSON"));
156 assert_eq!(result[0].role, Role::System);
157 }
158
159 #[test]
160 fn inject_schema_prepends_when_no_system() {
161 let msgs = vec![Message::user("hello")];
162 let schema = serde_json::json!({"type": "object"});
163 let result = inject_schema(&msgs, &schema);
164
165 assert_eq!(result.len(), 2);
166 assert_eq!(result[0].role, Role::System);
167 assert!(result[0].content.contains("Respond with valid JSON"));
168 assert_eq!(result[1].role, Role::User);
169 }
170
171 #[test]
172 fn inject_schema_only_first_system_message() {
173 let msgs = vec![
174 Message::system("System 1"),
175 Message::user("msg"),
176 Message::system("System 2"),
177 ];
178 let schema = serde_json::json!({"type": "object"});
179 let result = inject_schema(&msgs, &schema);
180
181 assert_eq!(result.len(), 3);
182 assert!(result[0].content.contains("Respond with valid JSON"));
184 assert_eq!(result[2].content, "System 2");
186 }
187}