safebrowsing_api/
lib.rs

1//! HTTP API client for Google Safe Browsing API v4
2//!
3//! This crate provides the HTTP client implementation for communicating with
4//! the Google Safe Browsing API servers. It handles request formation,
5//! authorization, and response parsing.
6
7use bytes::Bytes;
8use prost::Message;
9use reqwest::{Client, Proxy, Response};
10use safebrowsing_hash::HashPrefix;
11use safebrowsing_proto::{
12    safebrowsing_proto, ClientInfo, FetchThreatListUpdatesRequest, FetchThreatListUpdatesResponse,
13    FindFullHashesRequest, FindFullHashesResponse, ThreatEntry, ThreatInfo,
14};
15use safebrowsing_proto::{
16    PlatformType as ProtoPlatformType, ThreatEntryType as ProtoThreatEntryType,
17    ThreatType as ProtoThreatType,
18};
19use serde::{Deserialize, Serialize};
20use std::fmt;
21use std::time::Duration;
22use thiserror::Error;
23use tracing::{debug, error};
24
25/// Default Safe Browsing API base URL
26pub const API_BASE_URL: &str = "https://safebrowsing.googleapis.com";
27
28/// API endpoint paths
29const THREAT_LIST_UPDATES_PATH: &str = "/v4/threatListUpdates:fetch";
30const FULL_HASHES_PATH: &str = "/v4/fullHashes:find";
31
32/// Error types specific to the Safe Browsing API
33#[derive(Error, Debug)]
34pub enum ApiError {
35    /// Bad request error (HTTP 400)
36    #[error("Bad request: {0}")]
37    BadRequest(String),
38
39    /// Authentication error (HTTP 401)
40    #[error("Authentication error: {0}")]
41    Authentication(String),
42
43    /// API quota exceeded (HTTP 403)
44    #[error("API quota exceeded")]
45    QuotaExceeded,
46
47    /// Rate limiting error (HTTP 429)
48    #[error("Rate limited, retry after {retry_after:?}")]
49    RateLimit { retry_after: Option<Duration> },
50
51    /// Server unavailable (HTTP 503)
52    #[error("Server unavailable: {0}")]
53    ServerUnavailable(String),
54
55    /// Other HTTP error
56    #[error("HTTP error {status}: {message}")]
57    HttpStatus { status: u16, message: String },
58}
59
60/// Error type for API operations
61#[derive(Error, Debug)]
62pub enum Error {
63    /// HTTP client error
64    #[error("HTTP error: {0}")]
65    Http(#[from] reqwest::Error),
66
67    /// Safe Browsing API error
68    #[error("API error: {0}")]
69    Api(#[from] ApiError),
70
71    /// Protobuf encoding/decoding error
72    #[error("Protobuf error: {0}")]
73    Protobuf(String),
74
75    /// Configuration error
76    #[error("Configuration error: {0}")]
77    Configuration(String),
78}
79
80/// Result type for API operations
81type Result<T> = std::result::Result<T, Error>;
82
83/// Configuration for the Safe Browsing API client
84#[derive(Debug, Clone)]
85pub struct ApiConfig {
86    /// The API key for authenticating with the Safe Browsing API
87    pub api_key: String,
88
89    /// Client ID to identify the client to the API
90    pub client_id: String,
91
92    /// Client version string
93    pub client_version: String,
94
95    /// Base URL for the Safe Browsing API
96    pub base_url: String,
97
98    /// Optional HTTP proxy URL
99    pub proxy_url: Option<String>,
100
101    /// Request timeout duration
102    pub request_timeout: Duration,
103}
104
105/// A threat descriptor describes a specific threat list
106#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
107pub struct ThreatDescriptor {
108    /// The type of threat (malware, phishing, etc)
109    pub threat_type: ThreatType,
110
111    /// The platform this threat applies to
112    pub platform_type: PlatformType,
113
114    /// The type of entries in the threat list
115    pub threat_entry_type: ThreatEntryType,
116}
117
118impl fmt::Display for ThreatDescriptor {
119    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
120        write!(
121            f,
122            "{}/{}/{}",
123            self.threat_type, self.platform_type, self.threat_entry_type
124        )
125    }
126}
127
128/// Types of threats in the Safe Browsing API
129#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
130pub enum ThreatType {
131    /// Unknown threat type
132    Unspecified,
133    /// Malware threat
134    Malware,
135    /// Social engineering/phishing
136    SocialEngineering,
137    /// Unwanted software
138    UnwantedSoftware,
139    /// Potentially harmful application
140    PotentiallyHarmfulApplication,
141}
142
143impl fmt::Display for ThreatType {
144    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
145        match self {
146            Self::Unspecified => write!(f, "UNSPECIFIED"),
147            Self::Malware => write!(f, "MALWARE"),
148            Self::SocialEngineering => write!(f, "SOCIAL_ENGINEERING"),
149            Self::UnwantedSoftware => write!(f, "UNWANTED_SOFTWARE"),
150            Self::PotentiallyHarmfulApplication => write!(f, "POTENTIALLY_HARMFUL_APPLICATION"),
151        }
152    }
153}
154
155impl From<ThreatType> for i32 {
156    fn from(tt: ThreatType) -> i32 {
157        match tt {
158            ThreatType::Unspecified => ProtoThreatType::Unspecified as i32,
159            ThreatType::Malware => ProtoThreatType::Malware as i32,
160            ThreatType::SocialEngineering => ProtoThreatType::SocialEngineering as i32,
161            ThreatType::UnwantedSoftware => ProtoThreatType::UnwantedSoftware as i32,
162            ThreatType::PotentiallyHarmfulApplication => {
163                ProtoThreatType::PotentiallyHarmfulApplication as i32
164            }
165        }
166    }
167}
168
169impl From<i32> for ThreatType {
170    fn from(value: i32) -> Self {
171        match value {
172            x if x == ProtoThreatType::Malware as i32 => Self::Malware,
173            x if x == ProtoThreatType::SocialEngineering as i32 => Self::SocialEngineering,
174            x if x == ProtoThreatType::UnwantedSoftware as i32 => Self::UnwantedSoftware,
175            x if x == ProtoThreatType::PotentiallyHarmfulApplication as i32 => {
176                Self::PotentiallyHarmfulApplication
177            }
178            _ => Self::Unspecified,
179        }
180    }
181}
182
183/// Platform types in the Safe Browsing API
184#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
185pub enum PlatformType {
186    /// Unknown platform
187    Unspecified,
188    /// Windows platform
189    Windows,
190    /// Linux platform
191    Linux,
192    /// Android platform
193    Android,
194    /// macOS platform
195    OSX,
196    /// iOS platform
197    IOS,
198    /// Any platform (at least one platform)
199    AnyPlatform,
200    /// All platforms
201    AllPlatforms,
202    /// Chrome browser
203    Chrome,
204}
205
206impl fmt::Display for PlatformType {
207    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
208        match self {
209            Self::Unspecified => write!(f, "UNSPECIFIED"),
210            Self::Windows => write!(f, "WINDOWS"),
211            Self::Linux => write!(f, "LINUX"),
212            Self::Android => write!(f, "ANDROID"),
213            Self::OSX => write!(f, "OSX"),
214            Self::IOS => write!(f, "IOS"),
215            Self::AnyPlatform => write!(f, "ANY_PLATFORM"),
216            Self::AllPlatforms => write!(f, "ALL_PLATFORMS"),
217            Self::Chrome => write!(f, "CHROME"),
218        }
219    }
220}
221
222impl From<PlatformType> for i32 {
223    fn from(pt: PlatformType) -> i32 {
224        match pt {
225            PlatformType::Unspecified => ProtoPlatformType::Unspecified as i32,
226            PlatformType::Windows => ProtoPlatformType::Windows as i32,
227            PlatformType::Linux => ProtoPlatformType::Linux as i32,
228            PlatformType::Android => ProtoPlatformType::Android as i32,
229            PlatformType::OSX => ProtoPlatformType::Osx as i32,
230            PlatformType::IOS => ProtoPlatformType::Ios as i32,
231            PlatformType::AnyPlatform => ProtoPlatformType::AnyPlatform as i32,
232            PlatformType::AllPlatforms => ProtoPlatformType::AllPlatforms as i32,
233            PlatformType::Chrome => ProtoPlatformType::Chrome as i32,
234        }
235    }
236}
237
238impl From<i32> for PlatformType {
239    fn from(value: i32) -> Self {
240        match value {
241            x if x == ProtoPlatformType::Windows as i32 => Self::Windows,
242            x if x == ProtoPlatformType::Linux as i32 => Self::Linux,
243            x if x == ProtoPlatformType::Android as i32 => Self::Android,
244            x if x == ProtoPlatformType::Osx as i32 => Self::OSX,
245            x if x == ProtoPlatformType::Ios as i32 => Self::IOS,
246            x if x == ProtoPlatformType::AnyPlatform as i32 => Self::AnyPlatform,
247            x if x == ProtoPlatformType::AllPlatforms as i32 => Self::AllPlatforms,
248            x if x == ProtoPlatformType::Chrome as i32 => Self::Chrome,
249            _ => Self::Unspecified,
250        }
251    }
252}
253
254/// Types of threat entries in the Safe Browsing API
255#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
256pub enum ThreatEntryType {
257    /// Unknown entry type
258    Unspecified,
259    /// URL entry
260    Url,
261    /// Executable file
262    Executable,
263    /// IP range
264    IpRange,
265}
266
267impl fmt::Display for ThreatEntryType {
268    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
269        match self {
270            Self::Unspecified => write!(f, "UNSPECIFIED"),
271            Self::Url => write!(f, "URL"),
272            Self::Executable => write!(f, "EXECUTABLE"),
273            Self::IpRange => write!(f, "IP_RANGE"),
274        }
275    }
276}
277
278impl From<ThreatEntryType> for i32 {
279    fn from(tet: ThreatEntryType) -> i32 {
280        match tet {
281            ThreatEntryType::Unspecified => ProtoThreatEntryType::Unspecified as i32,
282            ThreatEntryType::Url => ProtoThreatEntryType::Url as i32,
283            ThreatEntryType::Executable => ProtoThreatEntryType::Executable as i32,
284            ThreatEntryType::IpRange => ProtoThreatEntryType::IpRange as i32,
285        }
286    }
287}
288
289impl From<i32> for ThreatEntryType {
290    fn from(value: i32) -> Self {
291        match value {
292            x if x == ProtoThreatEntryType::Url as i32 => Self::Url,
293            x if x == ProtoThreatEntryType::Executable as i32 => Self::Executable,
294            x if x == ProtoThreatEntryType::IpRange as i32 => Self::IpRange,
295            _ => Self::Unspecified,
296        }
297    }
298}
299
300/// Information about a URL that matched a threat list
301#[derive(Debug, Clone, PartialEq, Eq)]
302pub struct URLThreat {
303    /// The URL pattern that matched
304    pub pattern: String,
305
306    /// The threat descriptor that matched
307    pub threat_descriptor: ThreatDescriptor,
308}
309
310impl fmt::Display for URLThreat {
311    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
312        write!(f, "{}: {}", self.pattern, self.threat_descriptor)
313    }
314}
315
316/// Safe Browsing API client
317#[derive(Clone)]
318pub struct SafeBrowsingApi {
319    client: Client,
320    base_url: String,
321    api_key: String,
322    client_info: ClientInfo,
323}
324
325impl SafeBrowsingApi {
326    /// Create a new API client with the provided configuration
327    pub fn new(config: &ApiConfig) -> Result<Self> {
328        let mut client_builder = Client::builder()
329            .timeout(config.request_timeout)
330            .user_agent(format!("{}/{}", config.client_id, config.client_version))
331            .gzip(true);
332
333        // Configure proxy if specified
334        if let Some(proxy_url) = &config.proxy_url {
335            let proxy = Proxy::all(proxy_url)
336                .map_err(|e| Error::Configuration(format!("Invalid proxy URL: {e}")))?;
337            client_builder = client_builder.proxy(proxy);
338        }
339
340        let client = client_builder
341            .build()
342            .map_err(|e| Error::Configuration(format!("Failed to create HTTP client: {e}")))?;
343
344        let client_info = ClientInfo {
345            client_id: config.client_id.clone(),
346            client_version: config.client_version.clone(),
347        };
348
349        Ok(Self {
350            client,
351            base_url: config.base_url.clone(),
352            api_key: config.api_key.clone(),
353            client_info,
354        })
355    }
356
357    /// Fetch threat list updates from the API
358    pub async fn fetch_threat_list_update(
359        &self,
360        threat_descriptor: &ThreatDescriptor,
361        client_state: &[u8],
362    ) -> Result<FetchThreatListUpdatesResponse> {
363        let request = FetchThreatListUpdatesRequest {
364            client: Some(self.client_info.clone()),
365            list_update_requests: vec![
366                safebrowsing_proto::fetch_threat_list_updates_request::ListUpdateRequest {
367                    threat_type: threat_descriptor.threat_type.into(),
368                    platform_type: threat_descriptor.platform_type.into(),
369                    threat_entry_type: threat_descriptor.threat_entry_type.into(),
370                    state: client_state.to_vec().into(),
371                    constraints: Some(
372                        safebrowsing_proto::fetch_threat_list_updates_request::list_update_request::Constraints {
373                            max_update_entries: 0, // No limit
374                            max_database_entries: 0, // No limit
375                            region: String::new(),
376                            supported_compressions: vec![
377                                safebrowsing_proto::CompressionType::Raw as i32,
378                                safebrowsing_proto::CompressionType::Rice as i32,
379                            ],
380                        },
381                    ),
382                },
383            ],
384        };
385
386        self.post_protobuf(THREAT_LIST_UPDATES_PATH, &request).await
387    }
388
389    /// Find full hashes for the given hash prefixes
390    pub async fn find_full_hashes(
391        &self,
392        hash_prefix: &HashPrefix,
393        threat_descriptors: &[ThreatDescriptor],
394    ) -> Result<FindFullHashesResponse> {
395        let threat_entries = vec![ThreatEntry {
396            hash: Bytes::copy_from_slice(hash_prefix.as_bytes()),
397            url: String::new(),
398        }];
399
400        let threat_types: Vec<i32> = threat_descriptors
401            .iter()
402            .map(|td| td.threat_type.into())
403            .collect();
404
405        let platform_types: Vec<i32> = threat_descriptors
406            .iter()
407            .map(|td| td.platform_type.into())
408            .collect();
409
410        let threat_entry_types: Vec<i32> = threat_descriptors
411            .iter()
412            .map(|td| td.threat_entry_type.into())
413            .collect();
414
415        let request = FindFullHashesRequest {
416            client: Some(self.client_info.clone()),
417            client_states: Vec::new(),
418            threat_info: Some(ThreatInfo {
419                threat_types,
420                platform_types,
421                threat_entry_types,
422                threat_entries,
423            }),
424        };
425
426        self.post_protobuf(FULL_HASHES_PATH, &request).await
427    }
428
429    /// Make a POST request with protobuf payload
430    async fn post_protobuf<T, R>(&self, path: &str, request: &T) -> Result<R>
431    where
432        T: Message,
433        R: Message + Default,
434    {
435        let url = format!("{}{}?key={}&alt=proto", self.base_url, path, self.api_key);
436
437        // Encode the request
438        let mut buf = Vec::new();
439        prost::Message::encode(request, &mut buf).map_err(|e| Error::Protobuf(e.to_string()))?;
440
441        debug!("Making API request to: {}", url);
442        debug!("Request size: {} bytes", buf.len());
443
444        // Make the request
445        let response = self
446            .client
447            .post(&url)
448            .header("Content-Type", "application/x-protobuf")
449            .body(buf)
450            .send()
451            .await
452            .map_err(Error::Http)?;
453
454        self.handle_response(response).await
455    }
456
457    /// Handle HTTP response and decode protobuf
458    async fn handle_response<R>(&self, response: Response) -> Result<R>
459    where
460        R: Message + Default,
461    {
462        let status = response.status();
463        let headers = response.headers().clone();
464
465        debug!("API response status: {}", status);
466
467        if !status.is_success() {
468            let body = response
469                .text()
470                .await
471                .unwrap_or_else(|_| "Failed to read response body".to_string());
472
473            let api_error = match status.as_u16() {
474                400 => ApiError::BadRequest(body),
475                401 => ApiError::Authentication("Invalid API key".to_string()),
476                403 => ApiError::QuotaExceeded,
477                429 => {
478                    let retry_after = headers
479                        .get("retry-after")
480                        .and_then(|v| v.to_str().ok())
481                        .and_then(|v| v.parse::<u64>().ok())
482                        .map(Duration::from_secs);
483                    ApiError::RateLimit { retry_after }
484                }
485                503 => ApiError::ServerUnavailable("Service temporarily unavailable".to_string()),
486                _ => ApiError::HttpStatus {
487                    status: status.as_u16(),
488                    message: body,
489                },
490            };
491
492            return Err(Error::Api(api_error));
493        }
494
495        // Read response body
496        let body = response.bytes().await.map_err(Error::Http)?;
497        debug!("Response size: {} bytes", body.len());
498
499        // Decode protobuf
500        prost::Message::decode(body).map_err(|e| Error::Protobuf(e.to_string()))
501    }
502
503    /// Get the API base URL
504    pub fn base_url(&self) -> &str {
505        &self.base_url
506    }
507
508    /// Get the client info
509    pub fn client_info(&self) -> &ClientInfo {
510        &self.client_info
511    }
512}
513
514#[cfg(test)]
515mod tests {
516    use super::*;
517
518    #[test]
519    fn test_threat_descriptor_display() {
520        let td = ThreatDescriptor {
521            threat_type: ThreatType::Malware,
522            platform_type: PlatformType::AnyPlatform,
523            threat_entry_type: ThreatEntryType::Url,
524        };
525        assert_eq!(format!("{td}"), "MALWARE/ANY_PLATFORM/URL");
526    }
527
528    #[test]
529    fn test_threat_type_conversions() {
530        assert_eq!(
531            i32::from(ThreatType::Malware),
532            safebrowsing_proto::ThreatType::Malware as i32
533        );
534        assert_eq!(
535            ThreatType::from(safebrowsing_proto::ThreatType::Malware as i32),
536            ThreatType::Malware
537        );
538    }
539
540    #[test]
541    fn test_platform_type_conversions() {
542        assert_eq!(
543            i32::from(PlatformType::AnyPlatform),
544            safebrowsing_proto::PlatformType::AnyPlatform as i32
545        );
546        assert_eq!(
547            PlatformType::from(safebrowsing_proto::PlatformType::AnyPlatform as i32),
548            PlatformType::AnyPlatform
549        );
550    }
551
552    #[test]
553    fn test_threat_entry_type_conversions() {
554        assert_eq!(
555            i32::from(ThreatEntryType::Url),
556            safebrowsing_proto::ThreatEntryType::Url as i32
557        );
558        assert_eq!(
559            ThreatEntryType::from(safebrowsing_proto::ThreatEntryType::Url as i32),
560            ThreatEntryType::Url
561        );
562    }
563}