1use std::collections::HashMap;
8use std::sync::Arc;
9use std::time::{Duration, Instant};
10
11use async_trait::async_trait;
12use pingora_timeout::timeout;
13use sentinel_agent_protocol::{
14 GuardrailDetection, GuardrailInspectEvent, GuardrailInspectionType, GuardrailResponse,
15};
16use sentinel_config::{
17 GuardrailAction, GuardrailFailureMode, PiiDetectionConfig, PromptInjectionConfig,
18};
19use tracing::{debug, trace, warn};
20
21use crate::agents::AgentManager;
22
23#[derive(Debug)]
25pub enum PromptInjectionResult {
26 Clean,
28 Blocked {
30 status: u16,
31 message: String,
32 detections: Vec<GuardrailDetection>,
33 },
34 Detected { detections: Vec<GuardrailDetection> },
36 Warning { detections: Vec<GuardrailDetection> },
38 Error { message: String },
40}
41
42#[derive(Debug)]
44pub enum PiiCheckResult {
45 Clean,
47 Detected {
49 detections: Vec<GuardrailDetection>,
50 redacted_content: Option<String>,
51 },
52 Error { message: String },
54}
55
56#[async_trait]
60pub trait GuardrailAgentCaller: Send + Sync {
61 async fn call_guardrail_agent(
63 &self,
64 agent_name: &str,
65 event: GuardrailInspectEvent,
66 ) -> Result<GuardrailResponse, String>;
67}
68
69pub struct AgentManagerCaller {
71 #[allow(dead_code)]
72 agent_manager: Arc<AgentManager>,
73}
74
75impl AgentManagerCaller {
76 pub fn new(agent_manager: Arc<AgentManager>) -> Self {
78 Self { agent_manager }
79 }
80}
81
82#[async_trait]
83impl GuardrailAgentCaller for AgentManagerCaller {
84 async fn call_guardrail_agent(
85 &self,
86 agent_name: &str,
87 event: GuardrailInspectEvent,
88 ) -> Result<GuardrailResponse, String> {
89 trace!(
96 agent = agent_name,
97 inspection_type = ?event.inspection_type,
98 "Calling guardrail agent"
99 );
100
101 Err(format!(
104 "Agent '{}' not configured for guardrail inspection",
105 agent_name
106 ))
107 }
108}
109
110pub struct GuardrailProcessor {
115 agent_caller: Arc<dyn GuardrailAgentCaller>,
116}
117
118impl GuardrailProcessor {
119 pub fn new(agent_manager: Arc<AgentManager>) -> Self {
121 Self {
122 agent_caller: Arc::new(AgentManagerCaller::new(agent_manager)),
123 }
124 }
125
126 pub fn with_caller(agent_caller: Arc<dyn GuardrailAgentCaller>) -> Self {
130 Self { agent_caller }
131 }
132
133 pub async fn check_prompt_injection(
142 &self,
143 config: &PromptInjectionConfig,
144 content: &str,
145 model: Option<&str>,
146 route_id: Option<&str>,
147 correlation_id: &str,
148 ) -> PromptInjectionResult {
149 if !config.enabled {
150 return PromptInjectionResult::Clean;
151 }
152
153 trace!(
154 correlation_id = correlation_id,
155 agent = %config.agent,
156 content_len = content.len(),
157 "Checking content for prompt injection"
158 );
159
160 let event = GuardrailInspectEvent {
161 correlation_id: correlation_id.to_string(),
162 inspection_type: GuardrailInspectionType::PromptInjection,
163 content: content.to_string(),
164 model: model.map(String::from),
165 categories: vec![],
166 route_id: route_id.map(String::from),
167 metadata: HashMap::new(),
168 };
169
170 let start = Instant::now();
171 let timeout_duration = Duration::from_millis(config.timeout_ms);
172
173 match timeout(
175 timeout_duration,
176 self.agent_caller.call_guardrail_agent(&config.agent, event),
177 )
178 .await
179 {
180 Ok(Ok(response)) => {
181 let duration = start.elapsed();
182 debug!(
183 correlation_id = correlation_id,
184 agent = %config.agent,
185 detected = response.detected,
186 confidence = response.confidence,
187 detection_count = response.detections.len(),
188 duration_ms = duration.as_millis(),
189 "Prompt injection check completed"
190 );
191
192 if response.detected {
193 match config.action {
194 GuardrailAction::Block => PromptInjectionResult::Blocked {
195 status: config.block_status,
196 message: config.block_message.clone().unwrap_or_else(|| {
197 "Request blocked: potential prompt injection detected".to_string()
198 }),
199 detections: response.detections,
200 },
201 GuardrailAction::Log => PromptInjectionResult::Detected {
202 detections: response.detections,
203 },
204 GuardrailAction::Warn => PromptInjectionResult::Warning {
205 detections: response.detections,
206 },
207 }
208 } else {
209 PromptInjectionResult::Clean
210 }
211 }
212 Ok(Err(e)) => {
213 warn!(
214 correlation_id = correlation_id,
215 agent = %config.agent,
216 error = %e,
217 failure_mode = ?config.failure_mode,
218 "Prompt injection agent call failed"
219 );
220
221 match config.failure_mode {
222 GuardrailFailureMode::Open => PromptInjectionResult::Clean,
223 GuardrailFailureMode::Closed => PromptInjectionResult::Blocked {
224 status: 503,
225 message: "Guardrail check unavailable".to_string(),
226 detections: vec![],
227 },
228 }
229 }
230 Err(_) => {
231 warn!(
232 correlation_id = correlation_id,
233 agent = %config.agent,
234 timeout_ms = config.timeout_ms,
235 failure_mode = ?config.failure_mode,
236 "Prompt injection agent call timed out"
237 );
238
239 match config.failure_mode {
240 GuardrailFailureMode::Open => PromptInjectionResult::Clean,
241 GuardrailFailureMode::Closed => PromptInjectionResult::Blocked {
242 status: 504,
243 message: "Guardrail check timed out".to_string(),
244 detections: vec![],
245 },
246 }
247 }
248 }
249 }
250
251 pub async fn check_pii(
259 &self,
260 config: &PiiDetectionConfig,
261 content: &str,
262 route_id: Option<&str>,
263 correlation_id: &str,
264 ) -> PiiCheckResult {
265 if !config.enabled {
266 return PiiCheckResult::Clean;
267 }
268
269 trace!(
270 correlation_id = correlation_id,
271 agent = %config.agent,
272 content_len = content.len(),
273 categories = ?config.categories,
274 "Checking response for PII"
275 );
276
277 let event = GuardrailInspectEvent {
278 correlation_id: correlation_id.to_string(),
279 inspection_type: GuardrailInspectionType::PiiDetection,
280 content: content.to_string(),
281 model: None,
282 categories: config.categories.clone(),
283 route_id: route_id.map(String::from),
284 metadata: HashMap::new(),
285 };
286
287 let start = Instant::now();
288 let timeout_duration = Duration::from_millis(config.timeout_ms);
289
290 match timeout(
291 timeout_duration,
292 self.agent_caller.call_guardrail_agent(&config.agent, event),
293 )
294 .await
295 {
296 Ok(Ok(response)) => {
297 let duration = start.elapsed();
298 debug!(
299 correlation_id = correlation_id,
300 agent = %config.agent,
301 detected = response.detected,
302 detection_count = response.detections.len(),
303 duration_ms = duration.as_millis(),
304 "PII check completed"
305 );
306
307 if response.detected {
308 PiiCheckResult::Detected {
309 detections: response.detections,
310 redacted_content: response.redacted_content,
311 }
312 } else {
313 PiiCheckResult::Clean
314 }
315 }
316 Ok(Err(e)) => {
317 warn!(
318 correlation_id = correlation_id,
319 agent = %config.agent,
320 error = %e,
321 "PII detection agent call failed"
322 );
323
324 PiiCheckResult::Error {
325 message: e.to_string(),
326 }
327 }
328 Err(_) => {
329 warn!(
330 correlation_id = correlation_id,
331 agent = %config.agent,
332 timeout_ms = config.timeout_ms,
333 "PII detection agent call timed out"
334 );
335
336 PiiCheckResult::Error {
337 message: "Agent timeout".to_string(),
338 }
339 }
340 }
341 }
342}
343
344pub fn extract_inference_content(body: &[u8]) -> Option<String> {
349 let json: serde_json::Value = serde_json::from_slice(body).ok()?;
350
351 if let Some(messages) = json.get("messages").and_then(|m| m.as_array()) {
353 let content: Vec<String> = messages
354 .iter()
355 .filter_map(|msg| msg.get("content").and_then(|c| c.as_str()))
356 .map(String::from)
357 .collect();
358 if !content.is_empty() {
359 return Some(content.join("\n"));
360 }
361 }
362
363 if let Some(prompt) = json.get("prompt").and_then(|p| p.as_str()) {
365 return Some(prompt.to_string());
366 }
367
368 for field in &["input", "text", "query", "question"] {
370 if let Some(value) = json.get(*field).and_then(|v| v.as_str()) {
371 return Some(value.to_string());
372 }
373 }
374
375 None
376}
377
378#[cfg(test)]
379mod tests {
380 use super::*;
381 use sentinel_agent_protocol::{DetectionSeverity, TextSpan};
382 use std::sync::atomic::{AtomicUsize, Ordering};
383 use tokio::sync::Mutex;
384
385 struct MockAgentCaller {
389 response: Mutex<Option<Result<GuardrailResponse, String>>>,
390 call_count: AtomicUsize,
391 }
392
393 impl MockAgentCaller {
394 fn new() -> Self {
395 Self {
396 response: Mutex::new(None),
397 call_count: AtomicUsize::new(0),
398 }
399 }
400
401 fn with_response(response: Result<GuardrailResponse, String>) -> Self {
402 Self {
403 response: Mutex::new(Some(response)),
404 call_count: AtomicUsize::new(0),
405 }
406 }
407
408 fn call_count(&self) -> usize {
409 self.call_count.load(Ordering::SeqCst)
410 }
411 }
412
413 #[async_trait]
414 impl GuardrailAgentCaller for MockAgentCaller {
415 async fn call_guardrail_agent(
416 &self,
417 _agent_name: &str,
418 _event: GuardrailInspectEvent,
419 ) -> Result<GuardrailResponse, String> {
420 self.call_count.fetch_add(1, Ordering::SeqCst);
421
422 let guard = self.response.lock().await;
423 match &*guard {
424 Some(response) => response.clone(),
425 None => Err("No mock response configured".to_string()),
426 }
427 }
428 }
429
430 fn create_prompt_injection_config(
433 action: GuardrailAction,
434 failure_mode: GuardrailFailureMode,
435 ) -> PromptInjectionConfig {
436 PromptInjectionConfig {
437 enabled: true,
438 agent: "test-agent".to_string(),
439 action,
440 block_status: 400,
441 block_message: Some("Blocked: injection detected".to_string()),
442 timeout_ms: 5000,
443 failure_mode,
444 }
445 }
446
447 fn create_pii_config() -> PiiDetectionConfig {
448 PiiDetectionConfig {
449 enabled: true,
450 agent: "pii-scanner".to_string(),
451 action: sentinel_config::PiiAction::Log,
452 categories: vec!["ssn".to_string(), "email".to_string()],
453 timeout_ms: 5000,
454 failure_mode: GuardrailFailureMode::Open,
455 }
456 }
457
458 fn create_detection(category: &str, description: &str) -> GuardrailDetection {
459 GuardrailDetection {
460 category: category.to_string(),
461 description: description.to_string(),
462 severity: DetectionSeverity::High,
463 confidence: Some(0.95),
464 span: Some(TextSpan { start: 0, end: 10 }),
465 }
466 }
467
468 fn create_guardrail_response(detected: bool, detections: Vec<GuardrailDetection>) -> GuardrailResponse {
469 GuardrailResponse {
470 detected,
471 confidence: if detected { 0.95 } else { 0.0 },
472 detections,
473 redacted_content: None,
474 }
475 }
476
477 #[test]
480 fn test_extract_openai_content() {
481 let body = br#"{"messages": [{"role": "user", "content": "Hello world"}]}"#;
482 let content = extract_inference_content(body);
483 assert_eq!(content, Some("Hello world".to_string()));
484 }
485
486 #[test]
487 fn test_extract_openai_multi_message() {
488 let body = br#"{
489 "messages": [
490 {"role": "system", "content": "You are helpful"},
491 {"role": "user", "content": "Hello"}
492 ]
493 }"#;
494 let content = extract_inference_content(body);
495 assert_eq!(content, Some("You are helpful\nHello".to_string()));
496 }
497
498 #[test]
499 fn test_extract_anthropic_content() {
500 let body = br#"{"prompt": "Human: Hello\n\nAssistant:"}"#;
501 let content = extract_inference_content(body);
502 assert_eq!(content, Some("Human: Hello\n\nAssistant:".to_string()));
503 }
504
505 #[test]
506 fn test_extract_generic_input() {
507 let body = br#"{"input": "Test query"}"#;
508 let content = extract_inference_content(body);
509 assert_eq!(content, Some("Test query".to_string()));
510 }
511
512 #[test]
513 fn test_extract_generic_text() {
514 let body = br#"{"text": "Some text content"}"#;
515 let content = extract_inference_content(body);
516 assert_eq!(content, Some("Some text content".to_string()));
517 }
518
519 #[test]
520 fn test_extract_generic_query() {
521 let body = br#"{"query": "What is the weather?"}"#;
522 let content = extract_inference_content(body);
523 assert_eq!(content, Some("What is the weather?".to_string()));
524 }
525
526 #[test]
527 fn test_extract_generic_question() {
528 let body = br#"{"question": "How does this work?"}"#;
529 let content = extract_inference_content(body);
530 assert_eq!(content, Some("How does this work?".to_string()));
531 }
532
533 #[test]
534 fn test_extract_invalid_json() {
535 let body = b"not json";
536 let content = extract_inference_content(body);
537 assert_eq!(content, None);
538 }
539
540 #[test]
541 fn test_extract_empty_messages() {
542 let body = br#"{"messages": []}"#;
543 let content = extract_inference_content(body);
544 assert_eq!(content, None);
545 }
546
547 #[test]
548 fn test_extract_messages_without_content() {
549 let body = br#"{"messages": [{"role": "user"}]}"#;
550 let content = extract_inference_content(body);
551 assert_eq!(content, None);
552 }
553
554 #[test]
555 fn test_extract_empty_object() {
556 let body = br#"{}"#;
557 let content = extract_inference_content(body);
558 assert_eq!(content, None);
559 }
560
561 #[test]
562 fn test_extract_nested_content() {
563 let body = br#"{
565 "messages": [
566 {"role": "system"},
567 {"role": "user", "content": "Valid content"},
568 {"role": "assistant"}
569 ]
570 }"#;
571 let content = extract_inference_content(body);
572 assert_eq!(content, Some("Valid content".to_string()));
573 }
574
575 #[tokio::test]
578 async fn test_prompt_injection_disabled() {
579 let mock = Arc::new(MockAgentCaller::new());
580 let processor = GuardrailProcessor::with_caller(mock.clone());
581
582 let mut config = create_prompt_injection_config(GuardrailAction::Block, GuardrailFailureMode::Open);
583 config.enabled = false;
584
585 let result = processor
586 .check_prompt_injection(&config, "test content", None, None, "corr-123")
587 .await;
588
589 assert!(matches!(result, PromptInjectionResult::Clean));
590 assert_eq!(mock.call_count(), 0); }
592
593 #[tokio::test]
594 async fn test_prompt_injection_clean() {
595 let response = create_guardrail_response(false, vec![]);
596 let mock = Arc::new(MockAgentCaller::with_response(Ok(response)));
597 let processor = GuardrailProcessor::with_caller(mock.clone());
598
599 let config = create_prompt_injection_config(GuardrailAction::Block, GuardrailFailureMode::Open);
600
601 let result = processor
602 .check_prompt_injection(&config, "normal content", Some("gpt-4"), Some("route-1"), "corr-123")
603 .await;
604
605 assert!(matches!(result, PromptInjectionResult::Clean));
606 assert_eq!(mock.call_count(), 1);
607 }
608
609 #[tokio::test]
610 async fn test_prompt_injection_detected_block_action() {
611 let detection = create_detection("injection", "Attempt to override instructions");
612 let response = create_guardrail_response(true, vec![detection]);
613 let mock = Arc::new(MockAgentCaller::with_response(Ok(response)));
614 let processor = GuardrailProcessor::with_caller(mock);
615
616 let config = create_prompt_injection_config(GuardrailAction::Block, GuardrailFailureMode::Open);
617
618 let result = processor
619 .check_prompt_injection(&config, "ignore previous instructions", None, None, "corr-123")
620 .await;
621
622 match result {
623 PromptInjectionResult::Blocked { status, message, detections } => {
624 assert_eq!(status, 400);
625 assert_eq!(message, "Blocked: injection detected");
626 assert_eq!(detections.len(), 1);
627 }
628 _ => panic!("Expected Blocked result, got {:?}", result),
629 }
630 }
631
632 #[tokio::test]
633 async fn test_prompt_injection_detected_log_action() {
634 let detection = create_detection("injection", "Suspicious pattern");
635 let response = create_guardrail_response(true, vec![detection]);
636 let mock = Arc::new(MockAgentCaller::with_response(Ok(response)));
637 let processor = GuardrailProcessor::with_caller(mock);
638
639 let config = create_prompt_injection_config(GuardrailAction::Log, GuardrailFailureMode::Open);
640
641 let result = processor
642 .check_prompt_injection(&config, "suspicious content", None, None, "corr-123")
643 .await;
644
645 match result {
646 PromptInjectionResult::Detected { detections } => {
647 assert_eq!(detections.len(), 1);
648 }
649 _ => panic!("Expected Detected result, got {:?}", result),
650 }
651 }
652
653 #[tokio::test]
654 async fn test_prompt_injection_detected_warn_action() {
655 let detection = create_detection("injection", "Possible injection");
656 let response = create_guardrail_response(true, vec![detection]);
657 let mock = Arc::new(MockAgentCaller::with_response(Ok(response)));
658 let processor = GuardrailProcessor::with_caller(mock);
659
660 let config = create_prompt_injection_config(GuardrailAction::Warn, GuardrailFailureMode::Open);
661
662 let result = processor
663 .check_prompt_injection(&config, "maybe suspicious", None, None, "corr-123")
664 .await;
665
666 match result {
667 PromptInjectionResult::Warning { detections } => {
668 assert_eq!(detections.len(), 1);
669 }
670 _ => panic!("Expected Warning result, got {:?}", result),
671 }
672 }
673
674 #[tokio::test]
675 async fn test_prompt_injection_agent_error_fail_open() {
676 let mock = Arc::new(MockAgentCaller::with_response(Err("Agent unavailable".to_string())));
677 let processor = GuardrailProcessor::with_caller(mock);
678
679 let config = create_prompt_injection_config(GuardrailAction::Block, GuardrailFailureMode::Open);
680
681 let result = processor
682 .check_prompt_injection(&config, "test content", None, None, "corr-123")
683 .await;
684
685 assert!(matches!(result, PromptInjectionResult::Clean));
687 }
688
689 #[tokio::test]
690 async fn test_prompt_injection_agent_error_fail_closed() {
691 let mock = Arc::new(MockAgentCaller::with_response(Err("Agent unavailable".to_string())));
692 let processor = GuardrailProcessor::with_caller(mock);
693
694 let config = create_prompt_injection_config(GuardrailAction::Block, GuardrailFailureMode::Closed);
695
696 let result = processor
697 .check_prompt_injection(&config, "test content", None, None, "corr-123")
698 .await;
699
700 match result {
702 PromptInjectionResult::Blocked { status, message, .. } => {
703 assert_eq!(status, 503);
704 assert_eq!(message, "Guardrail check unavailable");
705 }
706 _ => panic!("Expected Blocked result, got {:?}", result),
707 }
708 }
709
710 #[tokio::test]
711 async fn test_prompt_injection_default_block_message() {
712 let detection = create_detection("injection", "Test");
713 let response = create_guardrail_response(true, vec![detection]);
714 let mock = Arc::new(MockAgentCaller::with_response(Ok(response)));
715 let processor = GuardrailProcessor::with_caller(mock);
716
717 let mut config = create_prompt_injection_config(GuardrailAction::Block, GuardrailFailureMode::Open);
718 config.block_message = None; let result = processor
721 .check_prompt_injection(&config, "injection attempt", None, None, "corr-123")
722 .await;
723
724 match result {
725 PromptInjectionResult::Blocked { message, .. } => {
726 assert_eq!(message, "Request blocked: potential prompt injection detected");
727 }
728 _ => panic!("Expected Blocked result"),
729 }
730 }
731
732 #[tokio::test]
735 async fn test_pii_disabled() {
736 let mock = Arc::new(MockAgentCaller::new());
737 let processor = GuardrailProcessor::with_caller(mock.clone());
738
739 let mut config = create_pii_config();
740 config.enabled = false;
741
742 let result = processor
743 .check_pii(&config, "content with SSN 123-45-6789", None, "corr-123")
744 .await;
745
746 assert!(matches!(result, PiiCheckResult::Clean));
747 assert_eq!(mock.call_count(), 0);
748 }
749
750 #[tokio::test]
751 async fn test_pii_clean() {
752 let response = create_guardrail_response(false, vec![]);
753 let mock = Arc::new(MockAgentCaller::with_response(Ok(response)));
754 let processor = GuardrailProcessor::with_caller(mock.clone());
755
756 let config = create_pii_config();
757
758 let result = processor
759 .check_pii(&config, "No sensitive data here", Some("route-1"), "corr-123")
760 .await;
761
762 assert!(matches!(result, PiiCheckResult::Clean));
763 assert_eq!(mock.call_count(), 1);
764 }
765
766 #[tokio::test]
767 async fn test_pii_detected() {
768 let ssn_detection = create_detection("ssn", "Social Security Number detected");
769 let email_detection = create_detection("email", "Email address detected");
770 let mut response = create_guardrail_response(true, vec![ssn_detection, email_detection]);
771 response.redacted_content = Some("My SSN is [REDACTED] and email is [REDACTED]".to_string());
772
773 let mock = Arc::new(MockAgentCaller::with_response(Ok(response)));
774 let processor = GuardrailProcessor::with_caller(mock);
775
776 let config = create_pii_config();
777
778 let result = processor
779 .check_pii(&config, "My SSN is 123-45-6789 and email is test@example.com", None, "corr-123")
780 .await;
781
782 match result {
783 PiiCheckResult::Detected { detections, redacted_content } => {
784 assert_eq!(detections.len(), 2);
785 assert!(redacted_content.is_some());
786 assert!(redacted_content.unwrap().contains("[REDACTED]"));
787 }
788 _ => panic!("Expected Detected result, got {:?}", result),
789 }
790 }
791
792 #[tokio::test]
793 async fn test_pii_agent_error() {
794 let mock = Arc::new(MockAgentCaller::with_response(Err("PII scanner unavailable".to_string())));
795 let processor = GuardrailProcessor::with_caller(mock);
796
797 let config = create_pii_config();
798
799 let result = processor
800 .check_pii(&config, "test content", None, "corr-123")
801 .await;
802
803 match result {
804 PiiCheckResult::Error { message } => {
805 assert!(message.contains("unavailable"));
806 }
807 _ => panic!("Expected Error result, got {:?}", result),
808 }
809 }
810
811 #[test]
814 fn test_prompt_injection_result_debug() {
815 let result = PromptInjectionResult::Clean;
816 let debug_str = format!("{:?}", result);
817 assert!(debug_str.contains("Clean"));
818
819 let result = PromptInjectionResult::Blocked {
820 status: 400,
821 message: "test".to_string(),
822 detections: vec![],
823 };
824 let debug_str = format!("{:?}", result);
825 assert!(debug_str.contains("Blocked"));
826 }
827
828 #[test]
829 fn test_pii_check_result_debug() {
830 let result = PiiCheckResult::Clean;
831 let debug_str = format!("{:?}", result);
832 assert!(debug_str.contains("Clean"));
833
834 let result = PiiCheckResult::Error {
835 message: "test error".to_string(),
836 };
837 let debug_str = format!("{:?}", result);
838 assert!(debug_str.contains("Error"));
839 }
840}