Skip to main content

semantic_diff/grouper/
llm.rs

1use super::{GroupingResponse, SemanticGroup};
2use std::collections::HashSet;
3use std::time::Duration;
4use tokio::process::Command;
5
6/// Maximum bytes to read from LLM stdout (1MB). Prevents OOM from malicious/broken LLM.
7const MAX_RESPONSE_BYTES: usize = 1_048_576;
8/// Maximum JSON string size before deserialization (100KB).
9const MAX_JSON_SIZE: usize = 102_400;
10/// Maximum number of semantic groups from LLM.
11const MAX_GROUPS: usize = 20;
12/// Maximum changes per group.
13const MAX_CHANGES_PER_GROUP: usize = 200;
14/// Maximum label length (characters).
15const MAX_LABEL_LEN: usize = 80;
16/// Maximum description length (characters).
17const MAX_DESC_LEN: usize = 500;
18
19/// Which LLM backend is available for semantic grouping.
20#[derive(Debug, Clone, Copy, PartialEq)]
21pub enum LlmBackend {
22    Claude,
23    Copilot,
24}
25
26/// Request semantic grouping from the detected LLM backend with a 30-second timeout.
27pub async fn request_grouping_with_timeout(
28    backend: LlmBackend,
29    model: &str,
30    summaries: &str,
31) -> anyhow::Result<Vec<SemanticGroup>> {
32    let model = model.to_string();
33    tokio::time::timeout(
34        Duration::from_secs(60),
35        request_grouping(backend, &model, summaries),
36    )
37    .await
38    .map_err(|_| anyhow::anyhow!("LLM timed out after 60s"))?
39}
40
41/// Invoke the LLM backend to group hunks by semantic intent.
42///
43/// Prompts are piped via stdin to prevent process table exposure of code diffs.
44/// Uses `tokio::process::Command::spawn()` so that aborting the JoinHandle
45/// drops the Child, which sends SIGKILL (critical for ROB-05 cancellation).
46pub async fn request_grouping(
47    backend: LlmBackend,
48    model: &str,
49    hunk_summaries: &str,
50) -> anyhow::Result<Vec<SemanticGroup>> {
51    let prompt = format!(
52        "Group these code changes by semantic intent at the HUNK level. \
53         Related hunks across different files should be in the same group.\n\
54         Return ONLY valid JSON.\n\
55         Schema: {{\"groups\": [{{\"label\": \"short name\", \"description\": \"one sentence\", \
56         \"changes\": [{{\"file\": \"path\", \"hunks\": [0, 1]}}]}}]}}\n\
57         Rules:\n\
58         - Every hunk of every file must appear in exactly one group\n\
59         - Use 2-5 groups (fewer for small changesets)\n\
60         - Labels should describe the PURPOSE (e.g. \"Auth refactor\", \"Test coverage\")\n\
61         - The \"hunks\" array contains 0-based hunk indices as shown in HUNK N: headers\n\
62         - A single file's hunks may be split across different groups if they serve different purposes\n\n\
63         Changed files and hunks:\n{hunk_summaries}",
64    );
65
66    let output = match backend {
67        LlmBackend::Claude => invoke_claude(&prompt, model).await?,
68        LlmBackend::Copilot => invoke_copilot(&prompt, model).await?,
69    };
70
71    // Extract JSON from potential markdown code fences
72    let json_str = extract_json(&output)?;
73
74    // FINDING-12: Validate JSON size before deserialization
75    if json_str.len() > MAX_JSON_SIZE {
76        anyhow::bail!(
77            "LLM JSON response too large ({} bytes, max {})",
78            json_str.len(),
79            MAX_JSON_SIZE
80        );
81    }
82
83    let response: GroupingResponse = serde_json::from_str(&json_str)?;
84
85    // Build set of valid (file, hunk_count) for validation
86    let known_files: HashSet<&str> = hunk_summaries
87        .lines()
88        .filter_map(|line| {
89            let line = line.trim();
90            if let Some(rest) = line.strip_prefix("FILE: ") {
91                let end = rest.find(" (")?;
92                Some(&rest[..end])
93            } else {
94                None
95            }
96        })
97        .collect();
98
99    // Validate: drop unknown files, enforce bounds (FINDING-13, 14, 15)
100    let validated_groups: Vec<SemanticGroup> = response
101        .groups
102        .into_iter()
103        .take(MAX_GROUPS) // FINDING-15: cap group count
104        .map(|group| {
105            let valid_changes: Vec<super::GroupedChange> = group
106                .changes()
107                .into_iter()
108                .filter(|change| {
109                    // Existing: check against known_files
110                    let known = known_files.contains(change.file.as_str());
111                    // FINDING-14: reject traversal paths and absolute paths
112                    let safe = !change.file.contains("..") && !change.file.starts_with('/');
113                    if !safe {
114                        tracing::warn!("Rejected LLM file path with traversal: {}", change.file);
115                    }
116                    known && safe
117                })
118                .take(MAX_CHANGES_PER_GROUP) // cap changes per group
119                .collect();
120            // FINDING-13: truncate label and description
121            SemanticGroup::new(
122                truncate_string(&group.label, MAX_LABEL_LEN),
123                truncate_string(&group.description, MAX_DESC_LEN),
124                valid_changes,
125            )
126        })
127        .filter(|group| !group.changes().is_empty())
128        .collect();
129
130    Ok(validated_groups)
131}
132
133/// Invoke the `claude` CLI and return the LLM response text.
134///
135/// Pipes the prompt via stdin to avoid exposing code diffs in the process table.
136/// The `-p` flag without an argument causes claude to read from stdin.
137async fn invoke_claude(prompt: &str, model: &str) -> anyhow::Result<String> {
138    use std::process::Stdio;
139    use tokio::io::{AsyncReadExt, AsyncWriteExt};
140
141    let mut child = Command::new("claude")
142        .args([
143            "-p",
144            "--output-format",
145            "json",
146            "--model",
147            model,
148            "--max-turns",
149            "1",
150        ])
151        .stdin(Stdio::piped())
152        .stdout(Stdio::piped())
153        .stderr(Stdio::piped())
154        .spawn()?;
155
156    // Write prompt to stdin, then close it
157    if let Some(mut stdin) = child.stdin.take() {
158        stdin.write_all(prompt.as_bytes()).await?;
159        // stdin is dropped here, closing the pipe
160    }
161
162    // Bounded read from stdout (FINDING-11: prevent OOM from oversized LLM response)
163    let stdout_pipe = child.stdout.take()
164        .ok_or_else(|| anyhow::anyhow!("failed to capture claude stdout"))?;
165    let mut limited = stdout_pipe.take(MAX_RESPONSE_BYTES as u64);
166    let mut buf = Vec::with_capacity(8192);
167    let bytes_read = limited.read_to_end(&mut buf).await?;
168
169    if bytes_read >= MAX_RESPONSE_BYTES {
170        child.kill().await.ok();
171        anyhow::bail!("LLM response exceeded {} byte limit", MAX_RESPONSE_BYTES);
172    }
173
174    let status = child.wait().await?;
175    if !status.success() {
176        // Try to read stderr for diagnostics
177        let mut stderr_buf = Vec::new();
178        if let Some(mut stderr) = child.stderr.take() {
179            stderr.read_to_end(&mut stderr_buf).await.ok();
180        }
181        let stderr_str = String::from_utf8_lossy(&stderr_buf);
182        anyhow::bail!("claude exited with status {}: {}", status, stderr_str);
183    }
184
185    let stdout_str = String::from_utf8(buf)?;
186    let wrapper: serde_json::Value = serde_json::from_str(&stdout_str)?;
187    let result_text = wrapper["result"]
188        .as_str()
189        .ok_or_else(|| anyhow::anyhow!("missing result field in claude JSON output"))?;
190
191    Ok(result_text.to_string())
192}
193
194/// Invoke `copilot --yolo` and return the LLM response text.
195///
196/// Pipes the prompt via stdin to avoid exposing code diffs in the process table.
197/// Without a positional prompt argument, copilot reads from stdin.
198async fn invoke_copilot(prompt: &str, model: &str) -> anyhow::Result<String> {
199    use std::process::Stdio;
200    use tokio::io::{AsyncReadExt, AsyncWriteExt};
201
202    let mut child = Command::new("copilot")
203        .args(["--yolo", "--model", model])
204        .stdin(Stdio::piped())
205        .stdout(Stdio::piped())
206        .stderr(Stdio::piped())
207        .spawn()?;
208
209    // Write prompt to stdin, then close it
210    if let Some(mut stdin) = child.stdin.take() {
211        stdin.write_all(prompt.as_bytes()).await?;
212    }
213
214    // Bounded read from stdout (FINDING-11: prevent OOM from oversized LLM response)
215    let stdout_pipe = child.stdout.take()
216        .ok_or_else(|| anyhow::anyhow!("failed to capture copilot stdout"))?;
217    let mut limited = stdout_pipe.take(MAX_RESPONSE_BYTES as u64);
218    let mut buf = Vec::with_capacity(8192);
219    let bytes_read = limited.read_to_end(&mut buf).await?;
220
221    if bytes_read >= MAX_RESPONSE_BYTES {
222        child.kill().await.ok();
223        anyhow::bail!("LLM response exceeded {} byte limit", MAX_RESPONSE_BYTES);
224    }
225
226    let status = child.wait().await?;
227    if !status.success() {
228        let mut stderr_buf = Vec::new();
229        if let Some(mut stderr) = child.stderr.take() {
230            stderr.read_to_end(&mut stderr_buf).await.ok();
231        }
232        let stderr_str = String::from_utf8_lossy(&stderr_buf);
233        anyhow::bail!("copilot exited with status {}: {}", status, stderr_str);
234    }
235
236    Ok(String::from_utf8(buf)?)
237}
238
239/// Extract JSON from text that may be wrapped in ```json ... ``` code fences.
240fn extract_json(text: &str) -> anyhow::Result<String> {
241    let trimmed = text.trim();
242    // Try direct parse first
243    if trimmed.starts_with('{') {
244        return Ok(trimmed.to_string());
245    }
246    // Try extracting from code fences — find first `{` to last `}`
247    if let Some(start) = trimmed.find('{') {
248        if let Some(end) = trimmed.rfind('}') {
249            return Ok(trimmed[start..=end].to_string());
250        }
251    }
252    anyhow::bail!("no JSON object found in response")
253}
254
255/// Truncate a string to at most `max` characters, respecting UTF-8 boundaries.
256fn truncate_string(s: &str, max: usize) -> String {
257    if s.chars().count() <= max {
258        s.to_string()
259    } else {
260        s.chars().take(max).collect()
261    }
262}
263
264#[cfg(test)]
265mod tests {
266    use super::*;
267
268    #[test]
269    fn test_extract_json_direct() {
270        let input = r#"{"groups": []}"#;
271        assert_eq!(extract_json(input).unwrap(), input);
272    }
273
274    #[test]
275    fn test_extract_json_code_fences() {
276        let input = "```json\n{\"groups\": []}\n```";
277        assert_eq!(extract_json(input).unwrap(), r#"{"groups": []}"#);
278    }
279
280    #[test]
281    fn test_extract_json_no_json() {
282        assert!(extract_json("no json here").is_err());
283    }
284
285    #[test]
286    fn test_parse_hunk_level_response() {
287        let json = r#"{
288            "groups": [{
289                "label": "Auth refactor",
290                "description": "Refactored auth flow",
291                "changes": [
292                    {"file": "src/auth.rs", "hunks": [0, 2]},
293                    {"file": "src/middleware.rs", "hunks": [1]}
294                ]
295            }]
296        }"#;
297        let response: GroupingResponse = serde_json::from_str(json).unwrap();
298        assert_eq!(response.groups.len(), 1);
299        assert_eq!(response.groups[0].changes().len(), 2);
300        assert_eq!(response.groups[0].changes()[0].hunks, vec![0, 2]);
301    }
302
303    #[test]
304    fn test_parse_empty_hunks_means_all() {
305        let json = r#"{
306            "groups": [{
307                "label": "Config",
308                "description": "Config changes",
309                "changes": [{"file": "config.toml", "hunks": []}]
310            }]
311        }"#;
312        let response: GroupingResponse = serde_json::from_str(json).unwrap();
313        assert!(response.groups[0].changes()[0].hunks.is_empty());
314    }
315
316    /// Verify invoke_claude uses Stdio::piped for stdin (structural test).
317    /// This reads the source file and checks that the invoke_claude function
318    /// uses stdin(Stdio::piped()) instead of passing prompt as CLI arg.
319    #[test]
320    fn test_invoke_claude_uses_stdin_pipe() {
321        let src = include_str!("llm.rs");
322        // Find the invoke_claude function body
323        let claude_start = src.find("async fn invoke_claude").expect("invoke_claude not found");
324        let claude_body = &src[claude_start..];
325        // Find the end of the function (next "async fn" or end of non-test code)
326        let end = claude_body[1..].find("\nasync fn").unwrap_or(claude_body.len());
327        let claude_fn = &claude_body[..end];
328
329        assert!(
330            claude_fn.contains("Stdio::piped()"),
331            "invoke_claude must use Stdio::piped() for stdin"
332        );
333        assert!(
334            claude_fn.contains("write_all"),
335            "invoke_claude must write prompt to stdin via write_all"
336        );
337        // Prompt should NOT appear inside the .args([...]) array
338        if let Some(args_start) = claude_fn.find(".args([") {
339            let args_section = &claude_fn[args_start..];
340            let args_end = args_section.find("])").expect("unclosed .args");
341            let args_content = &args_section[..args_end];
342            assert!(
343                !args_content.contains("prompt"),
344                "invoke_claude must not pass prompt in .args()"
345            );
346        }
347    }
348
349    /// Verify invoke_copilot uses Stdio::piped for stdin (structural test).
350    #[test]
351    fn test_invoke_copilot_uses_stdin_pipe() {
352        let src = include_str!("llm.rs");
353        let copilot_start = src.find("async fn invoke_copilot").expect("invoke_copilot not found");
354        let copilot_body = &src[copilot_start..];
355        let end = copilot_body[1..].find("\n/// ").or_else(|| copilot_body[1..].find("\n#[cfg(test)]")).unwrap_or(copilot_body.len());
356        let copilot_fn = &copilot_body[..end];
357
358        assert!(
359            copilot_fn.contains("Stdio::piped()"),
360            "invoke_copilot must use Stdio::piped() for stdin"
361        );
362        assert!(
363            copilot_fn.contains("write_all"),
364            "invoke_copilot must write prompt to stdin via write_all"
365        );
366    }
367
368    /// Verify neither invoke function passes prompt string in .args()
369    #[test]
370    fn test_no_prompt_in_args() {
371        let src = include_str!("llm.rs");
372        // Check invoke_claude: the .args array should not contain "prompt"
373        let claude_start = src.find("async fn invoke_claude").expect("invoke_claude not found");
374        let claude_body = &src[claude_start..];
375        let end = claude_body[1..].find("\nasync fn").unwrap_or(claude_body.len());
376        let claude_fn = &claude_body[..end];
377
378        // Find the .args([...]) block and ensure "prompt" is not inside it
379        if let Some(args_start) = claude_fn.find(".args([") {
380            let args_section = &claude_fn[args_start..];
381            let args_end = args_section.find("])").expect("unclosed .args");
382            let args_content = &args_section[..args_end];
383            assert!(
384                !args_content.contains("prompt"),
385                "invoke_claude .args() must not contain prompt variable"
386            );
387        }
388
389        // Check invoke_copilot
390        let copilot_start = src.find("async fn invoke_copilot").expect("invoke_copilot not found");
391        let copilot_body = &src[copilot_start..];
392        let end2 = copilot_body[1..].find("\n/// ").or_else(|| copilot_body[1..].find("\n#[cfg(test)]")).unwrap_or(copilot_body.len());
393        let copilot_fn = &copilot_body[..end2];
394
395        if let Some(args_start) = copilot_fn.find(".args([") {
396            let args_section = &copilot_fn[args_start..];
397            let args_end = args_section.find("])").expect("unclosed .args");
398            let args_content = &args_section[..args_end];
399            assert!(
400                !args_content.contains("prompt"),
401                "invoke_copilot .args() must not contain prompt variable"
402            );
403        }
404    }
405
406    #[test]
407    fn test_parse_files_fallback() {
408        // LLM returns old "files" format instead of "changes"
409        let json = r#"{
410            "groups": [{
411                "label": "Refactor",
412                "description": "Code cleanup",
413                "files": ["src/app.rs", "src/main.rs"]
414            }]
415        }"#;
416        let response: GroupingResponse = serde_json::from_str(json).unwrap();
417        let changes = response.groups[0].changes();
418        assert_eq!(changes.len(), 2);
419        assert_eq!(changes[0].file, "src/app.rs");
420        assert!(changes[0].hunks.is_empty()); // all hunks
421    }
422
423    // --- Bounded reading tests ---
424
425    #[test]
426    fn test_read_bounded_under_limit() {
427        // Simulate: content under MAX_RESPONSE_BYTES should be fully read
428        let data = "hello world";
429        assert!(data.len() < MAX_RESPONSE_BYTES);
430        // The bounded read logic uses .take() -- we test the constant is reasonable
431        assert_eq!(MAX_RESPONSE_BYTES, 1_048_576);
432    }
433
434    #[test]
435    fn test_read_bounded_over_limit_constant() {
436        // Verify the constant is 1MB
437        assert_eq!(MAX_RESPONSE_BYTES, 1_048_576);
438        // A response at or over this limit should be rejected
439        let oversized = vec![b'x'; MAX_RESPONSE_BYTES];
440        assert!(oversized.len() >= MAX_RESPONSE_BYTES);
441    }
442
443    // --- Validation tests ---
444
445    #[test]
446    fn test_validate_rejects_oversized_json() {
447        // JSON string > MAX_JSON_SIZE (100KB) should be rejected
448        let large_json = format!(r#"{{"groups": [{{"label": "x", "description": "{}", "changes": []}}]}}"#,
449            "a".repeat(MAX_JSON_SIZE + 1));
450        assert!(large_json.len() > MAX_JSON_SIZE);
451        // In request_grouping, this would bail before deserialization
452    }
453
454    #[test]
455    fn test_validate_caps_groups_at_max() {
456        // Build JSON with more than MAX_GROUPS groups
457        let mut groups_json = Vec::new();
458        for i in 0..30 {
459            groups_json.push(format!(
460                r#"{{"label": "Group {}", "description": "desc", "changes": [{{"file": "src/f{}.rs", "hunks": [0]}}]}}"#,
461                i, i
462            ));
463        }
464        let json = format!(r#"{{"groups": [{}]}}"#, groups_json.join(","));
465        let response: GroupingResponse = serde_json::from_str(&json).unwrap();
466        assert_eq!(response.groups.len(), 30);
467        // After validation, only MAX_GROUPS (20) should remain
468        let capped: Vec<_> = response.groups.into_iter().take(MAX_GROUPS).collect();
469        assert_eq!(capped.len(), 20);
470    }
471
472    #[test]
473    fn test_validate_rejects_path_traversal() {
474        let json = r#"{
475            "groups": [{
476                "label": "Evil",
477                "description": "traversal",
478                "changes": [{"file": "../../../etc/passwd", "hunks": [0]}]
479            }]
480        }"#;
481        let response: GroupingResponse = serde_json::from_str(json).unwrap();
482        let change = &response.groups[0].changes()[0];
483        assert!(change.file.contains(".."), "path should contain traversal");
484        // In validation, this would be filtered out
485    }
486
487    #[test]
488    fn test_validate_rejects_absolute_paths() {
489        let json = r#"{
490            "groups": [{
491                "label": "Evil",
492                "description": "absolute",
493                "changes": [{"file": "/etc/passwd", "hunks": [0]}]
494            }]
495        }"#;
496        let response: GroupingResponse = serde_json::from_str(json).unwrap();
497        let change = &response.groups[0].changes()[0];
498        assert!(change.file.starts_with('/'), "path should be absolute");
499        // In validation, this would be filtered out
500    }
501
502    #[test]
503    fn test_truncate_string_label() {
504        let long_label = "a".repeat(100);
505        let truncated = truncate_string(&long_label, MAX_LABEL_LEN);
506        assert_eq!(truncated.chars().count(), MAX_LABEL_LEN);
507    }
508
509    #[test]
510    fn test_truncate_string_description() {
511        let long_desc = "b".repeat(600);
512        let truncated = truncate_string(&long_desc, MAX_DESC_LEN);
513        assert_eq!(truncated.chars().count(), MAX_DESC_LEN);
514    }
515
516    #[test]
517    fn test_validate_caps_changes_per_group() {
518        // Build a group with more than MAX_CHANGES_PER_GROUP changes
519        let mut changes = Vec::new();
520        for i in 0..250 {
521            changes.push(format!(r#"{{"file": "src/f{}.rs", "hunks": [0]}}"#, i));
522        }
523        let json = format!(
524            r#"{{"groups": [{{"label": "Big", "description": "lots", "changes": [{}]}}]}}"#,
525            changes.join(",")
526        );
527        let response: GroupingResponse = serde_json::from_str(&json).unwrap();
528        assert_eq!(response.groups[0].changes().len(), 250);
529        // After validation, changes should be capped
530        let capped: Vec<_> = response.groups[0].changes().into_iter().take(MAX_CHANGES_PER_GROUP).collect();
531        assert_eq!(capped.len(), 200);
532    }
533}