1use std::sync::Arc;
7
8use axum::{
9 extract::{Path, Query, State},
10 http::StatusCode,
11 Json,
12};
13use velesdb_core::collection::graph::TraversalConfig;
14
15use crate::types::ErrorResponse;
16use crate::AppState;
17
18use super::handlers::graph_preamble;
19use super::types::{
20 EdgeCountResponse, EdgeResponse, EdgesResponse, GraphSearchRequest, GraphSearchResponse,
21 GraphSearchResultItem, NodeEdgeQueryParams, NodeListResponse, NodePayloadResponse,
22 ParallelTraverseRequest, TraversalStats, TraverseResponse, UpsertNodePayloadRequest,
23};
24
25#[utoipa::path(
27 delete,
28 path = "/collections/{name}/graph/edges/{edge_id}",
29 params(
30 ("name" = String, Path, description = "Collection name"),
31 ("edge_id" = u64, Path, description = "Edge ID to remove")
32 ),
33 responses(
34 (status = 204, description = "Edge removed successfully"),
35 (status = 404, description = "Edge or collection not found", body = ErrorResponse),
36 (status = 500, description = "Internal server error", body = ErrorResponse)
37 ),
38 tag = "graph"
39)]
40pub async fn remove_edge(
41 Path((name, edge_id)): Path<(String, u64)>,
42 State(state): State<Arc<AppState>>,
43) -> Result<StatusCode, (StatusCode, Json<ErrorResponse>)> {
44 let coll = graph_preamble(&state, &name)?;
45 if coll.remove_edge(edge_id) {
46 Ok(StatusCode::NO_CONTENT)
47 } else {
48 let err = velesdb_core::Error::EdgeNotFound(edge_id);
54 Err((
55 StatusCode::NOT_FOUND,
56 Json(ErrorResponse {
57 error: format!("{err} in collection '{name}'"),
58 code: Some(err.code().to_string()),
59 }),
60 ))
61 }
62}
63
64#[utoipa::path(
66 get,
67 path = "/collections/{name}/graph/edges/count",
68 params(
69 ("name" = String, Path, description = "Collection name")
70 ),
71 responses(
72 (status = 200, description = "Edge count retrieved", body = EdgeCountResponse),
73 (status = 404, description = "Collection not found", body = ErrorResponse)
74 ),
75 tag = "graph"
76)]
77pub async fn get_edge_count(
78 Path(name): Path<String>,
79 State(state): State<Arc<AppState>>,
80) -> Result<Json<EdgeCountResponse>, (StatusCode, Json<ErrorResponse>)> {
81 let coll = graph_preamble(&state, &name)?;
82 Ok(Json(EdgeCountResponse {
83 count: coll.edge_count(),
84 }))
85}
86
87#[utoipa::path(
89 get,
90 path = "/collections/{name}/graph/nodes",
91 params(
92 ("name" = String, Path, description = "Collection name")
93 ),
94 responses(
95 (status = 200, description = "Node list retrieved", body = NodeListResponse),
96 (status = 404, description = "Collection not found", body = ErrorResponse)
97 ),
98 tag = "graph"
99)]
100pub async fn list_nodes(
101 Path(name): Path<String>,
102 State(state): State<Arc<AppState>>,
103) -> Result<Json<NodeListResponse>, (StatusCode, Json<ErrorResponse>)> {
104 let coll = graph_preamble(&state, &name)?;
105 let node_ids = coll.all_node_ids();
106 let count = node_ids.len();
107 Ok(Json(NodeListResponse { node_ids, count }))
108}
109
110#[utoipa::path(
112 get,
113 path = "/collections/{name}/graph/nodes/{node_id}/edges",
114 params(
115 ("name" = String, Path, description = "Collection name"),
116 ("node_id" = u64, Path, description = "Node ID"),
117 NodeEdgeQueryParams
118 ),
119 responses(
120 (status = 200, description = "Node edges retrieved", body = EdgesResponse),
121 (status = 404, description = "Collection not found", body = ErrorResponse)
122 ),
123 tag = "graph"
124)]
125pub async fn get_node_edges(
126 Path((name, node_id)): Path<(String, u64)>,
127 Query(params): Query<NodeEdgeQueryParams>,
128 State(state): State<Arc<AppState>>,
129) -> Result<Json<EdgesResponse>, (StatusCode, Json<ErrorResponse>)> {
130 let coll = graph_preamble(&state, &name)?;
131
132 let raw_edges = match params.direction.to_lowercase().as_str() {
133 "in" => coll.get_incoming(node_id),
134 "both" => {
135 let mut all = coll.get_outgoing(node_id);
136 all.extend(coll.get_incoming(node_id));
137 all
138 }
139 _ => coll.get_outgoing(node_id),
140 };
141
142 let edges: Vec<EdgeResponse> = raw_edges
143 .into_iter()
144 .filter(|e| {
145 params
146 .label
147 .as_ref()
148 .is_none_or(|lbl| e.label() == lbl.as_str())
149 })
150 .map(|e| EdgeResponse {
151 id: e.id(),
152 source: e.source(),
153 target: e.target(),
154 label: e.label().to_string(),
155 properties: serde_json::to_value(e.properties()).unwrap_or_default(),
156 })
157 .collect();
158
159 let count = edges.len();
160 Ok(Json(EdgesResponse { edges, count }))
161}
162
163#[utoipa::path(
165 put,
166 path = "/collections/{name}/graph/nodes/{node_id}/payload",
167 params(
168 ("name" = String, Path, description = "Collection name"),
169 ("node_id" = u64, Path, description = "Node ID")
170 ),
171 request_body = UpsertNodePayloadRequest,
172 responses(
173 (status = 204, description = "Payload stored successfully"),
174 (status = 404, description = "Collection not found", body = ErrorResponse),
175 (status = 500, description = "Internal server error", body = ErrorResponse)
176 ),
177 tag = "graph"
178)]
179pub async fn upsert_node_payload(
180 Path((name, node_id)): Path<(String, u64)>,
181 State(state): State<Arc<AppState>>,
182 Json(request): Json<UpsertNodePayloadRequest>,
183) -> Result<StatusCode, (StatusCode, Json<ErrorResponse>)> {
184 let coll = graph_preamble(&state, &name)?;
185 coll.upsert_node_payload(node_id, &request.payload)
186 .map_err(|e| {
187 (
188 StatusCode::INTERNAL_SERVER_ERROR,
189 Json(ErrorResponse {
190 error: format!("Failed to store payload: {e}"),
191 code: None,
192 }),
193 )
194 })?;
195 Ok(StatusCode::NO_CONTENT)
196}
197
198#[utoipa::path(
200 get,
201 path = "/collections/{name}/graph/nodes/{node_id}/payload",
202 params(
203 ("name" = String, Path, description = "Collection name"),
204 ("node_id" = u64, Path, description = "Node ID")
205 ),
206 responses(
207 (status = 200, description = "Payload retrieved", body = NodePayloadResponse),
208 (status = 404, description = "Collection not found", body = ErrorResponse),
209 (status = 500, description = "Internal server error", body = ErrorResponse)
210 ),
211 tag = "graph"
212)]
213pub async fn get_node_payload(
214 Path((name, node_id)): Path<(String, u64)>,
215 State(state): State<Arc<AppState>>,
216) -> Result<Json<NodePayloadResponse>, (StatusCode, Json<ErrorResponse>)> {
217 let coll = graph_preamble(&state, &name)?;
218 let payload = coll.get_node_payload(node_id).map_err(|e| {
219 (
220 StatusCode::INTERNAL_SERVER_ERROR,
221 Json(ErrorResponse {
222 error: format!("Failed to get payload: {e}"),
223 code: None,
224 }),
225 )
226 })?;
227 Ok(Json(NodePayloadResponse { node_id, payload }))
228}
229
230#[utoipa::path(
232 post,
233 path = "/collections/{name}/graph/traverse/parallel",
234 request_body = ParallelTraverseRequest,
235 responses(
236 (status = 200, description = "Parallel traversal completed", body = TraverseResponse),
237 (status = 400, description = "Invalid request", body = ErrorResponse),
238 (status = 404, description = "Collection not found", body = ErrorResponse)
239 ),
240 tag = "graph"
241)]
242pub async fn traverse_parallel(
243 Path(name): Path<String>,
244 State(state): State<Arc<AppState>>,
245 Json(request): Json<ParallelTraverseRequest>,
246) -> Result<Json<TraverseResponse>, (StatusCode, Json<ErrorResponse>)> {
247 if request.sources.is_empty() {
248 return Err((
249 StatusCode::BAD_REQUEST,
250 Json(ErrorResponse {
251 error: "At least one source node ID is required".to_string(),
252 code: None,
253 }),
254 ));
255 }
256
257 let coll = graph_preamble(&state, &name)?;
258
259 let config = TraversalConfig::with_range(1, request.max_depth)
260 .with_limit(request.limit)
261 .with_rel_types(request.rel_types);
262
263 let raw_results = coll.traverse_bfs_parallel(&request.sources, &config);
264
265 let results: Vec<super::types::TraversalResultItem> = raw_results
266 .into_iter()
267 .map(|r| super::types::TraversalResultItem {
268 target_id: r.target_id,
269 depth: r.depth,
270 path: r.path,
271 })
272 .collect();
273
274 let depth_reached = results.iter().map(|r| r.depth).max().unwrap_or(0);
275 let visited = results.len();
276 let has_more = visited >= request.limit;
277
278 Ok(Json(TraverseResponse {
279 results,
280 has_more,
281 stats: TraversalStats {
282 visited,
283 depth_reached,
284 },
285 }))
286}
287
288#[utoipa::path(
290 post,
291 path = "/collections/{name}/graph/search",
292 request_body = GraphSearchRequest,
293 responses(
294 (status = 200, description = "Graph search results", body = GraphSearchResponse),
295 (status = 400, description = "Invalid request", body = ErrorResponse),
296 (status = 404, description = "Collection not found", body = ErrorResponse),
297 (status = 500, description = "Internal server error", body = ErrorResponse)
298 ),
299 tag = "graph"
300)]
301pub async fn graph_search(
302 Path(name): Path<String>,
303 State(state): State<Arc<AppState>>,
304 Json(request): Json<GraphSearchRequest>,
305) -> Result<Json<GraphSearchResponse>, (StatusCode, Json<ErrorResponse>)> {
306 let coll = graph_preamble(&state, &name)?;
307
308 if !coll.has_embeddings() {
309 return Err((
310 StatusCode::BAD_REQUEST,
311 Json(ErrorResponse {
312 error: format!(
313 "Graph collection '{name}' does not have embeddings. \
314 Create it with create_graph_collection_with_embeddings() to enable search."
315 ),
316 code: None,
317 }),
318 ));
319 }
320
321 let search_results = coll
322 .search_by_embedding(&request.vector, request.top_k)
323 .map_err(|e| {
324 (
325 StatusCode::INTERNAL_SERVER_ERROR,
326 Json(ErrorResponse {
327 error: format!("Graph search failed: {e}"),
328 code: None,
329 }),
330 )
331 })?;
332
333 let results: Vec<GraphSearchResultItem> = search_results
334 .into_iter()
335 .map(|r| GraphSearchResultItem {
336 id: r.point.id,
337 score: r.score,
338 payload: r.point.payload,
339 })
340 .collect();
341
342 Ok(Json(GraphSearchResponse { results }))
343}