Skip to main content

tract_cuda/
transform.rs

1use tract_core::internal::tract_smallvec::ToSmallVec;
2use tract_core::internal::*;
3use tract_core::model::translator::Translate;
4use tract_core::ops::array::{MultiBroadcastTo, Slice, TypedConcat};
5use tract_core::ops::binary::TypedBinOp;
6use tract_core::ops::cast::Cast;
7use tract_core::ops::einsum::prefix_matmul::{PrefixMatMul, rewrite_einsum_to_prefix_matmul};
8use tract_core::ops::element_wise::ElementWiseOp;
9use tract_core::ops::konst::Const;
10use tract_core::ops::logic::Comp;
11use tract_core::ops::nn::{Reduce, Softmax};
12use tract_core::tract_data::itertools::Itertools;
13use tract_core::tract_linalg::block_quant::{BlockQuant, BlockQuantFact, BlockQuantValue, Q4_0};
14use tract_core::transform::ModelTransform;
15use tract_gpu::fact::{DeviceFact, DeviceTypedFactExt};
16use tract_gpu::rewrite_rules::rewire_syncs::rewire_syncs;
17use tract_gpu::sync::{DeviceSync, DeviceSyncKind};
18use tract_gpu::tensor::{DeviceTensor, DeviceTensorExt, IntoDevice};
19use tract_gpu::utils::{as_q40_fact, as_q40_tensor};
20use tract_transformers::ops::apply_rope::{ApplyRope, RotateHalf};
21use tract_transformers::ops::dyn_kv_cache::DynKeyValueCache;
22use tract_transformers::ops::gelu_approximate::GeluApproximate;
23use tract_transformers::ops::rms_norm::RmsNorm;
24use tract_transformers::ops::scaled_masked_softmax::ScaledMaskedSoftmax;
25use tract_transformers::ops::silu::Silu;
26
27use crate::context::cuda_context;
28use crate::kernels::matmul::{GemmKernel, GgmlGemm};
29use crate::{Q40_ROW_PADDING, kernels, ops, rewrite_rules};
30
31#[derive(Debug, Default)]
32pub struct CudaTransform;
33
34impl ModelTransform for CudaTransform {
35    fn name(&self) -> StaticName {
36        "cuda-transform".into()
37    }
38
39    fn transform(&self, model: &mut TypedModel) -> TractResult<()> {
40        self.transform_up_to_phase(model, usize::MAX)
41    }
42}
43
44impl CudaTransform {
45    pub fn transform_up_to_phase(
46        &self,
47        model: &mut TypedModel,
48        stop_at_phase: usize,
49    ) -> TractResult<()> {
50        // Init CUDA Context if not done previously
51        cuda_context();
52
53        rewrite_einsum_to_prefix_matmul(model)?;
54        if stop_at_phase == 0 {
55            return Ok(());
56        }
57
58        Rewriter::default()
59            .with_rule_for("untranspose_matmul_output", rewrite_rules::untranspose_matmul_output)
60            .with_rule_for("add_broadcast_pre_matmul", rewrite_rules::add_broadcast_pre_matmul)
61            .rewrite(&(), model)?;
62
63        if stop_at_phase == 1 {
64            return Ok(());
65        }
66
67        *model = self.translate_model(model)?;
68
69        if stop_at_phase == 2 {
70            return Ok(());
71        }
72
73        Rewriter::default()
74            .with_rule_for("fuse_move_axis", rewrite_rules::fuse_move_axis)
75            .rewrite(&(), model)?;
76        Rewriter::default()
77            .with_rule_for("fuse_axis_op", rewrite_rules::fuse_axis_op)
78            .rewrite(&(), model)?;
79
80        rewire_syncs(model)?;
81        Ok(())
82    }
83
84    fn sync_inputs_if_required(
85        &self,
86        model: &mut TypedModel,
87        node: &TypedNode,
88        mapping: &HashMap<OutletId, OutletId>,
89        sync_kind: DeviceSyncKind,
90    ) -> TractResult<TVec<OutletId>> {
91        let mut mapped_inputs = tvec![];
92        for (i_idx, i) in node.inputs.iter().enumerate() {
93            let in_fact = model.outlet_fact_mut(mapping[i])?;
94            match sync_kind {
95                DeviceSyncKind::ToHost if in_fact.as_device_fact().is_some() => {
96                    mapped_inputs.push(
97                        model.wire_node(
98                            format!("{}.to-cpu-{i_idx}", node.name),
99                            DeviceSync::new(sync_kind),
100                            &[mapping[i]],
101                        )?[0],
102                    );
103                }
104                DeviceSyncKind::ToDevice if in_fact.as_device_fact().is_none() => {
105                    if let Some(ref konst) = in_fact.konst {
106                        if konst.as_device_tensor().is_none() {
107                            let device_konst =
108                                konst.as_ref().clone().into_device()?.into_opaque_tensor();
109                            let device_fact = DeviceFact::from_host(in_fact.clone())?;
110
111                            *in_fact = TypedFact::dt_scalar(DatumType::Opaque)
112                                .with_opaque_fact(device_fact);
113
114                            in_fact.konst = Some(Arc::new(device_konst));
115                            mapped_inputs.push(mapping[i]);
116                            continue;
117                        }
118                    }
119                    ensure!(
120                        in_fact.datum_type.is_copy(),
121                        "Only copy DatumType can be sync to Device: {:?}",
122                        in_fact.datum_type
123                    );
124
125                    mapped_inputs.push(
126                        model.wire_node(
127                            format!("{}.to-device-{i_idx}", node.name),
128                            DeviceSync::new(sync_kind),
129                            &[mapping[i]],
130                        )?[0],
131                    );
132                }
133                _ => mapped_inputs.push(mapping[i]),
134            }
135        }
136        Ok(mapped_inputs)
137    }
138
139    fn sync_model_outputs_if_required(
140        &self,
141        src: &TypedModel,
142        node: &TypedNode,
143        target: &mut TypedModel,
144        target_node_outlet_ids: TVec<OutletId>,
145    ) -> TractResult<TVec<OutletId>> {
146        let mut outputs = tvec![];
147        for (o_idx, o) in target_node_outlet_ids.into_iter().enumerate() {
148            // Add DeviceSync op for model output
149            let is_src_output = src.outputs.contains(&OutletId::new(node.id, o_idx));
150            if target.outlet_fact(o)?.as_device_fact().is_some() && is_src_output {
151                let sync_output = target.wire_node(
152                    format!("{}.to-host-{o_idx}-out", node.name),
153                    DeviceSync::new(DeviceSyncKind::ToHost),
154                    &[o],
155                )?[0];
156                outputs.push(sync_output);
157            } else {
158                outputs.push(o)
159            }
160        }
161        Ok(outputs)
162    }
163}
164
165fn can_translate_to_cuda_op(source: &TypedModel, node: &TypedNode) -> TractResult<bool> {
166    let input_facts = source.node_input_facts(node.id)?.iter().map(|f| (*f).clone()).collect_vec();
167    let input_dts = input_facts
168        .iter()
169        .map(|f| f.as_device_fact().map(|f| f.datum_type).unwrap_or(f.datum_type))
170        .collect_vec();
171
172    let in_dts_compatible =
173        input_facts.iter().all(|fact| DeviceTensor::is_supported_dt(fact.datum_type));
174
175    Ok(in_dts_compatible
176        && (node
177            .op_as::<Const>()
178            .is_some_and(|op| DeviceTensor::is_supported_dt(op.val().datum_type()))
179            || node
180                .op_as::<Silu>()
181                .is_some_and(|_| kernels::UnaryOps::is_supported_dt(input_dts[0]))
182            || node.op_as::<ElementWiseOp>().is_some_and(|op| {
183                kernels::UnaryOps::is_supported_dt(input_dts[0])
184                    && map_element_wise_ops_to_cuda(op).is_some()
185            })
186            || node.op_as::<TypedBinOp>().is_some_and(|op| {
187                map_binary_op_to_cuda(op).is_some_and(|op| op.0.is_supported_dt(input_dts[0]))
188            })
189            || node
190                .op_as::<Comp>()
191                .is_some_and(|op| convert_logic_op_to_cuda(op).0.is_supported_dt(input_dts[0]))
192            || node
193                .op_as::<Const>()
194                .is_some_and(|op| DeviceTensor::is_supported_dt(op.val().datum_type()))
195            || node.op_as::<Cast>().is_some_and(|op| {
196                ops::CudaCast::is_supported_dt(input_dts[0]) && ops::CudaCast::new(op.to).is_some()
197            })
198            || node.op_is::<MultiBroadcastTo>()
199            || node.op_is::<AxisOp>()
200            || node.op_is::<Slice>()
201            || node.op_is::<TypedConcat>()
202            || node.op_is::<DynKeyValueCache>()
203            || node.op_as::<Reduce>().is_some_and(|op| {
204                kernels::nn::Reducer::is_supported_dt(input_dts[0])
205                    && ops::CudaReduce::from_tract_core(op).is_ok()
206            })
207            || node.op_as::<Softmax>().is_some_and(|op| {
208                kernels::nn::Softmax::is_supported_dt(input_dts[0])
209                    && ops::CudaSoftmax::from_tract_core(op).is_ok()
210            })
211            || node
212                .op_as::<ScaledMaskedSoftmax>()
213                .is_some_and(|_| kernels::nn::ScaledMaskedSoftmax::is_supported_dt(input_dts[0]))
214            || node
215                .op_as::<RmsNorm>()
216                .is_some_and(|_| kernels::nn::RmsNorm::is_supported_dt(input_dts[0]))
217            || node
218                .op_as::<RotateHalf>()
219                .is_some_and(|_| kernels::array::RotateHalf::is_supported_dt(input_dts[0]))
220            || node
221                .op_as::<ApplyRope>()
222                .is_some_and(|_| kernels::nn::ApplyRope::is_supported_dt(input_dts[0]))
223            || node
224                .op_as::<GeluApproximate>()
225                .is_some_and(|_| kernels::nn::GeluApproximate::is_supported_dt(input_dts[0])))
226        || node.op_as::<PrefixMatMul>().is_some_and(|op| {
227            !op.transpose_c
228                && op.quantize_output.is_none()
229                && (GgmlGemm.is_supported_dts(&input_facts)
230                    || GgmlGemm.is_supported_dts(&[input_facts[1].clone(), input_facts[0].clone()]))
231        }))
232}
233
234pub fn pad_q40(q40_bqv: &BlockQuantValue) -> TractResult<BlockQuantValue> {
235    let shape = q40_bqv.fact.shape();
236    ensure!(shape.len() >= 2);
237
238    let k = *shape.last().unwrap();
239    ensure!(k % 32 == 0);
240
241    let to_pad = k.next_multiple_of(Q40_ROW_PADDING) - k;
242    if to_pad == 0 {
243        return Ok(q40_bqv.clone()); // No padding needed
244    }
245
246    let outer_rows: usize = shape[..shape.len() - 1].iter().product();
247    let row_bytes = k * Q4_0.block_bytes() / Q4_0.block_len();
248
249    let pad_quant = Q4_0.quant_f32(&vec![0f32; to_pad])?;
250    let pad_bytes = pad_quant.len();
251
252    let mut new_data = Vec::with_capacity(outer_rows * (row_bytes + pad_bytes));
253    let old_bytes = q40_bqv.value.as_bytes();
254
255    for row in 0..outer_rows {
256        let start = row * row_bytes;
257        new_data.extend_from_slice(&old_bytes[start..start + row_bytes]);
258        new_data.extend_from_slice(&pad_quant);
259    }
260
261    let mut new_shape = shape.to_smallvec();
262    *new_shape.last_mut().unwrap() += to_pad;
263
264    Ok(BlockQuantValue {
265        fact: BlockQuantFact::new(q40_bqv.fact.format.clone(), new_shape),
266        value: Arc::new(Blob::from_bytes(&new_data)?),
267    })
268}
269
270fn convert_const(op: &Const) -> TractResult<Const> {
271    let typed_fact: TypedFact = Arc::clone(op.val()).into();
272    let cuda_const = op.val().clone();
273
274    let to_device_opaque = |fact: TypedFact, tensor: Arc<Tensor>| -> TractResult<_> {
275        Ok((
276            DeviceFact::from_host(fact)?,
277            tensor.into_device()?.into_opaque_tensor().into_arc_tensor(),
278        ))
279    };
280
281    let (cuda_fact, cuda_tensor) = match op.opaque_fact() {
282        Some(_) => {
283            ensure!(as_q40_fact(&typed_fact).is_some(), "Only support Q40 block quantization");
284
285            let tensor = cuda_const.into_tensor();
286            let bqv = as_q40_tensor(&tensor).unwrap();
287
288            let padded_bqv = pad_q40(bqv)?;
289            let padded_fact = typed_fact.with_opaque_fact(padded_bqv.fact.clone());
290            let padded_tensor = tensor0(Opaque(Arc::new(padded_bqv)))
291                .broadcast_into_rank(op.val().rank())?
292                .into_arc_tensor();
293
294            to_device_opaque(padded_fact, padded_tensor)?
295        }
296        None => to_device_opaque(typed_fact, cuda_const)?,
297    };
298
299    Const::new_with_opaque_fact(cuda_tensor, Box::new(cuda_fact))
300}
301
302macro_rules! map_unary_ops {
303    ([$(($tract_unary_op:path, $cuda_unary_op:ident)),* $(,)?]) => {
304        |op: &tract_core::ops::element_wise::ElementWiseOp| {
305            $(if let Some(_op) = op.0.downcast_ref::<$tract_unary_op>() {
306                return Some($crate::ops::CudaUnaryOp(kernels::UnaryOps::$cuda_unary_op));
307            })*
308            return None;
309        }
310    };
311}
312
313fn map_element_wise_ops_to_cuda(op: &ElementWiseOp) -> Option<ops::CudaUnaryOp> {
314    map_unary_ops!([
315        (tract_core::ops::math::Abs, Abs),
316        (tract_core::ops::math::Exp, Exp),
317        (tract_core::ops::math::Ln, Ln),
318        (tract_core::ops::nn::Sigmoid, Sigmoid),
319        (tract_core::ops::math::Square, Sqr),
320        (tract_core::ops::math::Sqrt, Sqrt),
321        (tract_core::ops::math::Rsqrt, Rsqrt),
322        (tract_core::ops::math::Recip, Recip),
323        (tract_core::ops::math::Ceil, Ceil),
324        (tract_core::ops::math::Floor, Floor),
325        (tract_core::ops::math::Round, Round),
326        (tract_core::ops::math::RoundHalfToEven, RoundHalfToEven),
327        (tract_core::ops::math::Cos, Cos),
328        (tract_core::ops::math::Acos, Acos),
329        (tract_core::ops::math::Acosh, Acosh),
330        (tract_core::ops::math::Cosh, Cosh),
331        (tract_core::ops::math::Sin, Sin),
332        (tract_core::ops::math::Asin, Asin),
333        (tract_core::ops::math::Asinh, Asinh),
334        (tract_core::ops::math::Sinh, Sinh),
335        (tract_core::ops::math::Tan, Tan),
336        (tract_core::ops::math::Atan, Atan),
337        (tract_core::ops::math::Atanh, Atanh),
338        (tract_core::ops::math::Tanh, Tanh),
339        (tract_core::ops::math::Erf, Erf),
340        (tract_core::ops::math::Neg, Neg),
341    ])(op)
342}
343
344macro_rules! map_bin_ops {
345    ([$(($tract_bin_op:path, $cuda_bin_op:ident)),* $(,)?]) => {
346        |op: &TypedBinOp | {
347            $(if let Some(_op) = op.0.downcast_ref::<$tract_bin_op>() {
348                return Some($crate::ops::CudaBinOp(kernels::BinOps::$cuda_bin_op));
349            })*
350            return None;
351        }
352    };
353}
354
355#[allow(clippy::borrowed_box)]
356fn map_binary_op_to_cuda(op: &TypedBinOp) -> Option<ops::CudaBinOp> {
357    map_bin_ops!([
358        (tract_core::ops::math::Mul, Mul),
359        (tract_core::ops::math::Add, Add),
360        (tract_core::ops::math::Div, Div),
361        (tract_core::ops::math::Sub, Sub),
362        (tract_core::ops::math::Pow, Pow),
363        (tract_core::ops::logic::And, And),
364        (tract_core::ops::logic::Or, Or),
365    ])(op)
366}
367
368fn convert_logic_op_to_cuda(op: &Comp) -> ops::CudaBinOp {
369    match op {
370        Comp::Eq => ops::CudaBinOp(kernels::BinOps::Equals),
371        Comp::NE => ops::CudaBinOp(kernels::BinOps::NotEquals),
372        Comp::LT => ops::CudaBinOp(kernels::BinOps::Less),
373        Comp::LTE => ops::CudaBinOp(kernels::BinOps::LessEqual),
374        Comp::GT => ops::CudaBinOp(kernels::BinOps::Greater),
375        Comp::GTE => ops::CudaBinOp(kernels::BinOps::GreaterEqual),
376    }
377}
378
379fn convert_matmul_to_cuda(
380    model: &TypedModel,
381    node: &TypedNode,
382    target: &mut TypedModel,
383    inputs: &mut [OutletId],
384    op: &PrefixMatMul,
385) -> TractResult<TVec<OutletId>> {
386    let mut input_facts = model.node_input_facts(node.id)?;
387
388    let mut swap_inputs = false;
389    if !GgmlGemm.is_supported_dts(&[input_facts[0].clone(), input_facts[1].clone()])
390        && GgmlGemm.is_supported_dts(&[input_facts[1].clone(), input_facts[0].clone()])
391    {
392        input_facts.swap(0, 1);
393        inputs.swap(0, 1);
394        swap_inputs = true;
395    }
396
397    let a_pos = swap_inputs as usize;
398    let b_pos = 1 - swap_inputs as usize;
399    if op.transpose_a {
400        ensure!(as_q40_fact(input_facts[a_pos]).is_none(), "Cannot transpose Q40 tensor");
401
402        let rank = input_facts[a_pos].rank();
403        let perm_a_op = ops::CudaAxisOp::from_tract_core(AxisOp::Move(rank - 2, rank - 1));
404        let perm_a_name = node.name.clone() + ".perm_a";
405        inputs[a_pos] = target.wire_node(perm_a_name, perm_a_op, &[inputs[a_pos]])?[0];
406    }
407
408    if input_facts[0].datum_type == DatumType::F16 && as_q40_fact(input_facts[1]).is_some() {
409        let in_cast_op = ops::CudaCast::new(DatumType::F32).unwrap();
410        inputs[0] = target.wire_node(node.name.clone() + ".in_cast", in_cast_op, &[inputs[0]])?[0];
411    }
412
413    if !op.transpose_b {
414        ensure!(as_q40_fact(input_facts[b_pos]).is_none(), "Cannot transpose Q40 tensor");
415
416        let rank = input_facts[b_pos].rank();
417        let perm_b_op = ops::CudaAxisOp::from_tract_core(AxisOp::Move(rank - 2, rank - 1));
418        let perm_b_name = node.name.clone() + ".perm_b";
419        inputs[b_pos] = target.wire_node(perm_b_name, perm_b_op, &[inputs[b_pos]])?[0];
420    }
421
422    let op = ops::CudaGemm::<GgmlGemm>::new(false, true);
423    let mut matmul_output = target.wire_node(node.name.clone(), op, inputs)?;
424
425    if swap_inputs {
426        let out_fact = target.outlet_fact(matmul_output[0])?;
427        let rank = &out_fact
428            .opaque_fact
429            .clone()
430            .map(|fact| fact.clarify_dt_shape().unwrap().1.len())
431            .unwrap();
432
433        let perm_out_op = ops::CudaAxisOp::from_tract_core(AxisOp::Move(rank - 2, rank - 1));
434        matmul_output =
435            target.wire_node(node.name.clone() + ".perm_out", perm_out_op, &matmul_output)?;
436    }
437
438    let out_fact = target.outlet_fact(matmul_output[0])?;
439    let out_dt = out_fact.to_device_fact().map(|f| f.datum_type).unwrap_or(out_fact.datum_type);
440
441    let expected_dt = model.node_output_facts(node.id)?[0].datum_type;
442
443    if out_dt != expected_dt {
444        ensure!(
445            ops::CudaCast::is_supported_dt(out_dt),
446            "Matmul output type cannot be casted to expected type"
447        );
448        let cast_op = ops::CudaCast::new(model.node_output_facts(node.id)?[0].datum_type).unwrap();
449        matmul_output =
450            target.wire_node(node.name.clone() + ".out_cast", cast_op, &matmul_output)?
451    }
452    Ok(matmul_output)
453}
454
455impl Translate<TypedFact, Box<dyn TypedOp>, TypedFact, Box<dyn TypedOp>> for CudaTransform {
456    fn translate_node(
457        &self,
458        source: &TypedModel,
459        node: &TypedNode,
460        target: &mut TypedModel,
461        mapping: &HashMap<OutletId, OutletId>,
462    ) -> TractResult<TVec<OutletId>> {
463        let translatable = can_translate_to_cuda_op(source, node)?;
464
465        if translatable {
466            let mut device_inputs =
467                self.sync_inputs_if_required(target, node, mapping, DeviceSyncKind::ToDevice)?;
468
469            let outlet_ids: TVec<OutletId> = if let Some(op) = node.op_as::<PrefixMatMul>() {
470                convert_matmul_to_cuda(source, node, target, &mut device_inputs, op)?
471            } else {
472                let op: Box<dyn TypedOp> = if let Some(op) = node.op_as::<Const>() {
473                    Box::new(convert_const(op)?)
474                } else if let Some(op) = node.op_as::<ElementWiseOp>() {
475                    Box::new(map_element_wise_ops_to_cuda(op).unwrap())
476                } else if let Some(op) = node.op_as::<TypedBinOp>() {
477                    Box::new(map_binary_op_to_cuda(op).unwrap())
478                } else if let Some(op) = node.op_as::<Comp>() {
479                    Box::new(convert_logic_op_to_cuda(op))
480                } else if let Some(_op) = node.op_as::<Silu>() {
481                    Box::new(ops::CudaUnaryOp(kernels::UnaryOps::Silu))
482                } else if let Some(op) = node.op_as::<MultiBroadcastTo>() {
483                    Box::new(ops::CudaMultiBroadcastTo::new(op.shape.clone()))
484                } else if let Some(op) = node.op_as::<Cast>() {
485                    Box::new(ops::CudaCast::new(op.to).unwrap())
486                } else if let Some(op) = node.op_as::<AxisOp>() {
487                    let in_fact = source.node_input_facts(node.id)?[0];
488                    Box::new(ops::CudaAxisOp::from_tract_core_with_fact(op.clone(), in_fact))
489                } else if let Some(op) = node.op_as::<Slice>() {
490                    Box::new(ops::CudaSlice::from_tract_core(op.clone()))
491                } else if let Some(op) = node.op_as::<TypedConcat>() {
492                    Box::new(ops::CudaConcat::from_tract_core(op))
493                } else if let Some(op) = node.op_as::<DynKeyValueCache>() {
494                    Box::new(ops::CudaDynKVCache::from_tract_transformers(op))
495                } else if let Some(op) = node.op_as::<Reduce>() {
496                    Box::new(ops::CudaReduce::from_tract_core(op)?)
497                } else if let Some(op) = node.op_as::<Softmax>() {
498                    Box::new(ops::CudaSoftmax::from_tract_core(op)?)
499                } else if let Some(op) = node.op_as::<ScaledMaskedSoftmax>() {
500                    Box::new(ops::CudaScaledMaskedSoftmax { scale: op.scale.clone() })
501                } else if let Some(_op) = node.op_as::<RotateHalf>() {
502                    Box::new(ops::CudaRotateHalf)
503                } else if let Some(_op) = node.op_as::<ApplyRope>() {
504                    Box::new(ops::CudaApplyRope)
505                } else if let Some(op) = node.op_as::<RmsNorm>() {
506                    Box::new(ops::CudaRmsNorm::new(op.axis, op.eps.clone()))
507                } else if let Some(op) = node.op_as::<GeluApproximate>() {
508                    Box::new(ops::CudaGeluApproximate { fast_impl: op.fast_impl })
509                } else {
510                    bail!("Failed to translate a supported CUDA Op")
511                };
512                target.wire_node(node.name.clone(), op, &device_inputs)?
513            };
514            self.sync_model_outputs_if_required(source, node, target, outlet_ids)
515        } else {
516            let cpu_inputs =
517                self.sync_inputs_if_required(target, node, mapping, DeviceSyncKind::ToHost)?;
518            target.wire_node(&node.name, node.op.clone(), &cpu_inputs)
519        }
520    }
521}