1use tract_nnef::internal::*;
2use tract_nnef::prelude::tract_itertools::Itertools;
3use tract_nnef::tract_core::ops::array::TypedConcat;
4use tract_nnef::tract_core::ops::source::TypedSource;
5use tract_nnef::tract_core::ops::OpStateFreeze;
6
7use crate::rule_ensure;
8
9use super::next_node;
10
11#[derive(Debug, Clone)]
12pub struct DynKeyValueCacheState {
13 name: String,
14 axis: usize,
15 past_sequence_fact: TypedFact,
16 kv_cache: Option<TValue>,
17}
18
19impl DynKeyValueCacheState {
20 pub fn resolve_symbols(
21 state: &mut SessionState,
22 fact: TypedFact,
23 concrete_shape: Option<&[usize]>,
24 ) -> TractResult<()> {
25 let unresolved = fact
26 .shape
27 .iter()
28 .enumerate()
29 .filter_map(|(ax, symb)| match symb {
30 TDim::Sym(s) if state.resolved_symbols.get(s).is_none() => Some((ax, s)),
31 _ => None,
32 })
33 .collect_vec();
34
35 if unresolved.is_empty() {
36 return Ok(());
37 }
38
39 ensure!(unresolved.len() == 1);
40 let (ax, sym) = unresolved[0];
41 if let Some(shape) = concrete_shape {
42 ensure!(ax < shape.len());
43 state.resolved_symbols.set(sym, shape[ax] as i64);
44 } else {
45 state.resolved_symbols.set(sym, 0);
46 }
47
48 if state.scenario.is_none() {
49 state.scenario = sym.scope().unwrap().guess_scenario(&state.resolved_symbols)?;
50 }
51 Ok(())
52 }
53
54 pub fn truncate(&mut self, len: usize) -> TractResult<()> {
55 if let Some(t) = self.kv_cache.as_mut() {
56 *t = t.slice(self.axis, 0, len)?.into_tvalue();
57 } else {
58 bail!("Can not truncate a zero-len kv-cache value");
59 }
60 Ok(())
61 }
62}
63
64impl OpState for DynKeyValueCacheState {
65 fn load_from(&mut self, state: &mut SessionState, states: &mut Vec<TValue>) -> TractResult<()> {
66 let kv_cache_init = states.remove(0);
67 Self::resolve_symbols(state, self.past_sequence_fact.clone(), Some(kv_cache_init.shape()))?;
69 self.kv_cache = Some(kv_cache_init);
70
71 Ok(())
72 }
73
74 fn save_to(&self, states: &mut Vec<TValue>) -> TractResult<()> {
75 if let Some(kv_cache) = &self.kv_cache {
76 states.push(kv_cache.clone());
77 Ok(())
78 } else {
79 bail!("KV cache {} was never initialized", self.name)
80 }
81 }
82
83 fn init_tensor_fact(&self) -> Option<TypedFact> {
84 Some(self.past_sequence_fact.clone())
85 }
86
87 fn resolve_symbols(&mut self, state: &mut SessionState) -> TractResult<()> {
88 let shape = self.kv_cache.as_ref().map(|kv_cache| kv_cache.shape());
89 Self::resolve_symbols(state, self.past_sequence_fact.clone(), shape)
90 }
91
92 fn eval(
93 &mut self,
94 _state: &mut SessionState,
95 _op: &dyn Op,
96 inputs: TVec<TValue>,
97 ) -> TractResult<TVec<TValue>> {
98 let input = args_1!(inputs);
99 let output = if let Some(curr) = self.kv_cache.take() {
101 TypedConcat { axis: self.axis }.eval(tvec![curr, input])?.remove(0)
102 } else {
103 input
104 };
105 self.kv_cache = Some(output.clone());
106
107 Ok(tvec!(output))
108 }
109}
110
111#[derive(Clone, Debug)]
112pub struct DynKeyValueCache {
113 pub name: String,
114 pub axis: usize,
115 pub past_sequence_fact: TypedFact,
116 pub input_sequence_fact: TypedFact,
117}
118
119impl Op for DynKeyValueCache {
120 fn name(&self) -> StaticName {
121 "DynamicKeyValueCache".to_string().into()
122 }
123
124 op_as_typed_op!();
125}
126
127impl EvalOp for DynKeyValueCache {
128 fn is_stateless(&self) -> bool {
129 false
130 }
131
132 fn state(
133 &self,
134 _session: &mut SessionState,
135 _node_id: usize,
136 ) -> TractResult<Option<Box<dyn OpState>>> {
137 Ok(Some(Box::new(DynKeyValueCacheState {
138 name: self.name.clone(),
139 axis: self.axis,
140 past_sequence_fact: self.past_sequence_fact.clone(),
141 kv_cache: None,
142 })))
143 }
144}
145
146impl TypedOp for DynKeyValueCache {
147 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
148 ensure!(inputs.len() == 1);
149 let input = inputs[0];
150 let mut fact = input.without_value();
151
152 fact.shape.set(
153 self.axis,
154 self.past_sequence_fact.shape.dims()[self.axis].clone()
155 + self.input_sequence_fact.shape.dims()[self.axis].clone(),
156 );
157 Ok(tvec!(fact))
158 }
159
160 fn cost(&self, _inputs: &[&TypedFact]) -> TractResult<TVec<(Cost, TDim)>> {
161 let token_volume = self
162 .past_sequence_fact
163 .shape
164 .iter()
165 .enumerate()
166 .filter(|(axis, _d)| *axis != self.axis)
167 .map(|(_axis, d)| d)
168 .product::<TDim>();
169 Ok(tvec!((Cost::Custom(false, "KVCacheValuesPerToken".to_string()), token_volume)))
170 }
171
172 as_op!();
173}
174
175#[derive(Debug, Clone)]
176pub struct FrozenDynKeyValueCacheState {
177 name: String,
178 axis: usize,
179 past_sequence_fact: TypedFact,
180 kv_cache: Option<Tensor>,
181}
182
183impl OpStateFreeze for DynKeyValueCacheState {
184 fn freeze(&self) -> Box<dyn FrozenOpState> {
185 Box::new(FrozenDynKeyValueCacheState {
186 name: self.name.clone(),
187 axis: self.axis,
188 past_sequence_fact: self.past_sequence_fact.clone(),
189 kv_cache: self.kv_cache.clone().map(|t| t.into_tensor()),
190 })
191 }
192}
193
194impl FrozenOpState for FrozenDynKeyValueCacheState {
195 fn unfreeze(&self) -> Box<dyn OpState> {
196 Box::new(DynKeyValueCacheState {
197 axis: self.axis,
198 name: self.name.clone(),
199 past_sequence_fact: self.past_sequence_fact.clone(),
200 kv_cache: self.kv_cache.clone().map(|t| t.into_tvalue()),
201 })
202 }
203}
204
205pub fn replace_kv_cache(target: &mut TypedModel, source_node_id: usize) -> TractResult<Option<()>> {
208 assert!(target.node(source_node_id).op_is::<TypedSource>());
209 let (concat_node_id, non_source_input_id, axis, input_facts) = {
210 let concat_node = if let Some(n_node) = next_node(target, target.node(source_node_id)) {
211 n_node
212 } else {
213 return Ok(None);
214 };
215
216 rule_ensure!(
218 concat_node.op_is::<TypedConcat>()
219 && concat_node.inputs.len() == 2
220 && concat_node.outputs.len() == 1
221 && target.outputs.contains(&concat_node.id.into())
222 );
223
224 let concat_in_facts = target.node_input_facts(concat_node.id)?;
225
226 let concat_in_shapes = [concat_in_facts[0].shape.dims(), concat_in_facts[1].shape.dims()];
228 let rank = concat_in_shapes[0].len();
229 let axes = (0..rank)
230 .filter(|ax| concat_in_shapes[0][*ax] != concat_in_shapes[1][*ax])
231 .collect_vec();
232 ensure!(axes.len() == 1);
233
234 let axis = axes[0];
235 rule_ensure!(
236 matches!(concat_in_shapes[0][axis], TDim::Sym(_))
237 && matches!(concat_in_shapes[1][axis], TDim::Sym(_))
238 );
239 let mut facts = [concat_in_facts[0].clone(), concat_in_facts[1].clone()];
240 if concat_node.inputs[0].node == source_node_id {
241 (concat_node.id, concat_node.inputs[1].node, axis, facts)
242 } else if concat_node.inputs[1].node == source_node_id {
243 facts.swap(0, 1);
244 (concat_node.id, concat_node.inputs[0].node, axis, facts)
245 } else {
246 return Ok(None);
247 }
248 };
249
250 {
251 let name = target.node_names().collect_vec()[source_node_id].to_string();
253 let concat_node = target.node_mut(concat_node_id);
254 concat_node.op = Box::new(DynKeyValueCache {
255 name: name.clone(),
256 axis,
257 past_sequence_fact: input_facts[0].clone(),
258 input_sequence_fact: input_facts[1].clone(),
259 });
260 concat_node.name = name;
261 concat_node.inputs.retain(|input| input != &source_node_id.into());
262 }
263
264 {
265 let dummy_op = target.create_dummy();
267 let source_node = target.node_mut(source_node_id);
268 source_node.outputs[0].successors.clear();
269 source_node.op = dummy_op;
270 }
271 {
272 let non_source_input = target.node_mut(non_source_input_id);
274 non_source_input.outputs.iter_mut().for_each(|output| {
275 output.successors.iter_mut().for_each(|succ| {
276 if succ.node == concat_node_id {
277 succ.slot = 0
278 }
279 })
280 });
281 }
282
283 target.outputs.retain(|output| output.node != concat_node_id);
285 target.inputs.retain(|input| input.node != source_node_id);
286 target.outlet_labels.remove(&concat_node_id.into());
287 Ok(None)
288}
289
290#[cfg(test)]
291mod tests {
292 use super::*;
293 use tract_num_traits::AsPrimitive;
294 use tract_num_traits::Zero;
295
296 fn run_test_case<F: Datum + Zero + Copy>(
297 input_shapes: &[Vec<usize>],
298 axis: usize,
299 ) -> TractResult<()>
300 where
301 usize: AsPrimitive<F>,
302 {
303 let first_shape = &input_shapes[0];
304 ensure!(input_shapes.iter().all(|shape| (shape.len() == first_shape.len())
305 && (shape[..axis] == first_shape[..axis])
306 && (if axis != (shape.len() - 1) {
307 shape[(axis + 1)..] == first_shape[(axis + 1)..]
308 } else {
309 true
310 })));
311
312 let op_name = "test".to_string();
313 let dummy_model = TypedModel::default();
314
315 let make_shape =
316 |sym: &str| {
317 input_shapes[0]
318 .iter()
319 .enumerate()
320 .map(|(i, &dim)| {
321 if i == axis {
322 TDim::Sym(dummy_model.sym(sym))
323 } else {
324 TDim::Val(dim as _)
325 }
326 })
327 .collect::<TVec<TDim>>()
328 };
329
330 let past_shape = make_shape("P");
331 let input_shape = make_shape("S");
332
333 let op = DynKeyValueCache {
334 name: op_name.clone(),
335 past_sequence_fact: TypedFact::dt_shape(F::datum_type(), past_shape),
336 input_sequence_fact: TypedFact::dt_shape(F::datum_type(), input_shape),
337 axis,
338 };
339
340 let mut session_state = SessionState::default();
341 let mut state = op.state(&mut session_state, 0)?.unwrap();
342
343 let mut inputs = tvec![];
344
345 let shape = &input_shapes[0];
347 let len = shape.iter().product::<usize>();
348 let input = Tensor::from_shape(shape, &(0..len).map(|f| f.as_()).collect::<Vec<F>>())?;
349 inputs.push(input.clone().into_tvalue());
350
351 let mut state_initializers = vec![input.into()];
352
353 state.load_from(&mut session_state, &mut state_initializers)?;
354
355 for shape in input_shapes {
356 let len = shape.iter().product::<usize>();
357 let input = Tensor::from_shape(&shape, &(0..len).map(|f| f.as_()).collect::<Vec<F>>())?;
358 inputs.push(input.clone().into_tvalue());
359 state.eval(&mut session_state, &op, tvec!(input.clone().into()))?[0]
360 .clone()
361 .into_tensor();
362 }
363
364 let mut curr_states = vec![];
365 state.save_to(&mut curr_states)?;
366 let output = curr_states.remove(0);
367
368 let reference = &TypedConcat { axis }.eval(inputs)?[0];
369 output.close_enough(&reference.clone().into_tensor(), Approximation::Close)?;
370 Ok(())
371 }
372
373 #[test]
374 fn test_dyn_kv_cache() -> TractResult<()> {
375 run_test_case::<f32>(&[vec![2, 2]], 0)?;
376 run_test_case::<f32>(&[vec![2, 2], vec![4, 2]], 0)?;
377 run_test_case::<f32>(&[vec![2, 2], vec![2, 1], vec![2, 3]], 1)?;
378 Ok(())
379 }
380}