1#![allow(dead_code)]
2use std::collections::HashMap;
10use std::sync::Mutex;
11
12use tokio::sync::mpsc;
13
14use super::super::rpc_types::{
15 CommandMessageData, RpcMessage, COMMAND_MESSAGE, COMMAND_ROUTE_ANNOUNCE,
16 COMMAND_ROUTE_UNANNOUNCE,
17};
18
19pub const DEFAULT_ROUTE: &str = "wavesrv";
22pub const UPSTREAM_ROUTE: &str = "upstream";
23pub const SYS_ROUTE: &str = "sys";
24pub const TAURI_ROUTE: &str = "electron";
25
26pub const ROUTE_PREFIX_CONN: &str = "conn:";
27pub const ROUTE_PREFIX_CONTROLLER: &str = "controller:";
28pub const ROUTE_PREFIX_PROC: &str = "proc:";
29pub const ROUTE_PREFIX_TAB: &str = "tab:";
30pub const ROUTE_PREFIX_FE_BLOCK: &str = "feblock:";
31
32pub fn make_connection_route_id(conn_id: &str) -> String {
35 format!("conn:{}", conn_id)
36}
37
38pub fn make_controller_route_id(block_id: &str) -> String {
39 format!("controller:{}", block_id)
40}
41
42pub fn make_proc_route_id(proc_id: &str) -> String {
43 format!("proc:{}", proc_id)
44}
45
46pub fn make_tab_route_id(tab_id: &str) -> String {
47 format!("tab:{}", tab_id)
48}
49
50pub fn make_fe_block_route_id(block_id: &str) -> String {
51 format!("feblock:{}", block_id)
52}
53
54#[derive(Debug, Clone)]
57struct RouteInfo {
58 source_route_id: String,
59 dest_route_id: String,
60}
61
62struct MsgAndRoute {
63 msg_bytes: Vec<u8>,
64 from_route_id: String,
65}
66
67pub trait RpcClient: Send + Sync {
72 fn send_rpc_message(&self, msg: &[u8]);
73}
74
75const CLIENT_CHANNEL_CAPACITY: usize = 256;
81
82pub struct ChannelRpcClient {
83 tx: mpsc::Sender<Vec<u8>>,
84}
85
86impl ChannelRpcClient {
87 pub fn new() -> (Self, mpsc::Receiver<Vec<u8>>) {
88 let (tx, rx) = mpsc::channel(CLIENT_CHANNEL_CAPACITY);
89 (Self { tx }, rx)
90 }
91}
92
93impl RpcClient for ChannelRpcClient {
94 fn send_rpc_message(&self, msg: &[u8]) {
95 if let Err(_) = self.tx.try_send(msg.to_vec()) {
96 tracing::warn!("ChannelRpcClient: message dropped (channel full or closed)");
97 }
98 }
99}
100
101struct RouterInner {
104 route_map: HashMap<String, Box<dyn RpcClient>>,
105 announced_routes: HashMap<String, String>, rpc_map: HashMap<String, RouteInfo>, simple_request_map: HashMap<String, tokio::sync::oneshot::Sender<RpcMessage>>,
108}
109
110pub struct WshRouter {
115 inner: Mutex<RouterInner>,
116 input_tx: mpsc::Sender<MsgAndRoute>,
117}
118
119const ROUTER_CHANNEL_CAPACITY: usize = 1000;
121
122impl WshRouter {
123 pub fn new() -> std::sync::Arc<Self> {
125 let (input_tx, input_rx) = mpsc::channel(ROUTER_CHANNEL_CAPACITY);
126 let router = std::sync::Arc::new(Self {
127 inner: Mutex::new(RouterInner {
128 route_map: HashMap::new(),
129 announced_routes: HashMap::new(),
130 rpc_map: HashMap::new(),
131 simple_request_map: HashMap::new(),
132 }),
133 input_tx,
134 });
135 let router_clone = router.clone();
136 tokio::spawn(async move {
137 router_clone.run_server(input_rx).await;
138 });
139 router
140 }
141
142 pub fn inject_message(&self, msg_bytes: Vec<u8>, from_route_id: &str) {
144 if let Err(_) = self.input_tx.try_send(MsgAndRoute {
145 msg_bytes,
146 from_route_id: from_route_id.to_string(),
147 }) {
148 tracing::warn!(from_route = %from_route_id, "router input queue full or closed — message dropped");
149 }
150 }
151
152 pub fn register_route(&self, route_id: &str, client: Box<dyn RpcClient>) {
154 if route_id == SYS_ROUTE || route_id == UPSTREAM_ROUTE {
155 tracing::error!("WshRouter cannot register {} route", route_id);
156 return;
157 }
158 tracing::info!("[router] registering wsh route {:?}", route_id);
159 let mut inner = self.inner.lock().unwrap();
160 if inner.route_map.contains_key(route_id) {
161 tracing::warn!("[router] route {:?} already exists (replacing)", route_id);
162 }
163 inner.route_map.insert(route_id.to_string(), client);
164 }
165
166 pub fn unregister_route(&self, route_id: &str) {
168 tracing::info!("[router] unregistering wsh route {:?}", route_id);
169 let mut inner = self.inner.lock().unwrap();
170 inner.route_map.remove(route_id);
171 inner
172 .announced_routes
173 .retain(|_, local_id| local_id != route_id);
174 }
175
176 pub fn has_route(&self, route_id: &str) -> bool {
178 let inner = self.inner.lock().unwrap();
179 inner.route_map.contains_key(route_id)
180 }
181
182 pub fn route_ids(&self) -> Vec<String> {
184 let inner = self.inner.lock().unwrap();
185 inner.route_map.keys().cloned().collect()
186 }
187
188 pub async fn wait_for_register(
190 &self,
191 route_id: &str,
192 timeout: std::time::Duration,
193 ) -> bool {
194 let deadline = tokio::time::Instant::now() + timeout;
195 loop {
196 {
197 let inner = self.inner.lock().unwrap();
198 if inner.route_map.contains_key(route_id)
199 || inner.announced_routes.contains_key(route_id)
200 {
201 return true;
202 }
203 }
204 if tokio::time::Instant::now() >= deadline {
205 return false;
206 }
207 tokio::time::sleep(std::time::Duration::from_millis(30)).await;
208 }
209 }
210
211 pub async fn run_simple_raw_command(
213 &self,
214 msg: RpcMessage,
215 from_route_id: &str,
216 timeout: std::time::Duration,
217 ) -> Result<Option<RpcMessage>, String> {
218 if msg.command.is_empty() {
219 return Err("no command".to_string());
220 }
221 let msg_bytes =
222 serde_json::to_vec(&msg).map_err(|e| format!("marshal error: {}", e))?;
223 let rx = if !msg.reqid.is_empty() {
224 Some(self.register_simple_request(&msg.reqid))
225 } else {
226 None
227 };
228 self.inject_message(msg_bytes, from_route_id);
229 match rx {
230 None => Ok(None),
231 Some(rx) => {
232 tokio::select! {
233 resp = rx => {
234 match resp {
235 Ok(resp) => {
236 if !resp.error.is_empty() {
237 Err(resp.error)
238 } else {
239 Ok(Some(resp))
240 }
241 }
242 Err(_) => Err("request cancelled".to_string()),
243 }
244 }
245 _ = tokio::time::sleep(timeout) => {
246 self.clear_simple_request(&msg.reqid);
247 Err("timeout".to_string())
248 }
249 }
250 }
251 }
252 }
253
254 fn register_route_info(
257 &self,
258 rpc_id: &str,
259 source_route_id: &str,
260 dest_route_id: &str,
261 ) {
262 if rpc_id.is_empty() {
263 return;
264 }
265 let mut inner = self.inner.lock().unwrap();
266 inner.rpc_map.insert(
267 rpc_id.to_string(),
268 RouteInfo {
269 source_route_id: source_route_id.to_string(),
270 dest_route_id: dest_route_id.to_string(),
271 },
272 );
273 }
274
275 fn unregister_route_info(&self, rpc_id: &str) {
276 let mut inner = self.inner.lock().unwrap();
277 inner.rpc_map.remove(rpc_id);
278 }
279
280 fn get_route_info(&self, rpc_id: &str) -> Option<RouteInfo> {
281 let inner = self.inner.lock().unwrap();
282 inner.rpc_map.get(rpc_id).cloned()
283 }
284
285 fn send_routed_message(&self, msg_bytes: &[u8], route_id: &str) -> bool {
286 let inner = self.inner.lock().unwrap();
287 if let Some(rpc) = inner.route_map.get(route_id) {
288 rpc.send_rpc_message(msg_bytes);
289 return true;
290 }
291 if let Some(local_route) = inner.announced_routes.get(route_id) {
293 if let Some(rpc) = inner.route_map.get(local_route.as_str()) {
294 rpc.send_rpc_message(msg_bytes);
295 return true;
296 }
297 }
298 false
299 }
300
301 fn handle_no_route(&self, msg: &RpcMessage) {
302 let err_msg = if msg.route.is_empty() {
303 "no default route".to_string()
304 } else {
305 format!("no route for {:?}", msg.route)
306 };
307 if msg.reqid.is_empty() {
308 if msg.command == COMMAND_MESSAGE {
310 return; }
312 let resp = RpcMessage {
313 command: COMMAND_MESSAGE.to_string(),
314 route: msg.source.clone(),
315 data: serde_json::to_value(CommandMessageData {
316 oref: Default::default(),
317 message: err_msg,
318 })
319 .ok(),
320 ..Default::default()
321 };
322 if let Ok(resp_bytes) = serde_json::to_vec(&resp) {
323 let _ = self.input_tx.send(MsgAndRoute {
324 msg_bytes: resp_bytes,
325 from_route_id: SYS_ROUTE.to_string(),
326 });
327 }
328 return;
329 }
330 let response = RpcMessage {
332 resid: msg.reqid.clone(),
333 error: err_msg,
334 ..Default::default()
335 };
336 if let Ok(resp_bytes) = serde_json::to_vec(&response) {
337 self.send_routed_message(&resp_bytes, &msg.source);
338 }
339 }
340
341 fn handle_announce_message(&self, msg: &RpcMessage, from_route_id: &str) {
342 if msg.source == from_route_id {
343 return;
344 }
345 let mut inner = self.inner.lock().unwrap();
346 inner
347 .announced_routes
348 .insert(msg.source.clone(), from_route_id.to_string());
349 }
350
351 fn handle_unannounce_message(&self, msg: &RpcMessage) {
352 let mut inner = self.inner.lock().unwrap();
353 inner.announced_routes.remove(&msg.source);
354 }
355
356 fn register_simple_request(
357 &self,
358 req_id: &str,
359 ) -> tokio::sync::oneshot::Receiver<RpcMessage> {
360 let (tx, rx) = tokio::sync::oneshot::channel();
361 let mut inner = self.inner.lock().unwrap();
362 inner
363 .simple_request_map
364 .insert(req_id.to_string(), tx);
365 rx
366 }
367
368 fn try_simple_response(&self, msg: &RpcMessage) -> bool {
369 let mut inner = self.inner.lock().unwrap();
370 if let Some(tx) = inner.simple_request_map.remove(&msg.resid) {
371 let _ = tx.send(msg.clone());
372 return true;
373 }
374 false
375 }
376
377 fn clear_simple_request(&self, req_id: &str) {
378 let mut inner = self.inner.lock().unwrap();
379 inner.simple_request_map.remove(req_id);
380 }
381
382 async fn run_server(&self, mut input_rx: mpsc::Receiver<MsgAndRoute>) {
383 while let Some(input) = input_rx.recv().await {
384 let msg: RpcMessage = match serde_json::from_slice(&input.msg_bytes) {
385 Ok(m) => m,
386 Err(e) => {
387 tracing::error!("error unmarshalling message: {}", e);
388 continue;
389 }
390 };
391
392 if msg.command == COMMAND_ROUTE_ANNOUNCE {
394 self.handle_announce_message(&msg, &input.from_route_id);
395 continue;
396 }
397 if msg.command == COMMAND_ROUTE_UNANNOUNCE {
398 self.handle_unannounce_message(&msg);
399 continue;
400 }
401
402 if !msg.command.is_empty() {
404 let route_id = if msg.route.is_empty() {
405 DEFAULT_ROUTE.to_string()
406 } else {
407 msg.route.clone()
408 };
409 let ok = self.send_routed_message(&input.msg_bytes, &route_id);
410 if !ok {
411 self.handle_no_route(&msg);
412 continue;
413 }
414 self.register_route_info(&msg.reqid, &msg.source, &route_id);
415 continue;
416 }
417
418 if !msg.reqid.is_empty() {
420 if let Some(info) = self.get_route_info(&msg.reqid) {
421 self.send_routed_message(&input.msg_bytes, &info.dest_route_id);
422 }
423 continue;
424 }
425
426 if !msg.resid.is_empty() {
428 if self.try_simple_response(&msg) {
429 continue;
430 }
431 if let Some(info) = self.get_route_info(&msg.resid) {
432 self.send_routed_message(&input.msg_bytes, &info.source_route_id);
433 if !msg.cont {
434 self.unregister_route_info(&msg.resid);
435 }
436 }
437 continue;
438 }
439
440 }
442 }
443}
444
445impl Default for WshRouter {
446 fn default() -> Self {
447 let (input_tx, _) = mpsc::channel(ROUTER_CHANNEL_CAPACITY);
450 Self {
451 inner: Mutex::new(RouterInner {
452 route_map: HashMap::new(),
453 announced_routes: HashMap::new(),
454 rpc_map: HashMap::new(),
455 simple_request_map: HashMap::new(),
456 }),
457 input_tx,
458 }
459 }
460}
461
462#[cfg(test)]
467mod tests {
468 use super::*;
469 use std::sync::Arc;
470
471 #[test]
472 fn test_route_id_helpers() {
473 assert_eq!(make_connection_route_id("myconn"), "conn:myconn");
474 assert_eq!(make_controller_route_id("blk-1"), "controller:blk-1");
475 assert_eq!(make_proc_route_id("p-1"), "proc:p-1");
476 assert_eq!(make_tab_route_id("tab-1"), "tab:tab-1");
477 assert_eq!(make_fe_block_route_id("blk-2"), "feblock:blk-2");
478 }
479
480 struct CollectorClient {
482 messages: std::sync::Mutex<Vec<Vec<u8>>>,
483 }
484
485 impl CollectorClient {
486 fn new() -> Arc<Self> {
487 Arc::new(Self {
488 messages: std::sync::Mutex::new(Vec::new()),
489 })
490 }
491
492 fn received_messages(&self) -> Vec<RpcMessage> {
493 let msgs = self.messages.lock().unwrap();
494 msgs.iter()
495 .filter_map(|b| serde_json::from_slice(b).ok())
496 .collect()
497 }
498 }
499
500 impl RpcClient for Arc<CollectorClient> {
501 fn send_rpc_message(&self, msg: &[u8]) {
502 self.messages.lock().unwrap().push(msg.to_vec());
503 }
504 }
505
506 #[tokio::test]
507 async fn test_register_unregister_route() {
508 let router = WshRouter::new();
509 let client = CollectorClient::new();
510 router.register_route("test-route", Box::new(client.clone()));
511 assert!(router.has_route("test-route"));
512 assert!(!router.has_route("other-route"));
513
514 router.unregister_route("test-route");
515 assert!(!router.has_route("test-route"));
516 }
517
518 #[tokio::test]
519 async fn test_route_command_to_destination() {
520 let router = WshRouter::new();
521 let server = CollectorClient::new();
522 router.register_route(DEFAULT_ROUTE, Box::new(server.clone()));
523
524 let msg = RpcMessage {
525 command: "getmeta".to_string(),
526 reqid: "req-1".to_string(),
527 source: "tab:t1".to_string(),
528 ..Default::default()
529 };
530 let msg_bytes = serde_json::to_vec(&msg).unwrap();
531 router.inject_message(msg_bytes, "tab:t1");
532
533 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
535
536 let received = server.received_messages();
537 assert_eq!(received.len(), 1);
538 assert_eq!(received[0].command, "getmeta");
539 }
540
541 #[tokio::test]
542 async fn test_route_response_back_to_source() {
543 let router = WshRouter::new();
544
545 let server = CollectorClient::new();
546 let tab_client = CollectorClient::new();
547
548 router.register_route(DEFAULT_ROUTE, Box::new(server.clone()));
549 router.register_route("tab:t1", Box::new(tab_client.clone()));
550
551 let cmd = RpcMessage {
553 command: "getmeta".to_string(),
554 reqid: "req-1".to_string(),
555 source: "tab:t1".to_string(),
556 ..Default::default()
557 };
558 router.inject_message(serde_json::to_vec(&cmd).unwrap(), "tab:t1");
559 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
560
561 let resp = RpcMessage {
563 resid: "req-1".to_string(),
564 data: Some(serde_json::json!({"view": "term"})),
565 ..Default::default()
566 };
567 router.inject_message(serde_json::to_vec(&resp).unwrap(), DEFAULT_ROUTE);
568 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
569
570 let tab_msgs = tab_client.received_messages();
572 assert_eq!(tab_msgs.len(), 1);
573 assert_eq!(tab_msgs[0].resid, "req-1");
574 }
575
576 #[tokio::test]
577 async fn test_streaming_response_cont_flag() {
578 let router = WshRouter::new();
579 let server = CollectorClient::new();
580 let tab = CollectorClient::new();
581
582 router.register_route(DEFAULT_ROUTE, Box::new(server.clone()));
583 router.register_route("tab:t1", Box::new(tab.clone()));
584
585 let cmd = RpcMessage {
587 command: "filereadstream".to_string(),
588 reqid: "req-stream".to_string(),
589 source: "tab:t1".to_string(),
590 ..Default::default()
591 };
592 router.inject_message(serde_json::to_vec(&cmd).unwrap(), "tab:t1");
593 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
594
595 let resp1 = RpcMessage {
597 resid: "req-stream".to_string(),
598 cont: true,
599 data: Some(serde_json::json!({"chunk": 1})),
600 ..Default::default()
601 };
602 router.inject_message(serde_json::to_vec(&resp1).unwrap(), DEFAULT_ROUTE);
603 tokio::time::sleep(std::time::Duration::from_millis(20)).await;
604
605 let resp2 = RpcMessage {
607 resid: "req-stream".to_string(),
608 cont: false,
609 data: Some(serde_json::json!({"chunk": 2})),
610 ..Default::default()
611 };
612 router.inject_message(serde_json::to_vec(&resp2).unwrap(), DEFAULT_ROUTE);
613 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
614
615 let msgs = tab.received_messages();
616 assert_eq!(msgs.len(), 2);
617 assert!(msgs[0].cont);
618 assert!(!msgs[1].cont);
619 }
620
621 #[tokio::test]
622 async fn test_no_route_returns_error() {
623 let router = WshRouter::new();
624 let tab = CollectorClient::new();
625 router.register_route("tab:t1", Box::new(tab.clone()));
626
627 let cmd = RpcMessage {
629 command: "getmeta".to_string(),
630 reqid: "req-err".to_string(),
631 source: "tab:t1".to_string(),
632 route: "nonexistent".to_string(),
633 ..Default::default()
634 };
635 router.inject_message(serde_json::to_vec(&cmd).unwrap(), "tab:t1");
636 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
637
638 let msgs = tab.received_messages();
639 assert_eq!(msgs.len(), 1);
640 assert!(!msgs[0].error.is_empty());
641 assert!(msgs[0].error.contains("no route"));
642 }
643
644 #[tokio::test]
645 async fn test_announced_routes() {
646 let router = WshRouter::new();
647 let proxy = CollectorClient::new();
648 router.register_route("proxy-1", Box::new(proxy.clone()));
649
650 let announce = RpcMessage {
652 command: COMMAND_ROUTE_ANNOUNCE.to_string(),
653 source: "conn:myhost".to_string(),
654 ..Default::default()
655 };
656 router.inject_message(serde_json::to_vec(&announce).unwrap(), "proxy-1");
657 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
658
659 let cmd = RpcMessage {
661 command: "test".to_string(),
662 route: "conn:myhost".to_string(),
663 source: "sys".to_string(),
664 ..Default::default()
665 };
666 router.inject_message(serde_json::to_vec(&cmd).unwrap(), "sys");
667 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
668
669 let msgs = proxy.received_messages();
670 assert_eq!(msgs.len(), 1);
671 assert_eq!(msgs[0].command, "test");
672 }
673
674 #[tokio::test]
675 async fn test_simple_raw_command() {
676 let router = WshRouter::new();
677 let server = CollectorClient::new();
678 router.register_route(DEFAULT_ROUTE, Box::new(server.clone()));
679
680 let router_clone = router.clone();
682 tokio::spawn(async move {
683 tokio::time::sleep(std::time::Duration::from_millis(30)).await;
684 let resp = RpcMessage {
685 resid: "simple-req".to_string(),
686 data: Some(serde_json::json!({"ok": true})),
687 ..Default::default()
688 };
689 router_clone.inject_message(
690 serde_json::to_vec(&resp).unwrap(),
691 DEFAULT_ROUTE,
692 );
693 });
694
695 let cmd = RpcMessage {
696 command: "test".to_string(),
697 reqid: "simple-req".to_string(),
698 ..Default::default()
699 };
700 let result = router
701 .run_simple_raw_command(
702 cmd,
703 "sys",
704 std::time::Duration::from_secs(1),
705 )
706 .await;
707 assert!(result.is_ok());
708 let resp = result.unwrap().unwrap();
709 assert_eq!(resp.data, Some(serde_json::json!({"ok": true})));
710 }
711
712 #[tokio::test]
713 async fn test_wait_for_register() {
714 let router = WshRouter::new();
715
716 let router_clone = router.clone();
718 tokio::spawn(async move {
719 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
720 let client = CollectorClient::new();
721 router_clone.register_route("delayed-route", Box::new(client));
722 });
723
724 let found = router
725 .wait_for_register("delayed-route", std::time::Duration::from_secs(1))
726 .await;
727 assert!(found);
728
729 let not_found = router
731 .wait_for_register("never-route", std::time::Duration::from_millis(50))
732 .await;
733 assert!(!not_found);
734 }
735}