1pub use techne_mcp as mcp;
2
3pub mod tool;
4pub mod transport;
5
6mod connection;
7#[cfg(feature = "http")]
8mod http;
9mod stdio;
10
11#[cfg(feature = "http")]
12pub use http::Http;
13pub use stdio::Stdio;
14pub use tool::Tool;
15pub use transport::Transport;
16
17use crate::connection::{Connection, Receipt};
18use crate::mcp::client;
19use crate::mcp::server;
20use crate::mcp::server::response::{self, Response};
21use crate::transport::{Action, Channel};
22
23use tokio::task;
24
25use std::collections::BTreeMap;
26use std::env;
27use std::io;
28use std::sync::Arc;
29
30#[derive(Default)]
31pub struct Server {
32 name: String,
33 version: String,
34 tools: BTreeMap<String, Tool>,
35}
36
37impl Server {
38 pub fn new(name: impl AsRef<str>, version: impl AsRef<str>) -> Self {
39 Self {
40 name: name.as_ref().to_owned(),
41 version: version.as_ref().to_owned(),
42 tools: BTreeMap::new(),
43 }
44 }
45
46 pub fn tools(mut self, tools: impl IntoIterator<Item = Tool>) -> Self {
47 self.tools = tools
48 .into_iter()
49 .map(|tool| (tool.name.clone(), tool))
50 .collect();
51
52 self
53 }
54
55 pub async fn run(self, mut transport: impl Transport) -> io::Result<()> {
56 let server = Arc::new(self);
57
58 loop {
59 let action = transport.accept().await?;
60
61 match action {
62 Action::Subscribe(channel) => {
63 let _ = channel.send(transport::Result::Reject);
64 }
65 Action::Handle(bytes, channel) => {
66 let server = server.clone();
67
68 drop(task::spawn(async move {
69 if let Err(error) = server.handle(bytes, channel).await {
70 log::error!("{error}");
71 }
72 }));
73 }
74 Action::Quit => return Ok(()),
75 }
76 }
77 }
78
79 async fn handle(&self, bytes: mcp::Bytes, channel: Channel) -> io::Result<()> {
80 match client::Message::<mcp::Value>::deserialize(&bytes) {
81 Ok(message) => match message {
82 client::Message::Request(request) => {
83 self.serve(Connection::new(request.id, channel), request.payload)
84 .await
85 }
86 client::Message::Notification(notification) => {
87 self.deliver_notification(Receipt::new(channel), notification.payload)
88 .await
89 }
90 client::Message::Response(response) => {
91 self.deliver_response(Receipt::new(channel), response).await
92 }
93 client::Message::Error(error) => {
94 self.deliver_error(Receipt::new(channel), error).await
95 }
96 },
97 Err(error) => {
98 let bytes = mcp::Error::invalid_json(error.to_string()).serialize()?;
99 let _ = channel.send(transport::Result::Send(bytes));
100
101 Ok(())
102 }
103 }
104 }
105
106 async fn serve(&self, connection: Connection, request: client::Request) -> io::Result<()> {
107 log::debug!("Serving {request:?}");
108
109 match request {
110 client::Request::Initialize { .. } => self.initialize(connection).await,
111 client::Request::Ping => self.ping(connection).await,
112 client::Request::ToolsList => self.list_tools(connection).await,
113 client::Request::ToolsCall { params: call } => self.call_tool(connection, call).await,
114 }
115 }
116
117 async fn initialize(&self, connection: Connection) -> io::Result<()> {
118 use crate::mcp::server::capabilities::{self, Capabilities};
119
120 connection
121 .finish(response::Initialize {
122 protocol_version: mcp::VERSION.to_owned(),
123 capabilities: Capabilities {
124 tools: (!self.tools.is_empty()).then_some(capabilities::Tools {
125 list_changed: false, }),
127 },
128 server_info: mcp::Server {
129 name: self.name.clone(),
130 version: self.version.clone(),
131 },
132 })
133 .await
134 }
135
136 async fn ping(&self, connection: Connection) -> io::Result<()> {
137 connection.finish(Response::Ping {}).await
138 }
139
140 async fn list_tools(&self, connection: Connection) -> io::Result<()> {
141 connection
142 .finish(response::ToolsList {
143 tools: self
144 .tools
145 .values()
146 .map(|tool| server::Tool {
147 name: tool.name.clone(),
148 title: None,
149 description: tool.description.clone(),
150 input_schema: tool.input().clone(),
151 output_schema: tool.output().cloned(),
152 })
153 .collect(),
154 })
155 .await
156 }
157
158 async fn call_tool(
159 &self,
160 mut connection: Connection,
161 call: client::request::ToolCall,
162 ) -> io::Result<()> {
163 use futures::StreamExt;
164
165 let Some(tool) = self.tools.get(&call.name) else {
166 return connection
167 .error(mcp::ErrorKind::invalid_params(format!(
168 "Unknown tool: {}",
169 &call.name
170 )))
171 .await;
172 };
173
174 let mut output = tool.call(call.arguments)?.boxed();
175
176 while let Some(action) = output.next().await {
177 match action {
178 crate::tool::Action::Request(request) => connection.request(request).await?,
179 crate::tool::Action::Notify(notification) => {
180 connection.notify(notification).await?
181 }
182 crate::tool::Action::Finish(outcome) => return connection.finish(outcome?).await,
183 }
184 }
185
186 Ok(())
187 }
188
189 async fn deliver_notification(
190 &self,
191 receipt: Receipt,
192 _notification: client::Notification,
193 ) -> io::Result<()> {
194 receipt.reject();
196
197 Ok(())
198 }
199
200 async fn deliver_response(&self, receipt: Receipt, _response: mcp::Response) -> io::Result<()> {
201 receipt.reject();
203
204 Ok(())
205 }
206
207 async fn deliver_error(&self, receipt: Receipt, _error: mcp::Error) -> io::Result<()> {
208 receipt.reject();
210
211 Ok(())
212 }
213}
214
215pub async fn transport(mut args: env::Args) -> io::Result<impl Transport> {
216 enum HttpOrStdio {
217 #[cfg(feature = "http")]
218 Http(Http),
219 Stdio(Stdio),
220 }
221
222 impl Transport for HttpOrStdio {
223 fn accept(&mut self) -> impl Future<Output = io::Result<Action>> {
224 use futures::FutureExt;
225
226 match self {
227 #[cfg(feature = "http")]
228 HttpOrStdio::Http(http) => http.accept().boxed(),
229 HttpOrStdio::Stdio(stdio) => stdio.accept().boxed(),
230 }
231 }
232 }
233
234 let _executable = args.next();
235
236 let protocol = args.next();
237 let protocol = protocol.as_deref();
238
239 if protocol == Some("--http") {
240 #[cfg(feature = "http")]
241 {
242 let address = args.next();
243 let address = address.as_deref().unwrap_or("127.0.0.1:8080");
244
245 let rest = args.next();
246
247 if let Some(rest) = rest {
248 return Err(io::Error::new(
249 io::ErrorKind::InvalidInput,
250 format!("Unknown argument: {rest}"),
251 ));
252 }
253
254 return Ok(HttpOrStdio::Http(Http::bind(address).await?));
255 }
256
257 #[cfg(not(feature = "http"))]
258 return Err(io::Error::new(
259 io::ErrorKind::InvalidInput,
260 format!("Streamable HTTP is not supported for this server"),
261 ));
262 }
263
264 if let Some(protocol) = protocol {
265 return Err(io::Error::new(
266 io::ErrorKind::InvalidInput,
267 format!("unknown argument: {protocol}"),
268 ));
269 }
270
271 Ok(HttpOrStdio::Stdio(Stdio::current()))
272}