Skip to main content

sim_lib_numbers_tensor_bit/
lib.rs

1#![forbid(unsafe_code)]
2#![allow(deprecated)]
3#![deny(missing_docs)]
4
5//! Bit-tensor specialization: a packed-word boolean tensor element type and its
6//! `SpecTensor` backend, with bitwise operations over the tensor domain.
7//!
8//! [`BitTensor`] is the storage type (booleans packed into `u64` words) with
9//! element-wise [`bit_and`](BitTensor::bit_and), [`bit_or`](BitTensor::bit_or),
10//! and [`bit_xor`](BitTensor::bit_xor). [`BitTensorLib`] registers it as the
11//! `bool` element-type backend for the base tensor domain.
12//!
13//! # Examples
14//!
15//! Pack booleans, combine two tensors bit-for-bit, and unpack the result:
16//!
17//! ```
18//! use sim_lib_numbers_tensor_bit::BitTensor;
19//!
20//! let left = BitTensor::from_bools(vec![4], &[true, false, true, true]).unwrap();
21//! let right = BitTensor::from_bools(vec![4], &[true, true, false, true]).unwrap();
22//! let masked = left.bit_and(&right).unwrap();
23//! assert_eq!(masked.to_bools(), vec![true, false, false, true]);
24//! ```
25//!
26//! Shape mismatches fail closed rather than truncate:
27//!
28//! ```
29//! use sim_lib_numbers_tensor_bit::BitTensor;
30//!
31//! let a = BitTensor::from_bools(vec![2], &[true, false]).unwrap();
32//! let b = BitTensor::from_bools(vec![3], &[true, false, true]).unwrap();
33//! assert!(a.bit_or(&b).is_none());
34//! ```
35
36use sim_kernel::{
37    AbiVersion, DefaultFactory, Dependency, Export, Factory, Lib, LibManifest, LibTarget, Linker,
38    Result, Symbol, Value, Version,
39};
40use sim_lib_numbers_tensor::{
41    SpecTensor, SpecTensorDescriptor, Tensor, domains, element_count, spec_tensor_descriptor_value,
42    spec_tensor_symbol,
43};
44
45/// A boolean tensor stored as bit-packed `u64` words.
46///
47/// Each element occupies a single bit, so an `n`-element tensor is held in
48/// `ceil(n / 64)` words. The logical [`shape`](Self::shape) and element count
49/// drive layout; bitwise operations work directly on the packed words.
50#[derive(Clone, Debug, PartialEq, Eq)]
51pub struct BitTensor {
52    shape: Vec<usize>,
53    len: usize,
54    words: Vec<u64>,
55}
56
57impl BitTensor {
58    /// Packs a slice of booleans into a bit tensor of the given shape.
59    ///
60    /// Returns `None` when `bits.len()` does not match the element count
61    /// implied by `shape`.
62    pub fn from_bools(shape: Vec<usize>, bits: &[bool]) -> Option<Self> {
63        let len = element_count(&shape);
64        if len != bits.len() {
65            return None;
66        }
67        let mut words = vec![0u64; len.div_ceil(64)];
68        for (index, bit) in bits.iter().enumerate() {
69            if *bit {
70                words[index / 64] |= 1u64 << (index % 64);
71            }
72        }
73        Some(Self { shape, len, words })
74    }
75
76    /// Unpacks the tensor back into one boolean per element, in flat order.
77    pub fn to_bools(&self) -> Vec<bool> {
78        (0..self.len)
79            .map(|index| ((self.words[index / 64] >> (index % 64)) & 1) == 1)
80            .collect()
81    }
82
83    /// Element-wise bitwise OR with another bit tensor of the same shape.
84    ///
85    /// Returns `None` when the shapes differ.
86    pub fn bit_or(&self, other: &Self) -> Option<Self> {
87        map_words(self, other, |left, right| left | right)
88    }
89
90    /// Element-wise bitwise XOR with another bit tensor of the same shape.
91    ///
92    /// Returns `None` when the shapes differ.
93    pub fn bit_xor(&self, other: &Self) -> Option<Self> {
94        map_words(self, other, |left, right| left ^ right)
95    }
96
97    /// Element-wise bitwise AND with another bit tensor of the same shape.
98    ///
99    /// Returns `None` when the shapes differ.
100    pub fn bit_and(&self, other: &Self) -> Option<Self> {
101        map_words(self, other, |left, right| left & right)
102    }
103}
104
105impl SpecTensor for BitTensor {
106    fn shape(&self) -> &[usize] {
107        &self.shape
108    }
109
110    fn dtype(&self) -> Symbol {
111        domains::bool()
112    }
113
114    fn to_uniform(&self) -> Tensor {
115        Tensor {
116            shape: self.shape.clone(),
117            dtype: self.dtype(),
118            data: self
119                .to_bools()
120                .into_iter()
121                .map(bool_value)
122                .collect::<Option<Vec<_>>>()
123                .expect("bool tensor values should always encode"),
124        }
125    }
126
127    fn from_uniform(tensor: &Tensor) -> Option<Self> {
128        let bits = tensor
129            .data
130            .iter()
131            .map(parse_bool_cell)
132            .collect::<Option<Vec<_>>>()?;
133        Self::from_bools(tensor.shape.clone(), &bits)
134    }
135}
136
137/// Registered library that installs the bit-packed boolean tensor backend.
138///
139/// Loading this [`Lib`] registers a [`SpecTensor`] descriptor binding the
140/// `bool` element type to the [`BitTensor`] storage, so the base tensor domain
141/// can construct and round-trip boolean tensors through packed `u64` words.
142pub struct BitTensorLib;
143
144impl BitTensorLib {
145    /// Creates the bit-tensor library. The value is stateless; the spec-tensor
146    /// descriptor is installed when it is loaded into a
147    /// [`Cx`](sim_kernel::Cx).
148    pub fn new() -> Self {
149        Self
150    }
151}
152
153impl Default for BitTensorLib {
154    fn default() -> Self {
155        Self::new()
156    }
157}
158
159impl Lib for BitTensorLib {
160    fn manifest(&self) -> LibManifest {
161        LibManifest {
162            id: tensor_lib_symbol(),
163            version: Version(env!("CARGO_PKG_VERSION").to_owned()),
164            abi: AbiVersion { major: 0, minor: 1 },
165            target: LibTarget::HostRegistered,
166            requires: Vec::<Dependency>::new(),
167            capabilities: Vec::new(),
168            exports: vec![Export::Value {
169                symbol: tensor_spec_symbol(),
170            }],
171        }
172    }
173
174    fn load(&self, _cx: &mut sim_kernel::LoadCx, linker: &mut Linker<'_>) -> Result<()> {
175        linker.value(
176            tensor_spec_symbol(),
177            spec_tensor_descriptor_value(
178                &DefaultFactory,
179                SpecTensorDescriptor {
180                    symbol: tensor_spec_symbol(),
181                    dtype: domains::bool(),
182                    implementation: "BitTensor",
183                    storage: "bit-packed u64 words",
184                },
185            )?,
186        )
187    }
188}
189
190/// The manifest id symbol for this library (`numbers/tensor-bit`).
191pub fn tensor_lib_symbol() -> Symbol {
192    domains::domain("tensor-bit")
193}
194
195/// The symbol under which the bit-tensor [`SpecTensor`] descriptor is exported.
196pub fn tensor_spec_symbol() -> Symbol {
197    spec_tensor_symbol("bit")
198}
199
200fn bool_value(value: bool) -> Option<Value> {
201    DefaultFactory
202        .number_literal(domains::bool(), value.to_string())
203        .ok()
204}
205
206fn parse_bool_cell(value: &Value) -> Option<bool> {
207    let mut cx = sim_kernel::Cx::new(
208        std::sync::Arc::new(sim_kernel::NoopEvalPolicy),
209        std::sync::Arc::new(DefaultFactory),
210    );
211    let literal = value
212        .object()
213        .as_number_value()?
214        .number_literal(&mut cx)
215        .ok()??;
216    (literal.domain == domains::bool())
217        .then(|| literal.canonical.parse::<bool>().ok())
218        .flatten()
219}
220
221fn map_words(
222    left: &BitTensor,
223    right: &BitTensor,
224    f: impl Fn(u64, u64) -> u64,
225) -> Option<BitTensor> {
226    (left.shape == right.shape).then(|| BitTensor {
227        shape: left.shape.clone(),
228        len: left.len,
229        words: left
230            .words
231            .iter()
232            .zip(right.words.iter())
233            .map(|(left, right)| f(*left, *right))
234            .collect(),
235    })
236}
237
238#[cfg(test)]
239mod tests {
240    use sim_kernel::Lib;
241
242    use super::{BitTensor, BitTensorLib, SpecTensor, tensor_spec_symbol};
243
244    #[test]
245    fn bit_tensor_and_matches_bool_and() {
246        let left = BitTensor::from_bools(vec![4], &[true, false, true, true]).unwrap();
247        let right = BitTensor::from_bools(vec![4], &[true, true, false, true]).unwrap();
248        let out = left.bit_and(&right).unwrap();
249        assert_eq!(out.to_bools(), vec![true, false, false, true]);
250        let uniform = out.to_uniform();
251        assert_eq!(uniform.shape, vec![4]);
252    }
253
254    #[test]
255    fn lib_exports_spec_tensor_descriptor() {
256        assert_eq!(
257            BitTensorLib::new().manifest().exports[0].symbol(),
258            &tensor_spec_symbol()
259        );
260    }
261}