rust_network_scanner/
service_detection.rs

1//! Service detection and banner grabbing
2
3use std::time::Duration;
4use tokio::io::{AsyncReadExt, AsyncWriteExt};
5use tokio::net::TcpStream;
6use tokio::time::timeout;
7
8/// Known service signatures
9pub struct ServiceSignatures;
10
11impl ServiceSignatures {
12    /// Detect service from banner
13    pub fn detect_from_banner(banner: &str) -> Option<ServiceInfo> {
14        let banner_lower = banner.to_lowercase();
15
16        if banner_lower.contains("ssh") {
17            return Some(ServiceInfo {
18                name: "SSH".to_string(),
19                version: Self::extract_ssh_version(banner),
20                vendor: Self::extract_vendor(&banner_lower),
21            });
22        }
23
24        if banner_lower.contains("http") || banner_lower.contains("server:") {
25            return Some(ServiceInfo {
26                name: "HTTP".to_string(),
27                version: Self::extract_http_version(banner),
28                vendor: Self::extract_vendor(&banner_lower),
29            });
30        }
31
32        if banner_lower.contains("ftp") {
33            return Some(ServiceInfo {
34                name: "FTP".to_string(),
35                version: Self::extract_version_generic(banner),
36                vendor: Self::extract_vendor(&banner_lower),
37            });
38        }
39
40        if banner_lower.contains("smtp") || banner_lower.contains("postfix") {
41            return Some(ServiceInfo {
42                name: "SMTP".to_string(),
43                version: Self::extract_version_generic(banner),
44                vendor: Self::extract_vendor(&banner_lower),
45            });
46        }
47
48        if banner_lower.contains("mysql") {
49            return Some(ServiceInfo {
50                name: "MySQL".to_string(),
51                version: Self::extract_mysql_version(banner),
52                vendor: Some("Oracle".to_string()),
53            });
54        }
55
56        if banner_lower.contains("postgresql") || banner_lower.contains("postgres") {
57            return Some(ServiceInfo {
58                name: "PostgreSQL".to_string(),
59                version: Self::extract_version_generic(banner),
60                vendor: Some("PostgreSQL Global Development Group".to_string()),
61            });
62        }
63
64        None
65    }
66
67    fn extract_ssh_version(banner: &str) -> Option<String> {
68        if let Some(start) = banner.find("SSH-") {
69            let version_str = &banner[start..];
70            if let Some(end) =
71                version_str.find(|c: char| c.is_whitespace() || c == '\r' || c == '\n')
72            {
73                return Some(version_str[..end].to_string());
74            }
75        }
76        None
77    }
78
79    fn extract_http_version(banner: &str) -> Option<String> {
80        for line in banner.lines() {
81            if line.to_lowercase().starts_with("server:") {
82                return Some(line[7..].trim().to_string());
83            }
84        }
85        None
86    }
87
88    fn extract_mysql_version(banner: &str) -> Option<String> {
89        // MySQL version typically in format: 5.7.32 or 8.0.23
90        let re = regex::Regex::new(r"(\d+\.\d+\.\d+)").ok()?;
91        re.captures(banner)
92            .and_then(|cap| cap.get(1))
93            .map(|m| m.as_str().to_string())
94    }
95
96    fn extract_version_generic(banner: &str) -> Option<String> {
97        let re = regex::Regex::new(r"(\d+\.\d+(?:\.\d+)?)").ok()?;
98        re.captures(banner)
99            .and_then(|cap| cap.get(1))
100            .map(|m| m.as_str().to_string())
101    }
102
103    fn extract_vendor(banner: &str) -> Option<String> {
104        if banner.contains("apache") {
105            Some("Apache Software Foundation".to_string())
106        } else if banner.contains("nginx") {
107            Some("Nginx Inc.".to_string())
108        } else if banner.contains("microsoft") || banner.contains("iis") {
109            Some("Microsoft".to_string())
110        } else if banner.contains("openssh") {
111            Some("OpenSSH".to_string())
112        } else if banner.contains("postfix") {
113            Some("Postfix".to_string())
114        } else {
115            None
116        }
117    }
118}
119
120/// Service information
121#[derive(Debug, Clone)]
122pub struct ServiceInfo {
123    pub name: String,
124    pub version: Option<String>,
125    pub vendor: Option<String>,
126}
127
128impl ServiceInfo {
129    /// Check if this is a known vulnerable version
130    pub fn has_known_vulnerabilities(&self) -> bool {
131        // Simplified vulnerability check
132        if let Some(ref version) = self.version {
133            // Example: old SSH versions
134            if self.name == "SSH" && version.contains("OpenSSH_7.") {
135                return true;
136            }
137            // Example: old Apache versions
138            if self.name == "HTTP" && version.contains("Apache/2.2") {
139                return true;
140            }
141        }
142        false
143    }
144
145    /// Get severity if vulnerable
146    pub fn vulnerability_severity(&self) -> Option<&'static str> {
147        if self.has_known_vulnerabilities() {
148            Some("HIGH")
149        } else {
150            None
151        }
152    }
153}
154
155/// Banner grabber
156pub struct BannerGrabber;
157
158impl BannerGrabber {
159    /// Grab banner from service
160    pub async fn grab_banner(host: &str, port: u16, timeout_ms: u64) -> Result<String, String> {
161        let addr = format!("{}:{}", host, port);
162        let connect_timeout = Duration::from_millis(timeout_ms);
163
164        let stream = timeout(connect_timeout, TcpStream::connect(&addr))
165            .await
166            .map_err(|_| "Connection timeout".to_string())?
167            .map_err(|e| format!("Connection failed: {}", e))?;
168
169        let mut stream = stream;
170        let mut buffer = vec![0u8; 1024];
171
172        // Read initial banner
173        let read_timeout = Duration::from_millis(timeout_ms);
174        let n = timeout(read_timeout, stream.read(&mut buffer))
175            .await
176            .map_err(|_| "Read timeout".to_string())?
177            .map_err(|e| format!("Read failed: {}", e))?;
178
179        if n == 0 {
180            return Err("No data received".to_string());
181        }
182
183        let banner = String::from_utf8_lossy(&buffer[..n]).to_string();
184        Ok(banner)
185    }
186
187    /// Grab banner with HTTP probe
188    pub async fn grab_http_banner(
189        host: &str,
190        port: u16,
191        timeout_ms: u64,
192    ) -> Result<String, String> {
193        let addr = format!("{}:{}", host, port);
194        let connect_timeout = Duration::from_millis(timeout_ms);
195
196        let stream = timeout(connect_timeout, TcpStream::connect(&addr))
197            .await
198            .map_err(|_| "Connection timeout".to_string())?
199            .map_err(|e| format!("Connection failed: {}", e))?;
200
201        let mut stream = stream;
202
203        // Send HTTP GET request
204        let request = format!(
205            "GET / HTTP/1.1\r\nHost: {}\r\nUser-Agent: NetworkScanner/1.0\r\n\r\n",
206            host
207        );
208
209        timeout(
210            Duration::from_millis(timeout_ms),
211            stream.write_all(request.as_bytes()),
212        )
213        .await
214        .map_err(|_| "Write timeout".to_string())?
215        .map_err(|e| format!("Write failed: {}", e))?;
216
217        // Read response
218        let mut buffer = vec![0u8; 4096];
219        let read_timeout = Duration::from_millis(timeout_ms);
220        let n = timeout(read_timeout, stream.read(&mut buffer))
221            .await
222            .map_err(|_| "Read timeout".to_string())?
223            .map_err(|e| format!("Read failed: {}", e))?;
224
225        let response = String::from_utf8_lossy(&buffer[..n]).to_string();
226        Ok(response)
227    }
228}
229
230#[cfg(test)]
231mod tests {
232    use super::*;
233
234    #[test]
235    fn test_ssh_detection() {
236        let banner = "SSH-2.0-OpenSSH_8.2p1 Ubuntu-4ubuntu0.5";
237        let service = ServiceSignatures::detect_from_banner(banner).unwrap();
238        assert_eq!(service.name, "SSH");
239        assert!(service.version.is_some());
240    }
241
242    #[test]
243    fn test_http_detection() {
244        let banner = "HTTP/1.1 200 OK\r\nServer: Apache/2.4.41 (Ubuntu)\r\n";
245        let service = ServiceSignatures::detect_from_banner(banner).unwrap();
246        assert_eq!(service.name, "HTTP");
247    }
248
249    #[test]
250    fn test_mysql_detection() {
251        let _banner = "5.7.32-0ubuntu0.18.04.1";
252        let service =
253            ServiceSignatures::detect_from_banner("mysql 5.7.32-0ubuntu0.18.04.1").unwrap();
254        assert_eq!(service.name, "MySQL");
255    }
256
257    #[test]
258    fn test_vulnerability_check() {
259        let service = ServiceInfo {
260            name: "SSH".to_string(),
261            version: Some("OpenSSH_7.4".to_string()),
262            vendor: Some("OpenSSH".to_string()),
263        };
264        assert!(service.has_known_vulnerabilities());
265        assert_eq!(service.vulnerability_severity(), Some("HIGH"));
266    }
267
268    #[test]
269    fn test_no_vulnerability() {
270        let service = ServiceInfo {
271            name: "SSH".to_string(),
272            version: Some("OpenSSH_9.0".to_string()),
273            vendor: Some("OpenSSH".to_string()),
274        };
275        assert!(!service.has_known_vulnerabilities());
276    }
277}