rust_network_scanner/
service_detection.rs1use std::time::Duration;
4use tokio::io::{AsyncReadExt, AsyncWriteExt};
5use tokio::net::TcpStream;
6use tokio::time::timeout;
7
8pub struct ServiceSignatures;
10
11impl ServiceSignatures {
12 pub fn detect_from_banner(banner: &str) -> Option<ServiceInfo> {
14 let banner_lower = banner.to_lowercase();
15
16 if banner_lower.contains("ssh") {
17 return Some(ServiceInfo {
18 name: "SSH".to_string(),
19 version: Self::extract_ssh_version(banner),
20 vendor: Self::extract_vendor(&banner_lower),
21 });
22 }
23
24 if banner_lower.contains("http") || banner_lower.contains("server:") {
25 return Some(ServiceInfo {
26 name: "HTTP".to_string(),
27 version: Self::extract_http_version(banner),
28 vendor: Self::extract_vendor(&banner_lower),
29 });
30 }
31
32 if banner_lower.contains("ftp") {
33 return Some(ServiceInfo {
34 name: "FTP".to_string(),
35 version: Self::extract_version_generic(banner),
36 vendor: Self::extract_vendor(&banner_lower),
37 });
38 }
39
40 if banner_lower.contains("smtp") || banner_lower.contains("postfix") {
41 return Some(ServiceInfo {
42 name: "SMTP".to_string(),
43 version: Self::extract_version_generic(banner),
44 vendor: Self::extract_vendor(&banner_lower),
45 });
46 }
47
48 if banner_lower.contains("mysql") {
49 return Some(ServiceInfo {
50 name: "MySQL".to_string(),
51 version: Self::extract_mysql_version(banner),
52 vendor: Some("Oracle".to_string()),
53 });
54 }
55
56 if banner_lower.contains("postgresql") || banner_lower.contains("postgres") {
57 return Some(ServiceInfo {
58 name: "PostgreSQL".to_string(),
59 version: Self::extract_version_generic(banner),
60 vendor: Some("PostgreSQL Global Development Group".to_string()),
61 });
62 }
63
64 None
65 }
66
67 fn extract_ssh_version(banner: &str) -> Option<String> {
68 if let Some(start) = banner.find("SSH-") {
69 let version_str = &banner[start..];
70 if let Some(end) =
71 version_str.find(|c: char| c.is_whitespace() || c == '\r' || c == '\n')
72 {
73 return Some(version_str[..end].to_string());
74 }
75 }
76 None
77 }
78
79 fn extract_http_version(banner: &str) -> Option<String> {
80 for line in banner.lines() {
81 if line.to_lowercase().starts_with("server:") {
82 return Some(line[7..].trim().to_string());
83 }
84 }
85 None
86 }
87
88 fn extract_mysql_version(banner: &str) -> Option<String> {
89 let re = regex::Regex::new(r"(\d+\.\d+\.\d+)").ok()?;
91 re.captures(banner)
92 .and_then(|cap| cap.get(1))
93 .map(|m| m.as_str().to_string())
94 }
95
96 fn extract_version_generic(banner: &str) -> Option<String> {
97 let re = regex::Regex::new(r"(\d+\.\d+(?:\.\d+)?)").ok()?;
98 re.captures(banner)
99 .and_then(|cap| cap.get(1))
100 .map(|m| m.as_str().to_string())
101 }
102
103 fn extract_vendor(banner: &str) -> Option<String> {
104 if banner.contains("apache") {
105 Some("Apache Software Foundation".to_string())
106 } else if banner.contains("nginx") {
107 Some("Nginx Inc.".to_string())
108 } else if banner.contains("microsoft") || banner.contains("iis") {
109 Some("Microsoft".to_string())
110 } else if banner.contains("openssh") {
111 Some("OpenSSH".to_string())
112 } else if banner.contains("postfix") {
113 Some("Postfix".to_string())
114 } else {
115 None
116 }
117 }
118}
119
120#[derive(Debug, Clone)]
122pub struct ServiceInfo {
123 pub name: String,
124 pub version: Option<String>,
125 pub vendor: Option<String>,
126}
127
128impl ServiceInfo {
129 pub fn has_known_vulnerabilities(&self) -> bool {
131 if let Some(ref version) = self.version {
133 if self.name == "SSH" && version.contains("OpenSSH_7.") {
135 return true;
136 }
137 if self.name == "HTTP" && version.contains("Apache/2.2") {
139 return true;
140 }
141 }
142 false
143 }
144
145 pub fn vulnerability_severity(&self) -> Option<&'static str> {
147 if self.has_known_vulnerabilities() {
148 Some("HIGH")
149 } else {
150 None
151 }
152 }
153}
154
155pub struct BannerGrabber;
157
158impl BannerGrabber {
159 pub async fn grab_banner(host: &str, port: u16, timeout_ms: u64) -> Result<String, String> {
161 let addr = format!("{}:{}", host, port);
162 let connect_timeout = Duration::from_millis(timeout_ms);
163
164 let stream = timeout(connect_timeout, TcpStream::connect(&addr))
165 .await
166 .map_err(|_| "Connection timeout".to_string())?
167 .map_err(|e| format!("Connection failed: {}", e))?;
168
169 let mut stream = stream;
170 let mut buffer = vec![0u8; 1024];
171
172 let read_timeout = Duration::from_millis(timeout_ms);
174 let n = timeout(read_timeout, stream.read(&mut buffer))
175 .await
176 .map_err(|_| "Read timeout".to_string())?
177 .map_err(|e| format!("Read failed: {}", e))?;
178
179 if n == 0 {
180 return Err("No data received".to_string());
181 }
182
183 let banner = String::from_utf8_lossy(&buffer[..n]).to_string();
184 Ok(banner)
185 }
186
187 pub async fn grab_http_banner(
189 host: &str,
190 port: u16,
191 timeout_ms: u64,
192 ) -> Result<String, String> {
193 let addr = format!("{}:{}", host, port);
194 let connect_timeout = Duration::from_millis(timeout_ms);
195
196 let stream = timeout(connect_timeout, TcpStream::connect(&addr))
197 .await
198 .map_err(|_| "Connection timeout".to_string())?
199 .map_err(|e| format!("Connection failed: {}", e))?;
200
201 let mut stream = stream;
202
203 let request = format!(
205 "GET / HTTP/1.1\r\nHost: {}\r\nUser-Agent: NetworkScanner/1.0\r\n\r\n",
206 host
207 );
208
209 timeout(
210 Duration::from_millis(timeout_ms),
211 stream.write_all(request.as_bytes()),
212 )
213 .await
214 .map_err(|_| "Write timeout".to_string())?
215 .map_err(|e| format!("Write failed: {}", e))?;
216
217 let mut buffer = vec![0u8; 4096];
219 let read_timeout = Duration::from_millis(timeout_ms);
220 let n = timeout(read_timeout, stream.read(&mut buffer))
221 .await
222 .map_err(|_| "Read timeout".to_string())?
223 .map_err(|e| format!("Read failed: {}", e))?;
224
225 let response = String::from_utf8_lossy(&buffer[..n]).to_string();
226 Ok(response)
227 }
228}
229
230#[cfg(test)]
231mod tests {
232 use super::*;
233
234 #[test]
235 fn test_ssh_detection() {
236 let banner = "SSH-2.0-OpenSSH_8.2p1 Ubuntu-4ubuntu0.5";
237 let service = ServiceSignatures::detect_from_banner(banner).unwrap();
238 assert_eq!(service.name, "SSH");
239 assert!(service.version.is_some());
240 }
241
242 #[test]
243 fn test_http_detection() {
244 let banner = "HTTP/1.1 200 OK\r\nServer: Apache/2.4.41 (Ubuntu)\r\n";
245 let service = ServiceSignatures::detect_from_banner(banner).unwrap();
246 assert_eq!(service.name, "HTTP");
247 }
248
249 #[test]
250 fn test_mysql_detection() {
251 let _banner = "5.7.32-0ubuntu0.18.04.1";
252 let service =
253 ServiceSignatures::detect_from_banner("mysql 5.7.32-0ubuntu0.18.04.1").unwrap();
254 assert_eq!(service.name, "MySQL");
255 }
256
257 #[test]
258 fn test_vulnerability_check() {
259 let service = ServiceInfo {
260 name: "SSH".to_string(),
261 version: Some("OpenSSH_7.4".to_string()),
262 vendor: Some("OpenSSH".to_string()),
263 };
264 assert!(service.has_known_vulnerabilities());
265 assert_eq!(service.vulnerability_severity(), Some("HIGH"));
266 }
267
268 #[test]
269 fn test_no_vulnerability() {
270 let service = ServiceInfo {
271 name: "SSH".to_string(),
272 version: Some("OpenSSH_9.0".to_string()),
273 vendor: Some("OpenSSH".to_string()),
274 };
275 assert!(!service.has_known_vulnerabilities());
276 }
277}