tree_sitter_lint_tree_sitter_grep/
lib.rs

1#![allow(clippy::into_iter_on_ref)]
2
3use std::{
4    fmt, io,
5    path::{Path, PathBuf},
6    sync::{
7        atomic::{AtomicBool, Ordering},
8        Arc, Mutex, OnceLock,
9    },
10};
11
12use ignore::DirEntry;
13use rayon::prelude::*;
14use termcolor::{BufferWriter, ColorChoice};
15use thiserror::Error;
16use tree_sitter::{Node, Query, QueryError};
17
18mod args;
19mod language;
20mod line_buffer;
21mod lines;
22mod macros;
23mod matcher;
24mod plugin;
25mod printer;
26mod project_file_walker;
27mod query_context;
28mod searcher;
29mod sink;
30mod treesitter;
31mod use_printer;
32mod use_searcher;
33
34pub use args::Args;
35use language::{BySupportedLanguage, SupportedLanguage};
36pub use plugin::PluginInitializeReturn;
37use query_context::QueryContext;
38use treesitter::maybe_get_query;
39use use_printer::get_printer;
40use use_searcher::get_searcher;
41
42#[derive(Debug, Error)]
43pub enum Error {
44    #[error("couldn't read query file {path_to_query_file:?}")]
45    QueryFileReadError {
46        path_to_query_file: PathBuf,
47        source: io::Error,
48    },
49    #[error("{}",
50        match .0.len() {
51            1 => {
52                let (supported_language, query_error) = &.0[0];
53                format!("couldn't parse query for {supported_language:?}: {query_error}")
54            }
55            _ => {
56                let mut attempted_parsings = .0
57                    .iter()
58                    .map(|(supported_language, _)| format!("{supported_language:?}"))
59                    .collect::<Vec<_>>();
60                attempted_parsings.sort();
61                format!(
62                    "couldn't parse query for {}",
63                    join_with_or(&attempted_parsings)
64                )
65            }
66        }
67    )]
68    NoSuccessfulQueryParsing(Vec<(SupportedLanguage, /* QueryError */ String)>),
69    #[error("query must include at least one capture (\"@whatever\")")]
70    NoCaptureInQuery,
71    #[error("invalid capture name '{capture_name}'")]
72    InvalidCaptureName { capture_name: String },
73    #[error("plugin expected '--filter-arg <ARGUMENT>'")]
74    FilterPluginExpectedArgument,
75    #[error("plugin couldn't parse argument {filter_arg:?}")]
76    FilterPluginCouldntParseArgument { filter_arg: String },
77}
78
79#[derive(Debug, Error)]
80pub enum NonFatalSearchError {
81    #[error("File {path:?} is not recognized as a {specified_language:?} file")]
82    ExplicitPathArgumentNotOfSpecifiedType {
83        path: PathBuf,
84        specified_language: SupportedLanguage,
85    },
86    #[error("File {path:?} does not belong to a recognized language")]
87    ExplicitPathArgumentNotOfKnownType { path: PathBuf },
88    #[error(
89        "File {path:?} has ambiguous file-type, could be {}. Try passing the --language flag",
90        join_with_or(
91            &.languages
92                .into_iter()
93                .map(|language| format!("{}", language))
94                .collect::<Vec<_>>()
95        )
96    )]
97    AmbiguousLanguageForFile {
98        path: PathBuf,
99        languages: Vec<SupportedLanguage>,
100    },
101    #[error("The provided query could not be parsed for the language of file {path:?}")]
102    QueryNotParseableForFile { path: PathBuf },
103    #[error("No files were searched")]
104    NothingSearched,
105    #[error("{error}")]
106    IgnoreError {
107        #[from]
108        error: ignore::Error,
109    },
110}
111
112#[derive(Clone)]
113enum CaptureIndexError {
114    NoCaptureInQuery,
115    InvalidCaptureName { capture_name: String },
116}
117
118impl From<CaptureIndexError> for Error {
119    fn from(value: CaptureIndexError) -> Self {
120        match value {
121            CaptureIndexError::NoCaptureInQuery => Self::NoCaptureInQuery,
122            CaptureIndexError::InvalidCaptureName { capture_name } => {
123                Self::InvalidCaptureName { capture_name }
124            }
125        }
126    }
127}
128
129#[derive(Default)]
130struct CaptureIndex(OnceLock<Result<u32, CaptureIndexError>>);
131
132impl CaptureIndex {
133    pub fn get_or_init(
134        &self,
135        query: &Query,
136        capture_name: Option<&str>,
137    ) -> Result<u32, CaptureIndexError> {
138        self.0
139            .get_or_init(|| match capture_name {
140                None => match query.capture_names().len() {
141                    0 => Err(CaptureIndexError::NoCaptureInQuery),
142                    _ => Ok(0),
143                },
144                Some(capture_name) => query.capture_index_for_name(capture_name).ok_or_else(|| {
145                    CaptureIndexError::InvalidCaptureName {
146                        capture_name: capture_name.to_owned(),
147                    }
148                }),
149            })
150            .clone()
151    }
152}
153
154fn join_with_or<TItem: fmt::Display>(list: &[TItem]) -> String {
155    let mut ret: String = Default::default();
156    for (index, item) in list.iter().enumerate() {
157        ret.push_str(&item.to_string());
158        if list.len() >= 2 && index < list.len() - 2 {
159            ret.push_str(", ");
160        } else if list.len() >= 2 && index == list.len() - 2 {
161            ret.push_str(if list.len() == 2 { " or " } else { ", or " });
162        }
163    }
164    ret
165}
166
167#[derive(Default)]
168struct CachedQueries(BySupportedLanguage<OnceLock<Result<Arc<Query>, QueryError>>>);
169
170impl CachedQueries {
171    fn get_and_cache_query_for_language(
172        &self,
173        query_text: &str,
174        language: SupportedLanguage,
175    ) -> Option<Arc<Query>> {
176        self.0[language]
177            .get_or_init(|| maybe_get_query(query_text, language.language()).map(Arc::new))
178            .as_ref()
179            .ok()
180            .cloned()
181    }
182
183    fn error_if_no_successful_query_parsing(&self) -> Result<(), Error> {
184        if !self.0.values().any(|query| {
185            query
186                .get()
187                .and_then(|result| result.as_ref().ok())
188                .is_some()
189        }) {
190            let attempted_parsings = self
191                .0
192                .iter()
193                .filter(|(_, value)| value.get().is_some())
194                .collect::<Vec<_>>();
195            assert!(
196                !attempted_parsings.is_empty(),
197                "Should've tried to parse in at least one language or else should've already failed on no candidate files"
198            );
199            return Err(Error::NoSuccessfulQueryParsing(
200                attempted_parsings
201                    .into_iter()
202                    .map(|(supported_language, once_lock)| {
203                        (
204                            supported_language,
205                            format!("{}", once_lock.get().unwrap().as_ref().unwrap_err()),
206                        )
207                    })
208                    .collect(),
209            ));
210        }
211
212        Ok(())
213    }
214}
215
216pub struct RunStatus {
217    pub matched: bool,
218    pub non_fatal_errors: Vec<NonFatalSearchError>,
219}
220
221enum SingleFileSearchError {
222    NonFatalSearchError(NonFatalSearchError),
223    FatalError(Error),
224}
225
226impl From<Error> for SingleFileSearchError {
227    fn from(value: Error) -> Self {
228        Self::FatalError(value)
229    }
230}
231
232impl From<NonFatalSearchError> for SingleFileSearchError {
233    fn from(value: NonFatalSearchError) -> Self {
234        Self::NonFatalSearchError(value)
235    }
236}
237
238impl<TSuccess> From<Error> for Result<TSuccess, SingleFileSearchError> {
239    fn from(value: Error) -> Self {
240        Err(value.into())
241    }
242}
243
244impl<TSuccess> From<NonFatalSearchError> for Result<TSuccess, SingleFileSearchError> {
245    fn from(value: NonFatalSearchError) -> Self {
246        Err(value.into())
247    }
248}
249
250pub struct OutputContext {
251    pub buffer_writer: BufferWriter,
252}
253
254impl OutputContext {
255    pub fn new(buffer_writer: BufferWriter) -> Self {
256        Self { buffer_writer }
257    }
258}
259
260pub fn run_print(args: Args) -> Result<RunStatus, Error> {
261    run_for_context(
262        args,
263        OutputContext::new(BufferWriter::stdout(ColorChoice::Never)),
264        |context: &OutputContext,
265         args: &Args,
266         path: &Path,
267         query_context: QueryContext,
268         matched: &AtomicBool| {
269            let printer = get_printer(&context.buffer_writer, args);
270            let mut printer = printer.borrow_mut();
271
272            printer.get_mut().clear();
273            let mut sink = printer.sink_with_path(path);
274            get_searcher(args)
275                .borrow_mut()
276                .search_path(query_context, path, &mut sink)
277                .unwrap();
278            if sink.has_match() {
279                matched.store(true, Ordering::SeqCst);
280            }
281            context.buffer_writer.print(printer.get_mut()).unwrap();
282        },
283    )
284}
285
286pub struct CaptureInfo<'node> {
287    pub node: Node<'node>,
288    pub pattern_index: usize,
289}
290
291pub fn run_with_callback(
292    args: Args,
293    callback: impl Fn(CaptureInfo, &[u8], &Path) + Sync,
294) -> Result<RunStatus, Error> {
295    run_for_context(
296        args,
297        (),
298        |_context: &(),
299         args: &Args,
300         path: &Path,
301         query_context: QueryContext,
302         matched: &AtomicBool| {
303            get_searcher(args)
304                .borrow_mut()
305                .search_path_callback::<_, io::Error>(
306                    query_context,
307                    path,
308                    |capture_info: CaptureInfo, file_contents: &[u8], path: &Path| {
309                        callback(capture_info, file_contents, path);
310                        matched.store(true, Ordering::SeqCst);
311                    },
312                )
313                .unwrap();
314        },
315    )
316}
317
318fn run_for_context<TContext: Sync>(
319    args: Args,
320    context: TContext,
321    search_file: impl Fn(&TContext, &Args, &Path, QueryContext, &AtomicBool) + Sync,
322) -> Result<RunStatus, Error> {
323    let query_text = args.get_loaded_query_text()?;
324    let filter = args.get_loaded_filter()?;
325    let cached_queries: CachedQueries = Default::default();
326    let capture_index = CaptureIndex::default();
327    let matched = AtomicBool::new(false);
328    let searched = AtomicBool::new(false);
329    let non_fatal_errors: Arc<Mutex<Vec<NonFatalSearchError>>> = Default::default();
330
331    for_each_project_file(
332        &args,
333        non_fatal_errors.clone(),
334        |project_file_dir_entry, matched_languages| {
335            searched.store(true, Ordering::SeqCst);
336            let language = match args.language {
337                Some(specified_language) => {
338                    if !matched_languages.contains(&specified_language) {
339                        return NonFatalSearchError::ExplicitPathArgumentNotOfSpecifiedType {
340                            path: project_file_dir_entry.path().to_owned(),
341                            specified_language,
342                        }
343                        .into();
344                    }
345                    specified_language
346                }
347                None => match matched_languages.len() {
348                    0 => {
349                        return NonFatalSearchError::ExplicitPathArgumentNotOfKnownType {
350                            path: project_file_dir_entry.path().to_owned(),
351                        }
352                        .into();
353                    }
354                    1 => matched_languages[0],
355                    _ => {
356                        let successfully_parsed_query_languages = matched_languages
357                            .iter()
358                            .filter_map(|&matched_language| {
359                                cached_queries
360                                    .get_and_cache_query_for_language(&query_text, matched_language)
361                                    .map(|_| matched_language)
362                            })
363                            .collect::<Vec<_>>();
364                        match successfully_parsed_query_languages.len() {
365                            0 => {
366                                return NonFatalSearchError::QueryNotParseableForFile {
367                                    path: project_file_dir_entry.path().to_owned(),
368                                }
369                                .into();
370                            }
371                            1 => successfully_parsed_query_languages[0],
372                            _ => {
373                                return NonFatalSearchError::AmbiguousLanguageForFile {
374                                    path: project_file_dir_entry.path().to_owned(),
375                                    languages: successfully_parsed_query_languages,
376                                }
377                                .into();
378                            }
379                        }
380                    }
381                },
382            };
383            let query = cached_queries
384                .get_and_cache_query_for_language(&query_text, language)
385                .ok_or_else(|| NonFatalSearchError::QueryNotParseableForFile {
386                    path: project_file_dir_entry.path().to_owned(),
387                })?;
388            let capture_index = capture_index
389                .get_or_init(&query, args.capture_name.as_deref())
390                .map_err(Error::from)?;
391            let path =
392                format_relative_path(project_file_dir_entry.path(), args.is_using_default_paths());
393
394            let query_context =
395                QueryContext::new(query, capture_index, language.language(), filter.clone());
396
397            search_file(&context, &args, path, query_context, &matched);
398
399            Ok(())
400        },
401    )?;
402
403    let mut non_fatal_errors = Arc::into_inner(non_fatal_errors)
404        .unwrap()
405        .into_inner()
406        .unwrap()
407        .into_iter()
408        .filter(|non_fatal_error| {
409            !matches!(
410                non_fatal_error,
411                NonFatalSearchError::QueryNotParseableForFile { .. }
412            )
413        })
414        .collect::<Vec<_>>();
415    if non_fatal_errors.is_empty() {
416        if !searched.load(Ordering::SeqCst) {
417            non_fatal_errors.push(NonFatalSearchError::NothingSearched);
418        } else {
419            cached_queries.error_if_no_successful_query_parsing()?;
420        }
421    }
422
423    Ok(RunStatus {
424        matched: matched.load(Ordering::SeqCst),
425        non_fatal_errors,
426    })
427}
428
429fn for_each_project_file(
430    args: &Args,
431    non_fatal_errors: Arc<Mutex<Vec<NonFatalSearchError>>>,
432    callback: impl Fn(DirEntry, Vec<SupportedLanguage>) -> Result<(), SingleFileSearchError> + Sync,
433) -> Result<(), Error> {
434    let fatal_error: Mutex<Option<Error>> = Default::default();
435    args.get_project_file_parallel_iterator(non_fatal_errors.clone())
436        .for_each(|(project_file_dir_entry, matched_languages)| {
437            if fatal_error.lock().unwrap().is_some() {
438                return;
439            }
440
441            if let Err(error) = callback(project_file_dir_entry, matched_languages) {
442                match error {
443                    SingleFileSearchError::NonFatalSearchError(error) => {
444                        non_fatal_errors.lock().unwrap().push(error);
445                    }
446                    SingleFileSearchError::FatalError(error) => {
447                        *fatal_error.lock().unwrap() = Some(error);
448                    }
449                }
450            }
451        });
452
453    match fatal_error.into_inner().unwrap() {
454        Some(fatal_error) => Err(fatal_error),
455        None => Ok(()),
456    }
457}
458
459#[macro_export]
460macro_rules! only_run_once {
461    ($block:block) => {
462        static ONCE_LOCK: std::sync::OnceLock<()> = OnceLock::new();
463        ONCE_LOCK.get_or_init(|| {
464            $block;
465        });
466    };
467}
468
469fn format_relative_path(path: &Path, is_using_default_paths: bool) -> &Path {
470    if is_using_default_paths && path.starts_with("./") {
471        path.strip_prefix("./").unwrap()
472    } else {
473        path
474    }
475}