1use std::collections::HashMap;
21use std::sync::{Arc, Mutex};
22use std::time::{Duration, Instant};
23
24use serde::{Deserialize, Serialize};
25
26use super::auth_patterns::{match_line, AuthPatternMatch};
27
28const SESSION_TIMEOUT_SECS: u64 = 600;
32
33#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
34#[serde(
35 tag = "status",
36 rename_all = "kebab-case",
37 rename_all_fields = "camelCase"
38)]
39pub enum AuthSessionStatus {
40 Pending,
43 UrlAvailable { auth_url: String },
46 CodeEmitted {
48 device_code: String,
49 verification_url: String,
50 },
51 Success {
54 bundle_id: String,
55 email: Option<String>,
58 },
59 Failed { error: String },
62}
63
64impl AuthSessionStatus {
65 pub fn is_terminal(&self) -> bool {
66 matches!(self, Self::Success { .. } | Self::Failed { .. })
67 }
68}
69
70#[derive(Debug)]
71struct Session {
72 provider_id: String,
73 into_bundle_id: Option<String>,
74 status: AuthSessionStatus,
75 captured_url: Option<String>,
78 captured_device_code: Option<(String, String)>,
79 captured_email: Option<String>,
80 started_at: Instant,
81 transcript: Vec<String>,
85}
86
87impl Session {
88 fn new(provider_id: String, into_bundle_id: Option<String>) -> Self {
89 Self {
90 provider_id,
91 into_bundle_id,
92 status: AuthSessionStatus::Pending,
93 captured_url: None,
94 captured_device_code: None,
95 captured_email: None,
96 started_at: Instant::now(),
97 transcript: Vec::new(),
98 }
99 }
100
101 fn timed_out(&self) -> bool {
102 self.started_at.elapsed() > Duration::from_secs(SESSION_TIMEOUT_SECS)
103 }
104}
105
106#[derive(Debug, Clone, Serialize)]
111#[serde(rename_all = "camelCase")]
112pub struct StartSessionResult {
113 pub session_id: String,
114 pub auth_url: Option<String>,
118}
119
120#[derive(Debug, Clone, Serialize)]
123#[serde(rename_all = "camelCase")]
124pub struct PollSessionResult {
125 pub provider_id: String,
126 #[serde(flatten)]
127 pub status: AuthSessionStatus,
128}
129
130#[derive(Default)]
135struct ProcessRefs {
136 drain_tasks: HashMap<String, tokio::task::JoinHandle<()>>,
137 stdin_senders: HashMap<String, tokio::sync::mpsc::Sender<String>>,
138 pty_pids: HashMap<String, u32>,
142}
143
144#[derive(Default)]
145pub struct AuthSessionManager {
146 sessions: Arc<Mutex<HashMap<String, Session>>>,
147 process_refs: Arc<Mutex<ProcessRefs>>,
148}
149
150impl AuthSessionManager {
151 pub fn new() -> Self {
152 Self::default()
153 }
154
155 pub fn start_session(
159 &self,
160 provider_id: String,
161 into_bundle_id: Option<String>,
162 ) -> StartSessionResult {
163 let session_id = format!("auth-{}", uuid::Uuid::new_v4());
164 let session = Session::new(provider_id, into_bundle_id);
165 self.sessions
166 .lock()
167 .unwrap()
168 .insert(session_id.clone(), session);
169 StartSessionResult {
170 session_id,
171 auth_url: None,
172 }
173 }
174
175 pub fn record_line(&self, session_id: &str, line: &str) -> Option<AuthPatternMatch> {
179 let mut sessions = self.sessions.lock().unwrap();
180 let session = sessions.get_mut(session_id)?;
181 session.transcript.push(line.to_string());
182 let m = match_line(&session.provider_id, line)?;
183 match &m {
184 AuthPatternMatch::OAuthUrl(url) => {
185 if session.captured_url.is_none() {
186 session.captured_url = Some(url.clone());
187 if matches!(session.status, AuthSessionStatus::Pending) {
188 session.status = AuthSessionStatus::UrlAvailable {
189 auth_url: url.clone(),
190 };
191 }
192 }
193 }
194 AuthPatternMatch::DeviceCode {
195 code,
196 verification_url,
197 } => {
198 if session.captured_device_code.is_none() {
199 session.captured_device_code =
200 Some((code.clone(), verification_url.clone()));
201 if matches!(
202 session.status,
203 AuthSessionStatus::Pending | AuthSessionStatus::UrlAvailable { .. }
204 ) {
205 session.status = AuthSessionStatus::CodeEmitted {
206 device_code: code.clone(),
207 verification_url: verification_url.clone(),
208 };
209 }
210 }
211 }
212 AuthPatternMatch::LoginSuccess { email } => {
213 if session.captured_email.is_none() {
218 session.captured_email = email.clone();
219 }
220 }
221 AuthPatternMatch::LoginFailure { message: _ } => {
222 }
225 }
226 Some(m)
227 }
228
229 pub fn finish_success(&self, session_id: &str, bundle_id: String) -> bool {
233 let mut sessions = self.sessions.lock().unwrap();
234 let Some(session) = sessions.get_mut(session_id) else {
235 return false;
236 };
237 if session.status.is_terminal() {
238 return false;
239 }
240 session.status = AuthSessionStatus::Success {
241 bundle_id,
242 email: session.captured_email.clone(),
243 };
244 true
245 }
246
247 pub fn finish_failure(&self, session_id: &str, error: String) -> bool {
248 let mut sessions = self.sessions.lock().unwrap();
249 let Some(session) = sessions.get_mut(session_id) else {
250 return false;
251 };
252 if session.status.is_terminal() {
253 return false;
254 }
255 session.status = AuthSessionStatus::Failed { error };
256 true
257 }
258
259 pub fn poll_session(&self, session_id: &str) -> Option<PollSessionResult> {
263 let mut sessions = self.sessions.lock().unwrap();
264 let session = sessions.get_mut(session_id)?;
265 if !session.status.is_terminal() && session.timed_out() {
266 session.status = AuthSessionStatus::Failed {
267 error: format!(
268 "auth session timed out after {SESSION_TIMEOUT_SECS}s"
269 ),
270 };
271 }
272 Some(PollSessionResult {
273 provider_id: session.provider_id.clone(),
274 status: session.status.clone(),
275 })
276 }
277
278 pub fn cancel_session(&self, session_id: &str) -> bool {
284 let transitioned =
285 self.finish_failure(session_id, "cancelled by user".to_string());
286 let mut refs = self.process_refs.lock().unwrap();
287 if let Some(handle) = refs.drain_tasks.remove(session_id) {
288 handle.abort();
289 }
290 refs.stdin_senders.remove(session_id);
291 if let Some(pid) = refs.pty_pids.remove(session_id) {
295 if let Err(e) = kill_pid(pid) {
296 tracing::warn!(pid, session_id, error = %e, "cancel_session: kill_pid failed");
297 } else {
298 tracing::info!(pid, session_id, "cancel_session: PTY child killed");
299 }
300 }
301 transitioned
302 }
303
304 pub fn attach_pty_pid(&self, session_id: &str, pid: u32) {
308 let mut refs = self.process_refs.lock().unwrap();
309 refs.pty_pids.insert(session_id.to_string(), pid);
310 }
311
312 pub fn attach_process(
315 &self,
316 session_id: &str,
317 drain_task: tokio::task::JoinHandle<()>,
318 stdin_sender: tokio::sync::mpsc::Sender<String>,
319 ) {
320 let mut refs = self.process_refs.lock().unwrap();
321 refs.drain_tasks.insert(session_id.to_string(), drain_task);
322 refs.stdin_senders.insert(session_id.to_string(), stdin_sender);
323 }
324
325 pub async fn send_to_stdin(&self, session_id: &str, line: String) -> bool {
330 let sender = {
331 let refs = self.process_refs.lock().unwrap();
332 refs.stdin_senders.get(session_id).cloned()
333 };
334 match sender {
335 Some(s) => s.send(line).await.is_ok(),
336 None => false,
337 }
338 }
339
340 pub fn detach_process(&self, session_id: &str) {
344 let mut refs = self.process_refs.lock().unwrap();
345 refs.drain_tasks.remove(session_id);
346 refs.stdin_senders.remove(session_id);
347 refs.pty_pids.remove(session_id);
348 }
349
350 pub fn remove(&self, session_id: &str) {
356 self.sessions.lock().unwrap().remove(session_id);
357 }
358
359 #[allow(dead_code)]
363 pub fn transcript(&self, session_id: &str) -> Option<Vec<String>> {
364 self.sessions
365 .lock()
366 .unwrap()
367 .get(session_id)
368 .map(|s| s.transcript.clone())
369 }
370
371 #[cfg(test)]
375 fn force_age(&self, session_id: &str) {
376 if let Some(s) = self.sessions.lock().unwrap().get_mut(session_id) {
377 s.started_at = Instant::now() - Duration::from_secs(SESSION_TIMEOUT_SECS + 1);
378 }
379 }
380}
381
382#[cfg(windows)]
385fn kill_pid(pid: u32) -> std::io::Result<()> {
386 let status = std::process::Command::new("taskkill")
387 .args(["/F", "/T", "/PID", &pid.to_string()])
388 .stdin(std::process::Stdio::null())
389 .stdout(std::process::Stdio::null())
390 .stderr(std::process::Stdio::null())
391 .status()?;
392 if status.success() {
393 Ok(())
394 } else {
395 Err(std::io::Error::other(format!("taskkill exit {:?}", status.code())))
396 }
397}
398
399#[cfg(unix)]
400fn kill_pid(pid: u32) -> std::io::Result<()> {
401 let ret = unsafe { libc::kill(pid as i32, libc::SIGTERM) };
402 if ret != 0 {
403 return Err(std::io::Error::last_os_error());
404 }
405 Ok(())
406}
407
408#[cfg(test)]
409mod tests {
410 use super::*;
411
412 fn mgr() -> AuthSessionManager {
413 AuthSessionManager::new()
414 }
415
416 #[test]
417 fn start_creates_pending_session() {
418 let m = mgr();
419 let r = m.start_session("claude".to_string(), None);
420 assert!(!r.session_id.is_empty());
421 assert!(r.auth_url.is_none());
422 let p = m.poll_session(&r.session_id).expect("session exists");
423 assert_eq!(p.provider_id, "claude");
424 assert!(matches!(p.status, AuthSessionStatus::Pending));
425 }
426
427 #[test]
428 fn url_line_transitions_to_url_available() {
429 let m = mgr();
430 let r = m.start_session("claude".to_string(), None);
431 let _ = m.record_line(
432 &r.session_id,
433 "Open https://console.anthropic.com/oauth/authorize?state=xyz",
434 );
435 let p = m.poll_session(&r.session_id).unwrap();
436 match p.status {
437 AuthSessionStatus::UrlAvailable { auth_url } => {
438 assert!(auth_url.contains("anthropic.com/oauth"));
439 }
440 other => panic!("expected UrlAvailable, got {other:?}"),
441 }
442 }
443
444 #[test]
445 fn device_code_line_transitions_to_code_emitted() {
446 let m = mgr();
447 let r = m.start_session("copilot".to_string(), None);
448 let _ = m.record_line(&r.session_id, "! Copy your one-time code: ABCD-1234");
449 let p = m.poll_session(&r.session_id).unwrap();
450 match p.status {
451 AuthSessionStatus::CodeEmitted {
452 device_code,
453 verification_url,
454 } => {
455 assert_eq!(device_code, "ABCD-1234");
456 assert_eq!(verification_url, "https://github.com/login/device");
457 }
458 other => panic!("expected CodeEmitted, got {other:?}"),
459 }
460 }
461
462 #[test]
463 fn login_success_line_does_not_transition_state_alone() {
464 let m = mgr();
467 let r = m.start_session("claude".to_string(), None);
468 let _ = m.record_line(&r.session_id, "Successfully logged in as asaf@example.com");
469 let p = m.poll_session(&r.session_id).unwrap();
470 assert!(matches!(p.status, AuthSessionStatus::Pending));
472 }
473
474 #[test]
475 fn finish_success_carries_email_from_transcript() {
476 let m = mgr();
477 let r = m.start_session("claude".to_string(), None);
478 let _ = m.record_line(&r.session_id, "Successfully logged in as asaf@example.com");
479 assert!(m.finish_success(&r.session_id, "bundle-1".to_string()));
480 let p = m.poll_session(&r.session_id).unwrap();
481 match p.status {
482 AuthSessionStatus::Success { bundle_id, email } => {
483 assert_eq!(bundle_id, "bundle-1");
484 assert_eq!(email.as_deref(), Some("asaf@example.com"));
485 }
486 other => panic!("expected Success, got {other:?}"),
487 }
488 }
489
490 #[test]
491 fn cancel_transitions_to_failed() {
492 let m = mgr();
493 let r = m.start_session("claude".to_string(), None);
494 assert!(m.cancel_session(&r.session_id));
495 let p = m.poll_session(&r.session_id).unwrap();
496 match p.status {
497 AuthSessionStatus::Failed { error } => assert!(error.contains("cancelled")),
498 other => panic!("expected Failed, got {other:?}"),
499 }
500 }
501
502 #[test]
503 fn timeout_transitions_pending_to_failed_on_poll() {
504 let m = mgr();
505 let r = m.start_session("claude".to_string(), None);
506 m.force_age(&r.session_id);
507 let p = m.poll_session(&r.session_id).unwrap();
508 match p.status {
509 AuthSessionStatus::Failed { error } => assert!(error.contains("timed out")),
510 other => panic!("expected timeout Failed, got {other:?}"),
511 }
512 }
513
514 #[test]
515 fn terminal_states_cannot_be_re_transitioned() {
516 let m = mgr();
517 let r = m.start_session("claude".to_string(), None);
518 assert!(m.finish_success(&r.session_id, "bundle-1".to_string()));
519 assert!(!m.finish_failure(&r.session_id, "should be ignored".to_string()));
521 let p = m.poll_session(&r.session_id).unwrap();
522 assert!(matches!(p.status, AuthSessionStatus::Success { .. }));
523 }
524
525 #[test]
526 fn multiple_url_lines_keep_the_first_url() {
527 let m = mgr();
530 let r = m.start_session("claude".to_string(), None);
531 let _ = m.record_line(
532 &r.session_id,
533 "Open https://console.anthropic.com/oauth/authorize?state=first",
534 );
535 let _ = m.record_line(
536 &r.session_id,
537 "Open https://console.anthropic.com/oauth/authorize?state=second",
538 );
539 let p = m.poll_session(&r.session_id).unwrap();
540 match p.status {
541 AuthSessionStatus::UrlAvailable { auth_url } => {
542 assert!(auth_url.contains("state=first"));
543 }
544 _ => panic!("expected UrlAvailable"),
545 }
546 }
547
548 #[test]
549 fn remove_clears_session() {
550 let m = mgr();
551 let r = m.start_session("claude".to_string(), None);
552 m.remove(&r.session_id);
553 assert!(m.poll_session(&r.session_id).is_none());
554 }
555
556 #[test]
557 fn status_serializes_with_camelcase_field_names() {
558 let s = AuthSessionStatus::UrlAvailable {
563 auth_url: "https://example.com/oauth".to_string(),
564 };
565 let v = serde_json::to_value(&s).unwrap();
566 assert_eq!(
567 v,
568 serde_json::json!({
569 "status": "url-available",
570 "authUrl": "https://example.com/oauth"
571 })
572 );
573
574 let s = AuthSessionStatus::CodeEmitted {
575 device_code: "ABCD-1234".to_string(),
576 verification_url: "https://github.com/login/device".to_string(),
577 };
578 let v = serde_json::to_value(&s).unwrap();
579 assert_eq!(
580 v,
581 serde_json::json!({
582 "status": "code-emitted",
583 "deviceCode": "ABCD-1234",
584 "verificationUrl": "https://github.com/login/device"
585 })
586 );
587
588 let s = AuthSessionStatus::Success {
589 bundle_id: "bundle-1".to_string(),
590 email: Some("asaf@example.com".to_string()),
591 };
592 let v = serde_json::to_value(&s).unwrap();
593 assert_eq!(
594 v,
595 serde_json::json!({
596 "status": "success",
597 "bundleId": "bundle-1",
598 "email": "asaf@example.com"
599 })
600 );
601 }
602
603 #[test]
604 fn unknown_session_polls_to_none() {
605 let m = mgr();
606 assert!(m.poll_session("does-not-exist").is_none());
607 }
608
609 #[test]
610 fn transcript_records_all_lines_including_non_matching() {
611 let m = mgr();
612 let r = m.start_session("claude".to_string(), None);
613 let _ = m.record_line(&r.session_id, "Starting auth flow...");
614 let _ = m.record_line(&r.session_id, "Open https://console.anthropic.com/oauth/authorize");
615 let _ = m.record_line(&r.session_id, "Waiting for callback...");
616 let t = m.transcript(&r.session_id).unwrap();
617 assert_eq!(t.len(), 3);
618 assert_eq!(t[0], "Starting auth flow...");
619 assert!(t[1].contains("anthropic.com/oauth"));
620 assert_eq!(t[2], "Waiting for callback...");
621 }
622}