tc_scalar/reference/
case.rs

1use std::collections::HashSet;
2use std::convert::TryFrom;
3use std::fmt;
4use std::ops::Deref;
5
6use async_hash::{Digest, Hash, Output};
7use async_trait::async_trait;
8use destream::{de, en};
9use get_size::GetSize;
10use get_size_derive::*;
11use safecast::{Match, TryCastFrom};
12
13use tc_error::*;
14use tc_transact::public::{Public, StateInstance, ToState};
15use tc_value::Value;
16use tcgeneric::{Id, Instance, Map, PathSegment, TCPathBuf, Tuple};
17
18use crate::{OpDef, Scalar, Scope};
19
20use super::{Refer, TCRef};
21
22/// A switch-case flow control
23#[derive(Clone, Eq, PartialEq, GetSize)]
24pub struct Case {
25    cond: TCRef,
26    switch: Tuple<Scalar>,
27    case: Tuple<Scalar>,
28}
29
30#[async_trait]
31impl<State> Refer<State> for Case
32where
33    State: StateInstance + Refer<State> + From<Scalar>,
34    State::Closure: From<(Map<State>, OpDef)> + TryCastFrom<State>,
35    Map<State>: TryFrom<State, Error = TCError>,
36    Value: TryFrom<State, Error = TCError> + TryCastFrom<State>,
37    bool: TryCastFrom<State>,
38{
39    fn dereference_self(self, path: &TCPathBuf) -> Self {
40        Self {
41            cond: self.cond.dereference_self(path),
42
43            switch: self
44                .switch
45                .into_iter()
46                .map(|scalar| scalar.dereference_self(path))
47                .collect(),
48
49            case: self
50                .case
51                .into_iter()
52                .map(|scalar| scalar.dereference_self(path))
53                .collect(),
54        }
55    }
56
57    fn is_conditional(&self) -> bool {
58        true
59    }
60
61    fn is_inter_service_write(&self, cluster_path: &[PathSegment]) -> bool {
62        self.cond.is_inter_service_write(cluster_path)
63            || self
64                .switch
65                .iter()
66                .any(|scalar| scalar.is_inter_service_write(cluster_path))
67            || self
68                .case
69                .iter()
70                .any(|scalar| scalar.is_inter_service_write(cluster_path))
71    }
72
73    fn is_ref(&self) -> bool {
74        true
75    }
76
77    fn reference_self(self, path: &TCPathBuf) -> Self {
78        Self {
79            cond: self.cond.reference_self(path),
80
81            switch: self
82                .switch
83                .into_iter()
84                .map(|scalar| scalar.reference_self(path))
85                .collect(),
86
87            case: self
88                .case
89                .into_iter()
90                .map(|scalar| scalar.reference_self(path))
91                .collect(),
92        }
93    }
94
95    fn requires(&self, deps: &mut HashSet<Id>) {
96        self.cond.requires(deps);
97
98        for switch in self.switch.iter() {
99            switch.requires(deps);
100        }
101    }
102
103    async fn resolve<'a, T: ToState<State> + Public<State> + Instance>(
104        mut self,
105        context: &'a Scope<'a, State, T>,
106        txn: &'a State::Txn,
107    ) -> TCResult<State> {
108        assert_eq!(self.switch.len() + 1, self.case.len());
109
110        if self.cond.is_conditional() {
111            return Err(bad_request!(
112                "Case does not allow a nested conditional {:?}",
113                self.cond,
114            ));
115        }
116
117        for switch in self.switch.iter() {
118            if switch.is_conditional() {
119                return Err(bad_request!(
120                    "Case does not allow a nested conditional {:?}",
121                    switch,
122                ));
123            }
124        }
125
126        let cond = self.cond.resolve(context, txn).await?;
127        let cond = Value::try_from(cond)?;
128        for (i, switch) in self.switch.into_iter().enumerate() {
129            let switch = switch.resolve(context, txn).await?;
130            let switch = Value::try_from(switch)?;
131            if cond == switch {
132                return Ok(self.case.remove(i).into());
133            }
134        }
135
136        Ok(self.case.pop().unwrap().into())
137    }
138}
139
140impl<'a, D: Digest> Hash<D> for &'a Case {
141    fn hash(self) -> Output<D> {
142        Hash::<D>::hash((&self.cond, self.switch.deref(), self.case.deref()))
143    }
144}
145
146impl TryCastFrom<Scalar> for Case {
147    fn can_cast_from(scalar: &Scalar) -> bool {
148        scalar.matches::<(TCRef, Tuple<Scalar>, Tuple<Scalar>)>()
149    }
150
151    fn opt_cast_from(scalar: Scalar) -> Option<Self> {
152        if let Some((cond, switch, case)) =
153            <(TCRef, Tuple<Scalar>, Tuple<Scalar>)>::opt_cast_from(scalar)
154        {
155            if case.len() == switch.len() + 1 {
156                Some(Case { cond, switch, case })
157            } else {
158                None
159            }
160        } else {
161            None
162        }
163    }
164}
165
166#[async_trait]
167impl de::FromStream for Case {
168    type Context = ();
169
170    async fn from_stream<D: de::Decoder>(context: (), decoder: &mut D) -> Result<Self, D::Error> {
171        let (cond, switch, case) =
172            <(TCRef, Tuple<Scalar>, Tuple<Scalar>) as de::FromStream>::from_stream(
173                context, decoder,
174            )
175            .await?;
176
177        if case.len() == switch.len() + 1 {
178            Ok(Self { cond, switch, case })
179        } else {
180            Err(de::Error::custom(
181                "case length must equal switch length plus one",
182            ))
183        }
184    }
185}
186
187impl<'en> en::IntoStream<'en> for Case {
188    fn into_stream<E: en::Encoder<'en>>(self, encoder: E) -> Result<E::Ok, E::Error> {
189        (self.cond, self.switch.into_inner(), self.case.into_inner()).into_stream(encoder)
190    }
191}
192
193impl<'en> en::ToStream<'en> for Case {
194    fn to_stream<E: en::Encoder<'en>>(&'en self, encoder: E) -> Result<E::Ok, E::Error> {
195        en::IntoStream::into_stream(
196            (&self.cond, self.switch.deref(), self.case.deref()),
197            encoder,
198        )
199    }
200}
201
202impl fmt::Debug for Case {
203    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
204        write!(f, "switch ({:?})...", self.cond)
205    }
206}