1use std::collections::HashMap;
2use std::sync::Arc;
3use std::time::Duration;
4
5use hyper_util::{
6 rt::{TokioExecutor, TokioIo},
7 server::conn::auto::Builder,
8 service::TowerToHyperService,
9};
10use nu_protocol::{ShellError, engine::EngineState, shell_error::generic::GenericError};
11use rmcp::{
12 ServiceExt,
13 transport::{
14 stdio,
15 streamable_http_server::{
16 StreamableHttpServerConfig, StreamableHttpService,
17 session::local::{LocalSessionManager, SessionConfig},
18 },
19 },
20};
21use server::NushellMcpServer;
22use tokio::runtime::Runtime;
23use tokio::sync::RwLock;
24use tokio_util::sync::CancellationToken;
25use tracing_subscriber::EnvFilter;
26
27mod evaluation;
28mod history;
29mod server;
30
31#[derive(Debug, Clone, Default)]
33pub enum McpTransport {
34 #[default]
36 Stdio,
37 Http {
39 port: u16,
41 },
42}
43
44pub fn initialize_mcp_server(
45 mut engine_state: EngineState,
46 transport: McpTransport,
47) -> Result<(), ShellError> {
48 tracing_subscriber::fmt()
49 .with_env_filter(EnvFilter::from_default_env().add_directive(tracing::Level::DEBUG.into()))
50 .with_writer(std::io::stderr)
51 .with_ansi(false)
52 .init();
53
54 #[cfg(unix)]
63 {
64 let _ = nix::unistd::setsid();
66 }
67
68 engine_state.is_mcp = true;
71
72 tracing::info!(?transport, "Starting MCP server");
73 let runtime = Runtime::new().map_err(|e| {
74 ShellError::Generic(GenericError::new_internal(
75 format!("Could not instantiate tokio: {e}"),
76 "",
77 ))
78 })?;
79
80 runtime.block_on(async {
81 let result = match transport {
82 McpTransport::Stdio => run_stdio_server(engine_state).await,
83 McpTransport::Http { port } => run_http_server(engine_state, port).await,
84 };
85 if let Err(e) = result {
86 tracing::error!("Error running MCP server: {:?}", e);
87 }
88 });
89 Ok(())
90}
91
92async fn run_stdio_server(engine_state: EngineState) -> Result<(), Box<dyn std::error::Error>> {
93 NushellMcpServer::new(engine_state)
94 .serve(stdio())
95 .await
96 .inspect_err(|e| {
97 tracing::error!("serving error: {:?}", e);
98 })?
99 .waiting()
100 .await?;
101 Ok(())
102}
103
104const SESSION_KEEP_ALIVE: Duration = Duration::from_secs(30 * 60);
106
107const SESSION_CHANNEL_CAPACITY: usize = 16;
109
110async fn run_http_server(
111 engine_state: EngineState,
112 port: u16,
113) -> Result<(), Box<dyn std::error::Error>> {
114 let engine_state = Arc::new(engine_state);
115
116 let cancellation_token = CancellationToken::new();
118
119 let session_manager = Arc::new(LocalSessionManager {
120 sessions: RwLock::new(HashMap::new()),
121 session_config: SessionConfig {
122 channel_capacity: SESSION_CHANNEL_CAPACITY,
123 keep_alive: Some(SESSION_KEEP_ALIVE),
124 },
125 });
126
127 let service = TowerToHyperService::new(StreamableHttpService::new(
128 {
129 let engine_state = engine_state.clone();
130 move || Ok(NushellMcpServer::new((*engine_state).clone()))
131 },
132 session_manager,
133 StreamableHttpServerConfig::default(),
134 ));
135
136 let addr = format!("0.0.0.0:{port}");
137 let listener = tokio::net::TcpListener::bind(&addr).await?;
138 tracing::info!("MCP HTTP server listening on http://{addr}");
139 eprintln!("MCP HTTP server listening on http://{addr}");
140
141 loop {
142 let io = tokio::select! {
143 _ = tokio::signal::ctrl_c() => {
144 tracing::info!("Received Ctrl-C, shutting down...");
145 cancellation_token.cancel();
146 break;
147 }
148 accept = listener.accept() => {
149 TokioIo::new(accept?.0)
150 }
151 };
152 let service = service.clone();
153 tokio::spawn(async move {
154 let _ = Builder::new(TokioExecutor::new())
155 .serve_connection(io, service)
156 .await;
157 });
158 }
159 Ok(())
160}