1use std::collections::{HashMap, HashSet};
31use std::path::Path;
32
33use serde::{Deserialize, Serialize};
34
35use crate::pdg::get_pdg_context;
36use crate::types::{DependenceType, Language, PdgInfo, SliceDirection};
37use crate::TldrResult;
38
39#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct SliceNode {
46 pub line: u32,
48 pub code: String,
50 pub node_type: String,
52 pub definitions: Vec<String>,
54 pub uses: Vec<String>,
56 #[serde(skip_serializing_if = "Option::is_none")]
58 pub dep_type: Option<String>,
59 #[serde(skip_serializing_if = "Option::is_none")]
61 pub dep_label: Option<String>,
62}
63
64#[derive(Debug, Clone, Serialize, Deserialize)]
66pub struct SliceEdge {
67 pub from_line: u32,
69 pub to_line: u32,
71 pub dep_type: String,
73 pub label: String,
75}
76
77#[derive(Debug, Clone, Serialize, Deserialize)]
79pub struct RichSlice {
80 pub nodes: Vec<SliceNode>,
82 pub edges: Vec<SliceEdge>,
84}
85
86pub fn get_slice(
116 source_or_path: &str,
117 function_name: &str,
118 line: u32,
119 direction: SliceDirection,
120 variable: Option<&str>,
121 language: Language,
122) -> TldrResult<HashSet<u32>> {
123 let pdg = get_pdg_context(source_or_path, function_name, language)?;
125
126 let start_nodes = find_nodes_for_line(&pdg, line);
128
129 if start_nodes.is_empty() {
130 return Ok(HashSet::new());
132 }
133
134 let slice = compute_slice(&pdg, &start_nodes, direction, variable);
136
137 let lines = nodes_to_lines(&pdg, &slice);
139
140 Ok(lines)
141}
142
143pub fn get_slice_rich(
160 source_or_path: &str,
161 function_name: &str,
162 line: u32,
163 direction: SliceDirection,
164 variable: Option<&str>,
165 language: Language,
166) -> TldrResult<RichSlice> {
167 let pdg = get_pdg_context(source_or_path, function_name, language)?;
169
170 let start_nodes = find_nodes_for_line(&pdg, line);
172
173 if start_nodes.is_empty() {
174 return Ok(RichSlice {
175 nodes: Vec::new(),
176 edges: Vec::new(),
177 });
178 }
179
180 let visited = compute_slice(&pdg, &start_nodes, direction, variable);
182
183 let source_lines = read_source_lines(source_or_path);
185
186 let visited_nodes: Vec<&crate::types::PdgNode> = pdg
188 .nodes
189 .iter()
190 .filter(|n| visited.contains(&n.id))
191 .collect();
192
193 let mut line_map: HashMap<u32, SliceNode> = HashMap::new();
196
197 for node in &visited_nodes {
198 for l in node.lines.0..=node.lines.1 {
199 if l == 0 {
200 continue;
201 }
202 let code = source_lines
203 .get((l as usize).wrapping_sub(1))
204 .map(|s| s.trim_end().to_string())
205 .unwrap_or_default();
206
207 let entry = line_map.entry(l).or_insert_with(|| SliceNode {
208 line: l,
209 code,
210 node_type: node.node_type.clone(),
211 definitions: Vec::new(),
212 uses: Vec::new(),
213 dep_type: None,
214 dep_label: None,
215 });
216
217 for d in &node.definitions {
219 if !entry.definitions.contains(d) {
220 entry.definitions.push(d.clone());
221 }
222 }
223 for u in &node.uses {
224 if !entry.uses.contains(u) {
225 entry.uses.push(u.clone());
226 }
227 }
228 }
229 }
230
231 let mut edges: Vec<SliceEdge> = Vec::new();
233 for edge in &pdg.edges {
234 if visited.contains(&edge.source_id) && visited.contains(&edge.target_id) {
235 let from_line = node_id_to_line(&pdg, edge.source_id);
237 let to_line = node_id_to_line(&pdg, edge.target_id);
238 if let (Some(from), Some(to)) = (from_line, to_line) {
239 let dep_str = match edge.dep_type {
240 DependenceType::Data => "data",
241 DependenceType::Control => "control",
242 };
243 edges.push(SliceEdge {
244 from_line: from,
245 to_line: to,
246 dep_type: dep_str.to_string(),
247 label: edge.label.clone(),
248 });
249
250 if let Some(node) = line_map.get_mut(&to) {
252 if node.dep_type.is_none() {
253 node.dep_type = Some(dep_str.to_string());
254 if !edge.label.is_empty() {
255 node.dep_label = Some(edge.label.clone());
256 }
257 }
258 }
259 }
260 }
261 }
262
263 edges.sort_by_key(|e| (e.from_line, e.to_line));
265 edges.dedup_by(|a, b| {
267 a.from_line == b.from_line
268 && a.to_line == b.to_line
269 && a.dep_type == b.dep_type
270 && a.label == b.label
271 });
272
273 let mut nodes: Vec<SliceNode> = line_map.into_values().collect();
275 nodes.sort_by_key(|n| n.line);
276
277 Ok(RichSlice { nodes, edges })
278}
279
280fn read_source_lines(source_or_path: &str) -> Vec<String> {
282 let path = Path::new(source_or_path);
283 if path.exists() && path.is_file() {
284 match std::fs::read_to_string(path) {
285 Ok(content) => content.lines().map(|l| l.to_string()).collect(),
286 Err(_) => source_or_path.lines().map(|l| l.to_string()).collect(),
287 }
288 } else {
289 source_or_path.lines().map(|l| l.to_string()).collect()
290 }
291}
292
293fn node_id_to_line(pdg: &PdgInfo, node_id: usize) -> Option<u32> {
295 pdg.nodes
296 .iter()
297 .find(|n| n.id == node_id)
298 .map(|n| n.lines.0)
299 .filter(|&l| l > 0)
300}
301
302fn find_nodes_for_line(pdg: &PdgInfo, line: u32) -> Vec<usize> {
304 pdg.nodes
305 .iter()
306 .filter(|n| line >= n.lines.0 && line <= n.lines.1)
307 .map(|n| n.id)
308 .collect()
309}
310
311fn compute_slice(
313 pdg: &PdgInfo,
314 start_nodes: &[usize],
315 direction: SliceDirection,
316 variable: Option<&str>,
317) -> HashSet<usize> {
318 let mut visited = HashSet::new();
319 let mut worklist: Vec<usize> = start_nodes.to_vec();
320
321 while let Some(node_id) = worklist.pop() {
322 if visited.contains(&node_id) {
323 continue;
324 }
325 visited.insert(node_id);
326
327 let adjacent = match direction {
329 SliceDirection::Backward => {
330 pdg.edges
332 .iter()
333 .filter(|e| e.target_id == node_id)
334 .filter(|e| should_follow_edge(e, variable))
335 .map(|e| e.source_id)
336 .collect::<Vec<_>>()
337 }
338 SliceDirection::Forward => {
339 pdg.edges
341 .iter()
342 .filter(|e| e.source_id == node_id)
343 .filter(|e| should_follow_edge(e, variable))
344 .map(|e| e.target_id)
345 .collect::<Vec<_>>()
346 }
347 };
348
349 for adj in adjacent {
350 if !visited.contains(&adj) {
351 worklist.push(adj);
352 }
353 }
354 }
355
356 visited
357}
358
359fn should_follow_edge(edge: &crate::types::PdgEdge, variable: Option<&str>) -> bool {
361 match variable {
362 None => true, Some(var) => {
364 match edge.dep_type {
365 DependenceType::Control => true, DependenceType::Data => edge.label == var, }
368 }
369 }
370}
371
372fn nodes_to_lines(pdg: &PdgInfo, node_ids: &HashSet<usize>) -> HashSet<u32> {
374 let mut lines = HashSet::new();
375
376 for &node_id in node_ids {
377 if let Some(node) = pdg.nodes.iter().find(|n| n.id == node_id) {
378 for line in node.lines.0..=node.lines.1 {
380 if line > 0 {
381 lines.insert(line);
382 }
383 }
384 }
385 }
386
387 lines
388}
389
390#[cfg(test)]
391mod tests {
392 use super::*;
393
394 #[test]
395 fn test_backward_slice_simple() {
396 let source = r#"
397def foo():
398 x = 1
399 y = x + 2
400 return y
401"#;
402 let slice = get_slice(
403 source,
404 "foo",
405 4,
406 SliceDirection::Backward,
407 None,
408 Language::Python,
409 )
410 .unwrap();
411
412 assert!(!slice.is_empty(), "slice should not be empty");
414 }
415
416 #[test]
417 fn test_forward_slice_simple() {
418 let source = r#"
419def foo():
420 x = 1
421 y = x + 2
422 return y
423"#;
424 let slice = get_slice(
426 source,
427 "foo",
428 3,
429 SliceDirection::Forward,
430 None,
431 Language::Python,
432 )
433 .unwrap();
434
435 assert!(slice.contains(&3), "slice should include the starting line");
438 }
439
440 #[test]
441 fn test_slice_with_variable_filter() {
442 let source = r#"
443def foo():
444 x = 1
445 y = 2
446 z = x + y
447 return z
448"#;
449 let slice = get_slice(
450 source,
451 "foo",
452 5,
453 SliceDirection::Backward,
454 Some("x"),
455 Language::Python,
456 )
457 .unwrap();
458
459 assert!(!slice.is_empty(), "slice should not be empty");
462 }
463
464 #[test]
465 fn test_slice_line_not_in_function() {
466 let source = "def foo(): pass";
467 let slice = get_slice(
468 source,
469 "foo",
470 999,
471 SliceDirection::Backward,
472 None,
473 Language::Python,
474 )
475 .unwrap();
476
477 assert!(
479 slice.is_empty(),
480 "slice for non-existent line should be empty"
481 );
482 }
483
484 #[test]
485 fn test_slice_returns_line_numbers() {
486 let source = r#"
487def foo():
488 x = 1
489 return x
490"#;
491 let slice = get_slice(
492 source,
493 "foo",
494 3,
495 SliceDirection::Backward,
496 None,
497 Language::Python,
498 )
499 .unwrap();
500
501 for &line in &slice {
503 assert!(line > 0, "line numbers should be positive");
504 }
505 }
506
507 #[test]
508 fn test_backward_slice_with_control_deps() {
509 let source = r#"
510def foo(cond):
511 if cond:
512 x = 1
513 else:
514 x = 2
515 return x
516"#;
517 let slice = get_slice(
518 source,
519 "foo",
520 6,
521 SliceDirection::Backward,
522 None,
523 Language::Python,
524 )
525 .unwrap();
526
527 assert!(
529 !slice.is_empty(),
530 "slice should include control dependencies"
531 );
532 }
533
534 #[test]
535 fn test_forward_slice_traces_all_vars() {
536 let source = r#"
537def foo():
538 x = 1
539 y = x
540 z = y
541 return z
542"#;
543 let slice = get_slice(
545 source,
546 "foo",
547 3,
548 SliceDirection::Forward,
549 None,
550 Language::Python,
551 )
552 .unwrap();
553
554 assert!(
557 slice.contains(&3),
558 "forward slice should include the starting line"
559 );
560 }
561
562 #[test]
567 fn test_rich_slice_returns_nodes_with_code() {
568 let source = r#"
569def foo():
570 x = 1
571 y = x + 2
572 return y
573"#;
574 let rich = get_slice_rich(
575 source,
576 "foo",
577 4,
578 SliceDirection::Backward,
579 None,
580 Language::Python,
581 )
582 .unwrap();
583
584 assert!(!rich.nodes.is_empty(), "rich slice should have nodes");
586 for node in &rich.nodes {
587 assert!(!node.code.is_empty(), "each node should have code content");
588 assert!(node.line > 0, "line numbers should be positive");
589 }
590 }
591
592 #[test]
593 fn test_rich_slice_nodes_sorted_by_line() {
594 let source = r#"
595def foo():
596 x = 1
597 y = x + 2
598 return y
599"#;
600 let rich = get_slice_rich(
601 source,
602 "foo",
603 5,
604 SliceDirection::Backward,
605 None,
606 Language::Python,
607 )
608 .unwrap();
609
610 let lines: Vec<u32> = rich.nodes.iter().map(|n| n.line).collect();
612 let mut sorted = lines.clone();
613 sorted.sort();
614 assert_eq!(lines, sorted, "nodes should be sorted by line number");
615 }
616
617 #[test]
618 fn test_rich_slice_code_is_trimmed() {
619 let source = r#"
620def foo():
621 x = 1
622 y = x + 2
623 return y
624"#;
625 let rich = get_slice_rich(
626 source,
627 "foo",
628 5,
629 SliceDirection::Backward,
630 None,
631 Language::Python,
632 )
633 .unwrap();
634
635 for node in &rich.nodes {
636 assert_eq!(
637 node.code,
638 node.code.trim_end(),
639 "code should have trailing whitespace trimmed"
640 );
641 }
642 }
643
644 #[test]
645 fn test_rich_slice_preserves_definitions_and_uses() {
646 let source = r#"
647def foo():
648 x = 1
649 y = x + 2
650 return y
651"#;
652 let rich = get_slice_rich(
653 source,
654 "foo",
655 5,
656 SliceDirection::Backward,
657 None,
658 Language::Python,
659 )
660 .unwrap();
661
662 let has_defs = rich.nodes.iter().any(|n| !n.definitions.is_empty());
664 let has_uses = rich.nodes.iter().any(|n| !n.uses.is_empty());
665 assert!(
666 has_defs || has_uses,
667 "rich slice should preserve definition/use info from PDG"
668 );
669 }
670
671 #[test]
672 fn test_rich_slice_has_node_types() {
673 let source = r#"
674def foo():
675 x = 1
676 y = x + 2
677 return y
678"#;
679 let rich = get_slice_rich(
680 source,
681 "foo",
682 5,
683 SliceDirection::Backward,
684 None,
685 Language::Python,
686 )
687 .unwrap();
688
689 for node in &rich.nodes {
690 assert!(
691 !node.node_type.is_empty(),
692 "each node should have a node_type"
693 );
694 }
695 }
696
697 #[test]
698 fn test_rich_slice_edges_within_slice() {
699 let source = r#"
700def foo():
701 x = 1
702 y = x + 2
703 return y
704"#;
705 let rich = get_slice_rich(
706 source,
707 "foo",
708 5,
709 SliceDirection::Backward,
710 None,
711 Language::Python,
712 )
713 .unwrap();
714
715 let slice_lines: std::collections::HashSet<u32> =
716 rich.nodes.iter().map(|n| n.line).collect();
717 for edge in &rich.edges {
719 assert!(
720 slice_lines.contains(&edge.from_line),
721 "edge from_line {} should be in slice",
722 edge.from_line
723 );
724 assert!(
725 slice_lines.contains(&edge.to_line),
726 "edge to_line {} should be in slice",
727 edge.to_line
728 );
729 }
730 }
731
732 #[test]
733 fn test_rich_slice_edge_dep_types() {
734 let source = r#"
735def foo(cond):
736 if cond:
737 x = 1
738 else:
739 x = 2
740 return x
741"#;
742 let rich = get_slice_rich(
743 source,
744 "foo",
745 7,
746 SliceDirection::Backward,
747 None,
748 Language::Python,
749 )
750 .unwrap();
751
752 for edge in &rich.edges {
754 assert!(
755 edge.dep_type == "data" || edge.dep_type == "control",
756 "edge dep_type should be 'data' or 'control', got '{}'",
757 edge.dep_type
758 );
759 }
760 }
761
762 #[test]
763 fn test_rich_slice_empty_for_invalid_line() {
764 let source = "def foo(): pass";
765 let rich = get_slice_rich(
766 source,
767 "foo",
768 999,
769 SliceDirection::Backward,
770 None,
771 Language::Python,
772 )
773 .unwrap();
774
775 assert!(
776 rich.nodes.is_empty(),
777 "rich slice for non-existent line should have no nodes"
778 );
779 assert!(
780 rich.edges.is_empty(),
781 "rich slice for non-existent line should have no edges"
782 );
783 }
784
785 #[test]
786 fn test_rich_slice_from_file_path() {
787 use std::io::Write;
789 let dir = std::env::temp_dir();
790 let path = dir.join("test_slice_rich.py");
791 let mut f = std::fs::File::create(&path).unwrap();
792 writeln!(f, "def bar():").unwrap();
793 writeln!(f, " a = 10").unwrap();
794 writeln!(f, " b = a + 1").unwrap();
795 writeln!(f, " return b").unwrap();
796
797 let rich = get_slice_rich(
798 path.to_str().unwrap(),
799 "bar",
800 4,
801 SliceDirection::Backward,
802 None,
803 Language::Python,
804 )
805 .unwrap();
806
807 assert!(!rich.nodes.is_empty(), "should work with file path input");
808 let has_return = rich.nodes.iter().any(|n| n.code.contains("return"));
810 assert!(has_return, "should contain the criterion line code");
811
812 std::fs::remove_file(&path).ok();
813 }
814
815 #[test]
816 fn test_rich_slice_backward_compat_with_get_slice() {
817 let source = r#"
818def foo():
819 x = 1
820 y = x + 2
821 return y
822"#;
823 let plain = get_slice(
824 source,
825 "foo",
826 5,
827 SliceDirection::Backward,
828 None,
829 Language::Python,
830 )
831 .unwrap();
832 let rich = get_slice_rich(
833 source,
834 "foo",
835 5,
836 SliceDirection::Backward,
837 None,
838 Language::Python,
839 )
840 .unwrap();
841
842 let rich_lines: HashSet<u32> = rich.nodes.iter().map(|n| n.line).collect();
844 assert_eq!(
845 plain, rich_lines,
846 "rich slice lines should match plain slice lines"
847 );
848 }
849}