1use crate::types::Layer3Result;
6use async_trait::async_trait;
7
8#[async_trait]
12pub trait GuardRail: Send + Sync {
13 fn name(&self) -> &str;
15
16 async fn check_input(&self, input: &str) -> Layer3Result<GuardResult>;
18
19 async fn check_output(&self, output: &str) -> Layer3Result<GuardResult>;
21
22 async fn fix_input(&self, input: &str) -> Layer3Result<String>;
24
25 async fn fix_output(&self, output: &str) -> Layer3Result<String>;
27}
28
29#[derive(Debug, Clone)]
31pub struct GuardResult {
32 pub passed: bool,
34 pub issue: Option<GuardIssue>,
36 pub suggestion: Option<String>,
38}
39
40#[derive(Debug, Clone, PartialEq, Eq)]
42pub enum GuardIssue {
43 SensitiveData,
45 FormatError,
47 TooLong,
49 TooShort,
51 DangerousInstruction,
53 OffTopic,
55 Custom(String),
57}
58
59pub struct GuardRailsComposite {
61 rails: Vec<Box<dyn GuardRail>>,
62}
63
64impl GuardRailsComposite {
65 pub fn new() -> Self {
66 Self { rails: Vec::new() }
67 }
68
69 pub fn add(&mut self, rail: Box<dyn GuardRail>) {
70 self.rails.push(rail);
71 }
72
73 pub async fn check_input_all(&self, input: &str) -> Layer3Result<Vec<GuardResult>> {
74 let mut results = Vec::new();
75 for rail in &self.rails {
76 results.push(rail.check_input(input).await?);
77 }
78 Ok(results)
79 }
80
81 pub async fn check_output_all(&self, output: &str) -> Layer3Result<Vec<GuardResult>> {
82 let mut results = Vec::new();
83 for rail in &self.rails {
84 results.push(rail.check_output(output).await?);
85 }
86 Ok(results)
87 }
88}
89
90impl Default for GuardRailsComposite {
91 fn default() -> Self {
92 Self::new()
93 }
94}
95
96pub struct LengthGuard {
98 min_length: usize,
99 max_length: usize,
100}
101
102impl LengthGuard {
103 pub fn new(min_length: usize, max_length: usize) -> Self {
104 Self {
105 min_length,
106 max_length,
107 }
108 }
109}
110
111impl Default for LengthGuard {
112 fn default() -> Self {
113 Self::new(1, 10000)
114 }
115}
116
117#[async_trait]
118impl GuardRail for LengthGuard {
119 fn name(&self) -> &str {
120 "length"
121 }
122
123 async fn check_input(&self, input: &str) -> Layer3Result<GuardResult> {
124 let len = input.len();
125 if len < self.min_length {
126 return Ok(GuardResult {
127 passed: false,
128 issue: Some(GuardIssue::TooShort),
129 suggestion: Some(format!("Minimum length: {}", self.min_length)),
130 });
131 }
132 if len > self.max_length {
133 return Ok(GuardResult {
134 passed: false,
135 issue: Some(GuardIssue::TooLong),
136 suggestion: Some(format!("Maximum length: {}", self.max_length)),
137 });
138 }
139 Ok(GuardResult {
140 passed: true,
141 issue: None,
142 suggestion: None,
143 })
144 }
145
146 async fn check_output(&self, output: &str) -> Layer3Result<GuardResult> {
147 self.check_input(output).await
148 }
149
150 async fn fix_input(&self, input: &str) -> Layer3Result<String> {
151 Ok(input.to_string())
152 }
153
154 async fn fix_output(&self, output: &str) -> Layer3Result<String> {
155 if output.len() > self.max_length {
156 Ok(output[..self.max_length].to_string())
157 } else {
158 Ok(output.to_string())
159 }
160 }
161}
162
163pub struct RegexGuard {
165 pattern: regex::Regex,
166 block_matches: bool,
167 name: String,
168}
169
170impl RegexGuard {
171 pub fn new(pattern: regex::Regex, block_matches: bool, name: impl Into<String>) -> Self {
172 Self {
173 pattern,
174 block_matches,
175 name: name.into(),
176 }
177 }
178}
179
180#[async_trait]
181impl GuardRail for RegexGuard {
182 fn name(&self) -> &str {
183 &self.name
184 }
185
186 async fn check_input(&self, input: &str) -> Layer3Result<GuardResult> {
187 let matches = self.pattern.is_match(input);
188 let passed = if self.block_matches {
189 !matches
190 } else {
191 matches
192 };
193 Ok(GuardResult {
194 passed,
195 issue: if passed {
196 None
197 } else {
198 Some(GuardIssue::FormatError)
199 },
200 suggestion: None,
201 })
202 }
203
204 async fn check_output(&self, output: &str) -> Layer3Result<GuardResult> {
205 self.check_input(output).await
206 }
207
208 async fn fix_input(&self, input: &str) -> Layer3Result<String> {
209 Ok(self.pattern.replace_all(input, "").to_string())
210 }
211
212 async fn fix_output(&self, output: &str) -> Layer3Result<String> {
213 self.fix_input(output).await
214 }
215}
216
217#[cfg(test)]
218mod tests {
219 use super::*;
220
221 #[tokio::test]
222 async fn test_length_guard() {
223 let guard = LengthGuard::new(5, 100);
224 let result = guard.check_input("hello").await.unwrap();
225 assert!(result.passed);
226 }
227
228 #[tokio::test]
229 async fn test_length_guard_too_short() {
230 let guard = LengthGuard::new(10, 100);
231 let result = guard.check_input("hi").await.unwrap();
232 assert!(!result.passed);
233 }
234}