techne_server/
tool.rs

1use crate::mcp;
2use crate::mcp::server::tool::{IntoResponse, Response};
3use crate::mcp::server::{Notification, Request};
4use crate::mcp::{Map, Schema, Value};
5
6use futures::SinkExt;
7use futures::channel::mpsc;
8use serde::Serialize;
9use tokio::task;
10
11use std::collections::BTreeMap;
12use std::io;
13use std::marker::PhantomData;
14
15pub struct Tool<Name = String, Description = String> {
16    pub name: Name,
17    pub description: Description,
18    input: Schema,
19    output: Option<Schema>,
20    call: Box<dyn Fn(Value) -> io::Result<mpsc::Receiver<Action>> + Send + Sync>,
21}
22
23pub enum Action {
24    Request(Request),
25    Notify(Notification),
26    Finish(io::Result<Response>),
27}
28
29impl Tool<(), ()> {
30    /// # Safety
31    /// The input and output schemas must match the `call` implementation.
32    pub unsafe fn new(
33        input: Schema,
34        output: Option<Schema>,
35        call: impl Fn(Value) -> io::Result<mpsc::Receiver<Action>> + Send + Sync + 'static,
36    ) -> Self {
37        Self {
38            name: (),
39            description: (),
40            input,
41            output,
42            call: Box::new(call),
43        }
44    }
45}
46
47impl<Name, Description> Tool<Name, Description> {
48    pub fn name(self, name: impl AsRef<str>) -> Tool<String, Description> {
49        Tool {
50            name: name.as_ref().to_owned(),
51            description: self.description,
52            input: self.input,
53            output: self.output,
54            call: self.call,
55        }
56    }
57
58    pub fn description(self, description: impl AsRef<str>) -> Tool<Name> {
59        Tool {
60            name: self.name,
61            description: description.as_ref().to_owned(),
62            input: self.input,
63            output: self.output,
64            call: self.call,
65        }
66    }
67}
68
69impl Tool {
70    pub fn input(&self) -> &Schema {
71        &self.input
72    }
73
74    pub fn output(&self) -> Option<&Schema> {
75        self.output.as_ref()
76    }
77
78    pub fn call(&self, json: Value) -> io::Result<mpsc::Receiver<Action>> {
79        (self.call)(json)
80    }
81}
82
83pub fn tool<A, O, F>(
84    f: impl Fn(A) -> F + Send + Sync + 'static,
85    a: impl Argument<A> + Send + Sync + 'static,
86) -> Tool<(), ()>
87where
88    O: IntoResponse,
89    O::Content: Serialize + Send,
90    F: Future<Output = O> + Send + 'static,
91{
92    let input = Schema::Object {
93        description: None,
94        properties: BTreeMap::from_iter([property(&a)]),
95        required: Vec::from_iter([required(&a)].into_iter().flatten()),
96    };
97
98    let call = move |json| {
99        let mut object = object(json)?;
100        let a = deserialize(&a, &mut object)?;
101
102        Ok(spawn(f(a)))
103    };
104
105    Tool {
106        name: (),
107        description: (),
108        input,
109        output: None, // TODO
110        call: Box::new(call),
111    }
112}
113
114pub fn tool_2<A, B, O, F>(
115    f: impl Fn(A, B) -> F + Send + Sync + 'static,
116    a: impl Argument<A> + Send + Sync + 'static,
117    b: impl Argument<B> + Send + Sync + 'static,
118) -> Tool<(), ()>
119where
120    O: IntoResponse,
121    O::Content: Serialize + Send,
122    F: Future<Output = O> + Send + 'static,
123{
124    let input = Schema::Object {
125        description: None,
126        properties: BTreeMap::from_iter([property(&a), property(&b)]),
127        required: Vec::from_iter([required(&a), required(&b)].into_iter().flatten()),
128    };
129
130    let call = move |json| {
131        let mut object = object(json)?;
132        let a = deserialize(&a, &mut object)?;
133        let b = deserialize(&b, &mut object)?;
134
135        Ok(spawn(f(a, b)))
136    };
137
138    Tool {
139        name: (),
140        description: (),
141        input,
142        output: None, // TODO
143        call: Box::new(call),
144    }
145}
146
147fn spawn<O>(execution: impl Future<Output = O> + Send + 'static) -> mpsc::Receiver<Action>
148where
149    O: IntoResponse,
150    O::Content: Serialize + Send,
151{
152    let (mut sender, receiver) = mpsc::channel(1);
153
154    task::spawn(async move {
155        let output = execution.await;
156
157        let result = output
158            .into_outcome()
159            .serialize()
160            .await
161            .map_err(io::Error::from);
162
163        let _ = sender.send(Action::Finish(result)).await;
164    });
165
166    receiver
167}
168
169pub trait Argument<T> {
170    fn name(&self) -> &str;
171
172    fn schema(&self) -> Schema;
173
174    fn deserialize(&self, json: Value) -> io::Result<T>;
175
176    fn is_required(&self) -> bool {
177        true
178    }
179}
180
181pub fn string(name: impl AsRef<str>, description: impl AsRef<str>) -> impl Argument<String> {
182    NamedArg::new(name, description)
183}
184
185pub fn u32(name: impl AsRef<str>, description: impl AsRef<str>) -> impl Argument<u32> {
186    NamedArg::new(name, description)
187}
188
189pub fn f32(name: impl AsRef<str>, description: impl AsRef<str>) -> impl Argument<f32> {
190    NamedArg::new(name, description)
191}
192
193pub fn bool(name: impl AsRef<str>, description: impl AsRef<str>) -> impl Argument<bool> {
194    NamedArg::new(name, description)
195}
196
197pub fn optional<T>(argument: impl Argument<T>) -> impl Argument<Option<T>> {
198    struct Optional<A, T> {
199        argument: A,
200        _output: PhantomData<T>,
201    }
202
203    impl<A, T> Argument<Option<T>> for Optional<A, T>
204    where
205        A: Argument<T>,
206    {
207        fn name(&self) -> &str {
208            self.argument.name()
209        }
210
211        fn schema(&self) -> Schema {
212            self.argument.schema()
213        }
214
215        fn deserialize(&self, json: Value) -> io::Result<Option<T>> {
216            if json.is_null() {
217                return Ok(None);
218            }
219
220            self.argument.deserialize(json).map(Some)
221        }
222
223        fn is_required(&self) -> bool {
224            false
225        }
226    }
227
228    Optional {
229        argument,
230        _output: PhantomData,
231    }
232}
233
234struct NamedArg {
235    name: String,
236    description: String,
237}
238
239impl NamedArg {
240    fn new(name: impl AsRef<str>, description: impl AsRef<str>) -> Self {
241        Self {
242            name: name.as_ref().to_owned(),
243            description: description.as_ref().to_owned(),
244        }
245    }
246}
247
248impl Argument<String> for NamedArg {
249    fn name(&self) -> &str {
250        &self.name
251    }
252
253    fn schema(&self) -> Schema {
254        Schema::String {
255            description: Some(self.description.clone()),
256        }
257    }
258
259    fn deserialize(&self, json: Value) -> io::Result<String> {
260        mcp::from_value(json)
261    }
262}
263
264impl Argument<u32> for NamedArg {
265    fn name(&self) -> &str {
266        &self.name
267    }
268
269    fn schema(&self) -> Schema {
270        Schema::Integer {
271            description: Some(self.description.clone()),
272        }
273    }
274
275    fn deserialize(&self, json: Value) -> io::Result<u32> {
276        mcp::from_value(json)
277    }
278}
279
280impl Argument<f32> for NamedArg {
281    fn name(&self) -> &str {
282        &self.name
283    }
284
285    fn schema(&self) -> Schema {
286        Schema::Number {
287            description: Some(self.description.clone()),
288        }
289    }
290
291    fn deserialize(&self, json: Value) -> io::Result<f32> {
292        mcp::from_value(json)
293    }
294}
295
296impl Argument<bool> for NamedArg {
297    fn name(&self) -> &str {
298        &self.name
299    }
300
301    fn schema(&self) -> Schema {
302        Schema::Boolean {
303            description: Some(self.description.clone()),
304        }
305    }
306
307    fn deserialize(&self, json: Value) -> io::Result<bool> {
308        mcp::from_value(json)
309    }
310}
311
312fn property<T>(arg: &impl Argument<T>) -> (String, Schema) {
313    (arg.name().to_owned(), arg.schema().clone())
314}
315
316fn required<T>(arg: &impl Argument<T>) -> Option<String> {
317    arg.is_required().then(|| arg.name().to_owned())
318}
319
320fn object(json: Value) -> io::Result<Map<String, Value>> {
321    mcp::from_value(json)
322}
323
324fn deserialize<T>(arg: &impl Argument<T>, object: &mut Map<String, Value>) -> io::Result<T> {
325    arg.deserialize(object.remove(arg.name()).unwrap_or(Value::Null))
326}