Skip to main content

tensorlogic_compiler/passes/
validation.rs

1//! Validation passes for TLExpr.
2//!
3//! This module provides comprehensive validation for logical expressions before compilation.
4//! Validation helps catch errors early and provide helpful error messages.
5
6use anyhow::{anyhow, Result};
7use tensorlogic_ir::{PredicateSignature, TLExpr};
8
9use super::diagnostics::{diagnose_expression, DiagnosticLevel};
10use super::scope_analysis::{analyze_scopes, suggest_quantifiers};
11use super::type_checking::TypeChecker;
12
13/// Validate that all predicates with the same name have the same arity.
14///
15/// This is a basic validation that checks for consistency in predicate usage.
16/// Predicates must have the same number of arguments everywhere they appear.
17///
18/// # Errors
19///
20/// Returns an error if any predicate is used with different arities.
21///
22/// # Examples
23///
24/// ```
25/// use tensorlogic_compiler::passes::validate_arity;
26/// use tensorlogic_ir::{TLExpr, Term};
27///
28/// // Valid: knows/2 used consistently
29/// let expr = TLExpr::and(
30///     TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]),
31///     TLExpr::pred("knows", vec![Term::var("y"), Term::var("z")]),
32/// );
33/// assert!(validate_arity(&expr).is_ok());
34/// ```
35pub fn validate_arity(expr: &TLExpr) -> Result<()> {
36    expr.validate_arity().map_err(|e| anyhow!("{}", e))
37}
38
39/// Result of pre-compilation validation.
40///
41/// Contains all validation errors, warnings, and suggestions found during validation.
42#[derive(Debug, Clone)]
43pub struct ValidationResult {
44    /// Whether validation passed (no errors)
45    pub passed: bool,
46    /// Number of errors found
47    pub error_count: usize,
48    /// Number of warnings found
49    pub warning_count: usize,
50    /// All diagnostic messages (errors, warnings, hints)
51    pub diagnostics: Vec<String>,
52}
53
54impl ValidationResult {
55    /// Returns true if validation passed (no errors)
56    pub fn is_ok(&self) -> bool {
57        self.passed
58    }
59
60    /// Returns true if there are any errors
61    pub fn has_errors(&self) -> bool {
62        self.error_count > 0
63    }
64
65    /// Returns a formatted error message with all diagnostics
66    pub fn error_message(&self) -> String {
67        self.diagnostics.join("\n")
68    }
69}
70
71/// Performs comprehensive pre-compilation validation.
72///
73/// This function runs all available validation passes:
74/// 1. Arity validation (predicate consistency)
75/// 2. Scope analysis (unbound variables)
76/// 3. Enhanced diagnostics (unused bindings, type conflicts)
77///
78/// # Arguments
79///
80/// * `expr` - The expression to validate
81///
82/// # Returns
83///
84/// A `ValidationResult` containing all errors, warnings, and suggestions.
85///
86/// # Examples
87///
88/// ```
89/// use tensorlogic_compiler::passes::validate_expression;
90/// use tensorlogic_ir::{TLExpr, Term};
91///
92/// // Valid expression (fully quantified)
93/// let expr = TLExpr::exists(
94///     "x",
95///     "Person",
96///     TLExpr::exists(
97///         "y",
98///         "Person",
99///         TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]),
100///     ),
101/// );
102///
103/// let result = validate_expression(&expr);
104/// assert!(result.is_ok());
105/// ```
106///
107/// ```
108/// use tensorlogic_compiler::passes::validate_expression;
109/// use tensorlogic_ir::{TLExpr, Term};
110///
111/// // Expression with unbound variables
112/// let expr = TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]);
113///
114/// let result = validate_expression(&expr);
115/// assert!(result.has_errors());
116/// assert!(result.error_count >= 2); // x and y unbound
117/// ```
118pub fn validate_expression(expr: &TLExpr) -> ValidationResult {
119    let mut diagnostics = Vec::new();
120    let mut error_count = 0;
121    let mut warning_count = 0;
122
123    // 1. Arity validation
124    if let Err(e) = validate_arity(expr) {
125        diagnostics.push(format!("Arity error: {}", e));
126        error_count += 1;
127    }
128
129    // 2. Scope analysis
130    match analyze_scopes(expr) {
131        Ok(scope_result) => {
132            // Check for unbound variables
133            if !scope_result.unbound_variables.is_empty() {
134                for var in &scope_result.unbound_variables {
135                    diagnostics.push(format!("Unbound variable: '{}'", var));
136                    error_count += 1;
137                }
138
139                // Provide helpful suggestion
140                if let Ok(suggestions) = suggest_quantifiers(expr) {
141                    if !suggestions.is_empty() {
142                        diagnostics.push(format!("Suggestion: {}", suggestions.join(", ")));
143                    }
144                }
145            }
146
147            // Check for type conflicts
148            for conflict in &scope_result.type_conflicts {
149                diagnostics.push(format!(
150                    "Type conflict: variable '{}' has conflicting types '{}' and '{}'",
151                    conflict.variable, conflict.type1, conflict.type2
152                ));
153                error_count += 1;
154            }
155        }
156        Err(e) => {
157            diagnostics.push(format!("Scope analysis error: {}", e));
158            error_count += 1;
159        }
160    }
161
162    // 3. Enhanced diagnostics (warnings and hints)
163    let diag_messages = diagnose_expression(expr);
164    for diag in diag_messages {
165        let formatted = diag.format();
166        match diag.level {
167            DiagnosticLevel::Error => {
168                // Skip if we already reported this error above
169                if !diagnostics.iter().any(|d| d.contains(&diag.message)) {
170                    diagnostics.push(formatted);
171                    error_count += 1;
172                }
173            }
174            DiagnosticLevel::Warning => {
175                diagnostics.push(formatted);
176                warning_count += 1;
177            }
178            DiagnosticLevel::Info | DiagnosticLevel::Hint => {
179                diagnostics.push(formatted);
180            }
181        }
182    }
183
184    ValidationResult {
185        passed: error_count == 0,
186        error_count,
187        warning_count,
188        diagnostics,
189    }
190}
191
192/// Validates an expression with type signatures.
193///
194/// This is an extended validation that includes type checking against
195/// registered predicate signatures.
196///
197/// # Arguments
198///
199/// * `expr` - The expression to validate
200/// * `signatures` - Predicate signatures for type checking
201///
202/// # Returns
203///
204/// A `ValidationResult` with type checking errors included.
205///
206/// # Examples
207///
208/// ```
209/// use tensorlogic_compiler::passes::validate_expression_with_types;
210/// use tensorlogic_ir::{PredicateSignature, TLExpr, Term, TypeAnnotation};
211///
212/// let signatures = vec![
213///     PredicateSignature::new(
214///         "knows",
215///         vec![
216///             TypeAnnotation { type_name: "Person".to_string() },
217///             TypeAnnotation { type_name: "Person".to_string() },
218///         ],
219///     )
220/// ];
221///
222/// // Fully quantified expression
223/// let expr = TLExpr::exists(
224///     "x",
225///     "Person",
226///     TLExpr::exists(
227///         "y",
228///         "Person",
229///         TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]),
230///     ),
231/// );
232///
233/// let result = validate_expression_with_types(&expr, &signatures);
234/// assert!(result.is_ok());
235/// ```
236pub fn validate_expression_with_types(
237    expr: &TLExpr,
238    signatures: &[PredicateSignature],
239) -> ValidationResult {
240    let mut result = validate_expression(expr);
241
242    // Add type checking
243    use tensorlogic_ir::SignatureRegistry;
244    let mut registry = SignatureRegistry::new();
245    for sig in signatures {
246        registry.register(sig.clone());
247    }
248
249    let checker = TypeChecker::new(registry);
250    if let Err(e) = checker.check_expr(expr) {
251        result.diagnostics.push(format!("Type error: {}", e));
252        result.error_count += 1;
253        result.passed = false;
254    }
255
256    result
257}
258
259#[cfg(test)]
260mod tests {
261    use super::*;
262    use tensorlogic_ir::Term;
263
264    #[test]
265    fn test_validate_expression_ok() {
266        // Fully quantified expression with no unbound variables
267        let expr = TLExpr::exists(
268            "x",
269            "Person",
270            TLExpr::exists(
271                "y",
272                "Person",
273                TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]),
274            ),
275        );
276
277        let result = validate_expression(&expr);
278        if !result.is_ok() {
279            eprintln!("Validation failed with errors:");
280            for diag in &result.diagnostics {
281                eprintln!("  - {}", diag);
282            }
283        }
284        assert!(result.is_ok());
285        assert_eq!(result.error_count, 0);
286    }
287
288    #[test]
289    fn test_validate_expression_partial_binding() {
290        // Expression where y is bound but x is not
291        let expr = TLExpr::exists(
292            "y",
293            "Person",
294            TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]),
295        );
296
297        let result = validate_expression(&expr);
298        eprintln!(
299            "Error count: {}, diagnostics: {:?}",
300            result.error_count, result.diagnostics
301        );
302        assert!(result.has_errors());
303        // Expected: 1 error for unbound x, but diagnose_expression also reports it
304        // So we get 2 total (1 from scope analysis, 1 from diagnostics module)
305        assert!(result.error_count >= 1); // At least 1 for unbound x
306    }
307
308    #[test]
309    fn test_validate_expression_unbound_vars() {
310        let expr = TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]);
311
312        let result = validate_expression(&expr);
313        assert!(result.has_errors());
314        // Both x and y are unbound - at least 2 errors
315        assert!(result.error_count >= 2);
316    }
317
318    #[test]
319    fn test_validate_expression_arity_mismatch() {
320        let expr = TLExpr::and(
321            TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]),
322            TLExpr::pred("knows", vec![Term::var("z")]),
323        );
324
325        let result = validate_expression(&expr);
326        assert!(result.has_errors());
327        assert!(result.diagnostics.iter().any(|d| d.contains("Arity")));
328    }
329
330    #[test]
331    fn test_validate_expression_with_warnings() {
332        // Expression with unused binding
333        let expr = TLExpr::exists(
334            "x",
335            "Person",
336            TLExpr::pred("p", vec![Term::var("y")]), // x not used
337        );
338
339        let result = validate_expression(&expr);
340        assert!(result.warning_count > 0);
341    }
342
343    #[test]
344    fn test_validate_with_types() {
345        use tensorlogic_ir::TypeAnnotation;
346
347        let signatures = vec![PredicateSignature {
348            name: "knows".to_string(),
349            arity: 2,
350            arg_types: vec![
351                TypeAnnotation {
352                    type_name: "Person".to_string(),
353                },
354                TypeAnnotation {
355                    type_name: "Person".to_string(),
356                },
357            ],
358            parametric_types: None,
359        }];
360
361        // Fully quantified expression
362        let expr = TLExpr::exists(
363            "x",
364            "Person",
365            TLExpr::exists(
366                "y",
367                "Person",
368                TLExpr::pred("knows", vec![Term::var("x"), Term::var("y")]),
369            ),
370        );
371
372        let result = validate_expression_with_types(&expr, &signatures);
373        assert!(result.is_ok());
374    }
375
376    #[test]
377    fn test_validation_result_message() {
378        let expr = TLExpr::pred("knows", vec![Term::var("x")]);
379
380        let result = validate_expression(&expr);
381        let message = result.error_message();
382        assert!(!message.is_empty());
383        assert!(message.contains("Unbound"));
384    }
385}