tract_transformers/ops/
dyn_kv_cache.rs

1use tract_nnef::internal::*;
2use tract_nnef::prelude::tract_itertools::Itertools;
3use tract_nnef::tract_core::ops::array::TypedConcat;
4use tract_nnef::tract_core::ops::source::TypedSource;
5use tract_nnef::tract_core::ops::OpStateFreeze;
6
7use crate::rule_ensure;
8
9use super::next_node;
10
11#[derive(Debug, Clone)]
12pub struct DynKeyValueCacheState {
13    name: String,
14    axis: usize,
15    past_sequence_fact: TypedFact,
16    kv_cache: Option<TValue>,
17}
18
19impl DynKeyValueCacheState {
20    pub fn resolve_symbols(
21        state: &mut SessionState,
22        fact: TypedFact,
23        concrete_shape: Option<&[usize]>,
24    ) -> TractResult<()> {
25        let unresolved = fact
26            .shape
27            .iter()
28            .enumerate()
29            .filter_map(|(ax, symb)| match symb {
30                TDim::Sym(s) if state.resolved_symbols.get(s).is_none() => Some((ax, s)),
31                _ => None,
32            })
33            .collect_vec();
34
35        if unresolved.is_empty() {
36            return Ok(());
37        }
38
39        ensure!(unresolved.len() == 1);
40        let (ax, sym) = unresolved[0];
41        if let Some(shape) = concrete_shape {
42            ensure!(ax < shape.len());
43            state.resolved_symbols.set(sym, shape[ax] as i64);
44        } else {
45            state.resolved_symbols.set(sym, 0);
46        }
47
48        if state.scenario.is_none() {
49            state.scenario = sym.scope().unwrap().guess_scenario(&state.resolved_symbols)?;
50        }
51        Ok(())
52    }
53
54    pub fn truncate(&mut self, len: usize) -> TractResult<()> {
55        if let Some(t) = self.kv_cache.as_mut() {
56            *t = t.slice(self.axis, 0, len)?.into_tvalue();
57        } else {
58            bail!("Can not truncate a zero-len kv-cache value");
59        }
60        Ok(())
61    }
62}
63
64impl OpState for DynKeyValueCacheState {
65    fn load_from(&mut self, state: &mut SessionState, states: &mut Vec<TValue>) -> TractResult<()> {
66        let kv_cache_init = states.remove(0);
67        // KV Cache fact is always at index 0
68        Self::resolve_symbols(state, self.past_sequence_fact.clone(), Some(kv_cache_init.shape()))?;
69        self.kv_cache = Some(kv_cache_init);
70
71        Ok(())
72    }
73
74    fn save_to(&self, states: &mut Vec<TValue>) -> TractResult<()> {
75        if let Some(kv_cache) = &self.kv_cache {
76            states.push(kv_cache.clone());
77            Ok(())
78        } else {
79            bail!("KV cache {} was never initialized", self.name)
80        }
81    }
82
83    fn init_tensor_fact(&self) -> Option<TypedFact> {
84        Some(self.past_sequence_fact.clone())
85    }
86
87    fn resolve_symbols(&mut self, state: &mut SessionState) -> TractResult<()> {
88        let shape = self.kv_cache.as_ref().map(|kv_cache| kv_cache.shape());
89        Self::resolve_symbols(state, self.past_sequence_fact.clone(), shape)
90    }
91
92    fn eval(
93        &mut self,
94        _state: &mut SessionState,
95        _op: &dyn Op,
96        inputs: TVec<TValue>,
97    ) -> TractResult<TVec<TValue>> {
98        let input = args_1!(inputs);
99        // build output
100        let output = if let Some(curr) = self.kv_cache.take() {
101            TypedConcat { axis: self.axis }.eval(tvec![curr, input])?.remove(0)
102        } else {
103            input
104        };
105        self.kv_cache = Some(output.clone());
106
107        Ok(tvec!(output))
108    }
109}
110
111#[derive(Clone, Debug)]
112pub struct DynKeyValueCache {
113    pub name: String,
114    pub axis: usize,
115    pub past_sequence_fact: TypedFact,
116    pub input_sequence_fact: TypedFact,
117}
118
119impl Op for DynKeyValueCache {
120    fn name(&self) -> StaticName {
121        "DynamicKeyValueCache".to_string().into()
122    }
123
124    op_as_typed_op!();
125}
126
127impl EvalOp for DynKeyValueCache {
128    fn is_stateless(&self) -> bool {
129        false
130    }
131
132    fn state(
133        &self,
134        _session: &mut SessionState,
135        _node_id: usize,
136    ) -> TractResult<Option<Box<dyn OpState>>> {
137        Ok(Some(Box::new(DynKeyValueCacheState {
138            name: self.name.clone(),
139            axis: self.axis,
140            past_sequence_fact: self.past_sequence_fact.clone(),
141            kv_cache: None,
142        })))
143    }
144}
145
146impl TypedOp for DynKeyValueCache {
147    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
148        ensure!(inputs.len() == 1);
149        let input = inputs[0];
150        let mut fact = input.without_value();
151
152        fact.shape.set(
153            self.axis,
154            self.past_sequence_fact.shape.dims()[self.axis].clone()
155                + self.input_sequence_fact.shape.dims()[self.axis].clone(),
156        );
157        Ok(tvec!(fact))
158    }
159
160    fn cost(&self, _inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
161        let token_volume = self
162            .past_sequence_fact
163            .shape
164            .iter()
165            .enumerate()
166            .filter(|(axis, _d)| *axis != self.axis)
167            .map(|(_axis, d)| d)
168            .product::<TDim>();
169        Ok(tvec!((Cost::Custom(false, "KVCacheValuesPerToken".to_string()), token_volume)))
170    }
171
172    as_op!();
173}
174
175#[derive(Debug, Clone)]
176pub struct FrozenDynKeyValueCacheState {
177    name: String,
178    axis: usize,
179    past_sequence_fact: TypedFact,
180    kv_cache: Option<Tensor>,
181}
182
183impl OpStateFreeze for DynKeyValueCacheState {
184    fn freeze(&self) -> Box<dyn FrozenOpState> {
185        Box::new(FrozenDynKeyValueCacheState {
186            name: self.name.clone(),
187            axis: self.axis,
188            past_sequence_fact: self.past_sequence_fact.clone(),
189            kv_cache: self.kv_cache.clone().map(|t| t.into_tensor()),
190        })
191    }
192}
193
194impl FrozenOpState for FrozenDynKeyValueCacheState {
195    fn unfreeze(&self) -> Box<dyn OpState> {
196        Box::new(DynKeyValueCacheState {
197            axis: self.axis,
198            name: self.name.clone(),
199            past_sequence_fact: self.past_sequence_fact.clone(),
200            kv_cache: self.kv_cache.clone().map(|t| t.into_tvalue()),
201        })
202    }
203}
204
205/// Search pattern => Input -> Concat -> Output
206/// Return type is for using rule-ensure macro
207pub fn replace_kv_cache(target: &mut TypedModel, source_node_id: usize) -> TractResult<Option<()>> {
208    assert!(target.node(source_node_id).op_is::<TypedSource>());
209    let (concat_node_id, non_source_input_id, axis, input_facts) = {
210        let concat_node = if let Some(n_node) = next_node(target, target.node(source_node_id)) {
211            n_node
212        } else {
213            return Ok(None);
214        };
215
216        // Check KV Cache Pattern
217        rule_ensure!(
218            concat_node.op_is::<TypedConcat>()
219                && concat_node.inputs.len() == 2
220                && concat_node.outputs.len() == 1
221                && target.outputs.contains(&concat_node.id.into())
222        );
223
224        let concat_in_facts = target.node_input_facts(concat_node.id)?;
225
226        // Check on shapes
227        let concat_in_shapes = [concat_in_facts[0].shape.dims(), concat_in_facts[1].shape.dims()];
228        let rank = concat_in_shapes[0].len();
229        let axes = (0..rank)
230            .filter(|ax| concat_in_shapes[0][*ax] != concat_in_shapes[1][*ax])
231            .collect_vec();
232        ensure!(axes.len() == 1);
233
234        let axis = axes[0];
235        rule_ensure!(
236            matches!(concat_in_shapes[0][axis], TDim::Sym(_))
237                && matches!(concat_in_shapes[1][axis], TDim::Sym(_))
238        );
239        let mut facts = [concat_in_facts[0].clone(), concat_in_facts[1].clone()];
240        if concat_node.inputs[0].node == source_node_id {
241            (concat_node.id, concat_node.inputs[1].node, axis, facts)
242        } else if concat_node.inputs[1].node == source_node_id {
243            facts.swap(0, 1);
244            (concat_node.id, concat_node.inputs[0].node, axis, facts)
245        } else {
246            return Ok(None);
247        }
248    };
249
250    {
251        // Replace Concat by KVCache
252        let name = target.node_names().collect_vec()[source_node_id].to_string();
253        let concat_node = target.node_mut(concat_node_id);
254        concat_node.op = Box::new(DynKeyValueCache {
255            name: name.clone(),
256            axis,
257            past_sequence_fact: input_facts[0].clone(),
258            input_sequence_fact: input_facts[1].clone(),
259        });
260        concat_node.name = name;
261        concat_node.inputs.retain(|input| input != &source_node_id.into());
262    }
263
264    {
265        // Replace Source by Dummy Op for it to be cleaned later
266        let dummy_op = target.create_dummy();
267        let source_node = target.node_mut(source_node_id);
268        source_node.outputs[0].successors.clear();
269        source_node.op = dummy_op;
270    }
271    {
272        // Non-source input is usually the second input of Concat. Rewire it to the only input of the new KVCache Op
273        let non_source_input = target.node_mut(non_source_input_id);
274        non_source_input.outputs.iter_mut().for_each(|output| {
275            output.successors.iter_mut().for_each(|succ| {
276                if succ.node == concat_node_id {
277                    succ.slot = 0
278                }
279            })
280        });
281    }
282
283    // Clean model I/Os
284    target.outputs.retain(|output| output.node != concat_node_id);
285    target.inputs.retain(|input| input.node != source_node_id);
286    target.outlet_labels.remove(&concat_node_id.into());
287    Ok(None)
288}
289
290#[cfg(test)]
291mod tests {
292    use super::*;
293    use tract_num_traits::AsPrimitive;
294    use tract_num_traits::Zero;
295
296    fn run_test_case<F: Datum + Zero + Copy>(
297        input_shapes: &[Vec<usize>],
298        axis: usize,
299    ) -> TractResult<()>
300    where
301        usize: AsPrimitive<F>,
302    {
303        let first_shape = &input_shapes[0];
304        ensure!(input_shapes.iter().all(|shape| (shape.len() == first_shape.len())
305            && (shape[..axis] == first_shape[..axis])
306            && (if axis != (shape.len() - 1) {
307                shape[(axis + 1)..] == first_shape[(axis + 1)..]
308            } else {
309                true
310            })));
311
312        let op_name = "test".to_string();
313        let dummy_model = TypedModel::default();
314
315        let make_shape =
316            |sym: &str| {
317                input_shapes[0]
318                    .iter()
319                    .enumerate()
320                    .map(|(i, &dim)| {
321                        if i == axis {
322                            TDim::Sym(dummy_model.sym(sym))
323                        } else {
324                            TDim::Val(dim as _)
325                        }
326                    })
327                    .collect::<TVec<TDim>>()
328            };
329
330        let past_shape = make_shape("P");
331        let input_shape = make_shape("S");
332
333        let op = DynKeyValueCache {
334            name: op_name.clone(),
335            past_sequence_fact: TypedFact::dt_shape(F::datum_type(), past_shape),
336            input_sequence_fact: TypedFact::dt_shape(F::datum_type(), input_shape),
337            axis,
338        };
339
340        let mut session_state = SessionState::default();
341        let mut state = op.state(&mut session_state, 0)?.unwrap();
342
343        let mut inputs = tvec![];
344
345        // Init state with first shape
346        let shape = &input_shapes[0];
347        let len = shape.iter().product::<usize>();
348        let input = Tensor::from_shape(shape, &(0..len).map(|f| f.as_()).collect::<Vec<F>>())?;
349        inputs.push(input.clone().into_tvalue());
350
351        let mut state_initializers = vec![input.into()];
352
353        state.load_from(&mut session_state, &mut state_initializers)?;
354
355        for shape in input_shapes {
356            let len = shape.iter().product::<usize>();
357            let input = Tensor::from_shape(&shape, &(0..len).map(|f| f.as_()).collect::<Vec<F>>())?;
358            inputs.push(input.clone().into_tvalue());
359            state.eval(&mut session_state, &op, tvec!(input.clone().into()))?[0]
360                .clone()
361                .into_tensor();
362        }
363
364        let mut curr_states = vec![];
365        state.save_to(&mut curr_states)?;
366        let output = curr_states.remove(0);
367
368        let reference = &TypedConcat { axis }.eval(inputs)?[0];
369        output.close_enough(&reference.clone().into_tensor(), Approximation::Close)?;
370        Ok(())
371    }
372
373    #[test]
374    fn test_dyn_kv_cache() -> TractResult<()> {
375        run_test_case::<f32>(&[vec![2, 2]], 0)?;
376        run_test_case::<f32>(&[vec![2, 2], vec![4, 2]], 0)?;
377        run_test_case::<f32>(&[vec![2, 2], vec![2, 1], vec![2, 3]], 1)?;
378        Ok(())
379    }
380}