Skip to main content

sr_ai/ai/
copilot.rs

1use super::{AiBackend, AiEvent, AiRequest, AiResponse};
2use anyhow::{Context, Result};
3use async_trait::async_trait;
4use tokio::process::Command;
5use tokio::sync::mpsc;
6
7const DEFAULT_MODEL: &str = "gpt-4.1";
8
9/// Read-only tools the Copilot agent is allowed to use.
10/// Uses gh copilot's `--allow-tool` syntax: `shell(cmd:subcommand)`.
11/// No mutating git commands (add, commit, push, reset, clean, rm, etc.).
12const ALLOWED_TOOLS: &[&str] = &[
13    "shell(git:diff)",
14    "shell(git:log)",
15    "shell(git:show)",
16    "shell(git:status)",
17    "shell(git:ls-files)",
18    "shell(git:rev-parse)",
19    "shell(git:branch)",
20    "shell(git:cat-file)",
21    "shell(git:rev-list)",
22    "shell(git:shortlog)",
23    "shell(git:blame)",
24];
25
26pub struct CopilotBackend {
27    model: Option<String>,
28    debug: bool,
29}
30
31impl CopilotBackend {
32    pub fn new(model: Option<String>, debug: bool) -> Self {
33        Self { model, debug }
34    }
35}
36
37/// Build the system prompt, embedding the JSON schema when present.
38fn build_system_prompt(base: &str, json_schema: Option<&str>) -> String {
39    match json_schema {
40        Some(schema) => format!(
41            "{base}\n\n\
42             You MUST respond with valid JSON matching this schema:\n\
43             ```json\n{schema}\n```\n\n\
44             Respond ONLY with the JSON object, no markdown fences, no explanation."
45        ),
46        None => base.to_string(),
47    }
48}
49
50#[async_trait]
51impl AiBackend for CopilotBackend {
52    fn name(&self) -> &str {
53        "copilot"
54    }
55
56    async fn is_available(&self) -> bool {
57        Command::new("gh")
58            .args(["copilot", "--version"])
59            .output()
60            .await
61            .is_ok_and(|o| o.status.success())
62    }
63
64    async fn request(
65        &self,
66        req: &AiRequest,
67        _events: Option<mpsc::UnboundedSender<AiEvent>>,
68    ) -> Result<AiResponse> {
69        let model = self.model.as_deref().unwrap_or(DEFAULT_MODEL);
70        let system = build_system_prompt(&req.system_prompt, req.json_schema.as_deref());
71
72        let mut cmd = Command::new("gh");
73        cmd.current_dir(&req.working_dir)
74            .arg("copilot")
75            .arg("-p")
76            .arg(&req.user_prompt)
77            .arg("-s")
78            .arg("--model")
79            .arg(model);
80
81        // Sandbox: only allow read-only git subcommands.
82        for tool in ALLOWED_TOOLS {
83            cmd.arg("--allow-tool").arg(tool);
84        }
85
86        cmd.arg("--no-custom-instructions")
87            .arg("--system-prompt")
88            .arg(&system);
89
90        if self.debug {
91            eprintln!("[DEBUG] Calling gh copilot (model={model})");
92        }
93
94        let output = cmd.output().await.context("failed to run gh copilot")?;
95
96        let raw = String::from_utf8_lossy(&output.stdout).to_string();
97        let stderr = String::from_utf8_lossy(&output.stderr);
98
99        if self.debug {
100            eprintln!("[DEBUG] gh copilot exit code: {}", output.status);
101            eprintln!(
102                "[DEBUG] Raw response (first 500 chars): {}",
103                &raw[..raw.len().min(500)]
104            );
105            if !stderr.is_empty() {
106                eprintln!("[DEBUG] Stderr: {stderr}");
107            }
108        }
109
110        if !output.status.success() {
111            anyhow::bail!(crate::error::SrAiError::AiBackend(format!(
112                "gh copilot failed (exit {}): {}",
113                output.status,
114                stderr.trim()
115            )));
116        }
117
118        // Extract JSON from response (may be wrapped in markdown fences)
119        let text = extract_json(&raw).unwrap_or(raw);
120
121        Ok(AiResponse { text, usage: None })
122    }
123}
124
125/// Extract JSON from a response that may contain markdown code fences.
126pub(crate) fn extract_json(raw: &str) -> Option<String> {
127    let trimmed = raw.trim();
128
129    // Try direct parse first
130    if serde_json::from_str::<serde_json::Value>(trimmed).is_ok() {
131        return Some(trimmed.to_string());
132    }
133
134    // Try extracting from ```json ... ``` fences
135    if let Some(start) = trimmed.find("```json") {
136        let after = &trimmed[start + 7..];
137        if let Some(end) = after.find("```") {
138            let json_str = after[..end].trim();
139            if serde_json::from_str::<serde_json::Value>(json_str).is_ok() {
140                return Some(json_str.to_string());
141            }
142        }
143    }
144
145    // Try extracting from ``` ... ``` fences
146    if let Some(start) = trimmed.find("```") {
147        let after = &trimmed[start + 3..];
148        let after = if let Some(nl) = after.find('\n') {
149            &after[nl + 1..]
150        } else {
151            after
152        };
153        if let Some(end) = after.find("```") {
154            let json_str = after[..end].trim();
155            if serde_json::from_str::<serde_json::Value>(json_str).is_ok() {
156                return Some(json_str.to_string());
157            }
158        }
159    }
160
161    // Try finding first { ... } or [ ... ]
162    for (open, close) in [("{", "}"), ("[", "]")] {
163        if let Some(start) = trimmed.find(open)
164            && let Some(end) = trimmed.rfind(close)
165            && end > start
166        {
167            let candidate = &trimmed[start..=end];
168            if serde_json::from_str::<serde_json::Value>(candidate).is_ok() {
169                return Some(candidate.to_string());
170            }
171        }
172    }
173
174    None
175}
176
177#[cfg(test)]
178mod tests {
179    use super::*;
180
181    // --- extract_json tests ---
182
183    #[test]
184    fn extract_direct_json() {
185        let input = r#"{"commits": []}"#;
186        assert_eq!(extract_json(input), Some(input.to_string()));
187    }
188
189    #[test]
190    fn extract_from_json_fences() {
191        let input = "Here is the plan:\n```json\n{\"commits\": []}\n```\nDone.";
192        assert_eq!(extract_json(input), Some(r#"{"commits": []}"#.to_string()));
193    }
194
195    #[test]
196    fn extract_from_plain_fences() {
197        let input = "Result:\n```\n{\"commits\": [{\"order\": 1}]}\n```";
198        assert_eq!(
199            extract_json(input),
200            Some(r#"{"commits": [{"order": 1}]}"#.to_string())
201        );
202    }
203
204    #[test]
205    fn extract_from_surrounding_text() {
206        let input = "The result is {\"commits\": []} and that's it.";
207        assert_eq!(extract_json(input), Some(r#"{"commits": []}"#.to_string()));
208    }
209
210    #[test]
211    fn extract_array_json() {
212        let input = "Here: [1, 2, 3] done";
213        assert_eq!(extract_json(input), Some("[1, 2, 3]".to_string()));
214    }
215
216    #[test]
217    fn extract_returns_none_for_invalid() {
218        assert_eq!(extract_json("no json here"), None);
219        assert_eq!(extract_json(""), None);
220        assert_eq!(extract_json("{not valid json}"), None);
221    }
222
223    #[test]
224    fn extract_with_whitespace() {
225        let input = "  \n  {\"key\": \"value\"}  \n  ";
226        assert_eq!(extract_json(input), Some(r#"{"key": "value"}"#.to_string()));
227    }
228
229    // --- build_system_prompt tests ---
230
231    #[test]
232    fn system_prompt_without_schema() {
233        let result = build_system_prompt("You are a commit assistant.", None);
234        assert_eq!(result, "You are a commit assistant.");
235    }
236
237    #[test]
238    fn system_prompt_with_schema() {
239        let schema = r#"{"type": "object"}"#;
240        let result = build_system_prompt("Base prompt.", Some(schema));
241        assert!(result.starts_with("Base prompt."));
242        assert!(result.contains("You MUST respond with valid JSON"));
243        assert!(result.contains(schema));
244        assert!(result.contains("no markdown fences"));
245    }
246
247    // --- backend metadata tests ---
248
249    #[test]
250    fn backend_name() {
251        let backend = CopilotBackend::new(None, false);
252        assert_eq!(backend.name(), "copilot");
253    }
254
255    #[test]
256    fn default_model_constant() {
257        assert_eq!(DEFAULT_MODEL, "gpt-4.1");
258    }
259
260    // --- build_system_prompt preserves base content ---
261
262    #[test]
263    fn system_prompt_preserves_multiline_base() {
264        let base = "Line one.\nLine two.\nLine three.";
265        let result = build_system_prompt(base, None);
266        assert_eq!(result, base);
267
268        let with_schema = build_system_prompt(base, Some("{}"));
269        assert!(with_schema.starts_with(base));
270    }
271}