Skip to main content

sgr_agent/
client.rs

1//! LlmClient trait — abstract LLM backend for agent use.
2//!
3//! Implementations wrap `GeminiClient` / `OpenAIClient` existing methods.
4//! `structured_call` injects the schema into the system prompt for flexible parsing.
5
6use crate::tool::ToolDef;
7use crate::types::{Message, Role, SgrError, ToolCall};
8use serde_json::Value;
9
10/// Abstract LLM client for agent framework.
11#[async_trait::async_trait]
12pub trait LlmClient: Send + Sync {
13    /// Structured call: send messages with schema injected into system prompt.
14    /// Returns (parsed_output, native_tool_calls, raw_text).
15    async fn structured_call(
16        &self,
17        messages: &[Message],
18        schema: &Value,
19    ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError>;
20
21    /// Native function calling: send messages + tool defs, get tool calls.
22    async fn tools_call(
23        &self,
24        messages: &[Message],
25        tools: &[ToolDef],
26    ) -> Result<Vec<ToolCall>, SgrError>;
27
28    /// Plain text completion (no schema, no tools).
29    async fn complete(&self, messages: &[Message]) -> Result<String, SgrError>;
30}
31
32/// When a model responds with text content instead of tool calls,
33/// synthesize a "finish" tool call so the agent loop gets the answer.
34/// Call this in `tools_call` implementations after extracting tool calls.
35pub fn synthesize_finish_if_empty(calls: &mut Vec<ToolCall>, content: &str) {
36    if calls.is_empty() {
37        let text = content.trim();
38        if !text.is_empty() {
39            calls.push(ToolCall {
40                id: "synth_finish".into(),
41                name: "finish".into(),
42                arguments: serde_json::json!({"summary": text}),
43            });
44        }
45    }
46}
47
48/// Inject schema into messages: append to existing system message or prepend a new one.
49fn inject_schema(messages: &[Message], schema: &Value) -> Vec<Message> {
50    let schema_hint = format!(
51        "\n\nRespond with valid JSON matching this schema:\n{}\n\nDo NOT wrap in markdown code blocks. Output raw JSON only.",
52        serde_json::to_string_pretty(schema).unwrap_or_default()
53    );
54
55    let mut msgs = Vec::with_capacity(messages.len() + 1);
56    let mut injected = false;
57
58    for msg in messages {
59        if msg.role == Role::System && !injected {
60            // Append schema to existing system message
61            msgs.push(Message::system(format!("{}{}", msg.content, schema_hint)));
62            injected = true;
63        } else {
64            msgs.push(msg.clone());
65        }
66    }
67
68    if !injected {
69        // No system message found — prepend one
70        msgs.insert(0, Message::system(schema_hint));
71    }
72
73    msgs
74}
75
76#[cfg(feature = "gemini")]
77mod gemini_impl {
78    use super::*;
79    use crate::gemini::GeminiClient;
80
81    #[async_trait::async_trait]
82    impl LlmClient for GeminiClient {
83        async fn structured_call(
84            &self,
85            messages: &[Message],
86            schema: &Value,
87        ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError> {
88            let msgs = inject_schema(messages, schema);
89            let resp = self.flexible::<Value>(&msgs).await?;
90            Ok((resp.output, resp.tool_calls, resp.raw_text))
91        }
92
93        async fn tools_call(
94            &self,
95            messages: &[Message],
96            tools: &[ToolDef],
97        ) -> Result<Vec<ToolCall>, SgrError> {
98            self.tools_call(messages, tools).await
99        }
100
101        async fn complete(&self, messages: &[Message]) -> Result<String, SgrError> {
102            let resp = self.flexible::<Value>(messages).await?;
103            Ok(resp.raw_text)
104        }
105    }
106}
107
108#[cfg(feature = "openai")]
109mod openai_impl {
110    use super::*;
111    use crate::openai::OpenAIClient;
112
113    #[async_trait::async_trait]
114    impl LlmClient for OpenAIClient {
115        async fn structured_call(
116            &self,
117            messages: &[Message],
118            schema: &Value,
119        ) -> Result<(Option<Value>, Vec<ToolCall>, String), SgrError> {
120            let msgs = inject_schema(messages, schema);
121            let resp = self.flexible::<Value>(&msgs).await?;
122            Ok((resp.output, resp.tool_calls, resp.raw_text))
123        }
124
125        async fn tools_call(
126            &self,
127            messages: &[Message],
128            tools: &[ToolDef],
129        ) -> Result<Vec<ToolCall>, SgrError> {
130            self.tools_call(messages, tools).await
131        }
132
133        async fn complete(&self, messages: &[Message]) -> Result<String, SgrError> {
134            let resp = self.flexible::<Value>(messages).await?;
135            Ok(resp.raw_text)
136        }
137    }
138}
139
140#[cfg(test)]
141mod tests {
142    use super::*;
143
144    #[test]
145    fn inject_schema_appends_to_existing_system() {
146        let msgs = vec![
147            Message::system("You are a coding agent."),
148            Message::user("hello"),
149        ];
150        let schema = serde_json::json!({"type": "object"});
151        let result = inject_schema(&msgs, &schema);
152
153        assert_eq!(result.len(), 2);
154        assert!(result[0].content.contains("You are a coding agent."));
155        assert!(result[0].content.contains("Respond with valid JSON"));
156        assert_eq!(result[0].role, Role::System);
157    }
158
159    #[test]
160    fn inject_schema_prepends_when_no_system() {
161        let msgs = vec![Message::user("hello")];
162        let schema = serde_json::json!({"type": "object"});
163        let result = inject_schema(&msgs, &schema);
164
165        assert_eq!(result.len(), 2);
166        assert_eq!(result[0].role, Role::System);
167        assert!(result[0].content.contains("Respond with valid JSON"));
168        assert_eq!(result[1].role, Role::User);
169    }
170
171    #[test]
172    fn inject_schema_only_first_system_message() {
173        let msgs = vec![
174            Message::system("System 1"),
175            Message::user("msg"),
176            Message::system("System 2"),
177        ];
178        let schema = serde_json::json!({"type": "object"});
179        let result = inject_schema(&msgs, &schema);
180
181        assert_eq!(result.len(), 3);
182        // First system gets schema
183        assert!(result[0].content.contains("Respond with valid JSON"));
184        // Second system unchanged
185        assert_eq!(result[2].content, "System 2");
186    }
187}