agentmux_srv\backend\wshutil/
wshrpc.rs

1#![allow(dead_code)]
2// Copyright 2025-2026, AgentMux Corp.
3// SPDX-License-Identifier: Apache-2.0
4
5//! WshRpc — main RPC client with message routing and response handling.
6//! Port of Go's `pkg/wshutil/wshrpc.go`.
7//!
8//! Provides the core RPC communication layer:
9//! - Message send/receive via channels
10//! - Request/response correlation via request IDs
11//! - Streaming responses (continued flag)
12//! - Context cancellation propagation
13//! - Auth token injection
14
15
16use std::collections::HashMap;
17use std::sync::{Arc, Mutex};
18use std::sync::atomic::{AtomicBool, Ordering};
19use serde_json::Value;
20use tokio::sync::mpsc;
21use uuid::Uuid;
22
23use super::event::EventListener;
24use super::proxy::{RpcContext, RpcMessage};
25use super::osc::{DEFAULT_INPUT_CH_SIZE, DEFAULT_OUTPUT_CH_SIZE};
26
27/// Default timeout for RPC calls in milliseconds.
28pub const DEFAULT_TIMEOUT_MS: u64 = 5000;
29/// Channel buffer size for response channels.
30pub const RESP_CH_SIZE: usize = 32;
31
32/// RPC response data (single response or stream item).
33#[derive(Debug, Clone)]
34pub struct RpcResponse {
35    pub data: Option<Value>,
36    pub error: Option<String>,
37    pub is_final: bool,
38}
39
40/// Pending RPC request state.
41struct RpcData {
42    resp_tx: mpsc::Sender<RpcResponse>,
43}
44
45/// Handler for incoming RPC requests (server side).
46pub struct RpcResponseHandler {
47    pub req_id: String,
48    pub command: String,
49    pub data: Option<Value>,
50    pub rpc_context: RpcContext,
51    response_tx: mpsc::Sender<Vec<u8>>,
52    finalized: AtomicBool,
53}
54
55impl RpcResponseHandler {
56    /// Get the command name.
57    pub fn get_command(&self) -> &str {
58        &self.command
59    }
60
61    /// Get the raw command data.
62    pub fn get_command_raw_data(&self) -> Option<&Value> {
63        self.data.as_ref()
64    }
65
66    /// Get the RPC context.
67    pub fn get_rpc_context(&self) -> &RpcContext {
68        &self.rpc_context
69    }
70
71    /// Check if this request needs a response.
72    pub fn needs_response(&self) -> bool {
73        !self.req_id.is_empty()
74    }
75
76    /// Send a response (data + done flag).
77    pub async fn send_response(&self, data: Option<Value>, done: bool) -> Result<(), String> {
78        if !self.needs_response() {
79            return Ok(());
80        }
81
82        let msg = RpcMessage {
83            res_id: self.req_id.clone(),
84            data,
85            cont: !done,
86            ..Default::default()
87        };
88
89        let json = serde_json::to_vec(&msg).map_err(|e| format!("json encode: {}", e))?;
90        self.response_tx
91            .send(json)
92            .await
93            .map_err(|e| format!("send response: {}", e))?;
94
95        if done {
96            self.finalized.store(true, Ordering::SeqCst);
97        }
98        Ok(())
99    }
100
101    /// Send an error response.
102    pub async fn send_response_error(&self, err: &str) -> Result<(), String> {
103        if !self.needs_response() {
104            return Ok(());
105        }
106
107        let msg = RpcMessage {
108            res_id: self.req_id.clone(),
109            error: Some(err.to_string()),
110            ..Default::default()
111        };
112
113        let json = serde_json::to_vec(&msg).map_err(|e| format!("json encode: {}", e))?;
114        self.response_tx
115            .send(json)
116            .await
117            .map_err(|e| format!("send error response: {}", e))?;
118
119        self.finalized.store(true, Ordering::SeqCst);
120        Ok(())
121    }
122
123    /// Mark the handler as finalized.
124    pub fn finalize(&self) {
125        self.finalized.store(true, Ordering::SeqCst);
126    }
127
128    /// Check if the handler has been finalized.
129    pub fn is_finalized(&self) -> bool {
130        self.finalized.load(Ordering::SeqCst)
131    }
132}
133
134/// Callback type for command handlers.
135pub type CommandHandlerFn = Box<dyn Fn(RpcResponseHandler) -> bool + Send + Sync>;
136
137/// Main WshRpc client.
138///
139/// Provides bidirectional RPC over message channels:
140/// - Send requests and wait for responses
141/// - Handle incoming requests with registered handlers
142/// - Support streaming responses
143/// - Auth token management
144pub struct WshRpc {
145    input_ch: mpsc::Sender<Vec<u8>>,
146    output_ch: mpsc::Sender<Vec<u8>>,
147    rpc_context: Arc<Mutex<Option<RpcContext>>>,
148    auth_token: Arc<Mutex<String>>,
149    rpc_map: Arc<Mutex<HashMap<String, RpcData>>>,
150    event_listener: Arc<EventListener>,
151    debug: AtomicBool,
152    debug_name: String,
153    server_done: AtomicBool,
154}
155
156impl WshRpc {
157    /// Create a new WshRpc client with input/output channels.
158    pub fn new(debug_name: &str) -> (Self, mpsc::Receiver<Vec<u8>>, mpsc::Sender<Vec<u8>>) {
159        let (input_tx, _input_rx) = mpsc::channel(DEFAULT_INPUT_CH_SIZE);
160        let (output_tx, output_rx) = mpsc::channel(DEFAULT_OUTPUT_CH_SIZE);
161
162        let input_tx_clone = input_tx.clone();
163        let rpc = Self {
164            input_ch: input_tx,
165            output_ch: output_tx,
166            rpc_context: Arc::new(Mutex::new(None)),
167            auth_token: Arc::new(Mutex::new(String::new())),
168            rpc_map: Arc::new(Mutex::new(HashMap::new())),
169            event_listener: Arc::new(EventListener::new()),
170            debug: AtomicBool::new(false),
171            debug_name: debug_name.to_string(),
172            server_done: AtomicBool::new(false),
173        };
174
175        (rpc, output_rx, input_tx_clone)
176    }
177
178    /// Set the RPC context.
179    pub fn set_rpc_context(&self, ctx: RpcContext) {
180        *self.rpc_context.lock().unwrap() = Some(ctx);
181    }
182
183    /// Get the current RPC context.
184    pub fn get_rpc_context(&self) -> Option<RpcContext> {
185        self.rpc_context.lock().unwrap().clone()
186    }
187
188    /// Set the auth token.
189    pub fn set_auth_token(&self, token: &str) {
190        *self.auth_token.lock().unwrap() = token.to_string();
191    }
192
193    /// Get the auth token.
194    pub fn get_auth_token(&self) -> String {
195        self.auth_token.lock().unwrap().clone()
196    }
197
198    /// Enable or disable debug logging.
199    pub fn set_debug(&self, debug: bool) {
200        self.debug.store(debug, Ordering::SeqCst);
201    }
202
203    /// Get the event listener.
204    pub fn get_event_listener(&self) -> &EventListener {
205        &self.event_listener
206    }
207
208    /// Send an RPC request and wait for a single response.
209    pub async fn send_rpc_request(
210        &self,
211        command: &str,
212        data: Option<Value>,
213        timeout_ms: Option<u64>,
214    ) -> Result<Option<Value>, String> {
215        let req_id = Uuid::new_v4().to_string();
216        let timeout = timeout_ms.unwrap_or(DEFAULT_TIMEOUT_MS);
217
218        // Create response channel
219        let (resp_tx, mut resp_rx) = mpsc::channel(RESP_CH_SIZE);
220        self.rpc_map.lock().unwrap().insert(req_id.clone(), RpcData { resp_tx });
221
222        // Build request message
223        let mut msg = RpcMessage {
224            command: command.to_string(),
225            req_id: req_id.clone(),
226            data,
227            timeout: Some(timeout),
228            ..Default::default()
229        };
230
231        // Inject auth token
232        let auth_token = self.get_auth_token();
233        if !auth_token.is_empty() {
234            msg.auth_token = Some(auth_token);
235        }
236
237        // Send request
238        let json = serde_json::to_vec(&msg).map_err(|e| format!("json encode: {}", e))?;
239        self.output_ch
240            .send(json)
241            .await
242            .map_err(|e| format!("send request: {}", e))?;
243
244        // Wait for response with timeout
245        let result = tokio::time::timeout(
246            std::time::Duration::from_millis(timeout),
247            resp_rx.recv(),
248        )
249        .await;
250
251        // Clean up
252        self.rpc_map.lock().unwrap().remove(&req_id);
253
254        match result {
255            Ok(Some(resp)) => {
256                if let Some(err) = resp.error {
257                    Err(err)
258                } else {
259                    Ok(resp.data)
260                }
261            }
262            Ok(None) => Err("response channel closed".to_string()),
263            Err(_) => Err(format!("RPC timeout after {}ms", timeout)),
264        }
265    }
266
267    /// Send a fire-and-forget message (no response expected).
268    pub async fn send_message(&self, command: &str, data: Option<Value>) -> Result<(), String> {
269        let mut msg = RpcMessage {
270            command: command.to_string(),
271            req_id: String::new(), // no response expected
272            data,
273            ..Default::default()
274        };
275
276        let auth_token = self.get_auth_token();
277        if !auth_token.is_empty() {
278            msg.auth_token = Some(auth_token);
279        }
280
281        let json = serde_json::to_vec(&msg).map_err(|e| format!("json encode: {}", e))?;
282        self.output_ch
283            .send(json)
284            .await
285            .map_err(|e| format!("send message: {}", e))
286    }
287
288    /// Process an incoming message (response or request).
289    pub fn process_incoming_message(&self, raw_msg: &[u8]) -> Result<(), String> {
290        let msg: RpcMessage =
291            serde_json::from_slice(raw_msg).map_err(|e| format!("json decode: {}", e))?;
292
293        if msg.is_response() {
294            self.handle_response(msg)
295        } else if msg.is_request() {
296            // Request handling would be delegated to registered handlers
297            tracing::debug!("incoming request: {} ({})", msg.command, msg.req_id);
298            Ok(())
299        } else {
300            Err("message is neither request nor response".to_string())
301        }
302    }
303
304    /// Handle an incoming response message.
305    fn handle_response(&self, msg: RpcMessage) -> Result<(), String> {
306        let rpc_map = self.rpc_map.lock().unwrap();
307        if let Some(rpc_data) = rpc_map.get(&msg.res_id) {
308            let resp = RpcResponse {
309                data: msg.data,
310                error: msg.error,
311                is_final: !msg.cont,
312            };
313            let _ = rpc_data.resp_tx.try_send(resp);
314        } else if self.debug.load(Ordering::SeqCst) {
315            tracing::warn!(
316                "[{}] received response for unknown req_id: {}",
317                self.debug_name,
318                msg.res_id
319            );
320        }
321        Ok(())
322    }
323
324    /// Check if the server is done.
325    pub fn is_server_done(&self) -> bool {
326        self.server_done.load(Ordering::SeqCst)
327    }
328
329    /// Mark the server as done.
330    pub fn set_server_done(&self) {
331        self.server_done.store(true, Ordering::SeqCst);
332    }
333
334    /// Get the count of pending RPC requests.
335    pub fn pending_count(&self) -> usize {
336        self.rpc_map.lock().unwrap().len()
337    }
338}
339
340#[cfg(test)]
341mod tests {
342    use super::*;
343
344    #[tokio::test]
345    async fn test_wshrpc_create() {
346        let (rpc, _output_rx, _input_tx) = WshRpc::new("test");
347        assert!(!rpc.is_server_done());
348        assert_eq!(rpc.pending_count(), 0);
349        assert_eq!(rpc.get_auth_token(), "");
350    }
351
352    #[tokio::test]
353    async fn test_wshrpc_auth_token() {
354        let (rpc, _output_rx, _input_tx) = WshRpc::new("test");
355        rpc.set_auth_token("secret123");
356        assert_eq!(rpc.get_auth_token(), "secret123");
357    }
358
359    #[tokio::test]
360    async fn test_wshrpc_rpc_context() {
361        let (rpc, _output_rx, _input_tx) = WshRpc::new("test");
362        assert!(rpc.get_rpc_context().is_none());
363
364        rpc.set_rpc_context(RpcContext {
365            block_id: "block1".to_string(),
366            tab_id: "tab1".to_string(),
367            conn: "local".to_string(),
368        });
369
370        let ctx = rpc.get_rpc_context().unwrap();
371        assert_eq!(ctx.block_id, "block1");
372        assert_eq!(ctx.tab_id, "tab1");
373    }
374
375    #[tokio::test]
376    async fn test_wshrpc_send_message() {
377        let (rpc, mut output_rx, _input_tx) = WshRpc::new("test");
378        rpc.set_auth_token("token123");
379
380        rpc.send_message("notify", Some(serde_json::json!({"msg": "hello"})))
381            .await
382            .unwrap();
383
384        let raw = output_rx.recv().await.unwrap();
385        let msg: RpcMessage = serde_json::from_slice(&raw).unwrap();
386        assert_eq!(msg.command, "notify");
387        assert!(msg.req_id.is_empty()); // fire-and-forget
388        assert_eq!(msg.auth_token.unwrap(), "token123");
389    }
390
391    #[tokio::test]
392    async fn test_wshrpc_process_response() {
393        let (rpc, mut output_rx, _input_tx) = WshRpc::new("test");
394
395        // Simulate a pending request
396        let (resp_tx, mut resp_rx) = mpsc::channel(RESP_CH_SIZE);
397        rpc.rpc_map
398            .lock()
399            .unwrap()
400            .insert("req-1".to_string(), RpcData { resp_tx });
401
402        // Process a response
403        let response = RpcMessage {
404            res_id: "req-1".to_string(),
405            data: Some(serde_json::json!({"result": "success"})),
406            ..Default::default()
407        };
408        let raw = serde_json::to_vec(&response).unwrap();
409        rpc.process_incoming_message(&raw).unwrap();
410
411        // Check response was delivered
412        let resp = resp_rx.recv().await.unwrap();
413        assert!(resp.error.is_none());
414        assert!(resp.is_final);
415        assert_eq!(
416            resp.data.unwrap(),
417            serde_json::json!({"result": "success"})
418        );
419    }
420}