1use anyhow::{anyhow, Result};
17use std::collections::HashMap;
18use tensorlogic_ir::{TLExpr, Term};
19
20#[derive(Debug, Clone, PartialEq)]
22pub struct TriplePattern {
23 pub subject: PatternElement,
24 pub predicate: PatternElement,
25 pub object: PatternElement,
26}
27
28#[derive(Debug, Clone, PartialEq)]
30pub enum PatternElement {
31 Variable(String),
32 Constant(String),
33}
34
35#[derive(Debug, Clone, PartialEq)]
37pub enum FilterCondition {
38 Equals(String, String),
39 NotEquals(String, String),
40 GreaterThan(String, String),
41 LessThan(String, String),
42 GreaterOrEqual(String, String),
43 LessOrEqual(String, String),
44 Regex(String, String),
45 Bound(String),
46 IsIri(String),
47 IsLiteral(String),
48}
49
50#[derive(Debug, Clone, PartialEq)]
52pub enum AggregateFunction {
53 Count {
55 variable: Option<String>,
56 distinct: bool,
57 },
58 Sum { variable: String, distinct: bool },
60 Avg { variable: String, distinct: bool },
62 Min { variable: String },
64 Max { variable: String },
66 GroupConcat {
68 variable: String,
69 separator: Option<String>,
70 distinct: bool,
71 },
72 Sample { variable: String },
74}
75
76#[derive(Debug, Clone, PartialEq)]
78pub enum SelectElement {
79 Variable(String),
81 Aggregate {
83 function: AggregateFunction,
84 alias: Option<String>,
85 },
86}
87
88#[derive(Debug, Clone, PartialEq)]
90pub enum GraphPattern {
91 Triple(TriplePattern),
93 Group(Vec<GraphPattern>),
95 Optional(Box<GraphPattern>),
97 Union(Box<GraphPattern>, Box<GraphPattern>),
99 Filter(FilterCondition),
101}
102
103#[derive(Debug, Clone, PartialEq)]
105pub enum QueryType {
106 Select {
108 projections: Vec<SelectElement>,
110 select_vars: Vec<String>,
112 distinct: bool,
113 },
114 Ask,
116 Describe { resources: Vec<String> },
118 Construct { template: Vec<TriplePattern> },
120}
121
122#[derive(Debug, Clone)]
124pub struct SparqlQuery {
125 pub query_type: QueryType,
127 pub where_pattern: GraphPattern,
129 pub group_by: Vec<String>,
131 pub having: Vec<FilterCondition>,
133 pub limit: Option<usize>,
135 pub offset: Option<usize>,
136 pub order_by: Vec<String>,
137}
138
139pub struct SparqlCompiler {
141 predicate_mapping: HashMap<String, String>,
143}
144
145impl SparqlCompiler {
146 pub fn new() -> Self {
147 SparqlCompiler {
148 predicate_mapping: HashMap::new(),
149 }
150 }
151
152 pub fn add_predicate_mapping(&mut self, iri: String, predicate_name: String) {
156 self.predicate_mapping.insert(iri, predicate_name);
157 }
158
159 pub fn parse_query(&self, sparql: &str) -> Result<SparqlQuery> {
188 let normalized = sparql
190 .lines()
191 .map(|l| l.trim())
192 .filter(|l| !l.is_empty())
193 .collect::<Vec<_>>()
194 .join(" ");
195
196 let query_type = self.parse_query_type(&normalized)?;
198
199 let where_pattern = self.parse_where_clause(&normalized)?;
201
202 let group_by = self.parse_group_by(&normalized);
204 let having = self.parse_having(&normalized)?;
205
206 let limit = self.parse_limit(&normalized);
208 let offset = self.parse_offset(&normalized);
209 let order_by = self.parse_order_by(&normalized);
210
211 Ok(SparqlQuery {
212 query_type,
213 where_pattern,
214 group_by,
215 having,
216 limit,
217 offset,
218 order_by,
219 })
220 }
221
222 fn parse_group_by(&self, normalized: &str) -> Vec<String> {
224 let mut group_by = Vec::new();
225
226 if let Some(group_pos) = normalized.find("GROUP BY") {
227 let remaining = &normalized[group_pos + 8..];
229 let end_pos = remaining
230 .find("HAVING")
231 .or_else(|| remaining.find("ORDER BY"))
232 .or_else(|| remaining.find("LIMIT"))
233 .or_else(|| remaining.find("OFFSET"))
234 .unwrap_or(remaining.len());
235
236 let group_part = remaining[..end_pos].trim();
237 for token in group_part.split_whitespace() {
238 if let Some(var_name) = token.strip_prefix('?') {
239 group_by.push(var_name.to_string());
240 }
241 }
242 }
243
244 group_by
245 }
246
247 fn parse_having(&self, normalized: &str) -> Result<Vec<FilterCondition>> {
249 let mut conditions = Vec::new();
250
251 if let Some(having_pos) = normalized.find("HAVING") {
252 let remaining = &normalized[having_pos + 6..];
254 let end_pos = remaining
255 .find("ORDER BY")
256 .or_else(|| remaining.find("LIMIT"))
257 .or_else(|| remaining.find("OFFSET"))
258 .unwrap_or(remaining.len());
259
260 let having_part = remaining[..end_pos].trim();
261
262 if !having_part.is_empty() {
264 if let Some(filter) = self.parse_filter(&format!("FILTER{}", having_part))? {
265 conditions.push(filter);
266 }
267 }
268 }
269
270 Ok(conditions)
271 }
272
273 fn parse_aggregate(&self, text: &str) -> Option<(AggregateFunction, String)> {
275 let text = text.trim();
276
277 let (func_part, alias) = if let Some(as_pos) = text.to_uppercase().find(" AS ") {
279 let alias_start = as_pos + 4;
280 let alias = text[alias_start..]
281 .trim()
282 .trim_matches(|c| c == '?' || c == ')')
283 .to_string();
284 (text[..as_pos].trim(), Some(alias))
285 } else {
286 (text, None)
287 };
288
289 let upper = func_part.to_uppercase();
291
292 if upper.starts_with("COUNT(") {
293 let inner = func_part[6..].trim_end_matches(')').trim();
294 let distinct = inner.to_uppercase().starts_with("DISTINCT");
295 let var_part = if distinct { inner[8..].trim() } else { inner };
296 let variable = if var_part == "*" {
297 None
298 } else {
299 Some(var_part.trim_start_matches('?').to_string())
300 };
301 return Some((
302 AggregateFunction::Count { variable, distinct },
303 alias.unwrap_or_else(|| "count".to_string()),
304 ));
305 }
306
307 if upper.starts_with("SUM(") {
308 let inner = func_part[4..].trim_end_matches(')').trim();
309 let distinct = inner.to_uppercase().starts_with("DISTINCT");
310 let var_part = if distinct { inner[8..].trim() } else { inner };
311 let variable = var_part.trim_start_matches('?').to_string();
312 return Some((
313 AggregateFunction::Sum { variable, distinct },
314 alias.unwrap_or_else(|| "sum".to_string()),
315 ));
316 }
317
318 if upper.starts_with("AVG(") {
319 let inner = func_part[4..].trim_end_matches(')').trim();
320 let distinct = inner.to_uppercase().starts_with("DISTINCT");
321 let var_part = if distinct { inner[8..].trim() } else { inner };
322 let variable = var_part.trim_start_matches('?').to_string();
323 return Some((
324 AggregateFunction::Avg { variable, distinct },
325 alias.unwrap_or_else(|| "avg".to_string()),
326 ));
327 }
328
329 if upper.starts_with("MIN(") {
330 let inner = func_part[4..].trim_end_matches(')').trim();
331 let variable = inner.trim_start_matches('?').to_string();
332 return Some((
333 AggregateFunction::Min { variable },
334 alias.unwrap_or_else(|| "min".to_string()),
335 ));
336 }
337
338 if upper.starts_with("MAX(") {
339 let inner = func_part[4..].trim_end_matches(')').trim();
340 let variable = inner.trim_start_matches('?').to_string();
341 return Some((
342 AggregateFunction::Max { variable },
343 alias.unwrap_or_else(|| "max".to_string()),
344 ));
345 }
346
347 if upper.starts_with("GROUP_CONCAT(") {
348 let inner = func_part[13..].trim_end_matches(')').trim();
349 let distinct = inner.to_uppercase().starts_with("DISTINCT");
350 let var_part = if distinct { inner[8..].trim() } else { inner };
351 let (variable, separator) =
353 if let Some(sep_pos) = var_part.to_uppercase().find("; SEPARATOR") {
354 let var = var_part[..sep_pos]
355 .trim()
356 .trim_start_matches('?')
357 .to_string();
358 let sep_start = var_part.find('=').map(|p| p + 1).unwrap_or(sep_pos);
359 let sep = var_part[sep_start..].trim().trim_matches('"').to_string();
360 (var, Some(sep))
361 } else {
362 (var_part.trim_start_matches('?').to_string(), None)
363 };
364 return Some((
365 AggregateFunction::GroupConcat {
366 variable,
367 separator,
368 distinct,
369 },
370 alias.unwrap_or_else(|| "group_concat".to_string()),
371 ));
372 }
373
374 if upper.starts_with("SAMPLE(") {
375 let inner = func_part[7..].trim_end_matches(')').trim();
376 let variable = inner.trim_start_matches('?').to_string();
377 return Some((
378 AggregateFunction::Sample { variable },
379 alias.unwrap_or_else(|| "sample".to_string()),
380 ));
381 }
382
383 None
384 }
385
386 fn parse_query_type(&self, normalized: &str) -> Result<QueryType> {
388 if normalized.contains("ASK") {
389 Ok(QueryType::Ask)
390 } else if let Some(describe_pos) = normalized.find("DESCRIBE") {
391 let where_pos = normalized.find("WHERE").unwrap_or(normalized.len());
393 let describe_part = normalized[describe_pos + 8..where_pos].trim();
394 let mut resources = Vec::new();
395
396 for token in describe_part.split_whitespace() {
397 if token.starts_with('?') || token.starts_with('<') {
398 resources.push(
399 token
400 .trim_matches(|c| c == '?' || c == '<' || c == '>')
401 .to_string(),
402 );
403 }
404 }
405
406 Ok(QueryType::Describe { resources })
407 } else if normalized.contains("CONSTRUCT") {
408 let template = self.parse_construct_template(normalized)?;
410 Ok(QueryType::Construct { template })
411 } else if let Some(select_pos) = normalized.find("SELECT") {
412 let where_pos = normalized.find("WHERE").unwrap_or(normalized.len());
414 let select_part = normalized[select_pos + 6..where_pos].trim();
415
416 let distinct = select_part.starts_with("DISTINCT");
417 let vars_part = if distinct {
418 &select_part[8..]
419 } else {
420 select_part
421 };
422
423 let mut select_vars = Vec::new();
424 let mut projections = Vec::new();
425
426 let mut current_token = String::new();
428 let mut paren_depth = 0;
429
430 for c in vars_part.chars() {
431 match c {
432 '(' => {
433 paren_depth += 1;
434 current_token.push(c);
435 }
436 ')' => {
437 paren_depth -= 1;
438 current_token.push(c);
439 }
440 ' ' | ',' if paren_depth == 0 => {
441 if !current_token.trim().is_empty() {
442 let token = current_token.trim();
443 let token = if token.starts_with('(') && token.ends_with(')') {
445 &token[1..token.len() - 1]
446 } else {
447 token
448 };
449 if let Some((agg_func, alias)) = self.parse_aggregate(token) {
450 projections.push(SelectElement::Aggregate {
451 function: agg_func,
452 alias: Some(alias.clone()),
453 });
454 select_vars.push(alias);
455 } else if let Some(var_name) = token.strip_prefix('?') {
456 projections.push(SelectElement::Variable(var_name.to_string()));
457 select_vars.push(var_name.to_string());
458 } else if token == "*" {
459 projections.push(SelectElement::Variable("*".to_string()));
460 select_vars.push("*".to_string());
461 }
462 }
463 current_token.clear();
464 }
465 _ => current_token.push(c),
466 }
467 }
468
469 if !current_token.trim().is_empty() {
471 let token = current_token.trim();
472 let token = if token.starts_with('(') && token.ends_with(')') {
474 &token[1..token.len() - 1]
475 } else {
476 token
477 };
478 if let Some((agg_func, alias)) = self.parse_aggregate(token) {
479 projections.push(SelectElement::Aggregate {
480 function: agg_func,
481 alias: Some(alias.clone()),
482 });
483 select_vars.push(alias);
484 } else if let Some(var_name) = token.strip_prefix('?') {
485 projections.push(SelectElement::Variable(var_name.to_string()));
486 select_vars.push(var_name.to_string());
487 } else if token == "*" {
488 projections.push(SelectElement::Variable("*".to_string()));
489 select_vars.push("*".to_string());
490 }
491 }
492
493 Ok(QueryType::Select {
494 projections,
495 select_vars,
496 distinct,
497 })
498 } else {
499 Err(anyhow!("Unable to determine query type"))
500 }
501 }
502
503 fn parse_construct_template(&self, normalized: &str) -> Result<Vec<TriplePattern>> {
505 let construct_pos = normalized
506 .find("CONSTRUCT")
507 .ok_or_else(|| anyhow!("No CONSTRUCT found"))?;
508 let where_pos = normalized.find("WHERE").unwrap_or(normalized.len());
509
510 let template_start = normalized[construct_pos..where_pos]
512 .find('{')
513 .ok_or_else(|| anyhow!("No opening brace in CONSTRUCT template"))?;
514 let template_end = normalized[construct_pos..where_pos]
515 .rfind('}')
516 .ok_or_else(|| anyhow!("No closing brace in CONSTRUCT template"))?;
517
518 let template_content =
519 &normalized[construct_pos + template_start + 1..construct_pos + template_end];
520
521 let mut patterns = Vec::new();
522 for statement in self.split_sparql_statements(template_content) {
523 if let Some(pattern) = self.parse_triple_pattern(statement)? {
524 patterns.push(pattern);
525 }
526 }
527
528 Ok(patterns)
529 }
530
531 fn parse_where_clause(&self, normalized: &str) -> Result<GraphPattern> {
533 if let Some(where_start) = normalized.find("WHERE") {
535 if let Some(brace_start) = normalized[where_start..].find('{') {
536 let content_start = where_start + brace_start + 1;
537
538 let closing_brace = self.find_matching_brace(&normalized[content_start..])?;
540 let where_content = &normalized[content_start..content_start + closing_brace];
541
542 return self.parse_graph_pattern(where_content);
543 }
544 }
545
546 Err(anyhow!("No WHERE clause found"))
547 }
548
549 fn parse_graph_pattern(&self, content: &str) -> Result<GraphPattern> {
551 let content = content.trim();
552
553 if content.is_empty() {
554 return Err(anyhow!("Empty graph pattern"));
555 }
556
557 if let Some(union_pos) = content.find("UNION") {
559 let before_union = &content[..union_pos];
561 let open_braces = before_union.matches('{').count();
562 let close_braces = before_union.matches('}').count();
563
564 if open_braces == close_braces {
565 let left_part = before_union.trim();
567 let right_part = content[union_pos + 5..].trim();
568
569 let left_pattern = self.parse_graph_pattern(left_part)?;
570 let right_pattern = self.parse_graph_pattern(right_part)?;
571
572 return Ok(GraphPattern::Union(
573 Box::new(left_pattern),
574 Box::new(right_pattern),
575 ));
576 }
577 }
578
579 let mut patterns = Vec::new();
581 let statements = self.split_sparql_statements(content);
582
583 for statement in statements {
584 let statement = statement.trim();
585
586 if statement.is_empty() {
587 continue;
588 }
589
590 if statement.starts_with("OPTIONAL") {
592 if let Some(brace_start_pos) = statement.find('{') {
594 let content_start = brace_start_pos + 1;
595 if let Ok(closing_offset) =
596 self.find_matching_brace(&statement[content_start..])
597 {
598 let optional_content =
599 &statement[content_start..content_start + closing_offset];
600 let inner_pattern = self.parse_graph_pattern(optional_content)?;
601 patterns.push(GraphPattern::Optional(Box::new(inner_pattern)));
602 continue;
603 }
604 }
605 }
606
607 if statement.starts_with("FILTER") {
609 if let Some(filter) = self.parse_filter(statement)? {
610 patterns.push(GraphPattern::Filter(filter));
611 }
612 continue;
613 }
614
615 if statement.starts_with('{') && statement.ends_with('}') {
617 let inner = &statement[1..statement.len() - 1];
618 let inner_pattern = self.parse_graph_pattern(inner)?;
619 patterns.push(inner_pattern);
620 continue;
621 }
622
623 if let Some(pattern) = self.parse_triple_pattern(statement)? {
625 patterns.push(GraphPattern::Triple(pattern));
626 }
627 }
628
629 if patterns.is_empty() {
630 Err(anyhow!("Empty graph pattern in content: {}", content))
631 } else if patterns.len() == 1 {
632 Ok(patterns.into_iter().next().unwrap())
633 } else {
634 Ok(GraphPattern::Group(patterns))
635 }
636 }
637
638 fn find_matching_brace(&self, content: &str) -> Result<usize> {
640 let mut depth = 1;
641 let chars: Vec<char> = content.chars().collect();
642
643 for (i, &c) in chars.iter().enumerate() {
644 match c {
645 '{' => depth += 1,
646 '}' => {
647 depth -= 1;
648 if depth == 0 {
649 return Ok(i);
650 }
651 }
652 _ => {}
653 }
654 }
655
656 Err(anyhow!("No matching closing brace found"))
657 }
658
659 fn parse_limit(&self, normalized: &str) -> Option<usize> {
661 if let Some(limit_pos) = normalized.find("LIMIT") {
662 let after_limit = &normalized[limit_pos + 5..].trim();
663 if let Some(num_str) = after_limit.split_whitespace().next() {
664 return num_str.parse().ok();
665 }
666 }
667 None
668 }
669
670 fn parse_offset(&self, normalized: &str) -> Option<usize> {
672 if let Some(offset_pos) = normalized.find("OFFSET") {
673 let after_offset = &normalized[offset_pos + 6..].trim();
674 if let Some(num_str) = after_offset.split_whitespace().next() {
675 return num_str.parse().ok();
676 }
677 }
678 None
679 }
680
681 fn parse_order_by(&self, normalized: &str) -> Vec<String> {
683 if let Some(order_pos) = normalized.find("ORDER BY") {
684 let after_order = &normalized[order_pos + 8..];
685
686 let limit_offset = after_order.find("LIMIT").unwrap_or(after_order.len());
688 let offset_offset = after_order.find("OFFSET").unwrap_or(after_order.len());
689 let end_offset = limit_offset.min(offset_offset);
690
691 let order_part = after_order[..end_offset].trim();
692 return order_part
693 .split_whitespace()
694 .filter_map(|s| s.strip_prefix('?').map(|v| v.to_string()))
695 .collect();
696 }
697 Vec::new()
698 }
699
700 fn split_sparql_statements<'a>(&self, content: &'a str) -> Vec<&'a str> {
704 let mut statements = Vec::new();
705 let mut current_start = 0;
706 let mut inside_uri = false;
707 let mut inside_string = false;
708 let chars: Vec<char> = content.chars().collect();
709
710 for i in 0..chars.len() {
711 match chars[i] {
712 '<' if !inside_string => inside_uri = true,
713 '>' if !inside_string => inside_uri = false,
714 '"' if !inside_uri => inside_string = !inside_string,
715 '.' if !inside_uri && !inside_string => {
716 let statement = &content[current_start..i];
718 if !statement.trim().is_empty() {
719 statements.push(statement);
720 }
721 current_start = i + 1;
722 }
723 _ => {}
724 }
725 }
726
727 if current_start < content.len() {
729 let statement = &content[current_start..];
730 if !statement.trim().is_empty() {
731 statements.push(statement);
732 }
733 }
734
735 statements
736 }
737
738 fn parse_triple_pattern(&self, line: &str) -> Result<Option<TriplePattern>> {
740 let line = line.trim_end_matches('.').trim();
742 let parts: Vec<&str> = line.split_whitespace().collect();
743
744 if parts.len() < 3 {
745 return Ok(None);
746 }
747
748 let subject = self.parse_pattern_element(parts[0])?;
749 let predicate = self.parse_pattern_element(parts[1])?;
750 let object = self.parse_pattern_element(parts[2])?;
751
752 Ok(Some(TriplePattern {
753 subject,
754 predicate,
755 object,
756 }))
757 }
758
759 fn parse_pattern_element(&self, s: &str) -> Result<PatternElement> {
761 if let Some(var_name) = s.strip_prefix('?') {
762 Ok(PatternElement::Variable(var_name.to_string()))
763 } else if let Some(iri) = s.strip_prefix('<').and_then(|s| s.strip_suffix('>')) {
764 Ok(PatternElement::Constant(iri.to_string()))
765 } else if let Some(literal) = s.strip_prefix('"').and_then(|s| s.strip_suffix('"')) {
766 Ok(PatternElement::Constant(literal.to_string()))
767 } else {
768 Ok(PatternElement::Constant(s.to_string()))
769 }
770 }
771
772 fn parse_filter(&self, line: &str) -> Result<Option<FilterCondition>> {
774 let filter_content = line
775 .strip_prefix("FILTER")
776 .and_then(|s| s.trim().strip_prefix('('))
777 .and_then(|s| s.trim().strip_suffix(')'))
778 .map(|s| s.trim());
779
780 if let Some(content) = filter_content {
781 if content.starts_with("BOUND(") {
783 if let Some(var_end) = content.find(')') {
784 let var = &content[6..var_end].trim_start_matches('?');
785 return Ok(Some(FilterCondition::Bound(var.to_string())));
786 }
787 } else if content.starts_with("isIRI(") || content.starts_with("isURI(") {
788 let start_pos = 6;
790 if let Some(var_end) = content.find(')') {
791 let var = &content[start_pos..var_end].trim_start_matches('?');
792 return Ok(Some(FilterCondition::IsIri(var.to_string())));
793 }
794 } else if content.starts_with("isLiteral(") {
795 if let Some(var_end) = content.find(')') {
796 let var = &content[10..var_end].trim_start_matches('?');
797 return Ok(Some(FilterCondition::IsLiteral(var.to_string())));
798 }
799 } else if content.starts_with("regex(") {
800 if let Some(comma_pos) = content.find(',') {
802 let var = content[6..comma_pos].trim().trim_start_matches('?');
803 let pattern_part = content[comma_pos + 1..]
804 .trim()
805 .trim_end_matches(')')
806 .trim_matches('"');
807 return Ok(Some(FilterCondition::Regex(
808 var.to_string(),
809 pattern_part.to_string(),
810 )));
811 }
812 }
813
814 if content.contains(">=") {
816 let parts: Vec<&str> = content.split(">=").map(|s| s.trim()).collect();
817 if parts.len() == 2 {
818 return Ok(Some(FilterCondition::GreaterOrEqual(
819 parts[0].trim_start_matches('?').to_string(),
820 parts[1].trim_matches('"').to_string(),
821 )));
822 }
823 } else if content.contains("<=") {
824 let parts: Vec<&str> = content.split("<=").map(|s| s.trim()).collect();
825 if parts.len() == 2 {
826 return Ok(Some(FilterCondition::LessOrEqual(
827 parts[0].trim_start_matches('?').to_string(),
828 parts[1].trim_matches('"').to_string(),
829 )));
830 }
831 } else if content.contains(">") && !content.contains(">=") {
832 let parts: Vec<&str> = content.split('>').map(|s| s.trim()).collect();
833 if parts.len() == 2 {
834 return Ok(Some(FilterCondition::GreaterThan(
835 parts[0].trim_start_matches('?').to_string(),
836 parts[1].trim_matches('"').to_string(),
837 )));
838 }
839 } else if content.contains("<") && !content.contains("<=") {
840 let parts: Vec<&str> = content.split('<').map(|s| s.trim()).collect();
841 if parts.len() == 2 {
842 return Ok(Some(FilterCondition::LessThan(
843 parts[0].trim_start_matches('?').to_string(),
844 parts[1].trim_matches('"').to_string(),
845 )));
846 }
847 } else if content.contains("!=") {
848 let parts: Vec<&str> = content.split("!=").map(|s| s.trim()).collect();
849 if parts.len() == 2 {
850 return Ok(Some(FilterCondition::NotEquals(
851 parts[0].trim_start_matches('?').to_string(),
852 parts[1].trim_matches('"').to_string(),
853 )));
854 }
855 } else if content.contains("=")
856 && !content.contains("!=")
857 && !content.contains(">=")
858 && !content.contains("<=")
859 {
860 let parts: Vec<&str> = content.split('=').map(|s| s.trim()).collect();
861 if parts.len() == 2 {
862 return Ok(Some(FilterCondition::Equals(
863 parts[0].trim_start_matches('?').to_string(),
864 parts[1].trim_matches('"').to_string(),
865 )));
866 }
867 }
868 }
869
870 Ok(None)
871 }
872
873 pub fn compile_to_tensorlogic(&self, query: &SparqlQuery) -> Result<TLExpr> {
911 let where_expr = self.compile_graph_pattern(&query.where_pattern)?;
913
914 match &query.query_type {
916 QueryType::Ask => {
917 Ok(where_expr) }
920 QueryType::Select { select_vars, .. } => {
921 if select_vars.is_empty() || select_vars.contains(&"*".to_string()) {
924 Ok(where_expr)
925 } else {
926 Ok(where_expr)
928 }
929 }
930 QueryType::Describe { .. } => {
931 Ok(where_expr)
933 }
934 QueryType::Construct { template: _ } => {
935 Ok(where_expr)
939 }
940 }
941 }
942
943 fn compile_graph_pattern(&self, pattern: &GraphPattern) -> Result<TLExpr> {
945 match pattern {
946 GraphPattern::Triple(triple) => self.compile_triple_pattern(triple),
947
948 GraphPattern::Group(patterns) => {
949 if patterns.is_empty() {
950 return Err(anyhow!("Empty pattern group"));
951 }
952
953 let mut exprs: Vec<TLExpr> = Vec::new();
954 for p in patterns {
955 exprs.push(self.compile_graph_pattern(p)?);
956 }
957
958 Ok(exprs.into_iter().reduce(TLExpr::and).unwrap())
960 }
961
962 GraphPattern::Optional(inner) => {
963 let inner_expr = self.compile_graph_pattern(inner)?;
967
968 Ok(TLExpr::or(inner_expr.clone(), TLExpr::pred("true", vec![])))
971 }
972
973 GraphPattern::Union(left, right) => {
974 let left_expr = self.compile_graph_pattern(left)?;
976 let right_expr = self.compile_graph_pattern(right)?;
977 Ok(TLExpr::or(left_expr, right_expr))
978 }
979
980 GraphPattern::Filter(filter_cond) => self.compile_filter_condition(filter_cond),
981 }
982 }
983
984 fn compile_triple_pattern(&self, pattern: &TriplePattern) -> Result<TLExpr> {
986 let pred_name = match &pattern.predicate {
987 PatternElement::Constant(iri) => {
988 self.predicate_mapping
990 .get(iri)
991 .cloned()
992 .unwrap_or_else(|| Self::iri_to_name(iri))
993 }
994 PatternElement::Variable(v) => {
995 return Err(anyhow!("Variable predicates not supported: ?{}", v));
996 }
997 };
998
999 let subj_term = match &pattern.subject {
1000 PatternElement::Variable(v) => Term::var(v),
1001 PatternElement::Constant(c) => Term::constant(c),
1002 };
1003
1004 let obj_term = match &pattern.object {
1005 PatternElement::Variable(v) => Term::var(v),
1006 PatternElement::Constant(c) => Term::constant(c),
1007 };
1008
1009 Ok(TLExpr::pred(&pred_name, vec![subj_term, obj_term]))
1010 }
1011
1012 fn compile_filter_condition(&self, filter: &FilterCondition) -> Result<TLExpr> {
1014 let expr = match filter {
1015 FilterCondition::Equals(var, val) => {
1016 TLExpr::pred("equals", vec![Term::var(var), Term::constant(val)])
1017 }
1018 FilterCondition::NotEquals(var, val) => TLExpr::negate(TLExpr::pred(
1019 "equals",
1020 vec![Term::var(var), Term::constant(val)],
1021 )),
1022 FilterCondition::GreaterThan(var, val) => {
1023 TLExpr::pred("greaterThan", vec![Term::var(var), Term::constant(val)])
1024 }
1025 FilterCondition::LessThan(var, val) => {
1026 TLExpr::pred("lessThan", vec![Term::var(var), Term::constant(val)])
1027 }
1028 FilterCondition::GreaterOrEqual(var, val) => {
1029 TLExpr::pred("greaterOrEqual", vec![Term::var(var), Term::constant(val)])
1030 }
1031 FilterCondition::LessOrEqual(var, val) => {
1032 TLExpr::pred("lessOrEqual", vec![Term::var(var), Term::constant(val)])
1033 }
1034 FilterCondition::Regex(var, pattern) => {
1035 TLExpr::pred("matches", vec![Term::var(var), Term::constant(pattern)])
1036 }
1037 FilterCondition::Bound(var) => TLExpr::pred("bound", vec![Term::var(var)]),
1038 FilterCondition::IsIri(var) => TLExpr::pred("isIri", vec![Term::var(var)]),
1039 FilterCondition::IsLiteral(var) => TLExpr::pred("isLiteral", vec![Term::var(var)]),
1040 };
1041
1042 Ok(expr)
1043 }
1044
1045 fn iri_to_name(iri: &str) -> String {
1047 iri.split(['/', '#']).next_back().unwrap_or(iri).to_string()
1048 }
1049}
1050
1051impl Default for SparqlCompiler {
1052 fn default() -> Self {
1053 Self::new()
1054 }
1055}
1056
1057#[cfg(test)]
1058mod tests {
1059 use super::*;
1060
1061 #[test]
1064 fn test_parse_simple_query() {
1065 let compiler = SparqlCompiler::new();
1066 let query = r#"
1067 SELECT ?x ?y WHERE {
1068 ?x <http://example.org/knows> ?y .
1069 }
1070 "#;
1071
1072 let parsed = compiler.parse_query(query).unwrap();
1073
1074 match &parsed.query_type {
1076 QueryType::Select {
1077 select_vars,
1078 distinct,
1079 ..
1080 } => {
1081 assert_eq!(select_vars, &vec!["x", "y"]);
1082 assert!(!distinct);
1083 }
1084 _ => panic!("Expected SELECT query"),
1085 }
1086
1087 match &parsed.where_pattern {
1089 GraphPattern::Triple(pattern) => {
1090 assert_eq!(pattern.subject, PatternElement::Variable("x".to_string()));
1091 assert_eq!(
1092 pattern.predicate,
1093 PatternElement::Constant("http://example.org/knows".to_string())
1094 );
1095 assert_eq!(pattern.object, PatternElement::Variable("y".to_string()));
1096 }
1097 _ => panic!("Expected Triple pattern"),
1098 }
1099 }
1100
1101 #[test]
1102 fn test_parse_select_distinct() {
1103 let compiler = SparqlCompiler::new();
1104 let query = r#"
1105 SELECT DISTINCT ?x WHERE {
1106 ?x <http://example.org/type> ?t .
1107 }
1108 "#;
1109
1110 let parsed = compiler.parse_query(query).unwrap();
1111
1112 match &parsed.query_type {
1113 QueryType::Select {
1114 select_vars,
1115 distinct,
1116 ..
1117 } => {
1118 assert_eq!(select_vars, &vec!["x"]);
1119 assert!(distinct);
1120 }
1121 _ => panic!("Expected SELECT DISTINCT query"),
1122 }
1123 }
1124
1125 #[test]
1126 fn test_parse_query_with_filter() {
1127 let compiler = SparqlCompiler::new();
1128 let query = r#"
1129 SELECT ?x ?age WHERE {
1130 ?x <http://example.org/age> ?age .
1131 FILTER(?age > 18)
1132 }
1133 "#;
1134
1135 let parsed = compiler.parse_query(query).unwrap();
1136
1137 match &parsed.query_type {
1138 QueryType::Select { select_vars, .. } => {
1139 assert_eq!(select_vars, &vec!["x", "age"]);
1140 }
1141 _ => panic!("Expected SELECT query"),
1142 }
1143
1144 match &parsed.where_pattern {
1146 GraphPattern::Group(patterns) => {
1147 assert_eq!(patterns.len(), 2);
1148 assert!(matches!(patterns[0], GraphPattern::Triple(_)));
1150 assert!(matches!(patterns[1], GraphPattern::Filter(_)));
1151 }
1152 _ => panic!("Expected Group pattern with filter"),
1153 }
1154 }
1155
1156 #[test]
1157 fn test_parse_query_with_limit_offset() {
1158 let compiler = SparqlCompiler::new();
1159 let query = r#"
1160 SELECT ?x WHERE {
1161 ?x <http://example.org/type> ?t .
1162 } LIMIT 10 OFFSET 20
1163 "#;
1164
1165 let parsed = compiler.parse_query(query).unwrap();
1166 assert_eq!(parsed.limit, Some(10));
1167 assert_eq!(parsed.offset, Some(20));
1168 }
1169
1170 #[test]
1171 fn test_parse_query_with_order_by() {
1172 let compiler = SparqlCompiler::new();
1173 let query = r#"
1174 SELECT ?x ?name WHERE {
1175 ?x <http://example.org/name> ?name .
1176 } ORDER BY ?name
1177 "#;
1178
1179 let parsed = compiler.parse_query(query).unwrap();
1180 assert_eq!(parsed.order_by, vec!["name"]);
1181 }
1182
1183 #[test]
1186 fn test_parse_ask_query() {
1187 let compiler = SparqlCompiler::new();
1188 let query = r#"
1189 ASK WHERE {
1190 ?x <http://example.org/knows> ?y .
1191 }
1192 "#;
1193
1194 let parsed = compiler.parse_query(query).unwrap();
1195
1196 match &parsed.query_type {
1197 QueryType::Ask => {
1198 }
1200 _ => panic!("Expected ASK query"),
1201 }
1202 }
1203
1204 #[test]
1205 fn test_compile_ask_query() {
1206 let mut compiler = SparqlCompiler::new();
1207 compiler.add_predicate_mapping("http://example.org/knows".to_string(), "knows".to_string());
1208
1209 let query = r#"
1210 ASK WHERE {
1211 ?x <http://example.org/knows> ?y .
1212 }
1213 "#;
1214
1215 let parsed = compiler.parse_query(query).unwrap();
1216 let tl_expr = compiler.compile_to_tensorlogic(&parsed).unwrap();
1217
1218 let expr_str = format!("{:?}", tl_expr);
1220 assert!(expr_str.contains("knows"));
1221 }
1222
1223 #[test]
1226 fn test_parse_describe_query() {
1227 let compiler = SparqlCompiler::new();
1228 let query = r#"
1229 DESCRIBE ?x WHERE {
1230 ?x <http://example.org/type> <http://example.org/Person> .
1231 }
1232 "#;
1233
1234 let parsed = compiler.parse_query(query).unwrap();
1235
1236 match &parsed.query_type {
1237 QueryType::Describe { resources } => {
1238 assert_eq!(resources, &vec!["x"]);
1239 }
1240 _ => panic!("Expected DESCRIBE query"),
1241 }
1242 }
1243
1244 #[test]
1245 fn test_compile_describe_query() {
1246 let mut compiler = SparqlCompiler::new();
1247 compiler.add_predicate_mapping("http://example.org/type".to_string(), "type".to_string());
1248
1249 let query = r#"
1250 DESCRIBE ?x WHERE {
1251 ?x <http://example.org/type> ?t .
1252 }
1253 "#;
1254
1255 let parsed = compiler.parse_query(query).unwrap();
1256 let tl_expr = compiler.compile_to_tensorlogic(&parsed).unwrap();
1257
1258 let expr_str = format!("{:?}", tl_expr);
1259 assert!(expr_str.contains("type"));
1260 }
1261
1262 #[test]
1265 fn test_parse_construct_query() {
1266 let compiler = SparqlCompiler::new();
1267 let query = r#"
1268 CONSTRUCT { ?x <http://example.org/friend> ?y }
1269 WHERE {
1270 ?x <http://example.org/knows> ?y .
1271 }
1272 "#;
1273
1274 let parsed = compiler.parse_query(query).unwrap();
1275
1276 match &parsed.query_type {
1277 QueryType::Construct { template } => {
1278 assert_eq!(template.len(), 1);
1279 let pattern = &template[0];
1280 assert_eq!(pattern.subject, PatternElement::Variable("x".to_string()));
1281 assert_eq!(
1282 pattern.predicate,
1283 PatternElement::Constant("http://example.org/friend".to_string())
1284 );
1285 assert_eq!(pattern.object, PatternElement::Variable("y".to_string()));
1286 }
1287 _ => panic!("Expected CONSTRUCT query"),
1288 }
1289 }
1290
1291 #[test]
1292 fn test_compile_construct_query() {
1293 let mut compiler = SparqlCompiler::new();
1294 compiler.add_predicate_mapping("http://example.org/knows".to_string(), "knows".to_string());
1295
1296 let query = r#"
1297 CONSTRUCT { ?x <http://example.org/friend> ?y }
1298 WHERE {
1299 ?x <http://example.org/knows> ?y .
1300 }
1301 "#;
1302
1303 let parsed = compiler.parse_query(query).unwrap();
1304 let tl_expr = compiler.compile_to_tensorlogic(&parsed).unwrap();
1305
1306 let expr_str = format!("{:?}", tl_expr);
1307 assert!(expr_str.contains("knows"));
1308 }
1309
1310 #[test]
1313 fn test_parse_optional_pattern() {
1314 let compiler = SparqlCompiler::new();
1315 let query = r#"
1316 SELECT ?x ?name ?age WHERE {
1317 ?x <http://example.org/name> ?name .
1318 OPTIONAL { ?x <http://example.org/age> ?age }
1319 }
1320 "#;
1321
1322 let parsed = compiler.parse_query(query).unwrap();
1323
1324 match &parsed.where_pattern {
1325 GraphPattern::Group(patterns) => {
1326 assert_eq!(patterns.len(), 2);
1327 assert!(matches!(patterns[0], GraphPattern::Triple(_)));
1328 assert!(matches!(patterns[1], GraphPattern::Optional(_)));
1329 }
1330 _ => panic!("Expected Group with OPTIONAL"),
1331 }
1332 }
1333
1334 #[test]
1335 fn test_compile_optional_pattern() {
1336 let mut compiler = SparqlCompiler::new();
1337 compiler.add_predicate_mapping("http://example.org/name".to_string(), "name".to_string());
1338 compiler.add_predicate_mapping("http://example.org/age".to_string(), "age".to_string());
1339
1340 let query = r#"
1341 SELECT ?x ?name WHERE {
1342 ?x <http://example.org/name> ?name .
1343 OPTIONAL { ?x <http://example.org/age> ?age }
1344 }
1345 "#;
1346
1347 let parsed = compiler.parse_query(query).unwrap();
1348 let tl_expr = compiler.compile_to_tensorlogic(&parsed).unwrap();
1349
1350 let expr_str = format!("{:?}", tl_expr);
1352 assert!(expr_str.contains("name"));
1353 assert!(expr_str.contains("Or"));
1354 }
1355
1356 #[test]
1359 fn test_parse_union_pattern() {
1360 let compiler = SparqlCompiler::new();
1361 let query = r#"
1362 SELECT ?x ?y WHERE {
1363 { ?x <http://example.org/knows> ?y }
1364 UNION
1365 { ?x <http://example.org/likes> ?y }
1366 }
1367 "#;
1368
1369 let parsed = compiler.parse_query(query).unwrap();
1370
1371 match &parsed.where_pattern {
1372 GraphPattern::Union(_, _) => {
1373 }
1375 _ => panic!("Expected UNION pattern"),
1376 }
1377 }
1378
1379 #[test]
1380 fn test_compile_union_pattern() {
1381 let mut compiler = SparqlCompiler::new();
1382 compiler.add_predicate_mapping("http://example.org/knows".to_string(), "knows".to_string());
1383 compiler.add_predicate_mapping("http://example.org/likes".to_string(), "likes".to_string());
1384
1385 let query = r#"
1386 SELECT ?x ?y WHERE {
1387 { ?x <http://example.org/knows> ?y }
1388 UNION
1389 { ?x <http://example.org/likes> ?y }
1390 }
1391 "#;
1392
1393 let parsed = compiler.parse_query(query).unwrap();
1394 let tl_expr = compiler.compile_to_tensorlogic(&parsed).unwrap();
1395
1396 let expr_str = format!("{:?}", tl_expr);
1398 assert!(expr_str.contains("knows") || expr_str.contains("likes"));
1399 assert!(expr_str.contains("Or"));
1400 }
1401
1402 #[test]
1405 fn test_filter_greater_or_equal() {
1406 let compiler = SparqlCompiler::new();
1407 let query = r#"
1408 SELECT ?x WHERE {
1409 ?x <http://example.org/age> ?age .
1410 FILTER(?age >= 18)
1411 }
1412 "#;
1413
1414 let parsed = compiler.parse_query(query).unwrap();
1415
1416 match &parsed.where_pattern {
1417 GraphPattern::Group(patterns) => {
1418 if let Some(GraphPattern::Filter(FilterCondition::GreaterOrEqual(var, val))) =
1419 patterns.get(1)
1420 {
1421 assert_eq!(var, "age");
1422 assert_eq!(val, "18");
1423 } else {
1424 panic!("Expected GreaterOrEqual filter");
1425 }
1426 }
1427 _ => panic!("Expected Group pattern"),
1428 }
1429 }
1430
1431 #[test]
1432 fn test_filter_bound() {
1433 let compiler = SparqlCompiler::new();
1434 let filter = compiler.parse_filter("FILTER(BOUND(?x))").unwrap();
1435
1436 match filter {
1437 Some(FilterCondition::Bound(var)) => {
1438 assert_eq!(var, "x");
1439 }
1440 _ => panic!("Expected BOUND filter"),
1441 }
1442 }
1443
1444 #[test]
1445 fn test_filter_is_iri() {
1446 let compiler = SparqlCompiler::new();
1447 let filter = compiler.parse_filter("FILTER(isIRI(?x))").unwrap();
1448
1449 match filter {
1450 Some(FilterCondition::IsIri(var)) => {
1451 assert_eq!(var, "x");
1452 }
1453 _ => panic!("Expected isIRI filter"),
1454 }
1455 }
1456
1457 #[test]
1458 fn test_filter_regex() {
1459 let compiler = SparqlCompiler::new();
1460 let filter = compiler
1461 .parse_filter(r#"FILTER(regex(?name, "^John"))"#)
1462 .unwrap();
1463
1464 match filter {
1465 Some(FilterCondition::Regex(var, pattern)) => {
1466 assert_eq!(var, "name");
1467 assert_eq!(pattern, "^John");
1468 }
1469 _ => panic!("Expected regex filter"),
1470 }
1471 }
1472
1473 #[test]
1476 fn test_compile_simple_query() {
1477 let mut compiler = SparqlCompiler::new();
1478 compiler.add_predicate_mapping("http://example.org/knows".to_string(), "knows".to_string());
1479
1480 let query = r#"
1481 SELECT ?x ?y WHERE {
1482 ?x <http://example.org/knows> ?y .
1483 }
1484 "#;
1485
1486 let parsed = compiler.parse_query(query).unwrap();
1487 let tl_expr = compiler.compile_to_tensorlogic(&parsed).unwrap();
1488
1489 let expr_str = format!("{:?}", tl_expr);
1491 assert!(expr_str.contains("knows"));
1492 }
1493
1494 #[test]
1495 fn test_compile_query_with_multiple_patterns() {
1496 let mut compiler = SparqlCompiler::new();
1497 compiler.add_predicate_mapping("http://example.org/knows".to_string(), "knows".to_string());
1498
1499 let query = r#"
1500 SELECT ?x ?y ?z WHERE {
1501 ?x <http://example.org/knows> ?y .
1502 ?y <http://example.org/knows> ?z .
1503 }
1504 "#;
1505
1506 let parsed = compiler.parse_query(query).unwrap();
1507 let tl_expr = compiler.compile_to_tensorlogic(&parsed).unwrap();
1508
1509 let expr_str = format!("{:?}", tl_expr);
1511 assert!(expr_str.contains("knows"));
1512 assert!(expr_str.contains("And"));
1513 }
1514
1515 #[test]
1516 fn test_compile_query_with_filter() {
1517 let mut compiler = SparqlCompiler::new();
1518 compiler.add_predicate_mapping("http://example.org/age".to_string(), "age".to_string());
1519
1520 let query = r#"
1521 SELECT ?x ?a WHERE {
1522 ?x <http://example.org/age> ?a .
1523 FILTER(?a > 18)
1524 }
1525 "#;
1526
1527 let parsed = compiler.parse_query(query).unwrap();
1528 let tl_expr = compiler.compile_to_tensorlogic(&parsed).unwrap();
1529
1530 let expr_str = format!("{:?}", tl_expr);
1532 assert!(expr_str.contains("age"));
1533 assert!(expr_str.contains("greaterThan"));
1534 }
1535
1536 #[test]
1539 fn test_iri_to_name() {
1540 assert_eq!(
1541 SparqlCompiler::iri_to_name("http://example.org/knows"),
1542 "knows"
1543 );
1544 assert_eq!(
1545 SparqlCompiler::iri_to_name("http://xmlns.com/foaf/0.1#Person"),
1546 "Person"
1547 );
1548 assert_eq!(SparqlCompiler::iri_to_name("simple"), "simple");
1549 }
1550
1551 #[test]
1554 fn test_complex_query_with_optional_and_filter() {
1555 let mut compiler = SparqlCompiler::new();
1556 compiler.add_predicate_mapping("http://example.org/name".to_string(), "name".to_string());
1557 compiler.add_predicate_mapping("http://example.org/age".to_string(), "age".to_string());
1558
1559 let query = r#"
1560 SELECT DISTINCT ?x ?name WHERE {
1561 ?x <http://example.org/name> ?name .
1562 OPTIONAL {
1563 ?x <http://example.org/age> ?age .
1564 FILTER(?age >= 21)
1565 }
1566 } LIMIT 100 ORDER BY ?name
1567 "#;
1568
1569 let parsed = compiler.parse_query(query).unwrap();
1570
1571 match &parsed.query_type {
1573 QueryType::Select {
1574 select_vars,
1575 distinct,
1576 ..
1577 } => {
1578 assert_eq!(select_vars, &vec!["x", "name"]);
1579 assert!(distinct);
1580 }
1581 _ => panic!("Expected SELECT DISTINCT"),
1582 }
1583
1584 assert_eq!(parsed.limit, Some(100));
1585 assert_eq!(parsed.order_by, vec!["name"]);
1586
1587 match &parsed.where_pattern {
1589 GraphPattern::Group(patterns) => {
1590 assert!(patterns.len() >= 2, "Expected at least 2 patterns in group");
1591 assert!(matches!(patterns[0], GraphPattern::Triple(_)));
1593 }
1594 _ => panic!("Expected Group pattern"),
1595 }
1596
1597 let tl_expr = compiler.compile_to_tensorlogic(&parsed).unwrap();
1599 let expr_str = format!("{:?}", tl_expr);
1600 assert!(expr_str.contains("name"));
1601 assert!(expr_str.contains("And") || expr_str.contains("Or"));
1603 }
1604
1605 #[test]
1608 fn test_parse_count_aggregate() {
1609 let compiler = SparqlCompiler::new();
1610 let query = r#"
1611 SELECT (COUNT(?x) AS ?count) WHERE {
1612 ?x <http://example.org/type> <http://example.org/Person> .
1613 }
1614 "#;
1615
1616 let parsed = compiler.parse_query(query).unwrap();
1617
1618 match &parsed.query_type {
1619 QueryType::Select { projections, .. } => {
1620 assert_eq!(projections.len(), 1);
1621 match &projections[0] {
1622 SelectElement::Aggregate { function, alias } => {
1623 assert!(matches!(function, AggregateFunction::Count { .. }));
1624 assert_eq!(alias, &Some("count".to_string()));
1625 }
1626 _ => panic!("Expected Aggregate element"),
1627 }
1628 }
1629 _ => panic!("Expected SELECT"),
1630 }
1631 }
1632
1633 #[test]
1634 fn test_parse_sum_aggregate() {
1635 let compiler = SparqlCompiler::new();
1636 let query = r#"
1637 SELECT (SUM(?amount) AS ?total) WHERE {
1638 ?x <http://example.org/amount> ?amount .
1639 }
1640 "#;
1641
1642 let parsed = compiler.parse_query(query).unwrap();
1643
1644 match &parsed.query_type {
1645 QueryType::Select { projections, .. } => {
1646 assert_eq!(projections.len(), 1);
1647 match &projections[0] {
1648 SelectElement::Aggregate { function, .. } => {
1649 if let AggregateFunction::Sum { variable, .. } = function {
1650 assert_eq!(variable, "amount");
1651 } else {
1652 panic!("Expected SUM aggregate");
1653 }
1654 }
1655 _ => panic!("Expected Aggregate element"),
1656 }
1657 }
1658 _ => panic!("Expected SELECT"),
1659 }
1660 }
1661
1662 #[test]
1663 fn test_parse_avg_min_max() {
1664 let compiler = SparqlCompiler::new();
1665 let query = r#"
1666 SELECT (AVG(?age) AS ?avg_age) (MIN(?age) AS ?min_age) (MAX(?age) AS ?max_age) WHERE {
1667 ?x <http://example.org/age> ?age .
1668 }
1669 "#;
1670
1671 let parsed = compiler.parse_query(query).unwrap();
1672
1673 match &parsed.query_type {
1674 QueryType::Select { projections, .. } => {
1675 assert_eq!(projections.len(), 3);
1676 match &projections[0] {
1678 SelectElement::Aggregate { function, .. } => {
1679 assert!(matches!(function, AggregateFunction::Avg { .. }));
1680 }
1681 _ => panic!("Expected Aggregate element"),
1682 }
1683 match &projections[1] {
1685 SelectElement::Aggregate { function, .. } => {
1686 assert!(matches!(function, AggregateFunction::Min { .. }));
1687 }
1688 _ => panic!("Expected Aggregate element"),
1689 }
1690 match &projections[2] {
1692 SelectElement::Aggregate { function, .. } => {
1693 assert!(matches!(function, AggregateFunction::Max { .. }));
1694 }
1695 _ => panic!("Expected Aggregate element"),
1696 }
1697 }
1698 _ => panic!("Expected SELECT"),
1699 }
1700 }
1701
1702 #[test]
1703 fn test_parse_group_by() {
1704 let compiler = SparqlCompiler::new();
1705 let query = r#"
1706 SELECT ?dept (COUNT(?person) AS ?count) WHERE {
1707 ?person <http://example.org/department> ?dept .
1708 } GROUP BY ?dept
1709 "#;
1710
1711 let parsed = compiler.parse_query(query).unwrap();
1712
1713 assert_eq!(parsed.group_by, vec!["dept"]);
1714
1715 match &parsed.query_type {
1716 QueryType::Select { projections, .. } => {
1717 assert_eq!(projections.len(), 2);
1718 match &projections[0] {
1720 SelectElement::Variable(name) => assert_eq!(name, "dept"),
1721 _ => panic!("Expected Variable element"),
1722 }
1723 match &projections[1] {
1725 SelectElement::Aggregate { function, .. } => {
1726 assert!(matches!(function, AggregateFunction::Count { .. }));
1727 }
1728 _ => panic!("Expected Aggregate element"),
1729 }
1730 }
1731 _ => panic!("Expected SELECT"),
1732 }
1733 }
1734
1735 #[test]
1736 fn test_parse_having() {
1737 let compiler = SparqlCompiler::new();
1738 let query = r#"
1739 SELECT ?dept (COUNT(?person) AS ?count) WHERE {
1740 ?person <http://example.org/department> ?dept .
1741 } GROUP BY ?dept HAVING(?count > 10)
1742 "#;
1743
1744 let parsed = compiler.parse_query(query).unwrap();
1745
1746 assert_eq!(parsed.group_by, vec!["dept"]);
1747 assert_eq!(parsed.having.len(), 1);
1748
1749 match &parsed.having[0] {
1750 FilterCondition::GreaterThan(var, val) => {
1751 assert_eq!(var, "count");
1752 assert_eq!(val, "10");
1753 }
1754 _ => panic!("Expected GreaterThan condition"),
1755 }
1756 }
1757
1758 #[test]
1759 fn test_parse_count_distinct() {
1760 let compiler = SparqlCompiler::new();
1761 let query = r#"
1762 SELECT (COUNT(DISTINCT ?person) AS ?unique) WHERE {
1763 ?person <http://example.org/type> <http://example.org/Person> .
1764 }
1765 "#;
1766
1767 let parsed = compiler.parse_query(query).unwrap();
1768
1769 match &parsed.query_type {
1770 QueryType::Select { projections, .. } => match &projections[0] {
1771 SelectElement::Aggregate { function, .. } => {
1772 if let AggregateFunction::Count { distinct, .. } = function {
1773 assert!(distinct);
1774 } else {
1775 panic!("Expected COUNT aggregate");
1776 }
1777 }
1778 _ => panic!("Expected Aggregate element"),
1779 },
1780 _ => panic!("Expected SELECT"),
1781 }
1782 }
1783
1784 #[test]
1785 fn test_parse_count_star() {
1786 let compiler = SparqlCompiler::new();
1787 let query = r#"
1788 SELECT (COUNT(*) AS ?total) WHERE {
1789 ?x <http://example.org/type> ?type .
1790 }
1791 "#;
1792
1793 let parsed = compiler.parse_query(query).unwrap();
1794
1795 match &parsed.query_type {
1796 QueryType::Select { projections, .. } => match &projections[0] {
1797 SelectElement::Aggregate { function, .. } => {
1798 if let AggregateFunction::Count { variable, .. } = function {
1799 assert!(variable.is_none());
1800 } else {
1801 panic!("Expected COUNT aggregate");
1802 }
1803 }
1804 _ => panic!("Expected Aggregate element"),
1805 },
1806 _ => panic!("Expected SELECT"),
1807 }
1808 }
1809
1810 #[test]
1811 fn test_combined_variables_and_aggregates() {
1812 let compiler = SparqlCompiler::new();
1813 let query = r#"
1814 SELECT ?category (SUM(?price) AS ?total) (AVG(?price) AS ?average) WHERE {
1815 ?item <http://example.org/category> ?category .
1816 ?item <http://example.org/price> ?price .
1817 } GROUP BY ?category ORDER BY ?total LIMIT 10
1818 "#;
1819
1820 let parsed = compiler.parse_query(query).unwrap();
1821
1822 match &parsed.query_type {
1824 QueryType::Select {
1825 projections,
1826 select_vars,
1827 ..
1828 } => {
1829 assert_eq!(projections.len(), 3);
1830 assert_eq!(select_vars, &vec!["category", "total", "average"]);
1831 }
1832 _ => panic!("Expected SELECT"),
1833 }
1834
1835 assert_eq!(parsed.group_by, vec!["category"]);
1837 assert_eq!(parsed.order_by, vec!["total"]);
1838 assert_eq!(parsed.limit, Some(10));
1839 }
1840}