1use DatumType::{F16, F32};
2use tract_core::dyn_clone::clone_box;
3use tract_core::internal::*;
4use tract_core::model::translator::Translate;
5use tract_core::ops::array::{MultiBroadcastTo, Slice, TypedConcat};
6use tract_core::ops::binary::TypedBinOp;
7use tract_core::ops::cast::Cast;
8use tract_core::ops::cnn::conv::rewrite_kernel_conv_in_oihw;
9use tract_core::ops::cnn::{Conv, rewrite_conv_with_n_axis};
10use tract_core::ops::einsum::prefix_matmul::{PrefixMatMul, rewrite_einsum_to_prefix_matmul};
11use tract_core::ops::element_wise::ElementWiseOp;
12use tract_core::ops::konst::Const;
13use tract_core::ops::logic::Comp;
14use tract_core::ops::nn::{LeakyRelu, Reduce, Softmax};
15use tract_core::tract_data::itertools::Itertools;
16use tract_core::tract_linalg::block_quant::Q4_0;
17use tract_core::transform::ModelTransform;
18use tract_gpu::fact::{DeviceFact, DeviceTypedFactExt};
19use tract_gpu::rewrite_rules::rewire_syncs::rewire_syncs;
20use tract_gpu::rewrite_rules::rms_norm::remove_rms_norm_cast;
21use tract_gpu::sync::{DeviceSync, DeviceSyncKind};
22use tract_gpu::tensor::{DeviceTensor, DeviceTensorExt, IntoDevice};
23use tract_gpu::utils::as_quant_fact;
24use tract_pulse_opl::ops::{Delay, PulsePad};
25use tract_transformers::ops::apply_rope::{ApplyRope, RotateHalf};
26use tract_transformers::ops::dyn_kv_cache::DynKeyValueCache;
27use tract_transformers::ops::gelu_approximate::GeluApproximate;
28use tract_transformers::ops::rms_norm::RmsNorm;
29use tract_transformers::ops::scaled_masked_softmax::ScaledMaskedSoftmax;
30use tract_transformers::ops::sdpa::Sdpa;
31use tract_transformers::ops::silu::Silu;
32
33use crate::context::cuda_context;
34use crate::ops::{CudaDelay, CudaPulsePad};
35use crate::ops::{CudaLeakyRelu, wire_cuda_conv};
36use crate::{kernels, ops, rewrite_rules};
37
38#[derive(Debug, Default)]
39pub struct CudaTransform;
40
41impl ModelTransform for CudaTransform {
42 fn name(&self) -> StaticName {
43 "cuda-transform".into()
44 }
45
46 fn transform(&self, model: &mut TypedModel) -> TractResult<()> {
47 self.transform_up_to_phase(model, usize::MAX)
48 }
49}
50
51impl CudaTransform {
52 pub fn transform_up_to_phase(
53 &self,
54 model: &mut TypedModel,
55 stop_at_phase: usize,
56 ) -> TractResult<()> {
57 cuda_context();
59
60 rewrite_einsum_to_prefix_matmul(model, false)?;
61 if stop_at_phase == 0 {
62 return Ok(());
63 }
64
65 Rewriter::default()
66 .with_rule_for("untranspose_matmul_output", rewrite_rules::untranspose_matmul_output)
67 .with_rule_for("add_broadcast_pre_matmul", rewrite_rules::add_broadcast_pre_matmul)
68 .with_rule_for("rewrite_kernel_conv_in_oihw", rewrite_kernel_conv_in_oihw)
69 .with_rule_for("rewrite_conv_with_n_axis", rewrite_conv_with_n_axis)
70 .rewrite(&(), model)?;
71
72 Rewriter::default()
73 .with_rule_for("remove_rms_norm_cast", remove_rms_norm_cast)
74 .rewrite(&(), model)?;
75
76 if stop_at_phase == 1 {
77 return Ok(());
78 }
79
80 *model = self.translate_model(model)?;
81
82 if stop_at_phase == 2 {
83 return Ok(());
84 }
85
86 Rewriter::default()
87 .with_rule_for("fuse_move_axis", rewrite_rules::fuse_move_axis)
88 .rewrite(&(), model)?;
89 Rewriter::default()
90 .with_rule_for("fuse_axis_op", rewrite_rules::fuse_axis_op)
91 .rewrite(&(), model)?;
92
93 rewire_syncs(model)?;
94
95 Rewriter::default()
96 .with_rule_for("pad_q40_weights", rewrite_rules::pad_q40_weights)
97 .rewrite(&(), model)?;
98 Ok(())
99 }
100
101 fn sync_inputs_if_required(
102 &self,
103 model: &mut TypedModel,
104 node: &TypedNode,
105 mapping: &HashMap<OutletId, OutletId>,
106 sync_kind: DeviceSyncKind,
107 ) -> TractResult<TVec<OutletId>> {
108 let mut mapped_inputs = tvec![];
109 for (i_idx, i) in node.inputs.iter().enumerate() {
110 let in_fact = model.outlet_fact_mut(mapping[i])?;
111 match sync_kind {
112 DeviceSyncKind::ToHost if in_fact.as_device_fact().is_some() => {
113 mapped_inputs.push(
114 model.wire_node(
115 format!("{}.to-cpu-{i_idx}", node.name),
116 DeviceSync::new(sync_kind),
117 &[mapping[i]],
118 )?[0],
119 );
120 }
121 DeviceSyncKind::ToDevice if in_fact.as_device_fact().is_none() => {
122 if let Some(ref konst) = in_fact.konst {
123 if konst.as_device_tensor().is_none() {
124 let device_konst =
125 konst.as_ref().clone().into_device()?.into_opaque_tensor();
126 let device_fact = DeviceFact::from_host(in_fact.clone())?;
127
128 *in_fact = TypedFact::dt_scalar(DatumType::Opaque)
129 .with_opaque_fact(device_fact);
130
131 in_fact.konst = Some(Arc::new(device_konst));
132 mapped_inputs.push(mapping[i]);
133 continue;
134 }
135 }
136 ensure!(
137 in_fact.datum_type.is_copy(),
138 "Only copy DatumType can be sync to Device: {:?}",
139 in_fact.datum_type
140 );
141
142 mapped_inputs.push(
143 model.wire_node(
144 format!("{}.to-device-{i_idx}", node.name),
145 DeviceSync::new(sync_kind),
146 &[mapping[i]],
147 )?[0],
148 );
149 }
150 _ => mapped_inputs.push(mapping[i]),
151 }
152 }
153 Ok(mapped_inputs)
154 }
155
156 fn sync_model_outputs_if_required(
157 &self,
158 src: &TypedModel,
159 node: &TypedNode,
160 target: &mut TypedModel,
161 target_node_outlet_ids: TVec<OutletId>,
162 ) -> TractResult<TVec<OutletId>> {
163 let mut outputs = tvec![];
164 for (o_idx, o) in target_node_outlet_ids.into_iter().enumerate() {
165 let is_src_output = src.outputs.contains(&OutletId::new(node.id, o_idx));
167 if target.outlet_fact(o)?.as_device_fact().is_some() && is_src_output {
168 let sync_output = target.wire_node(
169 format!("{}.to-host-{o_idx}-out", node.name),
170 DeviceSync::new(DeviceSyncKind::ToHost),
171 &[o],
172 )?[0];
173 outputs.push(sync_output);
174 } else {
175 outputs.push(o)
176 }
177 }
178 Ok(outputs)
179 }
180}
181
182fn can_translate_to_cuda_op(source: &TypedModel, node: &TypedNode) -> TractResult<bool> {
183 let input_facts = source.node_input_facts(node.id)?.iter().map(|f| (*f).clone()).collect_vec();
184 let input_dts = input_facts
185 .iter()
186 .map(|f| f.as_device_fact().map(|f| f.datum_type).unwrap_or(f.datum_type))
187 .collect_vec();
188
189 let in_dts_compatible =
190 input_facts.iter().all(|fact| DeviceTensor::is_supported_dt(fact.datum_type));
191
192 Ok(in_dts_compatible
193 && (node
194 .op_as::<Const>()
195 .is_some_and(|op| DeviceTensor::is_supported_dt(op.val().datum_type()))
196 || node
197 .op_as::<Silu>()
198 .is_some_and(|_| kernels::UnaryOps::is_supported_dt(input_dts[0]))
199 || node.op_as::<ElementWiseOp>().is_some_and(|op| op.0.is::<LeakyRelu>())
200 || node.op_as::<ElementWiseOp>().is_some_and(|op| {
201 kernels::UnaryOps::is_supported_dt(input_dts[0])
202 && map_element_wise_ops_to_cuda(op).is_some()
203 })
204 || node.op_as::<TypedBinOp>().is_some_and(|op| {
205 map_binary_op_to_cuda(op).is_some_and(|op| op.0.is_supported_dt(input_dts[0]))
206 })
207 || node
208 .op_as::<Comp>()
209 .is_some_and(|op| convert_logic_op_to_cuda(op).0.is_supported_dt(input_dts[0]))
210 || node
211 .op_as::<Const>()
212 .is_some_and(|op| DeviceTensor::is_supported_dt(op.val().datum_type()))
213 || node.op_as::<Cast>().is_some_and(|op| {
214 ops::CudaCast::is_supported_dt(input_dts[0]) && ops::CudaCast::new(op.to).is_some()
215 })
216 || node.op_is::<MultiBroadcastTo>()
217 || node.op_is::<AxisOp>()
218 || node.op_is::<Slice>()
219 || node.op_is::<Delay>()
220 || node.op_is::<PulsePad>()
221 || node.op_is::<TypedConcat>()
222 || node.op_is::<DynKeyValueCache>()
223 || node.op_as::<Reduce>().is_some_and(|op| {
224 ops::CudaReduce::from_tract_core(op)
225 .is_ok_and(|op| op.reducer.is_supported_dt(input_dts[0]))
226 })
227 || node.op_as::<Softmax>().is_some_and(|op| {
228 kernels::nn::Softmax::is_supported_dt(input_dts[0])
229 && ops::CudaSoftmax::from_tract_core(op).is_ok()
230 })
231 || node
232 .op_as::<ScaledMaskedSoftmax>()
233 .is_some_and(|_| kernels::nn::ScaledMaskedSoftmax::is_supported_dt(input_dts[0]))
234 || node
235 .op_as::<RmsNorm>()
236 .is_some_and(|_| kernels::nn::RmsNorm::is_supported_dt(input_dts[0]))
237 || node
238 .op_as::<RotateHalf>()
239 .is_some_and(|_| kernels::array::RotateHalf::is_supported_dt(input_dts[0]))
240 || node
241 .op_as::<ApplyRope>()
242 .is_some_and(|_| kernels::nn::ApplyRope::is_supported_dt(input_dts[0]))
243 || node
244 .op_as::<GeluApproximate>()
245 .is_some_and(|_| kernels::nn::GeluApproximate::is_supported_dt(input_dts[0]))
246 || node.op_as::<Sdpa>().is_some()
247 || node.op_as::<PrefixMatMul>().is_some_and(|op| {
248 !op.transpose_c
249 && op.quantize_output.is_none()
250 && (can_convert_to_cuda_gemm(&input_facts)
251 || can_convert_to_cuda_gemm(&[
252 input_facts[1].clone(),
253 input_facts[0].clone(),
254 ]))
255 })
256 || (node.op_is::<Conv>() && input_facts[0].datum_type.is::<f32>())))
257}
258
259fn convert_const(op: &Const) -> TractResult<Const> {
260 let typed_fact: TypedFact = Arc::clone(op.val()).into();
261 let cuda_fact = if let Some(of) = op.opaque_fact() {
262 DeviceFact::from_host(typed_fact.with_opaque_fact(clone_box(of)))?
263 } else {
264 DeviceFact::from_host(typed_fact)?
265 };
266
267 let cuda_const = op.val().clone().into_device()?.into_opaque_tensor().into_arc_tensor();
268 Const::new_with_opaque_fact(cuda_const, Box::new(cuda_fact))
269}
270
271macro_rules! map_unary_ops {
272 ([$(($tract_unary_op:path, $cuda_unary_op:ident)),* $(,)?]) => {
273 |op: &tract_core::ops::element_wise::ElementWiseOp| {
274 $(if let Some(_op) = op.0.downcast_ref::<$tract_unary_op>() {
275 return Some($crate::ops::CudaUnaryOp(kernels::UnaryOps::$cuda_unary_op));
276 })*
277 return None;
278 }
279 };
280}
281
282fn map_element_wise_ops_to_cuda(op: &ElementWiseOp) -> Option<ops::CudaUnaryOp> {
283 map_unary_ops!([
284 (tract_core::ops::math::Abs, Abs),
285 (tract_core::ops::math::Exp, Exp),
286 (tract_core::ops::math::Ln, Ln),
287 (tract_core::ops::nn::Sigmoid, Sigmoid),
288 (tract_core::ops::math::Square, Sqr),
289 (tract_core::ops::math::Sqrt, Sqrt),
290 (tract_core::ops::math::Rsqrt, Rsqrt),
291 (tract_core::ops::math::Recip, Recip),
292 (tract_core::ops::math::Ceil, Ceil),
293 (tract_core::ops::math::Floor, Floor),
294 (tract_core::ops::math::Round, Round),
295 (tract_core::ops::math::RoundHalfToEven, RoundHalfToEven),
296 (tract_core::ops::math::Cos, Cos),
297 (tract_core::ops::math::Acos, Acos),
298 (tract_core::ops::math::Acosh, Acosh),
299 (tract_core::ops::math::Cosh, Cosh),
300 (tract_core::ops::math::Sin, Sin),
301 (tract_core::ops::math::Asin, Asin),
302 (tract_core::ops::math::Asinh, Asinh),
303 (tract_core::ops::math::Sinh, Sinh),
304 (tract_core::ops::math::Tan, Tan),
305 (tract_core::ops::math::Atan, Atan),
306 (tract_core::ops::math::Atanh, Atanh),
307 (tract_core::ops::math::Tanh, Tanh),
308 (tract_core::ops::math::Erf, Erf),
309 (tract_core::ops::math::Neg, Neg),
310 ])(op)
311}
312
313macro_rules! map_bin_ops {
314 ([$(($tract_bin_op:path, $cuda_bin_op:ident)),* $(,)?]) => {
315 |op: &TypedBinOp | {
316 $(if let Some(_op) = op.0.downcast_ref::<$tract_bin_op>() {
317 return Some($crate::ops::CudaBinOp(kernels::BinOps::$cuda_bin_op));
318 })*
319 return None;
320 }
321 };
322}
323
324#[allow(clippy::borrowed_box)]
325fn map_binary_op_to_cuda(op: &TypedBinOp) -> Option<ops::CudaBinOp> {
326 map_bin_ops!([
327 (tract_core::ops::math::Mul, Mul),
328 (tract_core::ops::math::Add, Add),
329 (tract_core::ops::math::Div, Div),
330 (tract_core::ops::math::Sub, Sub),
331 (tract_core::ops::math::Min, Min),
332 (tract_core::ops::math::Max, Max),
333 (tract_core::ops::math::Pow, Pow),
334 (tract_core::ops::logic::And, And),
335 (tract_core::ops::logic::Or, Or),
336 ])(op)
337}
338
339fn convert_logic_op_to_cuda(op: &Comp) -> ops::CudaBinOp {
340 match op {
341 Comp::Eq => ops::CudaBinOp(kernels::BinOps::Equals),
342 Comp::NE => ops::CudaBinOp(kernels::BinOps::NotEquals),
343 Comp::LT => ops::CudaBinOp(kernels::BinOps::Less),
344 Comp::LTE => ops::CudaBinOp(kernels::BinOps::LessEqual),
345 Comp::GT => ops::CudaBinOp(kernels::BinOps::Greater),
346 Comp::GTE => ops::CudaBinOp(kernels::BinOps::GreaterEqual),
347 }
348}
349
350fn can_convert_to_cuda_gemm(facts: &[TypedFact]) -> bool {
351 assert!(facts.len() == 2, "Ggml: Expected 2 inputs for Matmul");
352
353 let regular_types_support =
354 matches!((facts[0].datum_type, facts[1].datum_type), (F32, F32) | (F16, F16) | (F16, F32));
355
356 regular_types_support
357 || (as_quant_fact(&facts[1], &Q4_0).is_some() && matches!(facts[0].datum_type, F16 | F32))
358}
359
360fn convert_matmul_to_cuda(
361 model: &TypedModel,
362 node: &TypedNode,
363 target: &mut TypedModel,
364 inputs: &mut [OutletId],
365 op: &PrefixMatMul,
366) -> TractResult<TVec<OutletId>> {
367 let mut input_facts = model.node_input_facts(node.id)?;
368 let mut swap_inputs = false;
372 if !can_convert_to_cuda_gemm(&[input_facts[0].clone(), input_facts[1].clone()])
373 && can_convert_to_cuda_gemm(&[input_facts[1].clone(), input_facts[0].clone()])
374 {
375 input_facts.swap(0, 1);
376 inputs.swap(0, 1);
377 swap_inputs = true;
378 }
379
380 let act_fact = input_facts[0];
381 let weight_fact = input_facts[1];
382 let outlets = inputs.split_at_mut(1);
383 let act_outlet = &mut outlets.0[0];
384 let weights_outlet = &mut outlets.1[0];
385
386 let transpose_act = if swap_inputs { !op.transpose_b } else { op.transpose_a };
387 let transpose_weight = if swap_inputs { !op.transpose_a } else { op.transpose_b };
388
389 if transpose_act {
390 let rank = act_fact.rank();
391 let perm_act_op = ops::CudaAxisOp::from_tract_core(AxisOp::Move(rank - 2, rank - 1));
392 let perm_act_name = node.name.clone() + ".perm_activs";
393 *act_outlet = target.wire_node(perm_act_name, perm_act_op, &[*act_outlet])?[0];
394 }
395
396 if act_fact.datum_type == DatumType::F16 && as_quant_fact(weight_fact, &Q4_0).is_some() {
397 let in_cast_op = ops::CudaCast::new(DatumType::F32).unwrap();
398 *act_outlet =
399 target.wire_node(node.name.clone() + ".in_cast", in_cast_op, &[*act_outlet])?[0];
400 } else if act_fact.datum_type == DatumType::F16 && weight_fact.datum_type == DatumType::F32 {
401 let in_cast_op = ops::CudaCast::new(DatumType::F16).unwrap();
402 *weights_outlet =
403 target.wire_node(node.name.clone() + ".in_cast", in_cast_op, &[*weights_outlet])?[0];
404 }
405
406 if !transpose_weight {
407 ensure!(as_quant_fact(weight_fact, &Q4_0).is_none(), "Cannot transpose Q40 tensor");
408
409 let rank = weight_fact.rank();
410 let perm_weights_op = ops::CudaAxisOp::from_tract_core(AxisOp::Move(rank - 2, rank - 1));
411 let perm_weights_name = node.name.clone() + ".perm_weights";
412 *weights_outlet =
413 target.wire_node(perm_weights_name, perm_weights_op, &[*weights_outlet])?[0];
414 }
415
416 if as_quant_fact(weight_fact, &Q4_0).is_some() {
417 let device_fact = target.outlet_fact(*act_outlet)?.to_device_fact()?;
418 let quant_op = ops::CudaGgmlQuantQ81::new(device_fact.shape.clone())?;
419 *act_outlet =
420 target.wire_node(node.name.clone() + ".quant_activs", quant_op, &[*act_outlet])?[0];
421 }
422 let mut matmul_output =
423 target.wire_node(node.name.clone(), *Box::new(ops::CudaGgmlGemm), inputs)?;
424
425 if swap_inputs {
426 let out_fact = target.outlet_fact(matmul_output[0])?;
427 let rank = &out_fact
428 .opaque_fact
429 .clone()
430 .map(|fact| fact.clarify_dt_shape().unwrap().1.len())
431 .unwrap();
432
433 let perm_out_op = ops::CudaAxisOp::from_tract_core(AxisOp::Move(rank - 2, rank - 1));
434 matmul_output =
435 target.wire_node(node.name.clone() + ".perm_out", perm_out_op, &matmul_output)?;
436 }
437
438 let out_fact = target.outlet_fact(matmul_output[0])?;
439 let out_dt = out_fact.as_device_fact().map(|f| f.datum_type).unwrap_or(out_fact.datum_type);
440
441 let expected_dt = model.node_output_facts(node.id)?[0].datum_type;
442 if out_dt != expected_dt {
443 ensure!(
444 ops::CudaCast::is_supported_dt(out_dt),
445 "Matmul output type cannot be casted to expected type"
446 );
447 let cast_op = ops::CudaCast::new(model.node_output_facts(node.id)?[0].datum_type).unwrap();
448 matmul_output =
449 target.wire_node(node.name.clone() + ".out_cast", cast_op, &matmul_output)?
450 }
451 Ok(matmul_output)
452}
453
454fn convert_sdpa_to_cuda_flash_attn(
455 model: &TypedModel,
456 node: &TypedNode,
457 target: &mut TypedModel,
458 inputs: &mut [OutletId],
459 op: &Sdpa,
460) -> TractResult<TVec<OutletId>> {
461 let facts = model.node_input_facts(node.id)?;
462
463 let [qf, kf, vf] = [facts[0], facts[1], facts[2]];
464 ensure!(kf.datum_type() == vf.datum_type(), "K/V dtypes must match");
465
466 let mask_fact = if facts.len() == 4 { Some(facts[3]) } else { None };
467
468 let (q, k, v, m_opt) = match &mut inputs[..] {
469 [q, k, v, m, ..] => (q, k, v, Some(m)),
470 [q, k, v] => (q, k, v, None),
471 _ => bail!("unexpected number of inputs"),
472 };
473
474 fn name(base: &str, suffix: &str) -> String {
475 format!("{base}{suffix}")
476 }
477
478 fn mut_cast(
479 target: &mut TypedModel,
480 node_name: &str,
481 dst: &mut OutletId,
482 have: DatumType,
483 want: DatumType,
484 suffix: &str,
485 ) -> TractResult<()> {
486 if have != want {
487 *dst = target.wire_node(
488 name(node_name, suffix),
489 ops::CudaCast::new(want).unwrap(),
490 &[*dst],
491 )?[0];
492 }
493 Ok(())
494 }
495
496 fn add_head_axis_if_rank3(
497 target: &mut TypedModel,
498 node_name: &str,
499 dst: &mut OutletId,
500 fact: &TypedFact,
501 suffix: &str,
502 ) -> TractResult<bool> {
503 if fact.rank() == 3 {
504 let ax = ops::CudaAxisOp::from_tract_core(AxisOp::Add(1));
505 *dst = target.wire_node(name(node_name, suffix), ax, &[*dst])?[0];
506 Ok(true)
507 } else {
508 ensure!(fact.rank() == 4, "Q/K/V must be rank 3 or 4");
509 Ok(false)
510 }
511 }
512
513 let q_dt = qf.datum_type().unwrap();
515 let kv_dt = kf.datum_type().unwrap();
516 mut_cast(target, &node.name, k, kv_dt, DatumType::F16, ".cast_k")?;
517 mut_cast(target, &node.name, v, kv_dt, DatumType::F16, ".cast_v")?;
518 mut_cast(target, &node.name, q, q_dt, DatumType::F16, ".cast_q")?;
519
520 let mut added_head_axis = false;
522 added_head_axis |= add_head_axis_if_rank3(target, &node.name, q, qf, ".reshape_q")?;
523 added_head_axis |= add_head_axis_if_rank3(target, &node.name, k, kf, ".reshape_k")?;
524 added_head_axis |= add_head_axis_if_rank3(target, &node.name, v, vf, ".reshape_v")?;
525
526 let out_dim = kf.shape[kf.rank() - 1].to_i64()?;
527 ensure!(matches!(out_dim, 64 | 128), "Unsupported head dim (D): {out_dim}");
528 ensure!(kf.shape == vf.shape, "K and V shapes must be identical");
529
530 if let Some(mf) = mask_fact {
532 let m = m_opt.unwrap();
533 mut_cast(target, &node.name, m, mf.datum_type().unwrap(), DatumType::F16, ".cast_m")?;
534 if mf.rank() != 4 {
535 let ax = ops::CudaAxisOp::from_tract_core(AxisOp::Add(1));
536 *m = target.wire_node(name(&node.name, ".reshape_m"), ax, &[*m])?[0];
537 }
538 }
539
540 let scale = op
542 .scale
543 .as_ref()
544 .map(|s| *s.to_scalar::<f32>().unwrap())
545 .unwrap_or(1.0 / (out_dim as f32).sqrt());
546 let sdpa = ops::CudaFlashAttention::new(scale, op.is_causal);
547
548 let mut out = target.wire_node(node.name.clone(), sdpa, inputs)?;
549
550 if added_head_axis {
551 out = target.wire_node(
552 name(&node.name, ".reshape_out"),
553 ops::CudaAxisOp::from_tract_core(AxisOp::Rm(1)),
554 &out,
555 )?;
556 }
557
558 if q_dt != DatumType::F16 {
559 out = target.wire_node(
560 name(&node.name, ".cast_out"),
561 ops::CudaCast::new(q_dt).unwrap(),
562 &out,
563 )?;
564 }
565
566 Ok(out)
567}
568
569impl Translate<TypedFact, Box<dyn TypedOp>, TypedFact, Box<dyn TypedOp>> for CudaTransform {
570 fn translate_node(
571 &self,
572 source: &TypedModel,
573 node: &TypedNode,
574 target: &mut TypedModel,
575 mapping: &HashMap<OutletId, OutletId>,
576 ) -> TractResult<TVec<OutletId>> {
577 let translatable = can_translate_to_cuda_op(source, node)?;
578
579 if translatable {
580 let mut device_inputs =
581 self.sync_inputs_if_required(target, node, mapping, DeviceSyncKind::ToDevice)?;
582
583 let outlet_ids: TVec<OutletId> = if let Some(op) = node.op_as::<PrefixMatMul>() {
584 convert_matmul_to_cuda(source, node, target, &mut device_inputs, op)?
585 } else if let Some(op) = node.op_as::<Sdpa>() {
586 convert_sdpa_to_cuda_flash_attn(source, node, target, &mut device_inputs, op)?
587 } else if let Some(conv) = node.op_as::<Conv>() {
588 wire_cuda_conv(source, node, target, &device_inputs, conv)?
589 } else {
590 let op: Box<dyn TypedOp> = if let Some(op) = node.op_as::<Const>() {
591 Box::new(convert_const(op)?)
592 } else if let Some(op) = node.op_as::<ElementWiseOp>() {
593 if let Some(leaky) = op.0.downcast_ref::<LeakyRelu>() {
594 Box::new(CudaLeakyRelu { alpha: leaky.alpha })
595 } else {
596 Box::new(map_element_wise_ops_to_cuda(op).unwrap())
597 }
598 } else if let Some(op) = node.op_as::<TypedBinOp>() {
599 Box::new(map_binary_op_to_cuda(op).unwrap())
600 } else if let Some(op) = node.op_as::<Comp>() {
601 Box::new(convert_logic_op_to_cuda(op))
602 } else if let Some(_op) = node.op_as::<Silu>() {
603 Box::new(ops::CudaUnaryOp(kernels::UnaryOps::Silu))
604 } else if let Some(op) = node.op_as::<MultiBroadcastTo>() {
605 Box::new(ops::CudaMultiBroadcastTo::new(op.shape.clone()))
606 } else if let Some(op) = node.op_as::<Cast>() {
607 Box::new(ops::CudaCast::new(op.to).unwrap())
608 } else if let Some(op) = node.op_as::<AxisOp>() {
609 let in_fact = source.node_input_facts(node.id)?[0];
610 Box::new(ops::CudaAxisOp::from_tract_core_with_fact(op.clone(), in_fact))
611 } else if let Some(op) = node.op_as::<Slice>() {
612 Box::new(ops::CudaSlice::from_tract_core(op.clone()))
613 } else if let Some(op) = node.op_as::<TypedConcat>() {
614 Box::new(ops::CudaConcat::from_tract_core(op))
615 } else if let Some(op) = node.op_as::<DynKeyValueCache>() {
616 Box::new(ops::CudaDynKVCache::from_tract_transformers(op))
617 } else if let Some(op) = node.op_as::<Reduce>() {
618 Box::new(ops::CudaReduce::from_tract_core(op)?)
619 } else if let Some(op) = node.op_as::<Softmax>() {
620 Box::new(ops::CudaSoftmax::from_tract_core(op)?)
621 } else if let Some(op) = node.op_as::<ScaledMaskedSoftmax>() {
622 Box::new(ops::CudaScaledMaskedSoftmax { scale: op.scale.clone() })
623 } else if let Some(_op) = node.op_as::<RotateHalf>() {
624 Box::new(ops::CudaRotateHalf)
625 } else if let Some(_op) = node.op_as::<ApplyRope>() {
626 Box::new(ops::CudaApplyRope)
627 } else if let Some(op) = node.op_as::<RmsNorm>() {
628 Box::new(ops::CudaRmsNorm::new(op.axis, op.eps.clone()))
629 } else if let Some(op) = node.op_as::<GeluApproximate>() {
630 Box::new(ops::CudaGeluApproximate { fast_impl: op.fast_impl })
631 } else if let Some(op) = node.op_as::<Delay>() {
632 Box::new(CudaDelay::new(op.clone()))
633 } else if let Some(op) = node.op_as::<PulsePad>() {
634 Box::new(CudaPulsePad::new(op)?)
635 } else {
636 bail!("Failed to translate a supported CUDA Op")
637 };
638 target.wire_node(node.name.clone(), op, &device_inputs)?
639 };
640 self.sync_model_outputs_if_required(source, node, target, outlet_ids)
641 } else {
642 let cpu_inputs =
643 self.sync_inputs_if_required(target, node, mapping, DeviceSyncKind::ToHost)?;
644 target.wire_node(&node.name, node.op.clone(), &cpu_inputs)
645 }
646 }
647}
648
649#[cfg(test)]
650mod test {
651 use super::*;
652
653 #[test]
654 fn test_prefix_matmul_transform_f32_f16() -> TractResult<()> {
655 let mut model = TypedModel::default();
656 let (b, m, k, n) = (1, 16, 128, 32);
657
658 let a_fact = TypedFact::dt_shape(DatumType::F32, &[b, m, k]);
659 let b_fact = TypedFact::dt_shape(DatumType::F16, &[b, k, n]);
660
661 let source_a = model.add_source("a", a_fact)?;
662 let source_b = model.add_source("b", b_fact)?;
663
664 let op = PrefixMatMul {
665 transpose_a: false,
666 transpose_b: false,
667 transpose_c: false,
668 quantize_output: None,
669 operating_dt: Some(DatumType::F32),
670 };
671
672 let matmul_out = model.wire_node("matmul", op, &[source_a, source_b])?;
673 model.set_output_outlets(&matmul_out)?;
674
675 let tensor_a = Tensor::zero::<f32>(&[b, m, k])?;
676 let tensor_b = Tensor::zero::<f16>(&[b, k, n])?;
677 let inputs = tvec!(tensor_a.into(), tensor_b.into());
678
679 let transform = CudaTransform::default();
680 transform.transform(&mut model)?;
681
682 let cuda_runnable = model.into_runnable()?;
683 let _ = cuda_runnable.run(inputs)?;
684 Ok(())
685 }
686}