wick_invocation_server/
invocation_server.rs1use std::collections::HashMap;
2use std::time::{Duration, Instant};
3
4use flow_component::SharedComponent;
5use parking_lot::RwLock;
6use tokio::sync::mpsc;
7use tokio_stream::wrappers::ReceiverStream;
8use tokio_stream::StreamExt;
9use tonic::{Response, Status};
10use wick_packet::PacketStream;
11use wick_rpc::rpc::invocation_service_server::InvocationService;
12use wick_rpc::rpc::{InvocationRequest, ListResponse, Packet, StatsResponse};
13use wick_rpc::{rpc, DurationStatistics, Statistics};
14
15pub struct InvocationServer {
17 pub collection: SharedComponent,
19
20 stats: RwLock<HashMap<String, Statistics>>,
21}
22
23impl std::fmt::Debug for InvocationServer {
24 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
25 f.debug_struct("InvocationServer").finish()
26 }
27}
28
29impl InvocationServer {
30 #[must_use]
32 pub fn new(collection: SharedComponent) -> Self {
33 Self {
34 collection,
35 stats: RwLock::new(HashMap::new()),
36 }
37 }
38}
39
40#[derive(Debug, PartialEq, Eq, Copy, Clone)]
41enum JobResult {
42 Success,
43 Error,
44}
45
46impl InvocationServer {
47 fn record_execution<T: Into<String>>(&self, job: T, status: JobResult, time: Duration) {
48 let mut stats = self.stats.write();
49 let job = job.into();
50 let stat = stats.entry(job.clone()).or_insert_with(Statistics::default);
51 stat.runs += 1;
52 if status == JobResult::Error {
53 stat.errors += 1;
54 }
55 let durations = if stat.execution_duration.is_some() {
56 let mut durations = stat.execution_duration.take().unwrap();
57 if time < durations.min_time {
58 durations.min_time = time;
59 } else if time > durations.max_time {
60 durations.max_time = time;
61 }
62 let average = ((durations.average_time * (stat.runs - 1)) + time) / stat.runs;
63 durations.average_time = average;
64 let total = durations.total_time + time;
65 durations.total_time = total;
66
67 durations
68 } else {
69 DurationStatistics::new(time, time, time, time)
70 };
71 stat.execution_duration.replace(durations);
72 }
73}
74
75fn convert_invocation_stream(mut streaming: tonic::Streaming<InvocationRequest>) -> PacketStream {
76 let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
77 tokio::spawn(async move {
78 while let Some(p) = streaming.next().await {
79 let result = p.map_err(|e| wick_packet::Error::Component(e.to_string())).map(|p| {
80 p.data.map_or_else(
81 || unreachable!(),
82 |p| match p {
83 rpc::invocation_request::Data::Invocation(_) => unreachable!(),
84 rpc::invocation_request::Data::Packet(p) => wick_packet::Packet::from(p),
85 },
86 )
87 });
88
89 let _ = tx.send(result);
90 }
91 });
92
93 wick_packet::PacketStream::new(Box::new(tokio_stream::wrappers::UnboundedReceiverStream::new(rx)))
94}
95
96#[async_trait::async_trait]
97impl InvocationService for InvocationServer {
98 type InvokeStream = ReceiverStream<Result<Packet, Status>>;
99
100 async fn invoke(
101 &self,
102 request: tonic::Request<tonic::Streaming<InvocationRequest>>,
103 ) -> Result<Response<Self::InvokeStream>, Status> {
104 let start = Instant::now();
105
106 let (tx, rx) = mpsc::channel(4);
107 let mut stream = request.into_inner();
108 let first = stream.next().await;
109 let invocation: wick_packet::InvocationData = if let Some(Ok(inv)) = first {
110 if let Some(rpc::invocation_request::Data::Invocation(inv)) = inv.data {
111 inv
112 .try_into()
113 .map_err(|_| Status::invalid_argument("First message must be a valid invocation"))?
114 } else {
115 return Err(Status::invalid_argument("First message must be an invocation"));
116 }
117 } else {
118 return Err(Status::invalid_argument("First message must be an invocation"));
119 };
120 let stream = convert_invocation_stream(stream);
121 let packet_stream = PacketStream::new(Box::new(stream));
122 let invocation = invocation.with_stream(packet_stream);
123
124 let op_id = invocation.target().operation_id().to_owned();
125
126 let result = self
127 .collection
128 .handle(invocation, Default::default(), Default::default())
129 .await;
130 if let Err(e) = result {
131 let message = e.to_string();
132 error!("invocation failed: {}", message);
133 tx.send(Err(Status::internal(message))).await.unwrap();
134 self.record_execution(op_id, JobResult::Error, start.elapsed());
135 } else {
136 tokio::spawn(async move {
137 let mut receiver = result.unwrap();
138 while let Some(next) = receiver.next().await {
139 if next.is_err() {
140 todo!("Handle error");
141 }
142 let next = next.unwrap();
143
144 tx.send(Ok(next.into())).await.unwrap();
145 }
146 });
147 self.record_execution(op_id, JobResult::Success, start.elapsed());
148 }
149
150 Ok(Response::new(ReceiverStream::new(rx)))
151 }
152
153 async fn list(&self, _request: tonic::Request<rpc::ListRequest>) -> Result<Response<ListResponse>, Status> {
154 let response = ListResponse {
155 components: vec![self.collection.signature().clone().try_into().unwrap()],
156 };
157 Ok(Response::new(response))
158 }
159
160 async fn stats(&self, _request: tonic::Request<rpc::StatsRequest>) -> Result<Response<StatsResponse>, Status> {
161 Ok(Response::new(StatsResponse {
162 stats: self.stats.read().values().cloned().map(From::from).collect(),
163 }))
164 }
165}
166
167#[cfg(test)]
168mod tests {
169 }