Skip to main content

tract_gpu/ops/
copy_based.rs

1//! Translators for ops that only need the generic copy_nd dispatch.
2//! These are fully backend-agnostic and can be constructed without
3//! any backend-specific arguments.
4
5use tract_core::internal::*;
6use tract_core::ops::array::{MultiBroadcastTo, Slice, TypedConcat};
7use tract_pulse_opl::ops::{Delay, PulsePad};
8use tract_transformers::ops::dyn_kv_cache::DynKeyValueCache;
9
10/// Try to translate a node into a copy-based GPU op.
11/// Returns `Some(gpu_op)` if the node is one of the 7 copy-based ops.
12pub fn try_make_copy_based_op(
13    source: &TypedModel,
14    node: &TypedNode,
15) -> TractResult<Option<Box<dyn TypedOp>>> {
16    if let Some(op) = node.op_as::<MultiBroadcastTo>() {
17        return Ok(Some(Box::new(super::broadcast::GpuMultiBroadcastTo::new(op.shape.clone()))));
18    }
19    if let Some(op) = node.op_as::<AxisOp>() {
20        let in_fact = source.node_input_facts(node.id)?[0];
21        return Ok(Some(Box::new(super::change_axes::GpuAxisOp::from_tract_core_with_fact(
22            op.clone(),
23            in_fact,
24        ))));
25    }
26    if let Some(op) = node.op_as::<Slice>() {
27        return Ok(Some(Box::new(super::slice::GpuSlice::new(op.clone()))));
28    }
29    if let Some(op) = node.op_as::<TypedConcat>() {
30        return Ok(Some(Box::new(super::concat::GpuConcat::new(op.axis))));
31    }
32    if let Some(op) = node.op_as::<DynKeyValueCache>() {
33        return Ok(Some(Box::new(super::dyn_kv_cache::GpuDynKVCache::from_tract_transformers(op))));
34    }
35    if let Some(op) = node.op_as::<Delay>() {
36        return Ok(Some(Box::new(super::pulse::GpuDelay::new(op))));
37    }
38    if let Some(op) = node.op_as::<PulsePad>() {
39        return Ok(Some(Box::new(super::pulse::GpuPulsePad::new(op)?)));
40    }
41    Ok(None)
42}