agentmux_srv\backend/
userinput.rs

1#![allow(dead_code)]
2// Copyright 2025-2026, AgentMux Corp.
3// SPDX-License-Identifier: Apache-2.0
4
5//! User input: modal dialogs for interactive user prompts.
6//! Port of Go's pkg/userinput/userinput.go.
7//!
8//! Provides request/response types for:
9//! - Text input prompts
10//! - Confirmation dialogs
11//! - Checkbox-bearing dialogs
12//!
13//! The actual display is handled by the frontend (Tauri webview);
14//! this module defines the wire format and a registry for pending requests.
15
16
17use std::collections::HashMap;
18use std::sync::Mutex;
19use std::time::Duration;
20
21use serde::{Deserialize, Serialize};
22use tokio::sync::oneshot;
23use tokio::time;
24
25// ---- Request/Response types ----
26
27/// Request for user input, sent to the frontend for display.
28/// Matches Go's `userinput.UserInputRequest` JSON tags.
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct UserInputRequest {
31    /// Unique request ID.
32    #[serde(rename = "requestid")]
33    pub request_id: String,
34
35    /// Prompt text to display.
36    #[serde(rename = "querytext")]
37    pub query_text: String,
38
39    /// Expected response type: "text", "confirm".
40    #[serde(rename = "responsetype")]
41    pub response_type: String,
42
43    /// Dialog title.
44    pub title: String,
45
46    /// Whether to render title as markdown.
47    #[serde(default)]
48    pub markdown: bool,
49
50    /// Timeout in milliseconds (0 = no timeout).
51    #[serde(rename = "timeoutms", default)]
52    pub timeout_ms: i64,
53
54    /// Optional checkbox label.
55    #[serde(rename = "checkboxmsg", default, skip_serializing_if = "String::is_empty")]
56    pub checkbox_msg: String,
57
58    /// Whether the input text is public (can be logged).
59    #[serde(rename = "publictext", default)]
60    pub public_text: bool,
61
62    /// Custom "OK" button label.
63    #[serde(rename = "oklabel", default, skip_serializing_if = "String::is_empty")]
64    pub ok_label: String,
65
66    /// Custom "Cancel" button label.
67    #[serde(rename = "cancellabel", default, skip_serializing_if = "String::is_empty")]
68    pub cancel_label: String,
69}
70
71/// Response from the frontend after user interaction.
72/// Matches Go's `userinput.UserInputResponse` JSON tags.
73#[derive(Debug, Clone, Default, Serialize, Deserialize)]
74pub struct UserInputResponse {
75    /// Response type.
76    #[serde(rename = "type", default)]
77    pub response_type: String,
78
79    /// Request ID this responds to.
80    #[serde(rename = "requestid")]
81    pub request_id: String,
82
83    /// Text input value (for text responses).
84    #[serde(default, skip_serializing_if = "String::is_empty")]
85    pub text: String,
86
87    /// Confirmation result (for confirm responses).
88    #[serde(default)]
89    pub confirm: bool,
90
91    /// Error message (if the dialog was cancelled or errored).
92    #[serde(rename = "errormsg", default, skip_serializing_if = "String::is_empty")]
93    pub error_msg: String,
94
95    /// Checkbox state.
96    #[serde(rename = "checkboxstat", default)]
97    pub checkbox_stat: bool,
98}
99
100impl UserInputResponse {
101    /// Check if the response indicates an error.
102    pub fn is_error(&self) -> bool {
103        !self.error_msg.is_empty()
104    }
105
106    /// Check if the response is a confirmation.
107    pub fn is_confirmed(&self) -> bool {
108        self.confirm && self.error_msg.is_empty()
109    }
110}
111
112// ---- Response type constants ----
113
114pub const RESPONSE_TYPE_TEXT: &str = "text";
115pub const RESPONSE_TYPE_CONFIRM: &str = "confirm";
116
117// ---- Default timeout ----
118
119/// Default timeout for user input prompts (30 seconds).
120pub const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
121
122/// Timeout used for SSH passphrase/password prompts (60 seconds).
123pub const SSH_PROMPT_TIMEOUT: Duration = Duration::from_secs(60);
124
125// ---- Input handler registry ----
126
127/// Registry for pending user input requests.
128/// Routes responses from the frontend back to waiting async tasks.
129pub struct UserInputHandler {
130    channels: Mutex<HashMap<String, oneshot::Sender<UserInputResponse>>>,
131}
132
133impl UserInputHandler {
134    pub fn new() -> Self {
135        Self {
136            channels: Mutex::new(HashMap::new()),
137        }
138    }
139
140    /// Register a new user input request.
141    /// Returns the request ID and a receiver for the response.
142    pub fn register(&self, request_id: &str) -> oneshot::Receiver<UserInputResponse> {
143        let (tx, rx) = oneshot::channel();
144        let mut channels = self.channels.lock().unwrap();
145        channels.insert(request_id.to_string(), tx);
146        rx
147    }
148
149    /// Deliver a response from the frontend.
150    /// Returns error if no pending request with the given ID.
151    pub fn deliver(&self, response: UserInputResponse) -> Result<(), String> {
152        let mut channels = self.channels.lock().unwrap();
153        let tx = channels
154            .remove(&response.request_id)
155            .ok_or_else(|| format!("no pending request: {}", response.request_id))?;
156        tx.send(response)
157            .map_err(|_| "receiver dropped".to_string())
158    }
159
160    /// Cancel a pending request (removes it from the registry).
161    pub fn cancel(&self, request_id: &str) {
162        let mut channels = self.channels.lock().unwrap();
163        channels.remove(request_id);
164    }
165
166    /// Check if there's a pending request with the given ID.
167    pub fn has_pending(&self, request_id: &str) -> bool {
168        let channels = self.channels.lock().unwrap();
169        channels.contains_key(request_id)
170    }
171
172    /// Get count of pending requests.
173    pub fn pending_count(&self) -> usize {
174        let channels = self.channels.lock().unwrap();
175        channels.len()
176    }
177}
178
179impl Default for UserInputHandler {
180    fn default() -> Self {
181        Self::new()
182    }
183}
184
185/// Wait for a user input response with timeout.
186pub async fn wait_for_response(
187    rx: oneshot::Receiver<UserInputResponse>,
188    timeout: Duration,
189) -> Result<UserInputResponse, String> {
190    match time::timeout(timeout, rx).await {
191        Ok(Ok(response)) => {
192            if response.is_error() {
193                Err(response.error_msg)
194            } else {
195                Ok(response)
196            }
197        }
198        Ok(Err(_)) => Err("user input handler was cancelled".to_string()),
199        Err(_) => Err("timed out waiting for user input".to_string()),
200    }
201}
202
203#[cfg(test)]
204mod tests {
205    use super::*;
206
207    #[test]
208    fn test_user_input_request_serde() {
209        let req = UserInputRequest {
210            request_id: "req-1".to_string(),
211            query_text: "Enter password:".to_string(),
212            response_type: RESPONSE_TYPE_TEXT.to_string(),
213            title: "SSH Authentication".to_string(),
214            markdown: true,
215            timeout_ms: 60000,
216            checkbox_msg: String::new(),
217            public_text: false,
218            ok_label: String::new(),
219            cancel_label: String::new(),
220        };
221        let json = serde_json::to_string(&req).unwrap();
222        assert!(json.contains("\"requestid\":\"req-1\""));
223        assert!(json.contains("\"querytext\":\"Enter password:\""));
224        assert!(json.contains("\"responsetype\":\"text\""));
225        assert!(json.contains("\"timeoutms\":60000"));
226        // Empty fields should be omitted
227        assert!(!json.contains("\"checkboxmsg\""));
228        assert!(!json.contains("\"oklabel\""));
229
230        let parsed: UserInputRequest = serde_json::from_str(&json).unwrap();
231        assert_eq!(parsed.request_id, "req-1");
232        assert!(parsed.markdown);
233    }
234
235    #[test]
236    fn test_user_input_response_serde() {
237        let resp = UserInputResponse {
238            response_type: RESPONSE_TYPE_TEXT.to_string(),
239            request_id: "req-1".to_string(),
240            text: "my_password".to_string(),
241            confirm: false,
242            error_msg: String::new(),
243            checkbox_stat: false,
244        };
245        let json = serde_json::to_string(&resp).unwrap();
246        assert!(json.contains("\"requestid\":\"req-1\""));
247        assert!(json.contains("\"text\":\"my_password\""));
248        assert!(!json.contains("\"errormsg\"")); // empty, omitted
249
250        let parsed: UserInputResponse = serde_json::from_str(&json).unwrap();
251        assert_eq!(parsed.text, "my_password");
252    }
253
254    #[test]
255    fn test_response_is_error() {
256        let resp = UserInputResponse {
257            error_msg: "cancelled".to_string(),
258            ..Default::default()
259        };
260        assert!(resp.is_error());
261        assert!(!resp.is_confirmed());
262    }
263
264    #[test]
265    fn test_response_is_confirmed() {
266        let resp = UserInputResponse {
267            confirm: true,
268            ..Default::default()
269        };
270        assert!(resp.is_confirmed());
271        assert!(!resp.is_error());
272    }
273
274    #[test]
275    fn test_handler_register_deliver() {
276        let handler = UserInputHandler::new();
277        let mut rx = handler.register("req-1");
278        assert!(handler.has_pending("req-1"));
279        assert_eq!(handler.pending_count(), 1);
280
281        let response = UserInputResponse {
282            request_id: "req-1".to_string(),
283            text: "hello".to_string(),
284            ..Default::default()
285        };
286        handler.deliver(response).unwrap();
287        assert!(!handler.has_pending("req-1"));
288
289        let received = rx.try_recv().unwrap();
290        assert_eq!(received.text, "hello");
291    }
292
293    #[test]
294    fn test_handler_deliver_no_pending() {
295        let handler = UserInputHandler::new();
296        let response = UserInputResponse {
297            request_id: "nonexistent".to_string(),
298            ..Default::default()
299        };
300        assert!(handler.deliver(response).is_err());
301    }
302
303    #[test]
304    fn test_handler_cancel() {
305        let handler = UserInputHandler::new();
306        let _rx = handler.register("req-1");
307        assert!(handler.has_pending("req-1"));
308        handler.cancel("req-1");
309        assert!(!handler.has_pending("req-1"));
310    }
311
312    #[tokio::test]
313    async fn test_wait_for_response_success() {
314        let handler = UserInputHandler::new();
315        let rx = handler.register("req-1");
316
317        // Deliver response in background
318        tokio::spawn(async move {
319            tokio::time::sleep(Duration::from_millis(10)).await;
320            handler
321                .deliver(UserInputResponse {
322                    request_id: "req-1".to_string(),
323                    text: "password123".to_string(),
324                    ..Default::default()
325                })
326                .unwrap();
327        });
328
329        let result = wait_for_response(rx, Duration::from_secs(5)).await;
330        assert!(result.is_ok());
331        assert_eq!(result.unwrap().text, "password123");
332    }
333
334    #[tokio::test]
335    async fn test_wait_for_response_error() {
336        let handler = UserInputHandler::new();
337        let rx = handler.register("req-1");
338
339        tokio::spawn(async move {
340            handler
341                .deliver(UserInputResponse {
342                    request_id: "req-1".to_string(),
343                    error_msg: "user cancelled".to_string(),
344                    ..Default::default()
345                })
346                .unwrap();
347        });
348
349        let result = wait_for_response(rx, Duration::from_secs(5)).await;
350        assert!(result.is_err());
351        assert!(result.unwrap_err().contains("user cancelled"));
352    }
353
354    #[tokio::test]
355    async fn test_wait_for_response_timeout() {
356        let (_tx, rx) = oneshot::channel::<UserInputResponse>();
357        let result = wait_for_response(rx, Duration::from_millis(10)).await;
358        assert!(result.is_err());
359        assert!(result.unwrap_err().contains("timed out"));
360    }
361
362    #[test]
363    fn test_confirm_request_serde() {
364        let req = UserInputRequest {
365            request_id: "req-2".to_string(),
366            query_text: "Add host key?".to_string(),
367            response_type: RESPONSE_TYPE_CONFIRM.to_string(),
368            title: "SSH Host Key".to_string(),
369            markdown: true,
370            timeout_ms: 30000,
371            checkbox_msg: "Don't ask again".to_string(),
372            public_text: true,
373            ok_label: "Accept".to_string(),
374            cancel_label: "Reject".to_string(),
375        };
376        let json = serde_json::to_string(&req).unwrap();
377        assert!(json.contains("\"checkboxmsg\":\"Don't ask again\""));
378        assert!(json.contains("\"oklabel\":\"Accept\""));
379        assert!(json.contains("\"cancellabel\":\"Reject\""));
380    }
381}