Skip to main content

tract_transformers/ops/
dyn_kv_cache.rs

1use std::str::FromStr;
2
3use tract_nnef::internal::*;
4use tract_nnef::prelude::tract_itertools::Itertools;
5use tract_nnef::ser::{datum_type, tdims};
6use tract_nnef::tract_core::ops::OpStateFreeze;
7use tract_nnef::tract_core::ops::array::TypedConcat;
8use tract_nnef::tract_core::ops::source::TypedSource;
9
10pub fn register(registry: &mut Registry) {
11    registry.register_dumper(ser_dyn_kv_cache);
12    registry.register_primitive(
13        "tract_transformers_dyn_kv_cache",
14        &[
15            TypeName::Scalar.tensor().named("input"),
16            TypeName::String.named("name"),
17            TypeName::Integer.named("axis"),
18            TypeName::String.named("datum_type"),
19            TypeName::Integer.array().named("past_sequence_shape"),
20            TypeName::Integer.array().named("input_sequence_shape"),
21        ],
22        &[("output", TypeName::Scalar.tensor())],
23        de_dyn_kv_cache,
24    );
25}
26
27fn ser_dyn_kv_cache(
28    ast: &mut IntoAst,
29    node: &TypedNode,
30    op: &DynKeyValueCache,
31) -> TractResult<Option<Arc<RValue>>> {
32    let input = ast.mapping[&node.inputs[0]].clone();
33    Ok(Some(invocation(
34        "tract_transformers_dyn_kv_cache",
35        &[input],
36        &[
37            ("name", string(&op.name)),
38            ("axis", numeric(op.axis)),
39            ("datum_type", datum_type(op.past_sequence_fact.datum_type)),
40            ("past_sequence_shape", tdims(op.past_sequence_fact.shape.dims())),
41            ("input_sequence_shape", tdims(op.input_sequence_fact.shape.dims())),
42        ],
43    )))
44}
45
46fn de_dyn_kv_cache(
47    builder: &mut ModelBuilder,
48    invocation: &ResolvedInvocation,
49) -> TractResult<Value> {
50    let input = invocation.named_arg_as(builder, "input")?;
51    let name: String = invocation.named_arg_as(builder, "name")?;
52    let axis: usize = invocation.named_arg_as(builder, "axis")?;
53    let dt = DatumType::from_str(&invocation.named_arg_as::<String>(builder, "datum_type")?)?;
54    let past_sequence_shape: TVec<TDim> = builder
55        .allowing_new_symbols(|builder| invocation.named_arg_as(builder, "past_sequence_shape"))?;
56    let input_sequence_shape: TVec<TDim> = builder
57        .allowing_new_symbols(|builder| invocation.named_arg_as(builder, "input_sequence_shape"))?;
58    builder.wire(
59        DynKeyValueCache {
60            name,
61            axis,
62            past_sequence_fact: dt.fact(&*past_sequence_shape),
63            input_sequence_fact: dt.fact(&*input_sequence_shape),
64        },
65        &[input],
66    )
67}
68
69#[derive(Debug, Clone)]
70pub struct DynKeyValueCacheState {
71    name: String,
72    axis: usize,
73    past_sequence_fact: TypedFact,
74    kv_cache: Option<TValue>,
75}
76
77impl DynKeyValueCacheState {
78    pub fn resolve_symbols(
79        state: &mut TurnState,
80        fact: TypedFact,
81        concrete_shape: Option<&[usize]>,
82    ) -> TractResult<()> {
83        let unresolved = fact
84            .shape
85            .iter()
86            .enumerate()
87            .filter_map(|(ax, symb)| match symb {
88                TDim::Sym(s) if state.resolved_symbols.get(s).is_none() => Some((ax, s)),
89                _ => None,
90            })
91            .collect_vec();
92
93        if unresolved.is_empty() {
94            return Ok(());
95        }
96
97        ensure!(unresolved.len() == 1);
98        let (ax, sym) = unresolved[0];
99        if let Some(shape) = concrete_shape {
100            ensure!(ax < shape.len());
101            state.resolved_symbols.set(sym, shape[ax] as i64);
102        } else {
103            state.resolved_symbols.set(sym, 0);
104        }
105
106        if state.scenario.is_none() {
107            state.scenario = sym.scope().unwrap().guess_scenario(&state.resolved_symbols)?;
108        }
109        Ok(())
110    }
111
112    pub fn truncate(&mut self, len: usize) -> TractResult<()> {
113        if let Some(t) = self.kv_cache.as_mut() {
114            *t = t.slice(self.axis, 0, len)?.into_tvalue();
115        } else {
116            bail!("Can not truncate a zero-len kv-cache value");
117        }
118        Ok(())
119    }
120}
121
122impl OpState for DynKeyValueCacheState {
123    fn load_from(
124        &mut self,
125        state: &mut TurnState,
126        states: &mut dyn Iterator<Item = tract_nnef::prelude::TValue>,
127    ) -> TractResult<()> {
128        // KV Cache fact is always at index 0
129        let kv_cache_init = states.next().context("Not enough state initializers")?;
130        Self::resolve_symbols(state, self.past_sequence_fact.clone(), Some(kv_cache_init.shape()))?;
131        self.kv_cache = Some(kv_cache_init.clone());
132
133        Ok(())
134    }
135
136    fn save_to(&self, states: &mut Vec<TValue>) -> TractResult<()> {
137        if let Some(kv_cache) = &self.kv_cache {
138            states.push(kv_cache.clone());
139            Ok(())
140        } else {
141            bail!("KV cache {} was never initialized", self.name)
142        }
143    }
144
145    fn init_tensor_fact(&self) -> Option<(String, TypedFact)> {
146        Some((self.name.clone(), self.past_sequence_fact.clone()))
147    }
148
149    fn resolve_symbols(&mut self, state: &mut TurnState) -> TractResult<()> {
150        let shape = self.kv_cache.as_ref().map(|kv_cache| kv_cache.shape());
151        Self::resolve_symbols(state, self.past_sequence_fact.clone(), shape)
152    }
153
154    fn eval(
155        &mut self,
156        _state: &mut TurnState,
157        _op: &dyn Op,
158        inputs: TVec<TValue>,
159    ) -> TractResult<TVec<TValue>> {
160        let input = args_1!(inputs);
161        // build output
162        let output = if let Some(curr) = self.kv_cache.take() {
163            TypedConcat { axis: self.axis }.eval(tvec![curr, input])?.remove(0)
164        } else {
165            input
166        };
167        self.kv_cache = Some(output.clone());
168
169        Ok(tvec!(output))
170    }
171}
172
173#[derive(Clone, Debug, PartialEq, Eq)]
174pub struct DynKeyValueCache {
175    pub name: String,
176    pub axis: usize,
177    pub past_sequence_fact: TypedFact,
178    pub input_sequence_fact: TypedFact,
179}
180
181impl Op for DynKeyValueCache {
182    fn name(&self) -> StaticName {
183        "DynamicKeyValueCache".to_string().into()
184    }
185
186    op_as_typed_op!();
187}
188
189impl EvalOp for DynKeyValueCache {
190    fn is_stateless(&self) -> bool {
191        false
192    }
193
194    fn state(
195        &self,
196        _session: &TurnState,
197        _node_id: usize,
198    ) -> TractResult<Option<Box<dyn OpState>>> {
199        Ok(Some(Box::new(DynKeyValueCacheState {
200            name: self.name.clone(),
201            axis: self.axis,
202            past_sequence_fact: self.past_sequence_fact.clone(),
203            kv_cache: None,
204        })))
205    }
206}
207
208impl TypedOp for DynKeyValueCache {
209    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
210        ensure!(inputs.len() == 1);
211        let input = inputs[0];
212        let mut fact = input.without_value();
213
214        fact.shape.set(
215            self.axis,
216            self.past_sequence_fact.shape.dims()[self.axis].clone()
217                + self.input_sequence_fact.shape.dims()[self.axis].clone(),
218        );
219        Ok(tvec!(fact))
220    }
221
222    fn cost(&self, _inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
223        let token_volume = self
224            .past_sequence_fact
225            .shape
226            .iter()
227            .enumerate()
228            .filter(|(axis, _d)| *axis != self.axis)
229            .map(|(_axis, d)| d)
230            .product::<TDim>();
231        Ok(tvec!((Cost::Custom(false, "KVCacheValuesPerToken".to_string()), token_volume)))
232    }
233
234    as_op!();
235}
236
237#[derive(Debug, Clone)]
238pub struct FrozenDynKeyValueCacheState {
239    name: String,
240    axis: usize,
241    past_sequence_fact: TypedFact,
242    kv_cache: Option<Tensor>,
243}
244
245impl OpStateFreeze for DynKeyValueCacheState {
246    fn freeze(&self) -> Box<dyn FrozenOpState> {
247        Box::new(FrozenDynKeyValueCacheState {
248            name: self.name.clone(),
249            axis: self.axis,
250            past_sequence_fact: self.past_sequence_fact.clone(),
251            kv_cache: self.kv_cache.clone().map(|t| t.into_tensor()),
252        })
253    }
254
255    fn freeze_into(self: Box<Self>) -> Box<dyn FrozenOpState> {
256        Box::new(FrozenDynKeyValueCacheState {
257            name: self.name,
258            axis: self.axis,
259            past_sequence_fact: self.past_sequence_fact,
260            kv_cache: self.kv_cache.map(|t| t.into_tensor()),
261        })
262    }
263}
264
265impl FrozenOpState for FrozenDynKeyValueCacheState {
266    fn unfreeze(&self) -> Box<dyn OpState> {
267        Box::new(DynKeyValueCacheState {
268            axis: self.axis,
269            name: self.name.clone(),
270            past_sequence_fact: self.past_sequence_fact.clone(),
271            kv_cache: self.kv_cache.clone().map(|t| t.into_tvalue()),
272        })
273    }
274}
275
276/// Reverse of `replace_kv_cache`: replaces a DynKeyValueCache node with Source + Concat,
277/// restoring KV cache state as explicit model inputs and outputs.
278pub fn unfold_kv_cache(target: &mut TypedModel, kv_node_id: usize) -> TractResult<()> {
279    let node = target.node(kv_node_id);
280    let op = node.op_as::<DynKeyValueCache>().context("Not a DynKeyValueCache node")?;
281    let name = op.name.clone();
282    let axis = op.axis;
283    let past_fact = op.past_sequence_fact.clone();
284    let input_fact = op.input_sequence_fact.clone();
285    let existing_input = node.inputs[0];
286
287    // Add a new Source node for the past KV cache
288    let source_outlet = target.add_source(&name, past_fact)?;
289
290    // Compute output fact for the Concat
291    let mut output_fact = input_fact.clone();
292    output_fact.shape.set(
293        axis,
294        target.outlet_fact(source_outlet)?.shape.dims()[axis].clone()
295            + input_fact.shape.dims()[axis].clone(),
296    );
297
298    // Replace DynKeyValueCache op with TypedConcat
299    let kv_node = target.node_mut(kv_node_id);
300    kv_node.name = format!("{name}_concat");
301    kv_node.op = Box::new(TypedConcat { axis });
302    kv_node.outputs[0].fact = output_fact;
303
304    // Rewire: Concat takes [source, existing_input] as inputs
305    // Currently the node has [existing_input] at slot 0
306    // We need [source_outlet, existing_input] at slots [0, 1]
307    kv_node.inputs = vec![source_outlet, existing_input];
308
309    // Update successor info on the source node
310    target.nodes[source_outlet.node].outputs[source_outlet.slot]
311        .successors
312        .push(InletId::new(kv_node_id, 0));
313
314    // Update the existing input's successor slot from 0 to 1
315    target.nodes[existing_input.node].outputs[existing_input.slot].successors.iter_mut().for_each(
316        |succ| {
317            if succ.node == kv_node_id && succ.slot == 0 {
318                succ.slot = 1;
319            }
320        },
321    );
322
323    // Add the Concat output to model outputs and label it so runtimes preserve the name
324    let concat_outlet = OutletId::new(kv_node_id, 0);
325    target.outputs.push(concat_outlet);
326    target.set_outlet_label(concat_outlet, format!("{name}_concat"))?;
327
328    Ok(())
329}
330
331/// Search pattern => Input -> Concat -> Output
332/// Return type is for using rule-ensure macro
333pub fn replace_kv_cache(target: &mut TypedModel, source_node_id: usize) -> TractResult<Option<()>> {
334    assert!(target.node(source_node_id).op_is::<TypedSource>());
335    let (concat_node_id, non_source_input_id, axis, input_facts) = {
336        rule_if_some!(concat_node = target.next_node(target.node(source_node_id)));
337
338        // Check KV Cache Pattern
339        rule_if!(
340            concat_node.op_is::<TypedConcat>()
341                && concat_node.inputs.len() == 2
342                && concat_node.outputs.len() == 1
343                && target.outputs.contains(&concat_node.id.into())
344        );
345
346        let concat_in_facts = target.node_input_facts(concat_node.id)?;
347
348        // Check on shapes
349        let concat_in_shapes = [concat_in_facts[0].shape.dims(), concat_in_facts[1].shape.dims()];
350        let rank = concat_in_shapes[0].len();
351        let axes = (0..rank)
352            .filter(|ax| concat_in_shapes[0][*ax] != concat_in_shapes[1][*ax])
353            .collect_vec();
354        ensure!(axes.len() == 1);
355
356        let axis = axes[0];
357        rule_if!(
358            matches!(concat_in_shapes[0][axis], TDim::Sym(_))
359                && matches!(concat_in_shapes[1][axis], TDim::Sym(_))
360        );
361        let mut facts = [concat_in_facts[0].clone(), concat_in_facts[1].clone()];
362        if concat_node.inputs[0].node == source_node_id {
363            (concat_node.id, concat_node.inputs[1].node, axis, facts)
364        } else if concat_node.inputs[1].node == source_node_id {
365            facts.swap(0, 1);
366            (concat_node.id, concat_node.inputs[0].node, axis, facts)
367        } else {
368            return Ok(None);
369        }
370    };
371
372    {
373        // Replace Concat by KVCache
374        let name = target.node_names().collect_vec()[source_node_id].to_string();
375        let concat_node = target.node_mut(concat_node_id);
376        concat_node.op = Box::new(DynKeyValueCache {
377            name: name.clone(),
378            axis,
379            past_sequence_fact: input_facts[0].clone(),
380            input_sequence_fact: input_facts[1].clone(),
381        });
382        concat_node.name = name;
383        concat_node.inputs.retain(|input| input != &source_node_id.into());
384    }
385
386    {
387        // Replace Source by Dummy Op for it to be cleaned later
388        let dummy_op = target.create_dummy();
389        let source_node = target.node_mut(source_node_id);
390        source_node.outputs[0].successors.clear();
391        source_node.op = dummy_op;
392    }
393    {
394        // Non-source input is usually the second input of Concat. Rewire it to the only input of the new KVCache Op
395        let non_source_input = target.node_mut(non_source_input_id);
396        non_source_input.outputs.iter_mut().for_each(|output| {
397            output.successors.iter_mut().for_each(|succ| {
398                if succ.node == concat_node_id {
399                    succ.slot = 0
400                }
401            })
402        });
403    }
404
405    // Clean model I/Os
406    target.outputs.retain(|output| output.node != concat_node_id);
407    target.inputs.retain(|input| input.node != source_node_id);
408    target.outlet_labels.remove(&concat_node_id.into());
409    Ok(None)
410}
411
412#[cfg(test)]
413mod tests {
414    use super::*;
415    use tract_num_traits::AsPrimitive;
416    use tract_num_traits::Zero;
417
418    fn run_test_case<F: Datum + Zero + Copy>(
419        input_shapes: &[Vec<usize>],
420        axis: usize,
421    ) -> TractResult<()>
422    where
423        usize: AsPrimitive<F>,
424    {
425        let first_shape = &input_shapes[0];
426        ensure!(input_shapes.iter().all(|shape| (shape.len() == first_shape.len())
427            && (shape[..axis] == first_shape[..axis])
428            && (if axis != (shape.len() - 1) {
429                shape[(axis + 1)..] == first_shape[(axis + 1)..]
430            } else {
431                true
432            })));
433
434        let op_name = "test".to_string();
435        let dummy_model = TypedModel::default();
436
437        let make_shape =
438            |sym: &str| {
439                input_shapes[0]
440                    .iter()
441                    .enumerate()
442                    .map(|(i, &dim)| {
443                        if i == axis {
444                            TDim::Sym(dummy_model.sym(sym))
445                        } else {
446                            TDim::Val(dim as _)
447                        }
448                    })
449                    .collect::<TVec<TDim>>()
450            };
451
452        let past_shape = make_shape("P");
453        let input_shape = make_shape("S");
454
455        let op = DynKeyValueCache {
456            name: op_name.clone(),
457            past_sequence_fact: TypedFact::dt_shape(F::datum_type(), past_shape),
458            input_sequence_fact: TypedFact::dt_shape(F::datum_type(), input_shape),
459            axis,
460        };
461
462        let mut session_state = TurnState::default();
463        let mut state = op.state(&mut session_state, 0)?.unwrap();
464
465        let mut inputs = tvec![];
466
467        // Init state with first shape
468        let shape = &input_shapes[0];
469        let len = shape.iter().product::<usize>();
470        let input = Tensor::from_shape(shape, &(0..len).map(|f| f.as_()).collect::<Vec<F>>())?;
471        inputs.push(input.clone().into_tvalue());
472
473        let mut state_initializers = vec![input.into()].into_iter();
474
475        state.load_from(&mut session_state, &mut state_initializers)?;
476
477        for shape in input_shapes {
478            let len = shape.iter().product::<usize>();
479            let input = Tensor::from_shape(&shape, &(0..len).map(|f| f.as_()).collect::<Vec<F>>())?;
480            inputs.push(input.clone().into_tvalue());
481            state.eval(&mut session_state, &op, tvec!(input.clone().into()))?[0]
482                .clone()
483                .into_tensor();
484        }
485
486        let mut curr_states = vec![];
487        state.save_to(&mut curr_states)?;
488        let output = curr_states.remove(0);
489
490        let reference = &TypedConcat { axis }.eval(inputs)?[0];
491        output.close_enough(&reference.clone().into_tensor(), Approximation::Close)?;
492        Ok(())
493    }
494
495    #[test]
496    fn test_dyn_kv_cache() -> TractResult<()> {
497        run_test_case::<f32>(&[vec![2, 2]], 0)?;
498        run_test_case::<f32>(&[vec![2, 2], vec![4, 2]], 0)?;
499        run_test_case::<f32>(&[vec![2, 2], vec![2, 1], vec![2, 3]], 1)?;
500        Ok(())
501    }
502
503    #[test]
504    fn test_unfold_kv_cache() -> TractResult<()> {
505        // Build a model with DynKeyValueCache
506        let mut model = TypedModel::default();
507        let s = model.sym("S");
508        let p = model.sym("P");
509
510        let input_shape: TVec<TDim> = tvec![1.to_dim(), s.into(), 64.to_dim()];
511        let past_shape: TVec<TDim> = tvec![1.to_dim(), p.into(), 64.to_dim()];
512
513        let input = model.add_source("input", f32::fact(&input_shape))?;
514        let op = DynKeyValueCache {
515            name: "kv_cache_0".to_string(),
516            axis: 1,
517            past_sequence_fact: f32::fact(&past_shape),
518            input_sequence_fact: f32::fact(&input_shape),
519        };
520        let out = model.wire_node("kv_cache", op, &[input])?;
521        model.select_output_outlets(&out)?;
522
523        // Model should have 1 input (input), 1 output (kv_cache)
524        assert_eq!(model.inputs.len(), 1);
525        assert_eq!(model.outputs.len(), 1);
526        assert!(model.node(1).op_is::<DynKeyValueCache>());
527
528        // Unfold
529        unfold_kv_cache(&mut model, 1)?;
530
531        // After unfold: 2 inputs (input + kv_cache_0 source), 2 outputs (original + concat)
532        assert_eq!(model.inputs.len(), 2);
533        assert_eq!(model.outputs.len(), 2);
534
535        // The KV cache node should now be a Concat
536        assert!(model.node(1).op_is::<TypedConcat>());
537        let concat = model.node(1).op_as::<TypedConcat>().unwrap();
538        assert_eq!(concat.axis, 1);
539
540        // The new source node should exist
541        let source_node_id = model.inputs[1].node;
542        assert!(model.node(source_node_id).op_is::<TypedSource>());
543        assert_eq!(model.node(source_node_id).name, "kv_cache_0");
544
545        // Concat should have 2 inputs: [source, input]
546        assert_eq!(model.node(1).inputs.len(), 2);
547        assert_eq!(model.node(1).inputs[0].node, source_node_id);
548        assert_eq!(model.node(1).inputs[1].node, 0); // original input
549
550        Ok(())
551    }
552
553    #[test]
554    fn test_fold_unfold_round_trip() -> TractResult<()> {
555        use crate::rewriter::KeyValueCacheTransform;
556        use tract_nnef::tract_core::transform::ModelTransform;
557
558        // Build a model with Source + Concat (the pre-fold pattern)
559        let mut model = TypedModel::default();
560        let s = model.sym("S");
561        let p = model.sym("P");
562
563        let input_shape: TVec<TDim> = tvec![1.to_dim(), s.into(), 64.to_dim()];
564        let past_shape: TVec<TDim> = tvec![1.to_dim(), p.into(), 64.to_dim()];
565
566        let past = model.add_source("kv_past", f32::fact(&past_shape))?;
567        let input = model.add_source("input", f32::fact(&input_shape))?;
568        let concat = model.wire_node("concat", TypedConcat { axis: 1 }, &[past, input])?;
569        model.select_output_outlets(&concat)?;
570
571        let orig_input_count = model.inputs.len();
572        let orig_output_count = model.outputs.len();
573
574        // Fold: Source + Concat -> DynKeyValueCache
575        KeyValueCacheTransform.transform(&mut model)?;
576        assert_eq!(model.inputs.len(), orig_input_count - 1); // past source removed
577        assert_eq!(model.outputs.len(), orig_output_count - 1); // concat output removed
578
579        // Find the DynKeyValueCache node
580        let kv_node_id = model.nodes().iter().find(|n| n.op_is::<DynKeyValueCache>()).unwrap().id;
581
582        // Unfold: DynKeyValueCache -> Source + Concat
583        unfold_kv_cache(&mut model, kv_node_id)?;
584
585        // Should be back to original structure
586        assert_eq!(model.inputs.len(), orig_input_count);
587        assert_eq!(model.outputs.len(), orig_output_count);
588
589        // Verify it's a Concat again
590        let concat_node = model.nodes().iter().find(|n| n.op_is::<TypedConcat>()).unwrap();
591        assert_eq!(concat_node.op_as::<TypedConcat>().unwrap().axis, 1);
592        assert_eq!(concat_node.inputs.len(), 2);
593
594        Ok(())
595    }
596
597    #[test]
598    fn test_dyn_kv_cache_nnef_round_trip() -> TractResult<()> {
599        use crate::WithTractTransformers;
600
601        let mut model = TypedModel::default();
602        let s = model.sym("S");
603        let p = model.sym("P");
604
605        let input_shape: TVec<TDim> = tvec![1.to_dim(), s.into(), 64.to_dim()];
606        let past_shape: TVec<TDim> = tvec![1.to_dim(), p.into(), 64.to_dim()];
607
608        let input = model.add_source("input", f32::fact(&input_shape))?;
609        let op = DynKeyValueCache {
610            name: "kv_cache_0".to_string(),
611            axis: 1,
612            past_sequence_fact: f32::fact(&past_shape),
613            input_sequence_fact: f32::fact(&input_shape),
614        };
615        let out = model.wire_node("kv_cache", op, &[input])?;
616        model.select_output_outlets(&out)?;
617
618        let nnef = tract_nnef::nnef().with_tract_transformers();
619        let mut buffer = vec![];
620        nnef.write_to_tar(&model, &mut buffer)?;
621        let reloaded = nnef.model_for_read(&mut &*buffer)?;
622
623        assert_eq!(reloaded.nodes().len(), model.nodes().len());
624        let reloaded_kv = reloaded.node(1);
625        let reloaded_op = reloaded_kv.op_as::<DynKeyValueCache>().unwrap();
626        assert_eq!(reloaded_op.name, "kv_cache_0");
627        assert_eq!(reloaded_op.axis, 1);
628        assert_eq!(reloaded_op.past_sequence_fact.datum_type, DatumType::F32);
629        assert_eq!(reloaded_op.past_sequence_fact.shape.rank(), 3);
630        assert_eq!(reloaded_op.input_sequence_fact.datum_type, DatumType::F32);
631        assert_eq!(reloaded_op.input_sequence_fact.shape.rank(), 3);
632        Ok(())
633    }
634}