tract_linalg/frame/mmm/
input_store.rs

1use downcast_rs::{impl_downcast, Downcast};
2use dyn_clone::DynClone;
3use dyn_hash::DynHash;
4use std::alloc::Layout;
5use std::fmt::{Debug, Display};
6use std::hash::Hash;
7use std::sync::Arc;
8use tract_data::internal::*;
9
10use crate::WeightType;
11
12pub trait MMMInputFormat: Downcast + Debug + DynHash + DynClone + Send + Sync + Display {
13    fn prepare_tensor(&self, t: &Tensor, k_axis: usize, mn_axis: usize) -> TractResult<Tensor>;
14    fn prepare_one(
15        &self,
16        t: &Tensor,
17        k_axis: usize,
18        mn_axis: usize,
19    ) -> TractResult<Box<dyn MMMInputValue>>;
20    fn precursor(&self) -> WeightType;
21    fn r(&self) -> usize;
22    fn k_alignment(&self) -> usize;
23    fn same_as(&self, other: &dyn MMMInputFormat) -> bool;
24    fn mem_size(&self, k: TDim, mn: TDim) -> TDim;
25    fn extract_at_mn_f16(
26        &self,
27        data: &EagerPackedInput,
28        mn: usize,
29        slice: &mut [f16],
30    ) -> TractResult<()>;
31    fn extract_at_mn_f32(
32        &self,
33        data: &EagerPackedInput,
34        mn: usize,
35        slice: &mut [f32],
36    ) -> TractResult<()>;
37}
38
39dyn_clone::clone_trait_object!(MMMInputFormat);
40impl_downcast!(MMMInputFormat);
41dyn_hash::hash_trait_object!(MMMInputFormat);
42
43impl Eq for &dyn MMMInputFormat {}
44impl PartialEq for &dyn MMMInputFormat {
45    fn eq(&self, other: &Self) -> bool {
46        self.same_as(*other)
47    }
48}
49
50pub trait MMMInputValue: DynClone + Debug + DynHash + Send + Sync + Display + Downcast {
51    fn format(&self) -> &dyn MMMInputFormat;
52    fn scratch_panel_buffer_layout(&self) -> Option<Layout>;
53    fn panel_bytes(&self, i: usize, buffer: Option<*mut u8>) -> TractResult<*const u8>;
54    fn panels_count(&self) -> usize {
55        self.mn().divceil(self.format().r())
56    }
57    fn mn(&self) -> usize;
58    fn k(&self) -> usize;
59    fn opaque_fact(&self) -> &dyn OpaqueFact;
60    fn same_as(&self, other: &dyn MMMInputValue) -> bool;
61
62    fn extract_at_mn_f16(&self, mn: usize, slice: &mut [f16]) -> TractResult<()>;
63    fn extract_at_mn_f32(&self, mn: usize, slice: &mut [f32]) -> TractResult<()>;
64}
65dyn_clone::clone_trait_object!(MMMInputValue);
66impl_downcast!(MMMInputValue);
67dyn_hash::hash_trait_object!(MMMInputValue);
68
69impl From<Box<dyn MMMInputValue>> for Opaque {
70    fn from(value: Box<dyn MMMInputValue>) -> Self {
71        Opaque(Arc::new(value))
72    }
73}
74
75impl OpaquePayload for Box<dyn MMMInputValue> {
76    fn same_as(&self, other: &dyn OpaquePayload) -> bool {
77        other
78            .downcast_ref::<Self>()
79            .is_some_and(|other| (&**self as &dyn MMMInputValue).same_as(&**other))
80    }
81}
82
83#[allow(clippy::derived_hash_with_manual_eq)]
84#[derive(Clone, Hash, Debug)]
85pub struct PackedOpaqueFact {
86    pub format: Box<dyn MMMInputFormat>,
87    pub mn: TDim,
88    pub k: usize,
89}
90
91impl Display for PackedOpaqueFact {
92    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93        write!(f, "Eager {} tensor (mn={} k={})", self.format, self.mn, self.k)
94    }
95}
96
97impl OpaqueFact for PackedOpaqueFact {
98    fn mem_size(&self) -> TDim {
99        self.format.mem_size(self.k.to_dim(), self.mn.clone())
100    }
101
102    fn same_as(&self, other: &dyn OpaqueFact) -> bool {
103        other.downcast_ref::<Self>().is_some_and(|o| o == self)
104    }
105}
106
107impl PartialEq for PackedOpaqueFact {
108    fn eq(&self, other: &Self) -> bool {
109        self.format.same_as(&*other.format) && self.mn == other.mn && self.k == other.k
110    }
111}
112
113#[derive(Clone, Hash)]
114pub struct EagerPackedInput {
115    pub fact: PackedOpaqueFact,
116    pub packed: Arc<Blob>,
117    pub panel_bytes: usize,
118    pub mn: usize,
119}
120
121impl MMMInputValue for EagerPackedInput {
122    fn scratch_panel_buffer_layout(&self) -> Option<Layout> {
123        None
124    }
125    fn panel_bytes(&self, i: usize, _buffer: Option<*mut u8>) -> TractResult<*const u8> {
126        unsafe { Ok(self.packed.as_ptr().add(i * self.panel_bytes)) }
127    }
128    fn k(&self) -> usize {
129        self.fact.k
130    }
131    fn mn(&self) -> usize {
132        self.mn
133    }
134    fn format(&self) -> &dyn MMMInputFormat {
135        &*self.fact.format
136    }
137    fn opaque_fact(&self) -> &dyn OpaqueFact {
138        &self.fact
139    }
140    fn same_as(&self, other: &dyn MMMInputValue) -> bool {
141        other.downcast_ref::<Self>().is_some_and(|other| {
142            self.fact.same_as(&other.fact)
143                && self.packed == other.packed
144                && self.panel_bytes == other.panel_bytes
145        })
146    }
147    fn extract_at_mn_f16(&self, mn: usize, slice: &mut [f16]) -> TractResult<()> {
148        ensure!(slice.len() == self.k());
149        ensure!(mn < self.mn());
150        self.fact.format.extract_at_mn_f16(self, mn, slice)
151    }
152    fn extract_at_mn_f32(&self, mn: usize, slice: &mut [f32]) -> TractResult<()> {
153        ensure!(slice.len() == self.k());
154        ensure!(mn < self.mn());
155        self.fact.format.extract_at_mn_f32(self, mn, slice)
156    }
157}
158
159impl Display for EagerPackedInput {
160    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
161        (&self.fact as &dyn Display).fmt(f)
162    }
163}
164
165impl Debug for EagerPackedInput {
166    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
167        <Self as Display>::fmt(self, f)
168    }
169}