squawk_linter/
lib.rs

1use std::collections::HashSet;
2use std::fmt;
3
4use enum_iterator::Sequence;
5use enum_iterator::all;
6pub use ignore::Ignore;
7use ignore::find_ignores;
8use ignore_index::IgnoreIndex;
9use rowan::TextRange;
10use rowan::TextSize;
11use serde::Deserialize;
12
13use squawk_syntax::SyntaxNode;
14use squawk_syntax::{Parse, SourceFile};
15
16pub use version::Version;
17
18pub mod analyze;
19pub mod ignore;
20mod ignore_index;
21mod version;
22mod visitors;
23
24mod rules;
25
26#[cfg(test)]
27mod test_utils;
28use rules::adding_field_with_default;
29use rules::adding_foreign_key_constraint;
30use rules::adding_not_null_field;
31use rules::adding_primary_key_constraint;
32use rules::adding_required_field;
33use rules::ban_alter_domain_with_add_constraint;
34use rules::ban_char_field;
35use rules::ban_concurrent_index_creation_in_transaction;
36use rules::ban_create_domain_with_constraint;
37use rules::ban_drop_column;
38use rules::ban_drop_database;
39use rules::ban_drop_not_null;
40use rules::ban_drop_table;
41use rules::ban_truncate_cascade;
42use rules::ban_uncommitted_transaction;
43use rules::changing_column_type;
44use rules::constraint_missing_not_valid;
45use rules::disallow_unique_constraint;
46use rules::prefer_bigint_over_int;
47use rules::prefer_bigint_over_smallint;
48use rules::prefer_identity;
49use rules::prefer_robust_stmts;
50use rules::prefer_text_field;
51use rules::prefer_timestamptz;
52use rules::renaming_column;
53use rules::renaming_table;
54use rules::require_concurrent_index_creation;
55use rules::require_concurrent_index_deletion;
56use rules::require_timeout_settings;
57use rules::transaction_nesting;
58// xtask:new-rule:rule-import
59
60#[derive(Debug, PartialEq, Clone, Copy, Hash, Eq, Sequence)]
61pub enum Rule {
62    RequireConcurrentIndexCreation,
63    RequireConcurrentIndexDeletion,
64    ConstraintMissingNotValid,
65    AddingFieldWithDefault,
66    AddingForeignKeyConstraint,
67    ChangingColumnType,
68    AddingNotNullableField,
69    AddingSerialPrimaryKeyField,
70    RenamingColumn,
71    RenamingTable,
72    DisallowedUniqueConstraint,
73    BanDropDatabase,
74    PreferBigintOverInt,
75    PreferBigintOverSmallint,
76    PreferIdentity,
77    PreferRobustStmts,
78    PreferTextField,
79    PreferTimestampTz,
80    BanCharField,
81    BanDropColumn,
82    BanDropTable,
83    BanDropNotNull,
84    TransactionNesting,
85    AddingRequiredField,
86    BanConcurrentIndexCreationInTransaction,
87    UnusedIgnore,
88    BanCreateDomainWithConstraint,
89    BanAlterDomainWithAddConstraint,
90    BanTruncateCascade,
91    RequireTimeoutSettings,
92    BanUncommittedTransaction,
93    // xtask:new-rule:error-name
94}
95
96impl TryFrom<&str> for Rule {
97    type Error = String;
98
99    fn try_from(s: &str) -> Result<Self, Self::Error> {
100        match s {
101            "require-concurrent-index-creation" => Ok(Rule::RequireConcurrentIndexCreation),
102            "require-concurrent-index-deletion" => Ok(Rule::RequireConcurrentIndexDeletion),
103            "constraint-missing-not-valid" => Ok(Rule::ConstraintMissingNotValid),
104            "adding-field-with-default" => Ok(Rule::AddingFieldWithDefault),
105            "adding-foreign-key-constraint" => Ok(Rule::AddingForeignKeyConstraint),
106            "changing-column-type" => Ok(Rule::ChangingColumnType),
107            "adding-not-nullable-field" => Ok(Rule::AddingNotNullableField),
108            "adding-serial-primary-key-field" => Ok(Rule::AddingSerialPrimaryKeyField),
109            "renaming-column" => Ok(Rule::RenamingColumn),
110            "renaming-table" => Ok(Rule::RenamingTable),
111            "disallowed-unique-constraint" => Ok(Rule::DisallowedUniqueConstraint),
112            "ban-drop-database" => Ok(Rule::BanDropDatabase),
113            "prefer-bigint-over-int" => Ok(Rule::PreferBigintOverInt),
114            "prefer-bigint-over-smallint" => Ok(Rule::PreferBigintOverSmallint),
115            "prefer-identity" => Ok(Rule::PreferIdentity),
116            "prefer-robust-stmts" => Ok(Rule::PreferRobustStmts),
117            "prefer-text-field" => Ok(Rule::PreferTextField),
118            // this is typo'd so we just support both
119            "prefer-timestamptz" => Ok(Rule::PreferTimestampTz),
120            "prefer-timestamp-tz" => Ok(Rule::PreferTimestampTz),
121            "ban-char-field" => Ok(Rule::BanCharField),
122            "ban-drop-column" => Ok(Rule::BanDropColumn),
123            "ban-drop-table" => Ok(Rule::BanDropTable),
124            "ban-drop-not-null" => Ok(Rule::BanDropNotNull),
125            "transaction-nesting" => Ok(Rule::TransactionNesting),
126            "adding-required-field" => Ok(Rule::AddingRequiredField),
127            "ban-concurrent-index-creation-in-transaction" => {
128                Ok(Rule::BanConcurrentIndexCreationInTransaction)
129            }
130            "ban-create-domain-with-constraint" => Ok(Rule::BanCreateDomainWithConstraint),
131            "ban-alter-domain-with-add-constraint" => Ok(Rule::BanAlterDomainWithAddConstraint),
132            "ban-truncate-cascade" => Ok(Rule::BanTruncateCascade),
133            "require-timeout-settings" => Ok(Rule::RequireTimeoutSettings),
134            "ban-uncommitted-transaction" => Ok(Rule::BanUncommittedTransaction),
135            // xtask:new-rule:str-name
136            _ => Err(format!("Unknown violation name: {s}")),
137        }
138    }
139}
140
141#[derive(Debug, Clone, PartialEq, Eq)]
142pub struct UnknownRuleName {
143    val: String,
144}
145
146impl std::fmt::Display for UnknownRuleName {
147    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
148        write!(f, "invalid rule name {}", self.val)
149    }
150}
151
152impl std::error::Error for UnknownRuleName {}
153
154impl std::str::FromStr for Rule {
155    type Err = UnknownRuleName;
156    fn from_str(s: &str) -> Result<Self, Self::Err> {
157        Rule::try_from(s).map_err(|_| UnknownRuleName { val: s.to_string() })
158    }
159}
160
161impl fmt::Display for Rule {
162    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
163        let val = match &self {
164            Rule::RequireConcurrentIndexCreation => "require-concurrent-index-creation",
165            Rule::RequireConcurrentIndexDeletion => "require-concurrent-index-deletion",
166            Rule::ConstraintMissingNotValid => "constraint-missing-not-valid",
167            Rule::AddingFieldWithDefault => "adding-field-with-default",
168            Rule::AddingForeignKeyConstraint => "adding-foreign-key-constraint",
169            Rule::ChangingColumnType => "changing-column-type",
170            Rule::AddingNotNullableField => "adding-not-nullable-field",
171            Rule::AddingSerialPrimaryKeyField => "adding-serial-primary-key-field",
172            Rule::RenamingColumn => "renaming-column",
173            Rule::RenamingTable => "renaming-table",
174            Rule::DisallowedUniqueConstraint => "disallowed-unique-constraint",
175            Rule::BanDropDatabase => "ban-drop-database",
176            Rule::PreferBigintOverInt => "prefer-bigint-over-int",
177            Rule::PreferBigintOverSmallint => "prefer-bigint-over-smallint",
178            Rule::PreferIdentity => "prefer-identity",
179            Rule::PreferRobustStmts => "prefer-robust-stmts",
180            Rule::PreferTextField => "prefer-text-field",
181            Rule::PreferTimestampTz => "prefer-timestamp-tz",
182            Rule::BanCharField => "ban-char-field",
183            Rule::BanDropColumn => "ban-drop-column",
184            Rule::BanDropTable => "ban-drop-table",
185            Rule::BanDropNotNull => "ban-drop-not-null",
186            Rule::TransactionNesting => "transaction-nesting",
187            Rule::AddingRequiredField => "adding-required-field",
188            Rule::BanConcurrentIndexCreationInTransaction => {
189                "ban-concurrent-index-creation-in-transaction"
190            }
191            Rule::BanCreateDomainWithConstraint => "ban-create-domain-with-constraint",
192            Rule::UnusedIgnore => "unused-ignore",
193            Rule::BanAlterDomainWithAddConstraint => "ban-alter-domain-with-add-constraint",
194            Rule::BanTruncateCascade => "ban-truncate-cascade",
195            Rule::RequireTimeoutSettings => "require-timeout-settings",
196            Rule::BanUncommittedTransaction => "ban-uncommitted-transaction",
197            // xtask:new-rule:variant-to-name
198        };
199        write!(f, "{val}")
200    }
201}
202
203impl<'de> Deserialize<'de> for Rule {
204    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
205    where
206        D: serde::Deserializer<'de>,
207    {
208        let s = String::deserialize(deserializer)?;
209        s.parse().map_err(serde::de::Error::custom)
210    }
211}
212
213#[derive(Debug, Clone, PartialEq, Eq)]
214pub struct Fix {
215    pub title: String,
216    pub edits: Vec<Edit>,
217}
218
219impl Fix {
220    fn new<T: Into<String>>(title: T, edits: Vec<Edit>) -> Fix {
221        Fix {
222            title: title.into(),
223            edits,
224        }
225    }
226}
227
228#[derive(Debug, Clone, PartialEq, Eq)]
229pub struct Edit {
230    pub text_range: TextRange,
231    // TODO: does this need to be an Option?
232    pub text: Option<String>,
233}
234impl Edit {
235    pub fn insert<T: Into<String>>(text: T, at: TextSize) -> Self {
236        Self {
237            text_range: TextRange::new(at, at),
238            text: Some(text.into()),
239        }
240    }
241    pub fn replace<T: Into<String>>(text_range: TextRange, text: T) -> Self {
242        Self {
243            text_range,
244            text: Some(text.into()),
245        }
246    }
247    pub fn delete(text_range: TextRange) -> Self {
248        Self {
249            text_range,
250            text: None,
251        }
252    }
253}
254
255#[derive(Debug, Clone, PartialEq, Eq)]
256pub struct Violation {
257    // TODO: should this be String instead?
258    pub code: Rule,
259    pub message: String,
260    pub text_range: TextRange,
261    pub help: Option<String>,
262    pub fix: Option<Fix>,
263}
264
265impl Violation {
266    #[must_use]
267    pub fn for_node(code: Rule, message: String, node: &SyntaxNode) -> Self {
268        let range = node.text_range();
269
270        let start = node
271            .children_with_tokens()
272            .find(|x| !x.kind().is_trivia())
273            .map(|x| x.text_range().start())
274            // Not sure we actually hit this, but just being safe
275            .unwrap_or_else(|| range.start());
276
277        Self {
278            code,
279            text_range: TextRange::new(start, range.end()),
280            message,
281            help: None,
282            fix: None,
283        }
284    }
285
286    #[must_use]
287    pub fn for_range(code: Rule, message: String, text_range: TextRange) -> Self {
288        Self {
289            code,
290            text_range,
291            message,
292            help: None,
293            fix: None,
294        }
295    }
296
297    fn fix(mut self, fix: Option<Fix>) -> Violation {
298        self.fix = fix;
299        self
300    }
301    fn help(mut self, help: impl Into<String>) -> Violation {
302        self.help = Some(help.into());
303        self
304    }
305}
306
307#[derive(Default)]
308pub struct LinterSettings {
309    pub pg_version: Version,
310    pub assume_in_transaction: bool,
311}
312
313pub struct Linter {
314    errors: Vec<Violation>,
315    ignores: Vec<Ignore>,
316    pub rules: HashSet<Rule>,
317    pub settings: LinterSettings,
318}
319
320impl Linter {
321    fn report(&mut self, error: Violation) {
322        self.errors.push(error);
323    }
324
325    fn ignore(&mut self, ignore: Ignore) {
326        self.ignores.push(ignore);
327    }
328
329    #[must_use]
330    pub fn lint(&mut self, file: &Parse<SourceFile>, text: &str) -> Vec<Violation> {
331        if self.rules.contains(&Rule::AddingFieldWithDefault) {
332            adding_field_with_default(self, file);
333        }
334        if self.rules.contains(&Rule::AddingForeignKeyConstraint) {
335            adding_foreign_key_constraint(self, file);
336        }
337        if self.rules.contains(&Rule::AddingNotNullableField) {
338            adding_not_null_field(self, file);
339        }
340        if self.rules.contains(&Rule::AddingSerialPrimaryKeyField) {
341            adding_primary_key_constraint(self, file);
342        }
343        if self.rules.contains(&Rule::AddingRequiredField) {
344            adding_required_field(self, file);
345        }
346        if self.rules.contains(&Rule::BanDropDatabase) {
347            ban_drop_database(self, file);
348        }
349        if self.rules.contains(&Rule::BanCharField) {
350            ban_char_field(self, file);
351        }
352        if self
353            .rules
354            .contains(&Rule::BanConcurrentIndexCreationInTransaction)
355        {
356            ban_concurrent_index_creation_in_transaction(self, file);
357        }
358        if self.rules.contains(&Rule::BanDropColumn) {
359            ban_drop_column(self, file);
360        }
361        if self.rules.contains(&Rule::BanDropNotNull) {
362            ban_drop_not_null(self, file);
363        }
364        if self.rules.contains(&Rule::BanDropTable) {
365            ban_drop_table(self, file);
366        }
367        if self.rules.contains(&Rule::ChangingColumnType) {
368            changing_column_type(self, file);
369        }
370        if self.rules.contains(&Rule::ConstraintMissingNotValid) {
371            constraint_missing_not_valid(self, file);
372        }
373        if self.rules.contains(&Rule::DisallowedUniqueConstraint) {
374            disallow_unique_constraint(self, file);
375        }
376        if self.rules.contains(&Rule::PreferBigintOverInt) {
377            prefer_bigint_over_int(self, file);
378        }
379        if self.rules.contains(&Rule::PreferBigintOverSmallint) {
380            prefer_bigint_over_smallint(self, file);
381        }
382        if self.rules.contains(&Rule::PreferIdentity) {
383            prefer_identity(self, file);
384        }
385        if self.rules.contains(&Rule::PreferRobustStmts) {
386            prefer_robust_stmts(self, file);
387        }
388        if self.rules.contains(&Rule::PreferTextField) {
389            prefer_text_field(self, file);
390        }
391        if self.rules.contains(&Rule::PreferTimestampTz) {
392            prefer_timestamptz(self, file);
393        }
394        if self.rules.contains(&Rule::RenamingColumn) {
395            renaming_column(self, file);
396        }
397        if self.rules.contains(&Rule::RenamingTable) {
398            renaming_table(self, file);
399        }
400        if self.rules.contains(&Rule::RequireConcurrentIndexCreation) {
401            require_concurrent_index_creation(self, file);
402        }
403        if self.rules.contains(&Rule::RequireConcurrentIndexDeletion) {
404            require_concurrent_index_deletion(self, file);
405        }
406        if self.rules.contains(&Rule::BanCreateDomainWithConstraint) {
407            ban_create_domain_with_constraint(self, file);
408        }
409        if self.rules.contains(&Rule::BanAlterDomainWithAddConstraint) {
410            ban_alter_domain_with_add_constraint(self, file);
411        }
412        if self.rules.contains(&Rule::TransactionNesting) {
413            transaction_nesting(self, file);
414        }
415        if self.rules.contains(&Rule::BanTruncateCascade) {
416            ban_truncate_cascade(self, file);
417        }
418        if self.rules.contains(&Rule::RequireTimeoutSettings) {
419            require_timeout_settings(self, file);
420        }
421        if self.rules.contains(&Rule::BanUncommittedTransaction) {
422            ban_uncommitted_transaction(self, file);
423        }
424        // xtask:new-rule:rule-call
425
426        // locate any ignores in the file
427        find_ignores(self, &file.syntax_node());
428
429        self.errors(text)
430    }
431
432    fn errors(&mut self, text: &str) -> Vec<Violation> {
433        let ignore_index = IgnoreIndex::new(text, &self.ignores);
434        let mut errors: Vec<Violation> = self
435            .errors
436            .iter()
437            // TODO: we should have errors for when there was an ignore but that
438            // ignore didn't actually ignore anything
439            .filter(|err| !ignore_index.contains(err.text_range, err.code))
440            .cloned()
441            .collect::<Vec<_>>();
442        // ensure we order them by where they appear in the file
443        errors.sort_by_key(|x| x.text_range.start());
444        errors
445    }
446
447    pub fn with_all_rules() -> Self {
448        let rules = all::<Rule>().collect::<HashSet<_>>();
449        Linter::from(rules)
450    }
451
452    pub fn without_rules(exclude: &[Rule]) -> Self {
453        let all_rules = all::<Rule>().collect::<HashSet<_>>();
454        let mut exclude_set = HashSet::with_capacity(exclude.len());
455        for e in exclude {
456            exclude_set.insert(e);
457        }
458
459        let rules = all_rules
460            .into_iter()
461            .filter(|x| !exclude_set.contains(x))
462            .collect::<HashSet<_>>();
463
464        Linter::from(rules)
465    }
466
467    pub fn from(rules: impl Into<HashSet<Rule>>) -> Self {
468        Self {
469            errors: vec![],
470            ignores: vec![],
471            rules: rules.into(),
472            settings: Default::default(),
473        }
474    }
475}
476
477#[cfg(test)]
478mod tests {
479    use insta::assert_debug_snapshot;
480
481    use super::*;
482
483    #[test]
484    fn prefer_timestamp_aliases() {
485        let rule1: Rule = "prefer-timestamp-tz".parse().unwrap();
486        let rule2: Rule = "prefer-timestamptz".parse().unwrap();
487        assert_eq!(rule1, rule2);
488        assert_debug_snapshot!(rule1, @"PreferTimestampTz");
489    }
490
491    #[test]
492    fn invalid_rule_name() {
493        let result: Result<Rule, _> = "invalid-rule-name".parse();
494        assert!(result.is_err());
495    }
496}