tract_linalg/frame/mmm/
input_store.rs1use downcast_rs::{Downcast, impl_downcast};
2use dyn_clone::DynClone;
3use dyn_eq::DynEq;
4use dyn_hash::DynHash;
5use std::alloc::Layout;
6use std::fmt::{Debug, Display};
7use std::hash::Hash;
8use std::sync::Arc;
9use tract_data::internal::*;
10
11use crate::WeightType;
12
13pub trait MMMInputFormat:
14 Downcast + Debug + DynHash + dyn_eq::DynEq + DynClone + Send + Sync + Display
15{
16 fn prepare_tensor(&self, t: &Tensor, k_axis: usize, mn_axis: usize) -> TractResult<Tensor>;
17 fn prepare_one(
18 &self,
19 t: &Tensor,
20 k_axis: usize,
21 mn_axis: usize,
22 ) -> TractResult<Box<dyn MMMInputValue>>;
23 fn precursor(&self) -> WeightType;
24 fn r(&self) -> usize;
25 fn k_alignment(&self) -> usize;
26 fn merge_with<'o, 'a: 'o, 'b: 'o>(
27 &'a self,
28 other: &'b dyn MMMInputFormat,
29 ) -> Option<&'o dyn MMMInputFormat> {
30 if self.dyn_eq(other) { Some(other) } else { None }
31 }
32 fn mem_size(&self, k: TDim, mn: TDim) -> TDim;
33 fn extract_at_mn_f16(
34 &self,
35 data: &EagerPackedInput,
36 mn: usize,
37 slice: &mut [f16],
38 ) -> TractResult<()>;
39 fn extract_at_mn_f32(
40 &self,
41 data: &EagerPackedInput,
42 mn: usize,
43 slice: &mut [f32],
44 ) -> TractResult<()>;
45}
46
47dyn_clone::clone_trait_object!(MMMInputFormat);
48impl_downcast!(MMMInputFormat);
49dyn_hash::hash_trait_object!(MMMInputFormat);
50dyn_eq::eq_trait_object!(MMMInputFormat);
51
52pub trait MMMInputValue:
53 DynClone + Debug + DynHash + dyn_eq::DynEq + Send + Sync + Display + Downcast
54{
55 fn format(&self) -> &dyn MMMInputFormat;
56 fn scratch_panel_buffer_layout(&self) -> Option<Layout>;
57 fn panel_bytes(&self, i: usize, buffer: Option<*mut u8>) -> TractResult<*const u8>;
58 fn panels_count(&self) -> usize {
59 self.mn().divceil(self.format().r())
60 }
61 fn mn(&self) -> usize;
62 fn k(&self) -> usize;
63 fn exotic_fact(&self) -> &dyn ExoticFact;
64
65 fn extract_at_mn_f16(&self, mn: usize, slice: &mut [f16]) -> TractResult<()>;
66 fn extract_at_mn_f32(&self, mn: usize, slice: &mut [f32]) -> TractResult<()>;
67}
68dyn_clone::clone_trait_object!(MMMInputValue);
69impl_downcast!(MMMInputValue);
70dyn_hash::hash_trait_object!(MMMInputValue);
71dyn_eq::eq_trait_object!(MMMInputValue);
72
73#[allow(clippy::derived_hash_with_manual_eq)]
74#[derive(Clone, Hash, Debug)]
75pub struct PackedExoticFact {
76 pub format: Box<dyn MMMInputFormat>,
77 pub mn: TDim,
78 pub k: usize,
79}
80
81impl Display for PackedExoticFact {
82 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83 write!(f, "Eager {} tensor (mn={} k={})", self.format, self.mn, self.k)
84 }
85}
86
87impl ExoticFact for PackedExoticFact {
88 fn buffer_sizes(&self) -> TVec<TDim> {
89 tvec!(self.format.mem_size(self.k.to_dim(), self.mn.clone()))
90 }
91}
92
93impl PartialEq for PackedExoticFact {
94 fn eq(&self, other: &Self) -> bool {
95 self.format == other.format && self.mn == other.mn && self.k == other.k
96 }
97}
98impl Eq for PackedExoticFact {}
99
100#[derive(Clone, Hash, PartialEq, Eq)]
101pub struct EagerPackedInput {
102 pub fact: PackedExoticFact,
103 pub packed: Arc<Blob>,
104 pub panel_bytes: usize,
105 pub mn: usize,
106}
107
108impl MMMInputValue for EagerPackedInput {
109 fn scratch_panel_buffer_layout(&self) -> Option<Layout> {
110 None
111 }
112 fn panel_bytes(&self, i: usize, _buffer: Option<*mut u8>) -> TractResult<*const u8> {
113 unsafe { Ok(self.packed.as_ptr().add(i * self.panel_bytes)) }
114 }
115 fn k(&self) -> usize {
116 self.fact.k
117 }
118 fn mn(&self) -> usize {
119 self.mn
120 }
121 fn format(&self) -> &dyn MMMInputFormat {
122 &*self.fact.format
123 }
124 fn exotic_fact(&self) -> &dyn ExoticFact {
125 &self.fact
126 }
127 fn extract_at_mn_f16(&self, mn: usize, slice: &mut [f16]) -> TractResult<()> {
128 ensure!(slice.len() == self.k());
129 ensure!(mn < self.mn());
130 self.fact.format.extract_at_mn_f16(self, mn, slice)
131 }
132 fn extract_at_mn_f32(&self, mn: usize, slice: &mut [f32]) -> TractResult<()> {
133 ensure!(slice.len() == self.k());
134 ensure!(mn < self.mn());
135 self.fact.format.extract_at_mn_f32(self, mn, slice)
136 }
137}
138
139impl Display for EagerPackedInput {
140 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
141 (&self.fact as &dyn Display).fmt(f)
142 }
143}
144
145impl Debug for EagerPackedInput {
146 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
147 <Self as Display>::fmt(self, f)
148 }
149}