Skip to main content

tensorlogic_oxirs_bridge/
sparql.rs

1//! Advanced SPARQL query compilation to TensorLogic operations
2//!
3//! This module provides comprehensive support for compiling SPARQL 1.1 queries
4//! into TensorLogic expressions. Supports:
5//! - SELECT queries (basic and complex)
6//! - ASK queries (boolean existence checks)
7//! - DESCRIBE queries (resource descriptions)
8//! - CONSTRUCT queries (RDF graph construction)
9//! - Triple patterns with variables and constants
10//! - Filter constraints (comparison, regex)
11//! - OPTIONAL patterns (left-outer join semantics)
12//! - UNION patterns (disjunction)
13//!
14//! For production SPARQL federation and advanced features, consider using a dedicated SPARQL engine.
15
16use anyhow::{anyhow, Result};
17use std::collections::HashMap;
18use tensorlogic_ir::{TLExpr, Term};
19
20/// Represents a SPARQL triple pattern
21#[derive(Debug, Clone, PartialEq)]
22pub struct TriplePattern {
23    pub subject: PatternElement,
24    pub predicate: PatternElement,
25    pub object: PatternElement,
26}
27
28/// Element in a triple pattern (variable or constant)
29#[derive(Debug, Clone, PartialEq)]
30pub enum PatternElement {
31    Variable(String),
32    Constant(String),
33}
34
35/// Filter condition in SPARQL
36#[derive(Debug, Clone, PartialEq)]
37pub enum FilterCondition {
38    Equals(String, String),
39    NotEquals(String, String),
40    GreaterThan(String, String),
41    LessThan(String, String),
42    GreaterOrEqual(String, String),
43    LessOrEqual(String, String),
44    Regex(String, String),
45    Bound(String),
46    IsIri(String),
47    IsLiteral(String),
48}
49
50/// Aggregate function in SPARQL
51#[derive(Debug, Clone, PartialEq)]
52pub enum AggregateFunction {
53    /// COUNT aggregate - counts solutions
54    Count {
55        variable: Option<String>,
56        distinct: bool,
57    },
58    /// SUM aggregate - sums numeric values
59    Sum { variable: String, distinct: bool },
60    /// AVG aggregate - computes average
61    Avg { variable: String, distinct: bool },
62    /// MIN aggregate - finds minimum value
63    Min { variable: String },
64    /// MAX aggregate - finds maximum value
65    Max { variable: String },
66    /// GROUP_CONCAT aggregate - concatenates strings
67    GroupConcat {
68        variable: String,
69        separator: Option<String>,
70        distinct: bool,
71    },
72    /// SAMPLE aggregate - returns arbitrary value
73    Sample { variable: String },
74}
75
76/// A projection element that can be a variable or an aggregate expression
77#[derive(Debug, Clone, PartialEq)]
78pub enum SelectElement {
79    /// Simple variable projection
80    Variable(String),
81    /// Aggregate expression with optional alias
82    Aggregate {
83        function: AggregateFunction,
84        alias: Option<String>,
85    },
86}
87
88/// Graph pattern in SPARQL (supports complex patterns)
89#[derive(Debug, Clone, PartialEq)]
90pub enum GraphPattern {
91    /// Basic triple pattern
92    Triple(TriplePattern),
93    /// Conjunction of patterns (implicit AND)
94    Group(Vec<GraphPattern>),
95    /// OPTIONAL pattern (left-outer join)
96    Optional(Box<GraphPattern>),
97    /// UNION pattern (disjunction)
98    Union(Box<GraphPattern>, Box<GraphPattern>),
99    /// FILTER constraint
100    Filter(FilterCondition),
101}
102
103/// Type of SPARQL query
104#[derive(Debug, Clone, PartialEq)]
105pub enum QueryType {
106    /// SELECT query - returns variable bindings
107    Select {
108        /// Projection elements (variables and aggregates)
109        projections: Vec<SelectElement>,
110        /// Legacy field for simple variable names (for backward compatibility)
111        select_vars: Vec<String>,
112        distinct: bool,
113    },
114    /// ASK query - returns boolean (existence check)
115    Ask,
116    /// DESCRIBE query - returns RDF description of resources
117    Describe { resources: Vec<String> },
118    /// CONSTRUCT query - constructs new RDF triples
119    Construct { template: Vec<TriplePattern> },
120}
121
122/// Compiled SPARQL query
123#[derive(Debug, Clone)]
124pub struct SparqlQuery {
125    /// Type of query (SELECT, ASK, DESCRIBE, CONSTRUCT)
126    pub query_type: QueryType,
127    /// WHERE clause graph patterns
128    pub where_pattern: GraphPattern,
129    /// GROUP BY variables
130    pub group_by: Vec<String>,
131    /// HAVING conditions (applied after grouping)
132    pub having: Vec<FilterCondition>,
133    /// Solution modifiers
134    pub limit: Option<usize>,
135    pub offset: Option<usize>,
136    pub order_by: Vec<String>,
137}
138
139/// SPARQL query parser and compiler
140pub struct SparqlCompiler {
141    /// Map of predicate IRIs to TensorLogic predicate names
142    predicate_mapping: HashMap<String, String>,
143}
144
145impl SparqlCompiler {
146    pub fn new() -> Self {
147        SparqlCompiler {
148            predicate_mapping: HashMap::new(),
149        }
150    }
151
152    /// Add a mapping from IRI to predicate name
153    ///
154    /// Example: map `"http://example.org/knows"` to `"knows"`
155    pub fn add_predicate_mapping(&mut self, iri: String, predicate_name: String) {
156        self.predicate_mapping.insert(iri, predicate_name);
157    }
158
159    /// Parse a SPARQL query (SELECT, ASK, DESCRIBE, or CONSTRUCT)
160    ///
161    /// Supports SPARQL 1.1 syntax including:
162    /// ```sparql
163    /// # SELECT query
164    /// SELECT DISTINCT ?x ?y WHERE {
165    ///   ?x <http://example.org/knows> ?y .
166    ///   OPTIONAL { ?x <http://example.org/age> ?age }
167    ///   FILTER(?age > 18)
168    /// } LIMIT 10
169    ///
170    /// # ASK query
171    /// ASK WHERE {
172    ///   ?x <http://example.org/knows> ?y .
173    /// }
174    ///
175    /// # DESCRIBE query
176    /// DESCRIBE ?x WHERE {
177    ///   ?x <http://example.org/knows> ?y .
178    /// }
179    ///
180    /// # CONSTRUCT query
181    /// CONSTRUCT { ?x <http://example.org/friend> ?y }
182    /// WHERE { ?x <http://example.org/knows> ?y }
183    /// ```
184    ///
185    /// Note: This is a simplified parser for demonstration.
186    /// For production, use a dedicated SPARQL parser.
187    pub fn parse_query(&self, sparql: &str) -> Result<SparqlQuery> {
188        // Normalize the query by collapsing whitespace and removing newlines within clauses
189        let normalized = sparql
190            .lines()
191            .map(|l| l.trim())
192            .filter(|l| !l.is_empty())
193            .collect::<Vec<_>>()
194            .join(" ");
195
196        // Determine query type
197        let query_type = self.parse_query_type(&normalized)?;
198
199        // Parse WHERE clause
200        let where_pattern = self.parse_where_clause(&normalized)?;
201
202        // Parse GROUP BY and HAVING
203        let group_by = self.parse_group_by(&normalized);
204        let having = self.parse_having(&normalized)?;
205
206        // Parse solution modifiers
207        let limit = self.parse_limit(&normalized);
208        let offset = self.parse_offset(&normalized);
209        let order_by = self.parse_order_by(&normalized);
210
211        Ok(SparqlQuery {
212            query_type,
213            where_pattern,
214            group_by,
215            having,
216            limit,
217            offset,
218            order_by,
219        })
220    }
221
222    /// Parse GROUP BY clause
223    fn parse_group_by(&self, normalized: &str) -> Vec<String> {
224        let mut group_by = Vec::new();
225
226        if let Some(group_pos) = normalized.find("GROUP BY") {
227            // Find the end of GROUP BY (next clause or end of query)
228            let remaining = &normalized[group_pos + 8..];
229            let end_pos = remaining
230                .find("HAVING")
231                .or_else(|| remaining.find("ORDER BY"))
232                .or_else(|| remaining.find("LIMIT"))
233                .or_else(|| remaining.find("OFFSET"))
234                .unwrap_or(remaining.len());
235
236            let group_part = remaining[..end_pos].trim();
237            for token in group_part.split_whitespace() {
238                if let Some(var_name) = token.strip_prefix('?') {
239                    group_by.push(var_name.to_string());
240                }
241            }
242        }
243
244        group_by
245    }
246
247    /// Parse HAVING clause
248    fn parse_having(&self, normalized: &str) -> Result<Vec<FilterCondition>> {
249        let mut conditions = Vec::new();
250
251        if let Some(having_pos) = normalized.find("HAVING") {
252            // Find the end of HAVING (next clause or end of query)
253            let remaining = &normalized[having_pos + 6..];
254            let end_pos = remaining
255                .find("ORDER BY")
256                .or_else(|| remaining.find("LIMIT"))
257                .or_else(|| remaining.find("OFFSET"))
258                .unwrap_or(remaining.len());
259
260            let having_part = remaining[..end_pos].trim();
261
262            // Parse conditions similar to FILTER
263            if !having_part.is_empty() {
264                if let Some(filter) = self.parse_filter(&format!("FILTER{}", having_part))? {
265                    conditions.push(filter);
266                }
267            }
268        }
269
270        Ok(conditions)
271    }
272
273    /// Parse an aggregate function
274    fn parse_aggregate(&self, text: &str) -> Option<(AggregateFunction, String)> {
275        let text = text.trim();
276
277        // Check for AS alias
278        let (func_part, alias) = if let Some(as_pos) = text.to_uppercase().find(" AS ") {
279            let alias_start = as_pos + 4;
280            let alias = text[alias_start..]
281                .trim()
282                .trim_matches(|c| c == '?' || c == ')')
283                .to_string();
284            (text[..as_pos].trim(), Some(alias))
285        } else {
286            (text, None)
287        };
288
289        // Parse aggregate function
290        let upper = func_part.to_uppercase();
291
292        if upper.starts_with("COUNT(") {
293            let inner = func_part[6..].trim_end_matches(')').trim();
294            let distinct = inner.to_uppercase().starts_with("DISTINCT");
295            let var_part = if distinct { inner[8..].trim() } else { inner };
296            let variable = if var_part == "*" {
297                None
298            } else {
299                Some(var_part.trim_start_matches('?').to_string())
300            };
301            return Some((
302                AggregateFunction::Count { variable, distinct },
303                alias.unwrap_or_else(|| "count".to_string()),
304            ));
305        }
306
307        if upper.starts_with("SUM(") {
308            let inner = func_part[4..].trim_end_matches(')').trim();
309            let distinct = inner.to_uppercase().starts_with("DISTINCT");
310            let var_part = if distinct { inner[8..].trim() } else { inner };
311            let variable = var_part.trim_start_matches('?').to_string();
312            return Some((
313                AggregateFunction::Sum { variable, distinct },
314                alias.unwrap_or_else(|| "sum".to_string()),
315            ));
316        }
317
318        if upper.starts_with("AVG(") {
319            let inner = func_part[4..].trim_end_matches(')').trim();
320            let distinct = inner.to_uppercase().starts_with("DISTINCT");
321            let var_part = if distinct { inner[8..].trim() } else { inner };
322            let variable = var_part.trim_start_matches('?').to_string();
323            return Some((
324                AggregateFunction::Avg { variable, distinct },
325                alias.unwrap_or_else(|| "avg".to_string()),
326            ));
327        }
328
329        if upper.starts_with("MIN(") {
330            let inner = func_part[4..].trim_end_matches(')').trim();
331            let variable = inner.trim_start_matches('?').to_string();
332            return Some((
333                AggregateFunction::Min { variable },
334                alias.unwrap_or_else(|| "min".to_string()),
335            ));
336        }
337
338        if upper.starts_with("MAX(") {
339            let inner = func_part[4..].trim_end_matches(')').trim();
340            let variable = inner.trim_start_matches('?').to_string();
341            return Some((
342                AggregateFunction::Max { variable },
343                alias.unwrap_or_else(|| "max".to_string()),
344            ));
345        }
346
347        if upper.starts_with("GROUP_CONCAT(") {
348            let inner = func_part[13..].trim_end_matches(')').trim();
349            let distinct = inner.to_uppercase().starts_with("DISTINCT");
350            let var_part = if distinct { inner[8..].trim() } else { inner };
351            // Check for SEPARATOR
352            let (variable, separator) =
353                if let Some(sep_pos) = var_part.to_uppercase().find("; SEPARATOR") {
354                    let var = var_part[..sep_pos]
355                        .trim()
356                        .trim_start_matches('?')
357                        .to_string();
358                    let sep_start = var_part.find('=').map(|p| p + 1).unwrap_or(sep_pos);
359                    let sep = var_part[sep_start..].trim().trim_matches('"').to_string();
360                    (var, Some(sep))
361                } else {
362                    (var_part.trim_start_matches('?').to_string(), None)
363                };
364            return Some((
365                AggregateFunction::GroupConcat {
366                    variable,
367                    separator,
368                    distinct,
369                },
370                alias.unwrap_or_else(|| "group_concat".to_string()),
371            ));
372        }
373
374        if upper.starts_with("SAMPLE(") {
375            let inner = func_part[7..].trim_end_matches(')').trim();
376            let variable = inner.trim_start_matches('?').to_string();
377            return Some((
378                AggregateFunction::Sample { variable },
379                alias.unwrap_or_else(|| "sample".to_string()),
380            ));
381        }
382
383        None
384    }
385
386    /// Parse the query type (SELECT, ASK, DESCRIBE, CONSTRUCT)
387    fn parse_query_type(&self, normalized: &str) -> Result<QueryType> {
388        if normalized.contains("ASK") {
389            Ok(QueryType::Ask)
390        } else if let Some(describe_pos) = normalized.find("DESCRIBE") {
391            // Parse DESCRIBE resources
392            let where_pos = normalized.find("WHERE").unwrap_or(normalized.len());
393            let describe_part = normalized[describe_pos + 8..where_pos].trim();
394            let mut resources = Vec::new();
395
396            for token in describe_part.split_whitespace() {
397                if token.starts_with('?') || token.starts_with('<') {
398                    resources.push(
399                        token
400                            .trim_matches(|c| c == '?' || c == '<' || c == '>')
401                            .to_string(),
402                    );
403                }
404            }
405
406            Ok(QueryType::Describe { resources })
407        } else if normalized.contains("CONSTRUCT") {
408            // Parse CONSTRUCT template
409            let template = self.parse_construct_template(normalized)?;
410            Ok(QueryType::Construct { template })
411        } else if let Some(select_pos) = normalized.find("SELECT") {
412            // Parse SELECT variables and aggregates
413            let where_pos = normalized.find("WHERE").unwrap_or(normalized.len());
414            let select_part = normalized[select_pos + 6..where_pos].trim();
415
416            let distinct = select_part.starts_with("DISTINCT");
417            let vars_part = if distinct {
418                &select_part[8..]
419            } else {
420                select_part
421            };
422
423            let mut select_vars = Vec::new();
424            let mut projections = Vec::new();
425
426            // Split on commas or parentheses to handle aggregates
427            let mut current_token = String::new();
428            let mut paren_depth = 0;
429
430            for c in vars_part.chars() {
431                match c {
432                    '(' => {
433                        paren_depth += 1;
434                        current_token.push(c);
435                    }
436                    ')' => {
437                        paren_depth -= 1;
438                        current_token.push(c);
439                    }
440                    ' ' | ',' if paren_depth == 0 => {
441                        if !current_token.trim().is_empty() {
442                            let token = current_token.trim();
443                            // Strip outer parentheses for aggregate expressions
444                            let token = if token.starts_with('(') && token.ends_with(')') {
445                                &token[1..token.len() - 1]
446                            } else {
447                                token
448                            };
449                            if let Some((agg_func, alias)) = self.parse_aggregate(token) {
450                                projections.push(SelectElement::Aggregate {
451                                    function: agg_func,
452                                    alias: Some(alias.clone()),
453                                });
454                                select_vars.push(alias);
455                            } else if let Some(var_name) = token.strip_prefix('?') {
456                                projections.push(SelectElement::Variable(var_name.to_string()));
457                                select_vars.push(var_name.to_string());
458                            } else if token == "*" {
459                                projections.push(SelectElement::Variable("*".to_string()));
460                                select_vars.push("*".to_string());
461                            }
462                        }
463                        current_token.clear();
464                    }
465                    _ => current_token.push(c),
466                }
467            }
468
469            // Handle the last token
470            if !current_token.trim().is_empty() {
471                let token = current_token.trim();
472                // Strip outer parentheses for aggregate expressions
473                let token = if token.starts_with('(') && token.ends_with(')') {
474                    &token[1..token.len() - 1]
475                } else {
476                    token
477                };
478                if let Some((agg_func, alias)) = self.parse_aggregate(token) {
479                    projections.push(SelectElement::Aggregate {
480                        function: agg_func,
481                        alias: Some(alias.clone()),
482                    });
483                    select_vars.push(alias);
484                } else if let Some(var_name) = token.strip_prefix('?') {
485                    projections.push(SelectElement::Variable(var_name.to_string()));
486                    select_vars.push(var_name.to_string());
487                } else if token == "*" {
488                    projections.push(SelectElement::Variable("*".to_string()));
489                    select_vars.push("*".to_string());
490                }
491            }
492
493            Ok(QueryType::Select {
494                projections,
495                select_vars,
496                distinct,
497            })
498        } else {
499            Err(anyhow!("Unable to determine query type"))
500        }
501    }
502
503    /// Parse CONSTRUCT template patterns
504    fn parse_construct_template(&self, normalized: &str) -> Result<Vec<TriplePattern>> {
505        let construct_pos = normalized
506            .find("CONSTRUCT")
507            .ok_or_else(|| anyhow!("No CONSTRUCT found"))?;
508        let where_pos = normalized.find("WHERE").unwrap_or(normalized.len());
509
510        // Find template content between { and }
511        let template_start = normalized[construct_pos..where_pos]
512            .find('{')
513            .ok_or_else(|| anyhow!("No opening brace in CONSTRUCT template"))?;
514        let template_end = normalized[construct_pos..where_pos]
515            .rfind('}')
516            .ok_or_else(|| anyhow!("No closing brace in CONSTRUCT template"))?;
517
518        let template_content =
519            &normalized[construct_pos + template_start + 1..construct_pos + template_end];
520
521        let mut patterns = Vec::new();
522        for statement in self.split_sparql_statements(template_content) {
523            if let Some(pattern) = self.parse_triple_pattern(statement)? {
524                patterns.push(pattern);
525            }
526        }
527
528        Ok(patterns)
529    }
530
531    /// Parse WHERE clause into graph patterns
532    fn parse_where_clause(&self, normalized: &str) -> Result<GraphPattern> {
533        // Find WHERE clause content (between { and })
534        if let Some(where_start) = normalized.find("WHERE") {
535            if let Some(brace_start) = normalized[where_start..].find('{') {
536                let content_start = where_start + brace_start + 1;
537
538                // Find matching closing brace
539                let closing_brace = self.find_matching_brace(&normalized[content_start..])?;
540                let where_content = &normalized[content_start..content_start + closing_brace];
541
542                return self.parse_graph_pattern(where_content);
543            }
544        }
545
546        Err(anyhow!("No WHERE clause found"))
547    }
548
549    /// Parse a graph pattern (handles OPTIONAL, UNION, FILTER)
550    fn parse_graph_pattern(&self, content: &str) -> Result<GraphPattern> {
551        let content = content.trim();
552
553        if content.is_empty() {
554            return Err(anyhow!("Empty graph pattern"));
555        }
556
557        // Check for UNION (top-level only)
558        if let Some(union_pos) = content.find("UNION") {
559            // Ensure it's not inside braces
560            let before_union = &content[..union_pos];
561            let open_braces = before_union.matches('{').count();
562            let close_braces = before_union.matches('}').count();
563
564            if open_braces == close_braces {
565                // UNION is at top level
566                let left_part = before_union.trim();
567                let right_part = content[union_pos + 5..].trim();
568
569                let left_pattern = self.parse_graph_pattern(left_part)?;
570                let right_pattern = self.parse_graph_pattern(right_part)?;
571
572                return Ok(GraphPattern::Union(
573                    Box::new(left_pattern),
574                    Box::new(right_pattern),
575                ));
576            }
577        }
578
579        // Parse statements using split_sparql_statements
580        let mut patterns = Vec::new();
581        let statements = self.split_sparql_statements(content);
582
583        for statement in statements {
584            let statement = statement.trim();
585
586            if statement.is_empty() {
587                continue;
588            }
589
590            // Check for OPTIONAL
591            if statement.starts_with("OPTIONAL") {
592                // Find the content in braces
593                if let Some(brace_start_pos) = statement.find('{') {
594                    let content_start = brace_start_pos + 1;
595                    if let Ok(closing_offset) =
596                        self.find_matching_brace(&statement[content_start..])
597                    {
598                        let optional_content =
599                            &statement[content_start..content_start + closing_offset];
600                        let inner_pattern = self.parse_graph_pattern(optional_content)?;
601                        patterns.push(GraphPattern::Optional(Box::new(inner_pattern)));
602                        continue;
603                    }
604                }
605            }
606
607            // Check for FILTER
608            if statement.starts_with("FILTER") {
609                if let Some(filter) = self.parse_filter(statement)? {
610                    patterns.push(GraphPattern::Filter(filter));
611                }
612                continue;
613            }
614
615            // Check for nested braces (subgraph pattern)
616            if statement.starts_with('{') && statement.ends_with('}') {
617                let inner = &statement[1..statement.len() - 1];
618                let inner_pattern = self.parse_graph_pattern(inner)?;
619                patterns.push(inner_pattern);
620                continue;
621            }
622
623            // Parse as triple pattern
624            if let Some(pattern) = self.parse_triple_pattern(statement)? {
625                patterns.push(GraphPattern::Triple(pattern));
626            }
627        }
628
629        if patterns.is_empty() {
630            Err(anyhow!("Empty graph pattern in content: {}", content))
631        } else if patterns.len() == 1 {
632            Ok(patterns.into_iter().next().unwrap())
633        } else {
634            Ok(GraphPattern::Group(patterns))
635        }
636    }
637
638    /// Find matching closing brace
639    fn find_matching_brace(&self, content: &str) -> Result<usize> {
640        let mut depth = 1;
641        let chars: Vec<char> = content.chars().collect();
642
643        for (i, &c) in chars.iter().enumerate() {
644            match c {
645                '{' => depth += 1,
646                '}' => {
647                    depth -= 1;
648                    if depth == 0 {
649                        return Ok(i);
650                    }
651                }
652                _ => {}
653            }
654        }
655
656        Err(anyhow!("No matching closing brace found"))
657    }
658
659    /// Parse LIMIT modifier
660    fn parse_limit(&self, normalized: &str) -> Option<usize> {
661        if let Some(limit_pos) = normalized.find("LIMIT") {
662            let after_limit = &normalized[limit_pos + 5..].trim();
663            if let Some(num_str) = after_limit.split_whitespace().next() {
664                return num_str.parse().ok();
665            }
666        }
667        None
668    }
669
670    /// Parse OFFSET modifier
671    fn parse_offset(&self, normalized: &str) -> Option<usize> {
672        if let Some(offset_pos) = normalized.find("OFFSET") {
673            let after_offset = &normalized[offset_pos + 6..].trim();
674            if let Some(num_str) = after_offset.split_whitespace().next() {
675                return num_str.parse().ok();
676            }
677        }
678        None
679    }
680
681    /// Parse ORDER BY modifier
682    fn parse_order_by(&self, normalized: &str) -> Vec<String> {
683        if let Some(order_pos) = normalized.find("ORDER BY") {
684            let after_order = &normalized[order_pos + 8..];
685
686            // Find the end of ORDER BY clause (either LIMIT, OFFSET, or end of string)
687            let limit_offset = after_order.find("LIMIT").unwrap_or(after_order.len());
688            let offset_offset = after_order.find("OFFSET").unwrap_or(after_order.len());
689            let end_offset = limit_offset.min(offset_offset);
690
691            let order_part = after_order[..end_offset].trim();
692            return order_part
693                .split_whitespace()
694                .filter_map(|s| s.strip_prefix('?').map(|v| v.to_string()))
695                .collect();
696        }
697        Vec::new()
698    }
699
700    /// Split SPARQL WHERE content into statements, respecting URI boundaries
701    ///
702    /// This splits on '.' that are statement terminators, not on '.' inside <...> URIs
703    fn split_sparql_statements<'a>(&self, content: &'a str) -> Vec<&'a str> {
704        let mut statements = Vec::new();
705        let mut current_start = 0;
706        let mut inside_uri = false;
707        let mut inside_string = false;
708        let chars: Vec<char> = content.chars().collect();
709
710        for i in 0..chars.len() {
711            match chars[i] {
712                '<' if !inside_string => inside_uri = true,
713                '>' if !inside_string => inside_uri = false,
714                '"' if !inside_uri => inside_string = !inside_string,
715                '.' if !inside_uri && !inside_string => {
716                    // Found a statement-terminating period
717                    let statement = &content[current_start..i];
718                    if !statement.trim().is_empty() {
719                        statements.push(statement);
720                    }
721                    current_start = i + 1;
722                }
723                _ => {}
724            }
725        }
726
727        // Add the last statement if there's anything left
728        if current_start < content.len() {
729            let statement = &content[current_start..];
730            if !statement.trim().is_empty() {
731                statements.push(statement);
732            }
733        }
734
735        statements
736    }
737
738    /// Parse a triple pattern
739    fn parse_triple_pattern(&self, line: &str) -> Result<Option<TriplePattern>> {
740        // Remove trailing dot and split by whitespace
741        let line = line.trim_end_matches('.').trim();
742        let parts: Vec<&str> = line.split_whitespace().collect();
743
744        if parts.len() < 3 {
745            return Ok(None);
746        }
747
748        let subject = self.parse_pattern_element(parts[0])?;
749        let predicate = self.parse_pattern_element(parts[1])?;
750        let object = self.parse_pattern_element(parts[2])?;
751
752        Ok(Some(TriplePattern {
753            subject,
754            predicate,
755            object,
756        }))
757    }
758
759    /// Parse a pattern element (variable or constant)
760    fn parse_pattern_element(&self, s: &str) -> Result<PatternElement> {
761        if let Some(var_name) = s.strip_prefix('?') {
762            Ok(PatternElement::Variable(var_name.to_string()))
763        } else if let Some(iri) = s.strip_prefix('<').and_then(|s| s.strip_suffix('>')) {
764            Ok(PatternElement::Constant(iri.to_string()))
765        } else if let Some(literal) = s.strip_prefix('"').and_then(|s| s.strip_suffix('"')) {
766            Ok(PatternElement::Constant(literal.to_string()))
767        } else {
768            Ok(PatternElement::Constant(s.to_string()))
769        }
770    }
771
772    /// Parse a FILTER clause
773    fn parse_filter(&self, line: &str) -> Result<Option<FilterCondition>> {
774        let filter_content = line
775            .strip_prefix("FILTER")
776            .and_then(|s| s.trim().strip_prefix('('))
777            .and_then(|s| s.trim().strip_suffix(')'))
778            .map(|s| s.trim());
779
780        if let Some(content) = filter_content {
781            // Check for built-in functions
782            if content.starts_with("BOUND(") {
783                if let Some(var_end) = content.find(')') {
784                    let var = &content[6..var_end].trim_start_matches('?');
785                    return Ok(Some(FilterCondition::Bound(var.to_string())));
786                }
787            } else if content.starts_with("isIRI(") || content.starts_with("isURI(") {
788                // Both isIRI and isURI have the same length (6 characters including parenthesis)
789                let start_pos = 6;
790                if let Some(var_end) = content.find(')') {
791                    let var = &content[start_pos..var_end].trim_start_matches('?');
792                    return Ok(Some(FilterCondition::IsIri(var.to_string())));
793                }
794            } else if content.starts_with("isLiteral(") {
795                if let Some(var_end) = content.find(')') {
796                    let var = &content[10..var_end].trim_start_matches('?');
797                    return Ok(Some(FilterCondition::IsLiteral(var.to_string())));
798                }
799            } else if content.starts_with("regex(") {
800                // regex(?var, "pattern")
801                if let Some(comma_pos) = content.find(',') {
802                    let var = content[6..comma_pos].trim().trim_start_matches('?');
803                    let pattern_part = content[comma_pos + 1..]
804                        .trim()
805                        .trim_end_matches(')')
806                        .trim_matches('"');
807                    return Ok(Some(FilterCondition::Regex(
808                        var.to_string(),
809                        pattern_part.to_string(),
810                    )));
811                }
812            }
813
814            // Check for comparison operators
815            if content.contains(">=") {
816                let parts: Vec<&str> = content.split(">=").map(|s| s.trim()).collect();
817                if parts.len() == 2 {
818                    return Ok(Some(FilterCondition::GreaterOrEqual(
819                        parts[0].trim_start_matches('?').to_string(),
820                        parts[1].trim_matches('"').to_string(),
821                    )));
822                }
823            } else if content.contains("<=") {
824                let parts: Vec<&str> = content.split("<=").map(|s| s.trim()).collect();
825                if parts.len() == 2 {
826                    return Ok(Some(FilterCondition::LessOrEqual(
827                        parts[0].trim_start_matches('?').to_string(),
828                        parts[1].trim_matches('"').to_string(),
829                    )));
830                }
831            } else if content.contains(">") && !content.contains(">=") {
832                let parts: Vec<&str> = content.split('>').map(|s| s.trim()).collect();
833                if parts.len() == 2 {
834                    return Ok(Some(FilterCondition::GreaterThan(
835                        parts[0].trim_start_matches('?').to_string(),
836                        parts[1].trim_matches('"').to_string(),
837                    )));
838                }
839            } else if content.contains("<") && !content.contains("<=") {
840                let parts: Vec<&str> = content.split('<').map(|s| s.trim()).collect();
841                if parts.len() == 2 {
842                    return Ok(Some(FilterCondition::LessThan(
843                        parts[0].trim_start_matches('?').to_string(),
844                        parts[1].trim_matches('"').to_string(),
845                    )));
846                }
847            } else if content.contains("!=") {
848                let parts: Vec<&str> = content.split("!=").map(|s| s.trim()).collect();
849                if parts.len() == 2 {
850                    return Ok(Some(FilterCondition::NotEquals(
851                        parts[0].trim_start_matches('?').to_string(),
852                        parts[1].trim_matches('"').to_string(),
853                    )));
854                }
855            } else if content.contains("=")
856                && !content.contains("!=")
857                && !content.contains(">=")
858                && !content.contains("<=")
859            {
860                let parts: Vec<&str> = content.split('=').map(|s| s.trim()).collect();
861                if parts.len() == 2 {
862                    return Ok(Some(FilterCondition::Equals(
863                        parts[0].trim_start_matches('?').to_string(),
864                        parts[1].trim_matches('"').to_string(),
865                    )));
866                }
867            }
868        }
869
870        Ok(None)
871    }
872
873    /// Compile a SPARQL query to TensorLogic expression
874    ///
875    /// Converts SPARQL patterns to TLExpr predicates and filters to constraints.
876    /// Supports all query types (SELECT, ASK, DESCRIBE, CONSTRUCT) and advanced
877    /// patterns (OPTIONAL, UNION).
878    ///
879    /// ## Example
880    ///
881    /// ```
882    /// use tensorlogic_oxirs_bridge::sparql::SparqlCompiler;
883    ///
884    /// let mut compiler = SparqlCompiler::new();
885    /// compiler.add_predicate_mapping(
886    ///     "http://example.org/knows".to_string(),
887    ///     "knows".to_string()
888    /// );
889    ///
890    /// // SELECT query
891    /// let query = r#"
892    ///     SELECT ?x ?y WHERE {
893    ///       ?x <http://example.org/knows> ?y .
894    ///     }
895    /// "#;
896    ///
897    /// let sparql_query = compiler.parse_query(query).unwrap();
898    /// let tl_expr = compiler.compile_to_tensorlogic(&sparql_query).unwrap();
899    ///
900    /// // ASK query
901    /// let ask_query = r#"
902    ///     ASK WHERE {
903    ///       ?x <http://example.org/knows> ?y .
904    ///     }
905    /// "#;
906    ///
907    /// let sparql_ask = compiler.parse_query(ask_query).unwrap();
908    /// let ask_expr = compiler.compile_to_tensorlogic(&sparql_ask).unwrap();
909    /// ```
910    pub fn compile_to_tensorlogic(&self, query: &SparqlQuery) -> Result<TLExpr> {
911        // Compile WHERE clause pattern
912        let where_expr = self.compile_graph_pattern(&query.where_pattern)?;
913
914        // For ASK queries, wrap in EXISTS quantifier
915        match &query.query_type {
916            QueryType::Ask => {
917                // ASK is essentially EXISTS over all variables in the pattern
918                Ok(where_expr) // The pattern itself represents existence
919            }
920            QueryType::Select { select_vars, .. } => {
921                // For SELECT, the expression is the WHERE clause
922                // Variable projection happens at execution time
923                if select_vars.is_empty() || select_vars.contains(&"*".to_string()) {
924                    Ok(where_expr)
925                } else {
926                    // Could add quantifiers for non-selected variables here
927                    Ok(where_expr)
928                }
929            }
930            QueryType::Describe { .. } => {
931                // DESCRIBE returns all triples about specified resources
932                Ok(where_expr)
933            }
934            QueryType::Construct { template: _ } => {
935                // CONSTRUCT applies template pattern after WHERE clause matches
936                // For now, we return the WHERE clause; template application
937                // would happen at execution time
938                Ok(where_expr)
939            }
940        }
941    }
942
943    /// Compile a graph pattern to TLExpr
944    fn compile_graph_pattern(&self, pattern: &GraphPattern) -> Result<TLExpr> {
945        match pattern {
946            GraphPattern::Triple(triple) => self.compile_triple_pattern(triple),
947
948            GraphPattern::Group(patterns) => {
949                if patterns.is_empty() {
950                    return Err(anyhow!("Empty pattern group"));
951                }
952
953                let mut exprs: Vec<TLExpr> = Vec::new();
954                for p in patterns {
955                    exprs.push(self.compile_graph_pattern(p)?);
956                }
957
958                // Combine with AND
959                Ok(exprs.into_iter().reduce(TLExpr::and).unwrap())
960            }
961
962            GraphPattern::Optional(inner) => {
963                // OPTIONAL in SPARQL is like left-outer join
964                // In logic, we can represent as: pattern OR TRUE
965                // This ensures the outer pattern succeeds even if inner fails
966                let inner_expr = self.compile_graph_pattern(inner)?;
967
968                // Use OR with a trivially true expression
969                // This gives "optional" semantics - the pattern can match or not
970                Ok(TLExpr::or(inner_expr.clone(), TLExpr::pred("true", vec![])))
971            }
972
973            GraphPattern::Union(left, right) => {
974                // UNION is disjunction
975                let left_expr = self.compile_graph_pattern(left)?;
976                let right_expr = self.compile_graph_pattern(right)?;
977                Ok(TLExpr::or(left_expr, right_expr))
978            }
979
980            GraphPattern::Filter(filter_cond) => self.compile_filter_condition(filter_cond),
981        }
982    }
983
984    /// Compile a triple pattern to TLExpr
985    fn compile_triple_pattern(&self, pattern: &TriplePattern) -> Result<TLExpr> {
986        let pred_name = match &pattern.predicate {
987            PatternElement::Constant(iri) => {
988                // Try to map IRI to predicate name
989                self.predicate_mapping
990                    .get(iri)
991                    .cloned()
992                    .unwrap_or_else(|| Self::iri_to_name(iri))
993            }
994            PatternElement::Variable(v) => {
995                return Err(anyhow!("Variable predicates not supported: ?{}", v));
996            }
997        };
998
999        let subj_term = match &pattern.subject {
1000            PatternElement::Variable(v) => Term::var(v),
1001            PatternElement::Constant(c) => Term::constant(c),
1002        };
1003
1004        let obj_term = match &pattern.object {
1005            PatternElement::Variable(v) => Term::var(v),
1006            PatternElement::Constant(c) => Term::constant(c),
1007        };
1008
1009        Ok(TLExpr::pred(&pred_name, vec![subj_term, obj_term]))
1010    }
1011
1012    /// Compile a filter condition to TLExpr
1013    fn compile_filter_condition(&self, filter: &FilterCondition) -> Result<TLExpr> {
1014        let expr = match filter {
1015            FilterCondition::Equals(var, val) => {
1016                TLExpr::pred("equals", vec![Term::var(var), Term::constant(val)])
1017            }
1018            FilterCondition::NotEquals(var, val) => TLExpr::negate(TLExpr::pred(
1019                "equals",
1020                vec![Term::var(var), Term::constant(val)],
1021            )),
1022            FilterCondition::GreaterThan(var, val) => {
1023                TLExpr::pred("greaterThan", vec![Term::var(var), Term::constant(val)])
1024            }
1025            FilterCondition::LessThan(var, val) => {
1026                TLExpr::pred("lessThan", vec![Term::var(var), Term::constant(val)])
1027            }
1028            FilterCondition::GreaterOrEqual(var, val) => {
1029                TLExpr::pred("greaterOrEqual", vec![Term::var(var), Term::constant(val)])
1030            }
1031            FilterCondition::LessOrEqual(var, val) => {
1032                TLExpr::pred("lessOrEqual", vec![Term::var(var), Term::constant(val)])
1033            }
1034            FilterCondition::Regex(var, pattern) => {
1035                TLExpr::pred("matches", vec![Term::var(var), Term::constant(pattern)])
1036            }
1037            FilterCondition::Bound(var) => TLExpr::pred("bound", vec![Term::var(var)]),
1038            FilterCondition::IsIri(var) => TLExpr::pred("isIri", vec![Term::var(var)]),
1039            FilterCondition::IsLiteral(var) => TLExpr::pred("isLiteral", vec![Term::var(var)]),
1040        };
1041
1042        Ok(expr)
1043    }
1044
1045    /// Extract local name from IRI
1046    fn iri_to_name(iri: &str) -> String {
1047        iri.split(['/', '#']).next_back().unwrap_or(iri).to_string()
1048    }
1049}
1050
1051impl Default for SparqlCompiler {
1052    fn default() -> Self {
1053        Self::new()
1054    }
1055}
1056
1057#[cfg(test)]
1058mod tests {
1059    use super::*;
1060
1061    // ====== Basic SELECT Query Tests ======
1062
1063    #[test]
1064    fn test_parse_simple_query() {
1065        let compiler = SparqlCompiler::new();
1066        let query = r#"
1067            SELECT ?x ?y WHERE {
1068              ?x <http://example.org/knows> ?y .
1069            }
1070        "#;
1071
1072        let parsed = compiler.parse_query(query).unwrap();
1073
1074        // Check query type
1075        match &parsed.query_type {
1076            QueryType::Select {
1077                select_vars,
1078                distinct,
1079                ..
1080            } => {
1081                assert_eq!(select_vars, &vec!["x", "y"]);
1082                assert!(!distinct);
1083            }
1084            _ => panic!("Expected SELECT query"),
1085        }
1086
1087        // Check WHERE pattern
1088        match &parsed.where_pattern {
1089            GraphPattern::Triple(pattern) => {
1090                assert_eq!(pattern.subject, PatternElement::Variable("x".to_string()));
1091                assert_eq!(
1092                    pattern.predicate,
1093                    PatternElement::Constant("http://example.org/knows".to_string())
1094                );
1095                assert_eq!(pattern.object, PatternElement::Variable("y".to_string()));
1096            }
1097            _ => panic!("Expected Triple pattern"),
1098        }
1099    }
1100
1101    #[test]
1102    fn test_parse_select_distinct() {
1103        let compiler = SparqlCompiler::new();
1104        let query = r#"
1105            SELECT DISTINCT ?x WHERE {
1106              ?x <http://example.org/type> ?t .
1107            }
1108        "#;
1109
1110        let parsed = compiler.parse_query(query).unwrap();
1111
1112        match &parsed.query_type {
1113            QueryType::Select {
1114                select_vars,
1115                distinct,
1116                ..
1117            } => {
1118                assert_eq!(select_vars, &vec!["x"]);
1119                assert!(distinct);
1120            }
1121            _ => panic!("Expected SELECT DISTINCT query"),
1122        }
1123    }
1124
1125    #[test]
1126    fn test_parse_query_with_filter() {
1127        let compiler = SparqlCompiler::new();
1128        let query = r#"
1129            SELECT ?x ?age WHERE {
1130              ?x <http://example.org/age> ?age .
1131              FILTER(?age > 18)
1132            }
1133        "#;
1134
1135        let parsed = compiler.parse_query(query).unwrap();
1136
1137        match &parsed.query_type {
1138            QueryType::Select { select_vars, .. } => {
1139                assert_eq!(select_vars, &vec!["x", "age"]);
1140            }
1141            _ => panic!("Expected SELECT query"),
1142        }
1143
1144        // Check WHERE pattern contains filter
1145        match &parsed.where_pattern {
1146            GraphPattern::Group(patterns) => {
1147                assert_eq!(patterns.len(), 2);
1148                // One Triple, one Filter
1149                assert!(matches!(patterns[0], GraphPattern::Triple(_)));
1150                assert!(matches!(patterns[1], GraphPattern::Filter(_)));
1151            }
1152            _ => panic!("Expected Group pattern with filter"),
1153        }
1154    }
1155
1156    #[test]
1157    fn test_parse_query_with_limit_offset() {
1158        let compiler = SparqlCompiler::new();
1159        let query = r#"
1160            SELECT ?x WHERE {
1161              ?x <http://example.org/type> ?t .
1162            } LIMIT 10 OFFSET 20
1163        "#;
1164
1165        let parsed = compiler.parse_query(query).unwrap();
1166        assert_eq!(parsed.limit, Some(10));
1167        assert_eq!(parsed.offset, Some(20));
1168    }
1169
1170    #[test]
1171    fn test_parse_query_with_order_by() {
1172        let compiler = SparqlCompiler::new();
1173        let query = r#"
1174            SELECT ?x ?name WHERE {
1175              ?x <http://example.org/name> ?name .
1176            } ORDER BY ?name
1177        "#;
1178
1179        let parsed = compiler.parse_query(query).unwrap();
1180        assert_eq!(parsed.order_by, vec!["name"]);
1181    }
1182
1183    // ====== ASK Query Tests ======
1184
1185    #[test]
1186    fn test_parse_ask_query() {
1187        let compiler = SparqlCompiler::new();
1188        let query = r#"
1189            ASK WHERE {
1190              ?x <http://example.org/knows> ?y .
1191            }
1192        "#;
1193
1194        let parsed = compiler.parse_query(query).unwrap();
1195
1196        match &parsed.query_type {
1197            QueryType::Ask => {
1198                // Success
1199            }
1200            _ => panic!("Expected ASK query"),
1201        }
1202    }
1203
1204    #[test]
1205    fn test_compile_ask_query() {
1206        let mut compiler = SparqlCompiler::new();
1207        compiler.add_predicate_mapping("http://example.org/knows".to_string(), "knows".to_string());
1208
1209        let query = r#"
1210            ASK WHERE {
1211              ?x <http://example.org/knows> ?y .
1212            }
1213        "#;
1214
1215        let parsed = compiler.parse_query(query).unwrap();
1216        let tl_expr = compiler.compile_to_tensorlogic(&parsed).unwrap();
1217
1218        // Should generate existence check
1219        let expr_str = format!("{:?}", tl_expr);
1220        assert!(expr_str.contains("knows"));
1221    }
1222
1223    // ====== DESCRIBE Query Tests ======
1224
1225    #[test]
1226    fn test_parse_describe_query() {
1227        let compiler = SparqlCompiler::new();
1228        let query = r#"
1229            DESCRIBE ?x WHERE {
1230              ?x <http://example.org/type> <http://example.org/Person> .
1231            }
1232        "#;
1233
1234        let parsed = compiler.parse_query(query).unwrap();
1235
1236        match &parsed.query_type {
1237            QueryType::Describe { resources } => {
1238                assert_eq!(resources, &vec!["x"]);
1239            }
1240            _ => panic!("Expected DESCRIBE query"),
1241        }
1242    }
1243
1244    #[test]
1245    fn test_compile_describe_query() {
1246        let mut compiler = SparqlCompiler::new();
1247        compiler.add_predicate_mapping("http://example.org/type".to_string(), "type".to_string());
1248
1249        let query = r#"
1250            DESCRIBE ?x WHERE {
1251              ?x <http://example.org/type> ?t .
1252            }
1253        "#;
1254
1255        let parsed = compiler.parse_query(query).unwrap();
1256        let tl_expr = compiler.compile_to_tensorlogic(&parsed).unwrap();
1257
1258        let expr_str = format!("{:?}", tl_expr);
1259        assert!(expr_str.contains("type"));
1260    }
1261
1262    // ====== CONSTRUCT Query Tests ======
1263
1264    #[test]
1265    fn test_parse_construct_query() {
1266        let compiler = SparqlCompiler::new();
1267        let query = r#"
1268            CONSTRUCT { ?x <http://example.org/friend> ?y }
1269            WHERE {
1270              ?x <http://example.org/knows> ?y .
1271            }
1272        "#;
1273
1274        let parsed = compiler.parse_query(query).unwrap();
1275
1276        match &parsed.query_type {
1277            QueryType::Construct { template } => {
1278                assert_eq!(template.len(), 1);
1279                let pattern = &template[0];
1280                assert_eq!(pattern.subject, PatternElement::Variable("x".to_string()));
1281                assert_eq!(
1282                    pattern.predicate,
1283                    PatternElement::Constant("http://example.org/friend".to_string())
1284                );
1285                assert_eq!(pattern.object, PatternElement::Variable("y".to_string()));
1286            }
1287            _ => panic!("Expected CONSTRUCT query"),
1288        }
1289    }
1290
1291    #[test]
1292    fn test_compile_construct_query() {
1293        let mut compiler = SparqlCompiler::new();
1294        compiler.add_predicate_mapping("http://example.org/knows".to_string(), "knows".to_string());
1295
1296        let query = r#"
1297            CONSTRUCT { ?x <http://example.org/friend> ?y }
1298            WHERE {
1299              ?x <http://example.org/knows> ?y .
1300            }
1301        "#;
1302
1303        let parsed = compiler.parse_query(query).unwrap();
1304        let tl_expr = compiler.compile_to_tensorlogic(&parsed).unwrap();
1305
1306        let expr_str = format!("{:?}", tl_expr);
1307        assert!(expr_str.contains("knows"));
1308    }
1309
1310    // ====== OPTIONAL Pattern Tests ======
1311
1312    #[test]
1313    fn test_parse_optional_pattern() {
1314        let compiler = SparqlCompiler::new();
1315        let query = r#"
1316            SELECT ?x ?name ?age WHERE {
1317              ?x <http://example.org/name> ?name .
1318              OPTIONAL { ?x <http://example.org/age> ?age }
1319            }
1320        "#;
1321
1322        let parsed = compiler.parse_query(query).unwrap();
1323
1324        match &parsed.where_pattern {
1325            GraphPattern::Group(patterns) => {
1326                assert_eq!(patterns.len(), 2);
1327                assert!(matches!(patterns[0], GraphPattern::Triple(_)));
1328                assert!(matches!(patterns[1], GraphPattern::Optional(_)));
1329            }
1330            _ => panic!("Expected Group with OPTIONAL"),
1331        }
1332    }
1333
1334    #[test]
1335    fn test_compile_optional_pattern() {
1336        let mut compiler = SparqlCompiler::new();
1337        compiler.add_predicate_mapping("http://example.org/name".to_string(), "name".to_string());
1338        compiler.add_predicate_mapping("http://example.org/age".to_string(), "age".to_string());
1339
1340        let query = r#"
1341            SELECT ?x ?name WHERE {
1342              ?x <http://example.org/name> ?name .
1343              OPTIONAL { ?x <http://example.org/age> ?age }
1344            }
1345        "#;
1346
1347        let parsed = compiler.parse_query(query).unwrap();
1348        let tl_expr = compiler.compile_to_tensorlogic(&parsed).unwrap();
1349
1350        // Should have OR for optional semantics
1351        let expr_str = format!("{:?}", tl_expr);
1352        assert!(expr_str.contains("name"));
1353        assert!(expr_str.contains("Or"));
1354    }
1355
1356    // ====== UNION Pattern Tests ======
1357
1358    #[test]
1359    fn test_parse_union_pattern() {
1360        let compiler = SparqlCompiler::new();
1361        let query = r#"
1362            SELECT ?x ?y WHERE {
1363              { ?x <http://example.org/knows> ?y }
1364              UNION
1365              { ?x <http://example.org/likes> ?y }
1366            }
1367        "#;
1368
1369        let parsed = compiler.parse_query(query).unwrap();
1370
1371        match &parsed.where_pattern {
1372            GraphPattern::Union(_, _) => {
1373                // Success - found UNION pattern
1374            }
1375            _ => panic!("Expected UNION pattern"),
1376        }
1377    }
1378
1379    #[test]
1380    fn test_compile_union_pattern() {
1381        let mut compiler = SparqlCompiler::new();
1382        compiler.add_predicate_mapping("http://example.org/knows".to_string(), "knows".to_string());
1383        compiler.add_predicate_mapping("http://example.org/likes".to_string(), "likes".to_string());
1384
1385        let query = r#"
1386            SELECT ?x ?y WHERE {
1387              { ?x <http://example.org/knows> ?y }
1388              UNION
1389              { ?x <http://example.org/likes> ?y }
1390            }
1391        "#;
1392
1393        let parsed = compiler.parse_query(query).unwrap();
1394        let tl_expr = compiler.compile_to_tensorlogic(&parsed).unwrap();
1395
1396        // Should have OR for union
1397        let expr_str = format!("{:?}", tl_expr);
1398        assert!(expr_str.contains("knows") || expr_str.contains("likes"));
1399        assert!(expr_str.contains("Or"));
1400    }
1401
1402    // ====== Filter Conditions Tests ======
1403
1404    #[test]
1405    fn test_filter_greater_or_equal() {
1406        let compiler = SparqlCompiler::new();
1407        let query = r#"
1408            SELECT ?x WHERE {
1409              ?x <http://example.org/age> ?age .
1410              FILTER(?age >= 18)
1411            }
1412        "#;
1413
1414        let parsed = compiler.parse_query(query).unwrap();
1415
1416        match &parsed.where_pattern {
1417            GraphPattern::Group(patterns) => {
1418                if let Some(GraphPattern::Filter(FilterCondition::GreaterOrEqual(var, val))) =
1419                    patterns.get(1)
1420                {
1421                    assert_eq!(var, "age");
1422                    assert_eq!(val, "18");
1423                } else {
1424                    panic!("Expected GreaterOrEqual filter");
1425                }
1426            }
1427            _ => panic!("Expected Group pattern"),
1428        }
1429    }
1430
1431    #[test]
1432    fn test_filter_bound() {
1433        let compiler = SparqlCompiler::new();
1434        let filter = compiler.parse_filter("FILTER(BOUND(?x))").unwrap();
1435
1436        match filter {
1437            Some(FilterCondition::Bound(var)) => {
1438                assert_eq!(var, "x");
1439            }
1440            _ => panic!("Expected BOUND filter"),
1441        }
1442    }
1443
1444    #[test]
1445    fn test_filter_is_iri() {
1446        let compiler = SparqlCompiler::new();
1447        let filter = compiler.parse_filter("FILTER(isIRI(?x))").unwrap();
1448
1449        match filter {
1450            Some(FilterCondition::IsIri(var)) => {
1451                assert_eq!(var, "x");
1452            }
1453            _ => panic!("Expected isIRI filter"),
1454        }
1455    }
1456
1457    #[test]
1458    fn test_filter_regex() {
1459        let compiler = SparqlCompiler::new();
1460        let filter = compiler
1461            .parse_filter(r#"FILTER(regex(?name, "^John"))"#)
1462            .unwrap();
1463
1464        match filter {
1465            Some(FilterCondition::Regex(var, pattern)) => {
1466                assert_eq!(var, "name");
1467                assert_eq!(pattern, "^John");
1468            }
1469            _ => panic!("Expected regex filter"),
1470        }
1471    }
1472
1473    // ====== Compilation Tests ======
1474
1475    #[test]
1476    fn test_compile_simple_query() {
1477        let mut compiler = SparqlCompiler::new();
1478        compiler.add_predicate_mapping("http://example.org/knows".to_string(), "knows".to_string());
1479
1480        let query = r#"
1481            SELECT ?x ?y WHERE {
1482              ?x <http://example.org/knows> ?y .
1483            }
1484        "#;
1485
1486        let parsed = compiler.parse_query(query).unwrap();
1487        let tl_expr = compiler.compile_to_tensorlogic(&parsed).unwrap();
1488
1489        // Should generate a predicate expression
1490        let expr_str = format!("{:?}", tl_expr);
1491        assert!(expr_str.contains("knows"));
1492    }
1493
1494    #[test]
1495    fn test_compile_query_with_multiple_patterns() {
1496        let mut compiler = SparqlCompiler::new();
1497        compiler.add_predicate_mapping("http://example.org/knows".to_string(), "knows".to_string());
1498
1499        let query = r#"
1500            SELECT ?x ?y ?z WHERE {
1501              ?x <http://example.org/knows> ?y .
1502              ?y <http://example.org/knows> ?z .
1503            }
1504        "#;
1505
1506        let parsed = compiler.parse_query(query).unwrap();
1507        let tl_expr = compiler.compile_to_tensorlogic(&parsed).unwrap();
1508
1509        // Should generate AND of predicates
1510        let expr_str = format!("{:?}", tl_expr);
1511        assert!(expr_str.contains("knows"));
1512        assert!(expr_str.contains("And"));
1513    }
1514
1515    #[test]
1516    fn test_compile_query_with_filter() {
1517        let mut compiler = SparqlCompiler::new();
1518        compiler.add_predicate_mapping("http://example.org/age".to_string(), "age".to_string());
1519
1520        let query = r#"
1521            SELECT ?x ?a WHERE {
1522              ?x <http://example.org/age> ?a .
1523              FILTER(?a > 18)
1524            }
1525        "#;
1526
1527        let parsed = compiler.parse_query(query).unwrap();
1528        let tl_expr = compiler.compile_to_tensorlogic(&parsed).unwrap();
1529
1530        // Should include both predicate and filter
1531        let expr_str = format!("{:?}", tl_expr);
1532        assert!(expr_str.contains("age"));
1533        assert!(expr_str.contains("greaterThan"));
1534    }
1535
1536    // ====== Utility Tests ======
1537
1538    #[test]
1539    fn test_iri_to_name() {
1540        assert_eq!(
1541            SparqlCompiler::iri_to_name("http://example.org/knows"),
1542            "knows"
1543        );
1544        assert_eq!(
1545            SparqlCompiler::iri_to_name("http://xmlns.com/foaf/0.1#Person"),
1546            "Person"
1547        );
1548        assert_eq!(SparqlCompiler::iri_to_name("simple"), "simple");
1549    }
1550
1551    // ====== Complex Integration Tests ======
1552
1553    #[test]
1554    fn test_complex_query_with_optional_and_filter() {
1555        let mut compiler = SparqlCompiler::new();
1556        compiler.add_predicate_mapping("http://example.org/name".to_string(), "name".to_string());
1557        compiler.add_predicate_mapping("http://example.org/age".to_string(), "age".to_string());
1558
1559        let query = r#"
1560            SELECT DISTINCT ?x ?name WHERE {
1561              ?x <http://example.org/name> ?name .
1562              OPTIONAL {
1563                ?x <http://example.org/age> ?age .
1564                FILTER(?age >= 21)
1565              }
1566            } LIMIT 100 ORDER BY ?name
1567        "#;
1568
1569        let parsed = compiler.parse_query(query).unwrap();
1570
1571        // Check all components
1572        match &parsed.query_type {
1573            QueryType::Select {
1574                select_vars,
1575                distinct,
1576                ..
1577            } => {
1578                assert_eq!(select_vars, &vec!["x", "name"]);
1579                assert!(distinct);
1580            }
1581            _ => panic!("Expected SELECT DISTINCT"),
1582        }
1583
1584        assert_eq!(parsed.limit, Some(100));
1585        assert_eq!(parsed.order_by, vec!["name"]);
1586
1587        // Check WHERE pattern structure - should be a Group with at least 2 patterns
1588        match &parsed.where_pattern {
1589            GraphPattern::Group(patterns) => {
1590                assert!(patterns.len() >= 2, "Expected at least 2 patterns in group");
1591                // First should be a Triple (name predicate)
1592                assert!(matches!(patterns[0], GraphPattern::Triple(_)));
1593            }
1594            _ => panic!("Expected Group pattern"),
1595        }
1596
1597        // Compile and check basic predicates are present
1598        let tl_expr = compiler.compile_to_tensorlogic(&parsed).unwrap();
1599        let expr_str = format!("{:?}", tl_expr);
1600        assert!(expr_str.contains("name"));
1601        // Should have logical operators combining the patterns
1602        assert!(expr_str.contains("And") || expr_str.contains("Or"));
1603    }
1604
1605    // ====== Aggregate Function Tests ======
1606
1607    #[test]
1608    fn test_parse_count_aggregate() {
1609        let compiler = SparqlCompiler::new();
1610        let query = r#"
1611            SELECT (COUNT(?x) AS ?count) WHERE {
1612              ?x <http://example.org/type> <http://example.org/Person> .
1613            }
1614        "#;
1615
1616        let parsed = compiler.parse_query(query).unwrap();
1617
1618        match &parsed.query_type {
1619            QueryType::Select { projections, .. } => {
1620                assert_eq!(projections.len(), 1);
1621                match &projections[0] {
1622                    SelectElement::Aggregate { function, alias } => {
1623                        assert!(matches!(function, AggregateFunction::Count { .. }));
1624                        assert_eq!(alias, &Some("count".to_string()));
1625                    }
1626                    _ => panic!("Expected Aggregate element"),
1627                }
1628            }
1629            _ => panic!("Expected SELECT"),
1630        }
1631    }
1632
1633    #[test]
1634    fn test_parse_sum_aggregate() {
1635        let compiler = SparqlCompiler::new();
1636        let query = r#"
1637            SELECT (SUM(?amount) AS ?total) WHERE {
1638              ?x <http://example.org/amount> ?amount .
1639            }
1640        "#;
1641
1642        let parsed = compiler.parse_query(query).unwrap();
1643
1644        match &parsed.query_type {
1645            QueryType::Select { projections, .. } => {
1646                assert_eq!(projections.len(), 1);
1647                match &projections[0] {
1648                    SelectElement::Aggregate { function, .. } => {
1649                        if let AggregateFunction::Sum { variable, .. } = function {
1650                            assert_eq!(variable, "amount");
1651                        } else {
1652                            panic!("Expected SUM aggregate");
1653                        }
1654                    }
1655                    _ => panic!("Expected Aggregate element"),
1656                }
1657            }
1658            _ => panic!("Expected SELECT"),
1659        }
1660    }
1661
1662    #[test]
1663    fn test_parse_avg_min_max() {
1664        let compiler = SparqlCompiler::new();
1665        let query = r#"
1666            SELECT (AVG(?age) AS ?avg_age) (MIN(?age) AS ?min_age) (MAX(?age) AS ?max_age) WHERE {
1667              ?x <http://example.org/age> ?age .
1668            }
1669        "#;
1670
1671        let parsed = compiler.parse_query(query).unwrap();
1672
1673        match &parsed.query_type {
1674            QueryType::Select { projections, .. } => {
1675                assert_eq!(projections.len(), 3);
1676                // Check AVG
1677                match &projections[0] {
1678                    SelectElement::Aggregate { function, .. } => {
1679                        assert!(matches!(function, AggregateFunction::Avg { .. }));
1680                    }
1681                    _ => panic!("Expected Aggregate element"),
1682                }
1683                // Check MIN
1684                match &projections[1] {
1685                    SelectElement::Aggregate { function, .. } => {
1686                        assert!(matches!(function, AggregateFunction::Min { .. }));
1687                    }
1688                    _ => panic!("Expected Aggregate element"),
1689                }
1690                // Check MAX
1691                match &projections[2] {
1692                    SelectElement::Aggregate { function, .. } => {
1693                        assert!(matches!(function, AggregateFunction::Max { .. }));
1694                    }
1695                    _ => panic!("Expected Aggregate element"),
1696                }
1697            }
1698            _ => panic!("Expected SELECT"),
1699        }
1700    }
1701
1702    #[test]
1703    fn test_parse_group_by() {
1704        let compiler = SparqlCompiler::new();
1705        let query = r#"
1706            SELECT ?dept (COUNT(?person) AS ?count) WHERE {
1707              ?person <http://example.org/department> ?dept .
1708            } GROUP BY ?dept
1709        "#;
1710
1711        let parsed = compiler.parse_query(query).unwrap();
1712
1713        assert_eq!(parsed.group_by, vec!["dept"]);
1714
1715        match &parsed.query_type {
1716            QueryType::Select { projections, .. } => {
1717                assert_eq!(projections.len(), 2);
1718                // First should be variable
1719                match &projections[0] {
1720                    SelectElement::Variable(name) => assert_eq!(name, "dept"),
1721                    _ => panic!("Expected Variable element"),
1722                }
1723                // Second should be aggregate
1724                match &projections[1] {
1725                    SelectElement::Aggregate { function, .. } => {
1726                        assert!(matches!(function, AggregateFunction::Count { .. }));
1727                    }
1728                    _ => panic!("Expected Aggregate element"),
1729                }
1730            }
1731            _ => panic!("Expected SELECT"),
1732        }
1733    }
1734
1735    #[test]
1736    fn test_parse_having() {
1737        let compiler = SparqlCompiler::new();
1738        let query = r#"
1739            SELECT ?dept (COUNT(?person) AS ?count) WHERE {
1740              ?person <http://example.org/department> ?dept .
1741            } GROUP BY ?dept HAVING(?count > 10)
1742        "#;
1743
1744        let parsed = compiler.parse_query(query).unwrap();
1745
1746        assert_eq!(parsed.group_by, vec!["dept"]);
1747        assert_eq!(parsed.having.len(), 1);
1748
1749        match &parsed.having[0] {
1750            FilterCondition::GreaterThan(var, val) => {
1751                assert_eq!(var, "count");
1752                assert_eq!(val, "10");
1753            }
1754            _ => panic!("Expected GreaterThan condition"),
1755        }
1756    }
1757
1758    #[test]
1759    fn test_parse_count_distinct() {
1760        let compiler = SparqlCompiler::new();
1761        let query = r#"
1762            SELECT (COUNT(DISTINCT ?person) AS ?unique) WHERE {
1763              ?person <http://example.org/type> <http://example.org/Person> .
1764            }
1765        "#;
1766
1767        let parsed = compiler.parse_query(query).unwrap();
1768
1769        match &parsed.query_type {
1770            QueryType::Select { projections, .. } => match &projections[0] {
1771                SelectElement::Aggregate { function, .. } => {
1772                    if let AggregateFunction::Count { distinct, .. } = function {
1773                        assert!(distinct);
1774                    } else {
1775                        panic!("Expected COUNT aggregate");
1776                    }
1777                }
1778                _ => panic!("Expected Aggregate element"),
1779            },
1780            _ => panic!("Expected SELECT"),
1781        }
1782    }
1783
1784    #[test]
1785    fn test_parse_count_star() {
1786        let compiler = SparqlCompiler::new();
1787        let query = r#"
1788            SELECT (COUNT(*) AS ?total) WHERE {
1789              ?x <http://example.org/type> ?type .
1790            }
1791        "#;
1792
1793        let parsed = compiler.parse_query(query).unwrap();
1794
1795        match &parsed.query_type {
1796            QueryType::Select { projections, .. } => match &projections[0] {
1797                SelectElement::Aggregate { function, .. } => {
1798                    if let AggregateFunction::Count { variable, .. } = function {
1799                        assert!(variable.is_none());
1800                    } else {
1801                        panic!("Expected COUNT aggregate");
1802                    }
1803                }
1804                _ => panic!("Expected Aggregate element"),
1805            },
1806            _ => panic!("Expected SELECT"),
1807        }
1808    }
1809
1810    #[test]
1811    fn test_combined_variables_and_aggregates() {
1812        let compiler = SparqlCompiler::new();
1813        let query = r#"
1814            SELECT ?category (SUM(?price) AS ?total) (AVG(?price) AS ?average) WHERE {
1815              ?item <http://example.org/category> ?category .
1816              ?item <http://example.org/price> ?price .
1817            } GROUP BY ?category ORDER BY ?total LIMIT 10
1818        "#;
1819
1820        let parsed = compiler.parse_query(query).unwrap();
1821
1822        // Check projections
1823        match &parsed.query_type {
1824            QueryType::Select {
1825                projections,
1826                select_vars,
1827                ..
1828            } => {
1829                assert_eq!(projections.len(), 3);
1830                assert_eq!(select_vars, &vec!["category", "total", "average"]);
1831            }
1832            _ => panic!("Expected SELECT"),
1833        }
1834
1835        // Check modifiers
1836        assert_eq!(parsed.group_by, vec!["category"]);
1837        assert_eq!(parsed.order_by, vec!["total"]);
1838        assert_eq!(parsed.limit, Some(10));
1839    }
1840}