1use std::collections::HashSet;
2use std::convert::TryFrom;
3use std::fmt;
4use std::ops::Deref;
5
6use async_hash::{Digest, Hash, Output};
7use async_trait::async_trait;
8use destream::{de, en};
9use get_size::GetSize;
10use get_size_derive::*;
11use safecast::{Match, TryCastFrom};
12
13use tc_error::*;
14use tc_transact::public::{Public, StateInstance, ToState};
15use tc_value::Value;
16use tcgeneric::{Id, Instance, Map, PathSegment, TCPathBuf, Tuple};
17
18use crate::{OpDef, Scalar, Scope};
19
20use super::{Refer, TCRef};
21
22#[derive(Clone, Eq, PartialEq, GetSize)]
24pub struct Case {
25 cond: TCRef,
26 switch: Tuple<Scalar>,
27 case: Tuple<Scalar>,
28}
29
30#[async_trait]
31impl<State> Refer<State> for Case
32where
33 State: StateInstance + Refer<State> + From<Scalar>,
34 State::Closure: From<(Map<State>, OpDef)> + TryCastFrom<State>,
35 Map<State>: TryFrom<State, Error = TCError>,
36 Value: TryFrom<State, Error = TCError> + TryCastFrom<State>,
37 bool: TryCastFrom<State>,
38{
39 fn dereference_self(self, path: &TCPathBuf) -> Self {
40 Self {
41 cond: self.cond.dereference_self(path),
42
43 switch: self
44 .switch
45 .into_iter()
46 .map(|scalar| scalar.dereference_self(path))
47 .collect(),
48
49 case: self
50 .case
51 .into_iter()
52 .map(|scalar| scalar.dereference_self(path))
53 .collect(),
54 }
55 }
56
57 fn is_conditional(&self) -> bool {
58 true
59 }
60
61 fn is_inter_service_write(&self, cluster_path: &[PathSegment]) -> bool {
62 self.cond.is_inter_service_write(cluster_path)
63 || self
64 .switch
65 .iter()
66 .any(|scalar| scalar.is_inter_service_write(cluster_path))
67 || self
68 .case
69 .iter()
70 .any(|scalar| scalar.is_inter_service_write(cluster_path))
71 }
72
73 fn is_ref(&self) -> bool {
74 true
75 }
76
77 fn reference_self(self, path: &TCPathBuf) -> Self {
78 Self {
79 cond: self.cond.reference_self(path),
80
81 switch: self
82 .switch
83 .into_iter()
84 .map(|scalar| scalar.reference_self(path))
85 .collect(),
86
87 case: self
88 .case
89 .into_iter()
90 .map(|scalar| scalar.reference_self(path))
91 .collect(),
92 }
93 }
94
95 fn requires(&self, deps: &mut HashSet<Id>) {
96 self.cond.requires(deps);
97
98 for switch in self.switch.iter() {
99 switch.requires(deps);
100 }
101 }
102
103 async fn resolve<'a, T: ToState<State> + Public<State> + Instance>(
104 mut self,
105 context: &'a Scope<'a, State, T>,
106 txn: &'a State::Txn,
107 ) -> TCResult<State> {
108 assert_eq!(self.switch.len() + 1, self.case.len());
109
110 if self.cond.is_conditional() {
111 return Err(bad_request!(
112 "Case does not allow a nested conditional {:?}",
113 self.cond,
114 ));
115 }
116
117 for switch in self.switch.iter() {
118 if switch.is_conditional() {
119 return Err(bad_request!(
120 "Case does not allow a nested conditional {:?}",
121 switch,
122 ));
123 }
124 }
125
126 let cond = self.cond.resolve(context, txn).await?;
127 let cond = Value::try_from(cond)?;
128 for (i, switch) in self.switch.into_iter().enumerate() {
129 let switch = switch.resolve(context, txn).await?;
130 let switch = Value::try_from(switch)?;
131 if cond == switch {
132 return Ok(self.case.remove(i).into());
133 }
134 }
135
136 Ok(self.case.pop().unwrap().into())
137 }
138}
139
140impl<'a, D: Digest> Hash<D> for &'a Case {
141 fn hash(self) -> Output<D> {
142 Hash::<D>::hash((&self.cond, self.switch.deref(), self.case.deref()))
143 }
144}
145
146impl TryCastFrom<Scalar> for Case {
147 fn can_cast_from(scalar: &Scalar) -> bool {
148 scalar.matches::<(TCRef, Tuple<Scalar>, Tuple<Scalar>)>()
149 }
150
151 fn opt_cast_from(scalar: Scalar) -> Option<Self> {
152 if let Some((cond, switch, case)) =
153 <(TCRef, Tuple<Scalar>, Tuple<Scalar>)>::opt_cast_from(scalar)
154 {
155 if case.len() == switch.len() + 1 {
156 Some(Case { cond, switch, case })
157 } else {
158 None
159 }
160 } else {
161 None
162 }
163 }
164}
165
166#[async_trait]
167impl de::FromStream for Case {
168 type Context = ();
169
170 async fn from_stream<D: de::Decoder>(context: (), decoder: &mut D) -> Result<Self, D::Error> {
171 let (cond, switch, case) =
172 <(TCRef, Tuple<Scalar>, Tuple<Scalar>) as de::FromStream>::from_stream(
173 context, decoder,
174 )
175 .await?;
176
177 if case.len() == switch.len() + 1 {
178 Ok(Self { cond, switch, case })
179 } else {
180 Err(de::Error::custom(
181 "case length must equal switch length plus one",
182 ))
183 }
184 }
185}
186
187impl<'en> en::IntoStream<'en> for Case {
188 fn into_stream<E: en::Encoder<'en>>(self, encoder: E) -> Result<E::Ok, E::Error> {
189 (self.cond, self.switch.into_inner(), self.case.into_inner()).into_stream(encoder)
190 }
191}
192
193impl<'en> en::ToStream<'en> for Case {
194 fn to_stream<E: en::Encoder<'en>>(&'en self, encoder: E) -> Result<E::Ok, E::Error> {
195 en::IntoStream::into_stream(
196 (&self.cond, self.switch.deref(), self.case.deref()),
197 encoder,
198 )
199 }
200}
201
202impl fmt::Debug for Case {
203 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
204 write!(f, "switch ({:?})...", self.cond)
205 }
206}