squawk_linter/
lib.rs

1use std::collections::HashSet;
2use std::fmt;
3
4use enum_iterator::Sequence;
5use enum_iterator::all;
6pub use ignore::Ignore;
7use ignore::find_ignores;
8use ignore_index::IgnoreIndex;
9use rowan::TextRange;
10use rowan::TextSize;
11use serde::Deserialize;
12
13use squawk_syntax::SyntaxNode;
14use squawk_syntax::{Parse, SourceFile};
15
16pub use version::Version;
17
18pub mod analyze;
19pub mod ignore;
20mod ignore_index;
21mod version;
22mod visitors;
23
24mod rules;
25
26#[cfg(test)]
27mod test_utils;
28use rules::adding_field_with_default;
29use rules::adding_foreign_key_constraint;
30use rules::adding_not_null_field;
31use rules::adding_primary_key_constraint;
32use rules::adding_required_field;
33use rules::ban_alter_domain_with_add_constraint;
34use rules::ban_char_field;
35use rules::ban_concurrent_index_creation_in_transaction;
36use rules::ban_create_domain_with_constraint;
37use rules::ban_drop_column;
38use rules::ban_drop_database;
39use rules::ban_drop_not_null;
40use rules::ban_drop_table;
41use rules::ban_truncate_cascade;
42use rules::changing_column_type;
43use rules::constraint_missing_not_valid;
44use rules::disallow_unique_constraint;
45use rules::prefer_bigint_over_int;
46use rules::prefer_bigint_over_smallint;
47use rules::prefer_identity;
48use rules::prefer_robust_stmts;
49use rules::prefer_text_field;
50use rules::prefer_timestamptz;
51use rules::renaming_column;
52use rules::renaming_table;
53use rules::require_concurrent_index_creation;
54use rules::require_concurrent_index_deletion;
55use rules::require_timeout_settings;
56use rules::transaction_nesting;
57// xtask:new-rule:rule-import
58
59#[derive(Debug, PartialEq, Clone, Copy, Hash, Eq, Sequence)]
60pub enum Rule {
61    RequireConcurrentIndexCreation,
62    RequireConcurrentIndexDeletion,
63    ConstraintMissingNotValid,
64    AddingFieldWithDefault,
65    AddingForeignKeyConstraint,
66    ChangingColumnType,
67    AddingNotNullableField,
68    AddingSerialPrimaryKeyField,
69    RenamingColumn,
70    RenamingTable,
71    DisallowedUniqueConstraint,
72    BanDropDatabase,
73    PreferBigintOverInt,
74    PreferBigintOverSmallint,
75    PreferIdentity,
76    PreferRobustStmts,
77    PreferTextField,
78    PreferTimestampTz,
79    BanCharField,
80    BanDropColumn,
81    BanDropTable,
82    BanDropNotNull,
83    TransactionNesting,
84    AddingRequiredField,
85    BanConcurrentIndexCreationInTransaction,
86    UnusedIgnore,
87    BanCreateDomainWithConstraint,
88    BanAlterDomainWithAddConstraint,
89    BanTruncateCascade,
90    RequireTimeoutSettings,
91    // xtask:new-rule:error-name
92}
93
94impl TryFrom<&str> for Rule {
95    type Error = String;
96
97    fn try_from(s: &str) -> Result<Self, Self::Error> {
98        match s {
99            "require-concurrent-index-creation" => Ok(Rule::RequireConcurrentIndexCreation),
100            "require-concurrent-index-deletion" => Ok(Rule::RequireConcurrentIndexDeletion),
101            "constraint-missing-not-valid" => Ok(Rule::ConstraintMissingNotValid),
102            "adding-field-with-default" => Ok(Rule::AddingFieldWithDefault),
103            "adding-foreign-key-constraint" => Ok(Rule::AddingForeignKeyConstraint),
104            "changing-column-type" => Ok(Rule::ChangingColumnType),
105            "adding-not-nullable-field" => Ok(Rule::AddingNotNullableField),
106            "adding-serial-primary-key-field" => Ok(Rule::AddingSerialPrimaryKeyField),
107            "renaming-column" => Ok(Rule::RenamingColumn),
108            "renaming-table" => Ok(Rule::RenamingTable),
109            "disallowed-unique-constraint" => Ok(Rule::DisallowedUniqueConstraint),
110            "ban-drop-database" => Ok(Rule::BanDropDatabase),
111            "prefer-bigint-over-int" => Ok(Rule::PreferBigintOverInt),
112            "prefer-bigint-over-smallint" => Ok(Rule::PreferBigintOverSmallint),
113            "prefer-identity" => Ok(Rule::PreferIdentity),
114            "prefer-robust-stmts" => Ok(Rule::PreferRobustStmts),
115            "prefer-text-field" => Ok(Rule::PreferTextField),
116            // this is typo'd so we just support both
117            "prefer-timestamptz" => Ok(Rule::PreferTimestampTz),
118            "prefer-timestamp-tz" => Ok(Rule::PreferTimestampTz),
119            "ban-char-field" => Ok(Rule::BanCharField),
120            "ban-drop-column" => Ok(Rule::BanDropColumn),
121            "ban-drop-table" => Ok(Rule::BanDropTable),
122            "ban-drop-not-null" => Ok(Rule::BanDropNotNull),
123            "transaction-nesting" => Ok(Rule::TransactionNesting),
124            "adding-required-field" => Ok(Rule::AddingRequiredField),
125            "ban-concurrent-index-creation-in-transaction" => {
126                Ok(Rule::BanConcurrentIndexCreationInTransaction)
127            }
128            "ban-create-domain-with-constraint" => Ok(Rule::BanCreateDomainWithConstraint),
129            "ban-alter-domain-with-add-constraint" => Ok(Rule::BanAlterDomainWithAddConstraint),
130            "ban-truncate-cascade" => Ok(Rule::BanTruncateCascade),
131            "require-timeout-settings" => Ok(Rule::RequireTimeoutSettings),
132            // xtask:new-rule:str-name
133            _ => Err(format!("Unknown violation name: {s}")),
134        }
135    }
136}
137
138#[derive(Debug, Clone, PartialEq, Eq)]
139pub struct UnknownRuleName {
140    val: String,
141}
142
143impl std::fmt::Display for UnknownRuleName {
144    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
145        write!(f, "invalid rule name {}", self.val)
146    }
147}
148
149impl std::error::Error for UnknownRuleName {}
150
151impl std::str::FromStr for Rule {
152    type Err = UnknownRuleName;
153    fn from_str(s: &str) -> Result<Self, Self::Err> {
154        Rule::try_from(s).map_err(|_| UnknownRuleName { val: s.to_string() })
155    }
156}
157
158impl fmt::Display for Rule {
159    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
160        let val = match &self {
161            Rule::RequireConcurrentIndexCreation => "require-concurrent-index-creation",
162            Rule::RequireConcurrentIndexDeletion => "require-concurrent-index-deletion",
163            Rule::ConstraintMissingNotValid => "constraint-missing-not-valid",
164            Rule::AddingFieldWithDefault => "adding-field-with-default",
165            Rule::AddingForeignKeyConstraint => "adding-foreign-key-constraint",
166            Rule::ChangingColumnType => "changing-column-type",
167            Rule::AddingNotNullableField => "adding-not-nullable-field",
168            Rule::AddingSerialPrimaryKeyField => "adding-serial-primary-key-field",
169            Rule::RenamingColumn => "renaming-column",
170            Rule::RenamingTable => "renaming-table",
171            Rule::DisallowedUniqueConstraint => "disallowed-unique-constraint",
172            Rule::BanDropDatabase => "ban-drop-database",
173            Rule::PreferBigintOverInt => "prefer-bigint-over-int",
174            Rule::PreferBigintOverSmallint => "prefer-bigint-over-smallint",
175            Rule::PreferIdentity => "prefer-identity",
176            Rule::PreferRobustStmts => "prefer-robust-stmts",
177            Rule::PreferTextField => "prefer-text-field",
178            Rule::PreferTimestampTz => "prefer-timestamp-tz",
179            Rule::BanCharField => "ban-char-field",
180            Rule::BanDropColumn => "ban-drop-column",
181            Rule::BanDropTable => "ban-drop-table",
182            Rule::BanDropNotNull => "ban-drop-not-null",
183            Rule::TransactionNesting => "transaction-nesting",
184            Rule::AddingRequiredField => "adding-required-field",
185            Rule::BanConcurrentIndexCreationInTransaction => {
186                "ban-concurrent-index-creation-in-transaction"
187            }
188            Rule::BanCreateDomainWithConstraint => "ban-create-domain-with-constraint",
189            Rule::UnusedIgnore => "unused-ignore",
190            Rule::BanAlterDomainWithAddConstraint => "ban-alter-domain-with-add-constraint",
191            Rule::BanTruncateCascade => "ban-truncate-cascade",
192            Rule::RequireTimeoutSettings => "require-timeout-settings",
193            // xtask:new-rule:variant-to-name
194        };
195        write!(f, "{val}")
196    }
197}
198
199impl<'de> Deserialize<'de> for Rule {
200    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
201    where
202        D: serde::Deserializer<'de>,
203    {
204        let s = String::deserialize(deserializer)?;
205        s.parse().map_err(serde::de::Error::custom)
206    }
207}
208
209#[derive(Debug, Clone, PartialEq, Eq)]
210pub struct Fix {
211    pub title: String,
212    pub edits: Vec<Edit>,
213}
214
215impl Fix {
216    fn new<T: Into<String>>(title: T, edits: Vec<Edit>) -> Fix {
217        Fix {
218            title: title.into(),
219            edits,
220        }
221    }
222}
223
224#[derive(Debug, Clone, PartialEq, Eq)]
225pub struct Edit {
226    pub text_range: TextRange,
227    pub text: Option<String>,
228}
229impl Edit {
230    pub fn insert<T: Into<String>>(text: T, at: TextSize) -> Self {
231        Self {
232            text_range: TextRange::new(at, at),
233            text: Some(text.into()),
234        }
235    }
236    pub fn replace<T: Into<String>>(text_range: TextRange, text: T) -> Self {
237        Self {
238            text_range,
239            text: Some(text.into()),
240        }
241    }
242}
243
244#[derive(Debug, Clone, PartialEq, Eq)]
245pub struct Violation {
246    // TODO: should this be String instead?
247    pub code: Rule,
248    pub message: String,
249    pub text_range: TextRange,
250    pub help: Option<String>,
251    pub fix: Option<Fix>,
252}
253
254impl Violation {
255    #[must_use]
256    pub fn for_node(code: Rule, message: String, node: &SyntaxNode) -> Self {
257        let range = node.text_range();
258
259        let start = node
260            .children_with_tokens()
261            .find(|x| !x.kind().is_trivia())
262            .map(|x| x.text_range().start())
263            // Not sure we actually hit this, but just being safe
264            .unwrap_or_else(|| range.start());
265
266        Self {
267            code,
268            text_range: TextRange::new(start, range.end()),
269            message,
270            help: None,
271            fix: None,
272        }
273    }
274
275    #[must_use]
276    pub fn for_range(code: Rule, message: String, text_range: TextRange) -> Self {
277        Self {
278            code,
279            text_range,
280            message,
281            help: None,
282            fix: None,
283        }
284    }
285
286    fn fix(mut self, fix: Option<Fix>) -> Violation {
287        self.fix = fix;
288        self
289    }
290    fn help(mut self, help: impl Into<String>) -> Violation {
291        self.help = Some(help.into());
292        self
293    }
294}
295
296#[derive(Default)]
297pub struct LinterSettings {
298    pub pg_version: Version,
299    pub assume_in_transaction: bool,
300}
301
302pub struct Linter {
303    errors: Vec<Violation>,
304    ignores: Vec<Ignore>,
305    pub rules: HashSet<Rule>,
306    pub settings: LinterSettings,
307}
308
309impl Linter {
310    fn report(&mut self, error: Violation) {
311        self.errors.push(error);
312    }
313
314    fn ignore(&mut self, ignore: Ignore) {
315        self.ignores.push(ignore);
316    }
317
318    #[must_use]
319    pub fn lint(&mut self, file: &Parse<SourceFile>, text: &str) -> Vec<Violation> {
320        if self.rules.contains(&Rule::AddingFieldWithDefault) {
321            adding_field_with_default(self, file);
322        }
323        if self.rules.contains(&Rule::AddingForeignKeyConstraint) {
324            adding_foreign_key_constraint(self, file);
325        }
326        if self.rules.contains(&Rule::AddingNotNullableField) {
327            adding_not_null_field(self, file);
328        }
329        if self.rules.contains(&Rule::AddingSerialPrimaryKeyField) {
330            adding_primary_key_constraint(self, file);
331        }
332        if self.rules.contains(&Rule::AddingRequiredField) {
333            adding_required_field(self, file);
334        }
335        if self.rules.contains(&Rule::BanDropDatabase) {
336            ban_drop_database(self, file);
337        }
338        if self.rules.contains(&Rule::BanCharField) {
339            ban_char_field(self, file);
340        }
341        if self
342            .rules
343            .contains(&Rule::BanConcurrentIndexCreationInTransaction)
344        {
345            ban_concurrent_index_creation_in_transaction(self, file);
346        }
347        if self.rules.contains(&Rule::BanDropColumn) {
348            ban_drop_column(self, file);
349        }
350        if self.rules.contains(&Rule::BanDropNotNull) {
351            ban_drop_not_null(self, file);
352        }
353        if self.rules.contains(&Rule::BanDropTable) {
354            ban_drop_table(self, file);
355        }
356        if self.rules.contains(&Rule::ChangingColumnType) {
357            changing_column_type(self, file);
358        }
359        if self.rules.contains(&Rule::ConstraintMissingNotValid) {
360            constraint_missing_not_valid(self, file);
361        }
362        if self.rules.contains(&Rule::DisallowedUniqueConstraint) {
363            disallow_unique_constraint(self, file);
364        }
365        if self.rules.contains(&Rule::PreferBigintOverInt) {
366            prefer_bigint_over_int(self, file);
367        }
368        if self.rules.contains(&Rule::PreferBigintOverSmallint) {
369            prefer_bigint_over_smallint(self, file);
370        }
371        if self.rules.contains(&Rule::PreferIdentity) {
372            prefer_identity(self, file);
373        }
374        if self.rules.contains(&Rule::PreferRobustStmts) {
375            prefer_robust_stmts(self, file);
376        }
377        if self.rules.contains(&Rule::PreferTextField) {
378            prefer_text_field(self, file);
379        }
380        if self.rules.contains(&Rule::PreferTimestampTz) {
381            prefer_timestamptz(self, file);
382        }
383        if self.rules.contains(&Rule::RenamingColumn) {
384            renaming_column(self, file);
385        }
386        if self.rules.contains(&Rule::RenamingTable) {
387            renaming_table(self, file);
388        }
389        if self.rules.contains(&Rule::RequireConcurrentIndexCreation) {
390            require_concurrent_index_creation(self, file);
391        }
392        if self.rules.contains(&Rule::RequireConcurrentIndexDeletion) {
393            require_concurrent_index_deletion(self, file);
394        }
395        if self.rules.contains(&Rule::BanCreateDomainWithConstraint) {
396            ban_create_domain_with_constraint(self, file);
397        }
398        if self.rules.contains(&Rule::BanAlterDomainWithAddConstraint) {
399            ban_alter_domain_with_add_constraint(self, file);
400        }
401        if self.rules.contains(&Rule::TransactionNesting) {
402            transaction_nesting(self, file);
403        }
404        if self.rules.contains(&Rule::BanTruncateCascade) {
405            ban_truncate_cascade(self, file);
406        }
407        if self.rules.contains(&Rule::RequireTimeoutSettings) {
408            require_timeout_settings(self, file);
409        }
410        // xtask:new-rule:rule-call
411
412        // locate any ignores in the file
413        find_ignores(self, &file.syntax_node());
414
415        self.errors(text)
416    }
417
418    fn errors(&mut self, text: &str) -> Vec<Violation> {
419        let ignore_index = IgnoreIndex::new(text, &self.ignores);
420        let mut errors: Vec<Violation> = self
421            .errors
422            .iter()
423            // TODO: we should have errors for when there was an ignore but that
424            // ignore didn't actually ignore anything
425            .filter(|err| !ignore_index.contains(err.text_range, err.code))
426            .cloned()
427            .collect::<Vec<_>>();
428        // ensure we order them by where they appear in the file
429        errors.sort_by_key(|x| x.text_range.start());
430        errors
431    }
432
433    pub fn with_all_rules() -> Self {
434        let rules = all::<Rule>().collect::<HashSet<_>>();
435        Linter::from(rules)
436    }
437
438    pub fn without_rules(exclude: &[Rule]) -> Self {
439        let all_rules = all::<Rule>().collect::<HashSet<_>>();
440        let mut exclude_set = HashSet::with_capacity(exclude.len());
441        for e in exclude {
442            exclude_set.insert(e);
443        }
444
445        let rules = all_rules
446            .into_iter()
447            .filter(|x| !exclude_set.contains(x))
448            .collect::<HashSet<_>>();
449
450        Linter::from(rules)
451    }
452
453    pub fn from(rules: impl Into<HashSet<Rule>>) -> Self {
454        Self {
455            errors: vec![],
456            ignores: vec![],
457            rules: rules.into(),
458            settings: LinterSettings::default(),
459        }
460    }
461}
462
463#[cfg(test)]
464mod tests {
465    use insta::assert_debug_snapshot;
466
467    use super::*;
468
469    #[test]
470    fn prefer_timestamp_aliases() {
471        let rule1: Rule = "prefer-timestamp-tz".parse().unwrap();
472        let rule2: Rule = "prefer-timestamptz".parse().unwrap();
473        assert_eq!(rule1, rule2);
474        assert_debug_snapshot!(rule1, @"PreferTimestampTz");
475    }
476
477    #[test]
478    fn invalid_rule_name() {
479        let result: Result<Rule, _> = "invalid-rule-name".parse();
480        assert!(result.is_err());
481    }
482}