rust_network_scanner/
service_detection.rs

1//! Service detection and banner grabbing
2
3use std::time::Duration;
4use tokio::io::{AsyncReadExt, AsyncWriteExt};
5use tokio::net::TcpStream;
6use tokio::time::timeout;
7
8/// Known service signatures
9pub struct ServiceSignatures;
10
11impl ServiceSignatures {
12    /// Detect service from banner
13    pub fn detect_from_banner(banner: &str) -> Option<ServiceInfo> {
14        let banner_lower = banner.to_lowercase();
15
16        if banner_lower.contains("ssh") {
17            return Some(ServiceInfo {
18                name: "SSH".to_string(),
19                version: Self::extract_ssh_version(banner),
20                vendor: Self::extract_vendor(&banner_lower),
21            });
22        }
23
24        if banner_lower.contains("http") || banner_lower.contains("server:") {
25            return Some(ServiceInfo {
26                name: "HTTP".to_string(),
27                version: Self::extract_http_version(banner),
28                vendor: Self::extract_vendor(&banner_lower),
29            });
30        }
31
32        if banner_lower.contains("ftp") {
33            return Some(ServiceInfo {
34                name: "FTP".to_string(),
35                version: Self::extract_version_generic(banner),
36                vendor: Self::extract_vendor(&banner_lower),
37            });
38        }
39
40        if banner_lower.contains("smtp") || banner_lower.contains("postfix") {
41            return Some(ServiceInfo {
42                name: "SMTP".to_string(),
43                version: Self::extract_version_generic(banner),
44                vendor: Self::extract_vendor(&banner_lower),
45            });
46        }
47
48        if banner_lower.contains("mysql") {
49            return Some(ServiceInfo {
50                name: "MySQL".to_string(),
51                version: Self::extract_mysql_version(banner),
52                vendor: Some("Oracle".to_string()),
53            });
54        }
55
56        if banner_lower.contains("postgresql") || banner_lower.contains("postgres") {
57            return Some(ServiceInfo {
58                name: "PostgreSQL".to_string(),
59                version: Self::extract_version_generic(banner),
60                vendor: Some("PostgreSQL Global Development Group".to_string()),
61            });
62        }
63
64        None
65    }
66
67    fn extract_ssh_version(banner: &str) -> Option<String> {
68        if let Some(start) = banner.find("SSH-") {
69            let version_str = &banner[start..];
70            if let Some(end) = version_str.find(|c: char| c.is_whitespace() || c == '\r' || c == '\n') {
71                return Some(version_str[..end].to_string());
72            }
73        }
74        None
75    }
76
77    fn extract_http_version(banner: &str) -> Option<String> {
78        for line in banner.lines() {
79            if line.to_lowercase().starts_with("server:") {
80                return Some(line[7..].trim().to_string());
81            }
82        }
83        None
84    }
85
86    fn extract_mysql_version(banner: &str) -> Option<String> {
87        // MySQL version typically in format: 5.7.32 or 8.0.23
88        let re = regex::Regex::new(r"(\d+\.\d+\.\d+)").ok()?;
89        re.captures(banner)
90            .and_then(|cap| cap.get(1))
91            .map(|m| m.as_str().to_string())
92    }
93
94    fn extract_version_generic(banner: &str) -> Option<String> {
95        let re = regex::Regex::new(r"(\d+\.\d+(?:\.\d+)?)").ok()?;
96        re.captures(banner)
97            .and_then(|cap| cap.get(1))
98            .map(|m| m.as_str().to_string())
99    }
100
101    fn extract_vendor(banner: &str) -> Option<String> {
102        if banner.contains("apache") {
103            Some("Apache Software Foundation".to_string())
104        } else if banner.contains("nginx") {
105            Some("Nginx Inc.".to_string())
106        } else if banner.contains("microsoft") || banner.contains("iis") {
107            Some("Microsoft".to_string())
108        } else if banner.contains("openssh") {
109            Some("OpenSSH".to_string())
110        } else if banner.contains("postfix") {
111            Some("Postfix".to_string())
112        } else {
113            None
114        }
115    }
116}
117
118/// Service information
119#[derive(Debug, Clone)]
120pub struct ServiceInfo {
121    pub name: String,
122    pub version: Option<String>,
123    pub vendor: Option<String>,
124}
125
126impl ServiceInfo {
127    /// Check if this is a known vulnerable version
128    pub fn has_known_vulnerabilities(&self) -> bool {
129        // Simplified vulnerability check
130        if let Some(ref version) = self.version {
131            // Example: old SSH versions
132            if self.name == "SSH" && version.contains("OpenSSH_7.") {
133                return true;
134            }
135            // Example: old Apache versions
136            if self.name == "HTTP" && version.contains("Apache/2.2") {
137                return true;
138            }
139        }
140        false
141    }
142
143    /// Get severity if vulnerable
144    pub fn vulnerability_severity(&self) -> Option<&'static str> {
145        if self.has_known_vulnerabilities() {
146            Some("HIGH")
147        } else {
148            None
149        }
150    }
151}
152
153/// Banner grabber
154pub struct BannerGrabber;
155
156impl BannerGrabber {
157    /// Grab banner from service
158    pub async fn grab_banner(host: &str, port: u16, timeout_ms: u64) -> Result<String, String> {
159        let addr = format!("{}:{}", host, port);
160        let connect_timeout = Duration::from_millis(timeout_ms);
161
162        let stream = timeout(connect_timeout, TcpStream::connect(&addr))
163            .await
164            .map_err(|_| "Connection timeout".to_string())?
165            .map_err(|e| format!("Connection failed: {}", e))?;
166
167        let mut stream = stream;
168        let mut buffer = vec![0u8; 1024];
169
170        // Read initial banner
171        let read_timeout = Duration::from_millis(timeout_ms);
172        let n = timeout(read_timeout, stream.read(&mut buffer))
173            .await
174            .map_err(|_| "Read timeout".to_string())?
175            .map_err(|e| format!("Read failed: {}", e))?;
176
177        if n == 0 {
178            return Err("No data received".to_string());
179        }
180
181        let banner = String::from_utf8_lossy(&buffer[..n]).to_string();
182        Ok(banner)
183    }
184
185    /// Grab banner with HTTP probe
186    pub async fn grab_http_banner(host: &str, port: u16, timeout_ms: u64) -> Result<String, String> {
187        let addr = format!("{}:{}", host, port);
188        let connect_timeout = Duration::from_millis(timeout_ms);
189
190        let stream = timeout(connect_timeout, TcpStream::connect(&addr))
191            .await
192            .map_err(|_| "Connection timeout".to_string())?
193            .map_err(|e| format!("Connection failed: {}", e))?;
194
195        let mut stream = stream;
196
197        // Send HTTP GET request
198        let request = format!(
199            "GET / HTTP/1.1\r\nHost: {}\r\nUser-Agent: NetworkScanner/1.0\r\n\r\n",
200            host
201        );
202
203        timeout(Duration::from_millis(timeout_ms), stream.write_all(request.as_bytes()))
204            .await
205            .map_err(|_| "Write timeout".to_string())?
206            .map_err(|e| format!("Write failed: {}", e))?;
207
208        // Read response
209        let mut buffer = vec![0u8; 4096];
210        let read_timeout = Duration::from_millis(timeout_ms);
211        let n = timeout(read_timeout, stream.read(&mut buffer))
212            .await
213            .map_err(|_| "Read timeout".to_string())?
214            .map_err(|e| format!("Read failed: {}", e))?;
215
216        let response = String::from_utf8_lossy(&buffer[..n]).to_string();
217        Ok(response)
218    }
219}
220
221#[cfg(test)]
222mod tests {
223    use super::*;
224
225    #[test]
226    fn test_ssh_detection() {
227        let banner = "SSH-2.0-OpenSSH_8.2p1 Ubuntu-4ubuntu0.5";
228        let service = ServiceSignatures::detect_from_banner(banner).unwrap();
229        assert_eq!(service.name, "SSH");
230        assert!(service.version.is_some());
231    }
232
233    #[test]
234    fn test_http_detection() {
235        let banner = "HTTP/1.1 200 OK\r\nServer: Apache/2.4.41 (Ubuntu)\r\n";
236        let service = ServiceSignatures::detect_from_banner(banner).unwrap();
237        assert_eq!(service.name, "HTTP");
238    }
239
240    #[test]
241    fn test_mysql_detection() {
242        let banner = "5.7.32-0ubuntu0.18.04.1";
243        let service = ServiceSignatures::detect_from_banner("mysql 5.7.32-0ubuntu0.18.04.1").unwrap();
244        assert_eq!(service.name, "MySQL");
245    }
246
247    #[test]
248    fn test_vulnerability_check() {
249        let service = ServiceInfo {
250            name: "SSH".to_string(),
251            version: Some("OpenSSH_7.4".to_string()),
252            vendor: Some("OpenSSH".to_string()),
253        };
254        assert!(service.has_known_vulnerabilities());
255        assert_eq!(service.vulnerability_severity(), Some("HIGH"));
256    }
257
258    #[test]
259    fn test_no_vulnerability() {
260        let service = ServiceInfo {
261            name: "SSH".to_string(),
262            version: Some("OpenSSH_9.0".to_string()),
263            vendor: Some("OpenSSH".to_string()),
264        };
265        assert!(!service.has_known_vulnerabilities());
266    }
267}