tract_gpu/ops/
copy_based.rs1use tract_core::internal::*;
6use tract_core::ops::array::{MultiBroadcastTo, Slice, TypedConcat};
7use tract_pulse_opl::ops::{Delay, PulsePad};
8use tract_transformers::ops::dyn_kv_cache::DynKeyValueCache;
9
10pub fn try_make_copy_based_op(
13 source: &TypedModel,
14 node: &TypedNode,
15) -> TractResult<Option<Box<dyn TypedOp>>> {
16 if let Some(op) = node.op_as::<MultiBroadcastTo>() {
17 return Ok(Some(Box::new(super::broadcast::GpuMultiBroadcastTo::new(op.shape.clone()))));
18 }
19 if let Some(op) = node.op_as::<AxisOp>() {
20 let in_fact = source.node_input_facts(node.id)?[0];
21 return Ok(Some(Box::new(super::change_axes::GpuAxisOp::from_tract_core_with_fact(
22 op.clone(),
23 in_fact,
24 ))));
25 }
26 if let Some(op) = node.op_as::<Slice>() {
27 return Ok(Some(Box::new(super::slice::GpuSlice::new(op.clone()))));
28 }
29 if let Some(op) = node.op_as::<TypedConcat>() {
30 return Ok(Some(Box::new(super::concat::GpuConcat::new(op.axis))));
31 }
32 if let Some(op) = node.op_as::<DynKeyValueCache>() {
33 return Ok(Some(Box::new(super::dyn_kv_cache::GpuDynKVCache::from_tract_transformers(op))));
34 }
35 if let Some(op) = node.op_as::<Delay>() {
36 return Ok(Some(Box::new(super::pulse::GpuDelay::new(op))));
37 }
38 if let Some(op) = node.op_as::<PulsePad>() {
39 return Ok(Some(Box::new(super::pulse::GpuPulsePad::new(op)?)));
40 }
41 Ok(None)
42}