tc_scalar/reference/
if.rs

1//! Resolve a reference conditionally.
2
3use std::collections::HashSet;
4use std::fmt;
5
6use async_hash::{Digest, Hash, Output};
7use async_trait::async_trait;
8use destream::{de, en};
9use get_size::GetSize;
10use get_size_derive::*;
11use log::debug;
12use safecast::{Match, TryCastFrom, TryCastInto};
13
14use tc_error::*;
15use tc_transact::public::{Public, StateInstance, ToState};
16use tc_value::Value;
17use tcgeneric::{Id, Instance, Map, PathSegment, TCPathBuf};
18
19use crate::{OpDef, Scalar, Scope};
20
21use super::{Refer, TCRef};
22
23/// A conditional reference.
24#[derive(Clone, Eq, PartialEq, GetSize)]
25pub struct IfRef {
26    cond: TCRef,
27    then: Scalar,
28    or_else: Scalar,
29}
30
31#[async_trait]
32impl<State> Refer<State> for IfRef
33where
34    State: StateInstance + Refer<State> + From<Scalar>,
35    State::Closure: From<(Map<State>, OpDef)> + TryCastFrom<State>,
36    Map<State>: TryFrom<State, Error = TCError>,
37    Value: TryFrom<State, Error = TCError> + TryCastFrom<State>,
38    bool: TryCastFrom<State>,
39{
40    fn dereference_self(self, path: &TCPathBuf) -> Self {
41        Self {
42            cond: self.cond.dereference_self(path),
43            then: self.then.dereference_self(path),
44            or_else: self.or_else.dereference_self(path),
45        }
46    }
47
48    fn is_conditional(&self) -> bool {
49        true
50    }
51
52    fn is_inter_service_write(&self, cluster_path: &[PathSegment]) -> bool {
53        self.cond.is_inter_service_write(cluster_path)
54            || self.then.is_inter_service_write(cluster_path)
55            || self.or_else.is_inter_service_write(cluster_path)
56    }
57
58    fn is_ref(&self) -> bool {
59        true
60    }
61
62    fn reference_self(self, path: &TCPathBuf) -> Self {
63        Self {
64            cond: self.cond.reference_self(path),
65            then: self.then.reference_self(path),
66            or_else: self.or_else.reference_self(path),
67        }
68    }
69
70    fn requires(&self, deps: &mut HashSet<Id>) {
71        self.cond.requires(deps);
72    }
73
74    async fn resolve<'a, T: ToState<State> + Public<State> + Instance>(
75        self,
76        context: &'a Scope<'a, State, T>,
77        txn: &'a State::Txn,
78    ) -> TCResult<State> {
79        debug!("If::resolve {:?}", self);
80
81        if self.cond.is_conditional() {
82            return Err(bad_request!(
83                "If does not allow a nested conditional {:?}",
84                self.cond,
85            ));
86        }
87
88        let cond = self.cond.resolve(context, txn).await?;
89        debug!("If condition is {:?}", cond);
90
91        if cond.matches::<bool>() {
92            if cond.opt_cast_into().expect("if condition") {
93                Ok(self.then.into())
94            } else {
95                Ok(self.or_else.into())
96            }
97        } else {
98            Err(TCError::unexpected(cond, "a boolean condition"))
99        }
100    }
101}
102
103impl<'a, D: Digest> Hash<D> for &'a IfRef {
104    fn hash(self) -> Output<D> {
105        Hash::<D>::hash((&self.cond, &self.then, &self.or_else))
106    }
107}
108
109impl TryCastFrom<Scalar> for IfRef {
110    fn can_cast_from(scalar: &Scalar) -> bool {
111        scalar.matches::<(TCRef, Scalar, Scalar)>()
112    }
113
114    fn opt_cast_from(scalar: Scalar) -> Option<Self> {
115        scalar.opt_cast_into().map(|(cond, then, or_else)| Self {
116            cond,
117            then,
118            or_else,
119        })
120    }
121}
122
123#[async_trait]
124impl de::FromStream for IfRef {
125    type Context = ();
126
127    async fn from_stream<D: de::Decoder>(context: (), decoder: &mut D) -> Result<Self, D::Error> {
128        let (cond, then, or_else) =
129            <(TCRef, Scalar, Scalar) as de::FromStream>::from_stream(context, decoder).await?;
130
131        Ok(Self {
132            cond,
133            then,
134            or_else,
135        })
136    }
137}
138
139impl<'en> en::IntoStream<'en> for IfRef {
140    fn into_stream<E: en::Encoder<'en>>(self, encoder: E) -> Result<E::Ok, E::Error> {
141        (self.cond, self.then, self.or_else).into_stream(encoder)
142    }
143}
144
145impl<'en> en::ToStream<'en> for IfRef {
146    fn to_stream<E: en::Encoder<'en>>(&'en self, encoder: E) -> Result<E::Ok, E::Error> {
147        en::IntoStream::into_stream((&self.cond, &self.then, &self.or_else), encoder)
148    }
149}
150
151impl fmt::Debug for IfRef {
152    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
153        write!(
154            f,
155            "if {:?} then {:?} else {:?}",
156            self.cond, self.then, self.or_else
157        )
158    }
159}