1use serde::{
2 de::{MapAccess, Visitor},
3 Deserialize, Deserializer, Serialize,
4};
5use std::{collections::BTreeMap, future::Future, pin::Pin};
6
7use crate::prelude::{holder::DatasetHolder, types::LogString, SiemResult};
8
9use super::common::UserRole;
10
11pub trait TaskBuilder2: std::fmt::Debug {
12 fn build(
13 &self,
14 task: SiemTask,
15 ) -> SiemResult<Pin<Box<dyn Future<Output = SiemTaskResult> + Send>>>
16 where
17 Self: Sized;
18 fn clone(&self) -> Box<dyn TaskBuilder2>;
19}
20
21pub type TaskBuilder = fn(
22 SiemTask,
23 &DatasetHolder,
24) -> SiemResult<Pin<Box<dyn Future<Output = SiemTaskResult> + Send>>>;
25
26#[derive(Serialize)]
27pub struct TaskDefinition {
28 data: SiemTaskData,
29 name: LogString,
30 description: LogString,
31 min_permission: UserRole,
32 fire_mode: TaskFireMode,
33 max_duration: u64,
35 #[serde(skip)]
36 builder: TaskBuilder,
37}
38
39impl TaskDefinition {
40 pub fn new(
41 data: SiemTaskData,
42 name: LogString,
43 description: LogString,
44 min_permission: UserRole,
45 fire_mode: TaskFireMode,
46 max_duration: u64,
47 builder: TaskBuilder,
48 ) -> TaskDefinition {
49 TaskDefinition {
50 data,
51 name,
52 description,
53 min_permission,
54 fire_mode,
55 max_duration,
56 builder,
57 }
58 }
59
60 pub fn data(&self) -> &SiemTaskData {
61 &self.data
62 }
63 pub fn name(&self) -> &str {
64 &self.name
65 }
66 pub fn description(&self) -> &str {
67 &self.description
68 }
69 pub fn min_permission(&self) -> &UserRole {
70 &self.min_permission
71 }
72 pub fn fire_mode(&self) -> &TaskFireMode {
73 &self.fire_mode
74 }
75 pub fn max_duration(&self) -> u64 {
76 self.max_duration
77 }
78 pub fn builder(&self) -> TaskBuilder {
79 self.builder
80 }
81}
82
83impl std::fmt::Debug for TaskDefinition {
84 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85 f.debug_struct("TaskDefinition")
86 .field("data", &self.data)
87 .field("name", &self.name)
88 .field("description", &self.description)
89 .field("min_permission", &self.min_permission)
90 .field("fire_mode", &self.fire_mode)
91 .field("max_duration", &self.max_duration)
92 .finish()
93 }
94}
95
96impl Clone for TaskDefinition {
97 fn clone(&self) -> Self {
98 Self {
99 data: self.data.clone(),
100 name: self.name.clone(),
101 description: self.description.clone(),
102 min_permission: self.min_permission.clone(),
103 fire_mode: self.fire_mode.clone(),
104 max_duration: self.max_duration,
105 builder: self.builder,
106 }
107 }
108}
109
110impl<'de> Deserialize<'de> for TaskDefinition {
111 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
112 where
113 D: Deserializer<'de>,
114 {
115 deserializer.deserialize_map(TaskDefinitionVisitor::new())
116 }
117}
118
119struct TaskDefinitionVisitor {}
120
121impl TaskDefinitionVisitor {
122 fn new() -> Self {
123 TaskDefinitionVisitor {}
124 }
125}
126
127impl<'de> Visitor<'de> for TaskDefinitionVisitor {
128 type Value = TaskDefinition;
130
131 fn visit_map<M>(self, mut access: M) -> Result<Self::Value, M::Error>
135 where
136 M: MapAccess<'de>,
137 {
138 let mut data = SiemTaskData::UPDATE_GEOIP;
141 let mut name = String::new();
142 let mut description = String::new();
143 let mut min_permission = UserRole::Administrator;
144 let mut fire_mode = TaskFireMode::Inmediate;
145 let mut max_duration = 0;
146 while let Some(key) = access.next_key::<&str>()? {
147 if key == "name" {
148 name = access.next_value()?;
149 } else if key == "description" {
150 description = access.next_value()?;
151 } else if key == "min_permission" {
152 min_permission = access.next_value()?;
153 } else if key == "fire_mode" {
154 fire_mode = access.next_value()?;
155 } else if key == "max_duration" {
156 max_duration = access.next_value()?;
157 } else if key == "data" {
158 data = access.next_value()?;
159 }
160 }
161 Ok(TaskDefinition::new(
162 data,
163 LogString::Owned(name),
164 LogString::Owned(description),
165 min_permission,
166 fire_mode,
167 max_duration,
168 |task: SiemTask, _datasets: &DatasetHolder| {
169 Ok(Box::pin(async move {
170 SiemTaskResult {
171 data: Some(Ok("OK".into())),
172 id: task.id,
173 }
174 }))
175 },
176 ))
177 }
178
179 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
180 write!(formatter, "A valid command result")
181 }
182}
183
184#[derive(Serialize, Deserialize, Debug, Clone)]
185pub enum TaskFireMode {
186 Inmediate,
188 Cron(u32, u32, u32, u32, u32),
190 Repetitive(u64),
192 Future(i64),
194}
195
196#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
197#[allow(non_camel_case_types)]
198#[non_exhaustive]
199pub enum SiemTaskType {
200 EXECUTE_ENDPOINT_SCRIPT,
202 REMEDIATE_EMAILS,
204 REPORT_ABUSE,
206 UPDATE_GEOIP,
208 UPDATE_CLOUD_PROVIDER,
209 OTHER(LogString),
211}
212
213impl std::fmt::Display for SiemTaskType {
214 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
215 match self {
216 SiemTaskType::EXECUTE_ENDPOINT_SCRIPT => write!(f, "EXECUTE_ENDPOINT_SCRIPT"),
217 SiemTaskType::REMEDIATE_EMAILS => write!(f, "REMEDIATE_EMAILS"),
218 SiemTaskType::REPORT_ABUSE => write!(f, "REPORT_ABUSE"),
219 SiemTaskType::UPDATE_GEOIP => write!(f, "UPDATE_GEOIP"),
220 SiemTaskType::UPDATE_CLOUD_PROVIDER => write!(f, "UPDATE_CLOUD_PROVIDER"),
221 SiemTaskType::OTHER(name) => write!(f, "{}", name),
222 }
223 }
224}
225
226#[derive(Serialize, Deserialize, Debug, Clone)]
227#[allow(non_camel_case_types)]
228#[non_exhaustive]
229pub enum SiemTaskData {
230 EXECUTE_ENDPOINT_SCRIPT(LogString, BTreeMap<LogString, LogString>),
232 REMEDIATE_EMAILS(BTreeMap<LogString, LogString>),
234 REPORT_ABUSE(BTreeMap<LogString, LogString>),
236 UPDATE_GEOIP,
238 UPDATE_CLOUD_PROVIDER,
240 OTHER(LogString, BTreeMap<LogString, LogString>),
242}
243
244impl SiemTaskData {
245 pub fn class(&self) -> SiemTaskType {
246 match self {
247 SiemTaskData::EXECUTE_ENDPOINT_SCRIPT(_, _) => SiemTaskType::EXECUTE_ENDPOINT_SCRIPT,
248 SiemTaskData::REMEDIATE_EMAILS(_) => SiemTaskType::REMEDIATE_EMAILS,
249 SiemTaskData::REPORT_ABUSE(_) => SiemTaskType::REPORT_ABUSE,
250 SiemTaskData::UPDATE_GEOIP => SiemTaskType::UPDATE_GEOIP,
251 SiemTaskData::UPDATE_CLOUD_PROVIDER => SiemTaskType::UPDATE_CLOUD_PROVIDER,
252 SiemTaskData::OTHER(v, _) => SiemTaskType::OTHER(v.clone()),
253 }
254 }
255}
256
257#[derive(Serialize, Deserialize, Debug, Clone)]
262pub struct SiemTask {
263 pub created_at: i64,
264 pub enqueued_at: i64,
265 pub origin: String,
266 pub id: u64,
267 pub data: SiemTaskData,
268}
269
270#[derive(Serialize, Deserialize, Debug, Clone)]
271pub struct SiemTaskResult {
272 pub id: u64,
273 pub data: Option<Result<String, String>>,
274}
275
276#[test]
277fn task_builder_should_generate_async_task() {
278 let builder: TaskBuilder = |task: SiemTask, _datasets: &DatasetHolder| {
279 Ok(Box::pin(async move {
280 SiemTaskResult {
281 data: Some(Ok(format!("OK"))),
282 id: task.id,
283 }
284 }))
285 };
286
287 let task = SiemTask {
288 created_at: 0,
289 enqueued_at: 1,
290 origin: format!("123"),
291 id: 12345,
292 data: SiemTaskData::REPORT_ABUSE(BTreeMap::new()),
293 };
294 let dataset = DatasetHolder::default();
295 let task = builder(task, &dataset).unwrap();
296
297 async_std::task::block_on(async move {
298 let result = task.await;
299 assert_eq!(12345, result.id);
300 assert_eq!(Ok(format!("OK")), result.data.unwrap());
301 });
302}