1use std::collections::{HashMap, HashSet};
17use std::sync::Arc;
18use std::time::Instant;
19
20use bytes::BytesMut;
21use chrono::Utc;
22use dashmap::DashMap;
23use futures::{SinkExt, StreamExt};
24use tokio::net::TcpListener;
25use tokio::sync::mpsc;
26use tokio_tungstenite::tungstenite::handshake::server::{ErrorResponse, Request, Response};
27use tokio_tungstenite::tungstenite::http::StatusCode;
28use tokio_tungstenite::tungstenite::protocol::Message;
29
30use futu_auth::{CheckCtx, KeyRecord, KeyStore, LimitOutcome, RuntimeCounters, Scope};
31use futu_codec::frame::FutuFrame;
32use futu_codec::header::{FutuHeader, ProtoFmtType, HEADER_SIZE};
33use futu_core::proto_id;
34
35use crate::conn::{ClientConn, ConnState, DisconnectNotify, IncomingRequest};
36use crate::listener::{ServerConfig, MAX_CONNECTIONS};
37use crate::router::RequestRouter;
38
39pub struct WsServer {
41 listen_addr: String,
42 config: ServerConfig,
43 connections: Arc<DashMap<u64, ClientConn>>,
44 router: Arc<RequestRouter>,
45 subscriptions: Option<Arc<crate::subscription::SubscriptionManager>>,
46 key_store: Option<Arc<KeyStore>>,
48 counters: Option<Arc<RuntimeCounters>>,
50}
51
52impl WsServer {
53 pub fn new(
55 listen_addr: String,
56 config: ServerConfig,
57 connections: Arc<DashMap<u64, ClientConn>>,
58 router: Arc<RequestRouter>,
59 subscriptions: Option<Arc<crate::subscription::SubscriptionManager>>,
60 ) -> Self {
61 Self::with_auth(
62 listen_addr,
63 config,
64 connections,
65 router,
66 subscriptions,
67 None,
68 None,
69 )
70 }
71
72 #[allow(clippy::too_many_arguments)]
75 pub fn with_auth(
76 listen_addr: String,
77 config: ServerConfig,
78 connections: Arc<DashMap<u64, ClientConn>>,
79 router: Arc<RequestRouter>,
80 subscriptions: Option<Arc<crate::subscription::SubscriptionManager>>,
81 key_store: Option<Arc<KeyStore>>,
82 counters: Option<Arc<RuntimeCounters>>,
83 ) -> Self {
84 Self {
85 listen_addr,
86 config,
87 connections,
88 router,
89 subscriptions,
90 key_store,
91 counters,
92 }
93 }
94
95 pub async fn run(&self) -> anyhow::Result<()> {
97 let listener = TcpListener::bind(&self.listen_addr).await?;
98 tracing::info!(addr = %self.listen_addr, "WebSocket server listening");
99
100 let (req_tx, req_rx) = mpsc::unbounded_channel::<IncomingRequest>();
101 let (disconnect_tx, mut disconnect_rx) = mpsc::unbounded_channel::<DisconnectNotify>();
102
103 let connections = Arc::clone(&self.connections);
105 let router = Arc::clone(&self.router);
106 let config = self.config.clone();
107 let counters_for_process = self.counters.clone();
108 let key_store_for_process = self.key_store.clone();
109 let scope_mode = self.key_store.as_ref().is_some_and(|ks| ks.is_configured());
110 tokio::spawn(async move {
111 ws_process_requests(
112 req_rx,
113 connections,
114 router,
115 config,
116 counters_for_process,
117 key_store_for_process,
118 scope_mode,
119 )
120 .await;
121 });
122
123 let cleanup_connections = Arc::clone(&self.connections);
125 let cleanup_subs = self.subscriptions.clone();
126 tokio::spawn(async move {
127 while let Some(notify) = disconnect_rx.recv().await {
128 let removed = cleanup_connections.remove(¬ify.conn_id);
129 if removed.is_some() {
130 if let Some(ref subs) = cleanup_subs {
131 subs.on_disconnect(notify.conn_id);
132 }
133 tracing::info!(
134 conn_id = notify.conn_id,
135 remaining = cleanup_connections.len(),
136 "ws connection removed from pool"
137 );
138 }
139 }
140 });
141
142 let connections = Arc::clone(&self.connections);
144 let key_store_accept = self.key_store.clone();
145 if !scope_mode {
146 tracing::warn!(
147 "WS server running WITHOUT API key auth (legacy mode); \
148 all WS clients are open. Pass KeyStore via with_auth() to enable."
149 );
150 }
151 loop {
152 let (stream, peer_addr) = listener.accept().await?;
153
154 if connections.len() >= MAX_CONNECTIONS {
155 tracing::warn!(
156 peer = %peer_addr,
157 "max connections reached ({}), rejecting ws client",
158 MAX_CONNECTIONS,
159 );
160 drop(stream);
161 continue;
162 }
163
164 let conn_id = ClientConn::generate_conn_id();
165 let aes_key = ClientConn::generate_aes_key();
166 stream.set_nodelay(true).ok();
167
168 tracing::info!(
169 conn_id = conn_id,
170 peer = %peer_addr,
171 total = connections.len() + 1,
172 "ws client connected"
173 );
174
175 let (tx, authed) = run_ws_connection(
176 stream,
177 conn_id,
178 aes_key,
179 req_tx.clone(),
180 disconnect_tx.clone(),
181 key_store_accept.clone(),
182 )
183 .await;
184
185 let Some(authed) = authed else {
187 continue;
188 };
189
190 let (key_id, scopes) = match authed {
191 AuthResult::Authenticated(rec) => (Some(rec.id.clone()), rec.scopes.clone()),
192 AuthResult::Legacy => (None, HashSet::new()),
193 };
194
195 let conn = ClientConn {
196 conn_id,
197 state: ConnState::Connected,
198 aes_key,
199 aes_encrypt_enabled: false,
200 proto_fmt_type: ProtoFmtType::Protobuf,
201 last_keepalive: Instant::now(),
202 keepalive_count: std::sync::atomic::AtomicU32::new(0),
203 tx,
204 key_id,
205 scopes,
206 };
207
208 connections.insert(conn_id, conn);
209 }
210 }
211}
212
213enum AuthResult {
215 Authenticated(Arc<KeyRecord>),
216 Legacy,
217}
218
219async fn run_ws_connection(
227 stream: tokio::net::TcpStream,
228 conn_id: u64,
229 _aes_key: [u8; 16],
230 req_tx: mpsc::UnboundedSender<IncomingRequest>,
231 disconnect_tx: mpsc::UnboundedSender<DisconnectNotify>,
232 key_store: Option<Arc<KeyStore>>,
233) -> (mpsc::Sender<FutuFrame>, Option<AuthResult>) {
234 let (frame_tx, mut frame_rx) = mpsc::channel::<FutuFrame>(256);
235
236 let authed_slot: Arc<std::sync::Mutex<Option<AuthResult>>> =
239 Arc::new(std::sync::Mutex::new(None));
240 let slot_cb = Arc::clone(&authed_slot);
241 let store_cb = key_store.clone();
242
243 #[allow(clippy::result_large_err)] let callback = move |req: &Request, resp: Response| -> Result<Response, ErrorResponse> {
245 let Some(store) = store_cb.as_ref() else {
247 *slot_cb.lock().unwrap() = Some(AuthResult::Legacy);
248 return Ok(resp);
249 };
250 if !store.is_configured() {
251 *slot_cb.lock().unwrap() = Some(AuthResult::Legacy);
252 return Ok(resp);
253 }
254
255 let token = extract_ws_token(req);
257 let Some(token) = token else {
258 futu_auth::audit::reject("ws", "/ws", "<missing>", "missing token");
259 return Err(make_err_response(
260 StatusCode::UNAUTHORIZED,
261 "missing api key (use ?token=... or Authorization: Bearer ...)",
262 ));
263 };
264
265 let Some(rec) = store.verify(&token) else {
266 futu_auth::audit::reject("ws", "/ws", "<invalid>", "invalid api key");
267 return Err(make_err_response(
268 StatusCode::UNAUTHORIZED,
269 "invalid api key",
270 ));
271 };
272
273 if rec.is_expired(Utc::now()) {
274 futu_auth::audit::reject("ws", "/ws", &rec.id, "key expired");
275 return Err(make_err_response(StatusCode::UNAUTHORIZED, "key expired"));
276 }
277
278 if !rec.scopes.contains(&Scope::QotRead) {
280 futu_auth::audit::reject("ws", "/ws", &rec.id, "missing qot:read");
281 return Err(make_err_response(
282 StatusCode::FORBIDDEN,
283 "missing qot:read scope",
284 ));
285 }
286
287 futu_auth::audit::allow("ws", "/ws", &rec.id, Some("qot:read"));
288 *slot_cb.lock().unwrap() = Some(AuthResult::Authenticated(rec));
289 Ok(resp)
290 };
291
292 let ws_stream = match tokio_tungstenite::accept_hdr_async(stream, callback).await {
293 Ok(ws) => ws,
294 Err(e) => {
295 tracing::warn!(conn_id = conn_id, error = %e, "ws handshake failed");
296 let _ = disconnect_tx.send(DisconnectNotify { conn_id });
297 return (frame_tx, None);
298 }
299 };
300
301 let authed = authed_slot
303 .lock()
304 .unwrap()
305 .take()
306 .expect("authed_slot must be filled after successful handshake");
307
308 let (mut ws_sink, mut ws_stream_rx) = ws_stream.split();
309
310 tokio::spawn(async move {
312 while let Some(frame) = frame_rx.recv().await {
313 let mut buf = BytesMut::new();
314 frame.header.encode(&mut buf);
315 buf.extend_from_slice(&frame.body);
316 let msg = Message::Binary(buf.freeze().into());
317 if let Err(e) = ws_sink.send(msg).await {
318 tracing::warn!(conn_id = conn_id, error = %e, "ws send failed");
319 break;
320 }
321 }
322 });
323
324 tokio::spawn(async move {
326 while let Some(result) = ws_stream_rx.next().await {
327 match result {
328 Ok(msg) => {
329 let data = match msg {
330 Message::Binary(data) => data,
331 Message::Close(_) => {
332 tracing::info!(conn_id = conn_id, "ws client sent close");
333 break;
334 }
335 Message::Ping(_) | Message::Pong(_) => {
336 continue;
338 }
339 _ => {
340 continue;
342 }
343 };
344
345 if data.len() < HEADER_SIZE {
347 tracing::warn!(
348 conn_id = conn_id,
349 len = data.len(),
350 "ws message too short for futu header"
351 );
352 continue;
353 }
354
355 let header_buf = BytesMut::from(&data[..]);
356 let header = match FutuHeader::peek(&header_buf) {
357 Ok(Some(h)) => h,
358 Ok(None) => {
359 tracing::warn!(conn_id = conn_id, "ws header peek returned None");
360 continue;
361 }
362 Err(e) => {
363 tracing::warn!(conn_id = conn_id, error = %e, "ws invalid futu header");
364 continue;
365 }
366 };
367
368 let expected_len = HEADER_SIZE + header.body_len as usize;
369 if data.len() < expected_len {
370 tracing::warn!(
371 conn_id = conn_id,
372 expected = expected_len,
373 actual = data.len(),
374 "ws message shorter than expected frame size"
375 );
376 continue;
377 }
378
379 let body = bytes::Bytes::copy_from_slice(&data[HEADER_SIZE..expected_len]);
380
381 let req = IncomingRequest {
382 conn_id,
383 proto_id: header.proto_id,
384 serial_no: header.serial_no,
385 proto_fmt_type: header.proto_fmt_type,
386 body,
387 };
388
389 if req_tx.send(req).is_err() {
390 break;
391 }
392 }
393 Err(e) => {
394 tracing::warn!(conn_id = conn_id, error = %e, "ws recv error");
395 break;
396 }
397 }
398 }
399 tracing::info!(conn_id = conn_id, "ws connection closed");
400 let _ = disconnect_tx.send(DisconnectNotify { conn_id });
401 });
402
403 (frame_tx, Some(authed))
404}
405
406fn extract_ws_token(req: &Request) -> Option<String> {
411 if let Some(q) = req.uri().query() {
412 let params: HashMap<&str, &str> =
414 q.split('&').filter_map(|kv| kv.split_once('=')).collect();
415 if let Some(v) = params.get("token") {
416 if !v.is_empty() {
417 return Some((*v).to_string());
418 }
419 }
420 }
421 req.headers()
422 .get("authorization")
423 .and_then(|v| v.to_str().ok())
424 .and_then(|v| v.strip_prefix("Bearer ").map(|s| s.trim().to_string()))
425 .filter(|s| !s.is_empty())
426}
427
428fn make_err_response(code: StatusCode, msg: &str) -> ErrorResponse {
430 let body = Some(format!(r#"{{"error":"{msg}"}}"#));
431 let mut resp = tokio_tungstenite::tungstenite::http::Response::new(body);
432 *resp.status_mut() = code;
433 resp.headers_mut().insert(
434 "content-type",
435 tokio_tungstenite::tungstenite::http::HeaderValue::from_static("application/json"),
436 );
437 resp
438}
439
440async fn ws_process_requests(
446 mut req_rx: mpsc::UnboundedReceiver<IncomingRequest>,
447 connections: Arc<DashMap<u64, ClientConn>>,
448 router: Arc<RequestRouter>,
449 config: ServerConfig,
450 counters: Option<Arc<RuntimeCounters>>,
451 key_store: Option<Arc<KeyStore>>,
452 scope_mode: bool,
453) {
454 use crate::listener::ApiServer;
455
456 while let Some(mut req) = req_rx.recv().await {
457 let conn_id = req.conn_id;
458 let proto_id_val = req.proto_id;
459 let serial_no = req.serial_no;
460
461 if let Some(mut conn) = connections.get_mut(&conn_id) {
463 conn.last_keepalive = Instant::now();
464 }
465
466 if scope_mode {
468 if let Some(needed) = futu_auth::scope_for_proto_id(proto_id_val) {
469 let (scopes, key_id_snap) = match connections.get(&conn_id) {
471 Some(conn) => (conn.scopes.clone(), conn.key_id.clone()),
472 None => {
473 tracing::warn!(conn_id, proto_id = proto_id_val, "ws req on unknown conn");
474 continue;
475 }
476 };
477 let key_id_str = key_id_snap.as_deref().unwrap_or("<none>");
478 let limits_snap = match (&key_store, &key_id_snap) {
481 (Some(ks), Some(id)) => {
482 ks.get_by_id(id).map(|r| r.limits()).unwrap_or_default()
483 }
484 _ => futu_auth::Limits::default(),
485 };
486 if !scopes.contains(&needed) {
487 futu_auth::audit::reject(
488 "ws",
489 &format!("proto_id={proto_id_val}"),
490 key_id_str,
491 &format!("missing scope {needed}"),
492 );
493 continue;
494 }
495 if needed == Scope::TradeReal {
497 if let Some(c) = &counters {
498 let ctx = CheckCtx {
499 market: String::new(),
500 symbol: String::new(),
501 order_value: None,
502 trd_side: None,
503 };
504 if let LimitOutcome::Reject(reason) =
505 c.check_and_commit(key_id_str, &limits_snap, &ctx, Utc::now())
506 {
507 futu_auth::audit::reject(
508 "ws",
509 &format!("proto_id={proto_id_val}"),
510 key_id_str,
511 &format!("limit: {reason}"),
512 );
513 continue;
514 }
515 }
516 }
517 futu_auth::audit::allow(
518 "ws",
519 &format!("proto_id={proto_id_val}"),
520 key_id_str,
521 Some(needed.as_str()),
522 );
523 }
524 }
526
527 if proto_id_val != proto_id::INIT_CONNECT {
529 if let Some(conn) = connections.get(&conn_id) {
530 if conn.aes_encrypt_enabled {
531 match conn.decrypt_body(&req.body) {
532 Ok(decrypted) => {
533 req.body = bytes::Bytes::from(decrypted);
534 }
535 Err(e) => {
536 tracing::warn!(
537 conn_id = conn_id,
538 proto_id = proto_id_val,
539 error = %e,
540 "ws AES decrypt request failed, dropping"
541 );
542 continue;
543 }
544 }
545 }
546 }
547 }
548
549 let response_body = match proto_id_val {
550 proto_id::INIT_CONNECT => {
551 if let Some(mut conn) = connections.get_mut(&conn_id) {
552 conn.handle_init_connect(
553 &req.body,
554 config.server_ver,
555 config.login_user_id,
556 config.keepalive_interval,
557 config.rsa_private_key.as_deref(),
558 )
559 .ok()
560 } else {
561 None
562 }
563 }
564 proto_id::KEEP_ALIVE => {
565 if let Some(conn) = connections.get(&conn_id) {
566 conn.handle_keepalive(&req.body).ok()
567 } else {
568 None
569 }
570 }
571 _ => router.dispatch(conn_id, &req).await,
572 };
573
574 if let Some(body) = response_body {
575 ApiServer::send_response(&connections, conn_id, proto_id_val, serial_no, body).await;
576 }
577 }
578}
579
580#[cfg(test)]
581mod tests {
582 use super::*;
583 use tokio_tungstenite::tungstenite::http::Request as HttpRequest;
584
585 fn mk_req(uri: &str, auth_header: Option<&str>) -> HttpRequest<()> {
586 let mut b = HttpRequest::builder().uri(uri);
587 if let Some(v) = auth_header {
588 b = b.header("authorization", v);
589 }
590 b.body(()).unwrap()
591 }
592
593 #[test]
594 fn extract_token_from_query() {
595 let r = mk_req("/ws?token=abc123", None);
596 assert_eq!(extract_ws_token(&r), Some("abc123".to_string()));
597 }
598
599 #[test]
600 fn extract_token_from_bearer_header() {
601 let r = mk_req("/ws", Some("Bearer xyz789"));
602 assert_eq!(extract_ws_token(&r), Some("xyz789".to_string()));
603 }
604
605 #[test]
606 fn extract_token_query_preferred_over_header() {
607 let r = mk_req("/ws?token=from-query", Some("Bearer from-header"));
608 assert_eq!(extract_ws_token(&r), Some("from-query".to_string()));
609 }
610
611 #[test]
612 fn extract_token_empty_query_falls_back_to_header() {
613 let r = mk_req("/ws?token=", Some("Bearer from-header"));
614 assert_eq!(extract_ws_token(&r), Some("from-header".to_string()));
615 }
616
617 #[test]
618 fn extract_token_missing_everywhere() {
619 let r = mk_req("/ws", None);
620 assert_eq!(extract_ws_token(&r), None);
621 let r2 = mk_req("/ws?foo=bar", None);
622 assert_eq!(extract_ws_token(&r2), None);
623 }
624
625 #[test]
626 fn extract_token_non_bearer_auth_ignored() {
627 let r = mk_req("/ws", Some("Basic asdf"));
628 assert_eq!(extract_ws_token(&r), None);
629 }
630
631 #[test]
632 fn extract_token_bearer_empty_after_prefix() {
633 let r = mk_req("/ws", Some("Bearer "));
635 assert_eq!(extract_ws_token(&r), None);
636 }
637}