1use std::collections::HashMap;
8use std::sync::Arc;
9use std::sync::atomic::{AtomicU32, Ordering};
10use std::time::SystemTime;
11
12use bytes::Bytes;
13use futures::{SinkExt, StreamExt};
14use parking_lot::Mutex;
15use tokio::net::TcpStream;
16use tokio::sync::{mpsc, oneshot, watch};
17use tokio_util::codec::Framed;
18
19use futu_core::error::{FutuError, Result};
20use futu_core::log_redact::endpoint_log_fingerprint;
21use futu_net::encrypt::{aes_cbc_md5_decrypt_var, aes_cbc_md5_encrypt_var};
22
23use crate::nn_codec::{NNCodec, NNFrame, NNHeader, should_skip_encryption};
24
25pub struct BackendConn {
27 serial_no: AtomicU32,
28 sec_data: AtomicU32,
29 connected: Arc<std::sync::atomic::AtomicBool>,
30 connected_tx: watch::Sender<bool>,
31 session_key: Arc<Mutex<Option<Vec<u8>>>>,
37 cmd_tx: mpsc::Sender<BackendCmd>,
38 pending: PendingResponses,
39 shutdown_tx: watch::Sender<bool>,
40 pub user_id: AtomicU32,
41 client_ip: Mutex<String>,
46 pub client_type: u8,
47 pub client_ver: u16,
48 pub lang_id: u8,
49}
50
51enum BackendCmd {
52 Send(NNFrame),
53}
54
55type PendingResponses = Arc<Mutex<HashMap<u32, oneshot::Sender<NNFrame>>>>;
56
57pub type PushCallback = Arc<dyn Fn(u16, Bytes, SystemTime) + Send + Sync + 'static>;
59
60#[derive(Debug, Clone, Copy, PartialEq, Eq)]
61enum InboundFrameDecision {
62 Deliver,
63 Drop,
64}
65
66fn decode_inbound_frame_body(
67 frame: &mut NNFrame,
68 session_key: Option<&[u8]>,
69) -> InboundFrameDecision {
70 if !should_skip_encryption(frame.header.cmd_id)
71 && let Some(key) = session_key
72 {
73 let body_len = frame.body.len();
74 if body_len >= 32 && body_len.is_multiple_of(16) {
78 match aes_cbc_md5_decrypt_var(key, &frame.body) {
79 Ok(decrypted) => {
80 frame.body = Bytes::from(decrypted);
82 }
83 Err(e) => {
84 tracing::warn!(
88 cmd_id = frame.header.cmd_id,
89 body_len = body_len,
90 key_len = key.len(),
91 error = %e,
92 "decrypt failed, dropping inbound frame"
93 );
94 return InboundFrameDecision::Drop;
95 }
96 }
97 } else {
98 tracing::debug!(
99 cmd_id = frame.header.cmd_id,
100 body_len = body_len,
101 "body not encrypted (len not aligned to 16)"
102 );
103 }
104 }
105
106 if frame.header.is_compressed() {
107 let compressed_body_len = frame.body.len();
108 match crate::ftlogin_wire::decode_inbound_body(true, frame.body.as_ref()) {
109 Ok(decompressed) => {
110 frame.body = Bytes::from(decompressed);
111 frame.header.body_len = frame.body.len() as u32;
112 tracing::debug!(
113 cmd_id = frame.header.cmd_id,
114 serial_no = frame.header.serial_no,
115 compressed_body_len,
116 body_len = frame.body.len(),
117 "decompressed inbound frame after decrypt"
118 );
119 }
120 Err(e) => {
121 tracing::warn!(
126 cmd_id = frame.header.cmd_id,
127 serial_no = frame.header.serial_no,
128 body_len = compressed_body_len,
129 error = %e,
130 "decompress failed after decrypt, dropping inbound frame"
131 );
132 return InboundFrameDecision::Drop;
133 }
134 }
135 }
136
137 InboundFrameDecision::Deliver
138}
139
140fn release_pending_on_inbound_decode_error(frame: &NNFrame, pending: &PendingResponses) {
141 if !frame.header.is_push {
142 pending.lock().remove(&frame.header.serial_no);
143 }
144}
145
146fn mark_disconnected(
147 connected: &Arc<std::sync::atomic::AtomicBool>,
148 connected_tx: &watch::Sender<bool>,
149) {
150 let was_connected = connected.swap(false, Ordering::AcqRel);
151 if was_connected {
152 let _ = connected_tx.send(false);
153 }
154}
155
156impl BackendConn {
157 pub const CONNECT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(10);
164
165 pub const CLIENT_VER_FTGTW: u16 = 1030;
172
173 async fn establish_stream(addr: &str, timeout: std::time::Duration) -> Result<TcpStream> {
175 let stream = match tokio::time::timeout(timeout, TcpStream::connect(addr)).await {
176 Ok(Ok(s)) => s,
177 Ok(Err(e)) => return Err(e.into()),
178 Err(_elapsed) => {
179 return Err(FutuError::Network(std::io::Error::new(
180 std::io::ErrorKind::TimedOut,
181 format!("connect to {addr} timed out after {}s", timeout.as_secs()),
182 )));
183 }
184 };
185 stream.set_nodelay(true)?;
186 Ok(stream)
187 }
188
189 pub async fn connect(addr: &str, push_callback: PushCallback) -> Result<Self> {
191 let stream = Self::establish_stream(addr, Self::CONNECT_TIMEOUT).await?;
192 tracing::info!(
193 addr_fp = %endpoint_log_fingerprint(addr),
194 "connected to backend"
195 );
196 Ok(Self::from_stream(stream, push_callback))
197 }
198
199 pub async fn connect_race(
209 addrs: &[String],
210 push_callback: PushCallback,
211 ) -> Result<(Self, String)> {
212 use futures::stream::{FuturesUnordered, StreamExt};
213
214 if addrs.is_empty() {
215 return Err(FutuError::Network(std::io::Error::new(
216 std::io::ErrorKind::InvalidInput,
217 "connect_race: empty address list",
218 )));
219 }
220
221 tracing::info!(
222 candidates = addrs.len(),
223 candidate_fps = ?addrs
224 .iter()
225 .map(|addr| endpoint_log_fingerprint(addr))
226 .collect::<Vec<_>>(),
227 "racing parallel connects"
228 );
229
230 let mut attempts: FuturesUnordered<_> = addrs
231 .iter()
232 .cloned()
233 .map(|addr| async move {
234 let result = Self::establish_stream(&addr, Self::CONNECT_TIMEOUT).await;
235 (addr, result)
236 })
237 .collect();
238
239 let mut last_err: Option<FutuError> = None;
240 while let Some((addr, result)) = attempts.next().await {
241 match result {
242 Ok(stream) => {
243 tracing::info!(
244 addr_fp = %endpoint_log_fingerprint(&addr),
245 remaining_losers = attempts.len(),
246 "connect race winner"
247 );
248 drop(attempts); let conn = Self::from_stream(stream, push_callback);
250 return Ok((conn, addr));
251 }
252 Err(e) => {
253 tracing::debug!(
254 addr_fp = %endpoint_log_fingerprint(&addr),
255 error = %e,
256 "candidate failed"
257 );
258 last_err = Some(e);
259 }
260 }
261 }
262
263 Err(last_err.unwrap_or_else(|| {
264 FutuError::Network(std::io::Error::other("connect_race: all candidates failed"))
265 }))
266 }
267
268 #[cfg(feature = "test-util")]
273 pub fn from_duplex(stream: tokio::io::DuplexStream, push_callback: PushCallback) -> Self {
274 Self::from_stream_inner(stream, push_callback)
275 }
276
277 fn from_stream(stream: TcpStream, push_callback: PushCallback) -> Self {
279 Self::from_stream_inner(stream, push_callback)
280 }
281
282 pub(crate) fn from_stream_inner<S>(stream: S, push_callback: PushCallback) -> Self
285 where
286 S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Send + Unpin + 'static,
287 {
288 let framed = Framed::new(stream, NNCodec);
289 let (mut sink, mut stream_rx) = framed.split();
290
291 let (cmd_tx, mut cmd_rx) = mpsc::channel::<BackendCmd>(256);
292
293 let pending: PendingResponses =
294 Arc::new(Mutex::new(HashMap::<u32, oneshot::Sender<NNFrame>>::new()));
295 let pending_recv = pending.clone();
296 let pending_send = pending.clone();
297 let connected = Arc::new(std::sync::atomic::AtomicBool::new(true));
298 let (connected_tx, _) = watch::channel(true);
299 let (shutdown_tx, mut shutdown_rx_recv) = watch::channel(false);
300 let mut shutdown_rx_send = shutdown_tx.subscribe();
301 let connected_recv = Arc::clone(&connected);
302 let connected_send = Arc::clone(&connected);
303 let connected_tx_recv = connected_tx.clone();
304 let connected_tx_send = connected_tx.clone();
305 let session_key: Arc<Mutex<Option<Vec<u8>>>> = Arc::new(Mutex::new(None));
306 let session_key_for_recv = session_key.clone();
307
308 tokio::spawn(async move {
310 loop {
311 let next_frame = tokio::select! {
312 _ = shutdown_rx_recv.changed() => {
313 break;
314 }
315 next = stream_rx.next() => next,
316 };
317 let Some(Ok(mut frame)) = next_frame else {
318 break;
319 };
320
321 tracing::debug!(
323 cmd_id = frame.header.cmd_id,
324 serial_no = frame.header.serial_no,
325 is_push = frame.header.is_push,
326 is_compressed = frame.header.is_compressed,
327 ex_head_len = frame.header.ex_head_len,
328 body_len = frame.header.body_len,
329 actual_body_len = frame.body.len(),
330 "recv frame"
331 );
332
333 if let Some(err) = frame.parse_ex_head_error()
341 && (err.cmd_result != 0
342 || err.code != 0
343 || !err.message.is_empty()
344 || !err.source.is_empty())
345 {
346 let is_soft_fail = err.code == -102;
352 if is_soft_fail {
353 tracing::debug!(
354 cmd_id = frame.header.cmd_id,
355 code = err.code,
356 source = %err.source,
357 message = %err.message,
358 "server: cmd not available on this channel (soft-fail)"
359 );
360 } else {
361 tracing::warn!(
362 cmd_id = frame.header.cmd_id,
363 cmd_result = err.cmd_result,
364 code = err.code,
365 source = %err.source,
366 message = %err.message,
367 "server returned err_info in ex_head"
368 );
369 }
370 }
371
372 let session_key = session_key_for_recv.lock().clone();
373 if decode_inbound_frame_body(&mut frame, session_key.as_deref())
374 == InboundFrameDecision::Drop
375 {
376 release_pending_on_inbound_decode_error(&frame, &pending_recv);
377 continue;
378 }
379
380 let is_push = frame.header.is_push
385 || (frame.header.serial_no == 0 && !pending_recv.lock().contains_key(&0));
386 if is_push {
387 tracing::debug!(
388 cmd_id = frame.header.cmd_id,
389 body_len = frame.body.len(),
390 is_push = frame.header.is_push,
391 is_compressed = frame.header.is_compressed,
392 reserved = ?frame.header.reserved,
393 "backend push received"
394 );
395 push_callback(frame.header.cmd_id, frame.body, SystemTime::now());
396 } else {
397 let tx = pending_recv.lock().remove(&frame.header.serial_no);
398 if let Some(tx) = tx
399 && let Err(frame) = tx.send(frame)
400 {
401 tracing::debug!(
402 cmd_id = frame.header.cmd_id,
403 serial_no = frame.header.serial_no,
404 body_len = frame.body.len(),
405 "backend response receiver dropped before frame delivery"
406 );
407 }
408 }
409 }
410 mark_disconnected(&connected_recv, &connected_tx_recv);
412 tracing::warn!("backend connection closed");
413 let mut pending = pending_recv.lock();
414 let count = pending.len();
415 if count > 0 {
416 tracing::warn!(
417 pending_count = count,
418 "aborting pending requests due to disconnect"
419 );
420 }
421 pending.clear();
425 });
426
427 tokio::spawn(async move {
429 loop {
430 let next_cmd = tokio::select! {
431 _ = shutdown_rx_send.changed() => {
432 break;
433 }
434 cmd = cmd_rx.recv() => cmd,
435 };
436 let Some(cmd) = next_cmd else {
437 break;
438 };
439
440 match cmd {
441 BackendCmd::Send(frame) => {
442 if let Err(e) = sink.send(frame).await {
443 mark_disconnected(&connected_send, &connected_tx_send);
444 tracing::error!(error = %e, "backend send failed");
445 let mut pending = pending_send.lock();
446 let count = pending.len();
447 if count > 0 {
448 tracing::warn!(
449 pending_count = count,
450 "aborting pending requests due to send failure"
451 );
452 }
453 pending.clear();
454 break;
455 }
456 }
457 }
458 }
459 });
460
461 Self {
462 serial_no: AtomicU32::new(0),
463 sec_data: AtomicU32::new(1),
464 connected,
465 connected_tx,
466 session_key, cmd_tx,
468 pending,
469 shutdown_tx,
470 user_id: AtomicU32::new(0),
471 client_ip: Mutex::new(String::new()),
472 client_type: 40, client_ver: Self::CLIENT_VER_FTGTW,
474 lang_id: 0,
475 }
476 }
477
478 pub fn set_session_key(&self, key: Vec<u8>) {
482 *self.session_key.lock() = Some(key);
483 }
484
485 pub fn set_sec_data(&self, val: u32) {
487 self.sec_data.store(val, Ordering::Relaxed);
488 }
489
490 pub fn set_client_ip(&self, ip: String) {
492 *self.client_ip.lock() = ip;
493 }
494
495 pub fn client_ip(&self) -> String {
497 self.client_ip.lock().clone()
498 }
499
500 pub fn is_connected(&self) -> bool {
501 self.connected.load(Ordering::Acquire)
502 }
503
504 pub fn subscribe_connection_state(&self) -> watch::Receiver<bool> {
512 self.connected_tx.subscribe()
513 }
514
515 pub async fn request(&self, cmd_id: u16, body: Vec<u8>) -> Result<NNFrame> {
517 self.request_with_reserved(cmd_id, body, [0u8; 10]).await
518 }
519
520 pub async fn request_with_reserved(
522 &self,
523 cmd_id: u16,
524 body: Vec<u8>,
525 reserved: [u8; 10],
526 ) -> Result<NNFrame> {
527 self.request_with_reserved_timeout(
528 cmd_id,
529 body,
530 reserved,
531 std::time::Duration::from_secs(10),
532 )
533 .await
534 }
535
536 pub(crate) async fn request_with_reserved_timeout(
537 &self,
538 cmd_id: u16,
539 body: Vec<u8>,
540 reserved: [u8; 10],
541 timeout: std::time::Duration,
542 ) -> Result<NNFrame> {
543 let deadline = tokio::time::Instant::now() + timeout;
544 let frame = self.build_outbound_frame(cmd_id, body, reserved)?;
545 let serial_no = frame.header.serial_no;
546
547 let (resp_tx, resp_rx) = oneshot::channel();
548 self.pending.lock().insert(serial_no, resp_tx);
549 match tokio::time::timeout_at(deadline, self.cmd_tx.send(BackendCmd::Send(frame))).await {
550 Ok(Ok(())) => {}
551 Ok(Err(_closed)) => {
552 self.pending.lock().remove(&serial_no);
553 mark_disconnected(&self.connected, &self.connected_tx);
554 return Err(FutuError::NotInitialized);
555 }
556 Err(_elapsed) => {
557 self.pending.lock().remove(&serial_no);
558 return Err(FutuError::Timeout);
559 }
560 }
561
562 let resp = crate::delay_stats::trace_backend_request(cmd_id, async {
563 match tokio::time::timeout_at(deadline, resp_rx).await {
564 Ok(Ok(resp)) => Ok(resp),
565 Ok(Err(_closed)) => {
566 self.pending.lock().remove(&serial_no);
567 Err(FutuError::Codec("response channel closed".into()))
568 }
569 Err(_elapsed) => {
570 self.pending.lock().remove(&serial_no);
571 Err(FutuError::Timeout)
572 }
573 }
574 })
575 .await?;
576
577 Ok(resp)
578 }
579
580 pub async fn send_fire_and_forget(&self, cmd_id: u16, body: Vec<u8>) -> Result<()> {
582 let frame = self.build_outbound_frame(cmd_id, body, [0u8; 10])?;
583 if self.cmd_tx.send(BackendCmd::Send(frame)).await.is_err() {
584 mark_disconnected(&self.connected, &self.connected_tx);
585 return Err(FutuError::NotInitialized);
586 }
587
588 Ok(())
589 }
590
591 fn build_outbound_frame(
592 &self,
593 cmd_id: u16,
594 body: Vec<u8>,
595 reserved: [u8; 10],
596 ) -> Result<NNFrame> {
597 let serial = self.next_serial();
598 let mut header = NNHeader::new(cmd_id, serial);
599 header.user_id = self.user_id.load(Ordering::Relaxed);
600 header.client_type = self.client_type;
601 header.client_ver = self.client_ver;
602 header.lang_id = self.lang_id;
603 header.reserved.copy_from_slice(&reserved[..8]);
607
608 let final_body = self.encode_outbound_body(cmd_id, body)?;
609 header.body_len = final_body.len() as u32;
610
611 Ok(NNFrame {
612 header,
613 body: Bytes::from(final_body),
614 ex_head: Bytes::new(),
615 })
616 }
617
618 fn encode_outbound_body(&self, cmd_id: u16, body: Vec<u8>) -> Result<Vec<u8>> {
619 if should_skip_encryption(cmd_id) {
620 return Ok(body);
621 }
622
623 let key = self.session_key.lock().clone();
624 match key {
625 Some(key) => {
626 let sec = self.sec_data.fetch_add(1, Ordering::Relaxed) + 1;
628 let mut plaintext = Vec::with_capacity(4 + body.len());
629 plaintext.extend_from_slice(&sec.to_be_bytes());
630 plaintext.extend_from_slice(&body);
631 aes_cbc_md5_encrypt_var(&key, &plaintext)
633 }
634 None => Ok(body),
635 }
636 }
637
638 fn next_serial(&self) -> u32 {
639 self.serial_no.fetch_add(1, Ordering::Relaxed) + 1
640 }
641}
642
643impl Drop for BackendConn {
644 fn drop(&mut self) {
645 let _ = self.shutdown_tx.send(true);
646 mark_disconnected(&self.connected, &self.connected_tx);
647 self.pending.lock().clear();
648 }
649}
650
651#[cfg(test)]
652mod tests;