Skip to main content

tract_gpu/ops/
dyn_kv_cache.rs

1use crate::fact::DeviceTypedFactExt;
2use crate::tensor::{DeviceTensor, DeviceTensorExt, IntoDevice};
3use derive_new::new;
4use tract_core::internal::*;
5use tract_core::ops::OpStateFreeze;
6use tract_transformers::ops::dyn_kv_cache::{DynKeyValueCache, DynKeyValueCacheState};
7
8#[derive(Debug, Clone, new)]
9pub struct GpuDynKVCacheState {
10    node_id: usize,
11    name: String,
12    axis: usize,
13    past_sequence_fact: TypedFact,
14    kv_cache: Option<TValue>,
15}
16
17impl OpState for GpuDynKVCacheState {
18    fn load_from(
19        &mut self,
20        state: &mut TurnState,
21        states: &mut dyn Iterator<Item = TValue>,
22    ) -> TractResult<()> {
23        let kv_cache = states.next().context("Not enough state initializers")?;
24        DynKeyValueCacheState::resolve_symbols(
25            state,
26            self.past_sequence_fact.clone(),
27            Some(kv_cache.shape()),
28        )?;
29        self.kv_cache = Some(kv_cache.into_tensor().into_device()?.into_tensor().into_tvalue());
30        Ok(())
31    }
32
33    fn save_to(&self, states: &mut Vec<TValue>) -> TractResult<()> {
34        if let Some(kv_cache) = &self.kv_cache {
35            states.push(kv_cache.to_device_tensor()?.to_host()?.into_tensor().into_tvalue());
36            Ok(())
37        } else {
38            bail!("KV cache {} was never initialized", self.name)
39        }
40    }
41
42    fn init_tensor_fact(&self) -> Option<(String, TypedFact)> {
43        Some((self.name.clone(), self.past_sequence_fact.clone()))
44    }
45
46    fn resolve_symbols(&mut self, state: &mut TurnState) -> TractResult<()> {
47        let shape = self
48            .kv_cache
49            .as_ref()
50            .map(|kv_cache| kv_cache.to_device_tensor().expect("Expected GPU Tensor").shape());
51        DynKeyValueCacheState::resolve_symbols(state, self.past_sequence_fact.clone(), shape)
52    }
53
54    fn eval(
55        &mut self,
56        session: &mut TurnState,
57        op: &dyn Op,
58        inputs: TVec<TValue>,
59    ) -> TractResult<TVec<TValue>> {
60        ensure!(inputs.len() == 1);
61        let mut op_inputs = TVec::new();
62
63        if let Some(kv_cache) = self.kv_cache.take() {
64            op_inputs.push(kv_cache);
65        }
66
67        op_inputs.push(inputs.into_iter().next().unwrap());
68
69        let gpu_op =
70            op.downcast_ref::<GpuDynKVCache>().ok_or_else(|| format_err!("Wrong Op type"))?;
71        let axis = gpu_op.axis;
72
73        let inputs =
74            op_inputs.iter().map(|it| it.to_device_tensor()).collect::<TractResult<TVec<_>>>()?;
75        let mut output_shape = inputs[0].shape().to_vec();
76        output_shape[axis] = inputs.iter().map(|it| it.shape()[axis]).sum();
77        let output = crate::session_handler::make_tensor_for_node(
78            session,
79            self.node_id,
80            inputs[0].datum_type(),
81            &output_shape,
82        )?;
83
84        // Concat inputs into output
85        let ctx = crate::device::get_context()?;
86        let mut cursor = 0usize;
87        for input in &inputs {
88            let slice_len = input.shape()[axis];
89            if slice_len == 0 {
90                continue;
91            }
92            let dst_offset =
93                cursor * output.strides()[axis] as usize * output.datum_type().size_of();
94            ctx.copy_nd(
95                input,
96                0,
97                input.strides(),
98                &output,
99                dst_offset,
100                input.shape(),
101                output.strides(),
102            )?;
103            cursor += slice_len;
104        }
105
106        let res = output.into_tensor().into_tvalue();
107        self.kv_cache = Some(res.clone());
108        Ok(tvec!(res))
109    }
110}
111
112impl GpuDynKVCacheState {
113    pub fn truncate(&mut self, len: usize) -> TractResult<()> {
114        if let Some(v) = &mut self.kv_cache {
115            let mut t: Tensor = v.to_device_tensor()?.to_host()?.into_tensor();
116            t = t.slice(self.axis, 0, len)?;
117            *v = t.into_device()?.into_tensor().into_tvalue();
118        }
119        Ok(())
120    }
121}
122
123#[derive(Debug, Clone)]
124pub struct FrozenGpuDynKVCacheState {
125    node_id: usize,
126    name: String,
127    axis: usize,
128    past_sequence_fact: TypedFact,
129    kv_cache: Option<DeviceTensor>,
130}
131
132impl OpStateFreeze for GpuDynKVCacheState {
133    fn freeze(&self) -> Box<dyn FrozenOpState + 'static> {
134        Box::new(FrozenGpuDynKVCacheState {
135            node_id: self.node_id,
136            name: self.name.clone(),
137            axis: self.axis,
138            past_sequence_fact: self.past_sequence_fact.clone(),
139            kv_cache: self.kv_cache.clone().map(|t| t.to_device_tensor().cloned().unwrap()),
140        })
141    }
142
143    fn freeze_into(self: Box<Self>) -> Box<dyn FrozenOpState> {
144        Box::new(FrozenGpuDynKVCacheState {
145            node_id: self.node_id,
146            name: self.name,
147            axis: self.axis,
148            past_sequence_fact: self.past_sequence_fact,
149            kv_cache: self.kv_cache.map(|t| t.to_device_tensor().cloned().unwrap()),
150        })
151    }
152}
153
154impl FrozenOpState for FrozenGpuDynKVCacheState {
155    fn unfreeze(&self) -> Box<dyn OpState> {
156        Box::new(GpuDynKVCacheState {
157            node_id: self.node_id,
158            name: self.name.clone(),
159            axis: self.axis,
160            past_sequence_fact: self.past_sequence_fact.clone(),
161            kv_cache: self.kv_cache.clone().map(|t| t.into_tensor().into_tvalue()),
162        })
163    }
164}
165
166#[derive(Clone)]
167pub struct GpuDynKVCache {
168    pub name: String,
169    pub past_sequence_fact: TypedFact,
170    pub input_sequence_fact: TypedFact,
171    pub axis: usize,
172}
173
174impl GpuDynKVCache {
175    pub fn from_tract_transformers(op: &DynKeyValueCache) -> Self {
176        Self {
177            name: op.name.clone(),
178            axis: op.axis,
179            past_sequence_fact: op.past_sequence_fact.clone(),
180            input_sequence_fact: op.input_sequence_fact.clone(),
181        }
182    }
183}
184
185impl std::fmt::Debug for GpuDynKVCache {
186    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
187        write!(f, "GpuDynKVCache({}, axis={})", self.name, self.axis)
188    }
189}
190
191impl PartialEq for GpuDynKVCache {
192    fn eq(&self, other: &Self) -> bool {
193        self.name == other.name
194            && self.axis == other.axis
195            && self.past_sequence_fact == other.past_sequence_fact
196            && self.input_sequence_fact == other.input_sequence_fact
197    }
198}
199
200impl Eq for GpuDynKVCache {}
201
202impl std::hash::Hash for GpuDynKVCache {
203    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
204        self.name.hash(state);
205        self.axis.hash(state);
206    }
207}
208
209impl Op for GpuDynKVCache {
210    fn name(&self) -> StaticName {
211        "GpuDynKVCache".into()
212    }
213
214    fn info(&self) -> TractResult<Vec<String>> {
215        Ok(vec![format!("axis: {}", self.axis)])
216    }
217
218    op_as_typed_op!();
219}
220
221impl EvalOp for GpuDynKVCache {
222    fn is_stateless(&self) -> bool {
223        false
224    }
225
226    fn state(&self, _session: &TurnState, node_id: usize) -> TractResult<Option<Box<dyn OpState>>> {
227        Ok(Some(Box::new(GpuDynKVCacheState::new(
228            node_id,
229            self.name.clone(),
230            self.axis,
231            self.past_sequence_fact.clone(),
232            None,
233        ))))
234    }
235}
236
237impl TypedOp for GpuDynKVCache {
238    fn output_facts(&self, inputs: &[&TypedFact]) -> TractResult<TVec<TypedFact>> {
239        ensure!(inputs.len() == 1);
240        let mut facts = crate::utils::facts_to_device_facts(inputs, |facts| {
241            let mut fact = facts[0].without_value();
242            fact.shape.set(
243                self.axis,
244                self.past_sequence_fact.shape.dims()[self.axis].clone()
245                    + self.input_sequence_fact.shape.dims()[self.axis].clone(),
246            );
247            Ok(tvec!(fact))
248        })
249        .with_context(|| format!("Error while computing facts for {:?}", self.name()))?;
250        facts[0].as_device_fact_mut().unwrap().state_owned = true;
251        Ok(facts)
252    }
253
254    as_op!();
255}