1use std::collections::HashMap;
21use std::path::PathBuf;
22
23use anyhow::Result;
24use clap::Args;
25use colored::Colorize;
26
27use tldr_core::ast::ParserPool;
28use tldr_core::{compute_taint_with_tree, get_cfg_context, get_dfg_context, Language, TaintInfo};
29
30use crate::output::OutputFormat;
31
32#[derive(Debug, Args)]
34pub struct TaintArgs {
35 pub file: PathBuf,
37
38 pub function: String,
40
41 #[arg(long, short = 'l')]
43 pub lang: Option<Language>,
44
45 #[arg(long, short = 'v')]
47 pub verbose: bool,
48}
49
50impl TaintArgs {
51 pub fn run(&self, format: OutputFormat, quiet: bool) -> Result<()> {
53 use crate::output::OutputWriter;
54
55 let writer = OutputWriter::new(format, quiet);
56
57 let language = self
59 .lang
60 .unwrap_or_else(|| Language::from_path(&self.file).unwrap_or(Language::Python));
61
62 writer.progress(&format!(
63 "Analyzing taint flows for {} in {}...",
64 self.function,
65 self.file.display()
66 ));
67
68 if !self.file.exists() {
70 return Err(anyhow::anyhow!("File not found: {}", self.file.display()));
71 }
72
73 let source = std::fs::read_to_string(&self.file)?;
74
75 let cfg = get_cfg_context(
77 self.file.to_str().unwrap_or_default(),
78 &self.function,
79 language,
80 )?;
81
82 let dfg = get_dfg_context(
84 self.file.to_str().unwrap_or_default(),
85 &self.function,
86 language,
87 )?;
88
89 let (fn_start, fn_end) = if cfg.blocks.is_empty() {
93 (1u32, source.lines().count() as u32)
94 } else {
95 let start = cfg.blocks.iter().map(|b| b.lines.0).min().unwrap_or(1);
96 let end = cfg
97 .blocks
98 .iter()
99 .map(|b| b.lines.1)
100 .max()
101 .unwrap_or(source.lines().count() as u32);
102 (start, end)
103 };
104
105 let statements: HashMap<u32, String> = source
107 .lines()
108 .enumerate()
109 .filter(|(i, _)| {
110 let line_num = (i + 1) as u32;
111 line_num >= fn_start && line_num <= fn_end
112 })
113 .map(|(i, line)| ((i + 1) as u32, line.to_string()))
114 .collect();
115
116 let pool = ParserPool::new();
118 let tree = pool.parse(&source, language).ok();
119
120 let result = compute_taint_with_tree(
122 &cfg,
123 &dfg.refs,
124 &statements,
125 tree.as_ref(),
126 Some(source.as_bytes()),
127 language,
128 )?;
129
130 match format {
132 OutputFormat::Text => {
133 let text = format_taint_text(&result, self.verbose);
134 writer.write_text(&text)?;
135 }
136 OutputFormat::Json | OutputFormat::Compact => {
137 let json = serde_json::to_string_pretty(&result)
138 .map_err(|e| anyhow::anyhow!("JSON serialization failed: {}", e))?;
139 writer.write_text(&json)?;
140 }
141 OutputFormat::Dot => {
142 let json = serde_json::to_string_pretty(&result)
144 .map_err(|e| anyhow::anyhow!("JSON serialization failed: {}", e))?;
145 writer.write_text(&json)?;
146 }
147 OutputFormat::Sarif => {
148 let json = serde_json::to_string_pretty(&result)
150 .map_err(|e| anyhow::anyhow!("JSON serialization failed: {}", e))?;
151 writer.write_text(&json)?;
152 }
153 }
154
155 Ok(())
156 }
157}
158
159fn format_taint_text(result: &TaintInfo, verbose: bool) -> String {
161 let mut output = String::new();
162
163 output.push_str(&format!(
165 "{}\n",
166 format!("Taint Analysis: {}", result.function_name)
167 .bold()
168 .cyan()
169 ));
170 output.push_str(&"=".repeat(50));
171 output.push('\n');
172
173 output.push_str(&format!(
175 "\n{} ({}):\n",
176 "Sources".bold(),
177 result.sources.len()
178 ));
179 if result.sources.is_empty() {
180 output.push_str(" No taint sources detected.\n");
181 } else {
182 for source in &result.sources {
183 output.push_str(&format!(
184 " Line {}: {} ({})\n",
185 source.line.to_string().yellow(),
186 source.var.green(),
187 format!("{:?}", source.source_type).cyan()
188 ));
189 if let Some(ref stmt) = source.statement {
190 output.push_str(&format!(" {}\n", stmt.trim().dimmed()));
191 }
192 }
193 }
194
195 output.push_str(&format!("\n{} ({}):\n", "Sinks".bold(), result.sinks.len()));
197 if result.sinks.is_empty() {
198 output.push_str(" No sinks detected.\n");
199 } else {
200 for sink in &result.sinks {
201 let status = if sink.tainted {
202 "TAINTED".red().bold().to_string()
203 } else {
204 "safe".green().to_string()
205 };
206 output.push_str(&format!(
207 " Line {}: {} ({}) - {}\n",
208 sink.line.to_string().yellow(),
209 sink.var.green(),
210 format!("{:?}", sink.sink_type).cyan(),
211 status
212 ));
213 if let Some(ref stmt) = sink.statement {
214 output.push_str(&format!(" {}\n", stmt.trim().dimmed()));
215 }
216 }
217 }
218
219 let vulns: Vec<_> = result.sinks.iter().filter(|s| s.tainted).collect();
221 output.push_str(&format!(
222 "\n{} ({}):\n",
223 "Vulnerabilities".bold().red(),
224 vulns.len()
225 ));
226 if vulns.is_empty() {
227 output.push_str(&format!(" {}\n", "No vulnerabilities found.".green()));
228 } else {
229 for sink in vulns {
230 output.push_str(&format!(
231 " {} Line {}: {} flows to {} sink\n",
232 "[!]".red().bold(),
233 sink.line.to_string().yellow(),
234 sink.var.red(),
235 format!("{:?}", sink.sink_type).cyan()
236 ));
237 }
238 }
239
240 if !result.flows.is_empty() {
242 output.push_str(&format!(
243 "\n{} ({}):\n",
244 "Taint Flows".bold(),
245 result.flows.len()
246 ));
247 for flow in &result.flows {
248 output.push_str(&format!(
249 " {} (line {}) -> {} (line {})\n",
250 flow.source.var.green(),
251 flow.source.line,
252 flow.sink.var.red(),
253 flow.sink.line
254 ));
255 if !flow.path.is_empty() {
256 output.push_str(&format!(
257 " Path: {}\n",
258 flow.path
259 .iter()
260 .map(|b| b.to_string())
261 .collect::<Vec<_>>()
262 .join(" -> ")
263 .dimmed()
264 ));
265 }
266 }
267 }
268
269 if verbose && !result.tainted_vars.is_empty() {
271 output.push_str(&format!("\n{}:\n", "Tainted Variables per Block".bold()));
272 let mut blocks: Vec<_> = result.tainted_vars.keys().collect();
273 blocks.sort();
274 for block_id in blocks {
275 if let Some(vars) = result.tainted_vars.get(block_id) {
276 if !vars.is_empty() {
277 output.push_str(&format!(
278 " Block {}: {}\n",
279 block_id,
280 vars.iter()
281 .map(|v| v.as_str())
282 .collect::<Vec<_>>()
283 .join(", ")
284 .yellow()
285 ));
286 }
287 }
288 }
289 }
290
291 if !result.sanitized_vars.is_empty() {
293 output.push_str(&format!(
294 "\n{}: {}\n",
295 "Sanitized Variables".bold(),
296 result
297 .sanitized_vars
298 .iter()
299 .map(|v| v.as_str())
300 .collect::<Vec<_>>()
301 .join(", ")
302 .green()
303 ));
304 }
305
306 output
307}
308
309#[cfg(test)]
310mod tests {
311
312 use std::collections::HashMap;
313 use std::io::Write;
314 use tempfile::NamedTempFile;
315
316 use tldr_core::ast::ParserPool;
317 use tldr_core::{
318 compute_taint_with_tree, get_cfg_context, get_dfg_context, Language, TaintSinkType,
319 };
320
321 const PYTHON_FIXTURE: &str = r#"import os
322
323def safe_func():
324 x = "hardcoded"
325 os.system(x)
326
327def vulnerable_func(user_input):
328 data = input("Enter: ")
329 query = "SELECT * FROM users WHERE id = " + data
330 os.system(user_input)
331 eval(data)
332"#;
333
334 fn run_taint_on_function(code: &str, function: &str) -> tldr_core::TaintInfo {
336 let mut tmp = NamedTempFile::with_suffix(".py").unwrap();
337 tmp.write_all(code.as_bytes()).unwrap();
338 tmp.flush().unwrap();
339 let path = tmp.path().to_str().unwrap();
340
341 let cfg = get_cfg_context(path, function, Language::Python).unwrap();
342 let dfg = get_dfg_context(path, function, Language::Python).unwrap();
343
344 let (fn_start, fn_end) = if cfg.blocks.is_empty() {
346 (1u32, code.lines().count() as u32)
347 } else {
348 let start = cfg.blocks.iter().map(|b| b.lines.0).min().unwrap_or(1);
349 let end = cfg
350 .blocks
351 .iter()
352 .map(|b| b.lines.1)
353 .max()
354 .unwrap_or(code.lines().count() as u32);
355 (start, end)
356 };
357
358 let statements: HashMap<u32, String> = code
359 .lines()
360 .enumerate()
361 .filter(|(i, _)| {
362 let line_num = (i + 1) as u32;
363 line_num >= fn_start && line_num <= fn_end
364 })
365 .map(|(i, line)| ((i + 1) as u32, line.to_string()))
366 .collect();
367
368 let pool = ParserPool::new();
369 let tree = pool.parse(code, Language::Python).ok();
370
371 compute_taint_with_tree(
372 &cfg,
373 &dfg.refs,
374 &statements,
375 tree.as_ref(),
376 Some(code.as_bytes()),
377 Language::Python,
378 )
379 .unwrap()
380 }
381
382 #[test]
383 fn test_scoped_to_function() {
384 let result = run_taint_on_function(PYTHON_FIXTURE, "vulnerable_func");
385
386 for source in &result.sources {
389 assert!(
390 source.line >= 7 && source.line <= 11,
391 "Source on line {} is outside vulnerable_func's range (7-11). \
392 Leaking from another function! var={}, type={:?}",
393 source.line,
394 source.var,
395 source.source_type
396 );
397 }
398
399 for sink in &result.sinks {
401 assert!(
402 sink.line >= 7 && sink.line <= 11,
403 "Sink on line {} is outside vulnerable_func's range (7-11). \
404 Leaking from another function! var={}, type={:?}",
405 sink.line,
406 sink.var,
407 sink.sink_type
408 );
409 }
410
411 assert!(
413 !result.sources.is_empty(),
414 "Should detect sources in vulnerable_func"
415 );
416 }
417
418 #[test]
419 fn test_sinks_detected() {
420 let result = run_taint_on_function(PYTHON_FIXTURE, "vulnerable_func");
421
422 let sink_types: Vec<_> = result.sinks.iter().map(|s| s.sink_type).collect();
423
424 assert!(
425 sink_types.contains(&TaintSinkType::ShellExec),
426 "Should detect os.system as ShellExec sink, got: {:?}",
427 sink_types
428 );
429 assert!(
430 sink_types.contains(&TaintSinkType::CodeEval),
431 "Should detect eval as CodeEval sink, got: {:?}",
432 sink_types
433 );
434 }
435
436 #[test]
437 fn test_sources_are_deduplicated() {
438 let result = run_taint_on_function(PYTHON_FIXTURE, "vulnerable_func");
439
440 let mut seen = std::collections::HashSet::new();
442 for source in &result.sources {
443 let key = (
444 source.line,
445 std::mem::discriminant(&source.source_type),
446 source.var.clone(),
447 );
448 assert!(
449 seen.insert(key),
450 "Duplicate source: line={}, var={}, type={:?}",
451 source.line,
452 source.var,
453 source.source_type
454 );
455 }
456
457 let mut seen_sinks = std::collections::HashSet::new();
459 for sink in &result.sinks {
460 let key = (
461 sink.line,
462 std::mem::discriminant(&sink.sink_type),
463 sink.var.clone(),
464 );
465 assert!(
466 seen_sinks.insert(key),
467 "Duplicate sink: line={}, var={}, type={:?}",
468 sink.line,
469 sink.var,
470 sink.sink_type
471 );
472 }
473 }
474}