1use crate::infer::*;
2use crate::internal::*;
3
4use tract_core::broadcast::multi_broadcast;
5use tract_core::ops::cast::wire_cast;
6pub use tract_core::ops::change_axes::wire_with_rank_broadcast;
7pub use tract_core::ops::logic::*;
8
9impl Expansion for Comp {
10 fn name(&self) -> StaticName {
11 <Comp as Op>::name(self)
12 }
13
14 fn rules<'r, 'p: 'r, 's: 'r>(
15 &'s self,
16 s: &mut Solver<'r>,
17 inputs: &'p [TensorProxy],
18 outputs: &'p [TensorProxy],
19 ) -> InferenceResult {
20 super::binary::rules(s, inputs, outputs, |_, _| Ok(bool::datum_type()))
21 }
22
23 fn wire(
24 &self,
25 prefix: &str,
26 target: &mut TypedModel,
27 inputs: &[OutletId],
28 ) -> TractResult<TVec<OutletId>> {
29 let a = target.outlet_fact(inputs[0])?;
30 let b = target.outlet_fact(inputs[1])?;
31 let operating_datum_type = a
32 .datum_type
33 .common_super_type(b.datum_type)
34 .with_context(|| format!("No super type for {a:?} and {b:?}"))?;
35 let wires = wire_rank_broadcast(prefix, target, inputs)?;
36 let wires = wire_cast(prefix, target, &wires, operating_datum_type)?;
37 target.wire_node(prefix, *self, &wires)
38 }
39}
40
41#[derive(Debug, Clone, Hash)]
42pub struct Iff;
43
44impl Expansion for Iff {
45 fn name(&self) -> StaticName {
46 "Iff".into()
47 }
48
49 fn rules<'r, 'p: 'r, 's: 'r>(
50 &'s self,
51 s: &mut Solver<'r>,
52 inputs: &'p [TensorProxy],
53 outputs: &'p [TensorProxy],
54 ) -> InferenceResult {
55 check_input_arity(inputs, 3)?;
56 check_output_arity(outputs, 1)?;
57 s.equals(&inputs[0].datum_type, DatumType::Bool)?;
58 s.given_2(&inputs[1].datum_type, &inputs[2].datum_type, move |s, a, b| {
59 let dt = a
60 .common_super_type(b)
61 .with_context(|| format!("No super type for {a:?} and {b:?}"))?;
62 s.equals(&outputs[0].datum_type, dt)
63 })?;
64 s.given_3(&inputs[0].shape, &inputs[1].shape, &inputs[2].shape, move |s, c, t, f| {
65 let shape = multi_broadcast(&[&c, &t, &f])?;
66 s.equals(&outputs[0].shape, shape)
67 })?;
68 Ok(())
69 }
70
71 fn wire(
72 &self,
73 prefix: &str,
74 model: &mut TypedModel,
75 inputs: &[OutletId],
76 ) -> TractResult<TVec<OutletId>> {
77 let dta = model.outlet_fact(inputs[1])?.datum_type;
78 let dtb = model.outlet_fact(inputs[2])?.datum_type;
79 let dt = dta
80 .common_super_type(dtb)
81 .with_context(|| format!("No super type for {dta:?} and {dtb:?}"))?;
82 let mut casted = wire_cast(prefix, model, &inputs[1..], dt)?;
83 casted.insert(0, inputs[0]);
84 wire_with_rank_broadcast(prefix, model, tract_core::ops::logic::Iff, &casted)
85 }
86}