Skip to main content

zer_compute/backend/cpu/
device.rs

1use std::sync::Arc;
2
3use zer_compare::{FieldComparator, FellegiSunterScorer};
4use zer_core::{
5    comparison::{ComparisonBatch, ComparisonVector},
6    error::ZerError,
7    record::Record,
8    record_pool::RecordPool,
9    schema::Schema,
10    scoring::{ModelParams, ScoredPair},
11    traits::{Comparator, Result as ZerResult, Scorer},
12};
13
14// ── CpuDevice ─────────────────────────────────────────────────────────────────
15
16pub struct CpuDevice;
17
18// ── CpuFallbackComparator ─────────────────────────────────────────────────────
19
20/// CPU-side comparator wrapping `zer_compare::FieldComparator`.
21#[derive(Clone)]
22pub struct CpuFallbackComparator {
23    inner: Arc<FieldComparator>,
24}
25
26impl CpuFallbackComparator {
27    pub fn from_schema(schema: &Schema) -> Self {
28        Self { inner: Arc::new(FieldComparator::from_schema(schema)) }
29    }
30}
31
32impl Comparator for CpuFallbackComparator {
33    fn compare(&self, a: &Record, b: &Record, schema: &Schema) -> ComparisonVector {
34        self.inner.compare(a, b, schema)
35    }
36
37    fn compare_batch_from_pool(
38        &self,
39        pool:    &RecordPool,
40        indices: &[(usize, usize)],
41        schema:  &Schema,
42    ) -> ComparisonBatch {
43        self.inner.compare_batch_from_pool(pool, indices, schema)
44    }
45}
46
47// ── CpuFallbackScorer ─────────────────────────────────────────────────────────
48
49/// CPU-side Fellegi-Sunter scorer wrapping `zer_compare::FellegiSunterScorer`.
50#[derive(Clone)]
51pub struct CpuFallbackScorer;
52
53impl Scorer for CpuFallbackScorer {
54    fn score(&self, vector: &ComparisonVector, params: &ModelParams) -> ScoredPair {
55        FellegiSunterScorer.score(vector, params)
56    }
57
58    fn score_batch(&self, batch: &ComparisonBatch, params: &ModelParams) -> Vec<ScoredPair> {
59        FellegiSunterScorer.score_batch(batch, params)
60    }
61
62    fn estimate_params(
63        &self,
64        batch:    &ComparisonBatch,
65        init:     Option<ModelParams>,
66        max_iter: usize,
67    ) -> ZerResult<ModelParams> {
68        FellegiSunterScorer.estimate_params(batch, init, max_iter)
69    }
70}
71
72/// Convenience wrapper for `DeviceScorer::estimate_params` CPU fallback.
73pub fn cpu_estimate_params(
74    batch:    &ComparisonBatch,
75    init:     Option<ModelParams>,
76    max_iter: usize,
77) -> ZerResult<ModelParams> {
78    zer_compare::run_em(batch, init, max_iter).map_err(ZerError::from)
79}