safebrowsing_db/
lib.rs

1//! Database interface and implementations for Google Safe Browsing API
2//!
3//! This crate provides the database abstraction layer for the Safe Browsing API.
4//! It defines a common `Database` trait and provides implementations for
5//! in-memory storage and persistent disk storage.
6
7use async_trait::async_trait;
8
9use safebrowsing_api::{SafeBrowsingApi, ThreatDescriptor};
10
11use safebrowsing_hash::{HashPrefix, HashPrefixSet};
12use safebrowsing_proto::{CompressionType, RawHashes, RawIndices, RiceDeltaEncoding};
13use std::collections::{HashMap, HashSet as StdHashSet};
14use std::fmt;
15use std::sync::Arc;
16use std::time::{Duration, Instant};
17use tokio::sync::{Mutex, RwLock};
18
19use tracing::{debug, error, info, warn};
20
21/// Error and Result types should be imported from a shared error module or defined here
22type Result<T> = std::result::Result<T, DatabaseError>;
23
24/// Default maximum database age before it's considered stale
25pub const DEFAULT_MAX_DATABASE_AGE: Duration = Duration::from_secs(24 * 60 * 60);
26
27/// Maximum retry delay for database updates
28const MAX_RETRY_DELAY: Duration = Duration::from_secs(24 * 60 * 60);
29
30/// Base retry delay for database updates
31const BASE_RETRY_DELAY: Duration = Duration::from_secs(15 * 60);
32
33/// Error types for database operations
34#[derive(thiserror::Error, Debug)]
35pub enum DatabaseError {
36    /// Database is not ready
37    #[error("Database not ready")]
38    NotReady,
39
40    /// Database is stale (not updated recently enough)
41    #[error("Database is stale, last updated {0:?} ago")]
42    Stale(Duration),
43
44    /// Error decoding data
45    #[error("Error decoding data: {0}")]
46    DecodeError(String),
47
48    /// API error
49    #[error("API error: {0}")]
50    ApiError(#[from] safebrowsing_api::Error),
51    /// I/O error
52    #[error("I/O error: {0}")]
53    IoError(#[from] std::io::Error),
54
55    /// Rice decoder error
56    #[error("Rice decoder error: {0}")]
57    RiceDecodeError(String),
58
59    /// Invalid indices
60    #[error("Invalid indices: {0}")]
61    InvalidIndices(String),
62    /// Invalid checksum
63    #[error("Invalid checksum: expected {expected}, got {actual}")]
64    InvalidChecksum { expected: String, actual: String },
65    /// Hash error
66    #[error("Hash error: {0}")]
67    HashError(#[from] safebrowsing_hash::HashError),
68}
69
70/// Database statistics
71#[derive(Debug, Clone, Default)]
72pub struct DatabaseStats {
73    /// Total number of hash prefixes in the database
74    pub total_hashes: usize,
75
76    /// Number of threat lists
77    pub threat_lists: usize,
78
79    /// Estimated memory usage in bytes
80    pub memory_usage: usize,
81
82    /// Last update time
83    pub last_update: Option<Instant>,
84
85    /// Whether the database is stale
86    pub is_stale: bool,
87}
88
89impl fmt::Display for DatabaseStats {
90    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
91        let last_update = match self.last_update {
92            Some(time) => format!("{:?} ago", time.elapsed()),
93            None => "never".to_string(),
94        };
95
96        write!(
97            f,
98            "Database stats: {} hashes in {} lists, ~{} bytes, last update: {}, {}",
99            self.total_hashes,
100            self.threat_lists,
101            self.memory_usage,
102            last_update,
103            if self.is_stale { "STALE" } else { "up-to-date" }
104        )
105    }
106}
107
108/// Database interface for Safe Browsing
109///
110/// This trait defines the methods required for a Safe Browsing database
111/// implementation. It provides methods for looking up hash prefixes and
112/// updating the database from the Safe Browsing API.
113#[async_trait]
114pub trait Database {
115    /// Check if the database is ready for queries
116    async fn is_ready(&self) -> Result<bool>;
117
118    /// Get the current database status
119    async fn status(&self) -> Result<()>;
120
121    /// Update the database with the latest threat lists
122    async fn update(&self, api: &SafeBrowsingApi, threat_lists: &[ThreatDescriptor]) -> Result<()>;
123
124    /// Look up a hash prefix in the database
125    ///
126    /// If found, returns the matching hash prefix and the list of
127    /// threat descriptors that contain it
128    async fn lookup(
129        &self,
130        hash: &HashPrefix,
131    ) -> Result<Option<(HashPrefix, Vec<ThreatDescriptor>)>>;
132
133    /// Get the time since the last successful update
134    async fn time_since_last_update(&self) -> Option<Duration>;
135
136    /// Get database statistics
137    async fn stats(&self) -> DatabaseStats;
138}
139
140/// Entry in a threat list
141struct ThreatListEntry {
142    /// Set of hash prefixes
143    hash_set: HashPrefixSet,
144
145    /// Client state for this list
146    client_state: Vec<u8>,
147
148    /// Checksum of this list
149    checksum: Vec<u8>,
150
151    /// Last update time
152    last_update: Instant,
153}
154
155impl ThreatListEntry {
156    /// Create a new threat list entry
157    fn new(hash_set: HashPrefixSet, client_state: Vec<u8>, checksum: Vec<u8>) -> Self {
158        Self {
159            hash_set,
160            client_state,
161            checksum,
162            last_update: Instant::now(),
163        }
164    }
165
166    /// Check if this entry is stale
167    fn is_stale(&self, max_age: Duration) -> bool {
168        self.last_update.elapsed() > max_age
169    }
170}
171
172/// In-memory database implementation
173///
174/// This is a simple database that keeps all threat lists in memory.
175/// It's suitable for applications with moderate memory usage requirements.
176pub struct InMemoryDatabase {
177    /// Inner data protected by RwLock for concurrent access
178    inner: Arc<RwLock<InMemoryDatabaseInner>>,
179}
180
181struct InMemoryDatabaseInner {
182    /// Threat lists indexed by ThreatDescriptor
183    threat_lists: HashMap<ThreatDescriptor, ThreatListEntry>,
184
185    /// Whether the database has been initialized
186    initialized: bool,
187
188    /// Last update time
189    last_update: Option<Instant>,
190
191    /// Maximum database age before it's considered stale
192    max_age: Duration,
193
194    /// Total hash count for statistics
195    hash_count: usize,
196}
197
198impl InMemoryDatabase {
199    /// Create a new in-memory database
200    pub fn new() -> Self {
201        Self {
202            inner: Arc::new(RwLock::new(InMemoryDatabaseInner {
203                threat_lists: HashMap::new(),
204                initialized: false,
205                last_update: None,
206                max_age: Duration::from_secs(2 * 60 * 60),
207                hash_count: 0,
208            })),
209        }
210    }
211
212    /// Create a new in-memory database with a specific maximum age
213    pub fn with_max_age(max_age: Duration) -> Self {
214        Self {
215            inner: Arc::new(RwLock::new(InMemoryDatabaseInner {
216                threat_lists: HashMap::new(),
217                initialized: false,
218                last_update: None,
219                max_age,
220                hash_count: 0,
221            })),
222        }
223    }
224
225    /// Update a single threat list from the API
226    async fn update_threat_list(
227        &self,
228        api: &SafeBrowsingApi,
229        threat_descriptor: &ThreatDescriptor,
230    ) -> Result<()> {
231        let client_state = {
232            let inner = self.inner.read().await;
233            inner
234                .threat_lists
235                .get(threat_descriptor)
236                .map(|entry| entry.client_state.clone())
237                .unwrap_or_default()
238        };
239
240        debug!(
241            "Updating threat list: {} (state len: {})",
242            threat_descriptor,
243            client_state.len()
244        );
245
246        let response = api
247            .fetch_threat_list_update(threat_descriptor, &client_state)
248            .await?;
249
250        if response.list_update_responses.is_empty() {
251            warn!("Empty list update response for {}", threat_descriptor);
252            return Ok(());
253        }
254
255        let list_update = &response.list_update_responses[0];
256        let response_type = list_update.response_type;
257
258        debug!(
259            "Received response for {}: type={}, additions={}, removals={}",
260            threat_descriptor,
261            response_type,
262            list_update.additions.len(),
263            list_update.removals.len()
264        );
265
266        // Process based on response type
267        {
268            let mut inner = self.inner.write().await;
269            match response_type {
270                0 | 1 => {
271                    // Unspecified or partial update
272                    // If we have an existing entry, apply incremental updates
273                    if let Some(entry) = inner.threat_lists.get_mut(threat_descriptor) {
274                        // Create a local mutable borrow of hash_set
275                        let hash_set = &mut entry.hash_set;
276
277                        // Collect removal and addition sets
278                        let removals: Vec<safebrowsing_proto::ThreatEntrySet> =
279                            list_update.removals.clone();
280                        let additions: Vec<safebrowsing_proto::ThreatEntrySet> =
281                            list_update.additions.clone();
282
283                        // Apply removals first
284                        for removal_set in &removals {
285                            Self::process_raw_hashes_removal(hash_set, removal_set)?;
286                        }
287
288                        // Then apply additions
289                        for addition_set in &additions {
290                            Self::process_raw_hashes_addition(hash_set, addition_set)?;
291                        }
292
293                        // Update state and checksum
294                        if !list_update.new_client_state.is_empty() {
295                            entry.client_state = list_update.new_client_state.clone().to_vec();
296                        }
297
298                        if let Some(checksum) = &list_update.checksum {
299                            entry.checksum = checksum.sha256.clone().to_vec();
300
301                            // Verify the checksum
302                            let computed_checksum = entry.hash_set.compute_checksum();
303                            if computed_checksum.as_bytes() != &checksum.sha256[..] {
304                                return Err(DatabaseError::InvalidChecksum {
305                                    expected: hex::encode(&checksum.sha256),
306                                    actual: hex::encode(computed_checksum.as_bytes()),
307                                });
308                            }
309                        }
310
311                        entry.last_update = Instant::now();
312                    } else if response_type == 0 {
313                        // Unspecified response but no existing entry, treat as full update
314                        Self::process_full_update_inner(
315                            &mut inner,
316                            threat_descriptor,
317                            list_update,
318                        )?;
319                    } else {
320                        // Partial update but no existing entry, error
321                        warn!(
322                            "Received partial update for non-existent threat list: {}",
323                            threat_descriptor
324                        );
325                        // We'll request a full update next time
326                    }
327                }
328                2 => {
329                    // Full update - replace entire list
330                    Self::process_full_update_inner(&mut inner, threat_descriptor, list_update)?;
331                }
332                _ => {
333                    warn!("Unknown response type: {}", response_type);
334                }
335            }
336
337            Self::update_hash_count_inner(&mut inner).await;
338        }
339        Ok(())
340    }
341
342    /// Update hash count metadata
343    async fn update_hash_count_inner(inner: &mut InMemoryDatabaseInner) {
344        inner.hash_count = inner
345            .threat_lists
346            .values()
347            .map(|entry| entry.hash_set.len())
348            .sum();
349    }
350
351    /// Process a full update for a threat list
352    fn process_full_update_inner(
353        inner: &mut InMemoryDatabaseInner,
354        threat_descriptor: &ThreatDescriptor,
355        list_update: &safebrowsing_proto::fetch_threat_list_updates_response::ListUpdateResponse,
356    ) -> Result<()> {
357        let mut hash_set = HashPrefixSet::new();
358
359        // Process additions only for full updates
360        for addition_set in &list_update.additions {
361            Self::process_raw_hashes_addition(&mut hash_set, addition_set)?;
362        }
363
364        // Create new entry
365        let entry = ThreatListEntry::new(
366            hash_set,
367            list_update.new_client_state.clone().to_vec(),
368            list_update
369                .checksum
370                .as_ref()
371                .map_or_else(Vec::new, |c| c.sha256.clone().to_vec()),
372        );
373
374        // Verify checksum if present
375        if let Some(checksum) = &list_update.checksum {
376            let computed_checksum = entry.hash_set.compute_checksum();
377            if computed_checksum.as_bytes() != &checksum.sha256[..] {
378                return Err(DatabaseError::InvalidChecksum {
379                    expected: hex::encode(&checksum.sha256),
380                    actual: hex::encode(computed_checksum.as_bytes()),
381                });
382            }
383        }
384
385        // Add to threat lists
386        inner.threat_lists.insert(threat_descriptor.clone(), entry);
387
388        Ok(())
389    }
390
391    /// Process a threat entry addition set (free function)
392    fn process_raw_hashes_addition(
393        hash_set: &mut HashPrefixSet,
394        addition: &safebrowsing_proto::ThreatEntrySet,
395    ) -> Result<()> {
396        match addition.compression_type {
397            x if x == CompressionType::Raw as i32 => {
398                // Raw hashes
399                if let Some(raw_hashes) = &addition.raw_hashes {
400                    Self::process_raw_hashes(hash_set, raw_hashes)?;
401                }
402            }
403            x if x == CompressionType::Rice as i32 => {
404                // Rice-encoded hashes
405                if let Some(rice_hashes) = &addition.rice_hashes {
406                    Self::process_rice_hashes(hash_set, rice_hashes)?;
407                }
408            }
409            _ => {
410                warn!(
411                    "Unsupported compression type: {}",
412                    addition.compression_type
413                );
414            }
415        }
416
417        Ok(())
418    }
419
420    /// Process a threat entry removal set (free function)
421    fn process_raw_hashes_removal(
422        hash_set: &mut HashPrefixSet,
423        removal: &safebrowsing_proto::ThreatEntrySet,
424    ) -> Result<()> {
425        match removal.compression_type {
426            x if x == CompressionType::Raw as i32 => {
427                // Raw indices
428                if let Some(raw_indices) = &removal.raw_indices {
429                    Self::process_raw_indices(hash_set, raw_indices)?;
430                }
431            }
432            x if x == CompressionType::Rice as i32 => {
433                // Rice-encoded indices
434                if let Some(rice_indices) = &removal.rice_indices {
435                    Self::process_rice_indices(hash_set, rice_indices)?;
436                }
437            }
438            _ => {
439                warn!(
440                    "Unsupported compression type for removal: {}",
441                    removal.compression_type
442                );
443            }
444        }
445
446        Ok(())
447    }
448
449    // Move these helper functions outside the impl block
450
451    /// Process raw hash additions
452    fn process_raw_hashes(hash_set: &mut HashPrefixSet, raw_hashes: &RawHashes) -> Result<()> {
453        let prefix_size = raw_hashes.prefix_size as usize;
454        if !(4..=32).contains(&prefix_size) {
455            return Err(DatabaseError::DecodeError(format!(
456                "Invalid prefix size: {prefix_size}"
457            )));
458        }
459
460        let hashes = &raw_hashes.raw_hashes;
461        if hashes.len() % prefix_size != 0 {
462            return Err(DatabaseError::DecodeError(format!(
463                "Raw hashes length {} is not a multiple of prefix size {}",
464                hashes.len(),
465                prefix_size
466            )));
467        }
468
469        for i in (0..hashes.len()).step_by(prefix_size) {
470            let end = i + prefix_size;
471            if end > hashes.len() {
472                break;
473            }
474
475            // Convert to a Vec<u8> to avoid lifetime issues
476            let hash_vec = hashes[i..end].to_vec();
477            match HashPrefix::new(hash_vec) {
478                Ok(hash) => {
479                    hash_set.insert(hash);
480                }
481                Err(e) => {
482                    warn!("Skipping invalid hash: {}", e);
483                }
484            }
485        }
486
487        Ok(())
488    }
489
490    /// Process Rice-encoded hash additions
491    /// Process Rice-encoded hashes and add them to the hash set.
492    ///
493    /// IMPORTANT: This uses little-endian byte order to match the Go implementation.
494    /// The Safe Browsing API's Go reference implementation uses `binary.LittleEndian.PutUint32`
495    /// when converting Rice-decoded integers to hash bytes. Using big-endian would result
496    /// in completely different hash values and checksum mismatches.
497    ///
498    /// See: https://github.com/google/safebrowsing/blob/master/hash.go#L183
499    fn process_rice_hashes(
500        hash_set: &mut HashPrefixSet,
501        rice_hashes: &RiceDeltaEncoding,
502    ) -> Result<()> {
503        let decoded_hashes = Self::decode_rice_delta_encoding(rice_hashes)?;
504
505        for hash_value in decoded_hashes {
506            // Rice encoding is for 4-byte hashes
507            // CRITICAL: Use little-endian to match Go implementation
508            // Go code: binary.LittleEndian.PutUint32(buf[:], h)
509            let hash_vec = hash_value.to_le_bytes().to_vec();
510            match HashPrefix::new(hash_vec) {
511                Ok(hash) => {
512                    hash_set.insert(hash);
513                }
514                Err(e) => {
515                    warn!("Skipping invalid hash: {}", e);
516                }
517            }
518        }
519
520        Ok(())
521    }
522
523    /// Process raw indices for removals
524    fn process_raw_indices(hash_set: &mut HashPrefixSet, raw_indices: &RawIndices) -> Result<()> {
525        // Need to sort the hashes and remove by index
526        let sorted_hashes = hash_set.to_sorted_vec();
527
528        // Build a set of indices to remove
529        let mut indices_to_remove = StdHashSet::new();
530        for &index in &raw_indices.indices {
531            if index >= 0 && (index as usize) < sorted_hashes.len() {
532                indices_to_remove.insert(index as usize);
533            } else {
534                return Err(DatabaseError::InvalidIndices(format!(
535                    "Index out of bounds: {} (max: {})",
536                    index,
537                    sorted_hashes.len()
538                )));
539            }
540        }
541
542        // Remove the hashes at the specified indices
543        for (i, hash) in sorted_hashes.iter().enumerate() {
544            if indices_to_remove.contains(&i) {
545                hash_set.remove(hash);
546            }
547        }
548
549        Ok(())
550    }
551
552    /// Process Rice-encoded indices for removals
553    fn process_rice_indices(
554        hash_set: &mut HashPrefixSet,
555        rice_indices: &RiceDeltaEncoding,
556    ) -> Result<()> {
557        let decoded_indices = Self::decode_rice_delta_encoding(rice_indices)?;
558
559        // Convert hash set to sorted vector for indexed removal
560        let sorted_hashes: Vec<HashPrefix> = hash_set.to_sorted_vec();
561
562        // Build a set of hashes to remove based on indices
563        let mut hashes_to_remove = StdHashSet::new();
564        for index in decoded_indices {
565            let index = index as usize;
566            if index < sorted_hashes.len() {
567                hashes_to_remove.insert(sorted_hashes[index].clone());
568            } else {
569                return Err(DatabaseError::InvalidIndices(format!(
570                    "Rice-encoded index out of bounds: {} (max: {})",
571                    index,
572                    sorted_hashes.len()
573                )));
574            }
575        }
576
577        // Remove all identified hashes
578        for hash in hashes_to_remove {
579            hash_set.remove(&hash);
580        }
581
582        Ok(())
583    }
584
585    /// Decode a Rice-delta encoded value using proper Rice-Golomb decoding
586    fn decode_rice_delta_encoding(rice: &RiceDeltaEncoding) -> Result<Vec<u32>> {
587        use safebrowsing_hash::rice::decode_rice_integers;
588
589        decode_rice_integers(
590            rice.rice_parameter,
591            rice.first_value,
592            rice.num_entries,
593            &rice.encoded_data,
594        )
595        .map_err(|e| DatabaseError::RiceDecodeError(e.to_string()))
596    }
597
598    /// Update the total hash count for statistics
599    async fn update_hash_count(&self) {
600        let mut inner = self.inner.write().await;
601        Self::update_hash_count_inner(&mut inner).await;
602    }
603}
604
605impl Default for InMemoryDatabase {
606    fn default() -> Self {
607        Self::new()
608    }
609}
610
611#[async_trait]
612impl Database for InMemoryDatabase {
613    async fn is_ready(&self) -> Result<bool> {
614        let inner = self.inner.read().await;
615        if !inner.initialized {
616            return Ok(false);
617        }
618
619        // Check if any list is stale
620        for (descriptor, entry) in &inner.threat_lists {
621            if entry.is_stale(inner.max_age) {
622                warn!(
623                    "Threat list {} is stale (last updated {:?} ago)",
624                    descriptor,
625                    entry.last_update.elapsed()
626                );
627                return Ok(false);
628            }
629        }
630
631        Ok(true)
632    }
633
634    async fn status(&self) -> Result<()> {
635        let inner = self.inner.read().await;
636        if !inner.initialized {
637            return Err(DatabaseError::NotReady);
638        }
639
640        if let Some(last_update) = inner.last_update {
641            let elapsed = last_update.elapsed();
642            if elapsed > inner.max_age {
643                return Err(DatabaseError::Stale(elapsed));
644            }
645        } else {
646            return Err(DatabaseError::NotReady);
647        }
648
649        Ok(())
650    }
651
652    async fn update(&self, api: &SafeBrowsingApi, threat_lists: &[ThreatDescriptor]) -> Result<()> {
653        info!("Updating database with {} threat lists", threat_lists.len());
654
655        for threat_descriptor in threat_lists {
656            if let Err(e) = self.update_threat_list(api, threat_descriptor).await {
657                error!("Failed to update threat list {}: {}", threat_descriptor, e);
658                // Continue with other lists
659            }
660        }
661
662        {
663            let mut inner = self.inner.write().await;
664            inner.initialized = !inner.threat_lists.is_empty();
665            inner.last_update = Some(Instant::now());
666        }
667
668        Ok(())
669    }
670
671    async fn lookup(
672        &self,
673        hash: &HashPrefix,
674    ) -> Result<Option<(HashPrefix, Vec<ThreatDescriptor>)>> {
675        let inner = self.inner.read().await;
676        if !inner.initialized {
677            return Err(DatabaseError::NotReady);
678        }
679
680        let mut matching_descriptors = Vec::new();
681        let mut matching_prefix = None;
682
683        for (descriptor, entry) in &inner.threat_lists {
684            if let Some(prefix) = entry.hash_set.find_prefix(hash) {
685                matching_descriptors.push(descriptor.clone());
686
687                // We'll use the first matching prefix we find
688                if matching_prefix.is_none() {
689                    matching_prefix = Some(prefix.clone());
690                }
691            }
692        }
693
694        if let Some(prefix) = matching_prefix {
695            Ok(Some((prefix, matching_descriptors)))
696        } else {
697            Ok(None)
698        }
699    }
700
701    async fn time_since_last_update(&self) -> Option<Duration> {
702        let inner = self.inner.read().await;
703        inner.last_update.map(|time| time.elapsed())
704    }
705
706    async fn stats(&self) -> DatabaseStats {
707        let inner = self.inner.read().await;
708        let is_stale = if let Some(last_update) = inner.last_update {
709            last_update.elapsed() > inner.max_age
710        } else {
711            true
712        };
713
714        // Estimate memory usage (very rough approximation)
715        let mut memory_usage = 0;
716        for entry in inner.threat_lists.values() {
717            // Each hash is approximately its length + overhead
718            memory_usage += entry.hash_set.len() * 8; // ~8 bytes per hash with overhead
719            memory_usage += entry.client_state.len();
720            memory_usage += entry.checksum.len();
721            memory_usage += 32; // Struct overhead
722        }
723
724        // Map overhead
725        memory_usage += inner.threat_lists.len() * 32;
726
727        DatabaseStats {
728            total_hashes: inner.hash_count,
729            threat_lists: inner.threat_lists.len(),
730            memory_usage,
731            last_update: inner.last_update,
732            is_stale,
733        }
734    }
735}
736
737/// Thread-safe wrapper around an in-memory database
738///
739/// This provides a concurrent version of the InMemoryDatabase
740/// that can be safely shared between threads.
741pub struct ConcurrentDatabase {
742    /// The inner database
743    db: Arc<Mutex<InMemoryDatabase>>,
744}
745
746impl ConcurrentDatabase {
747    /// Create a new concurrent database
748    pub fn new() -> Self {
749        Self {
750            db: Arc::new(Mutex::new(InMemoryDatabase::new())),
751        }
752    }
753
754    /// Create a new concurrent database with a specific maximum age
755    pub fn with_max_age(max_age: Duration) -> Self {
756        Self {
757            db: Arc::new(Mutex::new(InMemoryDatabase::with_max_age(max_age))),
758        }
759    }
760}
761
762#[async_trait]
763impl Database for ConcurrentDatabase {
764    async fn is_ready(&self) -> Result<bool> {
765        let db = self.db.lock().await;
766        db.is_ready().await
767    }
768
769    async fn status(&self) -> Result<()> {
770        let db = self.db.lock().await;
771        db.status().await
772    }
773
774    async fn update(&self, api: &SafeBrowsingApi, threat_lists: &[ThreatDescriptor]) -> Result<()> {
775        let db = self.db.lock().await;
776        db.update(api, threat_lists).await
777    }
778
779    async fn lookup(
780        &self,
781        hash: &HashPrefix,
782    ) -> Result<Option<(HashPrefix, Vec<ThreatDescriptor>)>> {
783        let db = self.db.lock().await;
784        db.lookup(hash).await
785    }
786
787    async fn time_since_last_update(&self) -> Option<Duration> {
788        let db = self.db.lock().await;
789        db.time_since_last_update().await
790    }
791
792    async fn stats(&self) -> DatabaseStats {
793        let db = self.db.lock().await;
794        db.stats().await
795    }
796}
797
798impl Default for ConcurrentDatabase {
799    fn default() -> Self {
800        Self::new()
801    }
802}
803
804#[cfg(test)]
805mod tests {
806    use super::*;
807    use safebrowsing_api::{PlatformType, ThreatDescriptor, ThreatEntryType, ThreatType};
808    use safebrowsing_hash::HashPrefix;
809
810    fn create_test_threat_descriptor() -> ThreatDescriptor {
811        ThreatDescriptor {
812            threat_type: ThreatType::Malware,
813            platform_type: PlatformType::AnyPlatform,
814            threat_entry_type: ThreatEntryType::Url,
815        }
816    }
817
818    #[tokio::test]
819    async fn test_database_initialization() {
820        let db = InMemoryDatabase::new();
821
822        // New database should not be ready
823        assert!(!db.is_ready().await.unwrap());
824
825        // Status should return an error
826        assert!(matches!(db.status().await, Err(DatabaseError::NotReady)));
827
828        // Last update time should be None
829        assert!(db.time_since_last_update().await.is_none());
830    }
831
832    #[tokio::test]
833    async fn test_database_stats() {
834        let db = InMemoryDatabase::new();
835
836        let stats = db.stats().await;
837        assert_eq!(stats.total_hashes, 0);
838        assert_eq!(stats.threat_lists, 0);
839        assert!(stats.last_update.is_none());
840        assert!(stats.is_stale);
841    }
842
843    #[tokio::test]
844    async fn test_lookup_with_invalid_hash() {
845        let db = InMemoryDatabase::new();
846        let hash = HashPrefix::from_pattern("test");
847
848        // Lookup should fail on uninitialized database
849        assert!(matches!(
850            db.lookup(&hash).await,
851            Err(DatabaseError::NotReady)
852        ));
853    }
854
855    #[tokio::test]
856    async fn test_lookup_empty_database() {
857        let mut db = InMemoryDatabase::new();
858        let hash = HashPrefix::from_pattern("test");
859
860        // Manually set initialized to simulate a database with no entries
861        {
862            let mut inner = db.inner.write().await;
863            inner.initialized = true;
864            inner.last_update = Some(Instant::now());
865        }
866
867        // Lookup should return None for an empty database
868        assert!(matches!(db.lookup(&hash).await, Ok(None)));
869    }
870
871    #[tokio::test]
872    async fn test_time_since_last_update() {
873        let mut db = InMemoryDatabase::new();
874
875        // Initially, there should be no update time
876        assert!(db.time_since_last_update().await.is_none());
877
878        // Set an update time
879        {
880            let mut inner = db.inner.write().await;
881            inner.last_update = Some(Instant::now());
882        }
883        assert!(db.time_since_last_update().await.is_some());
884    }
885
886    #[tokio::test]
887    async fn test_concurrent_database() {
888        let db = ConcurrentDatabase::new();
889
890        // New database should not be ready
891        assert!(!db.is_ready().await.unwrap());
892
893        // Status should return an error
894        assert!(matches!(db.status().await, Err(DatabaseError::NotReady)));
895
896        // Stats should be default values
897        let stats = db.stats().await;
898        assert_eq!(stats.total_hashes, 0);
899    }
900
901    #[tokio::test]
902    async fn test_database_stats_display() {
903        let mut db = InMemoryDatabase::new();
904
905        // Without update time
906        let stats = db.stats().await;
907        let display = format!("{stats}");
908        assert!(display.contains("never"));
909        assert!(display.contains("STALE"));
910
911        // With update time
912        {
913            let mut inner = db.inner.write().await;
914            inner.last_update = Some(Instant::now());
915        }
916        let stats = db.stats().await;
917        let display = format!("{stats}");
918        assert!(display.contains("ago"));
919        assert!(display.contains("up-to-date"));
920    }
921}