rust_network_scanner/
lib.rs

1//! # Rust Network Scanner
2//!
3//! A memory-safe, asynchronous network security scanner for vulnerability assessment
4//! and network monitoring.
5//!
6//! ## Features
7//!
8//! - **Memory Safety**: Built with Rust to prevent buffer overflows and memory corruption
9//! - **Async/Await**: High-performance concurrent scanning using Tokio
10//! - **Port Scanning**: Detect open ports and services
11//! - **Service Detection**: Banner grabbing and service fingerprinting
12//! - **Vulnerability Detection**: Known version vulnerability checking
13//! - **SIEM Integration**: JSON export for security monitoring
14//! - **Security Focus**: Designed for financial infrastructure security assessment
15//!
16//! ## Alignment with Federal Guidance
17//!
18//! Implements network security tools using memory-safe Rust, aligning with
19//! 2024 CISA/FBI guidance for critical infrastructure security tools.
20
21pub mod service_detection;
22pub use service_detection::{BannerGrabber, ServiceInfo, ServiceSignatures};
23
24use chrono::{DateTime, Utc};
25use futures::future::join_all;
26use serde::{Deserialize, Serialize};
27use std::net::{IpAddr, Ipv4Addr, SocketAddr};
28use std::time::Duration;
29use thiserror::Error;
30use tokio::io::{AsyncReadExt, AsyncWriteExt};
31use tokio::net::TcpStream;
32use tokio::time::timeout;
33
34/// Scanner errors
35#[derive(Error, Debug)]
36pub enum ScanError {
37    #[error("Connection timeout")]
38    Timeout,
39
40    #[error("Connection failed: {0}")]
41    ConnectionFailed(String),
42
43    #[error("Invalid IP address")]
44    InvalidIpAddress,
45
46    #[error("Invalid port range")]
47    InvalidPortRange,
48
49    #[error("Invalid subnet mask")]
50    InvalidSubnetMask,
51
52    #[error("Banner grab failed: {0}")]
53    BannerGrabFailed(String),
54}
55
56/// Port risk level for security assessment
57#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
58pub enum PortRiskLevel {
59    /// Critical risk (telnet, FTP, unencrypted protocols)
60    Critical,
61    /// High risk (database ports, RDP)
62    High,
63    /// Medium risk (HTTP, SMTP)
64    Medium,
65    /// Low risk (HTTPS, SSH with proper config)
66    Low,
67    /// Unknown risk
68    Unknown,
69}
70
71/// Port status
72#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
73pub enum PortStatus {
74    Open,
75    Closed,
76    Filtered,
77}
78
79/// Scan result for a single port
80#[derive(Debug, Clone, Serialize, Deserialize)]
81pub struct PortScanResult {
82    pub port: u16,
83    pub status: PortStatus,
84    pub service: Option<String>,
85    pub banner: Option<String>,
86    pub risk_level: PortRiskLevel,
87    pub timestamp: DateTime<Utc>,
88    pub response_time_ms: Option<u64>,
89}
90
91/// Complete scan result for a target
92#[derive(Debug, Clone, Serialize, Deserialize)]
93pub struct ScanResult {
94    pub target: String,
95    pub scan_start: DateTime<Utc>,
96    pub scan_end: DateTime<Utc>,
97    pub ports_scanned: usize,
98    pub open_ports: Vec<PortScanResult>,
99    pub closed_ports: usize,
100    pub filtered_ports: usize,
101}
102
103impl ScanResult {
104    /// Export scan results as JSON
105    pub fn to_json(&self) -> Result<String, serde_json::Error> {
106        serde_json::to_string_pretty(self)
107    }
108
109    /// Get summary statistics
110    pub fn summary(&self) -> String {
111        format!(
112            "Target: {} | Scanned: {} ports | Open: {} | Closed: {} | Filtered: {}",
113            self.target,
114            self.ports_scanned,
115            self.open_ports.len(),
116            self.closed_ports,
117            self.filtered_ports
118        )
119    }
120
121    /// Get high-risk open ports (Critical and High risk levels)
122    pub fn get_high_risk_ports(&self) -> Vec<&PortScanResult> {
123        self.open_ports
124            .iter()
125            .filter(|p| {
126                matches!(p.risk_level, PortRiskLevel::Critical | PortRiskLevel::High)
127            })
128            .collect()
129    }
130
131    /// Get ports by risk level
132    pub fn get_ports_by_risk(&self, risk: PortRiskLevel) -> Vec<&PortScanResult> {
133        self.open_ports
134            .iter()
135            .filter(|p| p.risk_level == risk)
136            .collect()
137    }
138
139    /// Calculate scan duration in seconds
140    pub fn scan_duration_secs(&self) -> f64 {
141        (self.scan_end - self.scan_start).num_milliseconds() as f64 / 1000.0
142    }
143}
144
145/// Network scanner configuration
146#[derive(Debug, Clone)]
147pub struct ScannerConfig {
148    pub timeout_ms: u64,
149    pub concurrent_scans: usize,
150    pub detect_services: bool,
151    pub grab_banners: bool,
152}
153
154impl Default for ScannerConfig {
155    fn default() -> Self {
156        Self {
157            timeout_ms: 1000,
158            concurrent_scans: 100,
159            detect_services: true,
160            grab_banners: false, // Disabled by default for performance
161        }
162    }
163}
164
165/// Network scanner
166pub struct NetworkScanner {
167    config: ScannerConfig,
168}
169
170impl NetworkScanner {
171    /// Create a new network scanner with default configuration
172    pub fn new() -> Self {
173        Self {
174            config: ScannerConfig::default(),
175        }
176    }
177
178    /// Create a new network scanner with custom configuration
179    pub fn with_config(config: ScannerConfig) -> Self {
180        Self { config }
181    }
182
183    /// Scan a single port
184    pub async fn scan_port(&self, ip: IpAddr, port: u16) -> PortScanResult {
185        let start = std::time::Instant::now();
186        let addr = SocketAddr::new(ip, port);
187
188        let mut stream_opt = None;
189        let status = match timeout(
190            Duration::from_millis(self.config.timeout_ms),
191            TcpStream::connect(addr),
192        )
193        .await
194        {
195            Ok(Ok(stream)) => {
196                stream_opt = Some(stream);
197                PortStatus::Open
198            }
199            Ok(Err(_)) => PortStatus::Closed,
200            Err(_) => PortStatus::Filtered,
201        };
202
203        let response_time = if status == PortStatus::Open {
204            Some(start.elapsed().as_millis() as u64)
205        } else {
206            None
207        };
208
209        let service = if status == PortStatus::Open && self.config.detect_services {
210            Self::detect_service(port)
211        } else {
212            None
213        };
214
215        let banner = if status == PortStatus::Open && self.config.grab_banners {
216            if let Some(mut stream) = stream_opt {
217                Self::grab_banner(&mut stream, port).await.ok()
218            } else {
219                None
220            }
221        } else {
222            None
223        };
224
225        let risk_level = Self::assess_port_risk(port);
226
227        PortScanResult {
228            port,
229            status,
230            service,
231            banner,
232            risk_level,
233            timestamp: Utc::now(),
234            response_time_ms: response_time,
235        }
236    }
237
238    /// Scan a range of ports
239    pub async fn scan_ports(
240        &self,
241        ip: IpAddr,
242        start_port: u16,
243        end_port: u16,
244    ) -> Result<ScanResult, ScanError> {
245        if start_port > end_port {
246            return Err(ScanError::InvalidPortRange);
247        }
248
249        let scan_start = Utc::now();
250        let target = ip.to_string();
251
252        // Create tasks for all ports
253        let mut tasks = Vec::new();
254        for port in start_port..=end_port {
255            let task = self.scan_port(ip, port);
256            tasks.push(task);
257
258            // Limit concurrent scans
259            if tasks.len() >= self.config.concurrent_scans {
260                let results = join_all(tasks).await;
261                tasks = Vec::new();
262                // Process results
263                for _ in results {
264                    // Results processed below
265                }
266            }
267        }
268
269        // Process remaining tasks
270        let mut all_results = join_all(tasks).await;
271
272        let scan_end = Utc::now();
273
274        // Separate results by status
275        let open_ports: Vec<PortScanResult> = all_results
276            .iter()
277            .filter(|r| r.status == PortStatus::Open)
278            .cloned()
279            .collect();
280
281        let closed_ports = all_results
282            .iter()
283            .filter(|r| r.status == PortStatus::Closed)
284            .count();
285
286        let filtered_ports = all_results
287            .iter()
288            .filter(|r| r.status == PortStatus::Filtered)
289            .count();
290
291        Ok(ScanResult {
292            target,
293            scan_start,
294            scan_end,
295            ports_scanned: all_results.len(),
296            open_ports,
297            closed_ports,
298            filtered_ports,
299        })
300    }
301
302    /// Scan common ports (top 20)
303    pub async fn scan_common_ports(&self, ip: IpAddr) -> Result<ScanResult, ScanError> {
304        let common_ports = vec![
305            20, 21, 22, 23, 25, 53, 80, 110, 143, 443, 445, 993, 995, 3306, 3389, 5432, 5900,
306            8080, 8443, 27017,
307        ];
308
309        let scan_start = Utc::now();
310        let target = ip.to_string();
311
312        let tasks: Vec<_> = common_ports
313            .iter()
314            .map(|&port| self.scan_port(ip, port))
315            .collect();
316
317        let results = join_all(tasks).await;
318        let scan_end = Utc::now();
319
320        let open_ports: Vec<PortScanResult> = results
321            .iter()
322            .filter(|r| r.status == PortStatus::Open)
323            .cloned()
324            .collect();
325
326        let closed_ports = results
327            .iter()
328            .filter(|r| r.status == PortStatus::Closed)
329            .count();
330
331        let filtered_ports = results
332            .iter()
333            .filter(|r| r.status == PortStatus::Filtered)
334            .count();
335
336        Ok(ScanResult {
337            target,
338            scan_start,
339            scan_end,
340            ports_scanned: results.len(),
341            open_ports,
342            closed_ports,
343            filtered_ports,
344        })
345    }
346
347    /// Simple service detection based on port number
348    fn detect_service(port: u16) -> Option<String> {
349        let service = match port {
350            20 => "FTP-DATA",
351            21 => "FTP",
352            22 => "SSH",
353            23 => "Telnet",
354            25 => "SMTP",
355            53 => "DNS",
356            80 => "HTTP",
357            110 => "POP3",
358            143 => "IMAP",
359            443 => "HTTPS",
360            445 => "SMB",
361            993 => "IMAPS",
362            995 => "POP3S",
363            3306 => "MySQL",
364            3389 => "RDP",
365            5432 => "PostgreSQL",
366            5900 => "VNC",
367            8080 => "HTTP-Proxy",
368            8443 => "HTTPS-Alt",
369            27017 => "MongoDB",
370            _ => "Unknown",
371        };
372
373        Some(service.to_string())
374    }
375
376    /// Grab service banner from open port
377    async fn grab_banner(stream: &mut TcpStream, port: u16) -> Result<String, ScanError> {
378        // Send protocol-specific probes
379        let probe: &[u8] = match port {
380            80 | 8080 => b"HEAD / HTTP/1.0\r\n\r\n",
381            21 | 22 | 23 | 25 => b"", // These typically send banner on connect
382            _ => b"", // Default: just read
383        };
384
385        if !probe.is_empty() {
386            let _ = timeout(
387                Duration::from_millis(500),
388                stream.write_all(probe),
389            )
390            .await;
391        }
392
393        let mut buffer = vec![0u8; 1024];
394        match timeout(Duration::from_millis(500), stream.read(&mut buffer)).await {
395            Ok(Ok(n)) if n > 0 => {
396                let banner = String::from_utf8_lossy(&buffer[..n])
397                    .trim()
398                    .to_string();
399                if !banner.is_empty() {
400                    Ok(banner)
401                } else {
402                    Err(ScanError::BannerGrabFailed("Empty response".to_string()))
403                }
404            }
405            _ => Err(ScanError::BannerGrabFailed("No response".to_string())),
406        }
407    }
408
409    /// Assess security risk level of a port
410    fn assess_port_risk(port: u16) -> PortRiskLevel {
411        match port {
412            // Critical: Unencrypted, legacy protocols
413            21 | 23 | 69 | 512..=514 => PortRiskLevel::Critical, // FTP, Telnet, TFTP, rlogin/rsh/rexec
414
415            // High: Database ports, RDP, administrative services
416            3306 | 5432 | 27017 | 6379 | // MySQL, PostgreSQL, MongoDB, Redis
417            3389 | 5900 | // RDP, VNC
418            445 | 139 | 135 | // SMB, NetBIOS
419            1433 | 1521 => PortRiskLevel::High, // MS-SQL, Oracle
420
421            // Medium: HTTP, mail servers
422            80 | 8080 | 8000 | // HTTP
423            25 | 110 | 143 => PortRiskLevel::Medium, // SMTP, POP3, IMAP
424
425            // Low: Encrypted protocols
426            22 | 443 | 8443 | 465 | 587 | 993 | 995 => PortRiskLevel::Low, // SSH, HTTPS, SMTPS, IMAPS, POP3S
427
428            _ => PortRiskLevel::Unknown,
429        }
430    }
431
432    /// Scan a subnet (CIDR notation, e.g., "192.168.1.0/24")
433    pub async fn scan_subnet(
434        &self,
435        subnet: &str,
436        ports: Vec<u16>,
437    ) -> Result<Vec<ScanResult>, ScanError> {
438        let (base_ip, mask) = Self::parse_cidr(subnet)?;
439        let hosts = Self::generate_host_ips(base_ip, mask);
440
441        let mut results = Vec::new();
442        for host_ip in hosts {
443            let scan_start = Utc::now();
444            let target = host_ip.to_string();
445
446            let tasks: Vec<_> = ports
447                .iter()
448                .map(|&port| self.scan_port(IpAddr::V4(host_ip), port))
449                .collect();
450
451            let port_results = join_all(tasks).await;
452            let scan_end = Utc::now();
453
454            let open_ports: Vec<PortScanResult> = port_results
455                .iter()
456                .filter(|r| r.status == PortStatus::Open)
457                .cloned()
458                .collect();
459
460            // Only include hosts with open ports
461            if !open_ports.is_empty() {
462                let closed_ports = port_results
463                    .iter()
464                    .filter(|r| r.status == PortStatus::Closed)
465                    .count();
466
467                let filtered_ports = port_results
468                    .iter()
469                    .filter(|r| r.status == PortStatus::Filtered)
470                    .count();
471
472                results.push(ScanResult {
473                    target,
474                    scan_start,
475                    scan_end,
476                    ports_scanned: port_results.len(),
477                    open_ports,
478                    closed_ports,
479                    filtered_ports,
480                });
481            }
482        }
483
484        Ok(results)
485    }
486
487    /// Parse CIDR notation (e.g., "192.168.1.0/24")
488    fn parse_cidr(cidr: &str) -> Result<(Ipv4Addr, u8), ScanError> {
489        let parts: Vec<&str> = cidr.split('/').collect();
490        if parts.len() != 2 {
491            return Err(ScanError::InvalidSubnetMask);
492        }
493
494        let ip = parts[0]
495            .parse::<Ipv4Addr>()
496            .map_err(|_| ScanError::InvalidIpAddress)?;
497
498        let mask = parts[1]
499            .parse::<u8>()
500            .map_err(|_| ScanError::InvalidSubnetMask)?;
501
502        if mask > 32 {
503            return Err(ScanError::InvalidSubnetMask);
504        }
505
506        Ok((ip, mask))
507    }
508
509    /// Generate list of host IPs in a subnet
510    fn generate_host_ips(base_ip: Ipv4Addr, mask: u8) -> Vec<Ipv4Addr> {
511        let ip_u32 = u32::from(base_ip);
512        let network_mask = !((1u32 << (32 - mask)) - 1);
513        let network_addr = ip_u32 & network_mask;
514        let host_count = (1u32 << (32 - mask)).saturating_sub(2); // Exclude network and broadcast
515
516        let mut ips = Vec::new();
517        for i in 1..=host_count.min(254) {
518            // Limit to prevent huge scans
519            let host_ip = Ipv4Addr::from(network_addr + i);
520            ips.push(host_ip);
521        }
522
523        ips
524    }
525}
526
527impl Default for NetworkScanner {
528    fn default() -> Self {
529        Self::new()
530    }
531}
532
533#[cfg(test)]
534mod tests {
535    use super::*;
536    use std::str::FromStr;
537
538    #[tokio::test]
539    async fn test_scan_single_port() {
540        let scanner = NetworkScanner::new();
541        let ip = IpAddr::from_str("127.0.0.1").unwrap();
542
543        // Scan a port that's likely closed
544        let result = scanner.scan_port(ip, 9999).await;
545
546        assert!(result.port == 9999);
547        // Status could be Closed or Filtered depending on system
548        assert!(
549            result.status == PortStatus::Closed || result.status == PortStatus::Filtered
550        );
551    }
552
553    #[tokio::test]
554    async fn test_service_detection() {
555        assert_eq!(
556            NetworkScanner::detect_service(80),
557            Some("HTTP".to_string())
558        );
559        assert_eq!(
560            NetworkScanner::detect_service(443),
561            Some("HTTPS".to_string())
562        );
563        assert_eq!(NetworkScanner::detect_service(22), Some("SSH".to_string()));
564    }
565
566    #[tokio::test]
567    async fn test_scan_result_summary() {
568        let result = ScanResult {
569            target: "192.168.1.1".to_string(),
570            scan_start: Utc::now(),
571            scan_end: Utc::now(),
572            ports_scanned: 100,
573            open_ports: vec![],
574            closed_ports: 95,
575            filtered_ports: 5,
576        };
577
578        let summary = result.summary();
579        assert!(summary.contains("192.168.1.1"));
580        assert!(summary.contains("100"));
581    }
582
583    #[tokio::test]
584    async fn test_invalid_port_range() {
585        let scanner = NetworkScanner::new();
586        let ip = IpAddr::from_str("127.0.0.1").unwrap();
587
588        let result = scanner.scan_ports(ip, 100, 50).await;
589        assert!(result.is_err());
590    }
591
592    #[test]
593    fn test_port_risk_assessment() {
594        // Critical risk
595        assert_eq!(
596            NetworkScanner::assess_port_risk(23),
597            PortRiskLevel::Critical
598        ); // Telnet
599        assert_eq!(NetworkScanner::assess_port_risk(21), PortRiskLevel::Critical); // FTP
600
601        // High risk
602        assert_eq!(NetworkScanner::assess_port_risk(3389), PortRiskLevel::High); // RDP
603        assert_eq!(NetworkScanner::assess_port_risk(3306), PortRiskLevel::High); // MySQL
604
605        // Medium risk
606        assert_eq!(NetworkScanner::assess_port_risk(80), PortRiskLevel::Medium); // HTTP
607
608        // Low risk
609        assert_eq!(NetworkScanner::assess_port_risk(443), PortRiskLevel::Low); // HTTPS
610        assert_eq!(NetworkScanner::assess_port_risk(22), PortRiskLevel::Low); // SSH
611    }
612
613    #[test]
614    fn test_cidr_parsing() {
615        let result = NetworkScanner::parse_cidr("192.168.1.0/24");
616        assert!(result.is_ok());
617        let (ip, mask) = result.unwrap();
618        assert_eq!(ip.to_string(), "192.168.1.0");
619        assert_eq!(mask, 24);
620
621        // Invalid CIDR
622        assert!(NetworkScanner::parse_cidr("192.168.1.0").is_err());
623        assert!(NetworkScanner::parse_cidr("192.168.1.0/33").is_err());
624        assert!(NetworkScanner::parse_cidr("invalid/24").is_err());
625    }
626
627    #[test]
628    fn test_host_ip_generation() {
629        let base_ip = Ipv4Addr::new(192, 168, 1, 0);
630        let ips = NetworkScanner::generate_host_ips(base_ip, 30); // /30 = 2 usable hosts
631
632        assert_eq!(ips.len(), 2);
633        assert_eq!(ips[0], Ipv4Addr::new(192, 168, 1, 1));
634        assert_eq!(ips[1], Ipv4Addr::new(192, 168, 1, 2));
635    }
636
637    #[tokio::test]
638    async fn test_scan_result_high_risk_ports() {
639        let result = ScanResult {
640            target: "192.168.1.1".to_string(),
641            scan_start: Utc::now(),
642            scan_end: Utc::now(),
643            ports_scanned: 5,
644            open_ports: vec![
645                PortScanResult {
646                    port: 23,
647                    status: PortStatus::Open,
648                    service: Some("Telnet".to_string()),
649                    banner: None,
650                    risk_level: PortRiskLevel::Critical,
651                    timestamp: Utc::now(),
652                    response_time_ms: Some(10),
653                },
654                PortScanResult {
655                    port: 3389,
656                    status: PortStatus::Open,
657                    service: Some("RDP".to_string()),
658                    banner: None,
659                    risk_level: PortRiskLevel::High,
660                    timestamp: Utc::now(),
661                    response_time_ms: Some(15),
662                },
663                PortScanResult {
664                    port: 443,
665                    status: PortStatus::Open,
666                    service: Some("HTTPS".to_string()),
667                    banner: None,
668                    risk_level: PortRiskLevel::Low,
669                    timestamp: Utc::now(),
670                    response_time_ms: Some(5),
671                },
672            ],
673            closed_ports: 2,
674            filtered_ports: 0,
675        };
676
677        let high_risk = result.get_high_risk_ports();
678        assert_eq!(high_risk.len(), 2); // Telnet (Critical) + RDP (High)
679
680        let critical_ports = result.get_ports_by_risk(PortRiskLevel::Critical);
681        assert_eq!(critical_ports.len(), 1);
682        assert_eq!(critical_ports[0].port, 23);
683    }
684
685    #[tokio::test]
686    async fn test_scan_duration_calculation() {
687        let start = Utc::now();
688        tokio::time::sleep(Duration::from_millis(100)).await;
689        let end = Utc::now();
690
691        let result = ScanResult {
692            target: "127.0.0.1".to_string(),
693            scan_start: start,
694            scan_end: end,
695            ports_scanned: 10,
696            open_ports: vec![],
697            closed_ports: 10,
698            filtered_ports: 0,
699        };
700
701        let duration = result.scan_duration_secs();
702        assert!(duration >= 0.1 && duration < 1.0);
703    }
704
705    #[tokio::test]
706    async fn test_json_export() {
707        let result = ScanResult {
708            target: "192.168.1.100".to_string(),
709            scan_start: Utc::now(),
710            scan_end: Utc::now(),
711            ports_scanned: 3,
712            open_ports: vec![PortScanResult {
713                port: 80,
714                status: PortStatus::Open,
715                service: Some("HTTP".to_string()),
716                banner: Some("Server: nginx/1.18.0".to_string()),
717                risk_level: PortRiskLevel::Medium,
718                timestamp: Utc::now(),
719                response_time_ms: Some(12),
720            }],
721            closed_ports: 2,
722            filtered_ports: 0,
723        };
724
725        let json = result.to_json();
726        assert!(json.is_ok());
727        let json_str = json.unwrap();
728        assert!(json_str.contains("192.168.1.100"));
729        assert!(json_str.contains("HTTP"));
730        assert!(json_str.contains("nginx"));
731    }
732
733    #[test]
734    fn test_scanner_config_defaults() {
735        let config = ScannerConfig::default();
736        assert_eq!(config.timeout_ms, 1000);
737        assert_eq!(config.concurrent_scans, 100);
738        assert!(config.detect_services);
739        assert!(!config.grab_banners); // Disabled by default
740    }
741
742    #[tokio::test]
743    async fn test_scanner_with_custom_config() {
744        let config = ScannerConfig {
745            timeout_ms: 500,
746            concurrent_scans: 50,
747            detect_services: true,
748            grab_banners: false,
749        };
750
751        let scanner = NetworkScanner::with_config(config);
752        let ip = IpAddr::from_str("127.0.0.1").unwrap();
753
754        let result = scanner.scan_port(ip, 9999).await;
755        assert!(result.port == 9999);
756    }
757}