1use std::collections::{HashMap, HashSet};
2
3use synaptic_core::ToolDefinition;
4
5#[derive(Debug, Clone, Default)]
7pub struct FilterContext {
8 pub turn_count: usize,
10 pub last_tool: Option<String>,
12 pub metadata: HashMap<String, serde_json::Value>,
14}
15
16pub trait ToolFilter: Send + Sync {
18 fn filter(&self, tools: Vec<ToolDefinition>, context: &FilterContext) -> Vec<ToolDefinition>;
20}
21
22pub struct AllowListFilter {
24 allowed: HashSet<String>,
25}
26
27impl AllowListFilter {
28 pub fn new(allowed: impl IntoIterator<Item = impl Into<String>>) -> Self {
29 Self {
30 allowed: allowed.into_iter().map(|s| s.into()).collect(),
31 }
32 }
33}
34
35impl ToolFilter for AllowListFilter {
36 fn filter(&self, tools: Vec<ToolDefinition>, _context: &FilterContext) -> Vec<ToolDefinition> {
37 tools
38 .into_iter()
39 .filter(|t| self.allowed.contains(&t.name))
40 .collect()
41 }
42}
43
44pub struct DenyListFilter {
46 denied: HashSet<String>,
47}
48
49impl DenyListFilter {
50 pub fn new(denied: impl IntoIterator<Item = impl Into<String>>) -> Self {
51 Self {
52 denied: denied.into_iter().map(|s| s.into()).collect(),
53 }
54 }
55}
56
57impl ToolFilter for DenyListFilter {
58 fn filter(&self, tools: Vec<ToolDefinition>, _context: &FilterContext) -> Vec<ToolDefinition> {
59 tools
60 .into_iter()
61 .filter(|t| !self.denied.contains(&t.name))
62 .collect()
63 }
64}
65
66pub struct StateMachineFilter {
69 after_tool_rules: HashMap<String, HashSet<String>>,
71 turn_thresholds: Vec<TurnThreshold>,
73}
74
75#[derive(Debug, Clone)]
77struct TurnThreshold {
78 min_turns: usize,
80 add_tools: HashSet<String>,
82}
83
84impl StateMachineFilter {
85 pub fn new() -> Self {
86 Self {
87 after_tool_rules: HashMap::new(),
88 turn_thresholds: Vec::new(),
89 }
90 }
91
92 pub fn after_tool(
94 mut self,
95 tool_name: impl Into<String>,
96 allowed_next: impl IntoIterator<Item = impl Into<String>>,
97 ) -> Self {
98 self.after_tool_rules.insert(
99 tool_name.into(),
100 allowed_next.into_iter().map(|s| s.into()).collect(),
101 );
102 self
103 }
104
105 pub fn turn_threshold(
107 mut self,
108 min_turns: usize,
109 add_tools: impl IntoIterator<Item = impl Into<String>>,
110 ) -> Self {
111 self.turn_thresholds.push(TurnThreshold {
112 min_turns,
113 add_tools: add_tools.into_iter().map(|s| s.into()).collect(),
114 });
115 self
116 }
117}
118
119impl Default for StateMachineFilter {
120 fn default() -> Self {
121 Self::new()
122 }
123}
124
125impl ToolFilter for StateMachineFilter {
126 fn filter(&self, tools: Vec<ToolDefinition>, context: &FilterContext) -> Vec<ToolDefinition> {
127 let mut result = tools;
128
129 if let Some(last) = &context.last_tool {
131 if let Some(allowed) = self.after_tool_rules.get(last) {
132 result.retain(|t| allowed.contains(&t.name));
133 }
134 }
135
136 let mut gated_tools: HashMap<&str, bool> = HashMap::new();
140 for threshold in &self.turn_thresholds {
141 let met = context.turn_count >= threshold.min_turns;
142 for tool_name in &threshold.add_tools {
143 let entry = gated_tools.entry(tool_name.as_str()).or_insert(false);
144 if met {
145 *entry = true;
146 }
147 }
148 }
149
150 if !gated_tools.is_empty() {
151 result.retain(|t| {
152 match gated_tools.get(t.name.as_str()) {
153 Some(&met) => met, None => true, }
156 });
157 }
158
159 result
160 }
161}
162
163pub struct CompositeFilter(pub Vec<Box<dyn ToolFilter>>);
165
166impl CompositeFilter {
167 pub fn new(filters: Vec<Box<dyn ToolFilter>>) -> Self {
168 Self(filters)
169 }
170}
171
172impl ToolFilter for CompositeFilter {
173 fn filter(
174 &self,
175 mut tools: Vec<ToolDefinition>,
176 context: &FilterContext,
177 ) -> Vec<ToolDefinition> {
178 for f in &self.0 {
179 tools = f.filter(tools, context);
180 }
181 tools
182 }
183}