1use std::sync::Arc;
4use std::sync::atomic::{AtomicI64, Ordering};
5use std::time::Instant;
6
7use dashmap::DashMap;
8use tokio::net::TcpListener;
9use tokio::sync::{mpsc, watch};
10
11use futu_codec::header::ProtoFmtType;
12use futu_core::proto_id;
13
14use crate::conn::{ClientConn, ConnState, DisconnectNotify, IncomingRequest};
15use crate::metrics::GatewayMetrics;
16use crate::router::RequestRouter;
17
18pub const MAX_CONNECTIONS: usize = 128;
20
21pub(crate) const REQUEST_QUEUE_CAPACITY: usize = 4096;
26
27#[derive(Debug, Clone)]
29pub struct ServerConfig {
30 pub listen_addr: String,
32 pub server_ver: i32,
34 pub login_user_id: u64,
36 pub keepalive_interval: i32,
38 pub rsa_private_key: Option<String>,
40}
41
42#[must_use]
43pub(crate) fn default_server_time_offset_secs() -> Arc<AtomicI64> {
44 Arc::new(AtomicI64::new(0))
45}
46
47#[must_use]
48pub(crate) fn server_now_ts_at(server_time_offset_secs: &AtomicI64, local_now_ts: i64) -> i64 {
49 local_now_ts.saturating_add(server_time_offset_secs.load(Ordering::Relaxed))
50}
51
52#[must_use]
53pub(crate) fn server_now_ts(server_time_offset_secs: &AtomicI64) -> i64 {
54 server_now_ts_at(server_time_offset_secs, chrono::Utc::now().timestamp())
55}
56
57pub(crate) fn set_nodelay_with_log(
58 stream: &tokio::net::TcpStream,
59 peer_addr: std::net::SocketAddr,
60 surface: &'static str,
61) {
62 if let Err(error) = stream.set_nodelay(true) {
63 tracing::debug!(
64 peer = %peer_addr,
65 surface,
66 error = %error,
67 "tcp nodelay setup failed"
68 );
69 }
70}
71
72pub(crate) async fn shutdown_requested(shutdown_rx: &mut watch::Receiver<bool>) {
73 loop {
74 if *shutdown_rx.borrow() {
75 return;
76 }
77 if shutdown_rx.changed().await.is_err() {
78 return;
79 }
80 }
81}
82
83pub struct ApiServer {
85 config: ServerConfig,
86 connections: Arc<DashMap<u64, ClientConn>>,
87 router: Arc<RequestRouter>,
88 subscriptions: Option<Arc<crate::subscription::SubscriptionManager>>,
89 metrics: Arc<GatewayMetrics>,
90 server_time_offset_secs: Arc<AtomicI64>,
91}
92
93impl ApiServer {
94 pub fn new(config: ServerConfig) -> Self {
96 Self {
97 config,
98 connections: Arc::new(DashMap::new()),
99 router: Arc::new(RequestRouter::new()),
100 subscriptions: None,
101 metrics: Arc::new(GatewayMetrics::new()),
102 server_time_offset_secs: default_server_time_offset_secs(),
103 }
104 }
105
106 pub fn set_subscriptions(&mut self, subs: Arc<crate::subscription::SubscriptionManager>) {
108 self.subscriptions = Some(subs);
109 }
110
111 pub fn router(&self) -> &Arc<RequestRouter> {
113 &self.router
114 }
115
116 pub fn connections(&self) -> &Arc<DashMap<u64, ClientConn>> {
118 &self.connections
119 }
120
121 pub fn set_metrics(&mut self, metrics: Arc<GatewayMetrics>) {
123 self.metrics = metrics;
124 }
125
126 pub fn set_server_time_offset_secs(&mut self, offset: Arc<AtomicI64>) {
128 self.server_time_offset_secs = offset;
129 }
130
131 pub fn metrics(&self) -> &Arc<GatewayMetrics> {
133 &self.metrics
134 }
135
136 pub async fn run(&self) -> anyhow::Result<()> {
138 let (_shutdown_tx, shutdown_rx) = watch::channel(false);
139 self.run_until_shutdown(shutdown_rx).await
140 }
141
142 pub async fn run_until_shutdown(
144 &self,
145 mut shutdown_rx: watch::Receiver<bool>,
146 ) -> anyhow::Result<()> {
147 let listener = TcpListener::bind(&self.config.listen_addr).await?;
148 tracing::info!(addr = %self.config.listen_addr, "API server listening");
149
150 let (req_tx, req_rx) = mpsc::channel::<IncomingRequest>(REQUEST_QUEUE_CAPACITY);
151 let (disconnect_tx, mut disconnect_rx) = mpsc::unbounded_channel::<DisconnectNotify>();
155
156 let connections = Arc::clone(&self.connections);
158 let router = Arc::clone(&self.router);
159 let config = self.config.clone();
160 let metrics = Arc::clone(&self.metrics);
161 let server_time_offset_secs = Arc::clone(&self.server_time_offset_secs);
162 tokio::spawn(async move {
163 process_requests(
164 req_rx,
165 connections,
166 router,
167 config,
168 metrics,
169 server_time_offset_secs,
170 )
171 .await;
172 });
173
174 let cleanup_connections = Arc::clone(&self.connections);
176 let cleanup_subs = self.subscriptions.clone();
177 let cleanup_metrics = Arc::clone(&self.metrics);
178 tokio::spawn(async move {
179 while let Some(notify) = disconnect_rx.recv().await {
180 let removed = cleanup_connections.remove(¬ify.conn_id);
181 if removed.is_some() {
182 cleanup_metrics
183 .total_disconnections
184 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
185 if let Some(ref subs) = cleanup_subs {
187 subs.on_disconnect(notify.conn_id);
188 }
189 tracing::info!(
190 conn_id = notify.conn_id,
191 remaining = cleanup_connections.len(),
192 "connection removed from pool"
193 );
194 }
195 }
196 });
197
198 let ka_connections = Arc::clone(&self.connections);
200 let ka_subs = self.subscriptions.clone();
201 let ka_metrics = Arc::clone(&self.metrics);
202 let mut ka_shutdown_rx = shutdown_rx.clone();
203 tokio::spawn(async move {
204 const CHECK_INTERVAL_SECS: u64 = 15;
205 const TIMEOUT_SECS: u64 = 66;
206 let mut interval =
207 tokio::time::interval(std::time::Duration::from_secs(CHECK_INTERVAL_SECS));
208 interval.tick().await; loop {
210 tokio::select! {
211 _ = shutdown_requested(&mut ka_shutdown_rx) => {
212 tracing::info!("API server keepalive task stopped by shutdown signal");
213 break;
214 }
215 _ = interval.tick() => {}
216 }
217 let now = Instant::now();
218 let mut timed_out = Vec::new();
219 for entry in ka_connections.iter() {
220 let conn = entry.value();
221 if now.duration_since(conn.last_keepalive).as_secs() >= TIMEOUT_SECS {
222 timed_out.push(conn.conn_id);
223 }
224 }
225 for conn_id in timed_out {
226 if ka_connections.remove(&conn_id).is_some() {
227 ka_metrics
228 .keepalive_timeouts
229 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
230 ka_metrics
231 .total_disconnections
232 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
233 if let Some(ref subs) = ka_subs {
234 subs.on_disconnect(conn_id);
235 }
236 tracing::info!(
237 conn_id = conn_id,
238 remaining = ka_connections.len(),
239 "keepalive timeout, connection removed"
240 );
241 }
242 }
243 }
244 });
245
246 let connections = Arc::clone(&self.connections);
248 let accept_metrics = Arc::clone(&self.metrics);
249 loop {
250 let (stream, peer_addr) = tokio::select! {
251 _ = shutdown_requested(&mut shutdown_rx) => {
252 tracing::info!("API server accept loop stopped by shutdown signal");
253 break;
254 }
255 accepted = listener.accept() => accepted?,
256 };
257
258 if connections.len() >= MAX_CONNECTIONS {
259 accept_metrics
260 .rejected_connections
261 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
262 tracing::warn!(
263 peer = %peer_addr,
264 "max connections reached ({}), rejecting",
265 MAX_CONNECTIONS
266 );
267 drop(stream);
268 continue;
269 }
270
271 accept_metrics
272 .total_connections
273 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
274
275 let conn_id = crate::conn::ClientConn::generate_conn_id();
276 let aes_key = crate::conn::ClientConn::generate_aes_key();
277 set_nodelay_with_log(&stream, peer_addr, "tcp");
278
279 tracing::info!(
280 conn_id = conn_id,
281 peer = %peer_addr,
282 total = connections.len() + 1,
283 "client connected"
284 );
285
286 let tx = crate::conn::run_connection(
287 stream,
288 conn_id,
289 aes_key,
290 req_tx.clone(),
291 disconnect_tx.clone(),
292 shutdown_rx.clone(),
293 )
294 .await;
295
296 let conn = ClientConn {
297 conn_id,
298 state: ConnState::Connected,
299 aes_key,
300 aes_encrypt_enabled: false,
301 proto_fmt_type: ProtoFmtType::Protobuf,
302 last_keepalive: Instant::now(),
303 recv_notify: false,
304 ai_type: 0,
305 keepalive_count: std::sync::atomic::AtomicU32::new(0),
306 tx,
307 key_id: None,
310 scopes: std::collections::HashSet::new(),
311 allowed_markets: None,
314 allowed_acc_ids: None,
316 };
317
318 connections.insert(conn_id, conn);
319 }
320
321 Ok(())
322 }
323
324 pub async fn send_response(
326 connections: &DashMap<u64, ClientConn>,
327 conn_id: u64,
328 proto_id: u32,
329 serial_no: u32,
330 body: Vec<u8>,
331 ) {
332 if let Some(conn) = connections.get(&conn_id) {
333 let frame = conn.make_frame(proto_id, serial_no, bytes::Bytes::from(body));
334 if conn.tx.send(frame).await.is_err() {
335 tracing::warn!(
336 conn_id = conn_id,
337 "failed to send response, connection closed"
338 );
339 }
340 }
341 }
342}
343
344async fn process_requests(
346 mut req_rx: mpsc::Receiver<IncomingRequest>,
347 connections: Arc<DashMap<u64, ClientConn>>,
348 router: Arc<RequestRouter>,
349 config: ServerConfig,
350 metrics: Arc<GatewayMetrics>,
351 server_time_offset_secs: Arc<AtomicI64>,
352) {
353 while let Some(mut req) = req_rx.recv().await {
354 let conn_id = req.conn_id;
355 let proto_id_val = req.proto_id;
356 let serial_no = req.serial_no;
357 let req_start = Instant::now();
358
359 metrics
360 .total_requests
361 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
362
363 if let Some(mut conn) = connections.get_mut(&conn_id) {
365 conn.last_keepalive = Instant::now();
366 }
367
368 if futu_auth::is_internal_proto_id(proto_id_val) {
373 metrics
374 .total_request_errors
375 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
376 tracing::warn!(
377 conn_id,
378 proto_id = proto_id_val,
379 "rejecting daemon-internal proto_id at raw TCP public surface (audit 0532 F3)"
380 );
381 continue;
382 }
383
384 if proto_id_val != proto_id::INIT_CONNECT
386 && let Some(conn) = connections.get(&conn_id)
387 && conn.aes_encrypt_enabled
388 {
389 match conn.decrypt_body(&req.body) {
390 Ok(decrypted) => {
391 req.body = bytes::Bytes::from(decrypted);
392 }
393 Err(e) => {
394 metrics
395 .total_request_errors
396 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
397 tracing::warn!(
398 conn_id = conn_id,
399 proto_id = proto_id_val,
400 error = %e,
401 "AES decrypt request failed, dropping"
402 );
403 continue;
404 }
405 }
406 }
407
408 let response_body = futu_backend::delay_stats::with_api_request(
410 conn_id,
411 serial_no,
412 proto_id_val,
413 || async {
414 match proto_id_val {
415 proto_id::INIT_CONNECT => match connections.get_mut(&conn_id) {
416 Some(mut conn) => match conn.handle_init_connect(
417 &req.body,
418 config.server_ver,
419 config.login_user_id,
420 config.keepalive_interval,
421 config.rsa_private_key.as_deref(),
422 ) {
423 Ok(body) => Some(body),
424 Err(error) => {
425 metrics
426 .total_request_errors
427 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
428 tracing::warn!(
429 conn_id,
430 proto_id = proto_id_val,
431 error = %error,
432 "InitConnect handling failed"
433 );
434 None
435 }
436 },
437 None => {
438 metrics
439 .total_request_errors
440 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
441 tracing::warn!(
442 conn_id,
443 proto_id = proto_id_val,
444 "InitConnect request received for missing connection"
445 );
446 None
447 }
448 },
449 proto_id::KEEP_ALIVE => match connections.get(&conn_id) {
450 Some(conn) => match conn
451 .handle_keepalive_at(&req.body, server_now_ts(&server_time_offset_secs))
452 {
453 Ok(body) => Some(body),
454 Err(error) => {
455 metrics
456 .total_request_errors
457 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
458 tracing::warn!(
459 conn_id,
460 proto_id = proto_id_val,
461 error = %error,
462 "KeepAlive handling failed"
463 );
464 None
465 }
466 },
467 None => {
468 metrics
469 .total_request_errors
470 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
471 tracing::warn!(
472 conn_id,
473 proto_id = proto_id_val,
474 "KeepAlive request received for missing connection"
475 );
476 None
477 }
478 },
479 _ => {
480 router.dispatch(conn_id, &req).await
482 }
483 }
484 },
485 )
486 .await;
487
488 metrics.record_latency_ns(req_start.elapsed().as_nanos() as u64);
490
491 if let Some(body) = response_body {
493 metrics
494 .total_response_bytes
495 .fetch_add(body.len() as u64, std::sync::atomic::Ordering::Relaxed);
496 ApiServer::send_response(&connections, conn_id, proto_id_val, serial_no, body).await;
497 } else if proto_id_val != proto_id::INIT_CONNECT && proto_id_val != proto_id::KEEP_ALIVE {
498 metrics
499 .total_request_errors
500 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
501 }
502 }
503}
504
505#[cfg(test)]
506mod tests;