1use tract_nnef::internal::*;
2use tract_nnef::prelude::tract_itertools::Itertools;
3use tract_nnef::tract_core::ops::OpStateFreeze;
4use tract_nnef::tract_core::ops::array::TypedConcat;
5use tract_nnef::tract_core::ops::source::TypedSource;
6
7use super::next_node;
8
9#[derive(Debug, Clone)]
10pub struct DynKeyValueCacheState {
11 name: String,
12 axis: usize,
13 past_sequence_fact: TypedFact,
14 kv_cache: Option<TValue>,
15}
16
17impl DynKeyValueCacheState {
18 pub fn resolve_symbols(
19 state: &mut TurnState,
20 fact: TypedFact,
21 concrete_shape: Option<&[usize]>,
22 ) -> TractResult<()> {
23 let unresolved = fact
24 .shape
25 .iter()
26 .enumerate()
27 .filter_map(|(ax, symb)| match symb {
28 TDim::Sym(s) if state.resolved_symbols.get(s).is_none() => Some((ax, s)),
29 _ => None,
30 })
31 .collect_vec();
32
33 if unresolved.is_empty() {
34 return Ok(());
35 }
36
37 ensure!(unresolved.len() == 1);
38 let (ax, sym) = unresolved[0];
39 if let Some(shape) = concrete_shape {
40 ensure!(ax < shape.len());
41 state.resolved_symbols.set(sym, shape[ax] as i64);
42 } else {
43 state.resolved_symbols.set(sym, 0);
44 }
45
46 if state.scenario.is_none() {
47 state.scenario = sym.scope().unwrap().guess_scenario(&state.resolved_symbols)?;
48 }
49 Ok(())
50 }
51
52 pub fn truncate(&mut self, len: usize) -> TractResult<()> {
53 if let Some(t) = self.kv_cache.as_mut() {
54 *t = t.slice(self.axis, 0, len)?.into_tvalue();
55 } else {
56 bail!("Can not truncate a zero-len kv-cache value");
57 }
58 Ok(())
59 }
60}
61
62impl OpState for DynKeyValueCacheState {
63 fn load_from(
64 &mut self,
65 state: &mut TurnState,
66 states: &mut dyn Iterator<Item = tract_nnef::prelude::TValue>,
67 ) -> TractResult<()> {
68 let kv_cache_init = states.next().context("Not enough state initializers")?;
70 Self::resolve_symbols(state, self.past_sequence_fact.clone(), Some(kv_cache_init.shape()))?;
71 self.kv_cache = Some(kv_cache_init.clone());
72
73 Ok(())
74 }
75
76 fn save_to(&self, states: &mut Vec<TValue>) -> TractResult<()> {
77 if let Some(kv_cache) = &self.kv_cache {
78 states.push(kv_cache.clone());
79 Ok(())
80 } else {
81 bail!("KV cache {} was never initialized", self.name)
82 }
83 }
84
85 fn init_tensor_fact(&self) -> Option<(String, TypedFact)> {
86 Some((self.name.clone(), self.past_sequence_fact.clone()))
87 }
88
89 fn resolve_symbols(&mut self, state: &mut TurnState) -> TractResult<()> {
90 let shape = self.kv_cache.as_ref().map(|kv_cache| kv_cache.shape());
91 Self::resolve_symbols(state, self.past_sequence_fact.clone(), shape)
92 }
93
94 fn eval(
95 &mut self,
96 _state: &mut TurnState,
97 _op: &dyn Op,
98 inputs: TVec<TValue>,
99 ) -> TractResult<TVec<TValue>> {
100 let input = args_1!(inputs);
101 let output = if let Some(curr) = self.kv_cache.take() {
103 TypedConcat { axis: self.axis }.eval(tvec![curr, input])?.remove(0)
104 } else {
105 input
106 };
107 self.kv_cache = Some(output.clone());
108
109 Ok(tvec!(output))
110 }
111}
112
113#[derive(Clone, Debug)]
114pub struct DynKeyValueCache {
115 pub name: String,
116 pub axis: usize,
117 pub past_sequence_fact: TypedFact,
118 pub input_sequence_fact: TypedFact,
119}
120
121impl Op for DynKeyValueCache {
122 fn name(&self) -> StaticName {
123 "DynamicKeyValueCache".to_string().into()
124 }
125
126 op_as_typed_op!();
127}
128
129impl EvalOp for DynKeyValueCache {
130 fn is_stateless(&self) -> bool {
131 false
132 }
133
134 fn state(
135 &self,
136 _session: &TurnState,
137 _node_id: usize,
138 ) -> TractResult<Option<Box<dyn OpState>>> {
139 Ok(Some(Box::new(DynKeyValueCacheState {
140 name: self.name.clone(),
141 axis: self.axis,
142 past_sequence_fact: self.past_sequence_fact.clone(),
143 kv_cache: None,
144 })))
145 }
146}
147
148impl TypedOp for DynKeyValueCache {
149 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
150 ensure!(inputs.len() == 1);
151 let input = inputs[0];
152 let mut fact = input.without_value();
153
154 fact.shape.set(
155 self.axis,
156 self.past_sequence_fact.shape.dims()[self.axis].clone()
157 + self.input_sequence_fact.shape.dims()[self.axis].clone(),
158 );
159 Ok(tvec!(fact))
160 }
161
162 fn cost(&self, _inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
163 let token_volume = self
164 .past_sequence_fact
165 .shape
166 .iter()
167 .enumerate()
168 .filter(|(axis, _d)| *axis != self.axis)
169 .map(|(_axis, d)| d)
170 .product::<TDim>();
171 Ok(tvec!((Cost::Custom(false, "KVCacheValuesPerToken".to_string()), token_volume)))
172 }
173
174 as_op!();
175}
176
177#[derive(Debug, Clone)]
178pub struct FrozenDynKeyValueCacheState {
179 name: String,
180 axis: usize,
181 past_sequence_fact: TypedFact,
182 kv_cache: Option<Tensor>,
183}
184
185impl OpStateFreeze for DynKeyValueCacheState {
186 fn freeze(&self) -> Box<dyn FrozenOpState> {
187 Box::new(FrozenDynKeyValueCacheState {
188 name: self.name.clone(),
189 axis: self.axis,
190 past_sequence_fact: self.past_sequence_fact.clone(),
191 kv_cache: self.kv_cache.clone().map(|t| t.into_tensor()),
192 })
193 }
194}
195
196impl FrozenOpState for FrozenDynKeyValueCacheState {
197 fn unfreeze(&self) -> Box<dyn OpState> {
198 Box::new(DynKeyValueCacheState {
199 axis: self.axis,
200 name: self.name.clone(),
201 past_sequence_fact: self.past_sequence_fact.clone(),
202 kv_cache: self.kv_cache.clone().map(|t| t.into_tvalue()),
203 })
204 }
205}
206
207pub fn replace_kv_cache(target: &mut TypedModel, source_node_id: usize) -> TractResult<Option<()>> {
210 assert!(target.node(source_node_id).op_is::<TypedSource>());
211 let (concat_node_id, non_source_input_id, axis, input_facts) = {
212 rule_if_some!(concat_node = next_node(target, target.node(source_node_id)));
213
214 rule_if!(
216 concat_node.op_is::<TypedConcat>()
217 && concat_node.inputs.len() == 2
218 && concat_node.outputs.len() == 1
219 && target.outputs.contains(&concat_node.id.into())
220 );
221
222 let concat_in_facts = target.node_input_facts(concat_node.id)?;
223
224 let concat_in_shapes = [concat_in_facts[0].shape.dims(), concat_in_facts[1].shape.dims()];
226 let rank = concat_in_shapes[0].len();
227 let axes = (0..rank)
228 .filter(|ax| concat_in_shapes[0][*ax] != concat_in_shapes[1][*ax])
229 .collect_vec();
230 ensure!(axes.len() == 1);
231
232 let axis = axes[0];
233 rule_if!(
234 matches!(concat_in_shapes[0][axis], TDim::Sym(_))
235 && matches!(concat_in_shapes[1][axis], TDim::Sym(_))
236 );
237 let mut facts = [concat_in_facts[0].clone(), concat_in_facts[1].clone()];
238 if concat_node.inputs[0].node == source_node_id {
239 (concat_node.id, concat_node.inputs[1].node, axis, facts)
240 } else if concat_node.inputs[1].node == source_node_id {
241 facts.swap(0, 1);
242 (concat_node.id, concat_node.inputs[0].node, axis, facts)
243 } else {
244 return Ok(None);
245 }
246 };
247
248 {
249 let name = target.node_names().collect_vec()[source_node_id].to_string();
251 let concat_node = target.node_mut(concat_node_id);
252 concat_node.op = Box::new(DynKeyValueCache {
253 name: name.clone(),
254 axis,
255 past_sequence_fact: input_facts[0].clone(),
256 input_sequence_fact: input_facts[1].clone(),
257 });
258 concat_node.name = name;
259 concat_node.inputs.retain(|input| input != &source_node_id.into());
260 }
261
262 {
263 let dummy_op = target.create_dummy();
265 let source_node = target.node_mut(source_node_id);
266 source_node.outputs[0].successors.clear();
267 source_node.op = dummy_op;
268 }
269 {
270 let non_source_input = target.node_mut(non_source_input_id);
272 non_source_input.outputs.iter_mut().for_each(|output| {
273 output.successors.iter_mut().for_each(|succ| {
274 if succ.node == concat_node_id {
275 succ.slot = 0
276 }
277 })
278 });
279 }
280
281 target.outputs.retain(|output| output.node != concat_node_id);
283 target.inputs.retain(|input| input.node != source_node_id);
284 target.outlet_labels.remove(&concat_node_id.into());
285 Ok(None)
286}
287
288#[cfg(test)]
289mod tests {
290 use super::*;
291 use tract_num_traits::AsPrimitive;
292 use tract_num_traits::Zero;
293
294 fn run_test_case<F: Datum + Zero + Copy>(
295 input_shapes: &[Vec<usize>],
296 axis: usize,
297 ) -> TractResult<()>
298 where
299 usize: AsPrimitive<F>,
300 {
301 let first_shape = &input_shapes[0];
302 ensure!(input_shapes.iter().all(|shape| (shape.len() == first_shape.len())
303 && (shape[..axis] == first_shape[..axis])
304 && (if axis != (shape.len() - 1) {
305 shape[(axis + 1)..] == first_shape[(axis + 1)..]
306 } else {
307 true
308 })));
309
310 let op_name = "test".to_string();
311 let dummy_model = TypedModel::default();
312
313 let make_shape =
314 |sym: &str| {
315 input_shapes[0]
316 .iter()
317 .enumerate()
318 .map(|(i, &dim)| {
319 if i == axis {
320 TDim::Sym(dummy_model.sym(sym))
321 } else {
322 TDim::Val(dim as _)
323 }
324 })
325 .collect::<TVec<TDim>>()
326 };
327
328 let past_shape = make_shape("P");
329 let input_shape = make_shape("S");
330
331 let op = DynKeyValueCache {
332 name: op_name.clone(),
333 past_sequence_fact: TypedFact::dt_shape(F::datum_type(), past_shape),
334 input_sequence_fact: TypedFact::dt_shape(F::datum_type(), input_shape),
335 axis,
336 };
337
338 let mut session_state = TurnState::default();
339 let mut state = op.state(&mut session_state, 0)?.unwrap();
340
341 let mut inputs = tvec![];
342
343 let shape = &input_shapes[0];
345 let len = shape.iter().product::<usize>();
346 let input = Tensor::from_shape(shape, &(0..len).map(|f| f.as_()).collect::<Vec<F>>())?;
347 inputs.push(input.clone().into_tvalue());
348
349 let mut state_initializers = vec![input.into()].into_iter();
350
351 state.load_from(&mut session_state, &mut state_initializers)?;
352
353 for shape in input_shapes {
354 let len = shape.iter().product::<usize>();
355 let input = Tensor::from_shape(&shape, &(0..len).map(|f| f.as_()).collect::<Vec<F>>())?;
356 inputs.push(input.clone().into_tvalue());
357 state.eval(&mut session_state, &op, tvec!(input.clone().into()))?[0]
358 .clone()
359 .into_tensor();
360 }
361
362 let mut curr_states = vec![];
363 state.save_to(&mut curr_states)?;
364 let output = curr_states.remove(0);
365
366 let reference = &TypedConcat { axis }.eval(inputs)?[0];
367 output.close_enough(&reference.clone().into_tensor(), Approximation::Close)?;
368 Ok(())
369 }
370
371 #[test]
372 fn test_dyn_kv_cache() -> TractResult<()> {
373 run_test_case::<f32>(&[vec![2, 2]], 0)?;
374 run_test_case::<f32>(&[vec![2, 2], vec![4, 2]], 0)?;
375 run_test_case::<f32>(&[vec![2, 2], vec![2, 1], vec![2, 3]], 1)?;
376 Ok(())
377 }
378}