usiem/components/
task.rs

1use serde::{
2    de::{MapAccess, Visitor},
3    Deserialize, Deserializer, Serialize,
4};
5use std::{collections::BTreeMap, future::Future, pin::Pin};
6
7use crate::prelude::{holder::DatasetHolder, types::LogString, SiemResult};
8
9use super::common::UserRole;
10
11pub trait TaskBuilder2: std::fmt::Debug {
12    fn build(
13        &self,
14        task: SiemTask,
15    ) -> SiemResult<Pin<Box<dyn Future<Output = SiemTaskResult> + Send>>>
16    where
17        Self: Sized;
18    fn clone(&self) -> Box<dyn TaskBuilder2>;
19}
20
21pub type TaskBuilder = fn(
22    SiemTask,
23    &DatasetHolder,
24) -> SiemResult<Pin<Box<dyn Future<Output = SiemTaskResult> + Send>>>;
25
26#[derive(Serialize)]
27pub struct TaskDefinition {
28    data: SiemTaskData,
29    name: LogString,
30    description: LogString,
31    min_permission: UserRole,
32    fire_mode: TaskFireMode,
33    /// Time after which the task can be killed
34    max_duration: u64,
35    #[serde(skip)]
36    builder: TaskBuilder,
37}
38
39impl TaskDefinition {
40    pub fn new(
41        data: SiemTaskData,
42        name: LogString,
43        description: LogString,
44        min_permission: UserRole,
45        fire_mode: TaskFireMode,
46        max_duration: u64,
47        builder: TaskBuilder,
48    ) -> TaskDefinition {
49        TaskDefinition {
50            data,
51            name,
52            description,
53            min_permission,
54            fire_mode,
55            max_duration,
56            builder,
57        }
58    }
59
60    pub fn data(&self) -> &SiemTaskData {
61        &self.data
62    }
63    pub fn name(&self) -> &str {
64        &self.name
65    }
66    pub fn description(&self) -> &str {
67        &self.description
68    }
69    pub fn min_permission(&self) -> &UserRole {
70        &self.min_permission
71    }
72    pub fn fire_mode(&self) -> &TaskFireMode {
73        &self.fire_mode
74    }
75    pub fn max_duration(&self) -> u64 {
76        self.max_duration
77    }
78    pub fn builder(&self) -> TaskBuilder {
79        self.builder
80    }
81}
82
83impl std::fmt::Debug for TaskDefinition {
84    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85        f.debug_struct("TaskDefinition")
86            .field("data", &self.data)
87            .field("name", &self.name)
88            .field("description", &self.description)
89            .field("min_permission", &self.min_permission)
90            .field("fire_mode", &self.fire_mode)
91            .field("max_duration", &self.max_duration)
92            .finish()
93    }
94}
95
96impl Clone for TaskDefinition {
97    fn clone(&self) -> Self {
98        Self {
99            data: self.data.clone(),
100            name: self.name.clone(),
101            description: self.description.clone(),
102            min_permission: self.min_permission.clone(),
103            fire_mode: self.fire_mode.clone(),
104            max_duration: self.max_duration,
105            builder: self.builder,
106        }
107    }
108}
109
110impl<'de> Deserialize<'de> for TaskDefinition {
111    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
112    where
113        D: Deserializer<'de>,
114    {
115        deserializer.deserialize_map(TaskDefinitionVisitor::new())
116    }
117}
118
119struct TaskDefinitionVisitor {}
120
121impl TaskDefinitionVisitor {
122    fn new() -> Self {
123        TaskDefinitionVisitor {}
124    }
125}
126
127impl<'de> Visitor<'de> for TaskDefinitionVisitor {
128    // The type that our Visitor is going to produce.
129    type Value = TaskDefinition;
130
131    // Deserialize MyMap from an abstract "map" provided by the
132    // Deserializer. The MapAccess input is a callback provided by
133    // the Deserializer to let us see each entry in the map.
134    fn visit_map<M>(self, mut access: M) -> Result<Self::Value, M::Error>
135    where
136        M: MapAccess<'de>,
137    {
138        // While there are entries remaining in the input, add them
139        // into our map.
140        let mut data = SiemTaskData::UPDATE_GEOIP;
141        let mut name = String::new();
142        let mut description = String::new();
143        let mut min_permission = UserRole::Administrator;
144        let mut fire_mode = TaskFireMode::Inmediate;
145        let mut max_duration = 0;
146        while let Some(key) = access.next_key::<&str>()? {
147            if key == "name" {
148                name = access.next_value()?;
149            } else if key == "description" {
150                description = access.next_value()?;
151            } else if key == "min_permission" {
152                min_permission = access.next_value()?;
153            } else if key == "fire_mode" {
154                fire_mode = access.next_value()?;
155            } else if key == "max_duration" {
156                max_duration = access.next_value()?;
157            } else if key == "data" {
158                data = access.next_value()?;
159            }
160        }
161        Ok(TaskDefinition::new(
162            data,
163            LogString::Owned(name),
164            LogString::Owned(description),
165            min_permission,
166            fire_mode,
167            max_duration,
168            |task: SiemTask, _datasets: &DatasetHolder| {
169                Ok(Box::pin(async move {
170                    SiemTaskResult {
171                        data: Some(Ok("OK".into())),
172                        id: task.id,
173                    }
174                }))
175            },
176        ))
177    }
178
179    fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
180        write!(formatter, "A valid command result")
181    }
182}
183
184#[derive(Serialize, Deserialize, Debug, Clone)]
185pub enum TaskFireMode {
186    /// Execute this tasks as soon as posible
187    Inmediate,
188    /// Execute this taks using a cron definition
189    Cron(u32, u32, u32, u32, u32),
190    /// Execute each X miliseconds
191    Repetitive(u64),
192    /// Execute this task once in the future
193    Future(i64),
194}
195
196#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
197#[allow(non_camel_case_types)]
198#[non_exhaustive]
199pub enum SiemTaskType {
200    /// Script name and Script parameters
201    EXECUTE_ENDPOINT_SCRIPT,
202    /// Remediate a list of emails. List of parameters
203    REMEDIATE_EMAILS,
204    /// Report IP, email to abuse mail. Needed provider name and parameters
205    REPORT_ABUSE,
206    /// Update GeoIP database
207    UPDATE_GEOIP,
208    UPDATE_CLOUD_PROVIDER,
209    /// Task name, Map<ParamName, Description>
210    OTHER(LogString),
211}
212
213impl std::fmt::Display for SiemTaskType {
214    fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
215        match self {
216            SiemTaskType::EXECUTE_ENDPOINT_SCRIPT => write!(f, "EXECUTE_ENDPOINT_SCRIPT"),
217            SiemTaskType::REMEDIATE_EMAILS => write!(f, "REMEDIATE_EMAILS"),
218            SiemTaskType::REPORT_ABUSE => write!(f, "REPORT_ABUSE"),
219            SiemTaskType::UPDATE_GEOIP => write!(f, "UPDATE_GEOIP"),
220            SiemTaskType::UPDATE_CLOUD_PROVIDER => write!(f, "UPDATE_CLOUD_PROVIDER"),
221            SiemTaskType::OTHER(name) => write!(f, "{}", name),
222        }
223    }
224}
225
226#[derive(Serialize, Deserialize, Debug, Clone)]
227#[allow(non_camel_case_types)]
228#[non_exhaustive]
229pub enum SiemTaskData {
230    /// Script name and Script parameters
231    EXECUTE_ENDPOINT_SCRIPT(LogString, BTreeMap<LogString, LogString>),
232    /// Remediate a list of emails. List of parameters
233    REMEDIATE_EMAILS(BTreeMap<LogString, LogString>),
234    /// Report IP, email to abuse mail. Needed provider name and parameters
235    REPORT_ABUSE(BTreeMap<LogString, LogString>),
236    /// Update GeoIP dataset
237    UPDATE_GEOIP,
238    /// Update CloudProvider dataset
239    UPDATE_CLOUD_PROVIDER,
240    /// Task name, Map<ParamName, Description>
241    OTHER(LogString, BTreeMap<LogString, LogString>),
242}
243
244impl SiemTaskData {
245    pub fn class(&self) -> SiemTaskType {
246        match self {
247            SiemTaskData::EXECUTE_ENDPOINT_SCRIPT(_, _) => SiemTaskType::EXECUTE_ENDPOINT_SCRIPT,
248            SiemTaskData::REMEDIATE_EMAILS(_) => SiemTaskType::REMEDIATE_EMAILS,
249            SiemTaskData::REPORT_ABUSE(_) => SiemTaskType::REPORT_ABUSE,
250            SiemTaskData::UPDATE_GEOIP => SiemTaskType::UPDATE_GEOIP,
251            SiemTaskData::UPDATE_CLOUD_PROVIDER => SiemTaskType::UPDATE_CLOUD_PROVIDER,
252            SiemTaskData::OTHER(v, _) => SiemTaskType::OTHER(v.clone()),
253        }
254    }
255}
256
257/// Enqueued task with data.
258/// If the Task has finished then the result has Some data.
259/// This data can be a Ok with the output (not the data) or the error.
260/// The ID is to get the Task result
261#[derive(Serialize, Deserialize, Debug, Clone)]
262pub struct SiemTask {
263    pub created_at: i64,
264    pub enqueued_at: i64,
265    pub origin: String,
266    pub id: u64,
267    pub data: SiemTaskData,
268}
269
270#[derive(Serialize, Deserialize, Debug, Clone)]
271pub struct SiemTaskResult {
272    pub id: u64,
273    pub data: Option<Result<String, String>>,
274}
275
276#[test]
277fn task_builder_should_generate_async_task() {
278    let builder: TaskBuilder = |task: SiemTask, _datasets: &DatasetHolder| {
279        Ok(Box::pin(async move {
280            SiemTaskResult {
281                data: Some(Ok(format!("OK"))),
282                id: task.id,
283            }
284        }))
285    };
286
287    let task = SiemTask {
288        created_at: 0,
289        enqueued_at: 1,
290        origin: format!("123"),
291        id: 12345,
292        data: SiemTaskData::REPORT_ABUSE(BTreeMap::new()),
293    };
294    let dataset = DatasetHolder::default();
295    let task = builder(task, &dataset).unwrap();
296
297    async_std::task::block_on(async move {
298        let result = task.await;
299        assert_eq!(12345, result.id);
300        assert_eq!(Ok(format!("OK")), result.data.unwrap());
301    });
302}