1pub use techne_mcp as mcp;
2
3pub mod tool;
4pub mod transport;
5
6mod connection;
7#[cfg(feature = "http")]
8mod http;
9mod stdio;
10
11#[cfg(feature = "http")]
12pub use http::Http;
13pub use stdio::Stdio;
14pub use tool::Tool;
15pub use transport::Transport;
16
17use crate::connection::{Connection, Receipt};
18use crate::mcp::client;
19use crate::mcp::server;
20use crate::mcp::server::response::{self, Response};
21use crate::transport::{Action, Channel};
22
23use tokio::task;
24
25use std::collections::BTreeMap;
26use std::env;
27use std::io;
28use std::sync::Arc;
29
30#[derive(Default)]
31pub struct Server {
32    name: String,
33    version: String,
34    tools: BTreeMap<String, Tool>,
35}
36
37impl Server {
38    pub fn new(name: impl AsRef<str>, version: impl AsRef<str>) -> Self {
39        Self {
40            name: name.as_ref().to_owned(),
41            version: version.as_ref().to_owned(),
42            tools: BTreeMap::new(),
43        }
44    }
45
46    pub fn tools(mut self, tools: impl IntoIterator<Item = Tool>) -> Self {
47        self.tools = tools
48            .into_iter()
49            .map(|tool| (tool.name.clone(), tool))
50            .collect();
51
52        self
53    }
54
55    pub async fn run(self, mut transport: impl Transport) -> io::Result<()> {
56        let server = Arc::new(self);
57
58        loop {
59            let action = transport.accept().await?;
60
61            match action {
62                Action::Subscribe(channel) => {
63                    let _ = channel.send(transport::Result::Reject);
64                }
65                Action::Handle(bytes, channel) => {
66                    let server = server.clone();
67
68                    drop(task::spawn(async move {
69                        if let Err(error) = server.handle(bytes, channel).await {
70                            log::error!("{error}");
71                        }
72                    }));
73                }
74                Action::Quit => return Ok(()),
75            }
76        }
77    }
78
79    async fn handle(&self, bytes: mcp::Bytes, channel: Channel) -> io::Result<()> {
80        match client::Message::<mcp::Value>::deserialize(&bytes) {
81            Ok(message) => match message {
82                client::Message::Request(request) => {
83                    self.serve(Connection::new(request.id, channel), request.payload)
84                        .await
85                }
86                client::Message::Notification(notification) => {
87                    self.deliver_notification(Receipt::new(channel), notification.payload)
88                        .await
89                }
90                client::Message::Response(response) => {
91                    self.deliver_response(Receipt::new(channel), response).await
92                }
93                client::Message::Error(error) => {
94                    self.deliver_error(Receipt::new(channel), error).await
95                }
96            },
97            Err(error) => {
98                let bytes = mcp::Error::invalid_json(error.to_string()).serialize()?;
99                let _ = channel.send(transport::Result::Send(bytes));
100
101                Ok(())
102            }
103        }
104    }
105
106    async fn serve(&self, connection: Connection, request: client::Request) -> io::Result<()> {
107        log::debug!("Serving {request:?}");
108
109        match request {
110            client::Request::Initialize { .. } => self.initialize(connection).await,
111            client::Request::Ping => self.ping(connection).await,
112            client::Request::ToolsList => self.list_tools(connection).await,
113            client::Request::ToolsCall { params: call } => self.call_tool(connection, call).await,
114        }
115    }
116
117    async fn initialize(&self, connection: Connection) -> io::Result<()> {
118        use crate::mcp::server::capabilities::{self, Capabilities};
119
120        connection
121            .finish(response::Initialize {
122                protocol_version: mcp::VERSION.to_owned(),
123                capabilities: Capabilities {
124                    tools: (!self.tools.is_empty()).then_some(capabilities::Tools {
125                        list_changed: false, }),
127                },
128                server_info: mcp::Server {
129                    name: self.name.clone(),
130                    version: self.version.clone(),
131                },
132            })
133            .await
134    }
135
136    async fn ping(&self, connection: Connection) -> io::Result<()> {
137        connection.finish(Response::Ping {}).await
138    }
139
140    async fn list_tools(&self, connection: Connection) -> io::Result<()> {
141        connection
142            .finish(response::ToolsList {
143                tools: self
144                    .tools
145                    .values()
146                    .map(|tool| server::Tool {
147                        name: tool.name.clone(),
148                        title: None,
149                        description: tool.description.clone(),
150                        input_schema: tool.input().clone(),
151                        output_schema: tool.output().cloned(),
152                    })
153                    .collect(),
154            })
155            .await
156    }
157
158    async fn call_tool(
159        &self,
160        mut connection: Connection,
161        call: client::request::ToolCall,
162    ) -> io::Result<()> {
163        use futures::StreamExt;
164
165        let Some(tool) = self.tools.get(&call.name) else {
166            return connection
167                .error(mcp::ErrorKind::invalid_params(format!(
168                    "Unknown tool: {}",
169                    &call.name
170                )))
171                .await;
172        };
173
174        let mut output = tool.call(call.arguments)?.boxed();
175
176        while let Some(action) = output.next().await {
177            match action {
178                crate::tool::Action::Request(request) => connection.request(request).await?,
179                crate::tool::Action::Notify(notification) => {
180                    connection.notify(notification).await?
181                }
182                crate::tool::Action::Finish(outcome) => return connection.finish(outcome?).await,
183            }
184        }
185
186        Ok(())
187    }
188
189    async fn deliver_notification(
190        &self,
191        receipt: Receipt,
192        _notification: client::Notification,
193    ) -> io::Result<()> {
194        receipt.reject();
196
197        Ok(())
198    }
199
200    async fn deliver_response(&self, receipt: Receipt, _response: mcp::Response) -> io::Result<()> {
201        receipt.reject();
203
204        Ok(())
205    }
206
207    async fn deliver_error(&self, receipt: Receipt, _error: mcp::Error) -> io::Result<()> {
208        receipt.reject();
210
211        Ok(())
212    }
213}
214
215pub async fn transport(mut args: env::Args) -> io::Result<impl Transport> {
216    enum HttpOrStdio {
217        #[cfg(feature = "http")]
218        Http(Http),
219        Stdio(Stdio),
220    }
221
222    impl Transport for HttpOrStdio {
223        fn accept(&mut self) -> impl Future<Output = io::Result<Action>> {
224            use futures::FutureExt;
225
226            match self {
227                #[cfg(feature = "http")]
228                HttpOrStdio::Http(http) => http.accept().boxed(),
229                HttpOrStdio::Stdio(stdio) => stdio.accept().boxed(),
230            }
231        }
232    }
233
234    let _executable = args.next();
235
236    let protocol = args.next();
237    let protocol = protocol.as_deref();
238
239    if protocol == Some("--http") {
240        #[cfg(feature = "http")]
241        {
242            let address = args.next();
243            let address = address.as_deref().unwrap_or("127.0.0.1:8080");
244
245            let rest = args.next();
246
247            if let Some(rest) = rest {
248                return Err(io::Error::new(
249                    io::ErrorKind::InvalidInput,
250                    format!("Unknown argument: {rest}"),
251                ));
252            }
253
254            return Ok(HttpOrStdio::Http(Http::bind(address).await?));
255        }
256
257        #[cfg(not(feature = "http"))]
258        return Err(io::Error::new(
259            io::ErrorKind::InvalidInput,
260            format!("Streamable HTTP is not supported for this server"),
261        ));
262    }
263
264    if let Some(protocol) = protocol {
265        return Err(io::Error::new(
266            io::ErrorKind::InvalidInput,
267            format!("unknown argument: {protocol}"),
268        ));
269    }
270
271    Ok(HttpOrStdio::Stdio(Stdio::current()))
272}