1use bytes::Bytes;
8use prost::Message;
9use reqwest::{Client, Proxy, Response};
10use safebrowsing_hash::HashPrefix;
11use safebrowsing_proto::{
12 safebrowsing_proto, ClientInfo, FetchThreatListUpdatesRequest, FetchThreatListUpdatesResponse,
13 FindFullHashesRequest, FindFullHashesResponse, ThreatEntry, ThreatInfo,
14};
15use safebrowsing_proto::{
16 PlatformType as ProtoPlatformType, ThreatEntryType as ProtoThreatEntryType,
17 ThreatType as ProtoThreatType,
18};
19use serde::{Deserialize, Serialize};
20use std::fmt;
21use std::time::Duration;
22use thiserror::Error;
23use tracing::{debug, error};
24
25pub const API_BASE_URL: &str = "https://safebrowsing.googleapis.com";
27
28const THREAT_LIST_UPDATES_PATH: &str = "/v4/threatListUpdates:fetch";
30const FULL_HASHES_PATH: &str = "/v4/fullHashes:find";
31
32#[derive(Error, Debug)]
34pub enum ApiError {
35 #[error("Bad request: {0}")]
37 BadRequest(String),
38
39 #[error("Authentication error: {0}")]
41 Authentication(String),
42
43 #[error("API quota exceeded")]
45 QuotaExceeded,
46
47 #[error("Rate limited, retry after {retry_after:?}")]
49 RateLimit { retry_after: Option<Duration> },
50
51 #[error("Server unavailable: {0}")]
53 ServerUnavailable(String),
54
55 #[error("HTTP error {status}: {message}")]
57 HttpStatus { status: u16, message: String },
58}
59
60#[derive(Error, Debug)]
62pub enum Error {
63 #[error("HTTP error: {0}")]
65 Http(#[from] reqwest::Error),
66
67 #[error("API error: {0}")]
69 Api(#[from] ApiError),
70
71 #[error("Protobuf error: {0}")]
73 Protobuf(String),
74
75 #[error("Configuration error: {0}")]
77 Configuration(String),
78}
79
80type Result<T> = std::result::Result<T, Error>;
82
83#[derive(Debug, Clone)]
85pub struct ApiConfig {
86 pub api_key: String,
88
89 pub client_id: String,
91
92 pub client_version: String,
94
95 pub base_url: String,
97
98 pub proxy_url: Option<String>,
100
101 pub request_timeout: Duration,
103}
104
105#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
107pub struct ThreatDescriptor {
108 pub threat_type: ThreatType,
110
111 pub platform_type: PlatformType,
113
114 pub threat_entry_type: ThreatEntryType,
116}
117
118impl fmt::Display for ThreatDescriptor {
119 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
120 write!(
121 f,
122 "{}/{}/{}",
123 self.threat_type, self.platform_type, self.threat_entry_type
124 )
125 }
126}
127
128#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
130pub enum ThreatType {
131 Unspecified,
133 Malware,
135 SocialEngineering,
137 UnwantedSoftware,
139 PotentiallyHarmfulApplication,
141}
142
143impl fmt::Display for ThreatType {
144 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
145 match self {
146 Self::Unspecified => write!(f, "UNSPECIFIED"),
147 Self::Malware => write!(f, "MALWARE"),
148 Self::SocialEngineering => write!(f, "SOCIAL_ENGINEERING"),
149 Self::UnwantedSoftware => write!(f, "UNWANTED_SOFTWARE"),
150 Self::PotentiallyHarmfulApplication => write!(f, "POTENTIALLY_HARMFUL_APPLICATION"),
151 }
152 }
153}
154
155impl From<ThreatType> for i32 {
156 fn from(tt: ThreatType) -> i32 {
157 match tt {
158 ThreatType::Unspecified => ProtoThreatType::Unspecified as i32,
159 ThreatType::Malware => ProtoThreatType::Malware as i32,
160 ThreatType::SocialEngineering => ProtoThreatType::SocialEngineering as i32,
161 ThreatType::UnwantedSoftware => ProtoThreatType::UnwantedSoftware as i32,
162 ThreatType::PotentiallyHarmfulApplication => {
163 ProtoThreatType::PotentiallyHarmfulApplication as i32
164 }
165 }
166 }
167}
168
169impl From<i32> for ThreatType {
170 fn from(value: i32) -> Self {
171 match value {
172 x if x == ProtoThreatType::Malware as i32 => Self::Malware,
173 x if x == ProtoThreatType::SocialEngineering as i32 => Self::SocialEngineering,
174 x if x == ProtoThreatType::UnwantedSoftware as i32 => Self::UnwantedSoftware,
175 x if x == ProtoThreatType::PotentiallyHarmfulApplication as i32 => {
176 Self::PotentiallyHarmfulApplication
177 }
178 _ => Self::Unspecified,
179 }
180 }
181}
182
183#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
185pub enum PlatformType {
186 Unspecified,
188 Windows,
190 Linux,
192 Android,
194 OSX,
196 IOS,
198 AnyPlatform,
200 AllPlatforms,
202 Chrome,
204}
205
206impl fmt::Display for PlatformType {
207 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
208 match self {
209 Self::Unspecified => write!(f, "UNSPECIFIED"),
210 Self::Windows => write!(f, "WINDOWS"),
211 Self::Linux => write!(f, "LINUX"),
212 Self::Android => write!(f, "ANDROID"),
213 Self::OSX => write!(f, "OSX"),
214 Self::IOS => write!(f, "IOS"),
215 Self::AnyPlatform => write!(f, "ANY_PLATFORM"),
216 Self::AllPlatforms => write!(f, "ALL_PLATFORMS"),
217 Self::Chrome => write!(f, "CHROME"),
218 }
219 }
220}
221
222impl From<PlatformType> for i32 {
223 fn from(pt: PlatformType) -> i32 {
224 match pt {
225 PlatformType::Unspecified => ProtoPlatformType::Unspecified as i32,
226 PlatformType::Windows => ProtoPlatformType::Windows as i32,
227 PlatformType::Linux => ProtoPlatformType::Linux as i32,
228 PlatformType::Android => ProtoPlatformType::Android as i32,
229 PlatformType::OSX => ProtoPlatformType::Osx as i32,
230 PlatformType::IOS => ProtoPlatformType::Ios as i32,
231 PlatformType::AnyPlatform => ProtoPlatformType::AnyPlatform as i32,
232 PlatformType::AllPlatforms => ProtoPlatformType::AllPlatforms as i32,
233 PlatformType::Chrome => ProtoPlatformType::Chrome as i32,
234 }
235 }
236}
237
238impl From<i32> for PlatformType {
239 fn from(value: i32) -> Self {
240 match value {
241 x if x == ProtoPlatformType::Windows as i32 => Self::Windows,
242 x if x == ProtoPlatformType::Linux as i32 => Self::Linux,
243 x if x == ProtoPlatformType::Android as i32 => Self::Android,
244 x if x == ProtoPlatformType::Osx as i32 => Self::OSX,
245 x if x == ProtoPlatformType::Ios as i32 => Self::IOS,
246 x if x == ProtoPlatformType::AnyPlatform as i32 => Self::AnyPlatform,
247 x if x == ProtoPlatformType::AllPlatforms as i32 => Self::AllPlatforms,
248 x if x == ProtoPlatformType::Chrome as i32 => Self::Chrome,
249 _ => Self::Unspecified,
250 }
251 }
252}
253
254#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
256pub enum ThreatEntryType {
257 Unspecified,
259 Url,
261 Executable,
263 IpRange,
265}
266
267impl fmt::Display for ThreatEntryType {
268 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
269 match self {
270 Self::Unspecified => write!(f, "UNSPECIFIED"),
271 Self::Url => write!(f, "URL"),
272 Self::Executable => write!(f, "EXECUTABLE"),
273 Self::IpRange => write!(f, "IP_RANGE"),
274 }
275 }
276}
277
278impl From<ThreatEntryType> for i32 {
279 fn from(tet: ThreatEntryType) -> i32 {
280 match tet {
281 ThreatEntryType::Unspecified => ProtoThreatEntryType::Unspecified as i32,
282 ThreatEntryType::Url => ProtoThreatEntryType::Url as i32,
283 ThreatEntryType::Executable => ProtoThreatEntryType::Executable as i32,
284 ThreatEntryType::IpRange => ProtoThreatEntryType::IpRange as i32,
285 }
286 }
287}
288
289impl From<i32> for ThreatEntryType {
290 fn from(value: i32) -> Self {
291 match value {
292 x if x == ProtoThreatEntryType::Url as i32 => Self::Url,
293 x if x == ProtoThreatEntryType::Executable as i32 => Self::Executable,
294 x if x == ProtoThreatEntryType::IpRange as i32 => Self::IpRange,
295 _ => Self::Unspecified,
296 }
297 }
298}
299
300#[derive(Debug, Clone, PartialEq, Eq)]
302pub struct URLThreat {
303 pub pattern: String,
305
306 pub threat_descriptor: ThreatDescriptor,
308}
309
310impl fmt::Display for URLThreat {
311 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
312 write!(f, "{}: {}", self.pattern, self.threat_descriptor)
313 }
314}
315
316#[derive(Clone)]
318pub struct SafeBrowsingApi {
319 client: Client,
320 base_url: String,
321 api_key: String,
322 client_info: ClientInfo,
323}
324
325impl SafeBrowsingApi {
326 pub fn new(config: &ApiConfig) -> Result<Self> {
328 let mut client_builder = Client::builder()
329 .timeout(config.request_timeout)
330 .user_agent(format!("{}/{}", config.client_id, config.client_version))
331 .gzip(true);
332
333 if let Some(proxy_url) = &config.proxy_url {
335 let proxy = Proxy::all(proxy_url)
336 .map_err(|e| Error::Configuration(format!("Invalid proxy URL: {e}")))?;
337 client_builder = client_builder.proxy(proxy);
338 }
339
340 let client = client_builder
341 .build()
342 .map_err(|e| Error::Configuration(format!("Failed to create HTTP client: {e}")))?;
343
344 let client_info = ClientInfo {
345 client_id: config.client_id.clone(),
346 client_version: config.client_version.clone(),
347 };
348
349 Ok(Self {
350 client,
351 base_url: config.base_url.clone(),
352 api_key: config.api_key.clone(),
353 client_info,
354 })
355 }
356
357 pub async fn fetch_threat_list_update(
359 &self,
360 threat_descriptor: &ThreatDescriptor,
361 client_state: &[u8],
362 ) -> Result<FetchThreatListUpdatesResponse> {
363 let request = FetchThreatListUpdatesRequest {
364 client: Some(self.client_info.clone()),
365 list_update_requests: vec![
366 safebrowsing_proto::fetch_threat_list_updates_request::ListUpdateRequest {
367 threat_type: threat_descriptor.threat_type.into(),
368 platform_type: threat_descriptor.platform_type.into(),
369 threat_entry_type: threat_descriptor.threat_entry_type.into(),
370 state: client_state.to_vec().into(),
371 constraints: Some(
372 safebrowsing_proto::fetch_threat_list_updates_request::list_update_request::Constraints {
373 max_update_entries: 0, max_database_entries: 0, region: String::new(),
376 supported_compressions: vec![
377 safebrowsing_proto::CompressionType::Raw as i32,
378 safebrowsing_proto::CompressionType::Rice as i32,
379 ],
380 },
381 ),
382 },
383 ],
384 };
385
386 self.post_protobuf(THREAT_LIST_UPDATES_PATH, &request).await
387 }
388
389 pub async fn find_full_hashes(
391 &self,
392 hash_prefix: &HashPrefix,
393 threat_descriptors: &[ThreatDescriptor],
394 ) -> Result<FindFullHashesResponse> {
395 let threat_entries = vec![ThreatEntry {
396 hash: Bytes::copy_from_slice(hash_prefix.as_bytes()),
397 url: String::new(),
398 }];
399
400 let threat_types: Vec<i32> = threat_descriptors
401 .iter()
402 .map(|td| td.threat_type.into())
403 .collect();
404
405 let platform_types: Vec<i32> = threat_descriptors
406 .iter()
407 .map(|td| td.platform_type.into())
408 .collect();
409
410 let threat_entry_types: Vec<i32> = threat_descriptors
411 .iter()
412 .map(|td| td.threat_entry_type.into())
413 .collect();
414
415 let request = FindFullHashesRequest {
416 client: Some(self.client_info.clone()),
417 client_states: Vec::new(),
418 threat_info: Some(ThreatInfo {
419 threat_types,
420 platform_types,
421 threat_entry_types,
422 threat_entries,
423 }),
424 };
425
426 self.post_protobuf(FULL_HASHES_PATH, &request).await
427 }
428
429 async fn post_protobuf<T, R>(&self, path: &str, request: &T) -> Result<R>
431 where
432 T: Message,
433 R: Message + Default,
434 {
435 let url = format!("{}{}?key={}&alt=proto", self.base_url, path, self.api_key);
436
437 let mut buf = Vec::new();
439 prost::Message::encode(request, &mut buf).map_err(|e| Error::Protobuf(e.to_string()))?;
440
441 debug!("Making API request to: {}", url);
442 debug!("Request size: {} bytes", buf.len());
443
444 let response = self
446 .client
447 .post(&url)
448 .header("Content-Type", "application/x-protobuf")
449 .body(buf)
450 .send()
451 .await
452 .map_err(Error::Http)?;
453
454 self.handle_response(response).await
455 }
456
457 async fn handle_response<R>(&self, response: Response) -> Result<R>
459 where
460 R: Message + Default,
461 {
462 let status = response.status();
463 let headers = response.headers().clone();
464
465 debug!("API response status: {}", status);
466
467 if !status.is_success() {
468 let body = response
469 .text()
470 .await
471 .unwrap_or_else(|_| "Failed to read response body".to_string());
472
473 let api_error = match status.as_u16() {
474 400 => ApiError::BadRequest(body),
475 401 => ApiError::Authentication("Invalid API key".to_string()),
476 403 => ApiError::QuotaExceeded,
477 429 => {
478 let retry_after = headers
479 .get("retry-after")
480 .and_then(|v| v.to_str().ok())
481 .and_then(|v| v.parse::<u64>().ok())
482 .map(Duration::from_secs);
483 ApiError::RateLimit { retry_after }
484 }
485 503 => ApiError::ServerUnavailable("Service temporarily unavailable".to_string()),
486 _ => ApiError::HttpStatus {
487 status: status.as_u16(),
488 message: body,
489 },
490 };
491
492 return Err(Error::Api(api_error));
493 }
494
495 let body = response.bytes().await.map_err(Error::Http)?;
497 debug!("Response size: {} bytes", body.len());
498
499 prost::Message::decode(body).map_err(|e| Error::Protobuf(e.to_string()))
501 }
502
503 pub fn base_url(&self) -> &str {
505 &self.base_url
506 }
507
508 pub fn client_info(&self) -> &ClientInfo {
510 &self.client_info
511 }
512}
513
514#[cfg(test)]
515mod tests {
516 use super::*;
517
518 #[test]
519 fn test_threat_descriptor_display() {
520 let td = ThreatDescriptor {
521 threat_type: ThreatType::Malware,
522 platform_type: PlatformType::AnyPlatform,
523 threat_entry_type: ThreatEntryType::Url,
524 };
525 assert_eq!(format!("{td}"), "MALWARE/ANY_PLATFORM/URL");
526 }
527
528 #[test]
529 fn test_threat_type_conversions() {
530 assert_eq!(
531 i32::from(ThreatType::Malware),
532 safebrowsing_proto::ThreatType::Malware as i32
533 );
534 assert_eq!(
535 ThreatType::from(safebrowsing_proto::ThreatType::Malware as i32),
536 ThreatType::Malware
537 );
538 }
539
540 #[test]
541 fn test_platform_type_conversions() {
542 assert_eq!(
543 i32::from(PlatformType::AnyPlatform),
544 safebrowsing_proto::PlatformType::AnyPlatform as i32
545 );
546 assert_eq!(
547 PlatformType::from(safebrowsing_proto::PlatformType::AnyPlatform as i32),
548 PlatformType::AnyPlatform
549 );
550 }
551
552 #[test]
553 fn test_threat_entry_type_conversions() {
554 assert_eq!(
555 i32::from(ThreatEntryType::Url),
556 safebrowsing_proto::ThreatEntryType::Url as i32
557 );
558 assert_eq!(
559 ThreatEntryType::from(safebrowsing_proto::ThreatEntryType::Url as i32),
560 ThreatEntryType::Url
561 );
562 }
563}