1use super::{GroupingResponse, SemanticGroup};
2use std::collections::HashSet;
3use std::time::Duration;
4use tokio::process::Command;
5
6const MAX_RESPONSE_BYTES: usize = 1_048_576;
8const MAX_JSON_SIZE: usize = 102_400;
10const MAX_GROUPS: usize = 20;
12const MAX_CHANGES_PER_GROUP: usize = 200;
14const MAX_LABEL_LEN: usize = 80;
16const MAX_DESC_LEN: usize = 500;
18
19#[derive(Debug, Clone, Copy, PartialEq)]
21pub enum LlmBackend {
22 Claude,
23 Copilot,
24}
25
26pub async fn request_grouping_with_timeout(
28 backend: LlmBackend,
29 model: &str,
30 summaries: &str,
31) -> anyhow::Result<Vec<SemanticGroup>> {
32 let model = model.to_string();
33 tokio::time::timeout(
34 Duration::from_secs(60),
35 request_grouping(backend, &model, summaries),
36 )
37 .await
38 .map_err(|_| anyhow::anyhow!("LLM timed out after 60s"))?
39}
40
41pub async fn request_grouping(
47 backend: LlmBackend,
48 model: &str,
49 hunk_summaries: &str,
50) -> anyhow::Result<Vec<SemanticGroup>> {
51 let prompt = format!(
52 "Group these code changes by semantic intent at the HUNK level. \
53 Related hunks across different files should be in the same group.\n\
54 Return ONLY valid JSON.\n\
55 Schema: {{\"groups\": [{{\"label\": \"short name\", \"description\": \"one sentence\", \
56 \"changes\": [{{\"file\": \"path\", \"hunks\": [0, 1]}}]}}]}}\n\
57 Rules:\n\
58 - Every hunk of every file must appear in exactly one group\n\
59 - Use 2-5 groups (fewer for small changesets)\n\
60 - Labels should describe the PURPOSE (e.g. \"Auth refactor\", \"Test coverage\")\n\
61 - The \"hunks\" array contains 0-based hunk indices as shown in HUNK N: headers\n\
62 - A single file's hunks may be split across different groups if they serve different purposes\n\n\
63 Changed files and hunks:\n{hunk_summaries}",
64 );
65
66 let output = match backend {
67 LlmBackend::Claude => invoke_claude(&prompt, model).await?,
68 LlmBackend::Copilot => invoke_copilot(&prompt, model).await?,
69 };
70
71 let json_str = extract_json(&output)?;
73
74 if json_str.len() > MAX_JSON_SIZE {
76 anyhow::bail!(
77 "LLM JSON response too large ({} bytes, max {})",
78 json_str.len(),
79 MAX_JSON_SIZE
80 );
81 }
82
83 let response: GroupingResponse = serde_json::from_str(&json_str)?;
84
85 let known_files: HashSet<&str> = hunk_summaries
87 .lines()
88 .filter_map(|line| {
89 let line = line.trim();
90 if let Some(rest) = line.strip_prefix("FILE: ") {
91 let end = rest.find(" (")?;
92 Some(&rest[..end])
93 } else {
94 None
95 }
96 })
97 .collect();
98
99 let validated_groups: Vec<SemanticGroup> = response
101 .groups
102 .into_iter()
103 .take(MAX_GROUPS) .map(|group| {
105 let valid_changes: Vec<super::GroupedChange> = group
106 .changes()
107 .into_iter()
108 .filter(|change| {
109 let known = known_files.contains(change.file.as_str());
111 let safe = !change.file.contains("..") && !change.file.starts_with('/');
113 if !safe {
114 tracing::warn!("Rejected LLM file path with traversal: {}", change.file);
115 }
116 known && safe
117 })
118 .take(MAX_CHANGES_PER_GROUP) .collect();
120 SemanticGroup::new(
122 truncate_string(&group.label, MAX_LABEL_LEN),
123 truncate_string(&group.description, MAX_DESC_LEN),
124 valid_changes,
125 )
126 })
127 .filter(|group| !group.changes().is_empty())
128 .collect();
129
130 Ok(validated_groups)
131}
132
133async fn invoke_claude(prompt: &str, model: &str) -> anyhow::Result<String> {
138 use std::process::Stdio;
139 use tokio::io::{AsyncReadExt, AsyncWriteExt};
140
141 let mut child = Command::new("claude")
142 .args([
143 "-p",
144 "--output-format",
145 "json",
146 "--model",
147 model,
148 "--max-turns",
149 "1",
150 ])
151 .stdin(Stdio::piped())
152 .stdout(Stdio::piped())
153 .stderr(Stdio::piped())
154 .spawn()?;
155
156 if let Some(mut stdin) = child.stdin.take() {
158 stdin.write_all(prompt.as_bytes()).await?;
159 }
161
162 let stdout_pipe = child.stdout.take()
164 .ok_or_else(|| anyhow::anyhow!("failed to capture claude stdout"))?;
165 let mut limited = stdout_pipe.take(MAX_RESPONSE_BYTES as u64);
166 let mut buf = Vec::with_capacity(8192);
167 let bytes_read = limited.read_to_end(&mut buf).await?;
168
169 if bytes_read >= MAX_RESPONSE_BYTES {
170 child.kill().await.ok();
171 anyhow::bail!("LLM response exceeded {} byte limit", MAX_RESPONSE_BYTES);
172 }
173
174 let status = child.wait().await?;
175 if !status.success() {
176 let mut stderr_buf = Vec::new();
178 if let Some(mut stderr) = child.stderr.take() {
179 stderr.read_to_end(&mut stderr_buf).await.ok();
180 }
181 let stderr_str = String::from_utf8_lossy(&stderr_buf);
182 anyhow::bail!("claude exited with status {}: {}", status, stderr_str);
183 }
184
185 let stdout_str = String::from_utf8(buf)?;
186 let wrapper: serde_json::Value = serde_json::from_str(&stdout_str)?;
187 let result_text = wrapper["result"]
188 .as_str()
189 .ok_or_else(|| anyhow::anyhow!("missing result field in claude JSON output"))?;
190
191 Ok(result_text.to_string())
192}
193
194async fn invoke_copilot(prompt: &str, model: &str) -> anyhow::Result<String> {
199 use std::process::Stdio;
200 use tokio::io::{AsyncReadExt, AsyncWriteExt};
201
202 let mut child = Command::new("copilot")
203 .args(["--yolo", "--model", model])
204 .stdin(Stdio::piped())
205 .stdout(Stdio::piped())
206 .stderr(Stdio::piped())
207 .spawn()?;
208
209 if let Some(mut stdin) = child.stdin.take() {
211 stdin.write_all(prompt.as_bytes()).await?;
212 }
213
214 let stdout_pipe = child.stdout.take()
216 .ok_or_else(|| anyhow::anyhow!("failed to capture copilot stdout"))?;
217 let mut limited = stdout_pipe.take(MAX_RESPONSE_BYTES as u64);
218 let mut buf = Vec::with_capacity(8192);
219 let bytes_read = limited.read_to_end(&mut buf).await?;
220
221 if bytes_read >= MAX_RESPONSE_BYTES {
222 child.kill().await.ok();
223 anyhow::bail!("LLM response exceeded {} byte limit", MAX_RESPONSE_BYTES);
224 }
225
226 let status = child.wait().await?;
227 if !status.success() {
228 let mut stderr_buf = Vec::new();
229 if let Some(mut stderr) = child.stderr.take() {
230 stderr.read_to_end(&mut stderr_buf).await.ok();
231 }
232 let stderr_str = String::from_utf8_lossy(&stderr_buf);
233 anyhow::bail!("copilot exited with status {}: {}", status, stderr_str);
234 }
235
236 Ok(String::from_utf8(buf)?)
237}
238
239fn extract_json(text: &str) -> anyhow::Result<String> {
241 let trimmed = text.trim();
242 if trimmed.starts_with('{') {
244 return Ok(trimmed.to_string());
245 }
246 if let Some(start) = trimmed.find('{') {
248 if let Some(end) = trimmed.rfind('}') {
249 return Ok(trimmed[start..=end].to_string());
250 }
251 }
252 anyhow::bail!("no JSON object found in response")
253}
254
255fn truncate_string(s: &str, max: usize) -> String {
257 if s.chars().count() <= max {
258 s.to_string()
259 } else {
260 s.chars().take(max).collect()
261 }
262}
263
264#[cfg(test)]
265mod tests {
266 use super::*;
267
268 #[test]
269 fn test_extract_json_direct() {
270 let input = r#"{"groups": []}"#;
271 assert_eq!(extract_json(input).unwrap(), input);
272 }
273
274 #[test]
275 fn test_extract_json_code_fences() {
276 let input = "```json\n{\"groups\": []}\n```";
277 assert_eq!(extract_json(input).unwrap(), r#"{"groups": []}"#);
278 }
279
280 #[test]
281 fn test_extract_json_no_json() {
282 assert!(extract_json("no json here").is_err());
283 }
284
285 #[test]
286 fn test_parse_hunk_level_response() {
287 let json = r#"{
288 "groups": [{
289 "label": "Auth refactor",
290 "description": "Refactored auth flow",
291 "changes": [
292 {"file": "src/auth.rs", "hunks": [0, 2]},
293 {"file": "src/middleware.rs", "hunks": [1]}
294 ]
295 }]
296 }"#;
297 let response: GroupingResponse = serde_json::from_str(json).unwrap();
298 assert_eq!(response.groups.len(), 1);
299 assert_eq!(response.groups[0].changes().len(), 2);
300 assert_eq!(response.groups[0].changes()[0].hunks, vec![0, 2]);
301 }
302
303 #[test]
304 fn test_parse_empty_hunks_means_all() {
305 let json = r#"{
306 "groups": [{
307 "label": "Config",
308 "description": "Config changes",
309 "changes": [{"file": "config.toml", "hunks": []}]
310 }]
311 }"#;
312 let response: GroupingResponse = serde_json::from_str(json).unwrap();
313 assert!(response.groups[0].changes()[0].hunks.is_empty());
314 }
315
316 #[test]
320 fn test_invoke_claude_uses_stdin_pipe() {
321 let src = include_str!("llm.rs");
322 let claude_start = src.find("async fn invoke_claude").expect("invoke_claude not found");
324 let claude_body = &src[claude_start..];
325 let end = claude_body[1..].find("\nasync fn").unwrap_or(claude_body.len());
327 let claude_fn = &claude_body[..end];
328
329 assert!(
330 claude_fn.contains("Stdio::piped()"),
331 "invoke_claude must use Stdio::piped() for stdin"
332 );
333 assert!(
334 claude_fn.contains("write_all"),
335 "invoke_claude must write prompt to stdin via write_all"
336 );
337 if let Some(args_start) = claude_fn.find(".args([") {
339 let args_section = &claude_fn[args_start..];
340 let args_end = args_section.find("])").expect("unclosed .args");
341 let args_content = &args_section[..args_end];
342 assert!(
343 !args_content.contains("prompt"),
344 "invoke_claude must not pass prompt in .args()"
345 );
346 }
347 }
348
349 #[test]
351 fn test_invoke_copilot_uses_stdin_pipe() {
352 let src = include_str!("llm.rs");
353 let copilot_start = src.find("async fn invoke_copilot").expect("invoke_copilot not found");
354 let copilot_body = &src[copilot_start..];
355 let end = copilot_body[1..].find("\n/// ").or_else(|| copilot_body[1..].find("\n#[cfg(test)]")).unwrap_or(copilot_body.len());
356 let copilot_fn = &copilot_body[..end];
357
358 assert!(
359 copilot_fn.contains("Stdio::piped()"),
360 "invoke_copilot must use Stdio::piped() for stdin"
361 );
362 assert!(
363 copilot_fn.contains("write_all"),
364 "invoke_copilot must write prompt to stdin via write_all"
365 );
366 }
367
368 #[test]
370 fn test_no_prompt_in_args() {
371 let src = include_str!("llm.rs");
372 let claude_start = src.find("async fn invoke_claude").expect("invoke_claude not found");
374 let claude_body = &src[claude_start..];
375 let end = claude_body[1..].find("\nasync fn").unwrap_or(claude_body.len());
376 let claude_fn = &claude_body[..end];
377
378 if let Some(args_start) = claude_fn.find(".args([") {
380 let args_section = &claude_fn[args_start..];
381 let args_end = args_section.find("])").expect("unclosed .args");
382 let args_content = &args_section[..args_end];
383 assert!(
384 !args_content.contains("prompt"),
385 "invoke_claude .args() must not contain prompt variable"
386 );
387 }
388
389 let copilot_start = src.find("async fn invoke_copilot").expect("invoke_copilot not found");
391 let copilot_body = &src[copilot_start..];
392 let end2 = copilot_body[1..].find("\n/// ").or_else(|| copilot_body[1..].find("\n#[cfg(test)]")).unwrap_or(copilot_body.len());
393 let copilot_fn = &copilot_body[..end2];
394
395 if let Some(args_start) = copilot_fn.find(".args([") {
396 let args_section = &copilot_fn[args_start..];
397 let args_end = args_section.find("])").expect("unclosed .args");
398 let args_content = &args_section[..args_end];
399 assert!(
400 !args_content.contains("prompt"),
401 "invoke_copilot .args() must not contain prompt variable"
402 );
403 }
404 }
405
406 #[test]
407 fn test_parse_files_fallback() {
408 let json = r#"{
410 "groups": [{
411 "label": "Refactor",
412 "description": "Code cleanup",
413 "files": ["src/app.rs", "src/main.rs"]
414 }]
415 }"#;
416 let response: GroupingResponse = serde_json::from_str(json).unwrap();
417 let changes = response.groups[0].changes();
418 assert_eq!(changes.len(), 2);
419 assert_eq!(changes[0].file, "src/app.rs");
420 assert!(changes[0].hunks.is_empty()); }
422
423 #[test]
426 fn test_read_bounded_under_limit() {
427 let data = "hello world";
429 assert!(data.len() < MAX_RESPONSE_BYTES);
430 assert_eq!(MAX_RESPONSE_BYTES, 1_048_576);
432 }
433
434 #[test]
435 fn test_read_bounded_over_limit_constant() {
436 assert_eq!(MAX_RESPONSE_BYTES, 1_048_576);
438 let oversized = vec![b'x'; MAX_RESPONSE_BYTES];
440 assert!(oversized.len() >= MAX_RESPONSE_BYTES);
441 }
442
443 #[test]
446 fn test_validate_rejects_oversized_json() {
447 let large_json = format!(r#"{{"groups": [{{"label": "x", "description": "{}", "changes": []}}]}}"#,
449 "a".repeat(MAX_JSON_SIZE + 1));
450 assert!(large_json.len() > MAX_JSON_SIZE);
451 }
453
454 #[test]
455 fn test_validate_caps_groups_at_max() {
456 let mut groups_json = Vec::new();
458 for i in 0..30 {
459 groups_json.push(format!(
460 r#"{{"label": "Group {}", "description": "desc", "changes": [{{"file": "src/f{}.rs", "hunks": [0]}}]}}"#,
461 i, i
462 ));
463 }
464 let json = format!(r#"{{"groups": [{}]}}"#, groups_json.join(","));
465 let response: GroupingResponse = serde_json::from_str(&json).unwrap();
466 assert_eq!(response.groups.len(), 30);
467 let capped: Vec<_> = response.groups.into_iter().take(MAX_GROUPS).collect();
469 assert_eq!(capped.len(), 20);
470 }
471
472 #[test]
473 fn test_validate_rejects_path_traversal() {
474 let json = r#"{
475 "groups": [{
476 "label": "Evil",
477 "description": "traversal",
478 "changes": [{"file": "../../../etc/passwd", "hunks": [0]}]
479 }]
480 }"#;
481 let response: GroupingResponse = serde_json::from_str(json).unwrap();
482 let change = &response.groups[0].changes()[0];
483 assert!(change.file.contains(".."), "path should contain traversal");
484 }
486
487 #[test]
488 fn test_validate_rejects_absolute_paths() {
489 let json = r#"{
490 "groups": [{
491 "label": "Evil",
492 "description": "absolute",
493 "changes": [{"file": "/etc/passwd", "hunks": [0]}]
494 }]
495 }"#;
496 let response: GroupingResponse = serde_json::from_str(json).unwrap();
497 let change = &response.groups[0].changes()[0];
498 assert!(change.file.starts_with('/'), "path should be absolute");
499 }
501
502 #[test]
503 fn test_truncate_string_label() {
504 let long_label = "a".repeat(100);
505 let truncated = truncate_string(&long_label, MAX_LABEL_LEN);
506 assert_eq!(truncated.chars().count(), MAX_LABEL_LEN);
507 }
508
509 #[test]
510 fn test_truncate_string_description() {
511 let long_desc = "b".repeat(600);
512 let truncated = truncate_string(&long_desc, MAX_DESC_LEN);
513 assert_eq!(truncated.chars().count(), MAX_DESC_LEN);
514 }
515
516 #[test]
517 fn test_validate_caps_changes_per_group() {
518 let mut changes = Vec::new();
520 for i in 0..250 {
521 changes.push(format!(r#"{{"file": "src/f{}.rs", "hunks": [0]}}"#, i));
522 }
523 let json = format!(
524 r#"{{"groups": [{{"label": "Big", "description": "lots", "changes": [{}]}}]}}"#,
525 changes.join(",")
526 );
527 let response: GroupingResponse = serde_json::from_str(&json).unwrap();
528 assert_eq!(response.groups[0].changes().len(), 250);
529 let capped: Vec<_> = response.groups[0].changes().into_iter().take(MAX_CHANGES_PER_GROUP).collect();
531 assert_eq!(capped.len(), 200);
532 }
533}