zer_compute/backend/cpu/
device.rs1use std::sync::Arc;
2
3use zer_compare::{FieldComparator, FellegiSunterScorer};
4use zer_core::{
5 comparison::{ComparisonBatch, ComparisonVector},
6 error::ZerError,
7 record::Record,
8 record_pool::RecordPool,
9 schema::Schema,
10 scoring::{ModelParams, ScoredPair},
11 traits::{Comparator, Result as ZerResult, Scorer},
12};
13
14pub struct CpuDevice;
17
18#[derive(Clone)]
22pub struct CpuFallbackComparator {
23 inner: Arc<FieldComparator>,
24}
25
26impl CpuFallbackComparator {
27 pub fn from_schema(schema: &Schema) -> Self {
28 Self { inner: Arc::new(FieldComparator::from_schema(schema)) }
29 }
30}
31
32impl Comparator for CpuFallbackComparator {
33 fn compare(&self, a: &Record, b: &Record, schema: &Schema) -> ComparisonVector {
34 self.inner.compare(a, b, schema)
35 }
36
37 fn compare_batch_from_pool(
38 &self,
39 pool: &RecordPool,
40 indices: &[(usize, usize)],
41 schema: &Schema,
42 ) -> ComparisonBatch {
43 self.inner.compare_batch_from_pool(pool, indices, schema)
44 }
45}
46
47#[derive(Clone)]
51pub struct CpuFallbackScorer;
52
53impl Scorer for CpuFallbackScorer {
54 fn score(&self, vector: &ComparisonVector, params: &ModelParams) -> ScoredPair {
55 FellegiSunterScorer.score(vector, params)
56 }
57
58 fn score_batch(&self, batch: &ComparisonBatch, params: &ModelParams) -> Vec<ScoredPair> {
59 FellegiSunterScorer.score_batch(batch, params)
60 }
61
62 fn estimate_params(
63 &self,
64 batch: &ComparisonBatch,
65 init: Option<ModelParams>,
66 max_iter: usize,
67 ) -> ZerResult<ModelParams> {
68 FellegiSunterScorer.estimate_params(batch, init, max_iter)
69 }
70}
71
72pub fn cpu_estimate_params(
74 batch: &ComparisonBatch,
75 init: Option<ModelParams>,
76 max_iter: usize,
77) -> ZerResult<ModelParams> {
78 zer_compare::run_em(batch, init, max_iter).map_err(ZerError::from)
79}