sim_lib_numbers_tensor_bit/
lib.rs1#![forbid(unsafe_code)]
2#![allow(deprecated)]
3#![deny(missing_docs)]
4
5use sim_kernel::{
37 AbiVersion, DefaultFactory, Dependency, Export, Factory, Lib, LibManifest, LibTarget, Linker,
38 Result, Symbol, Value, Version,
39};
40use sim_lib_numbers_tensor::{
41 SpecTensor, SpecTensorDescriptor, Tensor, domains, element_count, spec_tensor_descriptor_value,
42 spec_tensor_symbol,
43};
44
45#[derive(Clone, Debug, PartialEq, Eq)]
51pub struct BitTensor {
52 shape: Vec<usize>,
53 len: usize,
54 words: Vec<u64>,
55}
56
57impl BitTensor {
58 pub fn from_bools(shape: Vec<usize>, bits: &[bool]) -> Option<Self> {
63 let len = element_count(&shape);
64 if len != bits.len() {
65 return None;
66 }
67 let mut words = vec![0u64; len.div_ceil(64)];
68 for (index, bit) in bits.iter().enumerate() {
69 if *bit {
70 words[index / 64] |= 1u64 << (index % 64);
71 }
72 }
73 Some(Self { shape, len, words })
74 }
75
76 pub fn to_bools(&self) -> Vec<bool> {
78 (0..self.len)
79 .map(|index| ((self.words[index / 64] >> (index % 64)) & 1) == 1)
80 .collect()
81 }
82
83 pub fn bit_or(&self, other: &Self) -> Option<Self> {
87 map_words(self, other, |left, right| left | right)
88 }
89
90 pub fn bit_xor(&self, other: &Self) -> Option<Self> {
94 map_words(self, other, |left, right| left ^ right)
95 }
96
97 pub fn bit_and(&self, other: &Self) -> Option<Self> {
101 map_words(self, other, |left, right| left & right)
102 }
103}
104
105impl SpecTensor for BitTensor {
106 fn shape(&self) -> &[usize] {
107 &self.shape
108 }
109
110 fn dtype(&self) -> Symbol {
111 domains::bool()
112 }
113
114 fn to_uniform(&self) -> Tensor {
115 Tensor {
116 shape: self.shape.clone(),
117 dtype: self.dtype(),
118 data: self
119 .to_bools()
120 .into_iter()
121 .map(bool_value)
122 .collect::<Option<Vec<_>>>()
123 .expect("bool tensor values should always encode"),
124 }
125 }
126
127 fn from_uniform(tensor: &Tensor) -> Option<Self> {
128 let bits = tensor
129 .data
130 .iter()
131 .map(parse_bool_cell)
132 .collect::<Option<Vec<_>>>()?;
133 Self::from_bools(tensor.shape.clone(), &bits)
134 }
135}
136
137pub struct BitTensorLib;
143
144impl BitTensorLib {
145 pub fn new() -> Self {
149 Self
150 }
151}
152
153impl Default for BitTensorLib {
154 fn default() -> Self {
155 Self::new()
156 }
157}
158
159impl Lib for BitTensorLib {
160 fn manifest(&self) -> LibManifest {
161 LibManifest {
162 id: tensor_lib_symbol(),
163 version: Version(env!("CARGO_PKG_VERSION").to_owned()),
164 abi: AbiVersion { major: 0, minor: 1 },
165 target: LibTarget::HostRegistered,
166 requires: Vec::<Dependency>::new(),
167 capabilities: Vec::new(),
168 exports: vec![Export::Value {
169 symbol: tensor_spec_symbol(),
170 }],
171 }
172 }
173
174 fn load(&self, _cx: &mut sim_kernel::LoadCx, linker: &mut Linker<'_>) -> Result<()> {
175 linker.value(
176 tensor_spec_symbol(),
177 spec_tensor_descriptor_value(
178 &DefaultFactory,
179 SpecTensorDescriptor {
180 symbol: tensor_spec_symbol(),
181 dtype: domains::bool(),
182 implementation: "BitTensor",
183 storage: "bit-packed u64 words",
184 },
185 )?,
186 )
187 }
188}
189
190pub fn tensor_lib_symbol() -> Symbol {
192 domains::domain("tensor-bit")
193}
194
195pub fn tensor_spec_symbol() -> Symbol {
197 spec_tensor_symbol("bit")
198}
199
200fn bool_value(value: bool) -> Option<Value> {
201 DefaultFactory
202 .number_literal(domains::bool(), value.to_string())
203 .ok()
204}
205
206fn parse_bool_cell(value: &Value) -> Option<bool> {
207 let mut cx = sim_kernel::Cx::new(
208 std::sync::Arc::new(sim_kernel::NoopEvalPolicy),
209 std::sync::Arc::new(DefaultFactory),
210 );
211 let literal = value
212 .object()
213 .as_number_value()?
214 .number_literal(&mut cx)
215 .ok()??;
216 (literal.domain == domains::bool())
217 .then(|| literal.canonical.parse::<bool>().ok())
218 .flatten()
219}
220
221fn map_words(
222 left: &BitTensor,
223 right: &BitTensor,
224 f: impl Fn(u64, u64) -> u64,
225) -> Option<BitTensor> {
226 (left.shape == right.shape).then(|| BitTensor {
227 shape: left.shape.clone(),
228 len: left.len,
229 words: left
230 .words
231 .iter()
232 .zip(right.words.iter())
233 .map(|(left, right)| f(*left, *right))
234 .collect(),
235 })
236}
237
238#[cfg(test)]
239mod tests {
240 use sim_kernel::Lib;
241
242 use super::{BitTensor, BitTensorLib, SpecTensor, tensor_spec_symbol};
243
244 #[test]
245 fn bit_tensor_and_matches_bool_and() {
246 let left = BitTensor::from_bools(vec![4], &[true, false, true, true]).unwrap();
247 let right = BitTensor::from_bools(vec![4], &[true, true, false, true]).unwrap();
248 let out = left.bit_and(&right).unwrap();
249 assert_eq!(out.to_bools(), vec![true, false, false, true]);
250 let uniform = out.to_uniform();
251 assert_eq!(uniform.shape, vec![4]);
252 }
253
254 #[test]
255 fn lib_exports_spec_tensor_descriptor() {
256 assert_eq!(
257 BitTensorLib::new().manifest().exports[0].symbol(),
258 &tensor_spec_symbol()
259 );
260 }
261}