saorsa_core/mcp/
security.rs

1// Copyright 2024 Saorsa Labs Limited
2//
3// This software is dual-licensed under:
4// - GNU Affero General Public License v3.0 or later (AGPL-3.0-or-later)
5// - Commercial License
6//
7// For AGPL-3.0 license, see LICENSE-AGPL-3.0
8// For commercial licensing, contact: saorsalabs@gmail.com
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under these licenses is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
14//! MCP Security Module
15//!
16//! This module provides comprehensive security features for the MCP server including:
17//! - JWT-based authentication
18//! - Peer identity verification
19//! - Access control and permissions
20//! - Rate limiting and abuse prevention
21//! - Message integrity and encryption
22
23use crate::{P2PError, PeerId, Result};
24use base64::prelude::*;
25use serde::{Deserialize, Serialize};
26use sha2::{Digest, Sha256};
27use std::collections::HashMap;
28use std::sync::Arc;
29use std::time::{Duration, SystemTime, UNIX_EPOCH};
30use tokio::sync::RwLock;
31
32/// JWT-like token structure for MCP authentication
33#[derive(Debug, Clone, Serialize, Deserialize)]
34pub struct MCPToken {
35    /// Token header
36    pub header: TokenHeader,
37    /// Token payload
38    pub payload: TokenPayload,
39    /// Token signature
40    pub signature: String,
41}
42
43/// Token header information
44#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct TokenHeader {
46    /// Algorithm used for signing
47    pub alg: String,
48    /// Token type
49    pub typ: String,
50    /// Key ID
51    pub kid: Option<String>,
52}
53
54/// Token payload with claims
55#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct TokenPayload {
57    /// Issuer (peer ID)
58    pub iss: PeerId,
59    /// Subject (target peer ID or tool)
60    pub sub: String,
61    /// Audience (intended recipient)
62    pub aud: String,
63    /// Expiration time (Unix timestamp)
64    pub exp: u64,
65    /// Not before time (Unix timestamp)
66    pub nbf: u64,
67    /// Issued at time (Unix timestamp)
68    pub iat: u64,
69    /// JWT ID
70    pub jti: String,
71    /// Custom claims
72    pub claims: HashMap<String, serde_json::Value>,
73}
74
75/// Security level for MCP operations
76#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
77pub enum SecurityLevel {
78    /// Public access - no authentication required
79    Public,
80    /// Basic authentication required
81    Basic,
82    /// Strong authentication required
83    Strong,
84    /// Administrative access required
85    Admin,
86}
87
88/// Permission for MCP operations
89#[derive(Debug, Clone, PartialEq, Eq, Hash)]
90pub enum MCPPermission {
91    /// Read access to tools
92    ReadTools,
93    /// Execute tools
94    ExecuteTools,
95    /// Register new tools
96    RegisterTools,
97    /// Modify existing tools
98    ModifyTools,
99    /// Delete tools
100    DeleteTools,
101    /// Access prompts
102    AccessPrompts,
103    /// Access resources
104    AccessResources,
105    /// Administrative access
106    Admin,
107    /// Custom permission
108    Custom(String),
109}
110
111impl MCPPermission {
112    /// Get permission string representation
113    pub fn as_str(&self) -> &str {
114        match self {
115            MCPPermission::ReadTools => "read:tools",
116            MCPPermission::ExecuteTools => "execute:tools",
117            MCPPermission::RegisterTools => "register:tools",
118            MCPPermission::ModifyTools => "modify:tools",
119            MCPPermission::DeleteTools => "delete:tools",
120            MCPPermission::AccessPrompts => "access:prompts",
121            MCPPermission::AccessResources => "access:resources",
122            MCPPermission::Admin => "admin",
123            MCPPermission::Custom(s) => s,
124        }
125    }
126
127    /// Parse permission from string
128    pub fn from_str_name(s: &str) -> Option<Self> {
129        match s {
130            "read:tools" => Some(MCPPermission::ReadTools),
131            "execute:tools" => Some(MCPPermission::ExecuteTools),
132            "register:tools" => Some(MCPPermission::RegisterTools),
133            "modify:tools" => Some(MCPPermission::ModifyTools),
134            "delete:tools" => Some(MCPPermission::DeleteTools),
135            "access:prompts" => Some(MCPPermission::AccessPrompts),
136            "access:resources" => Some(MCPPermission::AccessResources),
137            "admin" => Some(MCPPermission::Admin),
138            _ => Some(MCPPermission::Custom(s.to_string())),
139        }
140    }
141}
142
143/// Access control list for a peer
144#[derive(Debug, Clone)]
145pub struct PeerACL {
146    /// Peer ID
147    pub peer_id: PeerId,
148    /// Granted permissions
149    pub permissions: Vec<MCPPermission>,
150    /// Security level
151    pub security_level: SecurityLevel,
152    /// Reputation score (0.0 to 1.0)
153    pub reputation: f64,
154    /// Last access time
155    pub last_access: SystemTime,
156    /// Access count
157    pub access_count: u64,
158    /// Rate limit violations
159    pub rate_violations: u32,
160    /// Banned until (if applicable)
161    pub banned_until: Option<SystemTime>,
162}
163
164impl PeerACL {
165    /// Create new peer ACL with default permissions
166    pub fn new(peer_id: PeerId) -> Self {
167        Self {
168            peer_id,
169            permissions: vec![MCPPermission::ReadTools, MCPPermission::ExecuteTools],
170            security_level: SecurityLevel::Basic,
171            reputation: 0.5, // Start with neutral reputation
172            last_access: SystemTime::now(),
173            access_count: 0,
174            rate_violations: 0,
175            banned_until: None,
176        }
177    }
178
179    /// Check if peer has specific permission
180    pub fn has_permission(&self, permission: &MCPPermission) -> bool {
181        if self.is_banned() {
182            return false;
183        }
184
185        // Admin permission grants all access
186        if self.permissions.contains(&MCPPermission::Admin) {
187            return true;
188        }
189
190        self.permissions.contains(permission)
191    }
192
193    /// Check if peer is currently banned
194    pub fn is_banned(&self) -> bool {
195        if let Some(banned_until) = self.banned_until {
196            SystemTime::now() < banned_until
197        } else {
198            false
199        }
200    }
201
202    /// Update access statistics
203    pub fn record_access(&mut self) {
204        self.last_access = SystemTime::now();
205        self.access_count += 1;
206    }
207
208    /// Record rate limit violation
209    pub fn record_rate_violation(&mut self) {
210        self.rate_violations += 1;
211
212        // Auto-ban after too many violations
213        if self.rate_violations >= 10 {
214            self.banned_until = Some(SystemTime::now() + Duration::from_secs(3600)); // 1 hour
215        }
216    }
217
218    /// Grant permission to peer
219    pub fn grant_permission(&mut self, permission: MCPPermission) {
220        if !self.permissions.contains(&permission) {
221            self.permissions.push(permission);
222        }
223    }
224
225    /// Revoke permission from peer
226    pub fn revoke_permission(&mut self, permission: &MCPPermission) {
227        self.permissions.retain(|p| p != permission);
228    }
229}
230
231/// Rate limiter for controlling request frequency
232#[derive(Debug, Clone)]
233pub struct RateLimiter {
234    /// Requests per minute limit
235    pub rpm_limit: u32,
236    /// Request timestamps for each peer
237    requests: Arc<RwLock<HashMap<PeerId, Vec<SystemTime>>>>,
238}
239
240impl RateLimiter {
241    /// Create new rate limiter
242    pub fn new(rpm_limit: u32) -> Self {
243        Self {
244            rpm_limit,
245            requests: Arc::new(RwLock::new(HashMap::new())),
246        }
247    }
248
249    /// Check if request is allowed for peer
250    pub async fn is_allowed(&self, peer_id: &PeerId) -> bool {
251        let mut requests = self.requests.write().await;
252        let now = SystemTime::now();
253        let minute_ago = now - Duration::from_secs(60);
254
255        // Get or create request history for peer
256        let peer_requests = requests.entry(peer_id.clone()).or_insert_with(Vec::new);
257
258        // Remove old requests (older than 1 minute)
259        peer_requests.retain(|&req_time| req_time > minute_ago);
260
261        // Check if under limit
262        if peer_requests.len() < self.rpm_limit as usize {
263            peer_requests.push(now);
264            true
265        } else {
266            false
267        }
268    }
269
270    /// Reset rate limit for peer (admin function)
271    pub async fn reset_peer(&self, peer_id: &PeerId) {
272        let mut requests = self.requests.write().await;
273        requests.remove(peer_id);
274    }
275
276    /// Clean up old entries periodically
277    pub async fn cleanup(&self) {
278        let mut requests = self.requests.write().await;
279        let minute_ago = SystemTime::now() - Duration::from_secs(60);
280
281        for peer_requests in requests.values_mut() {
282            peer_requests.retain(|&req_time| req_time > minute_ago);
283        }
284
285        // Remove empty entries
286        requests.retain(|_, reqs| !reqs.is_empty());
287    }
288}
289
290/// MCP Security Manager
291pub struct MCPSecurityManager {
292    /// Access control lists
293    acls: Arc<RwLock<HashMap<PeerId, PeerACL>>>,
294    /// Rate limiter
295    rate_limiter: RateLimiter,
296    /// Shared secret for token signing
297    secret_key: Vec<u8>,
298    /// Tool security policies
299    tool_policies: Arc<RwLock<HashMap<String, SecurityLevel>>>,
300    /// Trusted peer list
301    trusted_peers: Arc<RwLock<Vec<PeerId>>>,
302}
303
304impl MCPSecurityManager {
305    /// Create new security manager
306    pub fn new(secret_key: Vec<u8>, rpm_limit: u32) -> Self {
307        Self {
308            acls: Arc::new(RwLock::new(HashMap::new())),
309            rate_limiter: RateLimiter::new(rpm_limit),
310            secret_key,
311            tool_policies: Arc::new(RwLock::new(HashMap::new())),
312            trusted_peers: Arc::new(RwLock::new(Vec::new())),
313        }
314    }
315
316    /// Generate authentication token for peer
317    pub async fn generate_token(
318        &self,
319        peer_id: &PeerId,
320        permissions: Vec<MCPPermission>,
321        ttl: Duration,
322    ) -> Result<String> {
323        let now = SystemTime::now().duration_since(UNIX_EPOCH).map_err(|e| {
324            P2PError::Identity(crate::error::IdentityError::SystemTime(
325                format!("Time error: {e}").into(),
326            ))
327        })?;
328
329        let payload = TokenPayload {
330            iss: peer_id.clone(),
331            sub: peer_id.clone(),
332            aud: "mcp-server".to_string(),
333            exp: (now + ttl).as_secs(),
334            nbf: now.as_secs(),
335            iat: now.as_secs(),
336            jti: uuid::Uuid::new_v4().to_string(),
337            claims: {
338                let mut claims = HashMap::new();
339                claims.insert(
340                    "permissions".to_string(),
341                    serde_json::to_value(
342                        permissions.iter().map(|p| p.as_str()).collect::<Vec<_>>(),
343                    )
344                    .map_err(|e| P2PError::Serialization(e.to_string().into()))?,
345                );
346                claims
347            },
348        };
349
350        let header = TokenHeader {
351            alg: "HS256".to_string(),
352            typ: "JWT".to_string(),
353            kid: None,
354        };
355
356        // Create token without signature first
357        let header_b64 = base64::prelude::BASE64_URL_SAFE_NO_PAD.encode(
358            serde_json::to_vec(&header)
359                .map_err(|e| P2PError::Serialization(e.to_string().into()))?,
360        );
361        let payload_b64 = base64::prelude::BASE64_URL_SAFE_NO_PAD.encode(
362            serde_json::to_vec(&payload)
363                .map_err(|e| P2PError::Serialization(e.to_string().into()))?,
364        );
365
366        // Sign the token
367        let signing_input = format!("{header_b64}.{payload_b64}");
368        let signature = self.sign_data(signing_input.as_bytes());
369        let signature_b64 = base64::prelude::BASE64_URL_SAFE_NO_PAD.encode(signature);
370
371        Ok(format!("{header_b64}.{payload_b64}.{signature_b64}"))
372    }
373
374    /// Verify authentication token
375    pub async fn verify_token(&self, token: &str) -> Result<TokenPayload> {
376        let parts: Vec<&str> = token.split('.').collect();
377        if parts.len() != 3 {
378            return Err(P2PError::Mcp(crate::error::McpError::InvalidRequest(
379                "Invalid token format".to_string().into(),
380            )));
381        }
382
383        let _header_data = base64::prelude::BASE64_URL_SAFE_NO_PAD
384            .decode(parts[0])
385            .map_err(|e| {
386                P2PError::Mcp(crate::error::McpError::InvalidRequest(
387                    format!("Invalid header encoding: {e}").into(),
388                ))
389            })?;
390        let payload_data = base64::prelude::BASE64_URL_SAFE_NO_PAD
391            .decode(parts[1])
392            .map_err(|e| {
393                P2PError::Mcp(crate::error::McpError::InvalidRequest(
394                    format!("Invalid payload encoding: {e}").into(),
395                ))
396            })?;
397        let signature = base64::prelude::BASE64_URL_SAFE_NO_PAD
398            .decode(parts[2])
399            .map_err(|e| {
400                P2PError::Mcp(crate::error::McpError::InvalidRequest(
401                    format!("Invalid signature encoding: {e}").into(),
402                ))
403            })?;
404
405        // Verify signature
406        let signing_input = format!("{}.{}", parts[0], parts[1]);
407        let expected_signature = self.sign_data(signing_input.as_bytes());
408
409        if signature != expected_signature {
410            return Err(P2PError::Mcp(crate::error::McpError::InvalidRequest(
411                "Invalid token signature".to_string().into(),
412            )));
413        }
414
415        // Parse payload
416        let payload: TokenPayload = serde_json::from_slice(&payload_data).map_err(|e| {
417            P2PError::Mcp(crate::error::McpError::InvalidRequest(
418                format!("Invalid payload: {e}").into(),
419            ))
420        })?;
421
422        // Check expiration
423        let now = SystemTime::now()
424            .duration_since(UNIX_EPOCH)
425            .map_err(|e| {
426                P2PError::Identity(crate::error::IdentityError::SystemTime(
427                    format!("Time error: {e}").into(),
428                ))
429            })?
430            .as_secs();
431
432        if payload.exp < now {
433            return Err(P2PError::Mcp(crate::error::McpError::InvalidRequest(
434                "Token expired".to_string().into(),
435            )));
436        }
437
438        if payload.nbf > now {
439            return Err(P2PError::Mcp(crate::error::McpError::InvalidRequest(
440                "Token not yet valid".to_string().into(),
441            )));
442        }
443
444        Ok(payload)
445    }
446
447    /// Check if peer has permission for operation
448    pub async fn check_permission(
449        &self,
450        peer_id: &PeerId,
451        permission: &MCPPermission,
452    ) -> Result<bool> {
453        let acls = self.acls.read().await;
454
455        if let Some(acl) = acls.get(peer_id) {
456            Ok(acl.has_permission(permission))
457        } else {
458            // Create default ACL for new peer
459            drop(acls);
460            let mut acls = self.acls.write().await;
461            acls.insert(peer_id.clone(), PeerACL::new(peer_id.clone()));
462            Ok(false) // New peers start with no permissions by default
463        }
464    }
465
466    /// Check rate limit for peer
467    pub async fn check_rate_limit(&self, peer_id: &PeerId) -> Result<bool> {
468        if self.rate_limiter.is_allowed(peer_id).await {
469            Ok(true)
470        } else {
471            // Record violation
472            let mut acls = self.acls.write().await;
473            if let Some(acl) = acls.get_mut(peer_id) {
474                acl.record_rate_violation();
475            }
476            Ok(false)
477        }
478    }
479
480    /// Grant permission to peer
481    pub async fn grant_permission(
482        &self,
483        peer_id: &PeerId,
484        permission: MCPPermission,
485    ) -> Result<()> {
486        let mut acls = self.acls.write().await;
487        let acl = acls
488            .entry(peer_id.clone())
489            .or_insert_with(|| PeerACL::new(peer_id.clone()));
490        acl.grant_permission(permission);
491        Ok(())
492    }
493
494    /// Revoke permission from peer
495    pub async fn revoke_permission(
496        &self,
497        peer_id: &PeerId,
498        permission: &MCPPermission,
499    ) -> Result<()> {
500        let mut acls = self.acls.write().await;
501        if let Some(acl) = acls.get_mut(peer_id) {
502            acl.revoke_permission(permission);
503        }
504        Ok(())
505    }
506
507    /// Add trusted peer
508    pub async fn add_trusted_peer(&self, peer_id: PeerId) -> Result<()> {
509        let mut trusted = self.trusted_peers.write().await;
510        if !trusted.contains(&peer_id) {
511            trusted.push(peer_id);
512        }
513        Ok(())
514    }
515
516    /// Check if peer is trusted
517    pub async fn is_trusted_peer(&self, peer_id: &PeerId) -> bool {
518        let trusted = self.trusted_peers.read().await;
519        trusted.contains(peer_id)
520    }
521
522    /// Set security policy for tool
523    pub async fn set_tool_policy(&self, tool_name: String, level: SecurityLevel) -> Result<()> {
524        let mut policies = self.tool_policies.write().await;
525        policies.insert(tool_name, level);
526        Ok(())
527    }
528
529    /// Get security policy for tool
530    pub async fn get_tool_policy(&self, tool_name: &str) -> SecurityLevel {
531        let policies = self.tool_policies.read().await;
532        policies
533            .get(tool_name)
534            .cloned()
535            .unwrap_or(SecurityLevel::Basic)
536    }
537
538    /// Sign data with secret key
539    fn sign_data(&self, data: &[u8]) -> Vec<u8> {
540        let mut hasher = Sha256::new();
541        hasher.update(&self.secret_key);
542        hasher.update(data);
543        hasher.finalize().to_vec()
544    }
545
546    /// Update peer reputation based on behavior
547    pub async fn update_reputation(&self, peer_id: &PeerId, delta: f64) -> Result<()> {
548        let mut acls = self.acls.write().await;
549        if let Some(acl) = acls.get_mut(peer_id) {
550            acl.reputation = (acl.reputation + delta).clamp(0.0, 1.0);
551        }
552        Ok(())
553    }
554
555    /// Get peer statistics
556    pub async fn get_peer_stats(&self, peer_id: &PeerId) -> Option<PeerACL> {
557        let acls = self.acls.read().await;
558        acls.get(peer_id).cloned()
559    }
560
561    /// Clean up expired data
562    pub async fn cleanup(&self) -> Result<()> {
563        self.rate_limiter.cleanup().await;
564
565        // Clean up old ACLs (remove entries not accessed in 24 hours)
566        let mut acls = self.acls.write().await;
567        let day_ago = SystemTime::now() - Duration::from_secs(24 * 3600);
568        acls.retain(|_, acl| acl.last_access > day_ago);
569
570        Ok(())
571    }
572}
573
574/// Security audit log entry
575#[derive(Debug, Clone)]
576pub struct SecurityAuditEntry {
577    /// Timestamp
578    pub timestamp: SystemTime,
579    /// Event type
580    pub event_type: String,
581    /// Peer ID involved
582    pub peer_id: PeerId,
583    /// Event details
584    pub details: HashMap<String, String>,
585    /// Severity level
586    pub severity: AuditSeverity,
587}
588
589/// Audit severity levels
590#[derive(Debug, Clone, PartialEq)]
591pub enum AuditSeverity {
592    /// Informational
593    Info,
594    /// Warning
595    Warning,
596    /// Error
597    Error,
598    /// Critical security event
599    Critical,
600}
601
602/// Security audit logger
603pub struct SecurityAuditLogger {
604    /// Audit entries
605    entries: Arc<RwLock<Vec<SecurityAuditEntry>>>,
606    /// Maximum entries to keep
607    max_entries: usize,
608}
609
610impl SecurityAuditLogger {
611    /// Create new audit logger
612    pub fn new(max_entries: usize) -> Self {
613        Self {
614            entries: Arc::new(RwLock::new(Vec::new())),
615            max_entries,
616        }
617    }
618
619    /// Log security event
620    pub async fn log_event(
621        &self,
622        event_type: String,
623        peer_id: PeerId,
624        details: HashMap<String, String>,
625        severity: AuditSeverity,
626    ) {
627        let entry = SecurityAuditEntry {
628            timestamp: SystemTime::now(),
629            event_type,
630            peer_id,
631            details,
632            severity,
633        };
634
635        let mut entries = self.entries.write().await;
636        entries.push(entry);
637
638        // Keep only recent entries
639        if entries.len() > self.max_entries {
640            let excess = entries.len() - self.max_entries;
641            entries.drain(0..excess);
642        }
643    }
644
645    /// Get recent audit entries
646    pub async fn get_recent_entries(&self, limit: Option<usize>) -> Vec<SecurityAuditEntry> {
647        let entries = self.entries.read().await;
648        let limit = limit.unwrap_or(entries.len());
649        entries.iter().rev().take(limit).cloned().collect()
650    }
651
652    /// Get entries by severity
653    pub async fn get_entries_by_severity(
654        &self,
655        severity: AuditSeverity,
656    ) -> Vec<SecurityAuditEntry> {
657        let entries = self.entries.read().await;
658        entries
659            .iter()
660            .filter(|e| e.severity == severity)
661            .cloned()
662            .collect()
663    }
664}
665
666#[cfg(test)]
667mod tests {
668    use super::*;
669    use std::time::Duration;
670
671    /// Helper function to create a test PeerId
672    fn create_test_peer() -> PeerId {
673        format!("test_peer_{}", rand::random::<u32>())
674    }
675
676    /// Helper function to create a test security manager
677    fn create_test_security_manager() -> MCPSecurityManager {
678        let secret_key = b"test_secret_key_1234567890123456".to_vec();
679        MCPSecurityManager::new(secret_key, 60) // 60 RPM limit
680    }
681
682    #[test]
683    fn test_mcp_permission_string_conversion() {
684        let permissions = vec![
685            (MCPPermission::ReadTools, "read:tools"),
686            (MCPPermission::ExecuteTools, "execute:tools"),
687            (MCPPermission::RegisterTools, "register:tools"),
688            (MCPPermission::ModifyTools, "modify:tools"),
689            (MCPPermission::DeleteTools, "delete:tools"),
690            (MCPPermission::AccessPrompts, "access:prompts"),
691            (MCPPermission::AccessResources, "access:resources"),
692            (MCPPermission::Admin, "admin"),
693        ];
694
695        for (permission, expected_str) in permissions {
696            assert_eq!(permission.as_str(), expected_str);
697            assert_eq!(MCPPermission::from_str_name(expected_str), Some(permission));
698        }
699
700        // Test custom permission
701        let custom = MCPPermission::Custom("custom:action".to_string());
702        assert_eq!(custom.as_str(), "custom:action");
703        assert_eq!(MCPPermission::from_str_name("custom:action"), Some(custom));
704
705        // Test unknown permission defaults to custom
706        let unknown = MCPPermission::from_str_name("unknown:permission");
707        match unknown {
708            Some(MCPPermission::Custom(s)) => assert_eq!(s, "unknown:permission"),
709            _ => panic!("Expected custom permission"),
710        }
711    }
712
713    #[test]
714    fn test_security_level_ordering() {
715        // Test security level ordering
716        assert!(SecurityLevel::Public < SecurityLevel::Basic);
717        assert!(SecurityLevel::Basic < SecurityLevel::Strong);
718        assert!(SecurityLevel::Strong < SecurityLevel::Admin);
719
720        // Test equality
721        assert_eq!(SecurityLevel::Public, SecurityLevel::Public);
722        assert_eq!(SecurityLevel::Basic, SecurityLevel::Basic);
723        assert_eq!(SecurityLevel::Strong, SecurityLevel::Strong);
724        assert_eq!(SecurityLevel::Admin, SecurityLevel::Admin);
725    }
726
727    #[test]
728    fn test_peer_acl_creation() {
729        let peer_id = create_test_peer();
730        let acl = PeerACL::new(peer_id.clone());
731
732        assert_eq!(acl.peer_id, peer_id);
733        assert_eq!(acl.permissions.len(), 2); // Default: ReadTools, ExecuteTools
734        assert!(acl.permissions.contains(&MCPPermission::ReadTools));
735        assert!(acl.permissions.contains(&MCPPermission::ExecuteTools));
736        assert_eq!(acl.security_level, SecurityLevel::Basic);
737        assert_eq!(acl.reputation, 0.5);
738        assert_eq!(acl.access_count, 0);
739        assert_eq!(acl.rate_violations, 0);
740        assert!(acl.banned_until.is_none());
741        assert!(!acl.is_banned());
742    }
743
744    #[test]
745    fn test_peer_acl_permissions() {
746        let peer_id = create_test_peer();
747        let mut acl = PeerACL::new(peer_id);
748
749        // Test default permissions
750        assert!(acl.has_permission(&MCPPermission::ReadTools));
751        assert!(acl.has_permission(&MCPPermission::ExecuteTools));
752        assert!(!acl.has_permission(&MCPPermission::RegisterTools));
753        assert!(!acl.has_permission(&MCPPermission::Admin));
754
755        // Grant admin permission
756        acl.grant_permission(MCPPermission::Admin);
757        // Admin permission grants all access
758        assert!(acl.has_permission(&MCPPermission::ReadTools));
759        assert!(acl.has_permission(&MCPPermission::ExecuteTools));
760        assert!(acl.has_permission(&MCPPermission::RegisterTools));
761        assert!(acl.has_permission(&MCPPermission::DeleteTools));
762        assert!(acl.has_permission(&MCPPermission::Admin));
763
764        // Revoke admin permission
765        acl.revoke_permission(&MCPPermission::Admin);
766        assert!(!acl.has_permission(&MCPPermission::RegisterTools));
767        assert!(!acl.has_permission(&MCPPermission::Admin));
768
769        // Grant specific permission
770        acl.grant_permission(MCPPermission::RegisterTools);
771        assert!(acl.has_permission(&MCPPermission::RegisterTools));
772
773        // Revoke specific permission
774        acl.revoke_permission(&MCPPermission::RegisterTools);
775        assert!(!acl.has_permission(&MCPPermission::RegisterTools));
776    }
777
778    #[test]
779    fn test_peer_acl_ban_functionality() {
780        let peer_id = create_test_peer();
781        let mut acl = PeerACL::new(peer_id);
782
783        // Initially not banned
784        assert!(!acl.is_banned());
785        assert!(acl.has_permission(&MCPPermission::ReadTools));
786
787        // Record violations (but not enough to trigger auto-ban)
788        for _ in 0..5 {
789            acl.record_rate_violation();
790        }
791        assert_eq!(acl.rate_violations, 5);
792        assert!(!acl.is_banned());
793
794        // Record enough violations to trigger auto-ban
795        for _ in 0..5 {
796            acl.record_rate_violation();
797        }
798        assert_eq!(acl.rate_violations, 10);
799        assert!(acl.is_banned());
800
801        // Banned peers have no permissions
802        assert!(!acl.has_permission(&MCPPermission::ReadTools));
803        assert!(!acl.has_permission(&MCPPermission::ExecuteTools));
804    }
805
806    #[test]
807    fn test_peer_acl_access_tracking() {
808        let peer_id = create_test_peer();
809        let mut acl = PeerACL::new(peer_id);
810
811        let initial_time = acl.last_access;
812        assert_eq!(acl.access_count, 0);
813
814        // Record access
815        std::thread::sleep(std::time::Duration::from_millis(10));
816        acl.record_access();
817
818        assert_eq!(acl.access_count, 1);
819        assert!(acl.last_access > initial_time);
820
821        // Record more access
822        acl.record_access();
823        assert_eq!(acl.access_count, 2);
824    }
825
826    #[tokio::test]
827    async fn test_rate_limiter_creation() {
828        let limiter = RateLimiter::new(60);
829        assert_eq!(limiter.rpm_limit, 60);
830    }
831
832    #[tokio::test]
833    async fn test_rate_limiter_basic_functionality() {
834        let limiter = RateLimiter::new(2); // 2 requests per minute
835        let peer_id = create_test_peer();
836
837        // First request should be allowed
838        assert!(limiter.is_allowed(&peer_id).await);
839
840        // Second request should be allowed
841        assert!(limiter.is_allowed(&peer_id).await);
842
843        // Third request should be denied (over limit)
844        assert!(!limiter.is_allowed(&peer_id).await);
845    }
846
847    #[tokio::test]
848    async fn test_rate_limiter_different_peers() {
849        let limiter = RateLimiter::new(1); // 1 request per minute
850        let peer1 = create_test_peer();
851        let peer2 = create_test_peer();
852
853        // Each peer should have their own limit
854        assert!(limiter.is_allowed(&peer1).await);
855        assert!(limiter.is_allowed(&peer2).await);
856
857        // Both should be over their individual limits now
858        assert!(!limiter.is_allowed(&peer1).await);
859        assert!(!limiter.is_allowed(&peer2).await);
860    }
861
862    #[tokio::test]
863    async fn test_rate_limiter_reset() {
864        let limiter = RateLimiter::new(1);
865        let peer_id = create_test_peer();
866
867        // Use up the limit
868        assert!(limiter.is_allowed(&peer_id).await);
869        assert!(!limiter.is_allowed(&peer_id).await);
870
871        // Reset the peer
872        limiter.reset_peer(&peer_id).await;
873
874        // Should be allowed again
875        assert!(limiter.is_allowed(&peer_id).await);
876    }
877
878    #[tokio::test]
879    async fn test_rate_limiter_cleanup() {
880        let limiter = RateLimiter::new(10);
881        let peer_id = create_test_peer();
882
883        // Make some requests
884        limiter.is_allowed(&peer_id).await;
885        limiter.is_allowed(&peer_id).await;
886
887        // Cleanup shouldn't affect recent requests
888        limiter.cleanup().await;
889
890        // Should still have request history
891        let requests = limiter.requests.read().await;
892        assert!(requests.contains_key(&peer_id));
893        let peer_requests = requests.get(&peer_id).expect("valid security operation");
894        assert_eq!(peer_requests.len(), 2);
895    }
896
897    #[tokio::test]
898    async fn test_security_manager_creation() {
899        let secret_key = b"test_secret_key".to_vec();
900        let manager = MCPSecurityManager::new(secret_key.clone(), 60);
901
902        // Verify configuration
903        assert_eq!(manager.secret_key, secret_key);
904        assert_eq!(manager.rate_limiter.rpm_limit, 60);
905    }
906
907    #[tokio::test]
908    async fn test_token_generation_and_verification() -> Result<()> {
909        let manager = create_test_security_manager();
910        let peer_id = create_test_peer();
911        let permissions = vec![MCPPermission::ReadTools, MCPPermission::ExecuteTools];
912        let ttl = Duration::from_secs(3600); // 1 hour
913
914        // Generate token
915        let token = manager
916            .generate_token(&peer_id, permissions.clone(), ttl)
917            .await?;
918        assert!(!token.is_empty());
919
920        // Verify token
921        let payload = manager.verify_token(&token).await?;
922        assert_eq!(payload.iss, peer_id);
923        assert_eq!(payload.sub, peer_id);
924        assert_eq!(payload.aud, "mcp-server");
925
926        // Check permissions in claims
927        let permissions_claim = payload
928            .claims
929            .get("permissions")
930            .expect("valid security operation");
931        let permission_strings: Vec<String> =
932            serde_json::from_value(permissions_claim.clone()).expect("valid security operation");
933        assert_eq!(permission_strings.len(), 2);
934        assert!(permission_strings.contains(&"read:tools".to_string()));
935        assert!(permission_strings.contains(&"execute:tools".to_string()));
936
937        Ok(())
938    }
939
940    #[tokio::test]
941    async fn test_token_verification_invalid() {
942        let manager = create_test_security_manager();
943
944        // Test invalid token format
945        let result = manager.verify_token("invalid.token").await;
946        assert!(result.is_err());
947
948        // Test malformed token
949        let result = manager.verify_token("invalid.token.format.extra").await;
950        assert!(result.is_err());
951
952        // Test empty token
953        let result = manager.verify_token("").await;
954        assert!(result.is_err());
955    }
956
957    #[tokio::test]
958    async fn test_token_signature_verification() -> Result<()> {
959        let manager1 = create_test_security_manager();
960        let manager2 = MCPSecurityManager::new(b"different_secret".to_vec(), 60);
961
962        let peer_id = create_test_peer();
963        let permissions = vec![MCPPermission::ReadTools];
964        let ttl = Duration::from_secs(3600);
965
966        // Generate token with manager1
967        let token = manager1.generate_token(&peer_id, permissions, ttl).await?;
968
969        // Verify with manager1 should succeed
970        assert!(manager1.verify_token(&token).await.is_ok());
971
972        // Verify with manager2 should fail (different secret)
973        assert!(manager2.verify_token(&token).await.is_err());
974
975        Ok(())
976    }
977
978    #[tokio::test]
979    async fn test_permission_management() -> Result<()> {
980        let manager = create_test_security_manager();
981        let peer_id = create_test_peer();
982
983        // Initially should have no permissions (new peer starts with false)
984        assert!(
985            !manager
986                .check_permission(&peer_id, &MCPPermission::ExecuteTools)
987                .await?
988        );
989
990        // Grant permission
991        manager
992            .grant_permission(&peer_id, MCPPermission::ExecuteTools)
993            .await?;
994        assert!(
995            manager
996                .check_permission(&peer_id, &MCPPermission::ExecuteTools)
997                .await?
998        );
999
1000        // Revoke permission
1001        manager
1002            .revoke_permission(&peer_id, &MCPPermission::ExecuteTools)
1003            .await?;
1004        assert!(
1005            !manager
1006                .check_permission(&peer_id, &MCPPermission::ExecuteTools)
1007                .await?
1008        );
1009
1010        Ok(())
1011    }
1012
1013    #[tokio::test]
1014    async fn test_rate_limit_checking() -> Result<()> {
1015        let manager = MCPSecurityManager::new(b"test_key".to_vec(), 2); // 2 RPM limit
1016        let peer_id = create_test_peer();
1017
1018        // Grant permission first to create ACL entry
1019        manager
1020            .grant_permission(&peer_id, MCPPermission::ReadTools)
1021            .await?;
1022
1023        // First two requests should pass
1024        assert!(manager.check_rate_limit(&peer_id).await?);
1025        assert!(manager.check_rate_limit(&peer_id).await?);
1026
1027        // Third request should fail
1028        assert!(!manager.check_rate_limit(&peer_id).await?);
1029
1030        // Check that violation was recorded
1031        let stats = manager.get_peer_stats(&peer_id).await;
1032        assert!(stats.is_some());
1033        let acl = stats.expect("valid security operation");
1034        assert_eq!(acl.rate_violations, 1);
1035
1036        Ok(())
1037    }
1038
1039    #[tokio::test]
1040    async fn test_trusted_peer_management() -> Result<()> {
1041        let manager = create_test_security_manager();
1042        let peer_id = create_test_peer();
1043
1044        // Initially not trusted
1045        assert!(!manager.is_trusted_peer(&peer_id).await);
1046
1047        // Add as trusted
1048        manager.add_trusted_peer(peer_id.clone()).await?;
1049        assert!(manager.is_trusted_peer(&peer_id).await);
1050
1051        // Adding same peer again should be idempotent
1052        manager.add_trusted_peer(peer_id.clone()).await?;
1053        assert!(manager.is_trusted_peer(&peer_id).await);
1054
1055        Ok(())
1056    }
1057
1058    #[tokio::test]
1059    async fn test_tool_security_policies() -> Result<()> {
1060        let manager = create_test_security_manager();
1061
1062        // Default policy should be Basic
1063        let policy = manager.get_tool_policy("test_tool").await;
1064        assert_eq!(policy, SecurityLevel::Basic);
1065
1066        // Set custom policy
1067        manager
1068            .set_tool_policy("test_tool".to_string(), SecurityLevel::Strong)
1069            .await?;
1070        let policy = manager.get_tool_policy("test_tool").await;
1071        assert_eq!(policy, SecurityLevel::Strong);
1072
1073        // Set admin policy
1074        manager
1075            .set_tool_policy("admin_tool".to_string(), SecurityLevel::Admin)
1076            .await?;
1077        let policy = manager.get_tool_policy("admin_tool").await;
1078        assert_eq!(policy, SecurityLevel::Admin);
1079
1080        Ok(())
1081    }
1082
1083    #[tokio::test]
1084    async fn test_reputation_management() -> Result<()> {
1085        let manager = create_test_security_manager();
1086        let peer_id = create_test_peer();
1087
1088        // Grant permission to create ACL entry
1089        manager
1090            .grant_permission(&peer_id, MCPPermission::ReadTools)
1091            .await?;
1092
1093        let stats = manager
1094            .get_peer_stats(&peer_id)
1095            .await
1096            .expect("valid security operation");
1097        assert_eq!(stats.reputation, 0.5); // Default reputation
1098
1099        // Increase reputation
1100        manager.update_reputation(&peer_id, 0.2).await?;
1101        let stats = manager
1102            .get_peer_stats(&peer_id)
1103            .await
1104            .expect("valid security operation");
1105        assert_eq!(stats.reputation, 0.7);
1106
1107        // Decrease reputation
1108        manager.update_reputation(&peer_id, -0.3).await?;
1109        let stats = manager
1110            .get_peer_stats(&peer_id)
1111            .await
1112            .expect("valid security operation");
1113        assert!((stats.reputation - 0.4).abs() < 0.001); // Use epsilon for float comparison
1114
1115        // Test bounds (should clamp to 0.0-1.0)
1116        manager.update_reputation(&peer_id, -1.0).await?;
1117        let stats = manager
1118            .get_peer_stats(&peer_id)
1119            .await
1120            .expect("valid security operation");
1121        assert_eq!(stats.reputation, 0.0);
1122
1123        manager.update_reputation(&peer_id, 2.0).await?;
1124        let stats = manager
1125            .get_peer_stats(&peer_id)
1126            .await
1127            .expect("valid security operation");
1128        assert_eq!(stats.reputation, 1.0);
1129
1130        Ok(())
1131    }
1132
1133    #[tokio::test]
1134    async fn test_security_manager_cleanup() -> Result<()> {
1135        let manager = create_test_security_manager();
1136        let peer_id = create_test_peer();
1137
1138        // Create some data
1139        manager
1140            .grant_permission(&peer_id, MCPPermission::ReadTools)
1141            .await?;
1142        manager.check_rate_limit(&peer_id).await?;
1143
1144        // Cleanup should work without errors
1145        manager.cleanup().await?;
1146
1147        Ok(())
1148    }
1149
1150    #[tokio::test]
1151    async fn test_audit_logger_creation() {
1152        let logger = SecurityAuditLogger::new(100);
1153        assert_eq!(logger.max_entries, 100);
1154
1155        let entries = logger.get_recent_entries(None).await;
1156        assert!(entries.is_empty());
1157    }
1158
1159    #[tokio::test]
1160    async fn test_audit_logger_logging() {
1161        let logger = SecurityAuditLogger::new(10);
1162        let peer_id = create_test_peer();
1163
1164        let mut details = HashMap::new();
1165        details.insert("action".to_string(), "test_action".to_string());
1166        details.insert("result".to_string(), "success".to_string());
1167
1168        // Log an event
1169        logger
1170            .log_event(
1171                "test_event".to_string(),
1172                peer_id.clone(),
1173                details.clone(),
1174                AuditSeverity::Info,
1175            )
1176            .await;
1177
1178        let entries = logger.get_recent_entries(None).await;
1179        assert_eq!(entries.len(), 1);
1180
1181        let entry = &entries[0];
1182        assert_eq!(entry.event_type, "test_event");
1183        assert_eq!(entry.peer_id, peer_id);
1184        assert_eq!(entry.severity, AuditSeverity::Info);
1185        assert_eq!(
1186            entry.details.get("action"),
1187            Some(&"test_action".to_string())
1188        );
1189    }
1190
1191    #[tokio::test]
1192    async fn test_audit_logger_severity_filtering() {
1193        let logger = SecurityAuditLogger::new(10);
1194        let peer_id = create_test_peer();
1195
1196        // Log events with different severities
1197        logger
1198            .log_event(
1199                "info_event".to_string(),
1200                peer_id.clone(),
1201                HashMap::new(),
1202                AuditSeverity::Info,
1203            )
1204            .await;
1205        logger
1206            .log_event(
1207                "warning_event".to_string(),
1208                peer_id.clone(),
1209                HashMap::new(),
1210                AuditSeverity::Warning,
1211            )
1212            .await;
1213        logger
1214            .log_event(
1215                "error_event".to_string(),
1216                peer_id.clone(),
1217                HashMap::new(),
1218                AuditSeverity::Error,
1219            )
1220            .await;
1221        logger
1222            .log_event(
1223                "critical_event".to_string(),
1224                peer_id.clone(),
1225                HashMap::new(),
1226                AuditSeverity::Critical,
1227            )
1228            .await;
1229
1230        // Test filtering by severity
1231        let info_entries = logger.get_entries_by_severity(AuditSeverity::Info).await;
1232        assert_eq!(info_entries.len(), 1);
1233        assert_eq!(info_entries[0].event_type, "info_event");
1234
1235        let warning_entries = logger.get_entries_by_severity(AuditSeverity::Warning).await;
1236        assert_eq!(warning_entries.len(), 1);
1237        assert_eq!(warning_entries[0].event_type, "warning_event");
1238
1239        let error_entries = logger.get_entries_by_severity(AuditSeverity::Error).await;
1240        assert_eq!(error_entries.len(), 1);
1241
1242        let critical_entries = logger
1243            .get_entries_by_severity(AuditSeverity::Critical)
1244            .await;
1245        assert_eq!(critical_entries.len(), 1);
1246    }
1247
1248    #[tokio::test]
1249    async fn test_audit_logger_max_entries() {
1250        let logger = SecurityAuditLogger::new(3); // Limit to 3 entries
1251        let peer_id = create_test_peer();
1252
1253        // Log 5 events
1254        for i in 0..5 {
1255            logger
1256                .log_event(
1257                    format!("event_{}", i).into(),
1258                    peer_id.clone(),
1259                    HashMap::new(),
1260                    AuditSeverity::Info,
1261                )
1262                .await;
1263        }
1264
1265        let entries = logger.get_recent_entries(None).await;
1266        assert_eq!(entries.len(), 3); // Should only keep 3 most recent
1267
1268        // Check that we have the most recent events (2, 3, 4)
1269        assert_eq!(entries[0].event_type, "event_4"); // Most recent first
1270        assert_eq!(entries[1].event_type, "event_3");
1271        assert_eq!(entries[2].event_type, "event_2");
1272    }
1273
1274    #[tokio::test]
1275    async fn test_audit_logger_recent_entries_limit() {
1276        let logger = SecurityAuditLogger::new(10);
1277        let peer_id = create_test_peer();
1278
1279        // Log 5 events
1280        for i in 0..5 {
1281            logger
1282                .log_event(
1283                    format!("event_{}", i).into(),
1284                    peer_id.clone(),
1285                    HashMap::new(),
1286                    AuditSeverity::Info,
1287                )
1288                .await;
1289        }
1290
1291        // Get limited number of recent entries
1292        let entries = logger.get_recent_entries(Some(3)).await;
1293        assert_eq!(entries.len(), 3);
1294
1295        // Should be most recent first
1296        assert_eq!(entries[0].event_type, "event_4");
1297        assert_eq!(entries[1].event_type, "event_3");
1298        assert_eq!(entries[2].event_type, "event_2");
1299    }
1300
1301    #[test]
1302    fn test_audit_severity_equality() {
1303        assert_eq!(AuditSeverity::Info, AuditSeverity::Info);
1304        assert_eq!(AuditSeverity::Warning, AuditSeverity::Warning);
1305        assert_eq!(AuditSeverity::Error, AuditSeverity::Error);
1306        assert_eq!(AuditSeverity::Critical, AuditSeverity::Critical);
1307
1308        assert_ne!(AuditSeverity::Info, AuditSeverity::Warning);
1309        assert_ne!(AuditSeverity::Warning, AuditSeverity::Error);
1310        assert_ne!(AuditSeverity::Error, AuditSeverity::Critical);
1311    }
1312
1313    #[test]
1314    fn test_token_header_structure() {
1315        let header = TokenHeader {
1316            alg: "HS256".to_string(),
1317            typ: "JWT".to_string(),
1318            kid: Some("key123".to_string()),
1319        };
1320
1321        assert_eq!(header.alg, "HS256");
1322        assert_eq!(header.typ, "JWT");
1323        assert_eq!(header.kid, Some("key123".to_string()));
1324    }
1325
1326    #[test]
1327    fn test_token_payload_structure() {
1328        let peer_id = create_test_peer();
1329        let now = SystemTime::now()
1330            .duration_since(std::time::UNIX_EPOCH)
1331            .expect("valid security operation")
1332            .as_secs();
1333
1334        let mut claims = HashMap::new();
1335        claims.insert("custom".to_string(), serde_json::json!("value"));
1336
1337        let payload = TokenPayload {
1338            iss: peer_id.clone(),
1339            sub: peer_id.to_string(),
1340            aud: "test-audience".to_string(),
1341            exp: now + 3600,
1342            nbf: now,
1343            iat: now,
1344            jti: "unique-id".to_string(),
1345            claims,
1346        };
1347
1348        assert_eq!(payload.iss, peer_id);
1349        assert_eq!(payload.aud, "test-audience");
1350        assert_eq!(payload.jti, "unique-id");
1351        assert!(payload.exp > payload.iat);
1352        assert_eq!(
1353            payload.claims.get("custom"),
1354            Some(&serde_json::json!("value"))
1355        );
1356    }
1357
1358    #[test]
1359    fn test_mcp_token_structure() {
1360        let peer_id = create_test_peer();
1361
1362        let header = TokenHeader {
1363            alg: "HS256".to_string(),
1364            typ: "JWT".to_string(),
1365            kid: None,
1366        };
1367
1368        let payload = TokenPayload {
1369            iss: peer_id.clone(),
1370            sub: peer_id.to_string(),
1371            aud: "test".to_string(),
1372            exp: 1234567890,
1373            nbf: 1234567800,
1374            iat: 1234567800,
1375            jti: "test-id".to_string(),
1376            claims: HashMap::new(),
1377        };
1378
1379        let token = MCPToken {
1380            header: header.clone(),
1381            payload: payload.clone(),
1382            signature: "test-signature".to_string(),
1383        };
1384
1385        assert_eq!(token.header.alg, header.alg);
1386        assert_eq!(token.payload.iss, payload.iss);
1387        assert_eq!(token.signature, "test-signature");
1388    }
1389}