Skip to main content

three_dcf_core/
embedding.rs

1use serde::{Deserialize, Serialize};
2use std::collections::hash_map::DefaultHasher;
3use std::hash::{Hash, Hasher};
4
5use crate::document::CellType;
6
7#[derive(Debug, Clone, Copy)]
8pub struct HashEmbedderConfig {
9    pub dimensions: usize,
10    pub seed: u64,
11}
12
13impl Default for HashEmbedderConfig {
14    fn default() -> Self {
15        Self {
16            dimensions: 64,
17            seed: 1337,
18        }
19    }
20}
21
22#[derive(Clone)]
23pub struct HashEmbedder {
24    config: HashEmbedderConfig,
25}
26
27impl HashEmbedder {
28    pub fn new(config: HashEmbedderConfig) -> Self {
29        Self { config }
30    }
31
32    pub fn embed_text(&self, text: &str) -> Vec<f32> {
33        let dims = self.config.dimensions.max(1);
34        let mut vector = vec![0f32; dims];
35        for token in text.split_whitespace() {
36            let bucket = self.bucket_for(token);
37            vector[bucket] += 1.0;
38        }
39        normalize(&mut vector);
40        vector
41    }
42
43    fn bucket_for(&self, token: &str) -> usize {
44        let mut hasher = DefaultHasher::new();
45        hasher.write_u64(self.config.seed);
46        token.to_lowercase().hash(&mut hasher);
47        (hasher.finish() as usize) % self.config.dimensions.max(1)
48    }
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct EmbeddingRecord {
53    pub chunk_id: String,
54    pub doc: String,
55    pub chunk_index: usize,
56    #[serde(default)]
57    pub z_start: u32,
58    #[serde(default)]
59    pub z_end: u32,
60    #[serde(default)]
61    pub cell_start: usize,
62    #[serde(default)]
63    pub cell_end: usize,
64    #[serde(default)]
65    pub token_count: usize,
66    #[serde(default = "default_cell_type")]
67    pub dominant_type: CellType,
68    #[serde(default)]
69    pub importance_mean: f32,
70    pub embedding: Vec<f32>,
71    pub text: String,
72}
73
74fn default_cell_type() -> CellType {
75    CellType::Text
76}
77
78fn normalize(vector: &mut [f32]) {
79    let norm = vector.iter().map(|v| v * v).sum::<f32>().sqrt();
80    if norm == 0.0 {
81        return;
82    }
83    for value in vector.iter_mut() {
84        *value /= norm;
85    }
86}