Skip to main content

tract_linalg/frame/mmm/
input_store.rs

1use downcast_rs::{Downcast, impl_downcast};
2use dyn_clone::DynClone;
3use dyn_hash::DynHash;
4use std::alloc::Layout;
5use std::fmt::{Debug, Display};
6use std::hash::Hash;
7use std::sync::Arc;
8use tract_data::internal::*;
9
10use crate::WeightType;
11
12pub trait MMMInputFormat: Downcast + Debug + DynHash + DynClone + Send + Sync + Display {
13    fn prepare_tensor(&self, t: &Tensor, k_axis: usize, mn_axis: usize) -> TractResult<Tensor>;
14    fn prepare_one(
15        &self,
16        t: &Tensor,
17        k_axis: usize,
18        mn_axis: usize,
19    ) -> TractResult<Box<dyn MMMInputValue>>;
20    fn precursor(&self) -> WeightType;
21    fn r(&self) -> usize;
22    fn k_alignment(&self) -> usize;
23    fn same_as(&self, other: &dyn MMMInputFormat) -> bool;
24    fn merge_with<'o, 'a: 'o, 'b: 'o>(
25        &'a self,
26        other: &'b dyn MMMInputFormat,
27    ) -> Option<&'o dyn MMMInputFormat> {
28        if self.same_as(other) { Some(other) } else { None }
29    }
30    fn mem_size(&self, k: TDim, mn: TDim) -> TDim;
31    fn extract_at_mn_f16(
32        &self,
33        data: &EagerPackedInput,
34        mn: usize,
35        slice: &mut [f16],
36    ) -> TractResult<()>;
37    fn extract_at_mn_f32(
38        &self,
39        data: &EagerPackedInput,
40        mn: usize,
41        slice: &mut [f32],
42    ) -> TractResult<()>;
43}
44
45dyn_clone::clone_trait_object!(MMMInputFormat);
46impl_downcast!(MMMInputFormat);
47dyn_hash::hash_trait_object!(MMMInputFormat);
48
49impl Eq for &dyn MMMInputFormat {}
50impl PartialEq for &dyn MMMInputFormat {
51    fn eq(&self, other: &Self) -> bool {
52        self.same_as(*other)
53    }
54}
55
56pub trait MMMInputValue: DynClone + Debug + DynHash + Send + Sync + Display + Downcast {
57    fn format(&self) -> &dyn MMMInputFormat;
58    fn scratch_panel_buffer_layout(&self) -> Option<Layout>;
59    fn panel_bytes(&self, i: usize, buffer: Option<*mut u8>) -> TractResult<*const u8>;
60    fn panels_count(&self) -> usize {
61        self.mn().divceil(self.format().r())
62    }
63    fn mn(&self) -> usize;
64    fn k(&self) -> usize;
65    fn opaque_fact(&self) -> &dyn OpaqueFact;
66    fn same_as(&self, other: &dyn MMMInputValue) -> bool;
67
68    fn extract_at_mn_f16(&self, mn: usize, slice: &mut [f16]) -> TractResult<()>;
69    fn extract_at_mn_f32(&self, mn: usize, slice: &mut [f32]) -> TractResult<()>;
70}
71dyn_clone::clone_trait_object!(MMMInputValue);
72impl_downcast!(MMMInputValue);
73dyn_hash::hash_trait_object!(MMMInputValue);
74
75impl From<Box<dyn MMMInputValue>> for Opaque {
76    fn from(value: Box<dyn MMMInputValue>) -> Self {
77        Opaque(Arc::new(value))
78    }
79}
80
81impl OpaquePayload for Box<dyn MMMInputValue> {
82    fn same_as(&self, other: &dyn OpaquePayload) -> bool {
83        other
84            .downcast_ref::<Self>()
85            .is_some_and(|other| (&**self as &dyn MMMInputValue).same_as(&**other))
86    }
87}
88
89#[allow(clippy::derived_hash_with_manual_eq)]
90#[derive(Clone, Hash, Debug)]
91pub struct PackedOpaqueFact {
92    pub format: Box<dyn MMMInputFormat>,
93    pub mn: TDim,
94    pub k: usize,
95}
96
97impl Display for PackedOpaqueFact {
98    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
99        write!(f, "Eager {} tensor (mn={} k={})", self.format, self.mn, self.k)
100    }
101}
102
103impl OpaqueFact for PackedOpaqueFact {
104    fn same_as(&self, other: &dyn OpaqueFact) -> bool {
105        other.downcast_ref::<Self>().is_some_and(|o| o == self)
106    }
107
108    fn buffer_sizes(&self) -> TVec<TDim> {
109        tvec!(self.format.mem_size(self.k.to_dim(), self.mn.clone()))
110    }
111}
112
113impl PartialEq for PackedOpaqueFact {
114    fn eq(&self, other: &Self) -> bool {
115        self.format.same_as(&*other.format) && self.mn == other.mn && self.k == other.k
116    }
117}
118
119#[derive(Clone, Hash)]
120pub struct EagerPackedInput {
121    pub fact: PackedOpaqueFact,
122    pub packed: Arc<Blob>,
123    pub panel_bytes: usize,
124    pub mn: usize,
125}
126
127impl MMMInputValue for EagerPackedInput {
128    fn scratch_panel_buffer_layout(&self) -> Option<Layout> {
129        None
130    }
131    fn panel_bytes(&self, i: usize, _buffer: Option<*mut u8>) -> TractResult<*const u8> {
132        unsafe { Ok(self.packed.as_ptr().add(i * self.panel_bytes)) }
133    }
134    fn k(&self) -> usize {
135        self.fact.k
136    }
137    fn mn(&self) -> usize {
138        self.mn
139    }
140    fn format(&self) -> &dyn MMMInputFormat {
141        &*self.fact.format
142    }
143    fn opaque_fact(&self) -> &dyn OpaqueFact {
144        &self.fact
145    }
146    fn same_as(&self, other: &dyn MMMInputValue) -> bool {
147        other.downcast_ref::<Self>().is_some_and(|other| {
148            self.fact.same_as(&other.fact)
149                && self.packed == other.packed
150                && self.panel_bytes == other.panel_bytes
151        })
152    }
153    fn extract_at_mn_f16(&self, mn: usize, slice: &mut [f16]) -> TractResult<()> {
154        ensure!(slice.len() == self.k());
155        ensure!(mn < self.mn());
156        self.fact.format.extract_at_mn_f16(self, mn, slice)
157    }
158    fn extract_at_mn_f32(&self, mn: usize, slice: &mut [f32]) -> TractResult<()> {
159        ensure!(slice.len() == self.k());
160        ensure!(mn < self.mn());
161        self.fact.format.extract_at_mn_f32(self, mn, slice)
162    }
163}
164
165impl Display for EagerPackedInput {
166    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
167        (&self.fact as &dyn Display).fmt(f)
168    }
169}
170
171impl Debug for EagerPackedInput {
172    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
173        <Self as Display>::fmt(self, f)
174    }
175}