tract_linalg/frame/mmm/
input_store.rs1use downcast_rs::{Downcast, impl_downcast};
2use dyn_clone::DynClone;
3use dyn_hash::DynHash;
4use std::alloc::Layout;
5use std::fmt::{Debug, Display};
6use std::hash::Hash;
7use std::sync::Arc;
8use tract_data::internal::*;
9
10use crate::WeightType;
11
12pub trait MMMInputFormat: Downcast + Debug + DynHash + DynClone + Send + Sync + Display {
13 fn prepare_tensor(&self, t: &Tensor, k_axis: usize, mn_axis: usize) -> TractResult<Tensor>;
14 fn prepare_one(
15 &self,
16 t: &Tensor,
17 k_axis: usize,
18 mn_axis: usize,
19 ) -> TractResult<Box<dyn MMMInputValue>>;
20 fn precursor(&self) -> WeightType;
21 fn r(&self) -> usize;
22 fn k_alignment(&self) -> usize;
23 fn same_as(&self, other: &dyn MMMInputFormat) -> bool;
24 fn merge_with<'o, 'a: 'o, 'b: 'o>(
25 &'a self,
26 other: &'b dyn MMMInputFormat,
27 ) -> Option<&'o dyn MMMInputFormat> {
28 if self.same_as(other) { Some(other) } else { None }
29 }
30 fn mem_size(&self, k: TDim, mn: TDim) -> TDim;
31 fn extract_at_mn_f16(
32 &self,
33 data: &EagerPackedInput,
34 mn: usize,
35 slice: &mut [f16],
36 ) -> TractResult<()>;
37 fn extract_at_mn_f32(
38 &self,
39 data: &EagerPackedInput,
40 mn: usize,
41 slice: &mut [f32],
42 ) -> TractResult<()>;
43}
44
45dyn_clone::clone_trait_object!(MMMInputFormat);
46impl_downcast!(MMMInputFormat);
47dyn_hash::hash_trait_object!(MMMInputFormat);
48
49impl Eq for &dyn MMMInputFormat {}
50impl PartialEq for &dyn MMMInputFormat {
51 fn eq(&self, other: &Self) -> bool {
52 self.same_as(*other)
53 }
54}
55
56pub trait MMMInputValue: DynClone + Debug + DynHash + Send + Sync + Display + Downcast {
57 fn format(&self) -> &dyn MMMInputFormat;
58 fn scratch_panel_buffer_layout(&self) -> Option<Layout>;
59 fn panel_bytes(&self, i: usize, buffer: Option<*mut u8>) -> TractResult<*const u8>;
60 fn panels_count(&self) -> usize {
61 self.mn().divceil(self.format().r())
62 }
63 fn mn(&self) -> usize;
64 fn k(&self) -> usize;
65 fn opaque_fact(&self) -> &dyn OpaqueFact;
66 fn same_as(&self, other: &dyn MMMInputValue) -> bool;
67
68 fn extract_at_mn_f16(&self, mn: usize, slice: &mut [f16]) -> TractResult<()>;
69 fn extract_at_mn_f32(&self, mn: usize, slice: &mut [f32]) -> TractResult<()>;
70}
71dyn_clone::clone_trait_object!(MMMInputValue);
72impl_downcast!(MMMInputValue);
73dyn_hash::hash_trait_object!(MMMInputValue);
74
75impl From<Box<dyn MMMInputValue>> for Opaque {
76 fn from(value: Box<dyn MMMInputValue>) -> Self {
77 Opaque(Arc::new(value))
78 }
79}
80
81impl OpaquePayload for Box<dyn MMMInputValue> {
82 fn same_as(&self, other: &dyn OpaquePayload) -> bool {
83 other
84 .downcast_ref::<Self>()
85 .is_some_and(|other| (&**self as &dyn MMMInputValue).same_as(&**other))
86 }
87}
88
89#[allow(clippy::derived_hash_with_manual_eq)]
90#[derive(Clone, Hash, Debug)]
91pub struct PackedOpaqueFact {
92 pub format: Box<dyn MMMInputFormat>,
93 pub mn: TDim,
94 pub k: usize,
95}
96
97impl Display for PackedOpaqueFact {
98 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
99 write!(f, "Eager {} tensor (mn={} k={})", self.format, self.mn, self.k)
100 }
101}
102
103impl OpaqueFact for PackedOpaqueFact {
104 fn same_as(&self, other: &dyn OpaqueFact) -> bool {
105 other.downcast_ref::<Self>().is_some_and(|o| o == self)
106 }
107
108 fn buffer_sizes(&self) -> TVec<TDim> {
109 tvec!(self.format.mem_size(self.k.to_dim(), self.mn.clone()))
110 }
111}
112
113impl PartialEq for PackedOpaqueFact {
114 fn eq(&self, other: &Self) -> bool {
115 self.format.same_as(&*other.format) && self.mn == other.mn && self.k == other.k
116 }
117}
118
119#[derive(Clone, Hash)]
120pub struct EagerPackedInput {
121 pub fact: PackedOpaqueFact,
122 pub packed: Arc<Blob>,
123 pub panel_bytes: usize,
124 pub mn: usize,
125}
126
127impl MMMInputValue for EagerPackedInput {
128 fn scratch_panel_buffer_layout(&self) -> Option<Layout> {
129 None
130 }
131 fn panel_bytes(&self, i: usize, _buffer: Option<*mut u8>) -> TractResult<*const u8> {
132 unsafe { Ok(self.packed.as_ptr().add(i * self.panel_bytes)) }
133 }
134 fn k(&self) -> usize {
135 self.fact.k
136 }
137 fn mn(&self) -> usize {
138 self.mn
139 }
140 fn format(&self) -> &dyn MMMInputFormat {
141 &*self.fact.format
142 }
143 fn opaque_fact(&self) -> &dyn OpaqueFact {
144 &self.fact
145 }
146 fn same_as(&self, other: &dyn MMMInputValue) -> bool {
147 other.downcast_ref::<Self>().is_some_and(|other| {
148 self.fact.same_as(&other.fact)
149 && self.packed == other.packed
150 && self.panel_bytes == other.panel_bytes
151 })
152 }
153 fn extract_at_mn_f16(&self, mn: usize, slice: &mut [f16]) -> TractResult<()> {
154 ensure!(slice.len() == self.k());
155 ensure!(mn < self.mn());
156 self.fact.format.extract_at_mn_f16(self, mn, slice)
157 }
158 fn extract_at_mn_f32(&self, mn: usize, slice: &mut [f32]) -> TractResult<()> {
159 ensure!(slice.len() == self.k());
160 ensure!(mn < self.mn());
161 self.fact.format.extract_at_mn_f32(self, mn, slice)
162 }
163}
164
165impl Display for EagerPackedInput {
166 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
167 (&self.fact as &dyn Display).fmt(f)
168 }
169}
170
171impl Debug for EagerPackedInput {
172 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
173 <Self as Display>::fmt(self, f)
174 }
175}