Skip to main content

tiller_sync/mcp/
mod.rs

1//! MCP (Model Context Protocol) server implementation.
2//!
3//! This module provides an MCP server that exposes tiller functionality as tools
4//! for AI agent integration. The server communicates via JSON-RPC over stdio.
5
6/// Checks if the server has been initialized and returns an error if not.
7macro_rules! require_init {
8    ($self:expr) => {
9        if !$self.check_initialized().await {
10            return Self::uninitialized();
11        }
12    };
13}
14
15mod mcp_utils;
16mod tools;
17
18use crate::{Config, Mode};
19use rmcp::handler::server::tool::ToolRouter;
20use rmcp::model::{
21    CallToolResult, Implementation, ProtocolVersion, ServerCapabilities, ServerInfo,
22};
23use rmcp::transport::stdio;
24use rmcp::ErrorData as McpError;
25use rmcp::{tool_handler, ServerHandler, ServiceExt};
26use std::sync::Arc;
27use tokio::sync::Mutex;
28use tracing::info;
29
30/// The tiller MCP server.
31///
32/// This server exposes tiller sync functionality as MCP tools.
33#[derive(Debug, Clone)]
34pub struct TillerServer {
35    initialized: Arc<Mutex<bool>>,
36    mode: Mode,
37    config: Arc<Config>,
38    tool_router: ToolRouter<TillerServer>,
39}
40
41impl TillerServer {
42    /// Creates a new TillerServer with the given configuration.
43    pub fn new(config: Config, mode: Mode) -> Self {
44        Self {
45            initialized: Arc::new(Mutex::new(false)),
46            mode,
47            config: Arc::new(config),
48            tool_router: Self::tool_router(),
49        }
50    }
51
52    async fn check_initialized(&self) -> bool {
53        *self.initialized.lock().await
54    }
55
56    fn uninitialized() -> Result<CallToolResult, McpError> {
57        Ok(CallToolResult::error(vec![rmcp::model::Content::text(
58            "You have not yet initialized the service. Please call __initialize_service__ first.",
59        )]))
60    }
61}
62
63#[tool_handler]
64impl ServerHandler for TillerServer {
65    /// Returns server information sent to the MCP client during initialization.
66    ///
67    /// The `instructions` field is intended by the specification to be the primary way to
68    /// communicate the server's purpose and usage to AI agents like Claude Code. This text is shown
69    /// to the AI to help it understand when and how to use this server's tools. However, it has
70    /// been noted that agents tend to consider this reading as optional. We have solved this
71    /// problem by requiring agents to call an `__initialize_service__` tool before anything else.
72    fn get_info(&self) -> ServerInfo {
73        ServerInfo {
74            protocol_version: ProtocolVersion::V_2024_11_05,
75            capabilities: ServerCapabilities::builder().enable_tools().build(),
76            server_info: Implementation {
77                name: "tiller".into(),
78                version: env!("CARGO_PKG_VERSION").into(),
79                ..Default::default()
80            },
81            instructions: Some(include_str!("docs/INTRO.md").into()),
82        }
83    }
84}
85
86/// Transport type for the MCP server.
87#[derive(Debug, Default)]
88pub(crate) enum Io {
89    #[default]
90    Stdio,
91    /// Mock transport for testing - holds one end of a duplex channel.
92    #[cfg(test)]
93    Mock(tokio::io::DuplexStream),
94}
95
96/// Runs the MCP server with stdio transport or mock transport. This function starts the MCP server
97/// and blocks until the client disconnects or an error occurs.
98///
99/// # Arguments
100/// - `config`: The `Config` object
101/// - `mode`: Whether we are running with a live Google sheet or with a test sheet
102/// - `io`: Whether we are using stdio as the transport or using mock io for testing
103///
104pub(crate) async fn run_server(config: Config, mode: Mode, io: Io) -> crate::Result<()> {
105    use crate::error::{ErrorType, IntoResult};
106    let server = TillerServer::new(config, mode);
107    info!("Starting MCP server...");
108
109    let service = match io {
110        Io::Stdio => server
111            .serve(stdio())
112            .await
113            .map_err(|e| anyhow::anyhow!("Failed to start MCP server: {e}"))
114            .pub_result(ErrorType::Service)?,
115        #[cfg(test)]
116        Io::Mock(stream) => server
117            .serve(stream)
118            .await
119            .map_err(|e| anyhow::anyhow!("Failed to start MCP server: {e}"))
120            .pub_result(ErrorType::Service)?,
121    };
122
123    info!("MCP server running, waiting for requests...");
124
125    // Wait for the server to complete (client disconnects or error)
126    service
127        .waiting()
128        .await
129        .map_err(|e| anyhow::anyhow!("MCP server error: {e}"))
130        .pub_result(ErrorType::Service)?;
131
132    info!("MCP server shut down");
133    Ok(())
134}
135
136#[cfg(test)]
137mod tests {
138    use super::*;
139    use crate::args::UpdateTransactionsArgs;
140    use crate::test::TestEnv;
141    use rmcp::ServiceExt;
142    use tokio::io::duplex;
143
144    /// Integration test for the MCP server using an in-memory transport.
145    /// Tests initialize_service, sync_down, and sync_up tools.
146    #[tokio::test]
147    async fn test_mcp_server_integration() {
148        // Create duplex channel - one end for server, one for client
149        let (client_io, server_io) = duplex(4096);
150
151        // Create test environment (holds TempDir alive for duration of test)
152        let env = TestEnv::new().await;
153        let config = env.config();
154
155        // Spawn server in background task
156        let server_handle =
157            tokio::spawn(
158                async move { run_server(config, Mode::Testing, Io::Mock(server_io)).await },
159            );
160
161        // Create MCP client connected to the other end
162        let client = ().serve(client_io).await.expect("Failed to create client");
163
164        // Test 1: Call initialize_service tool
165        let init_result = client
166            .call_tool(rmcp::model::CallToolRequestParam {
167                name: "initialize_service".into(),
168                arguments: None,
169            })
170            .await
171            .expect("initialize_service call failed");
172
173        assert!(
174            !init_result.is_error.unwrap_or(false),
175            "initialize_service returned error: {:?}",
176            init_result.content
177        );
178
179        // Test 2: Call sync_down tool
180        let sync_down_result = client
181            .call_tool(rmcp::model::CallToolRequestParam {
182                name: "sync_down".into(),
183                arguments: None,
184            })
185            .await
186            .expect("sync_down call failed");
187
188        assert!(
189            !sync_down_result.is_error.unwrap_or(false),
190            "sync_down returned error: {:?}",
191            sync_down_result.content
192        );
193
194        // Test 3: Call sync_up tool with force and formulas params
195        let mut args = serde_json::Map::new();
196        args.insert("force".into(), serde_json::Value::Bool(true));
197        args.insert(
198            "formulas".into(),
199            serde_json::Value::String("ignore".into()),
200        );
201
202        let sync_up_result = client
203            .call_tool(rmcp::model::CallToolRequestParam {
204                name: "sync_up".into(),
205                arguments: Some(args),
206            })
207            .await
208            .expect("sync_up call failed");
209
210        assert!(
211            !sync_up_result.is_error.unwrap_or(false),
212            "sync_up returned error: {:?}",
213            sync_up_result.content
214        );
215
216        // Test 4: Call update_transaction tool
217        // After sync_down, we have transactions in the database. Get one to update.
218        let tiller_data = env.config().db().get_tiller_data().await.unwrap();
219        let first_txn = &tiller_data.transactions.data()[0];
220        let txn_id = first_txn.transaction_id.clone();
221        let updates = crate::model::TransactionUpdates {
222            note: Some("Updated via MCP".to_string()),
223            ..Default::default()
224        };
225        let updates = UpdateTransactionsArgs::new(vec![txn_id], updates).unwrap();
226        let updates_json = serde_json::to_value(&updates)
227            .unwrap()
228            .as_object()
229            .unwrap()
230            .clone();
231
232        let update_result = client
233            .call_tool(rmcp::model::CallToolRequestParam {
234                name: "update_transactions".into(),
235                arguments: Some(updates_json),
236            })
237            .await
238            .expect("update_transactions call failed");
239
240        assert!(
241            !update_result.is_error.unwrap_or(false),
242            "update_transactions returned error: {:?}",
243            update_result.content
244        );
245
246        // Drop client to trigger server shutdown
247        drop(client);
248
249        // Wait for server to finish (with timeout)
250        let server_result = tokio::time::timeout(std::time::Duration::from_secs(5), server_handle)
251            .await
252            .expect("Server timed out")
253            .expect("Server task panicked");
254
255        assert!(
256            server_result.is_ok(),
257            "Server returned error: {:?}",
258            server_result
259        );
260    }
261
262    /// Queries MCP tool definitions and writes them to `.ignore/mcp_tools.txt`.
263    /// This provides a human-readable dump of the tool schemas for inspection.
264    #[tokio::test]
265    async fn write_mcp_tools_to_file() {
266        use std::fs::{self, File};
267        use std::io::Write;
268        use std::path::PathBuf;
269
270        fn project_root() -> PathBuf {
271            PathBuf::from(env!("CARGO_MANIFEST_DIR"))
272        }
273
274        // Create duplex channel
275        let (client_io, server_io) = duplex(4096);
276
277        // Create test environment
278        let env = TestEnv::new().await;
279        let config = env.config();
280
281        // Spawn server in background
282        let _server_handle =
283            tokio::spawn(
284                async move { run_server(config, Mode::Testing, Io::Mock(server_io)).await },
285            );
286
287        // Create MCP client
288        let client = ().serve(client_io).await.expect("Failed to create client");
289
290        // Get the list of tools
291        let tools_response = client
292            .list_tools(Default::default())
293            .await
294            .expect("Failed to list tools");
295
296        // Build output string
297        let mut output = String::new();
298        output.push_str(&format!(
299            "=== MCP Tools ({} total) ===\n\n",
300            tools_response.tools.len()
301        ));
302
303        for tool in &tools_response.tools {
304            output.push_str(
305                "────────────────────────────────────────────────────────────────────────────────\n",
306            );
307            output.push_str(&format!("TOOL: {}\n", tool.name));
308            output.push_str(
309                "────────────────────────────────────────────────────────────────────────────────\n",
310            );
311            output.push_str("\nDescription:\n");
312            if let Some(desc) = &tool.description {
313                for desc_line in desc.lines() {
314                    output.push_str(&format!("  {}\n", desc_line));
315                }
316            }
317            output.push_str("\nInput Schema:\n");
318            output.push_str(&serde_json::to_string_pretty(&tool.input_schema).unwrap());
319            output.push_str("\n\n");
320        }
321
322        // Write to .ignore/mcp_tools.txt
323        let ignore_dir = project_root().join(".ignore");
324        fs::create_dir_all(&ignore_dir).expect("Failed to create .ignore directory");
325
326        let output_path = ignore_dir.join("mcp_tools.txt");
327        let mut file = File::create(&output_path).expect("Failed to create output file");
328        file.write_all(output.as_bytes())
329            .expect("Failed to write output");
330    }
331}