techne_client/
lib.rs

1pub use techne_mcp as mcp;
2
3pub mod transport;
4
5mod connection;
6#[cfg(feature = "http")]
7mod http;
8mod stdio;
9
10#[cfg(feature = "http")]
11pub use http::Http;
12pub use stdio::Stdio;
13pub use transport::Transport;
14
15use connection::Connection;
16
17use crate::mcp::client::request;
18use crate::mcp::client::{Capabilities, Notification, Request, Response};
19use crate::mcp::server;
20use crate::mcp::server::tool;
21
22use sipper::{Straw, sipper};
23
24use std::fmt;
25use std::io;
26use std::sync::Arc;
27
28#[derive(Debug)]
29pub struct Client {
30    session: Session,
31    server: Server,
32}
33
34impl Client {
35    pub async fn new(
36        name: impl AsRef<str>,
37        version: impl AsRef<str>,
38        transport: impl Transport + Send + Sync + 'static,
39    ) -> io::Result<Self> {
40        let mut session = Session {
41            transport: Arc::new(transport),
42            next_request: mcp::Id::default(),
43        };
44
45        let initialize = session
46            .request(request::Initialize {
47                protocol_version: mcp::VERSION.to_owned(),
48                capabilities: Capabilities {},
49                client_info: mcp::Client {
50                    name: name.as_ref().to_owned(),
51                    title: None, // TODO
52                    version: version.as_ref().to_owned(),
53                },
54            })
55            .await?
56            .response::<server::response::Initialize>()
57            .await?;
58
59        if initialize.result.protocol_version != mcp::VERSION {
60            return Err(io::Error::new(
61                io::ErrorKind::Unsupported,
62                format!(
63                    "protocol mismatch (supported: {supported}, given: {given})",
64                    supported = mcp::VERSION,
65                    given = initialize.result.protocol_version,
66                ),
67            ));
68        }
69
70        let _ = session.notify(Notification::Initialized).await;
71
72        Ok(Self {
73            session,
74            server: Server {
75                capabilities: initialize.result.capabilities,
76                information: initialize.result.server_info,
77            },
78        })
79    }
80
81    pub fn server(&self) -> &Server {
82        &self.server
83    }
84
85    pub async fn list_tools(&mut self) -> io::Result<Vec<server::Tool>> {
86        let list = self.session.request(Request::ToolsList).await?;
87
88        let mcp::Response {
89            result: server::response::ToolsList { tools },
90            ..
91        } = list.response().await?;
92
93        Ok(tools)
94    }
95
96    pub fn call_tool(
97        &mut self,
98        name: impl AsRef<str>,
99        arguments: mcp::Value,
100    ) -> impl Straw<tool::Response, Event, io::Error> {
101        sipper(async move |mut sender| {
102            let mut call = self
103                .session
104                .request(Request::ToolsCall {
105                    params: request::ToolCall {
106                        name: name.as_ref().to_owned(),
107                        arguments,
108                    },
109                })
110                .await?;
111
112            loop {
113                match call.next().await? {
114                    server::Message::Request(request) => {
115                        sender
116                            .send(Event::Request(request.id, request.payload))
117                            .await;
118                    }
119                    server::Message::Notification(notification) => {
120                        sender.send(Event::Notification(notification.payload)).await;
121                    }
122                    server::Message::Response(response) => {
123                        return Ok(response.result);
124                    }
125                    server::Message::Error(error) => {
126                        log::warn!("{error}");
127                    }
128                }
129            }
130        })
131    }
132}
133
134#[derive(Debug, Clone)]
135pub enum Event {
136    Notification(server::Notification),
137    Request(mcp::Id, server::Request),
138}
139
140struct Session {
141    transport: Arc<dyn Transport + Send + Sync>,
142    next_request: mcp::Id,
143}
144
145impl Session {
146    async fn request(&mut self, request: impl Into<Request>) -> io::Result<Connection> {
147        let request = request.into();
148
149        self.transport
150            .send(mcp::Request::new(self.next_request.increment(), request).serialize()?)
151            .await
152            .map(Connection::new)
153    }
154
155    async fn notify(&self, notification: impl Into<Notification>) -> io::Result<()> {
156        let notification = notification.into();
157
158        self.transport
159            .send(mcp::Notification::new(notification).serialize()?)
160            .await?;
161
162        Ok(())
163    }
164
165    #[allow(unused)]
166    async fn response(&self, id: mcp::Id, response: Response) -> io::Result<()> {
167        self.transport
168            .send(mcp::Response::new(id, response).serialize()?)
169            .await;
170
171        Ok(())
172    }
173}
174
175impl fmt::Debug for Session {
176    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
177        f.debug_struct("Connection")
178            .field("next_request", &self.next_request) // TODO: Debug transport
179            .finish()
180    }
181}
182
183#[derive(Debug)]
184pub struct Server {
185    capabilities: server::Capabilities,
186    information: mcp::Server,
187}
188
189impl Server {
190    pub fn capabilities(&self) -> &server::Capabilities {
191        &self.capabilities
192    }
193
194    pub fn information(&self) -> &mcp::Server {
195        &self.information
196    }
197}