1use std::str::FromStr;
2
3use tract_nnef::internal::*;
4use tract_nnef::prelude::tract_itertools::Itertools;
5use tract_nnef::ser::{datum_type, tdims};
6use tract_nnef::tract_core::ops::OpStateFreeze;
7use tract_nnef::tract_core::ops::array::TypedConcat;
8use tract_nnef::tract_core::ops::source::TypedSource;
9
10pub fn register(registry: &mut Registry) {
11 registry.register_dumper(ser_dyn_kv_cache);
12 registry.register_primitive(
13 "tract_transformers_dyn_kv_cache",
14 &[
15 TypeName::Scalar.tensor().named("input"),
16 TypeName::String.named("name"),
17 TypeName::Integer.named("axis"),
18 TypeName::String.named("datum_type"),
19 TypeName::Integer.array().named("past_sequence_shape"),
20 TypeName::Integer.array().named("input_sequence_shape"),
21 ],
22 &[("output", TypeName::Scalar.tensor())],
23 de_dyn_kv_cache,
24 );
25}
26
27fn ser_dyn_kv_cache(
28 ast: &mut IntoAst,
29 node: &TypedNode,
30 op: &DynKeyValueCache,
31) -> TractResult<Option<Arc<RValue>>> {
32 let input = ast.mapping[&node.inputs[0]].clone();
33 Ok(Some(invocation(
34 "tract_transformers_dyn_kv_cache",
35 &[input],
36 &[
37 ("name", string(&op.name)),
38 ("axis", numeric(op.axis)),
39 ("datum_type", datum_type(op.past_sequence_fact.datum_type)),
40 ("past_sequence_shape", tdims(op.past_sequence_fact.shape.dims())),
41 ("input_sequence_shape", tdims(op.input_sequence_fact.shape.dims())),
42 ],
43 )))
44}
45
46fn de_dyn_kv_cache(
47 builder: &mut ModelBuilder,
48 invocation: &ResolvedInvocation,
49) -> TractResult<Value> {
50 let input = invocation.named_arg_as(builder, "input")?;
51 let name: String = invocation.named_arg_as(builder, "name")?;
52 let axis: usize = invocation.named_arg_as(builder, "axis")?;
53 let dt = DatumType::from_str(&invocation.named_arg_as::<String>(builder, "datum_type")?)?;
54 let past_sequence_shape: TVec<TDim> = builder
55 .allowing_new_symbols(|builder| invocation.named_arg_as(builder, "past_sequence_shape"))?;
56 let input_sequence_shape: TVec<TDim> = builder
57 .allowing_new_symbols(|builder| invocation.named_arg_as(builder, "input_sequence_shape"))?;
58 builder.wire(
59 DynKeyValueCache {
60 name,
61 axis,
62 past_sequence_fact: dt.fact(&*past_sequence_shape),
63 input_sequence_fact: dt.fact(&*input_sequence_shape),
64 },
65 &[input],
66 )
67}
68
69#[derive(Debug, Clone)]
70pub struct DynKeyValueCacheState {
71 name: String,
72 axis: usize,
73 past_sequence_fact: TypedFact,
74 kv_cache: Option<TValue>,
75}
76
77impl DynKeyValueCacheState {
78 pub fn resolve_symbols(
79 state: &mut TurnState,
80 fact: TypedFact,
81 concrete_shape: Option<&[usize]>,
82 ) -> TractResult<()> {
83 let unresolved = fact
84 .shape
85 .iter()
86 .enumerate()
87 .filter_map(|(ax, symb)| match symb {
88 TDim::Sym(s) if state.resolved_symbols.get(s).is_none() => Some((ax, s)),
89 _ => None,
90 })
91 .collect_vec();
92
93 if unresolved.is_empty() {
94 return Ok(());
95 }
96
97 ensure!(unresolved.len() == 1);
98 let (ax, sym) = unresolved[0];
99 if let Some(shape) = concrete_shape {
100 ensure!(ax < shape.len());
101 state.resolved_symbols.set(sym, shape[ax] as i64);
102 } else {
103 state.resolved_symbols.set(sym, 0);
104 }
105
106 if state.scenario.is_none() {
107 state.scenario = sym.scope().unwrap().guess_scenario(&state.resolved_symbols)?;
108 }
109 Ok(())
110 }
111
112 pub fn truncate(&mut self, len: usize) -> TractResult<()> {
113 if let Some(t) = self.kv_cache.as_mut() {
114 *t = t.slice(self.axis, 0, len)?.into_tvalue();
115 } else {
116 bail!("Can not truncate a zero-len kv-cache value");
117 }
118 Ok(())
119 }
120}
121
122impl OpState for DynKeyValueCacheState {
123 fn load_from(
124 &mut self,
125 state: &mut TurnState,
126 states: &mut dyn Iterator<Item = tract_nnef::prelude::TValue>,
127 ) -> TractResult<()> {
128 let kv_cache_init = states.next().context("Not enough state initializers")?;
130 Self::resolve_symbols(state, self.past_sequence_fact.clone(), Some(kv_cache_init.shape()))?;
131 self.kv_cache = Some(kv_cache_init.clone());
132
133 Ok(())
134 }
135
136 fn save_to(&self, states: &mut Vec<TValue>) -> TractResult<()> {
137 if let Some(kv_cache) = &self.kv_cache {
138 states.push(kv_cache.clone());
139 Ok(())
140 } else {
141 bail!("KV cache {} was never initialized", self.name)
142 }
143 }
144
145 fn init_tensor_fact(&self) -> Option<(String, TypedFact)> {
146 Some((self.name.clone(), self.past_sequence_fact.clone()))
147 }
148
149 fn resolve_symbols(&mut self, state: &mut TurnState) -> TractResult<()> {
150 let shape = self.kv_cache.as_ref().map(|kv_cache| kv_cache.shape());
151 Self::resolve_symbols(state, self.past_sequence_fact.clone(), shape)
152 }
153
154 fn eval(
155 &mut self,
156 _state: &mut TurnState,
157 _op: &dyn Op,
158 inputs: TVec<TValue>,
159 ) -> TractResult<TVec<TValue>> {
160 let input = args_1!(inputs);
161 let output = if let Some(curr) = self.kv_cache.take() {
163 TypedConcat { axis: self.axis }.eval(tvec![curr, input])?.remove(0)
164 } else {
165 input
166 };
167 self.kv_cache = Some(output.clone());
168
169 Ok(tvec!(output))
170 }
171}
172
173#[derive(Clone, Debug, PartialEq, Eq)]
174pub struct DynKeyValueCache {
175 pub name: String,
176 pub axis: usize,
177 pub past_sequence_fact: TypedFact,
178 pub input_sequence_fact: TypedFact,
179}
180
181impl Op for DynKeyValueCache {
182 fn name(&self) -> StaticName {
183 "DynamicKeyValueCache".to_string().into()
184 }
185
186 op_as_typed_op!();
187}
188
189impl EvalOp for DynKeyValueCache {
190 fn is_stateless(&self) -> bool {
191 false
192 }
193
194 fn state(
195 &self,
196 _session: &TurnState,
197 _node_id: usize,
198 ) -> TractResult<Option<Box<dyn OpState>>> {
199 Ok(Some(Box::new(DynKeyValueCacheState {
200 name: self.name.clone(),
201 axis: self.axis,
202 past_sequence_fact: self.past_sequence_fact.clone(),
203 kv_cache: None,
204 })))
205 }
206}
207
208impl TypedOp for DynKeyValueCache {
209 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
210 ensure!(inputs.len() == 1);
211 let input = inputs[0];
212 let mut fact = input.without_value();
213
214 fact.shape.set(
215 self.axis,
216 self.past_sequence_fact.shape.dims()[self.axis].clone()
217 + self.input_sequence_fact.shape.dims()[self.axis].clone(),
218 );
219 Ok(tvec!(fact))
220 }
221
222 fn cost(&self, _inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
223 let token_volume = self
224 .past_sequence_fact
225 .shape
226 .iter()
227 .enumerate()
228 .filter(|(axis, _d)| *axis != self.axis)
229 .map(|(_axis, d)| d)
230 .product::<TDim>();
231 Ok(tvec!((Cost::Custom(false, "KVCacheValuesPerToken".to_string()), token_volume)))
232 }
233
234 as_op!();
235}
236
237#[derive(Debug, Clone)]
238pub struct FrozenDynKeyValueCacheState {
239 name: String,
240 axis: usize,
241 past_sequence_fact: TypedFact,
242 kv_cache: Option<Tensor>,
243}
244
245impl OpStateFreeze for DynKeyValueCacheState {
246 fn freeze(&self) -> Box<dyn FrozenOpState> {
247 Box::new(FrozenDynKeyValueCacheState {
248 name: self.name.clone(),
249 axis: self.axis,
250 past_sequence_fact: self.past_sequence_fact.clone(),
251 kv_cache: self.kv_cache.clone().map(|t| t.into_tensor()),
252 })
253 }
254
255 fn freeze_into(self: Box<Self>) -> Box<dyn FrozenOpState> {
256 Box::new(FrozenDynKeyValueCacheState {
257 name: self.name,
258 axis: self.axis,
259 past_sequence_fact: self.past_sequence_fact,
260 kv_cache: self.kv_cache.map(|t| t.into_tensor()),
261 })
262 }
263}
264
265impl FrozenOpState for FrozenDynKeyValueCacheState {
266 fn unfreeze(&self) -> Box<dyn OpState> {
267 Box::new(DynKeyValueCacheState {
268 axis: self.axis,
269 name: self.name.clone(),
270 past_sequence_fact: self.past_sequence_fact.clone(),
271 kv_cache: self.kv_cache.clone().map(|t| t.into_tvalue()),
272 })
273 }
274}
275
276pub fn unfold_kv_cache(target: &mut TypedModel, kv_node_id: usize) -> TractResult<()> {
279 let node = target.node(kv_node_id);
280 let op = node.op_as::<DynKeyValueCache>().context("Not a DynKeyValueCache node")?;
281 let name = op.name.clone();
282 let axis = op.axis;
283 let past_fact = op.past_sequence_fact.clone();
284 let input_fact = op.input_sequence_fact.clone();
285 let existing_input = node.inputs[0];
286
287 let source_outlet = target.add_source(&name, past_fact)?;
289
290 let mut output_fact = input_fact.clone();
292 output_fact.shape.set(
293 axis,
294 target.outlet_fact(source_outlet)?.shape.dims()[axis].clone()
295 + input_fact.shape.dims()[axis].clone(),
296 );
297
298 let kv_node = target.node_mut(kv_node_id);
300 kv_node.name = format!("{name}_concat");
301 kv_node.op = Box::new(TypedConcat { axis });
302 kv_node.outputs[0].fact = output_fact;
303
304 kv_node.inputs = vec![source_outlet, existing_input];
308
309 target.nodes[source_outlet.node].outputs[source_outlet.slot]
311 .successors
312 .push(InletId::new(kv_node_id, 0));
313
314 target.nodes[existing_input.node].outputs[existing_input.slot].successors.iter_mut().for_each(
316 |succ| {
317 if succ.node == kv_node_id && succ.slot == 0 {
318 succ.slot = 1;
319 }
320 },
321 );
322
323 let concat_outlet = OutletId::new(kv_node_id, 0);
325 target.outputs.push(concat_outlet);
326 target.set_outlet_label(concat_outlet, format!("{name}_concat"))?;
327
328 Ok(())
329}
330
331pub fn replace_kv_cache(target: &mut TypedModel, source_node_id: usize) -> TractResult<Option<()>> {
334 assert!(target.node(source_node_id).op_is::<TypedSource>());
335 let (concat_node_id, non_source_input_id, axis, input_facts) = {
336 rule_if_some!(concat_node = target.next_node(target.node(source_node_id)));
337
338 rule_if!(
340 concat_node.op_is::<TypedConcat>()
341 && concat_node.inputs.len() == 2
342 && concat_node.outputs.len() == 1
343 && target.outputs.contains(&concat_node.id.into())
344 );
345
346 let concat_in_facts = target.node_input_facts(concat_node.id)?;
347
348 let concat_in_shapes = [concat_in_facts[0].shape.dims(), concat_in_facts[1].shape.dims()];
350 let rank = concat_in_shapes[0].len();
351 let axes = (0..rank)
352 .filter(|ax| concat_in_shapes[0][*ax] != concat_in_shapes[1][*ax])
353 .collect_vec();
354 ensure!(axes.len() == 1);
355
356 let axis = axes[0];
357 rule_if!(
358 matches!(concat_in_shapes[0][axis], TDim::Sym(_))
359 && matches!(concat_in_shapes[1][axis], TDim::Sym(_))
360 );
361 let mut facts = [concat_in_facts[0].clone(), concat_in_facts[1].clone()];
362 if concat_node.inputs[0].node == source_node_id {
363 (concat_node.id, concat_node.inputs[1].node, axis, facts)
364 } else if concat_node.inputs[1].node == source_node_id {
365 facts.swap(0, 1);
366 (concat_node.id, concat_node.inputs[0].node, axis, facts)
367 } else {
368 return Ok(None);
369 }
370 };
371
372 {
373 let name = target.node_names().collect_vec()[source_node_id].to_string();
375 let concat_node = target.node_mut(concat_node_id);
376 concat_node.op = Box::new(DynKeyValueCache {
377 name: name.clone(),
378 axis,
379 past_sequence_fact: input_facts[0].clone(),
380 input_sequence_fact: input_facts[1].clone(),
381 });
382 concat_node.name = name;
383 concat_node.inputs.retain(|input| input != &source_node_id.into());
384 }
385
386 {
387 let dummy_op = target.create_dummy();
389 let source_node = target.node_mut(source_node_id);
390 source_node.outputs[0].successors.clear();
391 source_node.op = dummy_op;
392 }
393 {
394 let non_source_input = target.node_mut(non_source_input_id);
396 non_source_input.outputs.iter_mut().for_each(|output| {
397 output.successors.iter_mut().for_each(|succ| {
398 if succ.node == concat_node_id {
399 succ.slot = 0
400 }
401 })
402 });
403 }
404
405 target.outputs.retain(|output| output.node != concat_node_id);
407 target.inputs.retain(|input| input.node != source_node_id);
408 target.outlet_labels.remove(&concat_node_id.into());
409 Ok(None)
410}
411
412#[cfg(test)]
413mod tests {
414 use super::*;
415 use tract_num_traits::AsPrimitive;
416 use tract_num_traits::Zero;
417
418 fn run_test_case<F: Datum + Zero + Copy>(
419 input_shapes: &[Vec<usize>],
420 axis: usize,
421 ) -> TractResult<()>
422 where
423 usize: AsPrimitive<F>,
424 {
425 let first_shape = &input_shapes[0];
426 ensure!(input_shapes.iter().all(|shape| (shape.len() == first_shape.len())
427 && (shape[..axis] == first_shape[..axis])
428 && (if axis != (shape.len() - 1) {
429 shape[(axis + 1)..] == first_shape[(axis + 1)..]
430 } else {
431 true
432 })));
433
434 let op_name = "test".to_string();
435 let dummy_model = TypedModel::default();
436
437 let make_shape =
438 |sym: &str| {
439 input_shapes[0]
440 .iter()
441 .enumerate()
442 .map(|(i, &dim)| {
443 if i == axis {
444 TDim::Sym(dummy_model.sym(sym))
445 } else {
446 TDim::Val(dim as _)
447 }
448 })
449 .collect::<TVec<TDim>>()
450 };
451
452 let past_shape = make_shape("P");
453 let input_shape = make_shape("S");
454
455 let op = DynKeyValueCache {
456 name: op_name.clone(),
457 past_sequence_fact: TypedFact::dt_shape(F::datum_type(), past_shape),
458 input_sequence_fact: TypedFact::dt_shape(F::datum_type(), input_shape),
459 axis,
460 };
461
462 let mut session_state = TurnState::default();
463 let mut state = op.state(&mut session_state, 0)?.unwrap();
464
465 let mut inputs = tvec![];
466
467 let shape = &input_shapes[0];
469 let len = shape.iter().product::<usize>();
470 let input = Tensor::from_shape(shape, &(0..len).map(|f| f.as_()).collect::<Vec<F>>())?;
471 inputs.push(input.clone().into_tvalue());
472
473 let mut state_initializers = vec![input.into()].into_iter();
474
475 state.load_from(&mut session_state, &mut state_initializers)?;
476
477 for shape in input_shapes {
478 let len = shape.iter().product::<usize>();
479 let input = Tensor::from_shape(&shape, &(0..len).map(|f| f.as_()).collect::<Vec<F>>())?;
480 inputs.push(input.clone().into_tvalue());
481 state.eval(&mut session_state, &op, tvec!(input.clone().into()))?[0]
482 .clone()
483 .into_tensor();
484 }
485
486 let mut curr_states = vec![];
487 state.save_to(&mut curr_states)?;
488 let output = curr_states.remove(0);
489
490 let reference = &TypedConcat { axis }.eval(inputs)?[0];
491 output.close_enough(&reference.clone().into_tensor(), Approximation::Close)?;
492 Ok(())
493 }
494
495 #[test]
496 fn test_dyn_kv_cache() -> TractResult<()> {
497 run_test_case::<f32>(&[vec![2, 2]], 0)?;
498 run_test_case::<f32>(&[vec![2, 2], vec![4, 2]], 0)?;
499 run_test_case::<f32>(&[vec![2, 2], vec![2, 1], vec![2, 3]], 1)?;
500 Ok(())
501 }
502
503 #[test]
504 fn test_unfold_kv_cache() -> TractResult<()> {
505 let mut model = TypedModel::default();
507 let s = model.sym("S");
508 let p = model.sym("P");
509
510 let input_shape: TVec<TDim> = tvec![1.to_dim(), s.into(), 64.to_dim()];
511 let past_shape: TVec<TDim> = tvec![1.to_dim(), p.into(), 64.to_dim()];
512
513 let input = model.add_source("input", f32::fact(&input_shape))?;
514 let op = DynKeyValueCache {
515 name: "kv_cache_0".to_string(),
516 axis: 1,
517 past_sequence_fact: f32::fact(&past_shape),
518 input_sequence_fact: f32::fact(&input_shape),
519 };
520 let out = model.wire_node("kv_cache", op, &[input])?;
521 model.select_output_outlets(&out)?;
522
523 assert_eq!(model.inputs.len(), 1);
525 assert_eq!(model.outputs.len(), 1);
526 assert!(model.node(1).op_is::<DynKeyValueCache>());
527
528 unfold_kv_cache(&mut model, 1)?;
530
531 assert_eq!(model.inputs.len(), 2);
533 assert_eq!(model.outputs.len(), 2);
534
535 assert!(model.node(1).op_is::<TypedConcat>());
537 let concat = model.node(1).op_as::<TypedConcat>().unwrap();
538 assert_eq!(concat.axis, 1);
539
540 let source_node_id = model.inputs[1].node;
542 assert!(model.node(source_node_id).op_is::<TypedSource>());
543 assert_eq!(model.node(source_node_id).name, "kv_cache_0");
544
545 assert_eq!(model.node(1).inputs.len(), 2);
547 assert_eq!(model.node(1).inputs[0].node, source_node_id);
548 assert_eq!(model.node(1).inputs[1].node, 0); Ok(())
551 }
552
553 #[test]
554 fn test_fold_unfold_round_trip() -> TractResult<()> {
555 use crate::rewriter::KeyValueCacheTransform;
556 use tract_nnef::tract_core::transform::ModelTransform;
557
558 let mut model = TypedModel::default();
560 let s = model.sym("S");
561 let p = model.sym("P");
562
563 let input_shape: TVec<TDim> = tvec![1.to_dim(), s.into(), 64.to_dim()];
564 let past_shape: TVec<TDim> = tvec![1.to_dim(), p.into(), 64.to_dim()];
565
566 let past = model.add_source("kv_past", f32::fact(&past_shape))?;
567 let input = model.add_source("input", f32::fact(&input_shape))?;
568 let concat = model.wire_node("concat", TypedConcat { axis: 1 }, &[past, input])?;
569 model.select_output_outlets(&concat)?;
570
571 let orig_input_count = model.inputs.len();
572 let orig_output_count = model.outputs.len();
573
574 KeyValueCacheTransform.transform(&mut model)?;
576 assert_eq!(model.inputs.len(), orig_input_count - 1); assert_eq!(model.outputs.len(), orig_output_count - 1); let kv_node_id = model.nodes().iter().find(|n| n.op_is::<DynKeyValueCache>()).unwrap().id;
581
582 unfold_kv_cache(&mut model, kv_node_id)?;
584
585 assert_eq!(model.inputs.len(), orig_input_count);
587 assert_eq!(model.outputs.len(), orig_output_count);
588
589 let concat_node = model.nodes().iter().find(|n| n.op_is::<TypedConcat>()).unwrap();
591 assert_eq!(concat_node.op_as::<TypedConcat>().unwrap().axis, 1);
592 assert_eq!(concat_node.inputs.len(), 2);
593
594 Ok(())
595 }
596
597 #[test]
598 fn test_dyn_kv_cache_nnef_round_trip() -> TractResult<()> {
599 use crate::WithTractTransformers;
600
601 let mut model = TypedModel::default();
602 let s = model.sym("S");
603 let p = model.sym("P");
604
605 let input_shape: TVec<TDim> = tvec![1.to_dim(), s.into(), 64.to_dim()];
606 let past_shape: TVec<TDim> = tvec![1.to_dim(), p.into(), 64.to_dim()];
607
608 let input = model.add_source("input", f32::fact(&input_shape))?;
609 let op = DynKeyValueCache {
610 name: "kv_cache_0".to_string(),
611 axis: 1,
612 past_sequence_fact: f32::fact(&past_shape),
613 input_sequence_fact: f32::fact(&input_shape),
614 };
615 let out = model.wire_node("kv_cache", op, &[input])?;
616 model.select_output_outlets(&out)?;
617
618 let nnef = tract_nnef::nnef().with_tract_transformers();
619 let mut buffer = vec![];
620 nnef.write_to_tar(&model, &mut buffer)?;
621 let reloaded = nnef.model_for_read(&mut &*buffer)?;
622
623 assert_eq!(reloaded.nodes().len(), model.nodes().len());
624 let reloaded_kv = reloaded.node(1);
625 let reloaded_op = reloaded_kv.op_as::<DynKeyValueCache>().unwrap();
626 assert_eq!(reloaded_op.name, "kv_cache_0");
627 assert_eq!(reloaded_op.axis, 1);
628 assert_eq!(reloaded_op.past_sequence_fact.datum_type, DatumType::F32);
629 assert_eq!(reloaded_op.past_sequence_fact.shape.rank(), 3);
630 assert_eq!(reloaded_op.input_sequence_fact.datum_type, DatumType::F32);
631 assert_eq!(reloaded_op.input_sequence_fact.shape.rank(), 3);
632 Ok(())
633 }
634}