Skip to main content

tract_linalg/frame/mmm/
input_store.rs

1use downcast_rs::{impl_downcast, Downcast};
2use dyn_clone::DynClone;
3use dyn_hash::DynHash;
4use std::alloc::Layout;
5use std::fmt::{Debug, Display};
6use std::hash::Hash;
7use std::sync::Arc;
8use tract_data::internal::*;
9
10use crate::WeightType;
11
12pub trait MMMInputFormat: Downcast + Debug + DynHash + DynClone + Send + Sync + Display {
13    fn prepare_tensor(&self, t: &Tensor, k_axis: usize, mn_axis: usize) -> TractResult<Tensor>;
14    fn prepare_one(
15        &self,
16        t: &Tensor,
17        k_axis: usize,
18        mn_axis: usize,
19    ) -> TractResult<Box<dyn MMMInputValue>>;
20    fn precursor(&self) -> WeightType;
21    fn r(&self) -> usize;
22    fn k_alignment(&self) -> usize;
23    fn same_as(&self, other: &dyn MMMInputFormat) -> bool;
24    fn merge_with<'o, 'a: 'o, 'b: 'o>(
25        &'a self,
26        other: &'b dyn MMMInputFormat,
27    ) -> Option<&'o dyn MMMInputFormat> {
28        if self.same_as(other) {
29            Some(other)
30        } else {
31            None
32        }
33    }
34    fn mem_size(&self, k: TDim, mn: TDim) -> TDim;
35    fn extract_at_mn_f16(
36        &self,
37        data: &EagerPackedInput,
38        mn: usize,
39        slice: &mut [f16],
40    ) -> TractResult<()>;
41    fn extract_at_mn_f32(
42        &self,
43        data: &EagerPackedInput,
44        mn: usize,
45        slice: &mut [f32],
46    ) -> TractResult<()>;
47}
48
49dyn_clone::clone_trait_object!(MMMInputFormat);
50impl_downcast!(MMMInputFormat);
51dyn_hash::hash_trait_object!(MMMInputFormat);
52
53impl Eq for &dyn MMMInputFormat {}
54impl PartialEq for &dyn MMMInputFormat {
55    fn eq(&self, other: &Self) -> bool {
56        self.same_as(*other)
57    }
58}
59
60pub trait MMMInputValue: DynClone + Debug + DynHash + Send + Sync + Display + Downcast {
61    fn format(&self) -> &dyn MMMInputFormat;
62    fn scratch_panel_buffer_layout(&self) -> Option<Layout>;
63    fn panel_bytes(&self, i: usize, buffer: Option<*mut u8>) -> TractResult<*const u8>;
64    fn panels_count(&self) -> usize {
65        self.mn().divceil(self.format().r())
66    }
67    fn mn(&self) -> usize;
68    fn k(&self) -> usize;
69    fn opaque_fact(&self) -> &dyn OpaqueFact;
70    fn same_as(&self, other: &dyn MMMInputValue) -> bool;
71
72    fn extract_at_mn_f16(&self, mn: usize, slice: &mut [f16]) -> TractResult<()>;
73    fn extract_at_mn_f32(&self, mn: usize, slice: &mut [f32]) -> TractResult<()>;
74}
75dyn_clone::clone_trait_object!(MMMInputValue);
76impl_downcast!(MMMInputValue);
77dyn_hash::hash_trait_object!(MMMInputValue);
78
79impl From<Box<dyn MMMInputValue>> for Opaque {
80    fn from(value: Box<dyn MMMInputValue>) -> Self {
81        Opaque(Arc::new(value))
82    }
83}
84
85impl OpaquePayload for Box<dyn MMMInputValue> {
86    fn same_as(&self, other: &dyn OpaquePayload) -> bool {
87        other
88            .downcast_ref::<Self>()
89            .is_some_and(|other| (&**self as &dyn MMMInputValue).same_as(&**other))
90    }
91}
92
93#[allow(clippy::derived_hash_with_manual_eq)]
94#[derive(Clone, Hash, Debug)]
95pub struct PackedOpaqueFact {
96    pub format: Box<dyn MMMInputFormat>,
97    pub mn: TDim,
98    pub k: usize,
99}
100
101impl Display for PackedOpaqueFact {
102    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
103        write!(f, "Eager {} tensor (mn={} k={})", self.format, self.mn, self.k)
104    }
105}
106
107impl OpaqueFact for PackedOpaqueFact {
108    fn mem_size(&self) -> TDim {
109        self.format.mem_size(self.k.to_dim(), self.mn.clone())
110    }
111
112    fn same_as(&self, other: &dyn OpaqueFact) -> bool {
113        other.downcast_ref::<Self>().is_some_and(|o| o == self)
114    }
115}
116
117impl PartialEq for PackedOpaqueFact {
118    fn eq(&self, other: &Self) -> bool {
119        self.format.same_as(&*other.format) && self.mn == other.mn && self.k == other.k
120    }
121}
122
123#[derive(Clone, Hash)]
124pub struct EagerPackedInput {
125    pub fact: PackedOpaqueFact,
126    pub packed: Arc<Blob>,
127    pub panel_bytes: usize,
128    pub mn: usize,
129}
130
131impl MMMInputValue for EagerPackedInput {
132    fn scratch_panel_buffer_layout(&self) -> Option<Layout> {
133        None
134    }
135    fn panel_bytes(&self, i: usize, _buffer: Option<*mut u8>) -> TractResult<*const u8> {
136        unsafe { Ok(self.packed.as_ptr().add(i * self.panel_bytes)) }
137    }
138    fn k(&self) -> usize {
139        self.fact.k
140    }
141    fn mn(&self) -> usize {
142        self.mn
143    }
144    fn format(&self) -> &dyn MMMInputFormat {
145        &*self.fact.format
146    }
147    fn opaque_fact(&self) -> &dyn OpaqueFact {
148        &self.fact
149    }
150    fn same_as(&self, other: &dyn MMMInputValue) -> bool {
151        other.downcast_ref::<Self>().is_some_and(|other| {
152            self.fact.same_as(&other.fact)
153                && self.packed == other.packed
154                && self.panel_bytes == other.panel_bytes
155        })
156    }
157    fn extract_at_mn_f16(&self, mn: usize, slice: &mut [f16]) -> TractResult<()> {
158        ensure!(slice.len() == self.k());
159        ensure!(mn < self.mn());
160        self.fact.format.extract_at_mn_f16(self, mn, slice)
161    }
162    fn extract_at_mn_f32(&self, mn: usize, slice: &mut [f32]) -> TractResult<()> {
163        ensure!(slice.len() == self.k());
164        ensure!(mn < self.mn());
165        self.fact.format.extract_at_mn_f32(self, mn, slice)
166    }
167}
168
169impl Display for EagerPackedInput {
170    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
171        (&self.fact as &dyn Display).fmt(f)
172    }
173}
174
175impl Debug for EagerPackedInput {
176    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
177        <Self as Display>::fmt(self, f)
178    }
179}