1use std::collections::{HashMap, HashSet};
2
3use synaptic_core::ToolDefinition;
4
5#[derive(Debug, Clone, Default)]
7pub struct FilterContext {
8 pub turn_count: usize,
10 pub last_tool: Option<String>,
12 pub metadata: HashMap<String, serde_json::Value>,
14}
15
16pub trait ToolFilter: Send + Sync {
18 fn filter(&self, tools: Vec<ToolDefinition>, context: &FilterContext) -> Vec<ToolDefinition>;
20}
21
22pub struct AllowListFilter {
24 allowed: HashSet<String>,
25}
26
27impl AllowListFilter {
28 pub fn new(allowed: impl IntoIterator<Item = impl Into<String>>) -> Self {
29 Self {
30 allowed: allowed.into_iter().map(|s| s.into()).collect(),
31 }
32 }
33}
34
35impl ToolFilter for AllowListFilter {
36 fn filter(&self, tools: Vec<ToolDefinition>, _context: &FilterContext) -> Vec<ToolDefinition> {
37 tools
38 .into_iter()
39 .filter(|t| self.allowed.contains(&t.name))
40 .collect()
41 }
42}
43
44pub struct DenyListFilter {
46 denied: HashSet<String>,
47}
48
49impl DenyListFilter {
50 pub fn new(denied: impl IntoIterator<Item = impl Into<String>>) -> Self {
51 Self {
52 denied: denied.into_iter().map(|s| s.into()).collect(),
53 }
54 }
55}
56
57impl ToolFilter for DenyListFilter {
58 fn filter(&self, tools: Vec<ToolDefinition>, _context: &FilterContext) -> Vec<ToolDefinition> {
59 tools
60 .into_iter()
61 .filter(|t| !self.denied.contains(&t.name))
62 .collect()
63 }
64}
65
66pub struct StateMachineFilter {
69 after_tool_rules: HashMap<String, HashSet<String>>,
71 turn_thresholds: Vec<TurnThreshold>,
73}
74
75#[derive(Debug, Clone)]
77struct TurnThreshold {
78 min_turns: usize,
80 add_tools: HashSet<String>,
82}
83
84impl StateMachineFilter {
85 pub fn new() -> Self {
86 Self {
87 after_tool_rules: HashMap::new(),
88 turn_thresholds: Vec::new(),
89 }
90 }
91
92 pub fn after_tool(
94 mut self,
95 tool_name: impl Into<String>,
96 allowed_next: impl IntoIterator<Item = impl Into<String>>,
97 ) -> Self {
98 self.after_tool_rules.insert(
99 tool_name.into(),
100 allowed_next.into_iter().map(|s| s.into()).collect(),
101 );
102 self
103 }
104
105 pub fn turn_threshold(
107 mut self,
108 min_turns: usize,
109 add_tools: impl IntoIterator<Item = impl Into<String>>,
110 ) -> Self {
111 self.turn_thresholds.push(TurnThreshold {
112 min_turns,
113 add_tools: add_tools.into_iter().map(|s| s.into()).collect(),
114 });
115 self
116 }
117}
118
119impl Default for StateMachineFilter {
120 fn default() -> Self {
121 Self::new()
122 }
123}
124
125impl ToolFilter for StateMachineFilter {
126 fn filter(&self, tools: Vec<ToolDefinition>, context: &FilterContext) -> Vec<ToolDefinition> {
127 let mut result = tools;
128
129 if let Some(last) = &context.last_tool {
131 if let Some(allowed) = self.after_tool_rules.get(last) {
132 result = result
133 .into_iter()
134 .filter(|t| allowed.contains(&t.name))
135 .collect();
136 }
137 }
138
139 let mut gated_tools: HashMap<&str, bool> = HashMap::new();
143 for threshold in &self.turn_thresholds {
144 let met = context.turn_count >= threshold.min_turns;
145 for tool_name in &threshold.add_tools {
146 let entry = gated_tools.entry(tool_name.as_str()).or_insert(false);
147 if met {
148 *entry = true;
149 }
150 }
151 }
152
153 if !gated_tools.is_empty() {
154 result = result
155 .into_iter()
156 .filter(|t| {
157 match gated_tools.get(t.name.as_str()) {
158 Some(&met) => met, None => true, }
161 })
162 .collect();
163 }
164
165 result
166 }
167}
168
169pub struct CompositeFilter(pub Vec<Box<dyn ToolFilter>>);
171
172impl CompositeFilter {
173 pub fn new(filters: Vec<Box<dyn ToolFilter>>) -> Self {
174 Self(filters)
175 }
176}
177
178impl ToolFilter for CompositeFilter {
179 fn filter(
180 &self,
181 mut tools: Vec<ToolDefinition>,
182 context: &FilterContext,
183 ) -> Vec<ToolDefinition> {
184 for f in &self.0 {
185 tools = f.filter(tools, context);
186 }
187 tools
188 }
189}