1use std::path::PathBuf;
7
8use anyhow::{anyhow, Result};
9use clap::Args;
10use serde::Serialize;
11
12use tldr_core::analysis::{
13 compute_dice_similarity, interpret_similarity, normalize_tokens, NormalizationMode,
14};
15
16use crate::output::{OutputFormat, OutputWriter};
17
18#[derive(Debug, Args)]
20pub struct DiceArgs {
21 pub target1: String,
23
24 pub target2: String,
26
27 #[arg(long, default_value = "all")]
29 pub normalize: String,
30
31 #[arg(long = "language")]
33 pub language: Option<String>,
34
35 #[arg(short, long, default_value = "json")]
37 pub output: String,
38}
39
40#[derive(Debug)]
42enum Target {
43 File(PathBuf),
44 Function(PathBuf, String),
45 Block(PathBuf, usize, usize),
46}
47
48#[derive(Debug, Serialize)]
50struct DiceSimilarityReport {
51 target1: String,
53 target2: String,
55 dice_coefficient: f64,
57 interpretation: String,
59 tokens1_count: usize,
61 tokens2_count: usize,
63}
64
65impl DiceArgs {
66 pub fn run(&self, format: OutputFormat, quiet: bool) -> Result<()> {
68 let writer = OutputWriter::new(format, quiet);
69
70 writer.progress(&format!(
71 "Comparing similarity between {} and {}...",
72 self.target1, self.target2
73 ));
74
75 let target1 = parse_target(&self.target1)?;
76 let target2 = parse_target(&self.target2)?;
77
78 let normalization =
79 NormalizationMode::parse(&self.normalize).unwrap_or(NormalizationMode::All);
80
81 let (source1, lang1) = get_source(&target1, self.language.as_deref())?;
83 let (source2, lang2) = get_source(&target2, self.language.as_deref())?;
84
85 let tokens1 = normalize_tokens(&source1, &lang1, normalization)
87 .map_err(|e| anyhow!("Failed to tokenize target1: {}", e))?;
88 let tokens2 = normalize_tokens(&source2, &lang2, normalization)
89 .map_err(|e| anyhow!("Failed to tokenize target2: {}", e))?;
90
91 let dice = compute_dice_similarity(&tokens1, &tokens2);
93
94 let report = DiceSimilarityReport {
95 target1: self.target1.clone(),
96 target2: self.target2.clone(),
97 dice_coefficient: dice,
98 interpretation: interpret_similarity(dice),
99 tokens1_count: tokens1.len(),
100 tokens2_count: tokens2.len(),
101 };
102
103 let effective_format = match self.output.as_str() {
105 "text" => OutputFormat::Text,
106 "json" => format,
107 _ => format,
108 };
109
110 if matches!(effective_format, OutputFormat::Text) {
111 let text = format_dice_text(&report);
112 writer.write_text(&text)?;
113 } else {
114 writer.write(&report)?;
115 }
116
117 Ok(())
118 }
119}
120
121fn parse_target(s: &str) -> Result<Target> {
123 if let Some((path, func)) = s.split_once("::") {
125 return Ok(Target::Function(PathBuf::from(path), func.to_string()));
126 }
127
128 let parts: Vec<&str> = s.rsplitn(3, ':').collect();
132
133 if parts.len() == 3 {
134 if let (Ok(end), Ok(start)) = (parts[0].parse::<usize>(), parts[1].parse::<usize>()) {
136 return Ok(Target::Block(PathBuf::from(parts[2]), start, end));
137 }
138 }
139
140 Ok(Target::File(PathBuf::from(s)))
142}
143
144fn get_source(target: &Target, lang_hint: Option<&str>) -> Result<(String, String)> {
146 match target {
147 Target::File(path) => {
148 let source = std::fs::read_to_string(path)
149 .map_err(|e| anyhow!("Failed to read {}: {}", path.display(), e))?;
150 let lang = lang_hint
151 .map(String::from)
152 .or_else(|| detect_language(path))
153 .ok_or_else(|| anyhow!("Could not detect language for {}", path.display()))?;
154 Ok((source, lang))
155 }
156 Target::Function(path, _func_name) => {
157 let source = std::fs::read_to_string(path)
160 .map_err(|e| anyhow!("Failed to read {}: {}", path.display(), e))?;
161 let lang = lang_hint
162 .map(String::from)
163 .or_else(|| detect_language(path))
164 .ok_or_else(|| anyhow!("Could not detect language"))?;
165 Ok((source, lang))
166 }
167 Target::Block(path, start, end) => {
168 let source = std::fs::read_to_string(path)
169 .map_err(|e| anyhow!("Failed to read {}: {}", path.display(), e))?;
170 let lines: Vec<&str> = source.lines().collect();
171
172 let start_idx = start.saturating_sub(1);
174 let end_idx = (*end).min(lines.len());
175
176 let block = lines
177 .get(start_idx..end_idx)
178 .map(|l| l.join("\n"))
179 .unwrap_or_default();
180
181 let lang = lang_hint
182 .map(String::from)
183 .or_else(|| detect_language(path))
184 .ok_or_else(|| anyhow!("Could not detect language"))?;
185
186 Ok((block, lang))
187 }
188 }
189}
190
191fn detect_language(path: &std::path::Path) -> Option<String> {
193 tldr_core::Language::from_path(path).map(|l| l.to_string())
194}
195
196fn format_dice_text(report: &DiceSimilarityReport) -> String {
198 use std::fmt::Write;
199
200 let mut output = String::new();
201
202 writeln!(output, "Similarity Comparison").unwrap();
203 writeln!(output, "=====================").unwrap();
204 writeln!(output).unwrap();
205 writeln!(
206 output,
207 "Target 1: {} ({} tokens)",
208 report.target1, report.tokens1_count
209 )
210 .unwrap();
211 writeln!(
212 output,
213 "Target 2: {} ({} tokens)",
214 report.target2, report.tokens2_count
215 )
216 .unwrap();
217 writeln!(output).unwrap();
218 writeln!(
219 output,
220 "Dice coefficient: {:.2}%",
221 report.dice_coefficient * 100.0
222 )
223 .unwrap();
224 writeln!(output, "Interpretation: {}", report.interpretation).unwrap();
225
226 output
227}