tract_hir/ops/
logic.rs

1use crate::infer::*;
2use crate::internal::*;
3
4use tract_core::broadcast::multi_broadcast;
5use tract_core::ops::cast::wire_cast;
6pub use tract_core::ops::change_axes::wire_with_rank_broadcast;
7pub use tract_core::ops::logic::*;
8
9impl Expansion for Comp {
10    fn name(&self) -> StaticName {
11        <Comp as Op>::name(self)
12    }
13
14    fn rules<'r, 'p: 'r, 's: 'r>(
15        &'s self,
16        s: &mut Solver<'r>,
17        inputs: &'p [TensorProxy],
18        outputs: &'p [TensorProxy],
19    ) -> InferenceResult {
20        super::binary::rules(s, inputs, outputs, |_, _| Ok(bool::datum_type()))
21    }
22
23    fn wire(
24        &self,
25        prefix: &str,
26        target: &mut TypedModel,
27        inputs: &[OutletId],
28    ) -> TractResult<TVec<OutletId>> {
29        let a = target.outlet_fact(inputs[0])?;
30        let b = target.outlet_fact(inputs[1])?;
31        let operating_datum_type = a
32            .datum_type
33            .common_super_type(b.datum_type)
34            .with_context(|| format!("No super type for {a:?} and {b:?}"))?;
35        let wires = wire_rank_broadcast(prefix, target, inputs)?;
36        let wires = wire_cast(prefix, target, &wires, operating_datum_type)?;
37        target.wire_node(prefix, *self, &wires)
38    }
39}
40
41#[derive(Debug, Clone, Hash)]
42pub struct Iff;
43
44impl Expansion for Iff {
45    fn name(&self) -> StaticName {
46        "Iff".into()
47    }
48
49    fn rules<'r, 'p: 'r, 's: 'r>(
50        &'s self,
51        s: &mut Solver<'r>,
52        inputs: &'p [TensorProxy],
53        outputs: &'p [TensorProxy],
54    ) -> InferenceResult {
55        check_input_arity(inputs, 3)?;
56        check_output_arity(outputs, 1)?;
57        s.equals(&inputs[0].datum_type, DatumType::Bool)?;
58        s.given_2(&inputs[1].datum_type, &inputs[2].datum_type, move |s, a, b| {
59            let dt = a
60                .common_super_type(b)
61                .with_context(|| format!("No super type for {a:?} and {b:?}"))?;
62            s.equals(&outputs[0].datum_type, dt)
63        })?;
64        s.given_3(&inputs[0].shape, &inputs[1].shape, &inputs[2].shape, move |s, c, t, f| {
65            let shape = multi_broadcast(&[&c, &t, &f])?;
66            s.equals(&outputs[0].shape, shape)
67        })?;
68        Ok(())
69    }
70
71    fn wire(
72        &self,
73        prefix: &str,
74        model: &mut TypedModel,
75        inputs: &[OutletId],
76    ) -> TractResult<TVec<OutletId>> {
77        let dta = model.outlet_fact(inputs[1])?.datum_type;
78        let dtb = model.outlet_fact(inputs[2])?.datum_type;
79        let dt = dta
80            .common_super_type(dtb)
81            .with_context(|| format!("No super type for {dta:?} and {dtb:?}"))?;
82        let mut casted = wire_cast(prefix, model, &inputs[1..], dt)?;
83        casted.insert(0, inputs[0]);
84        wire_with_rank_broadcast(prefix, model, tract_core::ops::logic::Iff, &casted)
85    }
86}