1use std::collections::HashSet;
2use std::fmt;
3
4use enum_iterator::Sequence;
5use enum_iterator::all;
6pub use ignore::Ignore;
7use ignore::find_ignores;
8use ignore_index::IgnoreIndex;
9use rowan::TextRange;
10use rowan::TextSize;
11use serde::Deserialize;
12
13use squawk_syntax::SyntaxNode;
14use squawk_syntax::{Parse, SourceFile};
15
16pub use version::Version;
17
18pub mod analyze;
19pub mod ignore;
20mod ignore_index;
21mod version;
22mod visitors;
23
24mod rules;
25
26#[cfg(test)]
27mod test_utils;
28use rules::adding_field_with_default;
29use rules::adding_foreign_key_constraint;
30use rules::adding_not_null_field;
31use rules::adding_primary_key_constraint;
32use rules::adding_required_field;
33use rules::ban_alter_domain_with_add_constraint;
34use rules::ban_char_field;
35use rules::ban_concurrent_index_creation_in_transaction;
36use rules::ban_create_domain_with_constraint;
37use rules::ban_drop_column;
38use rules::ban_drop_database;
39use rules::ban_drop_not_null;
40use rules::ban_drop_table;
41use rules::ban_truncate_cascade;
42use rules::ban_uncommitted_transaction;
43use rules::changing_column_type;
44use rules::constraint_missing_not_valid;
45use rules::disallow_unique_constraint;
46use rules::prefer_bigint_over_int;
47use rules::prefer_bigint_over_smallint;
48use rules::prefer_identity;
49use rules::prefer_robust_stmts;
50use rules::prefer_text_field;
51use rules::prefer_timestamptz;
52use rules::renaming_column;
53use rules::renaming_table;
54use rules::require_concurrent_index_creation;
55use rules::require_concurrent_index_deletion;
56use rules::require_timeout_settings;
57use rules::transaction_nesting;
58#[derive(Debug, PartialEq, Clone, Copy, Hash, Eq, Sequence)]
61pub enum Rule {
62 RequireConcurrentIndexCreation,
63 RequireConcurrentIndexDeletion,
64 ConstraintMissingNotValid,
65 AddingFieldWithDefault,
66 AddingForeignKeyConstraint,
67 ChangingColumnType,
68 AddingNotNullableField,
69 AddingSerialPrimaryKeyField,
70 RenamingColumn,
71 RenamingTable,
72 DisallowedUniqueConstraint,
73 BanDropDatabase,
74 PreferBigintOverInt,
75 PreferBigintOverSmallint,
76 PreferIdentity,
77 PreferRobustStmts,
78 PreferTextField,
79 PreferTimestampTz,
80 BanCharField,
81 BanDropColumn,
82 BanDropTable,
83 BanDropNotNull,
84 TransactionNesting,
85 AddingRequiredField,
86 BanConcurrentIndexCreationInTransaction,
87 UnusedIgnore,
88 BanCreateDomainWithConstraint,
89 BanAlterDomainWithAddConstraint,
90 BanTruncateCascade,
91 RequireTimeoutSettings,
92 BanUncommittedTransaction,
93 }
95
96impl TryFrom<&str> for Rule {
97 type Error = String;
98
99 fn try_from(s: &str) -> Result<Self, Self::Error> {
100 match s {
101 "require-concurrent-index-creation" => Ok(Rule::RequireConcurrentIndexCreation),
102 "require-concurrent-index-deletion" => Ok(Rule::RequireConcurrentIndexDeletion),
103 "constraint-missing-not-valid" => Ok(Rule::ConstraintMissingNotValid),
104 "adding-field-with-default" => Ok(Rule::AddingFieldWithDefault),
105 "adding-foreign-key-constraint" => Ok(Rule::AddingForeignKeyConstraint),
106 "changing-column-type" => Ok(Rule::ChangingColumnType),
107 "adding-not-nullable-field" => Ok(Rule::AddingNotNullableField),
108 "adding-serial-primary-key-field" => Ok(Rule::AddingSerialPrimaryKeyField),
109 "renaming-column" => Ok(Rule::RenamingColumn),
110 "renaming-table" => Ok(Rule::RenamingTable),
111 "disallowed-unique-constraint" => Ok(Rule::DisallowedUniqueConstraint),
112 "ban-drop-database" => Ok(Rule::BanDropDatabase),
113 "prefer-bigint-over-int" => Ok(Rule::PreferBigintOverInt),
114 "prefer-bigint-over-smallint" => Ok(Rule::PreferBigintOverSmallint),
115 "prefer-identity" => Ok(Rule::PreferIdentity),
116 "prefer-robust-stmts" => Ok(Rule::PreferRobustStmts),
117 "prefer-text-field" => Ok(Rule::PreferTextField),
118 "prefer-timestamptz" => Ok(Rule::PreferTimestampTz),
120 "prefer-timestamp-tz" => Ok(Rule::PreferTimestampTz),
121 "ban-char-field" => Ok(Rule::BanCharField),
122 "ban-drop-column" => Ok(Rule::BanDropColumn),
123 "ban-drop-table" => Ok(Rule::BanDropTable),
124 "ban-drop-not-null" => Ok(Rule::BanDropNotNull),
125 "transaction-nesting" => Ok(Rule::TransactionNesting),
126 "adding-required-field" => Ok(Rule::AddingRequiredField),
127 "ban-concurrent-index-creation-in-transaction" => {
128 Ok(Rule::BanConcurrentIndexCreationInTransaction)
129 }
130 "ban-create-domain-with-constraint" => Ok(Rule::BanCreateDomainWithConstraint),
131 "ban-alter-domain-with-add-constraint" => Ok(Rule::BanAlterDomainWithAddConstraint),
132 "ban-truncate-cascade" => Ok(Rule::BanTruncateCascade),
133 "require-timeout-settings" => Ok(Rule::RequireTimeoutSettings),
134 "ban-uncommitted-transaction" => Ok(Rule::BanUncommittedTransaction),
135 _ => Err(format!("Unknown violation name: {s}")),
137 }
138 }
139}
140
141#[derive(Debug, Clone, PartialEq, Eq)]
142pub struct UnknownRuleName {
143 val: String,
144}
145
146impl std::fmt::Display for UnknownRuleName {
147 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
148 write!(f, "invalid rule name {}", self.val)
149 }
150}
151
152impl std::error::Error for UnknownRuleName {}
153
154impl std::str::FromStr for Rule {
155 type Err = UnknownRuleName;
156 fn from_str(s: &str) -> Result<Self, Self::Err> {
157 Rule::try_from(s).map_err(|_| UnknownRuleName { val: s.to_string() })
158 }
159}
160
161impl fmt::Display for Rule {
162 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
163 let val = match &self {
164 Rule::RequireConcurrentIndexCreation => "require-concurrent-index-creation",
165 Rule::RequireConcurrentIndexDeletion => "require-concurrent-index-deletion",
166 Rule::ConstraintMissingNotValid => "constraint-missing-not-valid",
167 Rule::AddingFieldWithDefault => "adding-field-with-default",
168 Rule::AddingForeignKeyConstraint => "adding-foreign-key-constraint",
169 Rule::ChangingColumnType => "changing-column-type",
170 Rule::AddingNotNullableField => "adding-not-nullable-field",
171 Rule::AddingSerialPrimaryKeyField => "adding-serial-primary-key-field",
172 Rule::RenamingColumn => "renaming-column",
173 Rule::RenamingTable => "renaming-table",
174 Rule::DisallowedUniqueConstraint => "disallowed-unique-constraint",
175 Rule::BanDropDatabase => "ban-drop-database",
176 Rule::PreferBigintOverInt => "prefer-bigint-over-int",
177 Rule::PreferBigintOverSmallint => "prefer-bigint-over-smallint",
178 Rule::PreferIdentity => "prefer-identity",
179 Rule::PreferRobustStmts => "prefer-robust-stmts",
180 Rule::PreferTextField => "prefer-text-field",
181 Rule::PreferTimestampTz => "prefer-timestamp-tz",
182 Rule::BanCharField => "ban-char-field",
183 Rule::BanDropColumn => "ban-drop-column",
184 Rule::BanDropTable => "ban-drop-table",
185 Rule::BanDropNotNull => "ban-drop-not-null",
186 Rule::TransactionNesting => "transaction-nesting",
187 Rule::AddingRequiredField => "adding-required-field",
188 Rule::BanConcurrentIndexCreationInTransaction => {
189 "ban-concurrent-index-creation-in-transaction"
190 }
191 Rule::BanCreateDomainWithConstraint => "ban-create-domain-with-constraint",
192 Rule::UnusedIgnore => "unused-ignore",
193 Rule::BanAlterDomainWithAddConstraint => "ban-alter-domain-with-add-constraint",
194 Rule::BanTruncateCascade => "ban-truncate-cascade",
195 Rule::RequireTimeoutSettings => "require-timeout-settings",
196 Rule::BanUncommittedTransaction => "ban-uncommitted-transaction",
197 };
199 write!(f, "{val}")
200 }
201}
202
203impl<'de> Deserialize<'de> for Rule {
204 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
205 where
206 D: serde::Deserializer<'de>,
207 {
208 let s = String::deserialize(deserializer)?;
209 s.parse().map_err(serde::de::Error::custom)
210 }
211}
212
213#[derive(Debug, Clone, PartialEq, Eq)]
214pub struct Fix {
215 pub title: String,
216 pub edits: Vec<Edit>,
217}
218
219impl Fix {
220 fn new<T: Into<String>>(title: T, edits: Vec<Edit>) -> Fix {
221 Fix {
222 title: title.into(),
223 edits,
224 }
225 }
226}
227
228#[derive(Debug, Clone, PartialEq, Eq)]
229pub struct Edit {
230 pub text_range: TextRange,
231 pub text: Option<String>,
233}
234impl Edit {
235 pub fn insert<T: Into<String>>(text: T, at: TextSize) -> Self {
236 Self {
237 text_range: TextRange::new(at, at),
238 text: Some(text.into()),
239 }
240 }
241 pub fn replace<T: Into<String>>(text_range: TextRange, text: T) -> Self {
242 Self {
243 text_range,
244 text: Some(text.into()),
245 }
246 }
247 pub fn delete(text_range: TextRange) -> Self {
248 Self {
249 text_range,
250 text: None,
251 }
252 }
253}
254
255#[derive(Debug, Clone, PartialEq, Eq)]
256pub struct Violation {
257 pub code: Rule,
259 pub message: String,
260 pub text_range: TextRange,
261 pub help: Option<String>,
262 pub fix: Option<Fix>,
263}
264
265impl Violation {
266 #[must_use]
267 pub fn for_node(code: Rule, message: String, node: &SyntaxNode) -> Self {
268 let range = node.text_range();
269
270 let start = node
271 .children_with_tokens()
272 .find(|x| !x.kind().is_trivia())
273 .map(|x| x.text_range().start())
274 .unwrap_or_else(|| range.start());
276
277 Self {
278 code,
279 text_range: TextRange::new(start, range.end()),
280 message,
281 help: None,
282 fix: None,
283 }
284 }
285
286 #[must_use]
287 pub fn for_range(code: Rule, message: String, text_range: TextRange) -> Self {
288 Self {
289 code,
290 text_range,
291 message,
292 help: None,
293 fix: None,
294 }
295 }
296
297 fn fix(mut self, fix: Option<Fix>) -> Violation {
298 self.fix = fix;
299 self
300 }
301 fn help(mut self, help: impl Into<String>) -> Violation {
302 self.help = Some(help.into());
303 self
304 }
305}
306
307#[derive(Default)]
308pub struct LinterSettings {
309 pub pg_version: Version,
310 pub assume_in_transaction: bool,
311}
312
313pub struct Linter {
314 errors: Vec<Violation>,
315 ignores: Vec<Ignore>,
316 pub rules: HashSet<Rule>,
317 pub settings: LinterSettings,
318}
319
320impl Linter {
321 fn report(&mut self, error: Violation) {
322 self.errors.push(error);
323 }
324
325 fn ignore(&mut self, ignore: Ignore) {
326 self.ignores.push(ignore);
327 }
328
329 #[must_use]
330 pub fn lint(&mut self, file: &Parse<SourceFile>, text: &str) -> Vec<Violation> {
331 if self.rules.contains(&Rule::AddingFieldWithDefault) {
332 adding_field_with_default(self, file);
333 }
334 if self.rules.contains(&Rule::AddingForeignKeyConstraint) {
335 adding_foreign_key_constraint(self, file);
336 }
337 if self.rules.contains(&Rule::AddingNotNullableField) {
338 adding_not_null_field(self, file);
339 }
340 if self.rules.contains(&Rule::AddingSerialPrimaryKeyField) {
341 adding_primary_key_constraint(self, file);
342 }
343 if self.rules.contains(&Rule::AddingRequiredField) {
344 adding_required_field(self, file);
345 }
346 if self.rules.contains(&Rule::BanDropDatabase) {
347 ban_drop_database(self, file);
348 }
349 if self.rules.contains(&Rule::BanCharField) {
350 ban_char_field(self, file);
351 }
352 if self
353 .rules
354 .contains(&Rule::BanConcurrentIndexCreationInTransaction)
355 {
356 ban_concurrent_index_creation_in_transaction(self, file);
357 }
358 if self.rules.contains(&Rule::BanDropColumn) {
359 ban_drop_column(self, file);
360 }
361 if self.rules.contains(&Rule::BanDropNotNull) {
362 ban_drop_not_null(self, file);
363 }
364 if self.rules.contains(&Rule::BanDropTable) {
365 ban_drop_table(self, file);
366 }
367 if self.rules.contains(&Rule::ChangingColumnType) {
368 changing_column_type(self, file);
369 }
370 if self.rules.contains(&Rule::ConstraintMissingNotValid) {
371 constraint_missing_not_valid(self, file);
372 }
373 if self.rules.contains(&Rule::DisallowedUniqueConstraint) {
374 disallow_unique_constraint(self, file);
375 }
376 if self.rules.contains(&Rule::PreferBigintOverInt) {
377 prefer_bigint_over_int(self, file);
378 }
379 if self.rules.contains(&Rule::PreferBigintOverSmallint) {
380 prefer_bigint_over_smallint(self, file);
381 }
382 if self.rules.contains(&Rule::PreferIdentity) {
383 prefer_identity(self, file);
384 }
385 if self.rules.contains(&Rule::PreferRobustStmts) {
386 prefer_robust_stmts(self, file);
387 }
388 if self.rules.contains(&Rule::PreferTextField) {
389 prefer_text_field(self, file);
390 }
391 if self.rules.contains(&Rule::PreferTimestampTz) {
392 prefer_timestamptz(self, file);
393 }
394 if self.rules.contains(&Rule::RenamingColumn) {
395 renaming_column(self, file);
396 }
397 if self.rules.contains(&Rule::RenamingTable) {
398 renaming_table(self, file);
399 }
400 if self.rules.contains(&Rule::RequireConcurrentIndexCreation) {
401 require_concurrent_index_creation(self, file);
402 }
403 if self.rules.contains(&Rule::RequireConcurrentIndexDeletion) {
404 require_concurrent_index_deletion(self, file);
405 }
406 if self.rules.contains(&Rule::BanCreateDomainWithConstraint) {
407 ban_create_domain_with_constraint(self, file);
408 }
409 if self.rules.contains(&Rule::BanAlterDomainWithAddConstraint) {
410 ban_alter_domain_with_add_constraint(self, file);
411 }
412 if self.rules.contains(&Rule::TransactionNesting) {
413 transaction_nesting(self, file);
414 }
415 if self.rules.contains(&Rule::BanTruncateCascade) {
416 ban_truncate_cascade(self, file);
417 }
418 if self.rules.contains(&Rule::RequireTimeoutSettings) {
419 require_timeout_settings(self, file);
420 }
421 if self.rules.contains(&Rule::BanUncommittedTransaction) {
422 ban_uncommitted_transaction(self, file);
423 }
424 find_ignores(self, &file.syntax_node());
428
429 self.errors(text)
430 }
431
432 fn errors(&mut self, text: &str) -> Vec<Violation> {
433 let ignore_index = IgnoreIndex::new(text, &self.ignores);
434 let mut errors: Vec<Violation> = self
435 .errors
436 .iter()
437 .filter(|err| !ignore_index.contains(err.text_range, err.code))
440 .cloned()
441 .collect::<Vec<_>>();
442 errors.sort_by_key(|x| x.text_range.start());
444 errors
445 }
446
447 pub fn with_all_rules() -> Self {
448 let rules = all::<Rule>().collect::<HashSet<_>>();
449 Linter::from(rules)
450 }
451
452 pub fn without_rules(exclude: &[Rule]) -> Self {
453 let all_rules = all::<Rule>().collect::<HashSet<_>>();
454 let mut exclude_set = HashSet::with_capacity(exclude.len());
455 for e in exclude {
456 exclude_set.insert(e);
457 }
458
459 let rules = all_rules
460 .into_iter()
461 .filter(|x| !exclude_set.contains(x))
462 .collect::<HashSet<_>>();
463
464 Linter::from(rules)
465 }
466
467 pub fn from(rules: impl Into<HashSet<Rule>>) -> Self {
468 Self {
469 errors: vec![],
470 ignores: vec![],
471 rules: rules.into(),
472 settings: Default::default(),
473 }
474 }
475}
476
477#[cfg(test)]
478mod tests {
479 use insta::assert_debug_snapshot;
480
481 use super::*;
482
483 #[test]
484 fn prefer_timestamp_aliases() {
485 let rule1: Rule = "prefer-timestamp-tz".parse().unwrap();
486 let rule2: Rule = "prefer-timestamptz".parse().unwrap();
487 assert_eq!(rule1, rule2);
488 assert_debug_snapshot!(rule1, @"PreferTimestampTz");
489 }
490
491 #[test]
492 fn invalid_rule_name() {
493 let result: Result<Rule, _> = "invalid-rule-name".parse();
494 assert!(result.is_err());
495 }
496}