1use crate::mcp;
2use crate::mcp::server::tool::{IntoResponse, Response};
3use crate::mcp::server::{Notification, Request};
4use crate::mcp::{Map, Schema, Value};
5
6use futures::SinkExt;
7use futures::channel::mpsc;
8use serde::Serialize;
9use tokio::task;
10
11use std::collections::BTreeMap;
12use std::io;
13use std::marker::PhantomData;
14
15pub struct Tool<Name = String, Description = String> {
16 pub name: Name,
17 pub description: Description,
18 input: Schema,
19 output: Option<Schema>,
20 call: Box<dyn Fn(Value) -> io::Result<mpsc::Receiver<Action>> + Send + Sync>,
21}
22
23pub enum Action {
24 Request(Request),
25 Notify(Notification),
26 Finish(io::Result<Response>),
27}
28
29impl Tool<(), ()> {
30 pub unsafe fn new(
33 input: Schema,
34 output: Option<Schema>,
35 call: impl Fn(Value) -> io::Result<mpsc::Receiver<Action>> + Send + Sync + 'static,
36 ) -> Self {
37 Self {
38 name: (),
39 description: (),
40 input,
41 output,
42 call: Box::new(call),
43 }
44 }
45}
46
47impl<Name, Description> Tool<Name, Description> {
48 pub fn name(self, name: impl AsRef<str>) -> Tool<String, Description> {
49 Tool {
50 name: name.as_ref().to_owned(),
51 description: self.description,
52 input: self.input,
53 output: self.output,
54 call: self.call,
55 }
56 }
57
58 pub fn description(self, description: impl AsRef<str>) -> Tool<Name> {
59 Tool {
60 name: self.name,
61 description: description.as_ref().to_owned(),
62 input: self.input,
63 output: self.output,
64 call: self.call,
65 }
66 }
67}
68
69impl Tool {
70 pub fn input(&self) -> &Schema {
71 &self.input
72 }
73
74 pub fn output(&self) -> Option<&Schema> {
75 self.output.as_ref()
76 }
77
78 pub fn call(&self, json: Value) -> io::Result<mpsc::Receiver<Action>> {
79 (self.call)(json)
80 }
81}
82
83pub fn tool<A, O, F>(
84 f: impl Fn(A) -> F + Send + Sync + 'static,
85 a: impl Argument<A> + Send + Sync + 'static,
86) -> Tool<(), ()>
87where
88 O: IntoResponse,
89 O::Content: Serialize + Send,
90 F: Future<Output = O> + Send + 'static,
91{
92 let input = Schema::Object {
93 description: None,
94 properties: BTreeMap::from_iter([property(&a)]),
95 required: Vec::from_iter([required(&a)].into_iter().flatten()),
96 };
97
98 let call = move |json| {
99 let mut object = object(json)?;
100 let a = deserialize(&a, &mut object)?;
101
102 Ok(spawn(f(a)))
103 };
104
105 Tool {
106 name: (),
107 description: (),
108 input,
109 output: None, call: Box::new(call),
111 }
112}
113
114pub fn tool_2<A, B, O, F>(
115 f: impl Fn(A, B) -> F + Send + Sync + 'static,
116 a: impl Argument<A> + Send + Sync + 'static,
117 b: impl Argument<B> + Send + Sync + 'static,
118) -> Tool<(), ()>
119where
120 O: IntoResponse,
121 O::Content: Serialize + Send,
122 F: Future<Output = O> + Send + 'static,
123{
124 let input = Schema::Object {
125 description: None,
126 properties: BTreeMap::from_iter([property(&a), property(&b)]),
127 required: Vec::from_iter([required(&a), required(&b)].into_iter().flatten()),
128 };
129
130 let call = move |json| {
131 let mut object = object(json)?;
132 let a = deserialize(&a, &mut object)?;
133 let b = deserialize(&b, &mut object)?;
134
135 Ok(spawn(f(a, b)))
136 };
137
138 Tool {
139 name: (),
140 description: (),
141 input,
142 output: None, call: Box::new(call),
144 }
145}
146
147fn spawn<O>(execution: impl Future<Output = O> + Send + 'static) -> mpsc::Receiver<Action>
148where
149 O: IntoResponse,
150 O::Content: Serialize + Send,
151{
152 let (mut sender, receiver) = mpsc::channel(1);
153
154 task::spawn(async move {
155 let output = execution.await;
156
157 let result = output
158 .into_outcome()
159 .serialize()
160 .await
161 .map_err(io::Error::from);
162
163 let _ = sender.send(Action::Finish(result)).await;
164 });
165
166 receiver
167}
168
169pub trait Argument<T> {
170 fn name(&self) -> &str;
171
172 fn schema(&self) -> Schema;
173
174 fn deserialize(&self, json: Value) -> io::Result<T>;
175
176 fn is_required(&self) -> bool {
177 true
178 }
179}
180
181pub fn string(name: impl AsRef<str>, description: impl AsRef<str>) -> impl Argument<String> {
182 NamedArg::new(name, description)
183}
184
185pub fn u32(name: impl AsRef<str>, description: impl AsRef<str>) -> impl Argument<u32> {
186 NamedArg::new(name, description)
187}
188
189pub fn f32(name: impl AsRef<str>, description: impl AsRef<str>) -> impl Argument<f32> {
190 NamedArg::new(name, description)
191}
192
193pub fn bool(name: impl AsRef<str>, description: impl AsRef<str>) -> impl Argument<bool> {
194 NamedArg::new(name, description)
195}
196
197pub fn optional<T>(argument: impl Argument<T>) -> impl Argument<Option<T>> {
198 struct Optional<A, T> {
199 argument: A,
200 _output: PhantomData<T>,
201 }
202
203 impl<A, T> Argument<Option<T>> for Optional<A, T>
204 where
205 A: Argument<T>,
206 {
207 fn name(&self) -> &str {
208 self.argument.name()
209 }
210
211 fn schema(&self) -> Schema {
212 self.argument.schema()
213 }
214
215 fn deserialize(&self, json: Value) -> io::Result<Option<T>> {
216 if json.is_null() {
217 return Ok(None);
218 }
219
220 self.argument.deserialize(json).map(Some)
221 }
222
223 fn is_required(&self) -> bool {
224 false
225 }
226 }
227
228 Optional {
229 argument,
230 _output: PhantomData,
231 }
232}
233
234struct NamedArg {
235 name: String,
236 description: String,
237}
238
239impl NamedArg {
240 fn new(name: impl AsRef<str>, description: impl AsRef<str>) -> Self {
241 Self {
242 name: name.as_ref().to_owned(),
243 description: description.as_ref().to_owned(),
244 }
245 }
246}
247
248impl Argument<String> for NamedArg {
249 fn name(&self) -> &str {
250 &self.name
251 }
252
253 fn schema(&self) -> Schema {
254 Schema::String {
255 description: Some(self.description.clone()),
256 }
257 }
258
259 fn deserialize(&self, json: Value) -> io::Result<String> {
260 mcp::from_value(json)
261 }
262}
263
264impl Argument<u32> for NamedArg {
265 fn name(&self) -> &str {
266 &self.name
267 }
268
269 fn schema(&self) -> Schema {
270 Schema::Integer {
271 description: Some(self.description.clone()),
272 }
273 }
274
275 fn deserialize(&self, json: Value) -> io::Result<u32> {
276 mcp::from_value(json)
277 }
278}
279
280impl Argument<f32> for NamedArg {
281 fn name(&self) -> &str {
282 &self.name
283 }
284
285 fn schema(&self) -> Schema {
286 Schema::Number {
287 description: Some(self.description.clone()),
288 }
289 }
290
291 fn deserialize(&self, json: Value) -> io::Result<f32> {
292 mcp::from_value(json)
293 }
294}
295
296impl Argument<bool> for NamedArg {
297 fn name(&self) -> &str {
298 &self.name
299 }
300
301 fn schema(&self) -> Schema {
302 Schema::Boolean {
303 description: Some(self.description.clone()),
304 }
305 }
306
307 fn deserialize(&self, json: Value) -> io::Result<bool> {
308 mcp::from_value(json)
309 }
310}
311
312fn property<T>(arg: &impl Argument<T>) -> (String, Schema) {
313 (arg.name().to_owned(), arg.schema().clone())
314}
315
316fn required<T>(arg: &impl Argument<T>) -> Option<String> {
317 arg.is_required().then(|| arg.name().to_owned())
318}
319
320fn object(json: Value) -> io::Result<Map<String, Value>> {
321 mcp::from_value(json)
322}
323
324fn deserialize<T>(arg: &impl Argument<T>, object: &mut Map<String, Value>) -> io::Result<T> {
325 arg.deserialize(object.remove(arg.name()).unwrap_or(Value::Null))
326}