1use std::collections::HashSet;
4use std::convert::TryFrom;
5use std::fmt;
6use std::ops::Deref;
7
8use async_hash::{Digest, Hash, Output};
9use async_trait::async_trait;
10use destream::de::{self, Decoder, FromStream};
11use destream::en::{EncodeMap, Encoder, IntoStream, ToStream};
12use futures::TryFutureExt;
13use get_size::GetSize;
14use get_size_derive::*;
15use log::debug;
16use safecast::TryCastFrom;
17
18use tc_error::*;
19use tc_transact::public::{Public, StateInstance, ToState};
20use tcgeneric::*;
21
22use super::{OpDef, Scalar, Scope, Value};
23
24pub use after::After;
25pub use case::Case;
26pub use id::*;
27pub use op::*;
28pub use r#if::IfRef;
29pub use r#while::While;
30pub use with::With;
31
32mod after;
33mod case;
34mod r#if;
35mod r#while;
36mod with;
37
38pub mod id;
39pub mod op;
40
41const PREFIX: PathLabel = path_label(&["state", "scalar", "ref"]);
42
43#[async_trait]
45pub trait Refer<State: StateInstance> {
46 fn dereference_self(self, path: &TCPathBuf) -> Self;
50
51 fn is_conditional(&self) -> bool;
53
54 fn is_inter_service_write(&self, cluster_path: &[PathSegment]) -> bool;
56
57 fn is_ref(&self) -> bool;
59
60 fn reference_self(self, path: &TCPathBuf) -> Self;
64
65 fn requires(&self, deps: &mut HashSet<Id>);
67
68 async fn resolve<'a, T: ToState<State> + Public<State> + Instance>(
70 self,
71 context: &'a Scope<'a, State, T>,
72 txn: &'a State::Txn,
73 ) -> TCResult<State>;
74}
75
76#[derive(Clone, Copy, Eq, PartialEq)]
78pub enum RefType {
79 After,
80 Case,
81 Id,
82 If,
83 Op(OpRefType),
84 While,
85 With,
86}
87
88impl Class for RefType {}
89
90impl NativeClass for RefType {
91 fn from_path(path: &[PathSegment]) -> Option<Self> {
92 if path.len() == 4 && &path[0..3] == &PREFIX[..] {
93 match path[3].as_str() {
94 "after" => Some(Self::After),
95 "case" => Some(Self::Case),
96 "id" => Some(Self::Id),
97 "if" => Some(Self::If),
98 "while" => Some(Self::While),
99 "with" => Some(Self::With),
100 _ => None,
101 }
102 } else if let Some(ort) = OpRefType::from_path(path) {
103 Some(RefType::Op(ort))
104 } else {
105 None
106 }
107 }
108
109 fn path(&self) -> TCPathBuf {
110 let suffix = match self {
111 Self::After => "after",
112 Self::Case => "case",
113 Self::Id => "id",
114 Self::If => "if",
115 Self::Op(ort) => return ort.path(),
116 Self::While => "while",
117 Self::With => "with",
118 };
119
120 TCPathBuf::from(PREFIX).append(label(suffix))
121 }
122}
123
124impl fmt::Debug for RefType {
125 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
126 match self {
127 Self::After => f.write_str("After"),
128 Self::Case => f.write_str("Case"),
129 Self::Id => f.write_str("Id"),
130 Self::If => f.write_str("If"),
131 Self::Op(ort) => fmt::Debug::fmt(ort, f),
132 Self::While => f.write_str("While"),
133 Self::With => f.write_str("With"),
134 }
135 }
136}
137
138#[derive(Clone, Eq, PartialEq, GetSize)]
140pub enum TCRef {
141 After(Box<After>),
142 Case(Box<Case>),
143 Id(IdRef),
144 If(Box<IfRef>),
145 Op(OpRef),
146 While(Box<While>),
147 With(Box<With>),
148}
149
150impl Instance for TCRef {
151 type Class = RefType;
152
153 fn class(&self) -> Self::Class {
154 match self {
155 Self::After(_) => RefType::After,
156 Self::Case(_) => RefType::Case,
157 Self::Id(_) => RefType::Id,
158 Self::If(_) => RefType::If,
159 Self::Op(op_ref) => RefType::Op(op_ref.class()),
160 Self::While(_) => RefType::While,
161 Self::With(_) => RefType::With,
162 }
163 }
164}
165
166#[async_trait]
167impl<State> Refer<State> for TCRef
168where
169 State: StateInstance + Refer<State> + From<Scalar>,
170 State::Closure: From<(Map<State>, OpDef)> + TryCastFrom<State>,
171 Map<State>: TryFrom<State, Error = TCError>,
172 Value: TryFrom<State, Error = TCError> + TryCastFrom<State>,
173 bool: TryCastFrom<State>,
174{
175 fn dereference_self(self, path: &TCPathBuf) -> Self {
176 match self {
177 Self::After(after) => {
178 let after = after.dereference_self(path);
179 Self::After(Box::new(after))
180 }
181 Self::Case(case) => {
182 let case = case.dereference_self(path);
183 Self::Case(Box::new(case))
184 }
185 Self::Id(id_ref) => Self::Id(Refer::<State>::dereference_self(id_ref, path)),
186 Self::If(if_ref) => {
187 let if_ref = if_ref.dereference_self(path);
188 Self::If(Box::new(if_ref))
189 }
190 Self::Op(op_ref) => Self::Op(op_ref.dereference_self(path)),
191 Self::While(while_ref) => {
192 let while_ref = while_ref.dereference_self(path);
193 Self::While(Box::new(while_ref))
194 }
195 Self::With(with) => {
196 let with = with.dereference_self(path);
197 Self::With(Box::new(with))
198 }
199 }
200 }
201
202 fn is_conditional(&self) -> bool {
203 match self {
204 Self::After(after) => after.is_conditional(),
205 Self::Case(case) => case.is_conditional(),
206 Self::Id(id_ref) => Refer::<State>::is_conditional(id_ref),
207 Self::If(if_ref) => if_ref.is_conditional(),
208 Self::Op(op_ref) => op_ref.is_conditional(),
209 Self::While(while_ref) => while_ref.is_conditional(),
210 Self::With(with) => with.is_conditional(),
211 }
212 }
213
214 fn is_inter_service_write(&self, cluster_path: &[PathSegment]) -> bool {
215 match self {
216 Self::After(after) => after.is_inter_service_write(cluster_path),
217 Self::Case(case) => case.is_inter_service_write(cluster_path),
218 Self::Id(id_ref) => Refer::<State>::is_inter_service_write(id_ref, cluster_path),
219 Self::If(if_ref) => if_ref.is_inter_service_write(cluster_path),
220 Self::Op(op_ref) => op_ref.is_inter_service_write(cluster_path),
221 Self::While(while_ref) => while_ref.is_inter_service_write(cluster_path),
222 Self::With(with) => with.is_inter_service_write(cluster_path),
223 }
224 }
225
226 fn is_ref(&self) -> bool {
227 true
228 }
229
230 fn reference_self(self, path: &TCPathBuf) -> Self {
231 match self {
232 Self::After(after) => {
233 let after = after.reference_self(path);
234 Self::After(Box::new(after))
235 }
236 Self::Case(case) => {
237 let case = case.reference_self(path);
238 Self::Case(Box::new(case))
239 }
240 Self::Id(id_ref) => Self::Id(Refer::<State>::reference_self(id_ref, path)),
241 Self::If(if_ref) => {
242 let if_ref = if_ref.reference_self(path);
243 Self::If(Box::new(if_ref))
244 }
245 Self::Op(op_ref) => Self::Op(op_ref.reference_self(path)),
246 Self::While(while_ref) => {
247 let while_ref = while_ref.reference_self(path);
248 Self::While(Box::new(while_ref))
249 }
250 Self::With(with) => {
251 let with = with.reference_self(path);
252 Self::With(Box::new(with))
253 }
254 }
255 }
256
257 fn requires(&self, deps: &mut HashSet<Id>) {
258 match self {
259 Self::After(after) => after.requires(deps),
260 Self::Case(case) => case.requires(deps),
261 Self::Id(id_ref) => Refer::<State>::requires(id_ref, deps),
262 Self::If(if_ref) => if_ref.requires(deps),
263 Self::Op(op_ref) => op_ref.requires(deps),
264 Self::While(while_ref) => while_ref.requires(deps),
265 Self::With(with) => with.requires(deps),
266 }
267 }
268
269 async fn resolve<'a, T: ToState<State> + Public<State> + Instance>(
270 self,
271 context: &'a Scope<'a, State, T>,
272 txn: &'a State::Txn,
273 ) -> TCResult<State> {
274 debug!("TCRef::resolve {:?}", self);
275
276 match self {
277 Self::After(after) => after.resolve(context, txn).await,
278 Self::Case(case) => case.resolve(context, txn).await,
279 Self::Id(id_ref) => Refer::<State>::resolve(id_ref, context, txn).await,
280 Self::If(if_ref) => if_ref.resolve(context, txn).await,
281 Self::Op(op_ref) => op_ref.resolve(context, txn).await,
282 Self::While(while_ref) => while_ref.resolve(context, txn).await,
283 Self::With(with) => with.resolve(context, txn).await,
284 }
285 }
286}
287
288impl<'a, D: Digest> Hash<D> for &'a TCRef {
289 fn hash(self) -> Output<D> {
290 match self {
291 TCRef::After(after) => Hash::<D>::hash(after.deref()),
292 TCRef::Case(case) => Hash::<D>::hash(case.deref()),
293 TCRef::Id(id) => Hash::<D>::hash(id),
294 TCRef::If(if_ref) => Hash::<D>::hash(if_ref.deref()),
295 TCRef::Op(op) => Hash::<D>::hash(op),
296 TCRef::While(while_ref) => Hash::<D>::hash(while_ref.deref()),
297 TCRef::With(with) => Hash::<D>::hash(with.deref()),
298 }
299 }
300}
301
302impl TryFrom<TCRef> for OpRef {
303 type Error = TCError;
304
305 fn try_from(tc_ref: TCRef) -> TCResult<Self> {
306 match tc_ref {
307 TCRef::Op(op_ref) => Ok(op_ref),
308 other => Err(TCError::unexpected(other, "an OpRef")),
309 }
310 }
311}
312
313impl TryCastFrom<TCRef> for Id {
314 fn can_cast_from(tc_ref: &TCRef) -> bool {
315 match tc_ref {
316 TCRef::Id(_) => true,
317 _ => false,
318 }
319 }
320
321 fn opt_cast_from(tc_ref: TCRef) -> Option<Self> {
322 match tc_ref {
323 TCRef::Id(id_ref) => Some(id_ref.into_id()),
324 _ => None,
325 }
326 }
327}
328
329impl TryCastFrom<TCRef> for OpRef {
330 fn can_cast_from(tc_ref: &TCRef) -> bool {
331 match tc_ref {
332 TCRef::Op(_) => true,
333 _ => false,
334 }
335 }
336
337 fn opt_cast_from(tc_ref: TCRef) -> Option<Self> {
338 match tc_ref {
339 TCRef::Op(op) => Some(op),
340 _ => None,
341 }
342 }
343}
344
345pub struct RefVisitor;
347
348impl RefVisitor {
349 pub async fn visit_map_value<A: de::MapAccess>(
351 class: RefType,
352 access: &mut A,
353 ) -> Result<TCRef, A::Error> {
354 match class {
355 RefType::After => {
356 access
357 .next_value(())
358 .map_ok(Box::new)
359 .map_ok(TCRef::After)
360 .await
361 }
362 RefType::Case => {
363 access
364 .next_value(())
365 .map_ok(Box::new)
366 .map_ok(TCRef::Case)
367 .await
368 }
369 RefType::Id => access.next_value(()).map_ok(TCRef::Id).await,
370 RefType::If => {
371 access
372 .next_value(())
373 .map_ok(Box::new)
374 .map_ok(TCRef::If)
375 .await
376 }
377 RefType::Op(ort) => {
378 OpRefVisitor::visit_map_value(ort, access)
379 .map_ok(TCRef::Op)
380 .await
381 }
382 RefType::While => {
383 access
384 .next_value(())
385 .map_ok(Box::new)
386 .map_ok(TCRef::While)
387 .await
388 }
389 RefType::With => {
390 access
391 .next_value(())
392 .map_ok(Box::new)
393 .map_ok(TCRef::With)
394 .await
395 }
396 }
397 }
398
399 pub fn visit_ref_value<E: de::Error>(subject: Subject, params: Scalar) -> Result<TCRef, E> {
401 if params.is_none() {
402 match subject {
403 Subject::Link(link) => Err(de::Error::invalid_type(link, &"a Ref")),
404 Subject::Ref(id_ref, path) if path.is_empty() => Ok(TCRef::Id(id_ref)),
405 Subject::Ref(id_ref, path) => Ok(TCRef::Op(OpRef::Get((
406 Subject::Ref(id_ref, path),
407 Value::default().into(),
408 )))),
409 }
410 } else {
411 OpRefVisitor::visit_ref_value(subject, params).map(TCRef::Op)
412 }
413 }
414}
415
416#[async_trait]
417impl de::Visitor for RefVisitor {
418 type Value = TCRef;
419
420 fn expecting() -> &'static str {
421 "a Ref, like {\"$subject\": []} or {\"/path/to/op\": [\"key\"]"
422 }
423
424 async fn visit_map<A: de::MapAccess>(self, mut access: A) -> Result<Self::Value, A::Error> {
425 let subject = access.next_key::<Subject>(()).await?;
426
427 let subject =
428 subject.ok_or_else(|| de::Error::custom("expected a Ref or Link, found empty map"))?;
429
430 if let Subject::Link(link) = &subject {
431 if link.host().is_none() {
432 if let Some(class) = RefType::from_path(link.path()) {
433 debug!("RefVisitor visiting instance of {:?}...", class);
434 return Self::visit_map_value(class, &mut access).await;
435 }
436 }
437 }
438
439 let params = access.next_value(()).await?;
440 Self::visit_ref_value(subject, params)
441 }
442}
443
444#[async_trait]
445impl FromStream for TCRef {
446 type Context = ();
447
448 async fn from_stream<D: Decoder>(_: (), d: &mut D) -> Result<Self, <D as Decoder>::Error> {
449 d.decode_map(RefVisitor).await
450 }
451}
452
453impl<'en> ToStream<'en> for TCRef {
454 fn to_stream<E: Encoder<'en>>(&'en self, e: E) -> Result<E::Ok, E::Error> {
455 if let Self::Id(id_ref) = self {
456 return id_ref.to_stream(e);
457 } else if let Self::Op(op_ref) = self {
458 return op_ref.to_stream(e);
459 };
460
461 let mut map = e.encode_map(Some(1))?;
462
463 map.encode_key(self.class().path().to_string())?;
464 match self {
465 Self::Id(_) => unreachable!("TCRef::Id to_stream"),
466 Self::Op(_) => unreachable!("TCRef::Op to_stream"),
467
468 Self::After(after) => map.encode_value(after),
469 Self::Case(case) => map.encode_value(case),
470 Self::If(if_ref) => map.encode_value(if_ref),
471 Self::While(while_ref) => map.encode_value(while_ref),
472 Self::With(with) => map.encode_value(with),
473 }?;
474
475 map.end()
476 }
477}
478
479impl<'en> IntoStream<'en> for TCRef {
480 fn into_stream<E: Encoder<'en>>(self, e: E) -> Result<E::Ok, E::Error> {
481 if let Self::Id(id_ref) = self {
482 return id_ref.into_stream(e);
483 } else if let Self::Op(op_ref) = self {
484 return op_ref.into_stream(e);
485 };
486
487 let mut map = e.encode_map(Some(1))?;
488
489 map.encode_key(self.class().path().to_string())?;
490 match self {
491 Self::Id(_) => unreachable!("TCRef::Id into_stream"),
492 Self::Op(_) => unreachable!("TCRef::Op into_stream"),
493
494 Self::After(after) => map.encode_value(after),
495 Self::Case(case) => map.encode_value(case),
496 Self::If(if_ref) => map.encode_value(if_ref),
497 Self::While(while_ref) => map.encode_value(while_ref),
498 Self::With(with) => map.encode_value(with),
499 }?;
500
501 map.end()
502 }
503}
504
505impl fmt::Debug for TCRef {
506 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
507 match self {
508 Self::After(after) => fmt::Debug::fmt(after, f),
509 Self::Case(case) => fmt::Debug::fmt(case, f),
510 Self::Id(id_ref) => fmt::Debug::fmt(id_ref, f),
511 Self::If(if_ref) => fmt::Debug::fmt(if_ref, f),
512 Self::Op(op_ref) => fmt::Debug::fmt(op_ref, f),
513 Self::While(while_ref) => fmt::Debug::fmt(while_ref, f),
514 Self::With(with) => fmt::Debug::fmt(with, f),
515 }
516 }
517}