1use async_trait::async_trait;
8
9use safebrowsing_api::{SafeBrowsingApi, ThreatDescriptor};
10
11use safebrowsing_hash::{HashPrefix, HashPrefixSet};
12use safebrowsing_proto::{CompressionType, RawHashes, RawIndices, RiceDeltaEncoding};
13use std::collections::{HashMap, HashSet as StdHashSet};
14use std::fmt;
15use std::sync::Arc;
16use std::time::{Duration, Instant};
17use tokio::sync::{Mutex, RwLock};
18
19use tracing::{debug, error, info, warn};
20
21type Result<T> = std::result::Result<T, DatabaseError>;
23
24pub const DEFAULT_MAX_DATABASE_AGE: Duration = Duration::from_secs(24 * 60 * 60);
26
27const MAX_RETRY_DELAY: Duration = Duration::from_secs(24 * 60 * 60);
29
30const BASE_RETRY_DELAY: Duration = Duration::from_secs(15 * 60);
32
33#[derive(thiserror::Error, Debug)]
35pub enum DatabaseError {
36 #[error("Database not ready")]
38 NotReady,
39
40 #[error("Database is stale, last updated {0:?} ago")]
42 Stale(Duration),
43
44 #[error("Error decoding data: {0}")]
46 DecodeError(String),
47
48 #[error("API error: {0}")]
50 ApiError(#[from] safebrowsing_api::Error),
51 #[error("I/O error: {0}")]
53 IoError(#[from] std::io::Error),
54
55 #[error("Rice decoder error: {0}")]
57 RiceDecodeError(String),
58
59 #[error("Invalid indices: {0}")]
61 InvalidIndices(String),
62 #[error("Invalid checksum: expected {expected}, got {actual}")]
64 InvalidChecksum { expected: String, actual: String },
65 #[error("Hash error: {0}")]
67 HashError(#[from] safebrowsing_hash::HashError),
68}
69
70#[derive(Debug, Clone, Default)]
72pub struct DatabaseStats {
73 pub total_hashes: usize,
75
76 pub threat_lists: usize,
78
79 pub memory_usage: usize,
81
82 pub last_update: Option<Instant>,
84
85 pub is_stale: bool,
87}
88
89impl fmt::Display for DatabaseStats {
90 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
91 let last_update = match self.last_update {
92 Some(time) => format!("{:?} ago", time.elapsed()),
93 None => "never".to_string(),
94 };
95
96 write!(
97 f,
98 "Database stats: {} hashes in {} lists, ~{} bytes, last update: {}, {}",
99 self.total_hashes,
100 self.threat_lists,
101 self.memory_usage,
102 last_update,
103 if self.is_stale { "STALE" } else { "up-to-date" }
104 )
105 }
106}
107
108#[async_trait]
114pub trait Database {
115 async fn is_ready(&self) -> Result<bool>;
117
118 async fn status(&self) -> Result<()>;
120
121 async fn update(&self, api: &SafeBrowsingApi, threat_lists: &[ThreatDescriptor]) -> Result<()>;
123
124 async fn lookup(
129 &self,
130 hash: &HashPrefix,
131 ) -> Result<Option<(HashPrefix, Vec<ThreatDescriptor>)>>;
132
133 async fn time_since_last_update(&self) -> Option<Duration>;
135
136 async fn stats(&self) -> DatabaseStats;
138}
139
140struct ThreatListEntry {
142 hash_set: HashPrefixSet,
144
145 client_state: Vec<u8>,
147
148 checksum: Vec<u8>,
150
151 last_update: Instant,
153}
154
155impl ThreatListEntry {
156 fn new(hash_set: HashPrefixSet, client_state: Vec<u8>, checksum: Vec<u8>) -> Self {
158 Self {
159 hash_set,
160 client_state,
161 checksum,
162 last_update: Instant::now(),
163 }
164 }
165
166 fn is_stale(&self, max_age: Duration) -> bool {
168 self.last_update.elapsed() > max_age
169 }
170}
171
172pub struct InMemoryDatabase {
177 inner: Arc<RwLock<InMemoryDatabaseInner>>,
179}
180
181struct InMemoryDatabaseInner {
182 threat_lists: HashMap<ThreatDescriptor, ThreatListEntry>,
184
185 initialized: bool,
187
188 last_update: Option<Instant>,
190
191 max_age: Duration,
193
194 hash_count: usize,
196}
197
198impl InMemoryDatabase {
199 pub fn new() -> Self {
201 Self {
202 inner: Arc::new(RwLock::new(InMemoryDatabaseInner {
203 threat_lists: HashMap::new(),
204 initialized: false,
205 last_update: None,
206 max_age: Duration::from_secs(2 * 60 * 60),
207 hash_count: 0,
208 })),
209 }
210 }
211
212 pub fn with_max_age(max_age: Duration) -> Self {
214 Self {
215 inner: Arc::new(RwLock::new(InMemoryDatabaseInner {
216 threat_lists: HashMap::new(),
217 initialized: false,
218 last_update: None,
219 max_age,
220 hash_count: 0,
221 })),
222 }
223 }
224
225 async fn update_threat_list(
227 &self,
228 api: &SafeBrowsingApi,
229 threat_descriptor: &ThreatDescriptor,
230 ) -> Result<()> {
231 let client_state = {
232 let inner = self.inner.read().await;
233 inner
234 .threat_lists
235 .get(threat_descriptor)
236 .map(|entry| entry.client_state.clone())
237 .unwrap_or_default()
238 };
239
240 debug!(
241 "Updating threat list: {} (state len: {})",
242 threat_descriptor,
243 client_state.len()
244 );
245
246 let response = api
247 .fetch_threat_list_update(threat_descriptor, &client_state)
248 .await?;
249
250 if response.list_update_responses.is_empty() {
251 warn!("Empty list update response for {}", threat_descriptor);
252 return Ok(());
253 }
254
255 let list_update = &response.list_update_responses[0];
256 let response_type = list_update.response_type;
257
258 debug!(
259 "Received response for {}: type={}, additions={}, removals={}",
260 threat_descriptor,
261 response_type,
262 list_update.additions.len(),
263 list_update.removals.len()
264 );
265
266 {
268 let mut inner = self.inner.write().await;
269 match response_type {
270 0 | 1 => {
271 if let Some(entry) = inner.threat_lists.get_mut(threat_descriptor) {
274 let hash_set = &mut entry.hash_set;
276
277 let removals: Vec<safebrowsing_proto::ThreatEntrySet> =
279 list_update.removals.clone();
280 let additions: Vec<safebrowsing_proto::ThreatEntrySet> =
281 list_update.additions.clone();
282
283 for removal_set in &removals {
285 Self::process_raw_hashes_removal(hash_set, removal_set)?;
286 }
287
288 for addition_set in &additions {
290 Self::process_raw_hashes_addition(hash_set, addition_set)?;
291 }
292
293 if !list_update.new_client_state.is_empty() {
295 entry.client_state = list_update.new_client_state.clone().to_vec();
296 }
297
298 if let Some(checksum) = &list_update.checksum {
299 entry.checksum = checksum.sha256.clone().to_vec();
300
301 let computed_checksum = entry.hash_set.compute_checksum();
303 if computed_checksum.as_bytes() != &checksum.sha256[..] {
304 return Err(DatabaseError::InvalidChecksum {
305 expected: hex::encode(&checksum.sha256),
306 actual: hex::encode(computed_checksum.as_bytes()),
307 });
308 }
309 }
310
311 entry.last_update = Instant::now();
312 } else if response_type == 0 {
313 Self::process_full_update_inner(
315 &mut inner,
316 threat_descriptor,
317 list_update,
318 )?;
319 } else {
320 warn!(
322 "Received partial update for non-existent threat list: {}",
323 threat_descriptor
324 );
325 }
327 }
328 2 => {
329 Self::process_full_update_inner(&mut inner, threat_descriptor, list_update)?;
331 }
332 _ => {
333 warn!("Unknown response type: {}", response_type);
334 }
335 }
336
337 Self::update_hash_count_inner(&mut inner).await;
338 }
339 Ok(())
340 }
341
342 async fn update_hash_count_inner(inner: &mut InMemoryDatabaseInner) {
344 inner.hash_count = inner
345 .threat_lists
346 .values()
347 .map(|entry| entry.hash_set.len())
348 .sum();
349 }
350
351 fn process_full_update_inner(
353 inner: &mut InMemoryDatabaseInner,
354 threat_descriptor: &ThreatDescriptor,
355 list_update: &safebrowsing_proto::fetch_threat_list_updates_response::ListUpdateResponse,
356 ) -> Result<()> {
357 let mut hash_set = HashPrefixSet::new();
358
359 for addition_set in &list_update.additions {
361 Self::process_raw_hashes_addition(&mut hash_set, addition_set)?;
362 }
363
364 let entry = ThreatListEntry::new(
366 hash_set,
367 list_update.new_client_state.clone().to_vec(),
368 list_update
369 .checksum
370 .as_ref()
371 .map_or_else(Vec::new, |c| c.sha256.clone().to_vec()),
372 );
373
374 if let Some(checksum) = &list_update.checksum {
376 let computed_checksum = entry.hash_set.compute_checksum();
377 if computed_checksum.as_bytes() != &checksum.sha256[..] {
378 return Err(DatabaseError::InvalidChecksum {
379 expected: hex::encode(&checksum.sha256),
380 actual: hex::encode(computed_checksum.as_bytes()),
381 });
382 }
383 }
384
385 inner.threat_lists.insert(threat_descriptor.clone(), entry);
387
388 Ok(())
389 }
390
391 fn process_raw_hashes_addition(
393 hash_set: &mut HashPrefixSet,
394 addition: &safebrowsing_proto::ThreatEntrySet,
395 ) -> Result<()> {
396 match addition.compression_type {
397 x if x == CompressionType::Raw as i32 => {
398 if let Some(raw_hashes) = &addition.raw_hashes {
400 Self::process_raw_hashes(hash_set, raw_hashes)?;
401 }
402 }
403 x if x == CompressionType::Rice as i32 => {
404 if let Some(rice_hashes) = &addition.rice_hashes {
406 Self::process_rice_hashes(hash_set, rice_hashes)?;
407 }
408 }
409 _ => {
410 warn!(
411 "Unsupported compression type: {}",
412 addition.compression_type
413 );
414 }
415 }
416
417 Ok(())
418 }
419
420 fn process_raw_hashes_removal(
422 hash_set: &mut HashPrefixSet,
423 removal: &safebrowsing_proto::ThreatEntrySet,
424 ) -> Result<()> {
425 match removal.compression_type {
426 x if x == CompressionType::Raw as i32 => {
427 if let Some(raw_indices) = &removal.raw_indices {
429 Self::process_raw_indices(hash_set, raw_indices)?;
430 }
431 }
432 x if x == CompressionType::Rice as i32 => {
433 if let Some(rice_indices) = &removal.rice_indices {
435 Self::process_rice_indices(hash_set, rice_indices)?;
436 }
437 }
438 _ => {
439 warn!(
440 "Unsupported compression type for removal: {}",
441 removal.compression_type
442 );
443 }
444 }
445
446 Ok(())
447 }
448
449 fn process_raw_hashes(hash_set: &mut HashPrefixSet, raw_hashes: &RawHashes) -> Result<()> {
453 let prefix_size = raw_hashes.prefix_size as usize;
454 if !(4..=32).contains(&prefix_size) {
455 return Err(DatabaseError::DecodeError(format!(
456 "Invalid prefix size: {prefix_size}"
457 )));
458 }
459
460 let hashes = &raw_hashes.raw_hashes;
461 if hashes.len() % prefix_size != 0 {
462 return Err(DatabaseError::DecodeError(format!(
463 "Raw hashes length {} is not a multiple of prefix size {}",
464 hashes.len(),
465 prefix_size
466 )));
467 }
468
469 for i in (0..hashes.len()).step_by(prefix_size) {
470 let end = i + prefix_size;
471 if end > hashes.len() {
472 break;
473 }
474
475 let hash_vec = hashes[i..end].to_vec();
477 match HashPrefix::new(hash_vec) {
478 Ok(hash) => {
479 hash_set.insert(hash);
480 }
481 Err(e) => {
482 warn!("Skipping invalid hash: {}", e);
483 }
484 }
485 }
486
487 Ok(())
488 }
489
490 fn process_rice_hashes(
500 hash_set: &mut HashPrefixSet,
501 rice_hashes: &RiceDeltaEncoding,
502 ) -> Result<()> {
503 let decoded_hashes = Self::decode_rice_delta_encoding(rice_hashes)?;
504
505 for hash_value in decoded_hashes {
506 let hash_vec = hash_value.to_le_bytes().to_vec();
510 match HashPrefix::new(hash_vec) {
511 Ok(hash) => {
512 hash_set.insert(hash);
513 }
514 Err(e) => {
515 warn!("Skipping invalid hash: {}", e);
516 }
517 }
518 }
519
520 Ok(())
521 }
522
523 fn process_raw_indices(hash_set: &mut HashPrefixSet, raw_indices: &RawIndices) -> Result<()> {
525 let sorted_hashes = hash_set.to_sorted_vec();
527
528 let mut indices_to_remove = StdHashSet::new();
530 for &index in &raw_indices.indices {
531 if index >= 0 && (index as usize) < sorted_hashes.len() {
532 indices_to_remove.insert(index as usize);
533 } else {
534 return Err(DatabaseError::InvalidIndices(format!(
535 "Index out of bounds: {} (max: {})",
536 index,
537 sorted_hashes.len()
538 )));
539 }
540 }
541
542 for (i, hash) in sorted_hashes.iter().enumerate() {
544 if indices_to_remove.contains(&i) {
545 hash_set.remove(hash);
546 }
547 }
548
549 Ok(())
550 }
551
552 fn process_rice_indices(
554 hash_set: &mut HashPrefixSet,
555 rice_indices: &RiceDeltaEncoding,
556 ) -> Result<()> {
557 let decoded_indices = Self::decode_rice_delta_encoding(rice_indices)?;
558
559 let sorted_hashes: Vec<HashPrefix> = hash_set.to_sorted_vec();
561
562 let mut hashes_to_remove = StdHashSet::new();
564 for index in decoded_indices {
565 let index = index as usize;
566 if index < sorted_hashes.len() {
567 hashes_to_remove.insert(sorted_hashes[index].clone());
568 } else {
569 return Err(DatabaseError::InvalidIndices(format!(
570 "Rice-encoded index out of bounds: {} (max: {})",
571 index,
572 sorted_hashes.len()
573 )));
574 }
575 }
576
577 for hash in hashes_to_remove {
579 hash_set.remove(&hash);
580 }
581
582 Ok(())
583 }
584
585 fn decode_rice_delta_encoding(rice: &RiceDeltaEncoding) -> Result<Vec<u32>> {
587 use safebrowsing_hash::rice::decode_rice_integers;
588
589 decode_rice_integers(
590 rice.rice_parameter,
591 rice.first_value,
592 rice.num_entries,
593 &rice.encoded_data,
594 )
595 .map_err(|e| DatabaseError::RiceDecodeError(e.to_string()))
596 }
597
598 async fn update_hash_count(&self) {
600 let mut inner = self.inner.write().await;
601 Self::update_hash_count_inner(&mut inner).await;
602 }
603}
604
605impl Default for InMemoryDatabase {
606 fn default() -> Self {
607 Self::new()
608 }
609}
610
611#[async_trait]
612impl Database for InMemoryDatabase {
613 async fn is_ready(&self) -> Result<bool> {
614 let inner = self.inner.read().await;
615 if !inner.initialized {
616 return Ok(false);
617 }
618
619 for (descriptor, entry) in &inner.threat_lists {
621 if entry.is_stale(inner.max_age) {
622 warn!(
623 "Threat list {} is stale (last updated {:?} ago)",
624 descriptor,
625 entry.last_update.elapsed()
626 );
627 return Ok(false);
628 }
629 }
630
631 Ok(true)
632 }
633
634 async fn status(&self) -> Result<()> {
635 let inner = self.inner.read().await;
636 if !inner.initialized {
637 return Err(DatabaseError::NotReady);
638 }
639
640 if let Some(last_update) = inner.last_update {
641 let elapsed = last_update.elapsed();
642 if elapsed > inner.max_age {
643 return Err(DatabaseError::Stale(elapsed));
644 }
645 } else {
646 return Err(DatabaseError::NotReady);
647 }
648
649 Ok(())
650 }
651
652 async fn update(&self, api: &SafeBrowsingApi, threat_lists: &[ThreatDescriptor]) -> Result<()> {
653 info!("Updating database with {} threat lists", threat_lists.len());
654
655 for threat_descriptor in threat_lists {
656 if let Err(e) = self.update_threat_list(api, threat_descriptor).await {
657 error!("Failed to update threat list {}: {}", threat_descriptor, e);
658 }
660 }
661
662 {
663 let mut inner = self.inner.write().await;
664 inner.initialized = !inner.threat_lists.is_empty();
665 inner.last_update = Some(Instant::now());
666 }
667
668 Ok(())
669 }
670
671 async fn lookup(
672 &self,
673 hash: &HashPrefix,
674 ) -> Result<Option<(HashPrefix, Vec<ThreatDescriptor>)>> {
675 let inner = self.inner.read().await;
676 if !inner.initialized {
677 return Err(DatabaseError::NotReady);
678 }
679
680 let mut matching_descriptors = Vec::new();
681 let mut matching_prefix = None;
682
683 for (descriptor, entry) in &inner.threat_lists {
684 if let Some(prefix) = entry.hash_set.find_prefix(hash) {
685 matching_descriptors.push(descriptor.clone());
686
687 if matching_prefix.is_none() {
689 matching_prefix = Some(prefix.clone());
690 }
691 }
692 }
693
694 if let Some(prefix) = matching_prefix {
695 Ok(Some((prefix, matching_descriptors)))
696 } else {
697 Ok(None)
698 }
699 }
700
701 async fn time_since_last_update(&self) -> Option<Duration> {
702 let inner = self.inner.read().await;
703 inner.last_update.map(|time| time.elapsed())
704 }
705
706 async fn stats(&self) -> DatabaseStats {
707 let inner = self.inner.read().await;
708 let is_stale = if let Some(last_update) = inner.last_update {
709 last_update.elapsed() > inner.max_age
710 } else {
711 true
712 };
713
714 let mut memory_usage = 0;
716 for entry in inner.threat_lists.values() {
717 memory_usage += entry.hash_set.len() * 8; memory_usage += entry.client_state.len();
720 memory_usage += entry.checksum.len();
721 memory_usage += 32; }
723
724 memory_usage += inner.threat_lists.len() * 32;
726
727 DatabaseStats {
728 total_hashes: inner.hash_count,
729 threat_lists: inner.threat_lists.len(),
730 memory_usage,
731 last_update: inner.last_update,
732 is_stale,
733 }
734 }
735}
736
737pub struct ConcurrentDatabase {
742 db: Arc<Mutex<InMemoryDatabase>>,
744}
745
746impl ConcurrentDatabase {
747 pub fn new() -> Self {
749 Self {
750 db: Arc::new(Mutex::new(InMemoryDatabase::new())),
751 }
752 }
753
754 pub fn with_max_age(max_age: Duration) -> Self {
756 Self {
757 db: Arc::new(Mutex::new(InMemoryDatabase::with_max_age(max_age))),
758 }
759 }
760}
761
762#[async_trait]
763impl Database for ConcurrentDatabase {
764 async fn is_ready(&self) -> Result<bool> {
765 let db = self.db.lock().await;
766 db.is_ready().await
767 }
768
769 async fn status(&self) -> Result<()> {
770 let db = self.db.lock().await;
771 db.status().await
772 }
773
774 async fn update(&self, api: &SafeBrowsingApi, threat_lists: &[ThreatDescriptor]) -> Result<()> {
775 let db = self.db.lock().await;
776 db.update(api, threat_lists).await
777 }
778
779 async fn lookup(
780 &self,
781 hash: &HashPrefix,
782 ) -> Result<Option<(HashPrefix, Vec<ThreatDescriptor>)>> {
783 let db = self.db.lock().await;
784 db.lookup(hash).await
785 }
786
787 async fn time_since_last_update(&self) -> Option<Duration> {
788 let db = self.db.lock().await;
789 db.time_since_last_update().await
790 }
791
792 async fn stats(&self) -> DatabaseStats {
793 let db = self.db.lock().await;
794 db.stats().await
795 }
796}
797
798impl Default for ConcurrentDatabase {
799 fn default() -> Self {
800 Self::new()
801 }
802}
803
804#[cfg(test)]
805mod tests {
806 use super::*;
807 use safebrowsing_api::{PlatformType, ThreatDescriptor, ThreatEntryType, ThreatType};
808 use safebrowsing_hash::HashPrefix;
809
810 fn create_test_threat_descriptor() -> ThreatDescriptor {
811 ThreatDescriptor {
812 threat_type: ThreatType::Malware,
813 platform_type: PlatformType::AnyPlatform,
814 threat_entry_type: ThreatEntryType::Url,
815 }
816 }
817
818 #[tokio::test]
819 async fn test_database_initialization() {
820 let db = InMemoryDatabase::new();
821
822 assert!(!db.is_ready().await.unwrap());
824
825 assert!(matches!(db.status().await, Err(DatabaseError::NotReady)));
827
828 assert!(db.time_since_last_update().await.is_none());
830 }
831
832 #[tokio::test]
833 async fn test_database_stats() {
834 let db = InMemoryDatabase::new();
835
836 let stats = db.stats().await;
837 assert_eq!(stats.total_hashes, 0);
838 assert_eq!(stats.threat_lists, 0);
839 assert!(stats.last_update.is_none());
840 assert!(stats.is_stale);
841 }
842
843 #[tokio::test]
844 async fn test_lookup_with_invalid_hash() {
845 let db = InMemoryDatabase::new();
846 let hash = HashPrefix::from_pattern("test");
847
848 assert!(matches!(
850 db.lookup(&hash).await,
851 Err(DatabaseError::NotReady)
852 ));
853 }
854
855 #[tokio::test]
856 async fn test_lookup_empty_database() {
857 let mut db = InMemoryDatabase::new();
858 let hash = HashPrefix::from_pattern("test");
859
860 {
862 let mut inner = db.inner.write().await;
863 inner.initialized = true;
864 inner.last_update = Some(Instant::now());
865 }
866
867 assert!(matches!(db.lookup(&hash).await, Ok(None)));
869 }
870
871 #[tokio::test]
872 async fn test_time_since_last_update() {
873 let mut db = InMemoryDatabase::new();
874
875 assert!(db.time_since_last_update().await.is_none());
877
878 {
880 let mut inner = db.inner.write().await;
881 inner.last_update = Some(Instant::now());
882 }
883 assert!(db.time_since_last_update().await.is_some());
884 }
885
886 #[tokio::test]
887 async fn test_concurrent_database() {
888 let db = ConcurrentDatabase::new();
889
890 assert!(!db.is_ready().await.unwrap());
892
893 assert!(matches!(db.status().await, Err(DatabaseError::NotReady)));
895
896 let stats = db.stats().await;
898 assert_eq!(stats.total_hashes, 0);
899 }
900
901 #[tokio::test]
902 async fn test_database_stats_display() {
903 let mut db = InMemoryDatabase::new();
904
905 let stats = db.stats().await;
907 let display = format!("{stats}");
908 assert!(display.contains("never"));
909 assert!(display.contains("STALE"));
910
911 {
913 let mut inner = db.inner.write().await;
914 inner.last_update = Some(Instant::now());
915 }
916 let stats = db.stats().await;
917 let display = format!("{stats}");
918 assert!(display.contains("ago"));
919 assert!(display.contains("up-to-date"));
920 }
921}