1#![allow(dead_code)]
2use std::collections::HashMap;
18use std::sync::Mutex;
19use std::time::Duration;
20
21use serde::{Deserialize, Serialize};
22use tokio::sync::oneshot;
23use tokio::time;
24
25#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct UserInputRequest {
31 #[serde(rename = "requestid")]
33 pub request_id: String,
34
35 #[serde(rename = "querytext")]
37 pub query_text: String,
38
39 #[serde(rename = "responsetype")]
41 pub response_type: String,
42
43 pub title: String,
45
46 #[serde(default)]
48 pub markdown: bool,
49
50 #[serde(rename = "timeoutms", default)]
52 pub timeout_ms: i64,
53
54 #[serde(rename = "checkboxmsg", default, skip_serializing_if = "String::is_empty")]
56 pub checkbox_msg: String,
57
58 #[serde(rename = "publictext", default)]
60 pub public_text: bool,
61
62 #[serde(rename = "oklabel", default, skip_serializing_if = "String::is_empty")]
64 pub ok_label: String,
65
66 #[serde(rename = "cancellabel", default, skip_serializing_if = "String::is_empty")]
68 pub cancel_label: String,
69}
70
71#[derive(Debug, Clone, Default, Serialize, Deserialize)]
74pub struct UserInputResponse {
75 #[serde(rename = "type", default)]
77 pub response_type: String,
78
79 #[serde(rename = "requestid")]
81 pub request_id: String,
82
83 #[serde(default, skip_serializing_if = "String::is_empty")]
85 pub text: String,
86
87 #[serde(default)]
89 pub confirm: bool,
90
91 #[serde(rename = "errormsg", default, skip_serializing_if = "String::is_empty")]
93 pub error_msg: String,
94
95 #[serde(rename = "checkboxstat", default)]
97 pub checkbox_stat: bool,
98}
99
100impl UserInputResponse {
101 pub fn is_error(&self) -> bool {
103 !self.error_msg.is_empty()
104 }
105
106 pub fn is_confirmed(&self) -> bool {
108 self.confirm && self.error_msg.is_empty()
109 }
110}
111
112pub const RESPONSE_TYPE_TEXT: &str = "text";
115pub const RESPONSE_TYPE_CONFIRM: &str = "confirm";
116
117pub const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
121
122pub const SSH_PROMPT_TIMEOUT: Duration = Duration::from_secs(60);
124
125pub struct UserInputHandler {
130 channels: Mutex<HashMap<String, oneshot::Sender<UserInputResponse>>>,
131}
132
133impl UserInputHandler {
134 pub fn new() -> Self {
135 Self {
136 channels: Mutex::new(HashMap::new()),
137 }
138 }
139
140 pub fn register(&self, request_id: &str) -> oneshot::Receiver<UserInputResponse> {
143 let (tx, rx) = oneshot::channel();
144 let mut channels = self.channels.lock().unwrap();
145 channels.insert(request_id.to_string(), tx);
146 rx
147 }
148
149 pub fn deliver(&self, response: UserInputResponse) -> Result<(), String> {
152 let mut channels = self.channels.lock().unwrap();
153 let tx = channels
154 .remove(&response.request_id)
155 .ok_or_else(|| format!("no pending request: {}", response.request_id))?;
156 tx.send(response)
157 .map_err(|_| "receiver dropped".to_string())
158 }
159
160 pub fn cancel(&self, request_id: &str) {
162 let mut channels = self.channels.lock().unwrap();
163 channels.remove(request_id);
164 }
165
166 pub fn has_pending(&self, request_id: &str) -> bool {
168 let channels = self.channels.lock().unwrap();
169 channels.contains_key(request_id)
170 }
171
172 pub fn pending_count(&self) -> usize {
174 let channels = self.channels.lock().unwrap();
175 channels.len()
176 }
177}
178
179impl Default for UserInputHandler {
180 fn default() -> Self {
181 Self::new()
182 }
183}
184
185pub async fn wait_for_response(
187 rx: oneshot::Receiver<UserInputResponse>,
188 timeout: Duration,
189) -> Result<UserInputResponse, String> {
190 match time::timeout(timeout, rx).await {
191 Ok(Ok(response)) => {
192 if response.is_error() {
193 Err(response.error_msg)
194 } else {
195 Ok(response)
196 }
197 }
198 Ok(Err(_)) => Err("user input handler was cancelled".to_string()),
199 Err(_) => Err("timed out waiting for user input".to_string()),
200 }
201}
202
203#[cfg(test)]
204mod tests {
205 use super::*;
206
207 #[test]
208 fn test_user_input_request_serde() {
209 let req = UserInputRequest {
210 request_id: "req-1".to_string(),
211 query_text: "Enter password:".to_string(),
212 response_type: RESPONSE_TYPE_TEXT.to_string(),
213 title: "SSH Authentication".to_string(),
214 markdown: true,
215 timeout_ms: 60000,
216 checkbox_msg: String::new(),
217 public_text: false,
218 ok_label: String::new(),
219 cancel_label: String::new(),
220 };
221 let json = serde_json::to_string(&req).unwrap();
222 assert!(json.contains("\"requestid\":\"req-1\""));
223 assert!(json.contains("\"querytext\":\"Enter password:\""));
224 assert!(json.contains("\"responsetype\":\"text\""));
225 assert!(json.contains("\"timeoutms\":60000"));
226 assert!(!json.contains("\"checkboxmsg\""));
228 assert!(!json.contains("\"oklabel\""));
229
230 let parsed: UserInputRequest = serde_json::from_str(&json).unwrap();
231 assert_eq!(parsed.request_id, "req-1");
232 assert!(parsed.markdown);
233 }
234
235 #[test]
236 fn test_user_input_response_serde() {
237 let resp = UserInputResponse {
238 response_type: RESPONSE_TYPE_TEXT.to_string(),
239 request_id: "req-1".to_string(),
240 text: "my_password".to_string(),
241 confirm: false,
242 error_msg: String::new(),
243 checkbox_stat: false,
244 };
245 let json = serde_json::to_string(&resp).unwrap();
246 assert!(json.contains("\"requestid\":\"req-1\""));
247 assert!(json.contains("\"text\":\"my_password\""));
248 assert!(!json.contains("\"errormsg\"")); let parsed: UserInputResponse = serde_json::from_str(&json).unwrap();
251 assert_eq!(parsed.text, "my_password");
252 }
253
254 #[test]
255 fn test_response_is_error() {
256 let resp = UserInputResponse {
257 error_msg: "cancelled".to_string(),
258 ..Default::default()
259 };
260 assert!(resp.is_error());
261 assert!(!resp.is_confirmed());
262 }
263
264 #[test]
265 fn test_response_is_confirmed() {
266 let resp = UserInputResponse {
267 confirm: true,
268 ..Default::default()
269 };
270 assert!(resp.is_confirmed());
271 assert!(!resp.is_error());
272 }
273
274 #[test]
275 fn test_handler_register_deliver() {
276 let handler = UserInputHandler::new();
277 let mut rx = handler.register("req-1");
278 assert!(handler.has_pending("req-1"));
279 assert_eq!(handler.pending_count(), 1);
280
281 let response = UserInputResponse {
282 request_id: "req-1".to_string(),
283 text: "hello".to_string(),
284 ..Default::default()
285 };
286 handler.deliver(response).unwrap();
287 assert!(!handler.has_pending("req-1"));
288
289 let received = rx.try_recv().unwrap();
290 assert_eq!(received.text, "hello");
291 }
292
293 #[test]
294 fn test_handler_deliver_no_pending() {
295 let handler = UserInputHandler::new();
296 let response = UserInputResponse {
297 request_id: "nonexistent".to_string(),
298 ..Default::default()
299 };
300 assert!(handler.deliver(response).is_err());
301 }
302
303 #[test]
304 fn test_handler_cancel() {
305 let handler = UserInputHandler::new();
306 let _rx = handler.register("req-1");
307 assert!(handler.has_pending("req-1"));
308 handler.cancel("req-1");
309 assert!(!handler.has_pending("req-1"));
310 }
311
312 #[tokio::test]
313 async fn test_wait_for_response_success() {
314 let handler = UserInputHandler::new();
315 let rx = handler.register("req-1");
316
317 tokio::spawn(async move {
319 tokio::time::sleep(Duration::from_millis(10)).await;
320 handler
321 .deliver(UserInputResponse {
322 request_id: "req-1".to_string(),
323 text: "password123".to_string(),
324 ..Default::default()
325 })
326 .unwrap();
327 });
328
329 let result = wait_for_response(rx, Duration::from_secs(5)).await;
330 assert!(result.is_ok());
331 assert_eq!(result.unwrap().text, "password123");
332 }
333
334 #[tokio::test]
335 async fn test_wait_for_response_error() {
336 let handler = UserInputHandler::new();
337 let rx = handler.register("req-1");
338
339 tokio::spawn(async move {
340 handler
341 .deliver(UserInputResponse {
342 request_id: "req-1".to_string(),
343 error_msg: "user cancelled".to_string(),
344 ..Default::default()
345 })
346 .unwrap();
347 });
348
349 let result = wait_for_response(rx, Duration::from_secs(5)).await;
350 assert!(result.is_err());
351 assert!(result.unwrap_err().contains("user cancelled"));
352 }
353
354 #[tokio::test]
355 async fn test_wait_for_response_timeout() {
356 let (_tx, rx) = oneshot::channel::<UserInputResponse>();
357 let result = wait_for_response(rx, Duration::from_millis(10)).await;
358 assert!(result.is_err());
359 assert!(result.unwrap_err().contains("timed out"));
360 }
361
362 #[test]
363 fn test_confirm_request_serde() {
364 let req = UserInputRequest {
365 request_id: "req-2".to_string(),
366 query_text: "Add host key?".to_string(),
367 response_type: RESPONSE_TYPE_CONFIRM.to_string(),
368 title: "SSH Host Key".to_string(),
369 markdown: true,
370 timeout_ms: 30000,
371 checkbox_msg: "Don't ask again".to_string(),
372 public_text: true,
373 ok_label: "Accept".to_string(),
374 cancel_label: "Reject".to_string(),
375 };
376 let json = serde_json::to_string(&req).unwrap();
377 assert!(json.contains("\"checkboxmsg\":\"Don't ask again\""));
378 assert!(json.contains("\"oklabel\":\"Accept\""));
379 assert!(json.contains("\"cancellabel\":\"Reject\""));
380 }
381}