agentmux_srv\backend\wshutil/
proxy.rs

1#![allow(dead_code)]
2// Copyright 2025-2026, AgentMux Corp.
3// SPDX-License-Identifier: Apache-2.0
4
5//! RPC proxy types for forwarding messages between connections.
6//! Port of Go's `pkg/wshutil/wshproxy.go` and `wshmultiproxy.go`.
7
8
9use std::collections::HashMap;
10use std::sync::{Arc, Mutex};
11use serde::{Deserialize, Serialize};
12use tokio::sync::mpsc;
13use super::osc::{DEFAULT_INPUT_CH_SIZE, DEFAULT_OUTPUT_CH_SIZE};
14
15/// RPC context passed with each message.
16#[derive(Debug, Clone, Default, Serialize, Deserialize)]
17pub struct RpcContext {
18    #[serde(default, skip_serializing_if = "String::is_empty", rename = "blockid")]
19    pub block_id: String,
20    #[serde(default, skip_serializing_if = "String::is_empty", rename = "tabid")]
21    pub tab_id: String,
22    #[serde(default, skip_serializing_if = "String::is_empty", rename = "conn")]
23    pub conn: String,
24}
25
26/// RPC message format for JSON-RPC communication.
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct RpcMessage {
29    #[serde(default, skip_serializing_if = "String::is_empty")]
30    pub command: String,
31    #[serde(default, skip_serializing_if = "String::is_empty", rename = "reqid")]
32    pub req_id: String,
33    #[serde(default, skip_serializing_if = "String::is_empty", rename = "resid")]
34    pub res_id: String,
35    #[serde(default, skip_serializing_if = "Option::is_none")]
36    pub data: Option<serde_json::Value>,
37    #[serde(default, skip_serializing_if = "Option::is_none")]
38    pub error: Option<String>,
39    #[serde(default)]
40    pub cont: bool,
41    #[serde(default, skip_serializing_if = "Option::is_none")]
42    pub cancel: Option<bool>,
43    #[serde(default, skip_serializing_if = "Option::is_none", rename = "route")]
44    pub route: Option<String>,
45    #[serde(default, skip_serializing_if = "Option::is_none", rename = "source")]
46    pub source: Option<String>,
47    #[serde(default, skip_serializing_if = "Option::is_none", rename = "authtoken")]
48    pub auth_token: Option<String>,
49    #[serde(default, skip_serializing_if = "Option::is_none", rename = "timeout")]
50    pub timeout: Option<u64>,
51}
52
53impl RpcMessage {
54    /// Check if this is a request (has command and reqid).
55    pub fn is_request(&self) -> bool {
56        !self.command.is_empty() && !self.req_id.is_empty()
57    }
58
59    /// Check if this is a response (has resid).
60    pub fn is_response(&self) -> bool {
61        !self.res_id.is_empty()
62    }
63
64    /// Check if this is an error response.
65    pub fn is_error(&self) -> bool {
66        self.error.is_some()
67    }
68
69    /// Check if this is the final response (not continued).
70    pub fn is_final(&self) -> bool {
71        !self.cont
72    }
73}
74
75/// Single-connection RPC proxy.
76/// Forwards messages between a local connection and a remote endpoint.
77pub struct WshRpcProxy {
78    rpc_context: Arc<Mutex<Option<RpcContext>>>,
79    auth_token: Arc<Mutex<String>>,
80    pub to_remote: mpsc::Sender<Vec<u8>>,
81    pub from_remote: mpsc::Receiver<Vec<u8>>,
82    to_remote_rx: Option<mpsc::Receiver<Vec<u8>>>,
83    from_remote_tx: mpsc::Sender<Vec<u8>>,
84}
85
86impl WshRpcProxy {
87    pub fn new() -> Self {
88        let (to_remote_tx, to_remote_rx) = mpsc::channel(DEFAULT_INPUT_CH_SIZE);
89        let (from_remote_tx, from_remote_rx) = mpsc::channel(DEFAULT_OUTPUT_CH_SIZE);
90        Self {
91            rpc_context: Arc::new(Mutex::new(None)),
92            auth_token: Arc::new(Mutex::new(String::new())),
93            to_remote: to_remote_tx,
94            from_remote: from_remote_rx,
95            to_remote_rx: Some(to_remote_rx),
96            from_remote_tx,
97        }
98    }
99
100    pub fn set_rpc_context(&self, ctx: RpcContext) {
101        *self.rpc_context.lock().unwrap() = Some(ctx);
102    }
103
104    pub fn get_rpc_context(&self) -> Option<RpcContext> {
105        self.rpc_context.lock().unwrap().clone()
106    }
107
108    pub fn set_auth_token(&self, token: &str) {
109        *self.auth_token.lock().unwrap() = token.to_string();
110    }
111
112    pub fn get_auth_token(&self) -> String {
113        self.auth_token.lock().unwrap().clone()
114    }
115
116    /// Take the receiver end of to_remote channel (for proxy loop).
117    pub fn take_to_remote_rx(&mut self) -> Option<mpsc::Receiver<Vec<u8>>> {
118        self.to_remote_rx.take()
119    }
120
121    /// Get a clone of the from_remote sender (for injecting messages).
122    pub fn from_remote_sender(&self) -> mpsc::Sender<Vec<u8>> {
123        self.from_remote_tx.clone()
124    }
125
126    /// Send a message to the remote endpoint.
127    pub async fn send_to_remote(&self, msg: Vec<u8>) -> Result<(), String> {
128        self.to_remote
129            .send(msg)
130            .await
131            .map_err(|e| format!("failed to send to remote: {}", e))
132    }
133
134    /// Inject an RPC message (encode as JSON and send to remote).
135    pub async fn send_rpc_message(&self, msg: &RpcMessage) -> Result<(), String> {
136        let json = serde_json::to_vec(msg).map_err(|e| format!("json encode: {}", e))?;
137        self.send_to_remote(json).await
138    }
139
140    /// Send an error response for a given request.
141    pub async fn send_response_error(&self, req_id: &str, err_msg: &str) -> Result<(), String> {
142        if req_id.is_empty() {
143            return Ok(());
144        }
145        let msg = RpcMessage {
146            res_id: req_id.to_string(),
147            error: Some(err_msg.to_string()),
148            ..Default::default()
149        };
150        self.send_rpc_message(&msg).await
151    }
152}
153
154impl Default for RpcMessage {
155    fn default() -> Self {
156        Self {
157            command: String::new(),
158            req_id: String::new(),
159            res_id: String::new(),
160            data: None,
161            error: None,
162            cont: false,
163            cancel: None,
164            route: None,
165            source: None,
166            auth_token: None,
167            timeout: None,
168        }
169    }
170}
171
172/// Multi-connection broadcast proxy.
173/// Sends messages to multiple remote connections simultaneously.
174pub struct WshMultiProxy {
175    proxies: Arc<Mutex<HashMap<String, mpsc::Sender<Vec<u8>>>>>,
176}
177
178impl WshMultiProxy {
179    pub fn new() -> Self {
180        Self {
181            proxies: Arc::new(Mutex::new(HashMap::new())),
182        }
183    }
184
185    /// Add a named proxy connection.
186    pub fn add_proxy(&self, name: &str, sender: mpsc::Sender<Vec<u8>>) {
187        self.proxies.lock().unwrap().insert(name.to_string(), sender);
188    }
189
190    /// Remove a named proxy connection.
191    pub fn remove_proxy(&self, name: &str) {
192        self.proxies.lock().unwrap().remove(name);
193    }
194
195    /// Broadcast a message to all connected proxies.
196    pub async fn broadcast(&self, msg: Vec<u8>) {
197        let senders: Vec<mpsc::Sender<Vec<u8>>> = {
198            let proxies = self.proxies.lock().unwrap();
199            proxies.values().cloned().collect()
200        };
201
202        for sender in senders {
203            let msg_clone = msg.clone();
204            let _ = sender.send(msg_clone).await;
205        }
206    }
207
208    /// Broadcast an RPC message to all connected proxies.
209    pub async fn broadcast_rpc_message(&self, msg: &RpcMessage) -> Result<(), String> {
210        let json = serde_json::to_vec(msg).map_err(|e| format!("json encode: {}", e))?;
211        self.broadcast(json).await;
212        Ok(())
213    }
214
215    /// Get the count of connected proxies.
216    pub fn proxy_count(&self) -> usize {
217        self.proxies.lock().unwrap().len()
218    }
219
220    /// Get names of all connected proxies.
221    pub fn proxy_names(&self) -> Vec<String> {
222        self.proxies.lock().unwrap().keys().cloned().collect()
223    }
224}
225
226impl Default for WshMultiProxy {
227    fn default() -> Self {
228        Self::new()
229    }
230}
231
232#[cfg(test)]
233mod tests {
234    use super::*;
235
236    #[test]
237    fn test_rpc_message_request() {
238        let msg = RpcMessage {
239            command: "test".to_string(),
240            req_id: "abc123".to_string(),
241            ..Default::default()
242        };
243        assert!(msg.is_request());
244        assert!(!msg.is_response());
245        assert!(!msg.is_error());
246        assert!(msg.is_final());
247    }
248
249    #[test]
250    fn test_rpc_message_response() {
251        let msg = RpcMessage {
252            res_id: "abc123".to_string(),
253            data: Some(serde_json::json!({"result": "ok"})),
254            ..Default::default()
255        };
256        assert!(!msg.is_request());
257        assert!(msg.is_response());
258        assert!(!msg.is_error());
259    }
260
261    #[test]
262    fn test_rpc_message_error() {
263        let msg = RpcMessage {
264            res_id: "abc123".to_string(),
265            error: Some("something failed".to_string()),
266            ..Default::default()
267        };
268        assert!(msg.is_error());
269    }
270
271    #[test]
272    fn test_rpc_message_serde() {
273        let msg = RpcMessage {
274            command: "getblock".to_string(),
275            req_id: "req-1".to_string(),
276            data: Some(serde_json::json!({"id": "block-1"})),
277            route: Some("conn:local".to_string()),
278            ..Default::default()
279        };
280        let json = serde_json::to_string(&msg).unwrap();
281        let parsed: RpcMessage = serde_json::from_str(&json).unwrap();
282        assert_eq!(parsed.command, "getblock");
283        assert_eq!(parsed.req_id, "req-1");
284        assert_eq!(parsed.route.unwrap(), "conn:local");
285    }
286
287    #[tokio::test]
288    async fn test_multi_proxy_broadcast() {
289        let multi = WshMultiProxy::new();
290        let (tx1, mut rx1) = mpsc::channel(10);
291        let (tx2, mut rx2) = mpsc::channel(10);
292
293        multi.add_proxy("conn1", tx1);
294        multi.add_proxy("conn2", tx2);
295        assert_eq!(multi.proxy_count(), 2);
296
297        multi.broadcast(b"hello".to_vec()).await;
298
299        let msg1 = rx1.recv().await.unwrap();
300        let msg2 = rx2.recv().await.unwrap();
301        assert_eq!(msg1, b"hello");
302        assert_eq!(msg2, b"hello");
303    }
304
305    #[test]
306    fn test_multi_proxy_add_remove() {
307        let multi = WshMultiProxy::new();
308        let (tx, _rx) = mpsc::channel(10);
309
310        multi.add_proxy("conn1", tx);
311        assert_eq!(multi.proxy_count(), 1);
312
313        multi.remove_proxy("conn1");
314        assert_eq!(multi.proxy_count(), 0);
315    }
316}