Skip to main content

tract_transformers/ops/
dyn_kv_cache.rs

1use tract_nnef::internal::*;
2use tract_nnef::prelude::tract_itertools::Itertools;
3use tract_nnef::tract_core::ops::OpStateFreeze;
4use tract_nnef::tract_core::ops::array::TypedConcat;
5use tract_nnef::tract_core::ops::source::TypedSource;
6
7use super::next_node;
8
9#[derive(Debug, Clone)]
10pub struct DynKeyValueCacheState {
11    name: String,
12    axis: usize,
13    past_sequence_fact: TypedFact,
14    kv_cache: Option<TValue>,
15}
16
17impl DynKeyValueCacheState {
18    pub fn resolve_symbols(
19        state: &mut TurnState,
20        fact: TypedFact,
21        concrete_shape: Option<&[usize]>,
22    ) -> TractResult<()> {
23        let unresolved = fact
24            .shape
25            .iter()
26            .enumerate()
27            .filter_map(|(ax, symb)| match symb {
28                TDim::Sym(s) if state.resolved_symbols.get(s).is_none() => Some((ax, s)),
29                _ => None,
30            })
31            .collect_vec();
32
33        if unresolved.is_empty() {
34            return Ok(());
35        }
36
37        ensure!(unresolved.len() == 1);
38        let (ax, sym) = unresolved[0];
39        if let Some(shape) = concrete_shape {
40            ensure!(ax < shape.len());
41            state.resolved_symbols.set(sym, shape[ax] as i64);
42        } else {
43            state.resolved_symbols.set(sym, 0);
44        }
45
46        if state.scenario.is_none() {
47            state.scenario = sym.scope().unwrap().guess_scenario(&state.resolved_symbols)?;
48        }
49        Ok(())
50    }
51
52    pub fn truncate(&mut self, len: usize) -> TractResult<()> {
53        if let Some(t) = self.kv_cache.as_mut() {
54            *t = t.slice(self.axis, 0, len)?.into_tvalue();
55        } else {
56            bail!("Can not truncate a zero-len kv-cache value");
57        }
58        Ok(())
59    }
60}
61
62impl OpState for DynKeyValueCacheState {
63    fn load_from(
64        &mut self,
65        state: &mut TurnState,
66        states: &mut dyn Iterator<Item = tract_nnef::prelude::TValue>,
67    ) -> TractResult<()> {
68        // KV Cache fact is always at index 0
69        let kv_cache_init = states.next().context("Not enough state initializers")?;
70        Self::resolve_symbols(state, self.past_sequence_fact.clone(), Some(kv_cache_init.shape()))?;
71        self.kv_cache = Some(kv_cache_init.clone());
72
73        Ok(())
74    }
75
76    fn save_to(&self, states: &mut Vec<TValue>) -> TractResult<()> {
77        if let Some(kv_cache) = &self.kv_cache {
78            states.push(kv_cache.clone());
79            Ok(())
80        } else {
81            bail!("KV cache {} was never initialized", self.name)
82        }
83    }
84
85    fn init_tensor_fact(&self) -> Option<(String, TypedFact)> {
86        Some((self.name.clone(), self.past_sequence_fact.clone()))
87    }
88
89    fn resolve_symbols(&mut self, state: &mut TurnState) -> TractResult<()> {
90        let shape = self.kv_cache.as_ref().map(|kv_cache| kv_cache.shape());
91        Self::resolve_symbols(state, self.past_sequence_fact.clone(), shape)
92    }
93
94    fn eval(
95        &mut self,
96        _state: &mut TurnState,
97        _op: &dyn Op,
98        inputs: TVec<TValue>,
99    ) -> TractResult<TVec<TValue>> {
100        let input = args_1!(inputs);
101        // build output
102        let output = if let Some(curr) = self.kv_cache.take() {
103            TypedConcat { axis: self.axis }.eval(tvec![curr, input])?.remove(0)
104        } else {
105            input
106        };
107        self.kv_cache = Some(output.clone());
108
109        Ok(tvec!(output))
110    }
111}
112
113#[derive(Clone, Debug)]
114pub struct DynKeyValueCache {
115    pub name: String,
116    pub axis: usize,
117    pub past_sequence_fact: TypedFact,
118    pub input_sequence_fact: TypedFact,
119}
120
121impl Op for DynKeyValueCache {
122    fn name(&self) -> StaticName {
123        "DynamicKeyValueCache".to_string().into()
124    }
125
126    op_as_typed_op!();
127}
128
129impl EvalOp for DynKeyValueCache {
130    fn is_stateless(&self) -> bool {
131        false
132    }
133
134    fn state(
135        &self,
136        _session: &TurnState,
137        _node_id: usize,
138    ) -> TractResult<Option<Box<dyn OpState>>> {
139        Ok(Some(Box::new(DynKeyValueCacheState {
140            name: self.name.clone(),
141            axis: self.axis,
142            past_sequence_fact: self.past_sequence_fact.clone(),
143            kv_cache: None,
144        })))
145    }
146}
147
148impl TypedOp for DynKeyValueCache {
149    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
150        ensure!(inputs.len() == 1);
151        let input = inputs[0];
152        let mut fact = input.without_value();
153
154        fact.shape.set(
155            self.axis,
156            self.past_sequence_fact.shape.dims()[self.axis].clone()
157                + self.input_sequence_fact.shape.dims()[self.axis].clone(),
158        );
159        Ok(tvec!(fact))
160    }
161
162    fn cost(&self, _inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
163        let token_volume = self
164            .past_sequence_fact
165            .shape
166            .iter()
167            .enumerate()
168            .filter(|(axis, _d)| *axis != self.axis)
169            .map(|(_axis, d)| d)
170            .product::<TDim>();
171        Ok(tvec!((Cost::Custom(false, "KVCacheValuesPerToken".to_string()), token_volume)))
172    }
173
174    as_op!();
175}
176
177#[derive(Debug, Clone)]
178pub struct FrozenDynKeyValueCacheState {
179    name: String,
180    axis: usize,
181    past_sequence_fact: TypedFact,
182    kv_cache: Option<Tensor>,
183}
184
185impl OpStateFreeze for DynKeyValueCacheState {
186    fn freeze(&self) -> Box<dyn FrozenOpState> {
187        Box::new(FrozenDynKeyValueCacheState {
188            name: self.name.clone(),
189            axis: self.axis,
190            past_sequence_fact: self.past_sequence_fact.clone(),
191            kv_cache: self.kv_cache.clone().map(|t| t.into_tensor()),
192        })
193    }
194}
195
196impl FrozenOpState for FrozenDynKeyValueCacheState {
197    fn unfreeze(&self) -> Box<dyn OpState> {
198        Box::new(DynKeyValueCacheState {
199            axis: self.axis,
200            name: self.name.clone(),
201            past_sequence_fact: self.past_sequence_fact.clone(),
202            kv_cache: self.kv_cache.clone().map(|t| t.into_tvalue()),
203        })
204    }
205}
206
207/// Search pattern => Input -> Concat -> Output
208/// Return type is for using rule-ensure macro
209pub fn replace_kv_cache(target: &mut TypedModel, source_node_id: usize) -> TractResult<Option<()>> {
210    assert!(target.node(source_node_id).op_is::<TypedSource>());
211    let (concat_node_id, non_source_input_id, axis, input_facts) = {
212        rule_if_some!(concat_node = next_node(target, target.node(source_node_id)));
213
214        // Check KV Cache Pattern
215        rule_if!(
216            concat_node.op_is::<TypedConcat>()
217                && concat_node.inputs.len() == 2
218                && concat_node.outputs.len() == 1
219                && target.outputs.contains(&concat_node.id.into())
220        );
221
222        let concat_in_facts = target.node_input_facts(concat_node.id)?;
223
224        // Check on shapes
225        let concat_in_shapes = [concat_in_facts[0].shape.dims(), concat_in_facts[1].shape.dims()];
226        let rank = concat_in_shapes[0].len();
227        let axes = (0..rank)
228            .filter(|ax| concat_in_shapes[0][*ax] != concat_in_shapes[1][*ax])
229            .collect_vec();
230        ensure!(axes.len() == 1);
231
232        let axis = axes[0];
233        rule_if!(
234            matches!(concat_in_shapes[0][axis], TDim::Sym(_))
235                && matches!(concat_in_shapes[1][axis], TDim::Sym(_))
236        );
237        let mut facts = [concat_in_facts[0].clone(), concat_in_facts[1].clone()];
238        if concat_node.inputs[0].node == source_node_id {
239            (concat_node.id, concat_node.inputs[1].node, axis, facts)
240        } else if concat_node.inputs[1].node == source_node_id {
241            facts.swap(0, 1);
242            (concat_node.id, concat_node.inputs[0].node, axis, facts)
243        } else {
244            return Ok(None);
245        }
246    };
247
248    {
249        // Replace Concat by KVCache
250        let name = target.node_names().collect_vec()[source_node_id].to_string();
251        let concat_node = target.node_mut(concat_node_id);
252        concat_node.op = Box::new(DynKeyValueCache {
253            name: name.clone(),
254            axis,
255            past_sequence_fact: input_facts[0].clone(),
256            input_sequence_fact: input_facts[1].clone(),
257        });
258        concat_node.name = name;
259        concat_node.inputs.retain(|input| input != &source_node_id.into());
260    }
261
262    {
263        // Replace Source by Dummy Op for it to be cleaned later
264        let dummy_op = target.create_dummy();
265        let source_node = target.node_mut(source_node_id);
266        source_node.outputs[0].successors.clear();
267        source_node.op = dummy_op;
268    }
269    {
270        // Non-source input is usually the second input of Concat. Rewire it to the only input of the new KVCache Op
271        let non_source_input = target.node_mut(non_source_input_id);
272        non_source_input.outputs.iter_mut().for_each(|output| {
273            output.successors.iter_mut().for_each(|succ| {
274                if succ.node == concat_node_id {
275                    succ.slot = 0
276                }
277            })
278        });
279    }
280
281    // Clean model I/Os
282    target.outputs.retain(|output| output.node != concat_node_id);
283    target.inputs.retain(|input| input.node != source_node_id);
284    target.outlet_labels.remove(&concat_node_id.into());
285    Ok(None)
286}
287
288#[cfg(test)]
289mod tests {
290    use super::*;
291    use tract_num_traits::AsPrimitive;
292    use tract_num_traits::Zero;
293
294    fn run_test_case<F: Datum + Zero + Copy>(
295        input_shapes: &[Vec<usize>],
296        axis: usize,
297    ) -> TractResult<()>
298    where
299        usize: AsPrimitive<F>,
300    {
301        let first_shape = &input_shapes[0];
302        ensure!(input_shapes.iter().all(|shape| (shape.len() == first_shape.len())
303            && (shape[..axis] == first_shape[..axis])
304            && (if axis != (shape.len() - 1) {
305                shape[(axis + 1)..] == first_shape[(axis + 1)..]
306            } else {
307                true
308            })));
309
310        let op_name = "test".to_string();
311        let dummy_model = TypedModel::default();
312
313        let make_shape =
314            |sym: &str| {
315                input_shapes[0]
316                    .iter()
317                    .enumerate()
318                    .map(|(i, &dim)| {
319                        if i == axis {
320                            TDim::Sym(dummy_model.sym(sym))
321                        } else {
322                            TDim::Val(dim as _)
323                        }
324                    })
325                    .collect::<TVec<TDim>>()
326            };
327
328        let past_shape = make_shape("P");
329        let input_shape = make_shape("S");
330
331        let op = DynKeyValueCache {
332            name: op_name.clone(),
333            past_sequence_fact: TypedFact::dt_shape(F::datum_type(), past_shape),
334            input_sequence_fact: TypedFact::dt_shape(F::datum_type(), input_shape),
335            axis,
336        };
337
338        let mut session_state = TurnState::default();
339        let mut state = op.state(&mut session_state, 0)?.unwrap();
340
341        let mut inputs = tvec![];
342
343        // Init state with first shape
344        let shape = &input_shapes[0];
345        let len = shape.iter().product::<usize>();
346        let input = Tensor::from_shape(shape, &(0..len).map(|f| f.as_()).collect::<Vec<F>>())?;
347        inputs.push(input.clone().into_tvalue());
348
349        let mut state_initializers = vec![input.into()].into_iter();
350
351        state.load_from(&mut session_state, &mut state_initializers)?;
352
353        for shape in input_shapes {
354            let len = shape.iter().product::<usize>();
355            let input = Tensor::from_shape(&shape, &(0..len).map(|f| f.as_()).collect::<Vec<F>>())?;
356            inputs.push(input.clone().into_tvalue());
357            state.eval(&mut session_state, &op, tvec!(input.clone().into()))?[0]
358                .clone()
359                .into_tensor();
360        }
361
362        let mut curr_states = vec![];
363        state.save_to(&mut curr_states)?;
364        let output = curr_states.remove(0);
365
366        let reference = &TypedConcat { axis }.eval(inputs)?[0];
367        output.close_enough(&reference.clone().into_tensor(), Approximation::Close)?;
368        Ok(())
369    }
370
371    #[test]
372    fn test_dyn_kv_cache() -> TractResult<()> {
373        run_test_case::<f32>(&[vec![2, 2]], 0)?;
374        run_test_case::<f32>(&[vec![2, 2], vec![4, 2]], 0)?;
375        run_test_case::<f32>(&[vec![2, 2], vec![2, 1], vec![2, 3]], 1)?;
376        Ok(())
377    }
378}