1#![doc(html_root_url = "https://docs.rs/sery-mcp/0.4.3")]
34#![cfg_attr(docsrs, feature(doc_cfg))]
35#![allow(
45 clippy::doc_markdown,
46 clippy::items_after_statements,
47 clippy::case_sensitive_file_extension_comparisons
48)]
49
50use std::path::{Component, Path, PathBuf};
51use std::sync::OnceLock;
52
53use rmcp::{
54 handler::server::{router::tool::ToolRouter, wrapper::Parameters},
55 model::{
56 CallToolResult, Content, Implementation, ProtocolVersion, ServerCapabilities, ServerInfo,
57 },
58 schemars, tool, tool_handler, tool_router, ErrorData as McpError, ServerHandler,
59};
60
61pub const VERSION: &str = env!("CARGO_PKG_VERSION");
63
64const MAX_DOCUMENT_BYTES: u64 = 50 * 1024 * 1024;
70
71fn mdkit_engine() -> &'static mdkit::Engine {
76 static ENGINE: OnceLock<mdkit::Engine> = OnceLock::new();
77 ENGINE.get_or_init(mdkit::Engine::with_defaults)
78}
79
80fn tabkit_engine() -> &'static tabkit::Engine {
81 static ENGINE: OnceLock<tabkit::Engine> = OnceLock::new();
82 ENGINE.get_or_init(tabkit::Engine::with_defaults)
83}
84
85#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
91pub struct ListFolderRequest {
92 #[serde(default)]
94 #[schemars(
95 description = "Subdirectory under the configured --root. Must be relative — no '..' segments, no absolute paths. Defaults to the root."
96 )]
97 pub path: Option<String>,
98 #[serde(default)]
100 #[schemars(description = "Maximum entries to return. Defaults to 1000.")]
101 pub limit: Option<usize>,
102}
103
104#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
106pub struct SearchFilesRequest {
107 #[schemars(
109 description = "Search term (case-insensitive). Matched against file basenames; whole-path matches score lower."
110 )]
111 pub query: String,
112 #[serde(default)]
115 #[schemars(
116 description = "Restrict to files whose extension matches one of these (lowercase, no leading dot, e.g. ['csv','parquet'])."
117 )]
118 pub extensions: Option<Vec<String>>,
119 #[serde(default)]
121 #[schemars(description = "Maximum results to return. Defaults to 50.")]
122 pub limit: Option<usize>,
123}
124
125#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
127pub struct GetSchemaRequest {
128 #[schemars(
130 description = "Relative path to a tabular file (CSV / TSV / Parquet / XLSX / XLS / XLSB / XLSM / ODS) under --root."
131 )]
132 pub path: String,
133 #[serde(default)]
135 #[schemars(
136 description = "For multi-sheet XLSX / ODS files: which sheet to inspect. Defaults to the first non-empty sheet."
137 )]
138 pub sheet: Option<String>,
139}
140
141#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
143pub struct SampleRowsRequest {
144 #[schemars(description = "Relative path to a tabular file under --root.")]
146 pub path: String,
147 #[serde(default)]
149 #[schemars(description = "Sample-row count. Defaults to 5; capped at 100.")]
150 pub limit: Option<usize>,
151 #[serde(default)]
153 #[schemars(description = "For multi-sheet XLSX / ODS files: which sheet to sample.")]
154 pub sheet: Option<String>,
155}
156
157#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
159pub struct ReadDocumentRequest {
160 #[schemars(
162 description = "Relative path to a document file (DOCX / PDF / PPTX / HTML / IPYNB / EPUB / RTF / ODT) under --root. 50 MB cap."
163 )]
164 pub path: String,
165}
166
167#[derive(Debug, serde::Deserialize, schemars::JsonSchema)]
178pub struct QuerySqlRequest {
179 #[serde(default)]
181 #[schemars(
182 description = "Single-file mode. Relative path (or glob pattern like '2024/*.csv') under --root. The file(s) are registered as table `data` for the duration of this query. Mutually exclusive with `tables`."
183 )]
184 pub path: Option<String>,
185 #[serde(default)]
189 #[schemars(
190 description = "Multi-file mode. Map of {table_name: relative_path} — each path becomes a SQL table you can JOIN. Names must be valid SQL identifiers ([a-zA-Z_][a-zA-Z0-9_]*). Cap of 16 tables per call. Mutually exclusive with `path`."
191 )]
192 pub tables: Option<std::collections::HashMap<String, String>>,
193 #[schemars(
195 description = "SQL query — supports window functions, CTEs, glob reads, JOINs across the registered tables. Read-only — INSERT/UPDATE/DELETE/DDL/ATTACH/COPY/PRAGMA all rejected at validation time."
196 )]
197 pub sql: String,
198 #[serde(default)]
200 #[schemars(
201 description = "Maximum rows to return. Defaults to 100, capped at 1000. Use SQL LIMIT for tighter caps."
202 )]
203 pub limit: Option<usize>,
204}
205
206#[derive(Debug, serde::Serialize)]
212pub struct FileEntry {
213 pub relative_path: String,
215 pub size_bytes: u64,
217 pub modified: Option<String>,
219 pub extension: String,
221}
222
223#[derive(Debug, serde::Serialize)]
225pub struct SearchHit {
226 pub relative_path: String,
228 pub size_bytes: u64,
230 pub extension: String,
232 pub score: f64,
234 pub why_matched: &'static str,
236}
237
238#[derive(Debug, serde::Serialize)]
240pub struct ColumnInfo {
241 pub name: String,
243 #[serde(rename = "type")]
245 pub data_type: &'static str,
246 pub nullable: bool,
248}
249
250#[derive(Debug, serde::Serialize)]
252pub struct SchemaResponse {
253 pub relative_path: String,
255 pub format: String,
257 pub columns: Vec<ColumnInfo>,
259 pub row_count: Option<u64>,
262 #[serde(skip_serializing_if = "std::collections::HashMap::is_empty")]
266 pub metadata: std::collections::HashMap<String, String>,
267}
268
269#[derive(Debug, serde::Serialize)]
271pub struct SamplesResponse {
272 pub relative_path: String,
274 pub format: String,
276 pub columns: Vec<String>,
278 pub rows: Vec<serde_json::Map<String, serde_json::Value>>,
280 pub row_count: Option<u64>,
282}
283
284#[derive(Debug, serde::Serialize)]
286pub struct DocumentResponse {
287 pub relative_path: String,
289 pub format: String,
291 pub markdown: String,
293 pub title: Option<String>,
295 pub char_count: usize,
297 pub size_bytes: u64,
299}
300
301#[derive(Debug, serde::Serialize)]
303pub struct QueryResponse {
304 pub input: String,
308 pub format: String,
311 pub columns: Vec<String>,
313 pub rows: Vec<serde_json::Map<String, serde_json::Value>>,
315 pub row_count: usize,
317 pub truncated: bool,
322}
323
324#[derive(Clone)]
331pub struct SeryMcpServer {
332 root: PathBuf,
333 #[allow(dead_code)]
339 tool_router: ToolRouter<SeryMcpServer>,
340}
341
342#[tool_router]
343impl SeryMcpServer {
344 pub fn new(root: PathBuf) -> Self {
346 Self {
347 root,
348 tool_router: Self::tool_router(),
349 }
350 }
351
352 pub fn root(&self) -> &Path {
354 &self.root
355 }
356
357 #[tool(
360 description = "List files under the configured --root (or a sub-path). Returns one JSON object per file with relative_path, size_bytes, modified (ISO 8601), and extension. Read-only; never returns file contents. Path-traversal rejected."
361 )]
362 fn list_folder(
363 &self,
364 Parameters(req): Parameters<ListFolderRequest>,
365 ) -> Result<CallToolResult, McpError> {
366 let target = self.resolve_subpath(req.path.as_deref())?;
367 let limit = req.limit.unwrap_or(1000);
368 let entries = self.walk_entries(&target, limit)?;
369 as_json_result(&entries)
370 }
371
372 #[tool(
373 description = "Search files by name. Case-insensitive substring match against the basename, ranked: exact basename match (1.0), basename startswith (0.8), basename contains (0.5), path contains (0.2). Optional `extensions` filter restricts to specific file types. Returns up to `limit` hits sorted by score then path."
374 )]
375 fn search_files(
376 &self,
377 Parameters(req): Parameters<SearchFilesRequest>,
378 ) -> Result<CallToolResult, McpError> {
379 let limit = req.limit.unwrap_or(50);
380 let query = req.query.trim().to_lowercase();
381 if query.is_empty() {
382 return Err(McpError::invalid_params("'query' must not be empty", None));
383 }
384 let ext_filter: Option<Vec<String>> = req
385 .extensions
386 .map(|v| v.into_iter().map(|s| s.to_ascii_lowercase()).collect());
387
388 let scanner = scankit::Scanner::new(scankit::ScanConfig::default().follow_symlinks(false))
389 .map_err(|e| McpError::internal_error(format!("scankit init: {e}"), None))?;
390
391 let mut hits: Vec<SearchHit> = Vec::new();
392 for result in scanner.walk(&self.root) {
393 let Ok(entry) = result else { continue };
394 if let Some(filter) = ext_filter.as_ref() {
395 if !filter.iter().any(|e| e == &entry.extension) {
396 continue;
397 }
398 }
399 let basename = entry
400 .path
401 .file_name()
402 .and_then(|s| s.to_str())
403 .map(str::to_lowercase)
404 .unwrap_or_default();
405 let stem = entry
406 .path
407 .file_stem()
408 .and_then(|s| s.to_str())
409 .map(str::to_lowercase)
410 .unwrap_or_default();
411 let relative_path =
412 path_to_forward_slash(entry.path.strip_prefix(&self.root).unwrap_or(&entry.path));
413 let relative_lower = relative_path.to_lowercase();
414
415 let (score, why) = if stem == query || basename == query {
416 (1.0, "exact basename match")
417 } else if basename.starts_with(&query) {
418 (0.8, "basename starts with query")
419 } else if basename.contains(&query) {
420 (0.5, "basename contains query")
421 } else if relative_lower.contains(&query) {
422 (0.2, "path contains query")
423 } else {
424 continue;
425 };
426
427 hits.push(SearchHit {
428 relative_path,
429 size_bytes: entry.size_bytes,
430 extension: entry.extension,
431 score,
432 why_matched: why,
433 });
434 }
435 hits.sort_by(|a, b| {
436 b.score
437 .partial_cmp(&a.score)
438 .unwrap_or(std::cmp::Ordering::Equal)
439 .then_with(|| a.relative_path.cmp(&b.relative_path))
440 });
441 hits.truncate(limit);
442 as_json_result(&hits)
443 }
444
445 #[tool(
446 description = "Return column names + inferred types + row count for a tabular file (CSV / TSV / Parquet / XLSX / XLS / XLSB / XLSM / ODS). Backed by tabkit. row_count is null for very large files where a full scan was skipped. Specify `sheet` for multi-sheet workbooks."
447 )]
448 fn get_schema(
449 &self,
450 Parameters(req): Parameters<GetSchemaRequest>,
451 ) -> Result<CallToolResult, McpError> {
452 let path = self.resolve_required_file(&req.path)?;
453 let mut options = tabkit::ReadOptions::default().max_sample_rows(0);
454 if let Some(sheet) = req.sheet {
455 options = options.sheet_name(sheet);
456 }
457 let table = tabkit_engine()
458 .read(&path, &options)
459 .map_err(|e| McpError::internal_error(format!("tabkit read: {e}"), None))?;
460 let response = SchemaResponse {
461 relative_path: req.path,
462 format: extension_of(&path),
463 columns: table
464 .columns
465 .iter()
466 .map(|c| ColumnInfo {
467 name: c.name.clone(),
468 data_type: data_type_str(c.data_type),
469 nullable: c.nullable,
470 })
471 .collect(),
472 row_count: table.row_count,
473 metadata: table.metadata,
474 };
475 as_json_result(&response)
476 }
477
478 #[tool(
479 description = "Return the first N rows of a tabular file as header-keyed JSON objects. Defaults to 5 rows; capped at 100. Specify `sheet` for multi-sheet workbooks. Use sparingly — sample rows can contain PII; this tool returns raw cell values without redaction."
480 )]
481 fn sample_rows(
482 &self,
483 Parameters(req): Parameters<SampleRowsRequest>,
484 ) -> Result<CallToolResult, McpError> {
485 let path = self.resolve_required_file(&req.path)?;
486 let limit = req.limit.unwrap_or(5).min(100);
487 let mut options = tabkit::ReadOptions::default().max_sample_rows(limit);
488 if let Some(sheet) = req.sheet {
489 options = options.sheet_name(sheet);
490 }
491 let table = tabkit_engine()
492 .read(&path, &options)
493 .map_err(|e| McpError::internal_error(format!("tabkit read: {e}"), None))?;
494 let column_names: Vec<String> = table.columns.iter().map(|c| c.name.clone()).collect();
495 let rows = table
496 .sample_rows
497 .iter()
498 .map(|row| {
499 let mut obj = serde_json::Map::new();
500 for (i, col) in column_names.iter().enumerate() {
501 let v = row.get(i).map_or(serde_json::Value::Null, value_to_json);
502 obj.insert(col.clone(), v);
503 }
504 obj
505 })
506 .collect();
507 let response = SamplesResponse {
508 relative_path: req.path,
509 format: extension_of(&path),
510 columns: column_names,
511 rows,
512 row_count: table.row_count,
513 };
514 as_json_result(&response)
515 }
516
517 #[tool(
518 description = "Run a read-only SQL query against one or more CSV / TSV / Parquet files. \
519 Single-file: pass `path`, reference as table `data` in your SQL. \
520 Multi-file: pass `tables` (a {name: path} map), reference each name as a SQL table — lets you JOIN across files. \
521 Glob patterns (`*`, `?`) are supported in both — expanded at read time. \
522 Full SQL dialect: window functions, CTEs, smart CSV sniffing, native XLSX. \
523 Read-only by design — INSERT/UPDATE/DELETE/DDL/ATTACH/COPY/PRAGMA are rejected at validation time. \
524 Returns header-keyed JSON rows; capped at 1000 (default 100). Set `truncated: true` when more rows exist."
525 )]
526 fn query_sql(
527 &self,
528 Parameters(req): Parameters<QuerySqlRequest>,
529 ) -> Result<CallToolResult, McpError> {
530 let limit = req.limit.unwrap_or(100).min(1000);
531 let table_specs = self.resolve_table_specs(&req)?;
532 validate_query_sql(&req.sql)?;
533
534 let conn = duckdb::Connection::open_in_memory()
535 .map_err(|e| McpError::internal_error(format!("sql backend open: {e}"), None))?;
536
537 for spec in &table_specs {
540 let setup = build_register_view(&spec.table, &spec.path_for_sql, spec.format)?;
541 conn.execute_batch(&setup).map_err(|e| {
542 McpError::internal_error(format!("register table {}: {e}", spec.table), None)
543 })?;
544 }
545
546 let wrapped_sql = format!("SELECT * FROM ({}) LIMIT {}", req.sql, limit + 1);
551
552 let mut stmt = conn
553 .prepare(&wrapped_sql)
554 .map_err(|e| McpError::invalid_params(format!("sql prepare: {e}"), None))?;
555
556 let arrow_iter = stmt
561 .query_arrow(duckdb::params![])
562 .map_err(|e| McpError::invalid_params(format!("sql execute: {e}"), None))?;
563 let schema = arrow_iter.get_schema();
564 let columns: Vec<String> = schema.fields().iter().map(|f| f.name().clone()).collect();
565
566 let mut rows: Vec<serde_json::Map<String, serde_json::Value>> = Vec::with_capacity(limit);
567 let mut truncated = false;
568 'outer: for batch in arrow_iter {
569 for row_idx in 0..batch.num_rows() {
570 if rows.len() == limit {
571 truncated = true;
572 break 'outer;
573 }
574 let mut obj = serde_json::Map::with_capacity(columns.len());
575 for (col_idx, col_name) in columns.iter().enumerate() {
576 let array = batch.column(col_idx);
577 obj.insert(
578 col_name.clone(),
579 arrow_value_to_json(array.as_ref(), row_idx),
580 );
581 }
582 rows.push(obj);
583 }
584 }
585
586 let response = QueryResponse {
587 input: describe_input(&table_specs),
588 format: table_specs
589 .first()
590 .map(|s| s.format.to_string())
591 .unwrap_or_default(),
592 row_count: rows.len(),
593 columns,
594 rows,
595 truncated,
596 };
597 as_json_result(&response)
598 }
599
600 #[tool(
601 description = "Convert a document file (DOCX / PDF / PPTX / HTML / IPYNB / EPUB / RTF / ODT) to markdown. Backed by mdkit (libpdfium for PDF, pandoc for office formats, html2md for HTML). 50 MB file size cap; larger files return an error. Returns the full extracted text — pair with a chunk-aware caller if your LLM context window can't hold the whole document."
602 )]
603 fn read_document(
604 &self,
605 Parameters(req): Parameters<ReadDocumentRequest>,
606 ) -> Result<CallToolResult, McpError> {
607 let path = self.resolve_required_file(&req.path)?;
608 let metadata = std::fs::metadata(&path)
609 .map_err(|e| McpError::internal_error(format!("stat: {e}"), None))?;
610 if metadata.len() > MAX_DOCUMENT_BYTES {
611 return Err(McpError::invalid_params(
612 format!(
613 "file is {} bytes; read_document caps at {} bytes (50 MB)",
614 metadata.len(),
615 MAX_DOCUMENT_BYTES
616 ),
617 None,
618 ));
619 }
620 let document = mdkit_engine()
621 .extract(&path)
622 .map_err(|e| McpError::internal_error(format!("mdkit extract: {e}"), None))?;
623 let format = extension_of(&path);
624 let response = DocumentResponse {
625 char_count: document.markdown.chars().count(),
626 relative_path: req.path,
627 format,
628 title: document.title,
629 markdown: document.markdown,
630 size_bytes: metadata.len(),
631 };
632 as_json_result(&response)
633 }
634
635 fn resolve_subpath(&self, sub: Option<&str>) -> Result<PathBuf, McpError> {
640 let raw = match sub {
641 None => return Ok(self.root.clone()),
642 Some(s) if s.is_empty() || s == "." => return Ok(self.root.clone()),
643 Some(s) => s,
644 };
645 validate_relative_components(raw)?;
646 Ok(self.root.join(raw))
647 }
648
649 fn resolve_required_file(&self, sub: &str) -> Result<PathBuf, McpError> {
653 if sub.is_empty() {
654 return Err(McpError::invalid_params("'path' must not be empty", None));
655 }
656 validate_relative_components(sub)?;
657 let joined = self.root.join(sub);
658 let metadata = std::fs::metadata(&joined)
659 .map_err(|e| McpError::invalid_params(format!("path not readable: {e}"), None))?;
660 if !metadata.is_file() {
661 return Err(McpError::invalid_params(
662 "'path' must refer to a regular file (not a directory or symlink)",
663 None,
664 ));
665 }
666 Ok(joined)
667 }
668
669 fn resolve_required_path_or_glob(&self, sub: &str) -> Result<PathBuf, McpError> {
673 if sub.is_empty() {
674 return Err(McpError::invalid_params("path must not be empty", None));
675 }
676 validate_relative_components(sub)?;
677 let joined = self.root.join(sub);
678 if !is_glob_pattern(sub) {
679 let metadata = std::fs::metadata(&joined)
683 .map_err(|e| McpError::invalid_params(format!("path not readable: {e}"), None))?;
684 if !metadata.is_file() {
685 return Err(McpError::invalid_params(
686 "path must refer to a regular file or a glob pattern",
687 None,
688 ));
689 }
690 }
691 Ok(joined)
692 }
693
694 fn resolve_table_specs(&self, req: &QuerySqlRequest) -> Result<Vec<TableSpec>, McpError> {
705 match (&req.path, &req.tables) {
706 (Some(_), Some(_)) => Err(McpError::invalid_params(
707 "pass either `path` (single-file) or `tables` (multi-file), not both",
708 None,
709 )),
710 (None, None) => Err(McpError::invalid_params(
711 "must pass either `path` or `tables`",
712 None,
713 )),
714 (Some(path), None) => {
715 let resolved = self.resolve_required_path_or_glob(path)?;
716 let format = format_for_query_sql(path)?;
717 Ok(vec![TableSpec {
718 table: "data".to_string(),
719 path_for_sql: resolved.to_string_lossy().into_owned(),
720 relative_path: path.clone(),
721 format,
722 }])
723 }
724 (None, Some(tables)) => {
725 if tables.is_empty() {
726 return Err(McpError::invalid_params("`tables` must not be empty", None));
727 }
728 if tables.len() > 16 {
729 return Err(McpError::invalid_params("at most 16 tables per call", None));
730 }
731 let mut specs: Vec<TableSpec> = Vec::with_capacity(tables.len());
732 for (name, path) in tables {
733 if !is_valid_sql_identifier(name) {
734 return Err(McpError::invalid_params(
735 format!(
736 "table name '{name}' is not a valid SQL identifier \
737 ([a-zA-Z_][a-zA-Z0-9_]*)"
738 ),
739 None,
740 ));
741 }
742 let resolved = self.resolve_required_path_or_glob(path)?;
743 let format = format_for_query_sql(path)?;
744 specs.push(TableSpec {
745 table: name.clone(),
746 path_for_sql: resolved.to_string_lossy().into_owned(),
747 relative_path: path.clone(),
748 format,
749 });
750 }
751 specs.sort_by(|a, b| a.table.cmp(&b.table));
755 Ok(specs)
756 }
757 }
758 }
759
760 fn walk_entries(&self, target: &Path, limit: usize) -> Result<Vec<FileEntry>, McpError> {
764 let scanner = scankit::Scanner::new(scankit::ScanConfig::default().follow_symlinks(false))
765 .map_err(|e| McpError::internal_error(format!("scankit init: {e}"), None))?;
766
767 let mut out = Vec::new();
768 for result in scanner.walk(target) {
769 if out.len() >= limit {
770 break;
771 }
772 let Ok(entry) = result else { continue };
773 let relative =
774 path_to_forward_slash(entry.path.strip_prefix(&self.root).unwrap_or(&entry.path));
775 out.push(FileEntry {
776 relative_path: relative,
777 size_bytes: entry.size_bytes,
778 modified: entry
779 .modified
780 .map(|t| chrono::DateTime::<chrono::Utc>::from(t).to_rfc3339()),
781 extension: entry.extension,
782 });
783 }
784 Ok(out)
785 }
786}
787
788#[tool_handler]
793impl ServerHandler for SeryMcpServer {
794 fn get_info(&self) -> ServerInfo {
795 let mut server_info =
802 Implementation::new(env!("CARGO_PKG_NAME"), env!("CARGO_PKG_VERSION"));
803 server_info.description = Some(env!("CARGO_PKG_DESCRIPTION").to_string());
804 let homepage = env!("CARGO_PKG_HOMEPAGE");
805 if !homepage.is_empty() {
806 server_info.website_url = Some(homepage.to_string());
807 }
808
809 ServerInfo::new(ServerCapabilities::builder().enable_tools().build())
810 .with_server_info(server_info)
811 .with_protocol_version(ProtocolVersion::V_2024_11_05)
812 .with_instructions(
813 "sery-mcp exposes the local files under the configured --root as MCP tools. \
814 All tools are read-only. Path arguments are validated to fall under --root \
815 (no .. escape, no absolute paths). v0.3 ships six tools: list_folder, \
816 search_files, get_schema, sample_rows, read_document (DOCX/PDF/PPTX/HTML/IPYNB \
817 → markdown), and query_sql (DataFusion-backed SQL on CSV/Parquet — file is \
818 registered as table `data` for the duration of the call). \
819 See https://github.com/seryai/sery-mcp."
820 .to_string(),
821 )
822 }
823}
824
825fn validate_relative_components(raw: &str) -> Result<(), McpError> {
832 let p = Path::new(raw);
833 if p.is_absolute() || raw.starts_with('/') || raw.starts_with('\\') {
840 return Err(McpError::invalid_params(
841 "'path' must be relative to --root (no absolute paths)",
842 None,
843 ));
844 }
845 for component in p.components() {
846 match component {
847 Component::ParentDir => {
848 return Err(McpError::invalid_params(
849 "'path' must not contain '..' (no escaping the configured --root)",
850 None,
851 ));
852 }
853 Component::Prefix(_) | Component::RootDir => {
854 return Err(McpError::invalid_params(
859 "'path' must be relative to --root (no absolute paths)",
860 None,
861 ));
862 }
863 _ => {}
864 }
865 }
866 Ok(())
867}
868
869fn path_to_forward_slash(path: &Path) -> String {
882 let s = path.to_string_lossy().into_owned();
883 if std::path::MAIN_SEPARATOR == '/' {
884 s
885 } else {
886 s.replace(std::path::MAIN_SEPARATOR, "/")
887 }
888}
889
890fn extension_of(path: &Path) -> String {
893 path.extension()
894 .and_then(|s| s.to_str())
895 .map(str::to_ascii_lowercase)
896 .unwrap_or_default()
897}
898
899fn data_type_str(t: tabkit::DataType) -> &'static str {
901 match t {
902 tabkit::DataType::Bool => "boolean",
903 tabkit::DataType::Integer => "integer",
904 tabkit::DataType::Float => "float",
905 tabkit::DataType::Date => "date",
906 tabkit::DataType::DateTime => "datetime",
907 tabkit::DataType::Text => "text",
908 _ => "unknown",
911 }
912}
913
914fn value_to_json(v: &tabkit::Value) -> serde_json::Value {
916 match v {
917 tabkit::Value::Bool(b) => serde_json::Value::Bool(*b),
918 tabkit::Value::Integer(i) => serde_json::Value::Number((*i).into()),
919 tabkit::Value::Float(f) => serde_json::Number::from_f64(*f)
920 .map_or(serde_json::Value::Null, serde_json::Value::Number),
921 tabkit::Value::Date(s) | tabkit::Value::DateTime(s) | tabkit::Value::Text(s) => {
922 serde_json::Value::String(s.clone())
923 }
924 _ => serde_json::Value::Null,
927 }
928}
929
930#[derive(Debug)]
935struct TableSpec {
936 table: String,
937 path_for_sql: String,
938 relative_path: String,
939 format: &'static str,
940}
941
942fn validate_query_sql(sql: &str) -> Result<(), McpError> {
953 let trimmed = sql.trim();
954 if trimmed.is_empty() {
955 return Err(McpError::invalid_params("`sql` must not be empty", None));
956 }
957 let upper = trimmed.to_ascii_uppercase();
958 if !upper.starts_with("SELECT") && !upper.starts_with("WITH") {
959 return Err(McpError::invalid_params(
960 "sql must start with SELECT or WITH (read-only queries only)",
961 None,
962 ));
963 }
964
965 const FORBIDDEN: &[&str] = &[
966 "INSERT",
967 "UPDATE",
968 "DELETE",
969 "CREATE",
970 "DROP",
971 "ALTER",
972 "ATTACH",
973 "DETACH",
974 "COPY",
975 "PRAGMA",
976 "INSTALL",
977 "LOAD",
978 "EXPORT",
979 "IMPORT",
980 "CHECKPOINT",
981 "VACUUM",
982 "ANALYZE",
983 "TRUNCATE",
984 "GRANT",
985 "REVOKE",
986 "BEGIN",
987 "COMMIT",
988 "ROLLBACK",
989 "SAVEPOINT",
990 ];
991 let tokens: std::collections::HashSet<&str> = upper
992 .split(|c: char| !c.is_ascii_alphanumeric() && c != '_')
993 .filter(|t| !t.is_empty())
994 .collect();
995 for kw in FORBIDDEN {
996 if tokens.contains(*kw) {
997 return Err(McpError::invalid_params(
998 format!("forbidden SQL keyword: {kw} (query_sql is read-only)"),
999 None,
1000 ));
1001 }
1002 }
1003 Ok(())
1004}
1005
1006fn build_register_view(table: &str, path_for_sql: &str, format: &str) -> Result<String, McpError> {
1010 let escaped_path = sql_string_literal(path_for_sql);
1011 let read_call = match format {
1012 "csv" => format!("read_csv_auto({escaped_path})"),
1013 "tsv" => format!("read_csv_auto({escaped_path}, delim='\\t')"),
1014 "parquet" => format!("read_parquet({escaped_path})"),
1015 other => {
1016 return Err(McpError::invalid_params(
1017 format!(
1018 "query_sql supports csv / tsv / parquet only — got '{other}'. \
1019 Use get_schema or sample_rows for XLSX/ODS files."
1020 ),
1021 None,
1022 ));
1023 }
1024 };
1025 Ok(format!(
1028 "CREATE OR REPLACE VIEW {table} AS SELECT * FROM {read_call}"
1029 ))
1030}
1031
1032fn describe_input(specs: &[TableSpec]) -> String {
1036 if specs.len() == 1 && specs[0].table == "data" {
1037 return specs[0].relative_path.clone();
1038 }
1039 specs
1040 .iter()
1041 .map(|s| format!("{}={}", s.table, s.relative_path))
1042 .collect::<Vec<_>>()
1043 .join(", ")
1044}
1045
1046fn is_glob_pattern(s: &str) -> bool {
1049 s.contains('*') || s.contains('?') || s.contains('[')
1050}
1051
1052fn is_valid_sql_identifier(name: &str) -> bool {
1056 let mut chars = name.chars();
1057 let Some(first) = chars.next() else {
1058 return false;
1059 };
1060 if !(first.is_ascii_alphabetic() || first == '_') {
1061 return false;
1062 }
1063 chars.all(|c| c.is_ascii_alphanumeric() || c == '_')
1064}
1065
1066fn format_for_query_sql(path: &str) -> Result<&'static str, McpError> {
1068 let lower = path.to_ascii_lowercase();
1069 if lower.ends_with(".csv") {
1070 Ok("csv")
1071 } else if lower.ends_with(".tsv") {
1072 Ok("tsv")
1073 } else if lower.ends_with(".parquet") {
1074 Ok("parquet")
1075 } else {
1076 Err(McpError::invalid_params(
1077 format!(
1078 "query_sql expects a path / glob ending in .csv, .tsv, or .parquet — got '{path}'"
1079 ),
1080 None,
1081 ))
1082 }
1083}
1084
1085fn sql_string_literal(s: &str) -> String {
1089 format!("'{}'", s.replace('\'', "''"))
1090}
1091
1092#[allow(clippy::too_many_lines)] fn arrow_value_to_json(array: &dyn duckdb::arrow::array::Array, row: usize) -> serde_json::Value {
1104 use duckdb::arrow::array::{
1105 BooleanArray, Date32Array, Date64Array, Decimal128Array, Float32Array, Float64Array,
1106 Int16Array, Int32Array, Int64Array, Int8Array, LargeStringArray, StringArray,
1107 TimestampMicrosecondArray, TimestampMillisecondArray, TimestampNanosecondArray,
1108 TimestampSecondArray, UInt16Array, UInt32Array, UInt64Array, UInt8Array,
1109 };
1110 use duckdb::arrow::datatypes::DataType;
1111
1112 if array.is_null(row) {
1113 return serde_json::Value::Null;
1114 }
1115
1116 macro_rules! number {
1117 ($arr:ty) => {{
1118 let typed = array
1119 .as_any()
1120 .downcast_ref::<$arr>()
1121 .expect("downcast matches the matched DataType");
1122 serde_json::json!(typed.value(row))
1123 }};
1124 }
1125
1126 match array.data_type() {
1127 DataType::Boolean => {
1128 let typed = array
1129 .as_any()
1130 .downcast_ref::<BooleanArray>()
1131 .expect("BooleanArray");
1132 serde_json::Value::Bool(typed.value(row))
1133 }
1134 DataType::Int8 => number!(Int8Array),
1135 DataType::Int16 => number!(Int16Array),
1136 DataType::Int32 => number!(Int32Array),
1137 DataType::Int64 => number!(Int64Array),
1138 DataType::UInt8 => number!(UInt8Array),
1139 DataType::UInt16 => number!(UInt16Array),
1140 DataType::UInt32 => number!(UInt32Array),
1141 DataType::UInt64 => number!(UInt64Array),
1142 DataType::Float32 => {
1143 let typed = array
1144 .as_any()
1145 .downcast_ref::<Float32Array>()
1146 .expect("Float32Array");
1147 serde_json::Number::from_f64(f64::from(typed.value(row)))
1148 .map_or(serde_json::Value::Null, serde_json::Value::Number)
1149 }
1150 DataType::Float64 => {
1151 let typed = array
1152 .as_any()
1153 .downcast_ref::<Float64Array>()
1154 .expect("Float64Array");
1155 serde_json::Number::from_f64(typed.value(row))
1156 .map_or(serde_json::Value::Null, serde_json::Value::Number)
1157 }
1158 DataType::Utf8 => {
1159 let typed = array
1160 .as_any()
1161 .downcast_ref::<StringArray>()
1162 .expect("StringArray");
1163 serde_json::Value::String(typed.value(row).to_string())
1164 }
1165 DataType::LargeUtf8 => {
1166 let typed = array
1167 .as_any()
1168 .downcast_ref::<LargeStringArray>()
1169 .expect("LargeStringArray");
1170 serde_json::Value::String(typed.value(row).to_string())
1171 }
1172 DataType::Date32 => {
1173 let typed = array
1174 .as_any()
1175 .downcast_ref::<Date32Array>()
1176 .expect("Date32Array");
1177 typed
1178 .value_as_date(row)
1179 .map_or(serde_json::Value::Null, |d| {
1180 serde_json::Value::String(d.format("%Y-%m-%d").to_string())
1181 })
1182 }
1183 DataType::Date64 => {
1184 let typed = array
1185 .as_any()
1186 .downcast_ref::<Date64Array>()
1187 .expect("Date64Array");
1188 typed
1189 .value_as_date(row)
1190 .map_or(serde_json::Value::Null, |d| {
1191 serde_json::Value::String(d.format("%Y-%m-%d").to_string())
1192 })
1193 }
1194 DataType::Decimal128(_, scale) => {
1195 let typed = array
1202 .as_any()
1203 .downcast_ref::<Decimal128Array>()
1204 .expect("Decimal128Array");
1205 let raw = typed.value(row);
1206 if *scale == 0 {
1207 if let Ok(fits) = i64::try_from(raw) {
1208 return serde_json::Value::Number(fits.into());
1209 }
1210 }
1211 use duckdb::arrow::util::display::{ArrayFormatter, FormatOptions};
1212 ArrayFormatter::try_new(array, &FormatOptions::default()).map_or_else(
1213 |_| serde_json::Value::String(format!("(decimal {raw})")),
1214 |fmt| serde_json::Value::String(fmt.value(row).to_string()),
1215 )
1216 }
1217 DataType::Timestamp(_, _) => {
1218 if let Some(typed) = array.as_any().downcast_ref::<TimestampMicrosecondArray>() {
1222 return typed
1223 .value_as_datetime(row)
1224 .map_or(serde_json::Value::Null, |d| {
1225 serde_json::Value::String(d.and_utc().to_rfc3339())
1226 });
1227 }
1228 if let Some(typed) = array.as_any().downcast_ref::<TimestampMillisecondArray>() {
1229 return typed
1230 .value_as_datetime(row)
1231 .map_or(serde_json::Value::Null, |d| {
1232 serde_json::Value::String(d.and_utc().to_rfc3339())
1233 });
1234 }
1235 if let Some(typed) = array.as_any().downcast_ref::<TimestampNanosecondArray>() {
1236 return typed
1237 .value_as_datetime(row)
1238 .map_or(serde_json::Value::Null, |d| {
1239 serde_json::Value::String(d.and_utc().to_rfc3339())
1240 });
1241 }
1242 if let Some(typed) = array.as_any().downcast_ref::<TimestampSecondArray>() {
1243 return typed
1244 .value_as_datetime(row)
1245 .map_or(serde_json::Value::Null, |d| {
1246 serde_json::Value::String(d.and_utc().to_rfc3339())
1247 });
1248 }
1249 serde_json::Value::String(format!("(unsupported timestamp at row {row})"))
1250 }
1251 _ => {
1256 use duckdb::arrow::util::display::{ArrayFormatter, FormatOptions};
1257 ArrayFormatter::try_new(array, &FormatOptions::default()).map_or_else(
1258 |_| {
1259 serde_json::Value::String(format!("(unrenderable {} value)", array.data_type()))
1260 },
1261 |fmt| serde_json::Value::String(fmt.value(row).to_string()),
1262 )
1263 }
1264 }
1265}
1266
1267fn as_json_result<T: serde::Serialize>(value: &T) -> Result<CallToolResult, McpError> {
1271 let json = serde_json::to_string_pretty(value)
1272 .map_err(|e| McpError::internal_error(format!("serialize result: {e}"), None))?;
1273 Ok(CallToolResult::success(vec![Content::text(json)]))
1274}
1275
1276#[cfg(test)]
1281mod tests {
1282 use super::*;
1283 use std::fs;
1284 use tempfile::TempDir;
1285
1286 fn make_server(root: &Path) -> SeryMcpServer {
1287 SeryMcpServer::new(root.canonicalize().expect("temp dir must canonicalise"))
1288 }
1289
1290 #[test]
1293 fn resolve_subpath_defaults_to_root() {
1294 let dir = TempDir::new().unwrap();
1295 let server = make_server(dir.path());
1296 for input in [None, Some(""), Some(".")] {
1297 let resolved = server.resolve_subpath(input).unwrap();
1298 assert_eq!(resolved, server.root);
1299 }
1300 }
1301
1302 #[test]
1303 fn resolve_subpath_rejects_absolute() {
1304 let dir = TempDir::new().unwrap();
1305 let server = make_server(dir.path());
1306 let err = server.resolve_subpath(Some("/etc/passwd")).unwrap_err();
1307 assert!(format!("{err:?}").contains("absolute"));
1308 }
1309
1310 #[test]
1311 fn resolve_subpath_rejects_parent_dir() {
1312 let dir = TempDir::new().unwrap();
1313 let server = make_server(dir.path());
1314 let err = server.resolve_subpath(Some("../etc")).unwrap_err();
1315 assert!(format!("{err:?}").contains(".."));
1316 }
1317
1318 #[test]
1319 fn resolve_required_file_rejects_directory() {
1320 let dir = TempDir::new().unwrap();
1321 fs::create_dir(dir.path().join("sub")).unwrap();
1322 let server = make_server(dir.path());
1323 let err = server.resolve_required_file("sub").unwrap_err();
1324 assert!(format!("{err:?}").contains("regular file"));
1325 }
1326
1327 #[test]
1328 fn resolve_required_file_rejects_missing() {
1329 let dir = TempDir::new().unwrap();
1330 let server = make_server(dir.path());
1331 let err = server.resolve_required_file("nope.csv").unwrap_err();
1332 assert!(format!("{err:?}").contains("not readable"));
1333 }
1334
1335 #[test]
1336 fn resolve_required_file_accepts_real_file() {
1337 let dir = TempDir::new().unwrap();
1338 fs::write(dir.path().join("a.csv"), "x,y\n").unwrap();
1339 let server = make_server(dir.path());
1340 let resolved = server.resolve_required_file("a.csv").unwrap();
1341 assert_eq!(resolved, server.root.join("a.csv"));
1342 }
1343
1344 #[test]
1347 fn walk_entries_emits_files_under_root() {
1348 let dir = TempDir::new().unwrap();
1349 fs::write(dir.path().join("a.csv"), "x,y\n1,2\n").unwrap();
1350 fs::write(dir.path().join("b.txt"), "hello").unwrap();
1351 let server = make_server(dir.path());
1352 let entries = server.walk_entries(server.root(), 100).unwrap();
1353 assert_eq!(entries.len(), 2);
1354 let names: Vec<_> = entries.iter().map(|e| e.relative_path.clone()).collect();
1355 assert!(names.contains(&"a.csv".to_string()));
1356 assert!(names.contains(&"b.txt".to_string()));
1357 }
1358
1359 #[test]
1360 fn walk_entries_respects_limit() {
1361 let dir = TempDir::new().unwrap();
1362 for i in 0..10 {
1363 fs::write(dir.path().join(format!("f{i}.txt")), "x").unwrap();
1364 }
1365 let server = make_server(dir.path());
1366 let entries = server.walk_entries(server.root(), 3).unwrap();
1367 assert_eq!(entries.len(), 3);
1368 }
1369
1370 #[test]
1371 fn walk_entries_lowercases_extension() {
1372 let dir = TempDir::new().unwrap();
1373 fs::write(dir.path().join("REPORT.PDF"), "%PDF-").unwrap();
1374 let server = make_server(dir.path());
1375 let entries = server.walk_entries(server.root(), 100).unwrap();
1376 assert_eq!(entries.len(), 1);
1377 assert_eq!(entries[0].extension, "pdf");
1378 }
1379
1380 #[test]
1383 fn get_schema_returns_csv_columns() {
1384 let dir = TempDir::new().unwrap();
1385 fs::write(
1386 dir.path().join("sales.csv"),
1387 "id,name,amount\n1,alice,99.5\n2,bob,150.0\n",
1388 )
1389 .unwrap();
1390 let server = make_server(dir.path());
1391 let result = server
1392 .get_schema(Parameters(GetSchemaRequest {
1393 path: "sales.csv".into(),
1394 sheet: None,
1395 }))
1396 .unwrap();
1397 let payload = result_text(&result);
1398 let parsed: SchemaResponseDe = serde_json::from_str(&payload).unwrap();
1399 assert_eq!(parsed.format, "csv");
1400 assert_eq!(parsed.columns.len(), 3);
1401 let names: Vec<_> = parsed.columns.iter().map(|c| c.name.as_str()).collect();
1402 assert_eq!(names, vec!["id", "name", "amount"]);
1403 }
1404
1405 #[test]
1408 fn sample_rows_returns_header_keyed_objects() {
1409 let dir = TempDir::new().unwrap();
1410 fs::write(
1411 dir.path().join("sales.csv"),
1412 "id,name,amount\n1,alice,99.5\n2,bob,150.0\n3,eve,200.0\n",
1413 )
1414 .unwrap();
1415 let server = make_server(dir.path());
1416 let result = server
1417 .sample_rows(Parameters(SampleRowsRequest {
1418 path: "sales.csv".into(),
1419 limit: Some(2),
1420 sheet: None,
1421 }))
1422 .unwrap();
1423 let payload = result_text(&result);
1424 let parsed: SamplesResponseDe = serde_json::from_str(&payload).unwrap();
1425 assert_eq!(parsed.columns, vec!["id", "name", "amount"]);
1426 assert_eq!(parsed.rows.len(), 2);
1427 assert_eq!(parsed.rows[0].get("name").unwrap().as_str(), Some("alice"));
1428 }
1429
1430 #[test]
1433 fn search_files_ranks_basename_match_above_path_match() {
1434 let dir = TempDir::new().unwrap();
1435 fs::create_dir_all(dir.path().join("data/finance")).unwrap();
1436 fs::write(dir.path().join("data/finance/sales.csv"), "x").unwrap();
1437 fs::write(dir.path().join("salesreport.csv"), "x").unwrap();
1438 fs::write(dir.path().join("revenue.csv"), "x").unwrap();
1439 let server = make_server(dir.path());
1440 let result = server
1441 .search_files(Parameters(SearchFilesRequest {
1442 query: "sales".into(),
1443 extensions: None,
1444 limit: None,
1445 }))
1446 .unwrap();
1447 let payload = result_text(&result);
1448 let hits: Vec<SearchHitDe> = serde_json::from_str(&payload).unwrap();
1449 assert_eq!(hits.len(), 2);
1450 assert_eq!(hits[0].relative_path, "data/finance/sales.csv");
1452 assert!(hits[0].score > hits[1].score);
1453 }
1454
1455 #[test]
1456 fn search_files_extension_filter() {
1457 let dir = TempDir::new().unwrap();
1458 fs::write(dir.path().join("notes.csv"), "x").unwrap();
1459 fs::write(dir.path().join("notes.txt"), "x").unwrap();
1460 let server = make_server(dir.path());
1461 let result = server
1462 .search_files(Parameters(SearchFilesRequest {
1463 query: "notes".into(),
1464 extensions: Some(vec!["csv".into()]),
1465 limit: None,
1466 }))
1467 .unwrap();
1468 let hits: Vec<SearchHitDe> = serde_json::from_str(&result_text(&result)).unwrap();
1469 assert_eq!(hits.len(), 1);
1470 assert_eq!(hits[0].extension, "csv");
1471 }
1472
1473 fn query_req(
1476 path: Option<&str>,
1477 tables: Option<std::collections::HashMap<String, String>>,
1478 sql: &str,
1479 limit: Option<usize>,
1480 ) -> QuerySqlRequest {
1481 QuerySqlRequest {
1482 path: path.map(String::from),
1483 tables,
1484 sql: sql.into(),
1485 limit,
1486 }
1487 }
1488
1489 #[test]
1490 fn query_sql_csv_happy_path() {
1491 let dir = TempDir::new().unwrap();
1492 fs::write(
1493 dir.path().join("sales.csv"),
1494 "id,name,amount\n1,alice,100\n2,bob,250\n3,eve,50\n",
1495 )
1496 .unwrap();
1497 let server = make_server(dir.path());
1498 let result = server
1499 .query_sql(Parameters(query_req(
1500 Some("sales.csv"),
1501 None,
1502 "SELECT name, amount FROM data WHERE amount > 75 ORDER BY amount",
1503 None,
1504 )))
1505 .unwrap();
1506 let parsed: QueryResponseDe = serde_json::from_str(&result_text(&result)).unwrap();
1507 assert_eq!(parsed.format, "csv");
1508 assert_eq!(parsed.columns, vec!["name", "amount"]);
1509 assert_eq!(parsed.row_count, 2);
1510 assert!(!parsed.truncated);
1511 assert_eq!(parsed.rows[0].get("name").unwrap().as_str(), Some("alice"));
1512 assert_eq!(parsed.rows[1].get("name").unwrap().as_str(), Some("bob"));
1513 assert_eq!(parsed.input, "sales.csv");
1514 }
1515
1516 #[test]
1517 fn query_sql_truncates_at_limit() {
1518 use std::fmt::Write as _;
1519 let dir = TempDir::new().unwrap();
1520 let mut csv = String::from("n\n");
1521 for i in 0..20 {
1522 writeln!(csv, "{i}").unwrap();
1523 }
1524 fs::write(dir.path().join("nums.csv"), csv).unwrap();
1525 let server = make_server(dir.path());
1526 let result = server
1527 .query_sql(Parameters(query_req(
1528 Some("nums.csv"),
1529 None,
1530 "SELECT n FROM data",
1531 Some(5),
1532 )))
1533 .unwrap();
1534 let parsed: QueryResponseDe = serde_json::from_str(&result_text(&result)).unwrap();
1535 assert_eq!(parsed.row_count, 5);
1536 assert!(parsed.truncated);
1537 }
1538
1539 #[test]
1540 fn query_sql_rejects_unsupported_format() {
1541 let dir = TempDir::new().unwrap();
1542 fs::write(dir.path().join("notes.txt"), "hi").unwrap();
1543 let server = make_server(dir.path());
1544 let err = server
1545 .query_sql(Parameters(query_req(
1546 Some("notes.txt"),
1547 None,
1548 "SELECT 1",
1549 None,
1550 )))
1551 .unwrap_err();
1552 assert!(format!("{err:?}").to_lowercase().contains(".csv"));
1553 }
1554
1555 #[test]
1556 fn query_sql_surfaces_sql_parse_errors() {
1557 let dir = TempDir::new().unwrap();
1558 fs::write(dir.path().join("a.csv"), "x\n1\n").unwrap();
1559 let server = make_server(dir.path());
1560 let err = server
1561 .query_sql(Parameters(query_req(
1562 Some("a.csv"),
1563 None,
1564 "SELEKT * FROM data",
1565 None,
1566 )))
1567 .unwrap_err();
1568 let msg = format!("{err:?}").to_lowercase();
1569 assert!(msg.contains("sql") || msg.contains("read-only"));
1570 }
1571
1572 #[test]
1573 fn query_sql_blocks_ddl() {
1574 let dir = TempDir::new().unwrap();
1575 fs::write(dir.path().join("a.csv"), "x\n1\n").unwrap();
1576 let server = make_server(dir.path());
1577 for evil in [
1578 "DROP TABLE data",
1579 "ATTACH '/etc/passwd' AS p",
1580 "INSERT INTO data VALUES (1)",
1581 "PRAGMA table_info('data')",
1582 ] {
1583 let err = server
1584 .query_sql(Parameters(query_req(Some("a.csv"), None, evil, None)))
1585 .unwrap_err();
1586 let msg = format!("{err:?}").to_lowercase();
1587 assert!(
1588 msg.contains("forbidden") || msg.contains("read-only"),
1589 "expected SQL '{evil}' to be rejected; got {err:?}"
1590 );
1591 }
1592 }
1593
1594 #[test]
1595 fn query_sql_multi_file_join() {
1596 let dir = TempDir::new().unwrap();
1597 fs::write(
1598 dir.path().join("customers.csv"),
1599 "id,name\n1,Alice\n2,Bob\n",
1600 )
1601 .unwrap();
1602 fs::write(
1603 dir.path().join("orders.csv"),
1604 "customer_id,amount\n1,100\n1,50\n2,200\n",
1605 )
1606 .unwrap();
1607 let server = make_server(dir.path());
1608 let mut tables = std::collections::HashMap::new();
1609 tables.insert("customers".into(), "customers.csv".into());
1610 tables.insert("orders".into(), "orders.csv".into());
1611 let result = server
1612 .query_sql(Parameters(query_req(
1613 None,
1614 Some(tables),
1615 "SELECT c.name, SUM(o.amount) AS total \
1616 FROM customers c JOIN orders o ON c.id = o.customer_id \
1617 GROUP BY c.name ORDER BY total DESC",
1618 None,
1619 )))
1620 .unwrap();
1621 let parsed: QueryResponseDe = serde_json::from_str(&result_text(&result)).unwrap();
1622 assert_eq!(parsed.columns, vec!["name", "total"]);
1623 assert_eq!(parsed.row_count, 2);
1624 assert_eq!(parsed.rows[0].get("name").unwrap().as_str(), Some("Bob"));
1625 assert_eq!(parsed.rows[1].get("name").unwrap().as_str(), Some("Alice"));
1626 assert!(parsed.input.contains("customers=customers.csv"));
1628 assert!(parsed.input.contains("orders=orders.csv"));
1629 }
1630
1631 #[test]
1632 fn query_sql_glob_pattern() {
1633 let dir = TempDir::new().unwrap();
1634 fs::write(dir.path().join("jan.csv"), "amt\n10\n20\n").unwrap();
1635 fs::write(dir.path().join("feb.csv"), "amt\n30\n40\n").unwrap();
1636 let server = make_server(dir.path());
1637 let result = server
1638 .query_sql(Parameters(query_req(
1639 Some("*.csv"),
1640 None,
1641 "SELECT SUM(amt) AS total FROM data",
1642 None,
1643 )))
1644 .unwrap();
1645 let parsed: QueryResponseDe = serde_json::from_str(&result_text(&result)).unwrap();
1646 assert_eq!(parsed.row_count, 1);
1647 assert_eq!(
1648 parsed.rows[0]
1649 .get("total")
1650 .and_then(serde_json::Value::as_i64),
1651 Some(100)
1652 );
1653 }
1654
1655 #[test]
1656 fn query_sql_rejects_both_path_and_tables() {
1657 let dir = TempDir::new().unwrap();
1658 fs::write(dir.path().join("a.csv"), "x\n1\n").unwrap();
1659 let server = make_server(dir.path());
1660 let mut tables = std::collections::HashMap::new();
1661 tables.insert("t".into(), "a.csv".into());
1662 let err = server
1663 .query_sql(Parameters(query_req(
1664 Some("a.csv"),
1665 Some(tables),
1666 "SELECT 1",
1667 None,
1668 )))
1669 .unwrap_err();
1670 assert!(format!("{err:?}").contains("either"));
1671 }
1672
1673 #[test]
1674 fn query_sql_rejects_invalid_table_name() {
1675 let dir = TempDir::new().unwrap();
1676 fs::write(dir.path().join("a.csv"), "x\n1\n").unwrap();
1677 let server = make_server(dir.path());
1678 let mut tables = std::collections::HashMap::new();
1679 tables.insert("evil; DROP TABLE x".into(), "a.csv".into());
1680 let err = server
1681 .query_sql(Parameters(query_req(None, Some(tables), "SELECT 1", None)))
1682 .unwrap_err();
1683 assert!(format!("{err:?}").contains("identifier"));
1684 }
1685
1686 #[test]
1687 fn search_files_rejects_empty_query() {
1688 let dir = TempDir::new().unwrap();
1689 let server = make_server(dir.path());
1690 let err = server
1691 .search_files(Parameters(SearchFilesRequest {
1692 query: " ".into(),
1693 extensions: None,
1694 limit: None,
1695 }))
1696 .unwrap_err();
1697 assert!(format!("{err:?}").contains("empty"));
1698 }
1699
1700 fn result_text(result: &CallToolResult) -> String {
1703 let first = result.content.first().expect("at least one content item");
1704 if let Some(text) = first.as_text() {
1708 text.text.clone()
1709 } else {
1710 serde_json::to_string(&first).unwrap()
1711 }
1712 }
1713
1714 #[derive(serde::Deserialize)]
1718 struct SchemaResponseDe {
1719 #[allow(dead_code)]
1720 relative_path: String,
1721 format: String,
1722 columns: Vec<ColumnInfoDe>,
1723 #[allow(dead_code)]
1724 row_count: Option<u64>,
1725 }
1726
1727 #[derive(serde::Deserialize)]
1728 struct ColumnInfoDe {
1729 name: String,
1730 #[serde(rename = "type")]
1731 #[allow(dead_code)]
1732 data_type: String,
1733 #[allow(dead_code)]
1734 nullable: bool,
1735 }
1736
1737 #[derive(serde::Deserialize)]
1738 struct SamplesResponseDe {
1739 #[allow(dead_code)]
1740 relative_path: String,
1741 #[allow(dead_code)]
1742 format: String,
1743 columns: Vec<String>,
1744 rows: Vec<serde_json::Map<String, serde_json::Value>>,
1745 #[allow(dead_code)]
1746 row_count: Option<u64>,
1747 }
1748
1749 #[derive(serde::Deserialize)]
1750 struct SearchHitDe {
1751 relative_path: String,
1752 #[allow(dead_code)]
1753 size_bytes: u64,
1754 extension: String,
1755 score: f64,
1756 #[allow(dead_code)]
1757 why_matched: String,
1758 }
1759
1760 #[derive(serde::Deserialize)]
1761 struct QueryResponseDe {
1762 input: String,
1763 format: String,
1764 columns: Vec<String>,
1765 rows: Vec<serde_json::Map<String, serde_json::Value>>,
1766 row_count: usize,
1767 truncated: bool,
1768 }
1769}