1use super::{GroupingResponse, SemanticGroup};
2use std::collections::HashSet;
3use std::time::Duration;
4use tokio::process::Command;
5
6const MAX_RESPONSE_BYTES: usize = 1_048_576;
8const MAX_JSON_SIZE: usize = 102_400;
10const MAX_GROUPS: usize = 20;
12const MAX_CHANGES_PER_GROUP: usize = 200;
14const MAX_LABEL_LEN: usize = 80;
16const MAX_DESC_LEN: usize = 500;
18
19#[derive(Debug, Clone, Copy, PartialEq)]
21pub enum LlmBackend {
22 Claude,
23 Copilot,
24}
25
26pub async fn request_grouping_with_timeout(
28 backend: LlmBackend,
29 model: &str,
30 summaries: &str,
31) -> anyhow::Result<Vec<SemanticGroup>> {
32 let model = model.to_string();
33 tokio::time::timeout(
34 Duration::from_secs(60),
35 request_grouping(backend, &model, summaries),
36 )
37 .await
38 .map_err(|_| anyhow::anyhow!("LLM timed out after 60s"))?
39}
40
41pub async fn request_grouping(
47 backend: LlmBackend,
48 model: &str,
49 hunk_summaries: &str,
50) -> anyhow::Result<Vec<SemanticGroup>> {
51 let prompt = format!(
52 "Group these code changes by semantic intent at the HUNK level. \
53 Related hunks across different files should be in the same group.\n\
54 Return ONLY valid JSON.\n\
55 Schema: {{\"groups\": [{{\"label\": \"short name\", \"description\": \"one sentence\", \
56 \"changes\": [{{\"file\": \"path\", \"hunks\": [0, 1]}}]}}]}}\n\
57 Rules:\n\
58 - Every hunk of every file must appear in exactly one group\n\
59 - Use 2-5 groups (fewer for small changesets)\n\
60 - Labels should describe the PURPOSE (e.g. \"Auth refactor\", \"Test coverage\")\n\
61 - The \"hunks\" array contains 0-based hunk indices as shown in HUNK N: headers\n\
62 - A single file's hunks may be split across different groups if they serve different purposes\n\n\
63 Changed files and hunks:\n{hunk_summaries}",
64 );
65
66 let output = match backend {
67 LlmBackend::Claude => invoke_claude(&prompt, model).await?,
68 LlmBackend::Copilot => invoke_copilot(&prompt, model).await?,
69 };
70
71 let json_str = extract_json(&output)?;
73
74 if json_str.len() > MAX_JSON_SIZE {
76 anyhow::bail!(
77 "LLM JSON response too large ({} bytes, max {})",
78 json_str.len(),
79 MAX_JSON_SIZE
80 );
81 }
82
83 let response: GroupingResponse = serde_json::from_str(&json_str)?;
84
85 let known_files: HashSet<&str> = hunk_summaries
87 .lines()
88 .filter_map(|line| {
89 let line = line.trim();
90 if let Some(rest) = line.strip_prefix("FILE: ") {
91 let end = rest.find(" (")?;
92 Some(&rest[..end])
93 } else {
94 None
95 }
96 })
97 .collect();
98
99 let validated_groups: Vec<SemanticGroup> = response
101 .groups
102 .into_iter()
103 .take(MAX_GROUPS) .map(|group| {
105 let valid_changes: Vec<super::GroupedChange> = group
106 .changes()
107 .into_iter()
108 .filter(|change| {
109 let known = known_files.contains(change.file.as_str());
111 let safe = !change.file.contains("..") && !change.file.starts_with('/');
113 if !safe {
114 tracing::warn!("Rejected LLM file path with traversal: {}", change.file);
115 }
116 known && safe
117 })
118 .take(MAX_CHANGES_PER_GROUP) .collect();
120 SemanticGroup::new(
122 truncate_string(&group.label, MAX_LABEL_LEN),
123 truncate_string(&group.description, MAX_DESC_LEN),
124 valid_changes,
125 )
126 })
127 .filter(|group| !group.changes().is_empty())
128 .collect();
129
130 Ok(validated_groups)
131}
132
133pub async fn request_incremental_grouping(
138 backend: LlmBackend,
139 model: &str,
140 summaries: &str,
141) -> anyhow::Result<Vec<SemanticGroup>> {
142 let model = model.to_string();
143 tokio::time::timeout(
144 Duration::from_secs(60),
145 request_incremental(backend, &model, summaries),
146 )
147 .await
148 .map_err(|_| anyhow::anyhow!("LLM timed out after 60s"))?
149}
150
151async fn request_incremental(
152 backend: LlmBackend,
153 model: &str,
154 hunk_summaries: &str,
155) -> anyhow::Result<Vec<SemanticGroup>> {
156 let prompt = format!(
157 "You are updating an existing grouping of code changes. \
158 New or modified files have been added to the working tree.\n\
159 Assign the NEW/MODIFIED hunks to the EXISTING groups listed above, or create new groups if they don't fit.\n\
160 Return ONLY valid JSON with assignments for the NEW/MODIFIED files only.\n\
161 Schema: {{\"groups\": [{{\"label\": \"short name\", \"description\": \"one sentence\", \
162 \"changes\": [{{\"file\": \"path\", \"hunks\": [0, 1]}}]}}]}}\n\
163 Rules:\n\
164 - Every hunk of every NEW/MODIFIED file must appear in exactly one group\n\
165 - Reuse existing group labels when the change fits that group's purpose\n\
166 - Create new groups only when a change serves a genuinely different purpose\n\
167 - Use the same label string (case-sensitive) when assigning to an existing group\n\
168 - The \"hunks\" array contains 0-based hunk indices\n\
169 - Do NOT include unchanged files in your response\n\n\
170 {hunk_summaries}",
171 );
172
173 let output = match backend {
174 LlmBackend::Claude => invoke_claude(&prompt, model).await?,
175 LlmBackend::Copilot => invoke_copilot(&prompt, model).await?,
176 };
177
178 let json_str = extract_json(&output)?;
179
180 if json_str.len() > MAX_JSON_SIZE {
181 anyhow::bail!(
182 "LLM JSON response too large ({} bytes, max {})",
183 json_str.len(),
184 MAX_JSON_SIZE
185 );
186 }
187
188 let response: GroupingResponse = serde_json::from_str(&json_str)?;
189
190 let known_files: HashSet<&str> = hunk_summaries
192 .lines()
193 .filter_map(|line| {
194 let line = line.trim();
195 if let Some(rest) = line.strip_prefix("FILE: ") {
196 let end = rest.find(" (")?;
197 Some(&rest[..end])
198 } else {
199 None
200 }
201 })
202 .collect();
203
204 let validated_groups: Vec<SemanticGroup> = response
205 .groups
206 .into_iter()
207 .take(MAX_GROUPS)
208 .map(|group| {
209 let valid_changes: Vec<super::GroupedChange> = group
210 .changes()
211 .into_iter()
212 .filter(|change| {
213 let known = known_files.contains(change.file.as_str());
214 let safe = !change.file.contains("..") && !change.file.starts_with('/');
215 if !safe {
216 tracing::warn!("Rejected LLM file path with traversal: {}", change.file);
217 }
218 known && safe
219 })
220 .take(MAX_CHANGES_PER_GROUP)
221 .collect();
222 SemanticGroup::new(
223 truncate_string(&group.label, MAX_LABEL_LEN),
224 truncate_string(&group.description, MAX_DESC_LEN),
225 valid_changes,
226 )
227 })
228 .filter(|group| !group.changes().is_empty())
229 .collect();
230
231 Ok(validated_groups)
232}
233
234async fn invoke_claude(prompt: &str, model: &str) -> anyhow::Result<String> {
239 use std::process::Stdio;
240 use tokio::io::{AsyncReadExt, AsyncWriteExt};
241
242 let mut child = Command::new("claude")
243 .args([
244 "-p",
245 "--output-format",
246 "json",
247 "--model",
248 model,
249 "--max-turns",
250 "1",
251 ])
252 .stdin(Stdio::piped())
253 .stdout(Stdio::piped())
254 .stderr(Stdio::piped())
255 .spawn()?;
256
257 if let Some(mut stdin) = child.stdin.take() {
259 stdin.write_all(prompt.as_bytes()).await?;
260 }
262
263 let stdout_pipe = child.stdout.take()
265 .ok_or_else(|| anyhow::anyhow!("failed to capture claude stdout"))?;
266 let mut limited = stdout_pipe.take(MAX_RESPONSE_BYTES as u64);
267 let mut buf = Vec::with_capacity(8192);
268 let bytes_read = limited.read_to_end(&mut buf).await?;
269
270 if bytes_read >= MAX_RESPONSE_BYTES {
271 child.kill().await.ok();
272 anyhow::bail!("LLM response exceeded {MAX_RESPONSE_BYTES} byte limit");
273 }
274
275 let status = child.wait().await?;
276 if !status.success() {
277 let mut stderr_buf = Vec::new();
279 if let Some(mut stderr) = child.stderr.take() {
280 stderr.read_to_end(&mut stderr_buf).await.ok();
281 }
282 let stderr_str = String::from_utf8_lossy(&stderr_buf);
283 anyhow::bail!("claude exited with status {status}: {stderr_str}");
284 }
285
286 let stdout_str = String::from_utf8(buf)?;
287 let wrapper: serde_json::Value = serde_json::from_str(&stdout_str)?;
288 let result_text = wrapper["result"]
289 .as_str()
290 .ok_or_else(|| anyhow::anyhow!("missing result field in claude JSON output"))?;
291
292 Ok(result_text.to_string())
293}
294
295async fn invoke_copilot(prompt: &str, model: &str) -> anyhow::Result<String> {
300 use std::process::Stdio;
301 use tokio::io::{AsyncReadExt, AsyncWriteExt};
302
303 let mut child = Command::new("copilot")
304 .args(["--yolo", "--model", model])
305 .stdin(Stdio::piped())
306 .stdout(Stdio::piped())
307 .stderr(Stdio::piped())
308 .spawn()?;
309
310 if let Some(mut stdin) = child.stdin.take() {
312 stdin.write_all(prompt.as_bytes()).await?;
313 }
314
315 let stdout_pipe = child.stdout.take()
317 .ok_or_else(|| anyhow::anyhow!("failed to capture copilot stdout"))?;
318 let mut limited = stdout_pipe.take(MAX_RESPONSE_BYTES as u64);
319 let mut buf = Vec::with_capacity(8192);
320 let bytes_read = limited.read_to_end(&mut buf).await?;
321
322 if bytes_read >= MAX_RESPONSE_BYTES {
323 child.kill().await.ok();
324 anyhow::bail!("LLM response exceeded {MAX_RESPONSE_BYTES} byte limit");
325 }
326
327 let status = child.wait().await?;
328 if !status.success() {
329 let mut stderr_buf = Vec::new();
330 if let Some(mut stderr) = child.stderr.take() {
331 stderr.read_to_end(&mut stderr_buf).await.ok();
332 }
333 let stderr_str = String::from_utf8_lossy(&stderr_buf);
334 anyhow::bail!("copilot exited with status {status}: {stderr_str}");
335 }
336
337 Ok(String::from_utf8(buf)?)
338}
339
340fn extract_json(text: &str) -> anyhow::Result<String> {
342 let trimmed = text.trim();
343 if trimmed.starts_with('{') {
345 return Ok(trimmed.to_string());
346 }
347 if let Some(start) = trimmed.find('{') {
349 if let Some(end) = trimmed.rfind('}') {
350 return Ok(trimmed[start..=end].to_string());
351 }
352 }
353 anyhow::bail!("no JSON object found in response")
354}
355
356fn truncate_string(s: &str, max: usize) -> String {
358 if s.chars().count() <= max {
359 s.to_string()
360 } else {
361 s.chars().take(max).collect()
362 }
363}
364
365#[cfg(test)]
366mod tests {
367 use super::*;
368
369 #[test]
370 fn test_extract_json_direct() {
371 let input = r#"{"groups": []}"#;
372 assert_eq!(extract_json(input).unwrap(), input);
373 }
374
375 #[test]
376 fn test_extract_json_code_fences() {
377 let input = "```json\n{\"groups\": []}\n```";
378 assert_eq!(extract_json(input).unwrap(), r#"{"groups": []}"#);
379 }
380
381 #[test]
382 fn test_extract_json_no_json() {
383 assert!(extract_json("no json here").is_err());
384 }
385
386 #[test]
387 fn test_parse_hunk_level_response() {
388 let json = r#"{
389 "groups": [{
390 "label": "Auth refactor",
391 "description": "Refactored auth flow",
392 "changes": [
393 {"file": "src/auth.rs", "hunks": [0, 2]},
394 {"file": "src/middleware.rs", "hunks": [1]}
395 ]
396 }]
397 }"#;
398 let response: GroupingResponse = serde_json::from_str(json).unwrap();
399 assert_eq!(response.groups.len(), 1);
400 assert_eq!(response.groups[0].changes().len(), 2);
401 assert_eq!(response.groups[0].changes()[0].hunks, vec![0, 2]);
402 }
403
404 #[test]
405 fn test_parse_empty_hunks_means_all() {
406 let json = r#"{
407 "groups": [{
408 "label": "Config",
409 "description": "Config changes",
410 "changes": [{"file": "config.toml", "hunks": []}]
411 }]
412 }"#;
413 let response: GroupingResponse = serde_json::from_str(json).unwrap();
414 assert!(response.groups[0].changes()[0].hunks.is_empty());
415 }
416
417 #[test]
421 fn test_invoke_claude_uses_stdin_pipe() {
422 let src = include_str!("llm.rs");
423 let claude_start = src.find("async fn invoke_claude").expect("invoke_claude not found");
425 let claude_body = &src[claude_start..];
426 let end = claude_body[1..].find("\nasync fn").unwrap_or(claude_body.len());
428 let claude_fn = &claude_body[..end];
429
430 assert!(
431 claude_fn.contains("Stdio::piped()"),
432 "invoke_claude must use Stdio::piped() for stdin"
433 );
434 assert!(
435 claude_fn.contains("write_all"),
436 "invoke_claude must write prompt to stdin via write_all"
437 );
438 if let Some(args_start) = claude_fn.find(".args([") {
440 let args_section = &claude_fn[args_start..];
441 let args_end = args_section.find("])").expect("unclosed .args");
442 let args_content = &args_section[..args_end];
443 assert!(
444 !args_content.contains("prompt"),
445 "invoke_claude must not pass prompt in .args()"
446 );
447 }
448 }
449
450 #[test]
452 fn test_invoke_copilot_uses_stdin_pipe() {
453 let src = include_str!("llm.rs");
454 let copilot_start = src.find("async fn invoke_copilot").expect("invoke_copilot not found");
455 let copilot_body = &src[copilot_start..];
456 let end = copilot_body[1..].find("\n/// ").or_else(|| copilot_body[1..].find("\n#[cfg(test)]")).unwrap_or(copilot_body.len());
457 let copilot_fn = &copilot_body[..end];
458
459 assert!(
460 copilot_fn.contains("Stdio::piped()"),
461 "invoke_copilot must use Stdio::piped() for stdin"
462 );
463 assert!(
464 copilot_fn.contains("write_all"),
465 "invoke_copilot must write prompt to stdin via write_all"
466 );
467 }
468
469 #[test]
471 fn test_no_prompt_in_args() {
472 let src = include_str!("llm.rs");
473 let claude_start = src.find("async fn invoke_claude").expect("invoke_claude not found");
475 let claude_body = &src[claude_start..];
476 let end = claude_body[1..].find("\nasync fn").unwrap_or(claude_body.len());
477 let claude_fn = &claude_body[..end];
478
479 if let Some(args_start) = claude_fn.find(".args([") {
481 let args_section = &claude_fn[args_start..];
482 let args_end = args_section.find("])").expect("unclosed .args");
483 let args_content = &args_section[..args_end];
484 assert!(
485 !args_content.contains("prompt"),
486 "invoke_claude .args() must not contain prompt variable"
487 );
488 }
489
490 let copilot_start = src.find("async fn invoke_copilot").expect("invoke_copilot not found");
492 let copilot_body = &src[copilot_start..];
493 let end2 = copilot_body[1..].find("\n/// ").or_else(|| copilot_body[1..].find("\n#[cfg(test)]")).unwrap_or(copilot_body.len());
494 let copilot_fn = &copilot_body[..end2];
495
496 if let Some(args_start) = copilot_fn.find(".args([") {
497 let args_section = &copilot_fn[args_start..];
498 let args_end = args_section.find("])").expect("unclosed .args");
499 let args_content = &args_section[..args_end];
500 assert!(
501 !args_content.contains("prompt"),
502 "invoke_copilot .args() must not contain prompt variable"
503 );
504 }
505 }
506
507 #[test]
508 fn test_parse_files_fallback() {
509 let json = r#"{
511 "groups": [{
512 "label": "Refactor",
513 "description": "Code cleanup",
514 "files": ["src/app.rs", "src/main.rs"]
515 }]
516 }"#;
517 let response: GroupingResponse = serde_json::from_str(json).unwrap();
518 let changes = response.groups[0].changes();
519 assert_eq!(changes.len(), 2);
520 assert_eq!(changes[0].file, "src/app.rs");
521 assert!(changes[0].hunks.is_empty()); }
523
524 #[test]
527 fn test_read_bounded_under_limit() {
528 let data = "hello world";
530 assert!(data.len() < MAX_RESPONSE_BYTES);
531 assert_eq!(MAX_RESPONSE_BYTES, 1_048_576);
533 }
534
535 #[test]
536 fn test_read_bounded_over_limit_constant() {
537 assert_eq!(MAX_RESPONSE_BYTES, 1_048_576);
539 let oversized = vec![b'x'; MAX_RESPONSE_BYTES];
541 assert!(oversized.len() >= MAX_RESPONSE_BYTES);
542 }
543
544 #[test]
547 fn test_validate_rejects_oversized_json() {
548 let large_json = format!(r#"{{"groups": [{{"label": "x", "description": "{}", "changes": []}}]}}"#,
550 "a".repeat(MAX_JSON_SIZE + 1));
551 assert!(large_json.len() > MAX_JSON_SIZE);
552 }
554
555 #[test]
556 fn test_validate_caps_groups_at_max() {
557 let mut groups_json = Vec::new();
559 for i in 0..30 {
560 groups_json.push(format!(
561 r#"{{"label": "Group {}", "description": "desc", "changes": [{{"file": "src/f{}.rs", "hunks": [0]}}]}}"#,
562 i, i
563 ));
564 }
565 let json = format!(r#"{{"groups": [{}]}}"#, groups_json.join(","));
566 let response: GroupingResponse = serde_json::from_str(&json).unwrap();
567 assert_eq!(response.groups.len(), 30);
568 let capped: Vec<_> = response.groups.into_iter().take(MAX_GROUPS).collect();
570 assert_eq!(capped.len(), 20);
571 }
572
573 #[test]
574 fn test_validate_rejects_path_traversal() {
575 let json = r#"{
576 "groups": [{
577 "label": "Evil",
578 "description": "traversal",
579 "changes": [{"file": "../../../etc/passwd", "hunks": [0]}]
580 }]
581 }"#;
582 let response: GroupingResponse = serde_json::from_str(json).unwrap();
583 let change = &response.groups[0].changes()[0];
584 assert!(change.file.contains(".."), "path should contain traversal");
585 }
587
588 #[test]
589 fn test_validate_rejects_absolute_paths() {
590 let json = r#"{
591 "groups": [{
592 "label": "Evil",
593 "description": "absolute",
594 "changes": [{"file": "/etc/passwd", "hunks": [0]}]
595 }]
596 }"#;
597 let response: GroupingResponse = serde_json::from_str(json).unwrap();
598 let change = &response.groups[0].changes()[0];
599 assert!(change.file.starts_with('/'), "path should be absolute");
600 }
602
603 #[test]
604 fn test_truncate_string_label() {
605 let long_label = "a".repeat(100);
606 let truncated = truncate_string(&long_label, MAX_LABEL_LEN);
607 assert_eq!(truncated.chars().count(), MAX_LABEL_LEN);
608 }
609
610 #[test]
611 fn test_truncate_string_description() {
612 let long_desc = "b".repeat(600);
613 let truncated = truncate_string(&long_desc, MAX_DESC_LEN);
614 assert_eq!(truncated.chars().count(), MAX_DESC_LEN);
615 }
616
617 #[test]
618 fn test_validate_caps_changes_per_group() {
619 let mut changes = Vec::new();
621 for i in 0..250 {
622 changes.push(format!(r#"{{"file": "src/f{}.rs", "hunks": [0]}}"#, i));
623 }
624 let json = format!(
625 r#"{{"groups": [{{"label": "Big", "description": "lots", "changes": [{}]}}]}}"#,
626 changes.join(",")
627 );
628 let response: GroupingResponse = serde_json::from_str(&json).unwrap();
629 assert_eq!(response.groups[0].changes().len(), 250);
630 let capped: Vec<_> = response.groups[0].changes().into_iter().take(MAX_CHANGES_PER_GROUP).collect();
632 assert_eq!(capped.len(), 200);
633 }
634}