1use crate::fact::DeviceTypedFactExt;
2use crate::tensor::{DeviceTensor, DeviceTensorExt, IntoDevice};
3use derive_new::new;
4use tract_core::internal::*;
5use tract_core::ops::OpStateFreeze;
6use tract_transformers::ops::dyn_kv_cache::{DynKeyValueCache, DynKeyValueCacheState};
7
8#[derive(Debug, Clone, new)]
9pub struct GpuDynKVCacheState {
10 node_id: usize,
11 name: String,
12 axis: usize,
13 past_sequence_fact: TypedFact,
14 kv_cache: Option<TValue>,
15}
16
17impl OpState for GpuDynKVCacheState {
18 fn load_from(
19 &mut self,
20 state: &mut TurnState,
21 states: &mut dyn Iterator<Item = TValue>,
22 ) -> TractResult<()> {
23 let kv_cache = states.next().context("Not enough state initializers")?;
24 DynKeyValueCacheState::resolve_symbols(
25 state,
26 self.past_sequence_fact.clone(),
27 Some(kv_cache.shape()),
28 )?;
29 self.kv_cache = Some(kv_cache.into_tensor().into_device()?.into_tensor().into_tvalue());
30 Ok(())
31 }
32
33 fn save_to(&self, states: &mut Vec<TValue>) -> TractResult<()> {
34 if let Some(kv_cache) = &self.kv_cache {
35 states.push(kv_cache.to_device_tensor()?.to_host()?.into_tensor().into_tvalue());
36 Ok(())
37 } else {
38 bail!("KV cache {} was never initialized", self.name)
39 }
40 }
41
42 fn init_tensor_fact(&self) -> Option<(String, TypedFact)> {
43 Some((self.name.clone(), self.past_sequence_fact.clone()))
44 }
45
46 fn resolve_symbols(&mut self, state: &mut TurnState) -> TractResult<()> {
47 let shape = self
48 .kv_cache
49 .as_ref()
50 .map(|kv_cache| kv_cache.to_device_tensor().expect("Expected GPU Tensor").shape());
51 DynKeyValueCacheState::resolve_symbols(state, self.past_sequence_fact.clone(), shape)
52 }
53
54 fn eval(
55 &mut self,
56 session: &mut TurnState,
57 op: &dyn Op,
58 inputs: TVec<TValue>,
59 ) -> TractResult<TVec<TValue>> {
60 ensure!(inputs.len() == 1);
61 let mut op_inputs = TVec::new();
62
63 if let Some(kv_cache) = self.kv_cache.take() {
64 op_inputs.push(kv_cache);
65 }
66
67 op_inputs.push(inputs.into_iter().next().unwrap());
68
69 let gpu_op =
70 op.downcast_ref::<GpuDynKVCache>().ok_or_else(|| format_err!("Wrong Op type"))?;
71 let axis = gpu_op.axis;
72
73 let inputs =
74 op_inputs.iter().map(|it| it.to_device_tensor()).collect::<TractResult<TVec<_>>>()?;
75 let mut output_shape = inputs[0].shape().to_vec();
76 output_shape[axis] = inputs.iter().map(|it| it.shape()[axis]).sum();
77 let output = crate::session_handler::make_tensor_for_node(
78 session,
79 self.node_id,
80 inputs[0].datum_type(),
81 &output_shape,
82 )?;
83
84 let ctx = crate::device::get_context()?;
86 let mut cursor = 0usize;
87 for input in &inputs {
88 let slice_len = input.shape()[axis];
89 if slice_len == 0 {
90 continue;
91 }
92 let dst_offset =
93 cursor * output.strides()[axis] as usize * output.datum_type().size_of();
94 ctx.copy_nd(
95 input,
96 0,
97 input.strides(),
98 &output,
99 dst_offset,
100 input.shape(),
101 output.strides(),
102 )?;
103 cursor += slice_len;
104 }
105
106 let res = output.into_tensor().into_tvalue();
107 self.kv_cache = Some(res.clone());
108 Ok(tvec!(res))
109 }
110}
111
112impl GpuDynKVCacheState {
113 pub fn truncate(&mut self, len: usize) -> TractResult<()> {
114 if let Some(v) = &mut self.kv_cache {
115 let mut t: Tensor = v.to_device_tensor()?.to_host()?.into_tensor();
116 t = t.slice(self.axis, 0, len)?;
117 *v = t.into_device()?.into_tensor().into_tvalue();
118 }
119 Ok(())
120 }
121}
122
123#[derive(Debug, Clone)]
124pub struct FrozenGpuDynKVCacheState {
125 node_id: usize,
126 name: String,
127 axis: usize,
128 past_sequence_fact: TypedFact,
129 kv_cache: Option<DeviceTensor>,
130}
131
132impl OpStateFreeze for GpuDynKVCacheState {
133 fn freeze(&self) -> Box<dyn FrozenOpState + 'static> {
134 Box::new(FrozenGpuDynKVCacheState {
135 node_id: self.node_id,
136 name: self.name.clone(),
137 axis: self.axis,
138 past_sequence_fact: self.past_sequence_fact.clone(),
139 kv_cache: self.kv_cache.clone().map(|t| t.to_device_tensor().cloned().unwrap()),
140 })
141 }
142
143 fn freeze_into(self: Box<Self>) -> Box<dyn FrozenOpState> {
144 Box::new(FrozenGpuDynKVCacheState {
145 node_id: self.node_id,
146 name: self.name,
147 axis: self.axis,
148 past_sequence_fact: self.past_sequence_fact,
149 kv_cache: self.kv_cache.map(|t| t.to_device_tensor().cloned().unwrap()),
150 })
151 }
152}
153
154impl FrozenOpState for FrozenGpuDynKVCacheState {
155 fn unfreeze(&self) -> Box<dyn OpState> {
156 Box::new(GpuDynKVCacheState {
157 node_id: self.node_id,
158 name: self.name.clone(),
159 axis: self.axis,
160 past_sequence_fact: self.past_sequence_fact.clone(),
161 kv_cache: self.kv_cache.clone().map(|t| t.into_tensor().into_tvalue()),
162 })
163 }
164}
165
166#[derive(Clone)]
167pub struct GpuDynKVCache {
168 pub name: String,
169 pub past_sequence_fact: TypedFact,
170 pub input_sequence_fact: TypedFact,
171 pub axis: usize,
172}
173
174impl GpuDynKVCache {
175 pub fn from_tract_transformers(op: &DynKeyValueCache) -> Self {
176 Self {
177 name: op.name.clone(),
178 axis: op.axis,
179 past_sequence_fact: op.past_sequence_fact.clone(),
180 input_sequence_fact: op.input_sequence_fact.clone(),
181 }
182 }
183}
184
185impl std::fmt::Debug for GpuDynKVCache {
186 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
187 write!(f, "GpuDynKVCache({}, axis={})", self.name, self.axis)
188 }
189}
190
191impl PartialEq for GpuDynKVCache {
192 fn eq(&self, other: &Self) -> bool {
193 self.name == other.name
194 && self.axis == other.axis
195 && self.past_sequence_fact == other.past_sequence_fact
196 && self.input_sequence_fact == other.input_sequence_fact
197 }
198}
199
200impl Eq for GpuDynKVCache {}
201
202impl std::hash::Hash for GpuDynKVCache {
203 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
204 self.name.hash(state);
205 self.axis.hash(state);
206 }
207}
208
209impl Op for GpuDynKVCache {
210 fn name(&self) -> StaticName {
211 "GpuDynKVCache".into()
212 }
213
214 fn info(&self) -> TractResult<Vec<String>> {
215 Ok(vec![format!("axis: {}", self.axis)])
216 }
217
218 op_as_typed_op!();
219}
220
221impl EvalOp for GpuDynKVCache {
222 fn is_stateless(&self) -> bool {
223 false
224 }
225
226 fn state(&self, _session: &TurnState, node_id: usize) -> TractResult<Option<Box<dyn OpState>>> {
227 Ok(Some(Box::new(GpuDynKVCacheState::new(
228 node_id,
229 self.name.clone(),
230 self.axis,
231 self.past_sequence_fact.clone(),
232 None,
233 ))))
234 }
235}
236
237impl TypedOp for GpuDynKVCache {
238 fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
239 ensure!(inputs.len() == 1);
240 let mut facts = crate::utils::facts_to_device_facts(inputs, |facts| {
241 let mut fact = facts[0].without_value();
242 fact.shape.set(
243 self.axis,
244 self.past_sequence_fact.shape.dims()[self.axis].clone()
245 + self.input_sequence_fact.shape.dims()[self.axis].clone(),
246 );
247 Ok(tvec!(fact))
248 })
249 .with_context(|| format!("Error while computing facts for {:?}", self.name()))?;
250 facts[0].as_device_fact_mut().unwrap().state_owned = true;
251 Ok(facts)
252 }
253
254 as_op!();
255}