terminals_core/substrate/
splat.rs1use super::projection::{Projection, ProjectionId};
7
8pub const EMBEDDING_DIM: usize = 384;
10
11const SPLAT_BYTES: usize = EMBEDDING_DIM * 4;
13
14#[derive(Debug, Clone)]
15pub struct SplatProjection {
16 pub embedding: [f32; EMBEDDING_DIM],
17}
18
19impl Default for SplatProjection {
20 fn default() -> Self {
21 Self {
22 embedding: [0.0; EMBEDDING_DIM],
23 }
24 }
25}
26
27impl Projection for SplatProjection {
28 fn byte_size() -> usize {
29 SPLAT_BYTES
30 }
31
32 fn id() -> ProjectionId {
33 ProjectionId::Splat
34 }
35
36 fn read(buf: &[u8]) -> Self {
37 assert!(buf.len() >= SPLAT_BYTES, "SplatProjection: buffer too small");
38 let mut embedding = [0.0f32; EMBEDDING_DIM];
39 for i in 0..EMBEDDING_DIM {
40 let offset = i * 4;
41 embedding[i] = f32::from_le_bytes([
42 buf[offset],
43 buf[offset + 1],
44 buf[offset + 2],
45 buf[offset + 3],
46 ]);
47 }
48 Self { embedding }
49 }
50
51 fn write(&self, buf: &mut [u8]) {
52 assert!(buf.len() >= SPLAT_BYTES, "SplatProjection: buffer too small");
53 for i in 0..EMBEDDING_DIM {
54 let bytes = self.embedding[i].to_le_bytes();
55 let offset = i * 4;
56 buf[offset..offset + 4].copy_from_slice(&bytes);
57 }
58 }
59
60 fn shape_hash_contribution(&self) -> u32 {
61 let mut hash = 0x811c_9dc5u32;
63 for &v in &self.embedding[..16.min(EMBEDDING_DIM)] {
64 let bits = v.to_bits();
65 for byte in bits.to_le_bytes() {
66 hash ^= byte as u32;
67 hash = hash.wrapping_mul(0x0100_0193);
68 }
69 }
70 hash
71 }
72}
73
74#[cfg(test)]
75mod tests {
76 use super::*;
77
78 #[test]
79 fn test_splat_byte_size() {
80 assert_eq!(SplatProjection::byte_size(), 1536);
81 }
82
83 #[test]
84 fn test_splat_roundtrip() {
85 let mut proj = SplatProjection::default();
86 proj.embedding[0] = 1.0;
87 proj.embedding[100] = -0.5;
88 proj.embedding[383] = 0.42;
89
90 let mut buf = vec![0u8; SplatProjection::byte_size()];
91 proj.write(&mut buf);
92 let restored = SplatProjection::read(&buf);
93
94 assert!((restored.embedding[0] - 1.0).abs() < 1e-6);
95 assert!((restored.embedding[100] - (-0.5)).abs() < 1e-6);
96 assert!((restored.embedding[383] - 0.42).abs() < 1e-6);
97 }
98
99 #[test]
100 fn test_splat_shape_hash_varies() {
101 let a = SplatProjection::default();
102 let mut b = SplatProjection::default();
103 b.embedding[0] = 1.0;
104 assert_ne!(a.shape_hash_contribution(), b.shape_hash_contribution());
105 }
106}