1use super::{AiBackend, AiEvent, AiRequest, AiResponse};
2use anyhow::{Context, Result};
3use async_trait::async_trait;
4use tokio::process::Command;
5use tokio::sync::mpsc;
6
7const DEFAULT_MODEL: &str = "gpt-4.1";
8
9const ALLOWED_TOOLS: &[&str] = &[
13 "shell(git:diff)",
14 "shell(git:log)",
15 "shell(git:show)",
16 "shell(git:status)",
17 "shell(git:ls-files)",
18 "shell(git:rev-parse)",
19 "shell(git:branch)",
20 "shell(git:cat-file)",
21 "shell(git:rev-list)",
22 "shell(git:shortlog)",
23 "shell(git:blame)",
24];
25
26pub struct CopilotBackend {
27 model: Option<String>,
28 debug: bool,
29}
30
31impl CopilotBackend {
32 pub fn new(model: Option<String>, debug: bool) -> Self {
33 Self { model, debug }
34 }
35}
36
37fn build_system_prompt(base: &str, json_schema: Option<&str>) -> String {
39 match json_schema {
40 Some(schema) => format!(
41 "{base}\n\n\
42 You MUST respond with valid JSON matching this schema:\n\
43 ```json\n{schema}\n```\n\n\
44 Respond ONLY with the JSON object, no markdown fences, no explanation."
45 ),
46 None => base.to_string(),
47 }
48}
49
50#[async_trait]
51impl AiBackend for CopilotBackend {
52 fn name(&self) -> &str {
53 "copilot"
54 }
55
56 async fn is_available(&self) -> bool {
57 Command::new("gh")
58 .args(["copilot", "--version"])
59 .output()
60 .await
61 .is_ok_and(|o| o.status.success())
62 }
63
64 async fn request(
65 &self,
66 req: &AiRequest,
67 _events: Option<mpsc::UnboundedSender<AiEvent>>,
68 ) -> Result<AiResponse> {
69 let model = self.model.as_deref().unwrap_or(DEFAULT_MODEL);
70 let system = build_system_prompt(&req.system_prompt, req.json_schema.as_deref());
71
72 let mut cmd = Command::new("gh");
73 cmd.current_dir(&req.working_dir)
74 .arg("copilot")
75 .arg("-p")
76 .arg(&req.user_prompt)
77 .arg("-s")
78 .arg("--model")
79 .arg(model);
80
81 for tool in ALLOWED_TOOLS {
83 cmd.arg("--allow-tool").arg(tool);
84 }
85
86 cmd.arg("--no-custom-instructions")
87 .arg("--system-prompt")
88 .arg(&system);
89
90 if self.debug {
91 eprintln!("[DEBUG] Calling gh copilot (model={model})");
92 }
93
94 let output = cmd.output().await.context("failed to run gh copilot")?;
95
96 let raw = String::from_utf8_lossy(&output.stdout).to_string();
97 let stderr = String::from_utf8_lossy(&output.stderr);
98
99 if self.debug {
100 eprintln!("[DEBUG] gh copilot exit code: {}", output.status);
101 eprintln!(
102 "[DEBUG] Raw response (first 500 chars): {}",
103 &raw[..raw.len().min(500)]
104 );
105 if !stderr.is_empty() {
106 eprintln!("[DEBUG] Stderr: {stderr}");
107 }
108 }
109
110 if !output.status.success() {
111 anyhow::bail!(crate::error::SrAiError::AiBackend(format!(
112 "gh copilot failed (exit {}): {}",
113 output.status,
114 stderr.trim()
115 )));
116 }
117
118 let text = extract_json(&raw).unwrap_or(raw);
120
121 Ok(AiResponse { text, usage: None })
122 }
123}
124
125pub(crate) fn extract_json(raw: &str) -> Option<String> {
127 let trimmed = raw.trim();
128
129 if serde_json::from_str::<serde_json::Value>(trimmed).is_ok() {
131 return Some(trimmed.to_string());
132 }
133
134 if let Some(start) = trimmed.find("```json") {
136 let after = &trimmed[start + 7..];
137 if let Some(end) = after.find("```") {
138 let json_str = after[..end].trim();
139 if serde_json::from_str::<serde_json::Value>(json_str).is_ok() {
140 return Some(json_str.to_string());
141 }
142 }
143 }
144
145 if let Some(start) = trimmed.find("```") {
147 let after = &trimmed[start + 3..];
148 let after = if let Some(nl) = after.find('\n') {
149 &after[nl + 1..]
150 } else {
151 after
152 };
153 if let Some(end) = after.find("```") {
154 let json_str = after[..end].trim();
155 if serde_json::from_str::<serde_json::Value>(json_str).is_ok() {
156 return Some(json_str.to_string());
157 }
158 }
159 }
160
161 for (open, close) in [("{", "}"), ("[", "]")] {
163 if let Some(start) = trimmed.find(open)
164 && let Some(end) = trimmed.rfind(close)
165 && end > start
166 {
167 let candidate = &trimmed[start..=end];
168 if serde_json::from_str::<serde_json::Value>(candidate).is_ok() {
169 return Some(candidate.to_string());
170 }
171 }
172 }
173
174 None
175}
176
177#[cfg(test)]
178mod tests {
179 use super::*;
180
181 #[test]
184 fn extract_direct_json() {
185 let input = r#"{"commits": []}"#;
186 assert_eq!(extract_json(input), Some(input.to_string()));
187 }
188
189 #[test]
190 fn extract_from_json_fences() {
191 let input = "Here is the plan:\n```json\n{\"commits\": []}\n```\nDone.";
192 assert_eq!(extract_json(input), Some(r#"{"commits": []}"#.to_string()));
193 }
194
195 #[test]
196 fn extract_from_plain_fences() {
197 let input = "Result:\n```\n{\"commits\": [{\"order\": 1}]}\n```";
198 assert_eq!(
199 extract_json(input),
200 Some(r#"{"commits": [{"order": 1}]}"#.to_string())
201 );
202 }
203
204 #[test]
205 fn extract_from_surrounding_text() {
206 let input = "The result is {\"commits\": []} and that's it.";
207 assert_eq!(extract_json(input), Some(r#"{"commits": []}"#.to_string()));
208 }
209
210 #[test]
211 fn extract_array_json() {
212 let input = "Here: [1, 2, 3] done";
213 assert_eq!(extract_json(input), Some("[1, 2, 3]".to_string()));
214 }
215
216 #[test]
217 fn extract_returns_none_for_invalid() {
218 assert_eq!(extract_json("no json here"), None);
219 assert_eq!(extract_json(""), None);
220 assert_eq!(extract_json("{not valid json}"), None);
221 }
222
223 #[test]
224 fn extract_with_whitespace() {
225 let input = " \n {\"key\": \"value\"} \n ";
226 assert_eq!(extract_json(input), Some(r#"{"key": "value"}"#.to_string()));
227 }
228
229 #[test]
232 fn system_prompt_without_schema() {
233 let result = build_system_prompt("You are a commit assistant.", None);
234 assert_eq!(result, "You are a commit assistant.");
235 }
236
237 #[test]
238 fn system_prompt_with_schema() {
239 let schema = r#"{"type": "object"}"#;
240 let result = build_system_prompt("Base prompt.", Some(schema));
241 assert!(result.starts_with("Base prompt."));
242 assert!(result.contains("You MUST respond with valid JSON"));
243 assert!(result.contains(schema));
244 assert!(result.contains("no markdown fences"));
245 }
246
247 #[test]
250 fn backend_name() {
251 let backend = CopilotBackend::new(None, false);
252 assert_eq!(backend.name(), "copilot");
253 }
254
255 #[test]
256 fn default_model_constant() {
257 assert_eq!(DEFAULT_MODEL, "gpt-4.1");
258 }
259
260 #[test]
263 fn system_prompt_preserves_multiline_base() {
264 let base = "Line one.\nLine two.\nLine three.";
265 let result = build_system_prompt(base, None);
266 assert_eq!(result, base);
267
268 let with_schema = build_system_prompt(base, Some("{}"));
269 assert!(with_schema.starts_with(base));
270 }
271}