1use std::str::FromStr;
2
3use tract_nnef::internal::*;
4use tract_nnef::prelude::tract_itertools::Itertools;
5use tract_nnef::ser::{datum_type, tdims};
6use tract_nnef::tract_core::ops::OpStateFreeze;
7use tract_nnef::tract_core::ops::array::TypedConcat;
8use tract_nnef::tract_core::ops::source::TypedSource;
9
10pub fn register(registry: &mut Registry) {
11 registry.register_dumper(ser_dyn_kv_cache);
12 registry.register_primitive(
13 "tract_transformers_dyn_kv_cache",
14 &[
15 TypeName::Scalar.tensor().named("input"),
16 TypeName::String.named("name"),
17 TypeName::Integer.named("axis"),
18 TypeName::String.named("datum_type"),
19 TypeName::Integer.array().named("past_sequence_shape"),
20 TypeName::Integer.array().named("input_sequence_shape"),
21 ],
22 &[("output", TypeName::Scalar.tensor())],
23 de_dyn_kv_cache,
24 );
25}
26
27fn ser_dyn_kv_cache(
28 ast: &mut IntoAst,
29 node: &TypedNode,
30 op: &DynKeyValueCache,
31) -> TractResult<Option<Arc<RValue>>> {
32 let input = ast.mapping[&node.inputs[0]].clone();
33 Ok(Some(invocation(
34 "tract_transformers_dyn_kv_cache",
35 &[input],
36 &[
37 ("name", string(&op.name)),
38 ("axis", numeric(op.axis)),
39 ("datum_type", datum_type(op.past_sequence_fact.datum_type)),
40 ("past_sequence_shape", tdims(op.past_sequence_fact.shape.dims())),
41 ("input_sequence_shape", tdims(op.input_sequence_fact.shape.dims())),
42 ],
43 )))
44}
45
46fn de_dyn_kv_cache(
47 builder: &mut ModelBuilder,
48 invocation: &ResolvedInvocation,
49) -> TractResult<Value> {
50 let input = invocation.named_arg_as(builder, "input")?;
51 let name: String = invocation.named_arg_as(builder, "name")?;
52 let axis: usize = invocation.named_arg_as(builder, "axis")?;
53 let dt = DatumType::from_str(&invocation.named_arg_as::<String>(builder, "datum_type")?)?;
54 let past_sequence_shape: TVec<TDim> = builder
55 .allowing_new_symbols(|builder| invocation.named_arg_as(builder, "past_sequence_shape"))?;
56 let input_sequence_shape: TVec<TDim> = builder
57 .allowing_new_symbols(|builder| invocation.named_arg_as(builder, "input_sequence_shape"))?;
58 builder.wire(
59 DynKeyValueCache {
60 name,
61 axis,
62 past_sequence_fact: dt.fact(&*past_sequence_shape),
63 input_sequence_fact: dt.fact(&*input_sequence_shape),
64 },
65 &[input],
66 )
67}
68
69#[derive(Debug, Clone)]
70pub struct DynKeyValueCacheState {
71 name: String,
72 axis: usize,
73 past_sequence_fact: TypedFact,
74 kv_cache: Option<TValue>,
75}
76
77impl DynKeyValueCacheState {
78 pub fn resolve_symbols(
79 state: &mut TurnState,
80 fact: TypedFact,
81 concrete_shape: Option<&[usize]>,
82 ) -> TractResult<()> {
83 let unresolved = fact
84 .shape
85 .iter()
86 .enumerate()
87 .filter_map(|(ax, symb)| match symb {
88 TDim::Sym(s) if state.resolved_symbols.get(s).is_none() => Some((ax, s)),
89 _ => None,
90 })
91 .collect_vec();
92
93 if unresolved.is_empty() {
94 return Ok(());
95 }
96
97 ensure!(unresolved.len() == 1);
98 let (ax, sym) = unresolved[0];
99 if let Some(shape) = concrete_shape {
100 ensure!(ax < shape.len());
101 state.resolved_symbols.set(sym, shape[ax] as i64);
102 } else {
103 state.resolved_symbols.set(sym, 0);
104 }
105
106 if state.scenario.is_none() {
107 state.scenario = sym.scope().unwrap().guess_scenario(&state.resolved_symbols)?;
108 }
109 Ok(())
110 }
111
112 pub fn truncate(&mut self, len: usize) -> TractResult<()> {
113 if let Some(t) = self.kv_cache.as_mut() {
114 *t = t.slice(self.axis, 0, len)?.into_tvalue();
115 } else {
116 bail!("Can not truncate a zero-len kv-cache value");
117 }
118 Ok(())
119 }
120}
121
122impl OpState for DynKeyValueCacheState {
123 fn load_from(
124 &mut self,
125 state: &mut TurnState,
126 states: &mut dyn Iterator<Item = tract_nnef::prelude::TValue>,
127 ) -> TractResult<()> {
128 let kv_cache_init = states.next().context("Not enough state initializers")?;
130 Self::resolve_symbols(state, self.past_sequence_fact.clone(), Some(kv_cache_init.shape()))?;
131 self.kv_cache = Some(kv_cache_init.clone());
132
133 Ok(())
134 }
135
136 fn save_to(&self, states: &mut Vec<TValue>) -> TractResult<()> {
137 if let Some(kv_cache) = &self.kv_cache {
138 states.push(kv_cache.clone());
139 Ok(())
140 } else {
141 bail!("KV cache {} was never initialized", self.name)
142 }
143 }
144
145 fn init_tensor_fact(&self) -> Option<(String, TypedFact)> {
146 Some((self.name.clone(), self.past_sequence_fact.clone()))
147 }
148
149 fn resolve_symbols(&mut self, state: &mut TurnState) -> TractResult<()> {
150 let shape = self.kv_cache.as_ref().map(|kv_cache| kv_cache.shape());
151 Self::resolve_symbols(state, self.past_sequence_fact.clone(), shape)
152 }
153
154 fn eval(
155 &mut self,
156 _state: &mut TurnState,
157 _op: &dyn Op,
158 inputs: TVec<TValue>,
159 ) -> TractResult<TVec<TValue>> {
160 let input = args_1!(inputs);
161 let output = if let Some(curr) = self.kv_cache.take() {
163 TypedConcat { axis: self.axis }.eval(tvec![curr, input])?.remove(0)
164 } else {
165 input
166 };
167 self.kv_cache = Some(output.clone());
168
169 Ok(tvec!(output))
170 }
171}
172
173#[derive(Clone, Debug)]
174pub struct DynKeyValueCache {
175 pub name: String,
176 pub axis: usize,
177 pub past_sequence_fact: TypedFact,
178 pub input_sequence_fact: TypedFact,
179}
180
181impl Op for DynKeyValueCache {
182 fn name(&self) -> StaticName {
183 "DynamicKeyValueCache".to_string().into()
184 }
185
186 op_as_typed_op!();
187}
188
189impl EvalOp for DynKeyValueCache {
190 fn is_stateless(&self) -> bool {
191 false
192 }
193
194 fn state(
195 &self,
196 _session: &TurnState,
197 _node_id: usize,
198 ) -> TractResult<Option<Box<dyn OpState>>> {
199 Ok(Some(Box::new(DynKeyValueCacheState {
200 name: self.name.clone(),
201 axis: self.axis,
202 past_sequence_fact: self.past_sequence_fact.clone(),
203 kv_cache: None,
204 })))
205 }
206}
207
208impl TypedOp for DynKeyValueCache {
209 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
210 ensure!(inputs.len() == 1);
211 let input = inputs[0];
212 let mut fact = input.without_value();
213
214 fact.shape.set(
215 self.axis,
216 self.past_sequence_fact.shape.dims()[self.axis].clone()
217 + self.input_sequence_fact.shape.dims()[self.axis].clone(),
218 );
219 Ok(tvec!(fact))
220 }
221
222 fn cost(&self, _inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
223 let token_volume = self
224 .past_sequence_fact
225 .shape
226 .iter()
227 .enumerate()
228 .filter(|(axis, _d)| *axis != self.axis)
229 .map(|(_axis, d)| d)
230 .product::<TDim>();
231 Ok(tvec!((Cost::Custom(false, "KVCacheValuesPerToken".to_string()), token_volume)))
232 }
233
234 as_op!();
235}
236
237#[derive(Debug, Clone)]
238pub struct FrozenDynKeyValueCacheState {
239 name: String,
240 axis: usize,
241 past_sequence_fact: TypedFact,
242 kv_cache: Option<Tensor>,
243}
244
245impl OpStateFreeze for DynKeyValueCacheState {
246 fn freeze(&self) -> Box<dyn FrozenOpState> {
247 Box::new(FrozenDynKeyValueCacheState {
248 name: self.name.clone(),
249 axis: self.axis,
250 past_sequence_fact: self.past_sequence_fact.clone(),
251 kv_cache: self.kv_cache.clone().map(|t| t.into_tensor()),
252 })
253 }
254}
255
256impl FrozenOpState for FrozenDynKeyValueCacheState {
257 fn unfreeze(&self) -> Box<dyn OpState> {
258 Box::new(DynKeyValueCacheState {
259 axis: self.axis,
260 name: self.name.clone(),
261 past_sequence_fact: self.past_sequence_fact.clone(),
262 kv_cache: self.kv_cache.clone().map(|t| t.into_tvalue()),
263 })
264 }
265}
266
267pub fn unfold_kv_cache(target: &mut TypedModel, kv_node_id: usize) -> TractResult<()> {
270 let node = target.node(kv_node_id);
271 let op = node.op_as::<DynKeyValueCache>().context("Not a DynKeyValueCache node")?;
272 let name = op.name.clone();
273 let axis = op.axis;
274 let past_fact = op.past_sequence_fact.clone();
275 let input_fact = op.input_sequence_fact.clone();
276 let existing_input = node.inputs[0];
277
278 let source_outlet = target.add_source(&name, past_fact)?;
280
281 let mut output_fact = input_fact.clone();
283 output_fact.shape.set(
284 axis,
285 target.outlet_fact(source_outlet)?.shape.dims()[axis].clone()
286 + input_fact.shape.dims()[axis].clone(),
287 );
288
289 let kv_node = target.node_mut(kv_node_id);
291 kv_node.name = format!("{name}_concat");
292 kv_node.op = Box::new(TypedConcat { axis });
293 kv_node.outputs[0].fact = output_fact;
294
295 kv_node.inputs = vec![source_outlet, existing_input];
299
300 target.nodes[source_outlet.node].outputs[source_outlet.slot]
302 .successors
303 .push(InletId::new(kv_node_id, 0));
304
305 target.nodes[existing_input.node].outputs[existing_input.slot].successors.iter_mut().for_each(
307 |succ| {
308 if succ.node == kv_node_id && succ.slot == 0 {
309 succ.slot = 1;
310 }
311 },
312 );
313
314 let concat_outlet = OutletId::new(kv_node_id, 0);
316 target.outputs.push(concat_outlet);
317 target.set_outlet_label(concat_outlet, format!("{name}_concat"))?;
318
319 Ok(())
320}
321
322pub fn replace_kv_cache(target: &mut TypedModel, source_node_id: usize) -> TractResult<Option<()>> {
325 assert!(target.node(source_node_id).op_is::<TypedSource>());
326 let (concat_node_id, non_source_input_id, axis, input_facts) = {
327 rule_if_some!(concat_node = target.next_node(target.node(source_node_id)));
328
329 rule_if!(
331 concat_node.op_is::<TypedConcat>()
332 && concat_node.inputs.len() == 2
333 && concat_node.outputs.len() == 1
334 && target.outputs.contains(&concat_node.id.into())
335 );
336
337 let concat_in_facts = target.node_input_facts(concat_node.id)?;
338
339 let concat_in_shapes = [concat_in_facts[0].shape.dims(), concat_in_facts[1].shape.dims()];
341 let rank = concat_in_shapes[0].len();
342 let axes = (0..rank)
343 .filter(|ax| concat_in_shapes[0][*ax] != concat_in_shapes[1][*ax])
344 .collect_vec();
345 ensure!(axes.len() == 1);
346
347 let axis = axes[0];
348 rule_if!(
349 matches!(concat_in_shapes[0][axis], TDim::Sym(_))
350 && matches!(concat_in_shapes[1][axis], TDim::Sym(_))
351 );
352 let mut facts = [concat_in_facts[0].clone(), concat_in_facts[1].clone()];
353 if concat_node.inputs[0].node == source_node_id {
354 (concat_node.id, concat_node.inputs[1].node, axis, facts)
355 } else if concat_node.inputs[1].node == source_node_id {
356 facts.swap(0, 1);
357 (concat_node.id, concat_node.inputs[0].node, axis, facts)
358 } else {
359 return Ok(None);
360 }
361 };
362
363 {
364 let name = target.node_names().collect_vec()[source_node_id].to_string();
366 let concat_node = target.node_mut(concat_node_id);
367 concat_node.op = Box::new(DynKeyValueCache {
368 name: name.clone(),
369 axis,
370 past_sequence_fact: input_facts[0].clone(),
371 input_sequence_fact: input_facts[1].clone(),
372 });
373 concat_node.name = name;
374 concat_node.inputs.retain(|input| input != &source_node_id.into());
375 }
376
377 {
378 let dummy_op = target.create_dummy();
380 let source_node = target.node_mut(source_node_id);
381 source_node.outputs[0].successors.clear();
382 source_node.op = dummy_op;
383 }
384 {
385 let non_source_input = target.node_mut(non_source_input_id);
387 non_source_input.outputs.iter_mut().for_each(|output| {
388 output.successors.iter_mut().for_each(|succ| {
389 if succ.node == concat_node_id {
390 succ.slot = 0
391 }
392 })
393 });
394 }
395
396 target.outputs.retain(|output| output.node != concat_node_id);
398 target.inputs.retain(|input| input.node != source_node_id);
399 target.outlet_labels.remove(&concat_node_id.into());
400 Ok(None)
401}
402
403#[cfg(test)]
404mod tests {
405 use super::*;
406 use tract_num_traits::AsPrimitive;
407 use tract_num_traits::Zero;
408
409 fn run_test_case<F: Datum + Zero + Copy>(
410 input_shapes: &[Vec<usize>],
411 axis: usize,
412 ) -> TractResult<()>
413 where
414 usize: AsPrimitive<F>,
415 {
416 let first_shape = &input_shapes[0];
417 ensure!(input_shapes.iter().all(|shape| (shape.len() == first_shape.len())
418 && (shape[..axis] == first_shape[..axis])
419 && (if axis != (shape.len() - 1) {
420 shape[(axis + 1)..] == first_shape[(axis + 1)..]
421 } else {
422 true
423 })));
424
425 let op_name = "test".to_string();
426 let dummy_model = TypedModel::default();
427
428 let make_shape =
429 |sym: &str| {
430 input_shapes[0]
431 .iter()
432 .enumerate()
433 .map(|(i, &dim)| {
434 if i == axis {
435 TDim::Sym(dummy_model.sym(sym))
436 } else {
437 TDim::Val(dim as _)
438 }
439 })
440 .collect::<TVec<TDim>>()
441 };
442
443 let past_shape = make_shape("P");
444 let input_shape = make_shape("S");
445
446 let op = DynKeyValueCache {
447 name: op_name.clone(),
448 past_sequence_fact: TypedFact::dt_shape(F::datum_type(), past_shape),
449 input_sequence_fact: TypedFact::dt_shape(F::datum_type(), input_shape),
450 axis,
451 };
452
453 let mut session_state = TurnState::default();
454 let mut state = op.state(&mut session_state, 0)?.unwrap();
455
456 let mut inputs = tvec![];
457
458 let shape = &input_shapes[0];
460 let len = shape.iter().product::<usize>();
461 let input = Tensor::from_shape(shape, &(0..len).map(|f| f.as_()).collect::<Vec<F>>())?;
462 inputs.push(input.clone().into_tvalue());
463
464 let mut state_initializers = vec![input.into()].into_iter();
465
466 state.load_from(&mut session_state, &mut state_initializers)?;
467
468 for shape in input_shapes {
469 let len = shape.iter().product::<usize>();
470 let input = Tensor::from_shape(&shape, &(0..len).map(|f| f.as_()).collect::<Vec<F>>())?;
471 inputs.push(input.clone().into_tvalue());
472 state.eval(&mut session_state, &op, tvec!(input.clone().into()))?[0]
473 .clone()
474 .into_tensor();
475 }
476
477 let mut curr_states = vec![];
478 state.save_to(&mut curr_states)?;
479 let output = curr_states.remove(0);
480
481 let reference = &TypedConcat { axis }.eval(inputs)?[0];
482 output.close_enough(&reference.clone().into_tensor(), Approximation::Close)?;
483 Ok(())
484 }
485
486 #[test]
487 fn test_dyn_kv_cache() -> TractResult<()> {
488 run_test_case::<f32>(&[vec![2, 2]], 0)?;
489 run_test_case::<f32>(&[vec![2, 2], vec![4, 2]], 0)?;
490 run_test_case::<f32>(&[vec![2, 2], vec![2, 1], vec![2, 3]], 1)?;
491 Ok(())
492 }
493
494 #[test]
495 fn test_unfold_kv_cache() -> TractResult<()> {
496 let mut model = TypedModel::default();
498 let s = model.sym("S");
499 let p = model.sym("P");
500
501 let input_shape: TVec<TDim> = tvec![1.to_dim(), s.into(), 64.to_dim()];
502 let past_shape: TVec<TDim> = tvec![1.to_dim(), p.into(), 64.to_dim()];
503
504 let input = model.add_source("input", f32::fact(&input_shape))?;
505 let op = DynKeyValueCache {
506 name: "kv_cache_0".to_string(),
507 axis: 1,
508 past_sequence_fact: f32::fact(&past_shape),
509 input_sequence_fact: f32::fact(&input_shape),
510 };
511 let out = model.wire_node("kv_cache", op, &[input])?;
512 model.set_output_outlets(&out)?;
513
514 assert_eq!(model.inputs.len(), 1);
516 assert_eq!(model.outputs.len(), 1);
517 assert!(model.node(1).op_is::<DynKeyValueCache>());
518
519 unfold_kv_cache(&mut model, 1)?;
521
522 assert_eq!(model.inputs.len(), 2);
524 assert_eq!(model.outputs.len(), 2);
525
526 assert!(model.node(1).op_is::<TypedConcat>());
528 let concat = model.node(1).op_as::<TypedConcat>().unwrap();
529 assert_eq!(concat.axis, 1);
530
531 let source_node_id = model.inputs[1].node;
533 assert!(model.node(source_node_id).op_is::<TypedSource>());
534 assert_eq!(model.node(source_node_id).name, "kv_cache_0");
535
536 assert_eq!(model.node(1).inputs.len(), 2);
538 assert_eq!(model.node(1).inputs[0].node, source_node_id);
539 assert_eq!(model.node(1).inputs[1].node, 0); Ok(())
542 }
543
544 #[test]
545 fn test_fold_unfold_round_trip() -> TractResult<()> {
546 use crate::rewriter::KeyValueCacheTransform;
547 use tract_nnef::tract_core::transform::ModelTransform;
548
549 let mut model = TypedModel::default();
551 let s = model.sym("S");
552 let p = model.sym("P");
553
554 let input_shape: TVec<TDim> = tvec![1.to_dim(), s.into(), 64.to_dim()];
555 let past_shape: TVec<TDim> = tvec![1.to_dim(), p.into(), 64.to_dim()];
556
557 let past = model.add_source("kv_past", f32::fact(&past_shape))?;
558 let input = model.add_source("input", f32::fact(&input_shape))?;
559 let concat = model.wire_node("concat", TypedConcat { axis: 1 }, &[past, input])?;
560 model.set_output_outlets(&concat)?;
561
562 let orig_input_count = model.inputs.len();
563 let orig_output_count = model.outputs.len();
564
565 KeyValueCacheTransform.transform(&mut model)?;
567 assert_eq!(model.inputs.len(), orig_input_count - 1); assert_eq!(model.outputs.len(), orig_output_count - 1); let kv_node_id = model.nodes().iter().find(|n| n.op_is::<DynKeyValueCache>()).unwrap().id;
572
573 unfold_kv_cache(&mut model, kv_node_id)?;
575
576 assert_eq!(model.inputs.len(), orig_input_count);
578 assert_eq!(model.outputs.len(), orig_output_count);
579
580 let concat_node = model.nodes().iter().find(|n| n.op_is::<TypedConcat>()).unwrap();
582 assert_eq!(concat_node.op_as::<TypedConcat>().unwrap().axis, 1);
583 assert_eq!(concat_node.inputs.len(), 2);
584
585 Ok(())
586 }
587
588 #[test]
589 fn test_dyn_kv_cache_nnef_round_trip() -> TractResult<()> {
590 use crate::WithTractTransformers;
591
592 let mut model = TypedModel::default();
593 let s = model.sym("S");
594 let p = model.sym("P");
595
596 let input_shape: TVec<TDim> = tvec![1.to_dim(), s.into(), 64.to_dim()];
597 let past_shape: TVec<TDim> = tvec![1.to_dim(), p.into(), 64.to_dim()];
598
599 let input = model.add_source("input", f32::fact(&input_shape))?;
600 let op = DynKeyValueCache {
601 name: "kv_cache_0".to_string(),
602 axis: 1,
603 past_sequence_fact: f32::fact(&past_shape),
604 input_sequence_fact: f32::fact(&input_shape),
605 };
606 let out = model.wire_node("kv_cache", op, &[input])?;
607 model.set_output_outlets(&out)?;
608
609 let nnef = tract_nnef::nnef().with_tract_transformers();
610 let mut buffer = vec![];
611 nnef.write_to_tar(&model, &mut buffer)?;
612 let reloaded = nnef.model_for_read(&mut &*buffer)?;
613
614 assert_eq!(reloaded.nodes().len(), model.nodes().len());
615 let reloaded_kv = reloaded.node(1);
616 let reloaded_op = reloaded_kv.op_as::<DynKeyValueCache>().unwrap();
617 assert_eq!(reloaded_op.name, "kv_cache_0");
618 assert_eq!(reloaded_op.axis, 1);
619 assert_eq!(reloaded_op.past_sequence_fact.datum_type, DatumType::F32);
620 assert_eq!(reloaded_op.past_sequence_fact.shape.rank(), 3);
621 assert_eq!(reloaded_op.input_sequence_fact.datum_type, DatumType::F32);
622 assert_eq!(reloaded_op.input_sequence_fact.shape.rank(), 3);
623 Ok(())
624 }
625}