rust_network_scanner/
service_detection.rs1use std::time::Duration;
4use tokio::io::{AsyncReadExt, AsyncWriteExt};
5use tokio::net::TcpStream;
6use tokio::time::timeout;
7
8pub struct ServiceSignatures;
10
11impl ServiceSignatures {
12 pub fn detect_from_banner(banner: &str) -> Option<ServiceInfo> {
14 let banner_lower = banner.to_lowercase();
15
16 if banner_lower.contains("ssh") {
17 return Some(ServiceInfo {
18 name: "SSH".to_string(),
19 version: Self::extract_ssh_version(banner),
20 vendor: Self::extract_vendor(&banner_lower),
21 });
22 }
23
24 if banner_lower.contains("http") || banner_lower.contains("server:") {
25 return Some(ServiceInfo {
26 name: "HTTP".to_string(),
27 version: Self::extract_http_version(banner),
28 vendor: Self::extract_vendor(&banner_lower),
29 });
30 }
31
32 if banner_lower.contains("ftp") {
33 return Some(ServiceInfo {
34 name: "FTP".to_string(),
35 version: Self::extract_version_generic(banner),
36 vendor: Self::extract_vendor(&banner_lower),
37 });
38 }
39
40 if banner_lower.contains("smtp") || banner_lower.contains("postfix") {
41 return Some(ServiceInfo {
42 name: "SMTP".to_string(),
43 version: Self::extract_version_generic(banner),
44 vendor: Self::extract_vendor(&banner_lower),
45 });
46 }
47
48 if banner_lower.contains("mysql") {
49 return Some(ServiceInfo {
50 name: "MySQL".to_string(),
51 version: Self::extract_mysql_version(banner),
52 vendor: Some("Oracle".to_string()),
53 });
54 }
55
56 if banner_lower.contains("postgresql") || banner_lower.contains("postgres") {
57 return Some(ServiceInfo {
58 name: "PostgreSQL".to_string(),
59 version: Self::extract_version_generic(banner),
60 vendor: Some("PostgreSQL Global Development Group".to_string()),
61 });
62 }
63
64 None
65 }
66
67 fn extract_ssh_version(banner: &str) -> Option<String> {
68 if let Some(start) = banner.find("SSH-") {
69 let version_str = &banner[start..];
70 if let Some(end) = version_str.find(|c: char| c.is_whitespace() || c == '\r' || c == '\n') {
71 return Some(version_str[..end].to_string());
72 }
73 }
74 None
75 }
76
77 fn extract_http_version(banner: &str) -> Option<String> {
78 for line in banner.lines() {
79 if line.to_lowercase().starts_with("server:") {
80 return Some(line[7..].trim().to_string());
81 }
82 }
83 None
84 }
85
86 fn extract_mysql_version(banner: &str) -> Option<String> {
87 let re = regex::Regex::new(r"(\d+\.\d+\.\d+)").ok()?;
89 re.captures(banner)
90 .and_then(|cap| cap.get(1))
91 .map(|m| m.as_str().to_string())
92 }
93
94 fn extract_version_generic(banner: &str) -> Option<String> {
95 let re = regex::Regex::new(r"(\d+\.\d+(?:\.\d+)?)").ok()?;
96 re.captures(banner)
97 .and_then(|cap| cap.get(1))
98 .map(|m| m.as_str().to_string())
99 }
100
101 fn extract_vendor(banner: &str) -> Option<String> {
102 if banner.contains("apache") {
103 Some("Apache Software Foundation".to_string())
104 } else if banner.contains("nginx") {
105 Some("Nginx Inc.".to_string())
106 } else if banner.contains("microsoft") || banner.contains("iis") {
107 Some("Microsoft".to_string())
108 } else if banner.contains("openssh") {
109 Some("OpenSSH".to_string())
110 } else if banner.contains("postfix") {
111 Some("Postfix".to_string())
112 } else {
113 None
114 }
115 }
116}
117
118#[derive(Debug, Clone)]
120pub struct ServiceInfo {
121 pub name: String,
122 pub version: Option<String>,
123 pub vendor: Option<String>,
124}
125
126impl ServiceInfo {
127 pub fn has_known_vulnerabilities(&self) -> bool {
129 if let Some(ref version) = self.version {
131 if self.name == "SSH" && version.contains("OpenSSH_7.") {
133 return true;
134 }
135 if self.name == "HTTP" && version.contains("Apache/2.2") {
137 return true;
138 }
139 }
140 false
141 }
142
143 pub fn vulnerability_severity(&self) -> Option<&'static str> {
145 if self.has_known_vulnerabilities() {
146 Some("HIGH")
147 } else {
148 None
149 }
150 }
151}
152
153pub struct BannerGrabber;
155
156impl BannerGrabber {
157 pub async fn grab_banner(host: &str, port: u16, timeout_ms: u64) -> Result<String, String> {
159 let addr = format!("{}:{}", host, port);
160 let connect_timeout = Duration::from_millis(timeout_ms);
161
162 let stream = timeout(connect_timeout, TcpStream::connect(&addr))
163 .await
164 .map_err(|_| "Connection timeout".to_string())?
165 .map_err(|e| format!("Connection failed: {}", e))?;
166
167 let mut stream = stream;
168 let mut buffer = vec![0u8; 1024];
169
170 let read_timeout = Duration::from_millis(timeout_ms);
172 let n = timeout(read_timeout, stream.read(&mut buffer))
173 .await
174 .map_err(|_| "Read timeout".to_string())?
175 .map_err(|e| format!("Read failed: {}", e))?;
176
177 if n == 0 {
178 return Err("No data received".to_string());
179 }
180
181 let banner = String::from_utf8_lossy(&buffer[..n]).to_string();
182 Ok(banner)
183 }
184
185 pub async fn grab_http_banner(host: &str, port: u16, timeout_ms: u64) -> Result<String, String> {
187 let addr = format!("{}:{}", host, port);
188 let connect_timeout = Duration::from_millis(timeout_ms);
189
190 let stream = timeout(connect_timeout, TcpStream::connect(&addr))
191 .await
192 .map_err(|_| "Connection timeout".to_string())?
193 .map_err(|e| format!("Connection failed: {}", e))?;
194
195 let mut stream = stream;
196
197 let request = format!(
199 "GET / HTTP/1.1\r\nHost: {}\r\nUser-Agent: NetworkScanner/1.0\r\n\r\n",
200 host
201 );
202
203 timeout(Duration::from_millis(timeout_ms), stream.write_all(request.as_bytes()))
204 .await
205 .map_err(|_| "Write timeout".to_string())?
206 .map_err(|e| format!("Write failed: {}", e))?;
207
208 let mut buffer = vec![0u8; 4096];
210 let read_timeout = Duration::from_millis(timeout_ms);
211 let n = timeout(read_timeout, stream.read(&mut buffer))
212 .await
213 .map_err(|_| "Read timeout".to_string())?
214 .map_err(|e| format!("Read failed: {}", e))?;
215
216 let response = String::from_utf8_lossy(&buffer[..n]).to_string();
217 Ok(response)
218 }
219}
220
221#[cfg(test)]
222mod tests {
223 use super::*;
224
225 #[test]
226 fn test_ssh_detection() {
227 let banner = "SSH-2.0-OpenSSH_8.2p1 Ubuntu-4ubuntu0.5";
228 let service = ServiceSignatures::detect_from_banner(banner).unwrap();
229 assert_eq!(service.name, "SSH");
230 assert!(service.version.is_some());
231 }
232
233 #[test]
234 fn test_http_detection() {
235 let banner = "HTTP/1.1 200 OK\r\nServer: Apache/2.4.41 (Ubuntu)\r\n";
236 let service = ServiceSignatures::detect_from_banner(banner).unwrap();
237 assert_eq!(service.name, "HTTP");
238 }
239
240 #[test]
241 fn test_mysql_detection() {
242 let banner = "5.7.32-0ubuntu0.18.04.1";
243 let service = ServiceSignatures::detect_from_banner("mysql 5.7.32-0ubuntu0.18.04.1").unwrap();
244 assert_eq!(service.name, "MySQL");
245 }
246
247 #[test]
248 fn test_vulnerability_check() {
249 let service = ServiceInfo {
250 name: "SSH".to_string(),
251 version: Some("OpenSSH_7.4".to_string()),
252 vendor: Some("OpenSSH".to_string()),
253 };
254 assert!(service.has_known_vulnerabilities());
255 assert_eq!(service.vulnerability_severity(), Some("HIGH"));
256 }
257
258 #[test]
259 fn test_no_vulnerability() {
260 let service = ServiceInfo {
261 name: "SSH".to_string(),
262 version: Some("OpenSSH_9.0".to_string()),
263 vendor: Some("OpenSSH".to_string()),
264 };
265 assert!(!service.has_known_vulnerabilities());
266 }
267}