vyre_reference/dual_impls/memory/scatter/
reference.rs1use crate::{dual_impls::common, workgroup::Memory};
2use vyre_primitives::Scatter;
3
4impl common::ReferenceEvaluator for Scatter {
5 fn evaluate(&self, inputs: &[Memory]) -> Result<Memory, common::EvalError> {
6 let (values, indices) = common::two_inputs(inputs, "scatter")?;
7 let values = common::u32_words(values, "scatter")?;
8 let indices = common::u32_words(indices, "scatter")?;
9 if values.len() != indices.len() {
10 return Err(common::EvalError::new(format!(
11 "primitive `scatter` expected equal value/index counts, got {} and {}. Fix: make scatter inputs the same length.",
12 values.len(),
13 indices.len()
14 )));
15 }
16 let max_index = indices.iter().copied().max().unwrap_or(0);
17 let len = usize::try_from(max_index).map_err(|_| {
18 common::EvalError::new(
19 "primitive `scatter` max index does not fit usize. Fix: keep scatter indices addressable.",
20 )
21 })?;
22 let mut output = vec![0; len.saturating_add(1)];
23 for (value, index) in values.into_iter().zip(indices) {
24 let slot = usize::try_from(index).map_err(|_| {
25 common::EvalError::new(
26 "primitive `scatter` index does not fit usize. Fix: keep scatter indices addressable.",
27 )
28 })?;
29 output[slot] = value;
30 }
31 Ok(common::write_u32s(output))
32 }
33}