Skip to main content

sr_ai/ai/
copilot.rs

1use super::{AiBackend, AiRequest, AiResponse};
2use anyhow::{Context, Result};
3use async_trait::async_trait;
4use tokio::process::Command;
5
6const DEFAULT_MODEL: &str = "gpt-4.1";
7
8pub struct CopilotBackend {
9    model: Option<String>,
10    debug: bool,
11}
12
13impl CopilotBackend {
14    pub fn new(model: Option<String>, debug: bool) -> Self {
15        Self { model, debug }
16    }
17}
18
19/// Build the system prompt, embedding the JSON schema when present.
20fn build_system_prompt(base: &str, json_schema: Option<&str>) -> String {
21    match json_schema {
22        Some(schema) => format!(
23            "{base}\n\n\
24             You MUST respond with valid JSON matching this schema:\n\
25             ```json\n{schema}\n```\n\n\
26             Respond ONLY with the JSON object, no markdown fences, no explanation."
27        ),
28        None => base.to_string(),
29    }
30}
31
32#[async_trait]
33impl AiBackend for CopilotBackend {
34    fn name(&self) -> &str {
35        "copilot"
36    }
37
38    async fn is_available(&self) -> bool {
39        Command::new("gh")
40            .args(["copilot", "--version"])
41            .output()
42            .await
43            .is_ok_and(|o| o.status.success())
44    }
45
46    async fn request(&self, req: &AiRequest) -> Result<AiResponse> {
47        let model = self.model.as_deref().unwrap_or(DEFAULT_MODEL);
48        let system = build_system_prompt(&req.system_prompt, req.json_schema.as_deref());
49
50        let mut cmd = Command::new("gh");
51        cmd.current_dir(&req.working_dir)
52            .arg("copilot")
53            .arg("-p")
54            .arg(&req.user_prompt)
55            .arg("-s")
56            .arg("--model")
57            .arg(model)
58            .arg("--allow-tool")
59            .arg("shell(git:*)")
60            .arg("--no-custom-instructions")
61            .arg("--system-prompt")
62            .arg(&system);
63
64        if self.debug {
65            eprintln!("[DEBUG] Calling gh copilot (model={model})");
66        }
67
68        let output = cmd.output().await.context("failed to run gh copilot")?;
69
70        let raw = String::from_utf8_lossy(&output.stdout).to_string();
71        let stderr = String::from_utf8_lossy(&output.stderr);
72
73        if self.debug {
74            eprintln!("[DEBUG] gh copilot exit code: {}", output.status);
75            eprintln!(
76                "[DEBUG] Raw response (first 500 chars): {}",
77                &raw[..raw.len().min(500)]
78            );
79            if !stderr.is_empty() {
80                eprintln!("[DEBUG] Stderr: {stderr}");
81            }
82        }
83
84        if !output.status.success() {
85            anyhow::bail!(crate::error::SrAiError::AiBackend(format!(
86                "gh copilot failed (exit {}): {}",
87                output.status,
88                stderr.trim()
89            )));
90        }
91
92        // Extract JSON from response (may be wrapped in markdown fences)
93        let text = extract_json(&raw).unwrap_or(raw);
94
95        Ok(AiResponse { text })
96    }
97}
98
99/// Extract JSON from a response that may contain markdown code fences.
100pub(crate) fn extract_json(raw: &str) -> Option<String> {
101    let trimmed = raw.trim();
102
103    // Try direct parse first
104    if serde_json::from_str::<serde_json::Value>(trimmed).is_ok() {
105        return Some(trimmed.to_string());
106    }
107
108    // Try extracting from ```json ... ``` fences
109    if let Some(start) = trimmed.find("```json") {
110        let after = &trimmed[start + 7..];
111        if let Some(end) = after.find("```") {
112            let json_str = after[..end].trim();
113            if serde_json::from_str::<serde_json::Value>(json_str).is_ok() {
114                return Some(json_str.to_string());
115            }
116        }
117    }
118
119    // Try extracting from ``` ... ``` fences
120    if let Some(start) = trimmed.find("```") {
121        let after = &trimmed[start + 3..];
122        let after = if let Some(nl) = after.find('\n') {
123            &after[nl + 1..]
124        } else {
125            after
126        };
127        if let Some(end) = after.find("```") {
128            let json_str = after[..end].trim();
129            if serde_json::from_str::<serde_json::Value>(json_str).is_ok() {
130                return Some(json_str.to_string());
131            }
132        }
133    }
134
135    // Try finding first { ... } or [ ... ]
136    for (open, close) in [("{", "}"), ("[", "]")] {
137        if let Some(start) = trimmed.find(open)
138            && let Some(end) = trimmed.rfind(close)
139            && end > start
140        {
141            let candidate = &trimmed[start..=end];
142            if serde_json::from_str::<serde_json::Value>(candidate).is_ok() {
143                return Some(candidate.to_string());
144            }
145        }
146    }
147
148    None
149}
150
151#[cfg(test)]
152mod tests {
153    use super::*;
154
155    // --- extract_json tests ---
156
157    #[test]
158    fn extract_direct_json() {
159        let input = r#"{"commits": []}"#;
160        assert_eq!(extract_json(input), Some(input.to_string()));
161    }
162
163    #[test]
164    fn extract_from_json_fences() {
165        let input = "Here is the plan:\n```json\n{\"commits\": []}\n```\nDone.";
166        assert_eq!(extract_json(input), Some(r#"{"commits": []}"#.to_string()));
167    }
168
169    #[test]
170    fn extract_from_plain_fences() {
171        let input = "Result:\n```\n{\"commits\": [{\"order\": 1}]}\n```";
172        assert_eq!(
173            extract_json(input),
174            Some(r#"{"commits": [{"order": 1}]}"#.to_string())
175        );
176    }
177
178    #[test]
179    fn extract_from_surrounding_text() {
180        let input = "The result is {\"commits\": []} and that's it.";
181        assert_eq!(extract_json(input), Some(r#"{"commits": []}"#.to_string()));
182    }
183
184    #[test]
185    fn extract_array_json() {
186        let input = "Here: [1, 2, 3] done";
187        assert_eq!(extract_json(input), Some("[1, 2, 3]".to_string()));
188    }
189
190    #[test]
191    fn extract_returns_none_for_invalid() {
192        assert_eq!(extract_json("no json here"), None);
193        assert_eq!(extract_json(""), None);
194        assert_eq!(extract_json("{not valid json}"), None);
195    }
196
197    #[test]
198    fn extract_with_whitespace() {
199        let input = "  \n  {\"key\": \"value\"}  \n  ";
200        assert_eq!(extract_json(input), Some(r#"{"key": "value"}"#.to_string()));
201    }
202
203    // --- build_system_prompt tests ---
204
205    #[test]
206    fn system_prompt_without_schema() {
207        let result = build_system_prompt("You are a commit assistant.", None);
208        assert_eq!(result, "You are a commit assistant.");
209    }
210
211    #[test]
212    fn system_prompt_with_schema() {
213        let schema = r#"{"type": "object"}"#;
214        let result = build_system_prompt("Base prompt.", Some(schema));
215        assert!(result.starts_with("Base prompt."));
216        assert!(result.contains("You MUST respond with valid JSON"));
217        assert!(result.contains(schema));
218        assert!(result.contains("no markdown fences"));
219    }
220
221    // --- backend metadata tests ---
222
223    #[test]
224    fn backend_name() {
225        let backend = CopilotBackend::new(None, false);
226        assert_eq!(backend.name(), "copilot");
227    }
228
229    #[test]
230    fn default_model_constant() {
231        assert_eq!(DEFAULT_MODEL, "gpt-4.1");
232    }
233
234    // --- build_system_prompt preserves base content ---
235
236    #[test]
237    fn system_prompt_preserves_multiline_base() {
238        let base = "Line one.\nLine two.\nLine three.";
239        let result = build_system_prompt(base, None);
240        assert_eq!(result, base);
241
242        let with_schema = build_system_prompt(base, Some("{}"));
243        assert!(with_schema.starts_with(base));
244    }
245}