Skip to main content

steer_tools/
schema.rs

1use crate::error::ToolExecutionError;
2use schemars::JsonSchema;
3use serde::{Deserialize, Serialize, de::DeserializeOwned};
4use serde_json::Value;
5use std::error::Error as StdError;
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
8#[serde(transparent)]
9pub struct InputSchema(Value);
10
11impl InputSchema {
12    pub fn new(schema: Value) -> Self {
13        Self(schema)
14    }
15
16    pub fn as_value(&self) -> &Value {
17        &self.0
18    }
19
20    pub fn into_value(self) -> Value {
21        self.0
22    }
23
24    pub fn summary(&self) -> InputSchemaSummary {
25        InputSchemaSummary::from_value(&self.0)
26    }
27
28    pub fn object(properties: serde_json::Map<String, Value>, required: Vec<String>) -> Self {
29        let mut schema = serde_json::Map::new();
30        schema.insert("type".to_string(), Value::String("object".to_string()));
31        schema.insert("properties".to_string(), Value::Object(properties));
32        if !required.is_empty() {
33            let required_values = required.into_iter().map(Value::String).collect::<Vec<_>>();
34            schema.insert("required".to_string(), Value::Array(required_values));
35        }
36        Self(Value::Object(schema))
37    }
38
39    pub fn empty_object() -> Self {
40        Self::object(Default::default(), Vec::new())
41    }
42}
43
44impl From<Value> for InputSchema {
45    fn from(schema: Value) -> Self {
46        Self(schema)
47    }
48}
49
50impl From<schemars::Schema> for InputSchema {
51    fn from(schema: schemars::Schema) -> Self {
52        let schema_value = serde_json::to_value(&schema).unwrap_or(serde_json::Value::Null);
53        Self(ensure_object_properties(schema_value))
54    }
55}
56
57fn ensure_object_properties(schema: Value) -> Value {
58    let mut schema = schema;
59    if let Value::Object(obj) = &mut schema {
60        let is_object = obj
61            .get("type")
62            .and_then(|v| v.as_str())
63            .is_some_and(|t| t == "object");
64        if is_object && !obj.contains_key("properties") {
65            obj.insert(
66                "properties".to_string(),
67                Value::Object(serde_json::Map::new()),
68            );
69        }
70    }
71    schema
72}
73
74#[derive(Debug, Clone)]
75pub struct InputSchemaSummary {
76    pub properties: serde_json::Map<String, Value>,
77    pub required: Vec<String>,
78    pub schema_type: String,
79}
80
81impl InputSchemaSummary {
82    fn from_value(schema: &Value) -> Self {
83        let mut properties = serde_json::Map::new();
84        let mut required = std::collections::BTreeSet::new();
85
86        Self::merge_schema(schema, &mut properties, &mut required);
87
88        let mut schema_type = schema
89            .as_object()
90            .and_then(|obj| obj.get("type"))
91            .and_then(|v| v.as_str())
92            .unwrap_or_default()
93            .to_string();
94
95        if schema_type.is_empty() && (!properties.is_empty() || !required.is_empty()) {
96            schema_type = "object".to_string();
97        }
98
99        Self {
100            properties,
101            required: required.into_iter().collect(),
102            schema_type,
103        }
104    }
105
106    fn merge_schema(
107        schema: &Value,
108        properties: &mut serde_json::Map<String, Value>,
109        required: &mut std::collections::BTreeSet<String>,
110    ) {
111        let Some(obj) = schema.as_object() else {
112            return;
113        };
114
115        if let Some(prop_obj) = obj.get("properties").and_then(|v| v.as_object()) {
116            for (key, value) in prop_obj {
117                merge_property(properties, key, value);
118            }
119        }
120
121        if let Some(req) = obj.get("required").and_then(|v| v.as_array()) {
122            for item in req {
123                if let Some(name) = item.as_str() {
124                    required.insert(name.to_string());
125                }
126            }
127        }
128
129        if let Some(all_of) = obj.get("allOf").and_then(|v| v.as_array()) {
130            for sub in all_of {
131                let summary = InputSchemaSummary::from_value(sub);
132                for (key, value) in summary.properties {
133                    merge_property(properties, &key, &value);
134                }
135                required.extend(summary.required);
136            }
137        }
138
139        if let Some(one_of) = obj.get("oneOf").and_then(|v| v.as_array()) {
140            Self::merge_one_of(one_of, properties, required);
141        }
142
143        if let Some(any_of) = obj.get("anyOf").and_then(|v| v.as_array()) {
144            Self::merge_one_of(any_of, properties, required);
145        }
146    }
147
148    fn merge_one_of(
149        subschemas: &[Value],
150        properties: &mut serde_json::Map<String, Value>,
151        required: &mut std::collections::BTreeSet<String>,
152    ) {
153        let mut intersection: Option<std::collections::BTreeSet<String>> = None;
154
155        for sub in subschemas {
156            let summary = InputSchemaSummary::from_value(sub);
157            for (key, value) in summary.properties {
158                merge_property(properties, &key, &value);
159            }
160
161            let required_set: std::collections::BTreeSet<String> =
162                summary.required.into_iter().collect();
163
164            intersection = match intersection.take() {
165                None => Some(required_set),
166                Some(existing) => Some(
167                    existing
168                        .intersection(&required_set)
169                        .cloned()
170                        .collect::<std::collections::BTreeSet<String>>(),
171                ),
172            };
173        }
174
175        if let Some(required_set) = intersection {
176            required.extend(required_set);
177        }
178    }
179}
180
181fn merge_property(properties: &mut serde_json::Map<String, Value>, key: &str, value: &Value) {
182    match properties.get_mut(key) {
183        None => {
184            properties.insert(key.to_string(), value.clone());
185        }
186        Some(existing) => {
187            if existing == value {
188                return;
189            }
190            let existing_values = extract_enum_values(existing);
191            let incoming_values = extract_enum_values(value);
192            if incoming_values.is_empty() && existing_values.is_empty() {
193                return;
194            }
195
196            let mut combined = existing_values;
197            for item in incoming_values {
198                if !combined.contains(&item) {
199                    combined.push(item);
200                }
201            }
202
203            if combined.is_empty() {
204                return;
205            }
206
207            if let Some(obj) = existing.as_object_mut() {
208                obj.remove("const");
209                obj.insert("enum".to_string(), Value::Array(combined));
210            }
211        }
212    }
213}
214
215fn extract_enum_values(value: &Value) -> Vec<Value> {
216    let Some(obj) = value.as_object() else {
217        return Vec::new();
218    };
219
220    if let Some(enum_values) = obj.get("enum").and_then(|v| v.as_array()) {
221        return enum_values.clone();
222    }
223
224    if let Some(const_value) = obj.get("const") {
225        return vec![const_value.clone()];
226    }
227
228    Vec::new()
229}
230
231#[cfg(test)]
232mod tests {
233    use super::InputSchema;
234    use schemars::schema_for;
235
236    #[test]
237    fn dispatch_agent_schema_includes_target() {
238        let schema = schema_for!(crate::tools::dispatch_agent::DispatchAgentParams);
239        let input_schema: InputSchema = schema.into();
240        let summary = input_schema.summary();
241
242        assert!(summary.properties.contains_key("prompt"));
243        assert!(summary.properties.contains_key("target"));
244        assert!(summary.required.contains(&"prompt".to_string()));
245        assert!(summary.required.contains(&"target".to_string()));
246        assert_eq!(summary.schema_type, "object");
247
248        let root = input_schema.as_value();
249        let target_schema = root
250            .get("properties")
251            .and_then(|v| v.get("target"))
252            .expect("target schema should exist");
253        let target_schema = resolve_ref(root, target_schema);
254        let variants = target_schema
255            .get("oneOf")
256            .or_else(|| target_schema.get("anyOf"))
257            .or_else(|| target_schema.get("allOf"))
258            .and_then(|v| v.as_array())
259            .expect("target should be a tagged union");
260
261        let mut session_values = Vec::new();
262        for variant in variants {
263            let session_prop = variant
264                .get("properties")
265                .and_then(|v| v.get("session"))
266                .unwrap_or(&serde_json::Value::Null);
267            session_values.extend(super::extract_enum_values(session_prop));
268        }
269
270        assert!(session_values.contains(&serde_json::Value::String("new".to_string())));
271        assert!(session_values.contains(&serde_json::Value::String("resume".to_string())));
272    }
273
274    fn resolve_ref<'a>(
275        root: &'a serde_json::Value,
276        schema: &'a serde_json::Value,
277    ) -> &'a serde_json::Value {
278        let Some(reference) = schema.get("$ref").and_then(|v| v.as_str()) else {
279            return schema;
280        };
281        let path = reference.strip_prefix("#/").unwrap_or(reference);
282        let mut current = root;
283        for segment in path.split('/') {
284            let Some(next) = current.get(segment) else {
285                return schema;
286            };
287            current = next;
288        }
289        current
290    }
291}
292
293#[derive(Debug, Clone, Serialize, Deserialize)]
294pub struct ToolSchema {
295    pub name: String,
296    #[serde(default)]
297    pub display_name: String,
298    pub description: String,
299    pub input_schema: InputSchema,
300}
301
302pub trait ToolSpec {
303    type Params: DeserializeOwned + JsonSchema + Send;
304    type Result: Into<crate::result::ToolResult> + Send;
305    type Error: StdError + Send + Sync + 'static;
306
307    const NAME: &'static str;
308    const DISPLAY_NAME: &'static str;
309
310    fn execution_error(error: Self::Error) -> ToolExecutionError;
311}
312
313#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
314pub struct ToolCall {
315    pub name: String,
316    pub parameters: Value,
317    pub id: String,
318}