Skip to main content

tract_linalg/frame/mmm/
input_store.rs

1use downcast_rs::{Downcast, impl_downcast};
2use dyn_clone::DynClone;
3use dyn_eq::DynEq;
4use dyn_hash::DynHash;
5use std::alloc::Layout;
6use std::fmt::{Debug, Display};
7use std::hash::Hash;
8use std::sync::Arc;
9use tract_data::internal::*;
10
11use crate::WeightType;
12
13pub trait MMMInputFormat:
14    Downcast + Debug + DynHash + dyn_eq::DynEq + DynClone + Send + Sync + Display
15{
16    fn prepare_tensor(&self, t: &Tensor, k_axis: usize, mn_axis: usize) -> TractResult<Tensor>;
17    fn prepare_one(
18        &self,
19        t: &Tensor,
20        k_axis: usize,
21        mn_axis: usize,
22    ) -> TractResult<Box<dyn MMMInputValue>>;
23    fn precursor(&self) -> WeightType;
24    fn r(&self) -> usize;
25    fn k_alignment(&self) -> usize;
26    fn merge_with<'o, 'a: 'o, 'b: 'o>(
27        &'a self,
28        other: &'b dyn MMMInputFormat,
29    ) -> Option<&'o dyn MMMInputFormat> {
30        if self.dyn_eq(other) { Some(other) } else { None }
31    }
32    fn mem_size(&self, k: TDim, mn: TDim) -> TDim;
33    fn extract_at_mn_f16(
34        &self,
35        data: &EagerPackedInput,
36        mn: usize,
37        slice: &mut [f16],
38    ) -> TractResult<()>;
39    fn extract_at_mn_f32(
40        &self,
41        data: &EagerPackedInput,
42        mn: usize,
43        slice: &mut [f32],
44    ) -> TractResult<()>;
45}
46
47dyn_clone::clone_trait_object!(MMMInputFormat);
48impl_downcast!(MMMInputFormat);
49dyn_hash::hash_trait_object!(MMMInputFormat);
50dyn_eq::eq_trait_object!(MMMInputFormat);
51
52pub trait MMMInputValue:
53    DynClone + Debug + DynHash + dyn_eq::DynEq + Send + Sync + Display + Downcast
54{
55    fn format(&self) -> &dyn MMMInputFormat;
56    fn scratch_panel_buffer_layout(&self) -> Option<Layout>;
57    fn panel_bytes(&self, i: usize, buffer: Option<*mut u8>) -> TractResult<*const u8>;
58    fn panels_count(&self) -> usize {
59        self.mn().divceil(self.format().r())
60    }
61    fn mn(&self) -> usize;
62    fn k(&self) -> usize;
63    fn exotic_fact(&self) -> &dyn ExoticFact;
64
65    fn extract_at_mn_f16(&self, mn: usize, slice: &mut [f16]) -> TractResult<()>;
66    fn extract_at_mn_f32(&self, mn: usize, slice: &mut [f32]) -> TractResult<()>;
67}
68dyn_clone::clone_trait_object!(MMMInputValue);
69impl_downcast!(MMMInputValue);
70dyn_hash::hash_trait_object!(MMMInputValue);
71dyn_eq::eq_trait_object!(MMMInputValue);
72
73#[allow(clippy::derived_hash_with_manual_eq)]
74#[derive(Clone, Hash, Debug)]
75pub struct PackedExoticFact {
76    pub format: Box<dyn MMMInputFormat>,
77    pub mn: TDim,
78    pub k: usize,
79}
80
81impl Display for PackedExoticFact {
82    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
83        write!(f, "Eager {} tensor (mn={} k={})", self.format, self.mn, self.k)
84    }
85}
86
87impl ExoticFact for PackedExoticFact {
88    fn buffer_sizes(&self) -> TVec<TDim> {
89        tvec!(self.format.mem_size(self.k.to_dim(), self.mn.clone()))
90    }
91}
92
93impl PartialEq for PackedExoticFact {
94    fn eq(&self, other: &Self) -> bool {
95        self.format == other.format && self.mn == other.mn && self.k == other.k
96    }
97}
98impl Eq for PackedExoticFact {}
99
100#[derive(Clone, Hash, PartialEq, Eq)]
101pub struct EagerPackedInput {
102    pub fact: PackedExoticFact,
103    pub packed: Arc<Blob>,
104    pub panel_bytes: usize,
105    pub mn: usize,
106}
107
108impl MMMInputValue for EagerPackedInput {
109    fn scratch_panel_buffer_layout(&self) -> Option<Layout> {
110        None
111    }
112    fn panel_bytes(&self, i: usize, _buffer: Option<*mut u8>) -> TractResult<*const u8> {
113        unsafe { Ok(self.packed.as_ptr().add(i * self.panel_bytes)) }
114    }
115    fn k(&self) -> usize {
116        self.fact.k
117    }
118    fn mn(&self) -> usize {
119        self.mn
120    }
121    fn format(&self) -> &dyn MMMInputFormat {
122        &*self.fact.format
123    }
124    fn exotic_fact(&self) -> &dyn ExoticFact {
125        &self.fact
126    }
127    fn extract_at_mn_f16(&self, mn: usize, slice: &mut [f16]) -> TractResult<()> {
128        ensure!(slice.len() == self.k());
129        ensure!(mn < self.mn());
130        self.fact.format.extract_at_mn_f16(self, mn, slice)
131    }
132    fn extract_at_mn_f32(&self, mn: usize, slice: &mut [f32]) -> TractResult<()> {
133        ensure!(slice.len() == self.k());
134        ensure!(mn < self.mn());
135        self.fact.format.extract_at_mn_f32(self, mn, slice)
136    }
137}
138
139impl Display for EagerPackedInput {
140    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
141        (&self.fact as &dyn Display).fmt(f)
142    }
143}
144
145impl Debug for EagerPackedInput {
146    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
147        <Self as Display>::fmt(self, f)
148    }
149}