Skip to main content

tract_transformers/ops/
dyn_kv_cache.rs

1use std::str::FromStr;
2
3use tract_nnef::internal::*;
4use tract_nnef::prelude::tract_itertools::Itertools;
5use tract_nnef::ser::{datum_type, tdims};
6use tract_nnef::tract_core::ops::OpStateFreeze;
7use tract_nnef::tract_core::ops::array::TypedConcat;
8use tract_nnef::tract_core::ops::source::TypedSource;
9
10pub fn register(registry: &mut Registry) {
11    registry.register_dumper(ser_dyn_kv_cache);
12    registry.register_primitive(
13        "tract_transformers_dyn_kv_cache",
14        &[
15            TypeName::Scalar.tensor().named("input"),
16            TypeName::String.named("name"),
17            TypeName::Integer.named("axis"),
18            TypeName::String.named("datum_type"),
19            TypeName::Integer.array().named("past_sequence_shape"),
20            TypeName::Integer.array().named("input_sequence_shape"),
21        ],
22        &[("output", TypeName::Scalar.tensor())],
23        de_dyn_kv_cache,
24    );
25}
26
27fn ser_dyn_kv_cache(
28    ast: &mut IntoAst,
29    node: &TypedNode,
30    op: &DynKeyValueCache,
31) -> TractResult<Option<Arc<RValue>>> {
32    let input = ast.mapping[&node.inputs[0]].clone();
33    Ok(Some(invocation(
34        "tract_transformers_dyn_kv_cache",
35        &[input],
36        &[
37            ("name", string(&op.name)),
38            ("axis", numeric(op.axis)),
39            ("datum_type", datum_type(op.past_sequence_fact.datum_type)),
40            ("past_sequence_shape", tdims(op.past_sequence_fact.shape.dims())),
41            ("input_sequence_shape", tdims(op.input_sequence_fact.shape.dims())),
42        ],
43    )))
44}
45
46fn de_dyn_kv_cache(
47    builder: &mut ModelBuilder,
48    invocation: &ResolvedInvocation,
49) -> TractResult<Value> {
50    let input = invocation.named_arg_as(builder, "input")?;
51    let name: String = invocation.named_arg_as(builder, "name")?;
52    let axis: usize = invocation.named_arg_as(builder, "axis")?;
53    let dt = DatumType::from_str(&invocation.named_arg_as::<String>(builder, "datum_type")?)?;
54    let past_sequence_shape: TVec<TDim> = builder
55        .allowing_new_symbols(|builder| invocation.named_arg_as(builder, "past_sequence_shape"))?;
56    let input_sequence_shape: TVec<TDim> = builder
57        .allowing_new_symbols(|builder| invocation.named_arg_as(builder, "input_sequence_shape"))?;
58    builder.wire(
59        DynKeyValueCache {
60            name,
61            axis,
62            past_sequence_fact: dt.fact(&*past_sequence_shape),
63            input_sequence_fact: dt.fact(&*input_sequence_shape),
64        },
65        &[input],
66    )
67}
68
69#[derive(Debug, Clone)]
70pub struct DynKeyValueCacheState {
71    name: String,
72    axis: usize,
73    past_sequence_fact: TypedFact,
74    kv_cache: Option<TValue>,
75}
76
77impl DynKeyValueCacheState {
78    pub fn resolve_symbols(
79        state: &mut TurnState,
80        fact: TypedFact,
81        concrete_shape: Option<&[usize]>,
82    ) -> TractResult<()> {
83        let unresolved = fact
84            .shape
85            .iter()
86            .enumerate()
87            .filter_map(|(ax, symb)| match symb {
88                TDim::Sym(s) if state.resolved_symbols.get(s).is_none() => Some((ax, s)),
89                _ => None,
90            })
91            .collect_vec();
92
93        if unresolved.is_empty() {
94            return Ok(());
95        }
96
97        ensure!(unresolved.len() == 1);
98        let (ax, sym) = unresolved[0];
99        if let Some(shape) = concrete_shape {
100            ensure!(ax < shape.len());
101            state.resolved_symbols.set(sym, shape[ax] as i64);
102        } else {
103            state.resolved_symbols.set(sym, 0);
104        }
105
106        if state.scenario.is_none() {
107            state.scenario = sym.scope().unwrap().guess_scenario(&state.resolved_symbols)?;
108        }
109        Ok(())
110    }
111
112    pub fn truncate(&mut self, len: usize) -> TractResult<()> {
113        if let Some(t) = self.kv_cache.as_mut() {
114            *t = t.slice(self.axis, 0, len)?.into_tvalue();
115        } else {
116            bail!("Can not truncate a zero-len kv-cache value");
117        }
118        Ok(())
119    }
120}
121
122impl OpState for DynKeyValueCacheState {
123    fn load_from(
124        &mut self,
125        state: &mut TurnState,
126        states: &mut dyn Iterator<Item = tract_nnef::prelude::TValue>,
127    ) -> TractResult<()> {
128        // KV Cache fact is always at index 0
129        let kv_cache_init = states.next().context("Not enough state initializers")?;
130        Self::resolve_symbols(state, self.past_sequence_fact.clone(), Some(kv_cache_init.shape()))?;
131        self.kv_cache = Some(kv_cache_init.clone());
132
133        Ok(())
134    }
135
136    fn save_to(&self, states: &mut Vec<TValue>) -> TractResult<()> {
137        if let Some(kv_cache) = &self.kv_cache {
138            states.push(kv_cache.clone());
139            Ok(())
140        } else {
141            bail!("KV cache {} was never initialized", self.name)
142        }
143    }
144
145    fn init_tensor_fact(&self) -> Option<(String, TypedFact)> {
146        Some((self.name.clone(), self.past_sequence_fact.clone()))
147    }
148
149    fn resolve_symbols(&mut self, state: &mut TurnState) -> TractResult<()> {
150        let shape = self.kv_cache.as_ref().map(|kv_cache| kv_cache.shape());
151        Self::resolve_symbols(state, self.past_sequence_fact.clone(), shape)
152    }
153
154    fn eval(
155        &mut self,
156        _state: &mut TurnState,
157        _op: &dyn Op,
158        inputs: TVec<TValue>,
159    ) -> TractResult<TVec<TValue>> {
160        let input = args_1!(inputs);
161        // build output
162        let output = if let Some(curr) = self.kv_cache.take() {
163            TypedConcat { axis: self.axis }.eval(tvec![curr, input])?.remove(0)
164        } else {
165            input
166        };
167        self.kv_cache = Some(output.clone());
168
169        Ok(tvec!(output))
170    }
171}
172
173#[derive(Clone, Debug)]
174pub struct DynKeyValueCache {
175    pub name: String,
176    pub axis: usize,
177    pub past_sequence_fact: TypedFact,
178    pub input_sequence_fact: TypedFact,
179}
180
181impl Op for DynKeyValueCache {
182    fn name(&self) -> StaticName {
183        "DynamicKeyValueCache".to_string().into()
184    }
185
186    op_as_typed_op!();
187}
188
189impl EvalOp for DynKeyValueCache {
190    fn is_stateless(&self) -> bool {
191        false
192    }
193
194    fn state(
195        &self,
196        _session: &TurnState,
197        _node_id: usize,
198    ) -> TractResult<Option<Box<dyn OpState>>> {
199        Ok(Some(Box::new(DynKeyValueCacheState {
200            name: self.name.clone(),
201            axis: self.axis,
202            past_sequence_fact: self.past_sequence_fact.clone(),
203            kv_cache: None,
204        })))
205    }
206}
207
208impl TypedOp for DynKeyValueCache {
209    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
210        ensure!(inputs.len() == 1);
211        let input = inputs[0];
212        let mut fact = input.without_value();
213
214        fact.shape.set(
215            self.axis,
216            self.past_sequence_fact.shape.dims()[self.axis].clone()
217                + self.input_sequence_fact.shape.dims()[self.axis].clone(),
218        );
219        Ok(tvec!(fact))
220    }
221
222    fn cost(&self, _inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
223        let token_volume = self
224            .past_sequence_fact
225            .shape
226            .iter()
227            .enumerate()
228            .filter(|(axis, _d)| *axis != self.axis)
229            .map(|(_axis, d)| d)
230            .product::<TDim>();
231        Ok(tvec!((Cost::Custom(false, "KVCacheValuesPerToken".to_string()), token_volume)))
232    }
233
234    as_op!();
235}
236
237#[derive(Debug, Clone)]
238pub struct FrozenDynKeyValueCacheState {
239    name: String,
240    axis: usize,
241    past_sequence_fact: TypedFact,
242    kv_cache: Option<Tensor>,
243}
244
245impl OpStateFreeze for DynKeyValueCacheState {
246    fn freeze(&self) -> Box<dyn FrozenOpState> {
247        Box::new(FrozenDynKeyValueCacheState {
248            name: self.name.clone(),
249            axis: self.axis,
250            past_sequence_fact: self.past_sequence_fact.clone(),
251            kv_cache: self.kv_cache.clone().map(|t| t.into_tensor()),
252        })
253    }
254}
255
256impl FrozenOpState for FrozenDynKeyValueCacheState {
257    fn unfreeze(&self) -> Box<dyn OpState> {
258        Box::new(DynKeyValueCacheState {
259            axis: self.axis,
260            name: self.name.clone(),
261            past_sequence_fact: self.past_sequence_fact.clone(),
262            kv_cache: self.kv_cache.clone().map(|t| t.into_tvalue()),
263        })
264    }
265}
266
267/// Reverse of `replace_kv_cache`: replaces a DynKeyValueCache node with Source + Concat,
268/// restoring KV cache state as explicit model inputs and outputs.
269pub fn unfold_kv_cache(target: &mut TypedModel, kv_node_id: usize) -> TractResult<()> {
270    let node = target.node(kv_node_id);
271    let op = node.op_as::<DynKeyValueCache>().context("Not a DynKeyValueCache node")?;
272    let name = op.name.clone();
273    let axis = op.axis;
274    let past_fact = op.past_sequence_fact.clone();
275    let input_fact = op.input_sequence_fact.clone();
276    let existing_input = node.inputs[0];
277
278    // Add a new Source node for the past KV cache
279    let source_outlet = target.add_source(&name, past_fact)?;
280
281    // Compute output fact for the Concat
282    let mut output_fact = input_fact.clone();
283    output_fact.shape.set(
284        axis,
285        target.outlet_fact(source_outlet)?.shape.dims()[axis].clone()
286            + input_fact.shape.dims()[axis].clone(),
287    );
288
289    // Replace DynKeyValueCache op with TypedConcat
290    let kv_node = target.node_mut(kv_node_id);
291    kv_node.name = format!("{name}_concat");
292    kv_node.op = Box::new(TypedConcat { axis });
293    kv_node.outputs[0].fact = output_fact;
294
295    // Rewire: Concat takes [source, existing_input] as inputs
296    // Currently the node has [existing_input] at slot 0
297    // We need [source_outlet, existing_input] at slots [0, 1]
298    kv_node.inputs = vec![source_outlet, existing_input];
299
300    // Update successor info on the source node
301    target.nodes[source_outlet.node].outputs[source_outlet.slot]
302        .successors
303        .push(InletId::new(kv_node_id, 0));
304
305    // Update the existing input's successor slot from 0 to 1
306    target.nodes[existing_input.node].outputs[existing_input.slot].successors.iter_mut().for_each(
307        |succ| {
308            if succ.node == kv_node_id && succ.slot == 0 {
309                succ.slot = 1;
310            }
311        },
312    );
313
314    // Add the Concat output to model outputs and label it so runtimes preserve the name
315    let concat_outlet = OutletId::new(kv_node_id, 0);
316    target.outputs.push(concat_outlet);
317    target.set_outlet_label(concat_outlet, format!("{name}_concat"))?;
318
319    Ok(())
320}
321
322/// Search pattern => Input -> Concat -> Output
323/// Return type is for using rule-ensure macro
324pub fn replace_kv_cache(target: &mut TypedModel, source_node_id: usize) -> TractResult<Option<()>> {
325    assert!(target.node(source_node_id).op_is::<TypedSource>());
326    let (concat_node_id, non_source_input_id, axis, input_facts) = {
327        rule_if_some!(concat_node = target.next_node(target.node(source_node_id)));
328
329        // Check KV Cache Pattern
330        rule_if!(
331            concat_node.op_is::<TypedConcat>()
332                && concat_node.inputs.len() == 2
333                && concat_node.outputs.len() == 1
334                && target.outputs.contains(&concat_node.id.into())
335        );
336
337        let concat_in_facts = target.node_input_facts(concat_node.id)?;
338
339        // Check on shapes
340        let concat_in_shapes = [concat_in_facts[0].shape.dims(), concat_in_facts[1].shape.dims()];
341        let rank = concat_in_shapes[0].len();
342        let axes = (0..rank)
343            .filter(|ax| concat_in_shapes[0][*ax] != concat_in_shapes[1][*ax])
344            .collect_vec();
345        ensure!(axes.len() == 1);
346
347        let axis = axes[0];
348        rule_if!(
349            matches!(concat_in_shapes[0][axis], TDim::Sym(_))
350                && matches!(concat_in_shapes[1][axis], TDim::Sym(_))
351        );
352        let mut facts = [concat_in_facts[0].clone(), concat_in_facts[1].clone()];
353        if concat_node.inputs[0].node == source_node_id {
354            (concat_node.id, concat_node.inputs[1].node, axis, facts)
355        } else if concat_node.inputs[1].node == source_node_id {
356            facts.swap(0, 1);
357            (concat_node.id, concat_node.inputs[0].node, axis, facts)
358        } else {
359            return Ok(None);
360        }
361    };
362
363    {
364        // Replace Concat by KVCache
365        let name = target.node_names().collect_vec()[source_node_id].to_string();
366        let concat_node = target.node_mut(concat_node_id);
367        concat_node.op = Box::new(DynKeyValueCache {
368            name: name.clone(),
369            axis,
370            past_sequence_fact: input_facts[0].clone(),
371            input_sequence_fact: input_facts[1].clone(),
372        });
373        concat_node.name = name;
374        concat_node.inputs.retain(|input| input != &source_node_id.into());
375    }
376
377    {
378        // Replace Source by Dummy Op for it to be cleaned later
379        let dummy_op = target.create_dummy();
380        let source_node = target.node_mut(source_node_id);
381        source_node.outputs[0].successors.clear();
382        source_node.op = dummy_op;
383    }
384    {
385        // Non-source input is usually the second input of Concat. Rewire it to the only input of the new KVCache Op
386        let non_source_input = target.node_mut(non_source_input_id);
387        non_source_input.outputs.iter_mut().for_each(|output| {
388            output.successors.iter_mut().for_each(|succ| {
389                if succ.node == concat_node_id {
390                    succ.slot = 0
391                }
392            })
393        });
394    }
395
396    // Clean model I/Os
397    target.outputs.retain(|output| output.node != concat_node_id);
398    target.inputs.retain(|input| input.node != source_node_id);
399    target.outlet_labels.remove(&concat_node_id.into());
400    Ok(None)
401}
402
403#[cfg(test)]
404mod tests {
405    use super::*;
406    use tract_num_traits::AsPrimitive;
407    use tract_num_traits::Zero;
408
409    fn run_test_case<F: Datum + Zero + Copy>(
410        input_shapes: &[Vec<usize>],
411        axis: usize,
412    ) -> TractResult<()>
413    where
414        usize: AsPrimitive<F>,
415    {
416        let first_shape = &input_shapes[0];
417        ensure!(input_shapes.iter().all(|shape| (shape.len() == first_shape.len())
418            && (shape[..axis] == first_shape[..axis])
419            && (if axis != (shape.len() - 1) {
420                shape[(axis + 1)..] == first_shape[(axis + 1)..]
421            } else {
422                true
423            })));
424
425        let op_name = "test".to_string();
426        let dummy_model = TypedModel::default();
427
428        let make_shape =
429            |sym: &str| {
430                input_shapes[0]
431                    .iter()
432                    .enumerate()
433                    .map(|(i, &dim)| {
434                        if i == axis {
435                            TDim::Sym(dummy_model.sym(sym))
436                        } else {
437                            TDim::Val(dim as _)
438                        }
439                    })
440                    .collect::<TVec<TDim>>()
441            };
442
443        let past_shape = make_shape("P");
444        let input_shape = make_shape("S");
445
446        let op = DynKeyValueCache {
447            name: op_name.clone(),
448            past_sequence_fact: TypedFact::dt_shape(F::datum_type(), past_shape),
449            input_sequence_fact: TypedFact::dt_shape(F::datum_type(), input_shape),
450            axis,
451        };
452
453        let mut session_state = TurnState::default();
454        let mut state = op.state(&mut session_state, 0)?.unwrap();
455
456        let mut inputs = tvec![];
457
458        // Init state with first shape
459        let shape = &input_shapes[0];
460        let len = shape.iter().product::<usize>();
461        let input = Tensor::from_shape(shape, &(0..len).map(|f| f.as_()).collect::<Vec<F>>())?;
462        inputs.push(input.clone().into_tvalue());
463
464        let mut state_initializers = vec![input.into()].into_iter();
465
466        state.load_from(&mut session_state, &mut state_initializers)?;
467
468        for shape in input_shapes {
469            let len = shape.iter().product::<usize>();
470            let input = Tensor::from_shape(&shape, &(0..len).map(|f| f.as_()).collect::<Vec<F>>())?;
471            inputs.push(input.clone().into_tvalue());
472            state.eval(&mut session_state, &op, tvec!(input.clone().into()))?[0]
473                .clone()
474                .into_tensor();
475        }
476
477        let mut curr_states = vec![];
478        state.save_to(&mut curr_states)?;
479        let output = curr_states.remove(0);
480
481        let reference = &TypedConcat { axis }.eval(inputs)?[0];
482        output.close_enough(&reference.clone().into_tensor(), Approximation::Close)?;
483        Ok(())
484    }
485
486    #[test]
487    fn test_dyn_kv_cache() -> TractResult<()> {
488        run_test_case::<f32>(&[vec![2, 2]], 0)?;
489        run_test_case::<f32>(&[vec![2, 2], vec![4, 2]], 0)?;
490        run_test_case::<f32>(&[vec![2, 2], vec![2, 1], vec![2, 3]], 1)?;
491        Ok(())
492    }
493
494    #[test]
495    fn test_unfold_kv_cache() -> TractResult<()> {
496        // Build a model with DynKeyValueCache
497        let mut model = TypedModel::default();
498        let s = model.sym("S");
499        let p = model.sym("P");
500
501        let input_shape: TVec<TDim> = tvec![1.to_dim(), s.into(), 64.to_dim()];
502        let past_shape: TVec<TDim> = tvec![1.to_dim(), p.into(), 64.to_dim()];
503
504        let input = model.add_source("input", f32::fact(&input_shape))?;
505        let op = DynKeyValueCache {
506            name: "kv_cache_0".to_string(),
507            axis: 1,
508            past_sequence_fact: f32::fact(&past_shape),
509            input_sequence_fact: f32::fact(&input_shape),
510        };
511        let out = model.wire_node("kv_cache", op, &[input])?;
512        model.set_output_outlets(&out)?;
513
514        // Model should have 1 input (input), 1 output (kv_cache)
515        assert_eq!(model.inputs.len(), 1);
516        assert_eq!(model.outputs.len(), 1);
517        assert!(model.node(1).op_is::<DynKeyValueCache>());
518
519        // Unfold
520        unfold_kv_cache(&mut model, 1)?;
521
522        // After unfold: 2 inputs (input + kv_cache_0 source), 2 outputs (original + concat)
523        assert_eq!(model.inputs.len(), 2);
524        assert_eq!(model.outputs.len(), 2);
525
526        // The KV cache node should now be a Concat
527        assert!(model.node(1).op_is::<TypedConcat>());
528        let concat = model.node(1).op_as::<TypedConcat>().unwrap();
529        assert_eq!(concat.axis, 1);
530
531        // The new source node should exist
532        let source_node_id = model.inputs[1].node;
533        assert!(model.node(source_node_id).op_is::<TypedSource>());
534        assert_eq!(model.node(source_node_id).name, "kv_cache_0");
535
536        // Concat should have 2 inputs: [source, input]
537        assert_eq!(model.node(1).inputs.len(), 2);
538        assert_eq!(model.node(1).inputs[0].node, source_node_id);
539        assert_eq!(model.node(1).inputs[1].node, 0); // original input
540
541        Ok(())
542    }
543
544    #[test]
545    fn test_fold_unfold_round_trip() -> TractResult<()> {
546        use crate::rewriter::KeyValueCacheTransform;
547        use tract_nnef::tract_core::transform::ModelTransform;
548
549        // Build a model with Source + Concat (the pre-fold pattern)
550        let mut model = TypedModel::default();
551        let s = model.sym("S");
552        let p = model.sym("P");
553
554        let input_shape: TVec<TDim> = tvec![1.to_dim(), s.into(), 64.to_dim()];
555        let past_shape: TVec<TDim> = tvec![1.to_dim(), p.into(), 64.to_dim()];
556
557        let past = model.add_source("kv_past", f32::fact(&past_shape))?;
558        let input = model.add_source("input", f32::fact(&input_shape))?;
559        let concat = model.wire_node("concat", TypedConcat { axis: 1 }, &[past, input])?;
560        model.set_output_outlets(&concat)?;
561
562        let orig_input_count = model.inputs.len();
563        let orig_output_count = model.outputs.len();
564
565        // Fold: Source + Concat -> DynKeyValueCache
566        KeyValueCacheTransform.transform(&mut model)?;
567        assert_eq!(model.inputs.len(), orig_input_count - 1); // past source removed
568        assert_eq!(model.outputs.len(), orig_output_count - 1); // concat output removed
569
570        // Find the DynKeyValueCache node
571        let kv_node_id = model.nodes().iter().find(|n| n.op_is::<DynKeyValueCache>()).unwrap().id;
572
573        // Unfold: DynKeyValueCache -> Source + Concat
574        unfold_kv_cache(&mut model, kv_node_id)?;
575
576        // Should be back to original structure
577        assert_eq!(model.inputs.len(), orig_input_count);
578        assert_eq!(model.outputs.len(), orig_output_count);
579
580        // Verify it's a Concat again
581        let concat_node = model.nodes().iter().find(|n| n.op_is::<TypedConcat>()).unwrap();
582        assert_eq!(concat_node.op_as::<TypedConcat>().unwrap().axis, 1);
583        assert_eq!(concat_node.inputs.len(), 2);
584
585        Ok(())
586    }
587
588    #[test]
589    fn test_dyn_kv_cache_nnef_round_trip() -> TractResult<()> {
590        use crate::WithTractTransformers;
591
592        let mut model = TypedModel::default();
593        let s = model.sym("S");
594        let p = model.sym("P");
595
596        let input_shape: TVec<TDim> = tvec![1.to_dim(), s.into(), 64.to_dim()];
597        let past_shape: TVec<TDim> = tvec![1.to_dim(), p.into(), 64.to_dim()];
598
599        let input = model.add_source("input", f32::fact(&input_shape))?;
600        let op = DynKeyValueCache {
601            name: "kv_cache_0".to_string(),
602            axis: 1,
603            past_sequence_fact: f32::fact(&past_shape),
604            input_sequence_fact: f32::fact(&input_shape),
605        };
606        let out = model.wire_node("kv_cache", op, &[input])?;
607        model.set_output_outlets(&out)?;
608
609        let nnef = tract_nnef::nnef().with_tract_transformers();
610        let mut buffer = vec![];
611        nnef.write_to_tar(&model, &mut buffer)?;
612        let reloaded = nnef.model_for_read(&mut &*buffer)?;
613
614        assert_eq!(reloaded.nodes().len(), model.nodes().len());
615        let reloaded_kv = reloaded.node(1);
616        let reloaded_op = reloaded_kv.op_as::<DynKeyValueCache>().unwrap();
617        assert_eq!(reloaded_op.name, "kv_cache_0");
618        assert_eq!(reloaded_op.axis, 1);
619        assert_eq!(reloaded_op.past_sequence_fact.datum_type, DatumType::F32);
620        assert_eq!(reloaded_op.past_sequence_fact.shape.rank(), 3);
621        assert_eq!(reloaded_op.input_sequence_fact.datum_type, DatumType::F32);
622        assert_eq!(reloaded_op.input_sequence_fact.shape.rank(), 3);
623        Ok(())
624    }
625}