Skip to main content

synaptic_tools/
filter.rs

1use std::collections::{HashMap, HashSet};
2
3use synaptic_core::ToolDefinition;
4
5/// Context available when filtering tools.
6#[derive(Debug, Clone, Default)]
7pub struct FilterContext {
8    /// Number of agent turns completed so far.
9    pub turn_count: usize,
10    /// Name of the last tool that was called, if any.
11    pub last_tool: Option<String>,
12    /// Arbitrary metadata for custom filter logic.
13    pub metadata: HashMap<String, serde_json::Value>,
14}
15
16/// Trait for filtering available tools based on context.
17pub trait ToolFilter: Send + Sync {
18    /// Filter the list of tool definitions based on the current context.
19    fn filter(&self, tools: Vec<ToolDefinition>, context: &FilterContext) -> Vec<ToolDefinition>;
20}
21
22/// Only allows tools whose names are in the allow list.
23pub struct AllowListFilter {
24    allowed: HashSet<String>,
25}
26
27impl AllowListFilter {
28    pub fn new(allowed: impl IntoIterator<Item = impl Into<String>>) -> Self {
29        Self {
30            allowed: allowed.into_iter().map(|s| s.into()).collect(),
31        }
32    }
33}
34
35impl ToolFilter for AllowListFilter {
36    fn filter(&self, tools: Vec<ToolDefinition>, _context: &FilterContext) -> Vec<ToolDefinition> {
37        tools
38            .into_iter()
39            .filter(|t| self.allowed.contains(&t.name))
40            .collect()
41    }
42}
43
44/// Removes tools whose names are in the deny list.
45pub struct DenyListFilter {
46    denied: HashSet<String>,
47}
48
49impl DenyListFilter {
50    pub fn new(denied: impl IntoIterator<Item = impl Into<String>>) -> Self {
51        Self {
52            denied: denied.into_iter().map(|s| s.into()).collect(),
53        }
54    }
55}
56
57impl ToolFilter for DenyListFilter {
58    fn filter(&self, tools: Vec<ToolDefinition>, _context: &FilterContext) -> Vec<ToolDefinition> {
59        tools
60            .into_iter()
61            .filter(|t| !self.denied.contains(&t.name))
62            .collect()
63    }
64}
65
66/// Filters tools based on state machine rules: which tools are allowed
67/// after certain tools, and which tools become available after N turns.
68pub struct StateMachineFilter {
69    /// Rules keyed by the name of the last tool called.
70    after_tool_rules: HashMap<String, HashSet<String>>,
71    /// Rules that gate tools behind a turn count threshold.
72    turn_thresholds: Vec<TurnThreshold>,
73}
74
75/// Rule for what tools to add after a certain number of turns.
76#[derive(Debug, Clone)]
77struct TurnThreshold {
78    /// Minimum turn count for this rule to apply.
79    min_turns: usize,
80    /// Tools gated behind this threshold.
81    add_tools: HashSet<String>,
82}
83
84impl StateMachineFilter {
85    pub fn new() -> Self {
86        Self {
87            after_tool_rules: HashMap::new(),
88            turn_thresholds: Vec::new(),
89        }
90    }
91
92    /// Add a rule: after `tool_name` is called, only `allowed_next` tools are available.
93    pub fn after_tool(
94        mut self,
95        tool_name: impl Into<String>,
96        allowed_next: impl IntoIterator<Item = impl Into<String>>,
97    ) -> Self {
98        self.after_tool_rules.insert(
99            tool_name.into(),
100            allowed_next.into_iter().map(|s| s.into()).collect(),
101        );
102        self
103    }
104
105    /// Add a rule: tools in `add_tools` are hidden until `min_turns` is reached.
106    pub fn turn_threshold(
107        mut self,
108        min_turns: usize,
109        add_tools: impl IntoIterator<Item = impl Into<String>>,
110    ) -> Self {
111        self.turn_thresholds.push(TurnThreshold {
112            min_turns,
113            add_tools: add_tools.into_iter().map(|s| s.into()).collect(),
114        });
115        self
116    }
117}
118
119impl Default for StateMachineFilter {
120    fn default() -> Self {
121        Self::new()
122    }
123}
124
125impl ToolFilter for StateMachineFilter {
126    fn filter(&self, tools: Vec<ToolDefinition>, context: &FilterContext) -> Vec<ToolDefinition> {
127        let mut result = tools;
128
129        // Apply after-tool rules: restrict to only allowed_next
130        if let Some(last) = &context.last_tool {
131            if let Some(allowed) = self.after_tool_rules.get(last) {
132                result = result
133                    .into_iter()
134                    .filter(|t| allowed.contains(&t.name))
135                    .collect();
136            }
137        }
138
139        // Apply turn thresholds: collect tools that are gated by turn count.
140        // If a tool appears in ANY threshold, it is only included when that
141        // threshold is met.
142        let mut gated_tools: HashMap<&str, bool> = HashMap::new();
143        for threshold in &self.turn_thresholds {
144            let met = context.turn_count >= threshold.min_turns;
145            for tool_name in &threshold.add_tools {
146                let entry = gated_tools.entry(tool_name.as_str()).or_insert(false);
147                if met {
148                    *entry = true;
149                }
150            }
151        }
152
153        if !gated_tools.is_empty() {
154            result = result
155                .into_iter()
156                .filter(|t| {
157                    match gated_tools.get(t.name.as_str()) {
158                        Some(&met) => met, // Gated tool: only include if threshold met
159                        None => true,      // Not gated: always include
160                    }
161                })
162                .collect();
163        }
164
165        result
166    }
167}
168
169/// Composes multiple filters, applying them in sequence.
170pub struct CompositeFilter(pub Vec<Box<dyn ToolFilter>>);
171
172impl CompositeFilter {
173    pub fn new(filters: Vec<Box<dyn ToolFilter>>) -> Self {
174        Self(filters)
175    }
176}
177
178impl ToolFilter for CompositeFilter {
179    fn filter(
180        &self,
181        mut tools: Vec<ToolDefinition>,
182        context: &FilterContext,
183    ) -> Vec<ToolDefinition> {
184        for f in &self.0 {
185            tools = f.filter(tools, context);
186        }
187        tools
188    }
189}