1use std::sync::Arc;
2use std::sync::atomic::{AtomicU32, AtomicU64, Ordering};
3use std::time::Duration;
4
5use bytes::Bytes;
6use dashmap::DashMap;
7use tokio::sync::{mpsc, oneshot};
8
9use futu_codec::frame::FutuFrame;
10use futu_core::error::{FutuError, Result};
11use futu_core::proto_id;
12
13use crate::connection::Connection;
14use crate::encrypt;
15use crate::reconnect::ReconnectPolicy;
16
17const DEFAULT_REQUEST_TIMEOUT: Duration = Duration::from_secs(12);
18const DEFAULT_KEEP_ALIVE_INTERVAL_SECS: u64 = 10;
19const MIN_KEEP_ALIVE_INTERVAL_SECS: u64 = 1;
20const MAX_KEEP_ALIVE_INTERVAL_SECS: u64 = 60;
21const MAX_HEARTBEAT_SEND_FAILURES: u32 = 3;
22pub(crate) const PUSH_CHANNEL_CAPACITY: usize = 4096;
23
24fn sanitize_keep_alive_interval_secs(raw: i32) -> u64 {
25 if raw <= 0 {
26 return DEFAULT_KEEP_ALIVE_INTERVAL_SECS;
27 }
28 (raw as u64).clamp(MIN_KEEP_ALIVE_INTERVAL_SECS, MAX_KEEP_ALIVE_INTERVAL_SECS)
29}
30
31fn heartbeat_send_failure_should_exit(consecutive_failures: &mut u32, result: &Result<()>) -> bool {
32 if result.is_ok() {
33 *consecutive_failures = 0;
34 return false;
35 }
36
37 *consecutive_failures = consecutive_failures.saturating_add(1);
38 *consecutive_failures >= MAX_HEARTBEAT_SEND_FAILURES
39}
40
41fn init_connect_server_err(ret_type: i32, ret_msg: Option<String>) -> FutuError {
42 let msg = match ret_msg {
43 Some(msg) if !msg.is_empty() => msg,
44 _ => "<missing ret_msg>".to_string(),
45 };
46 FutuError::ServerError { ret_type, msg }
47}
48
49#[derive(Debug, Clone)]
51pub struct ClientConfig {
52 pub addr: String,
53 pub client_ver: String,
54 pub client_id: String,
55 pub recv_notify: bool,
56 pub rsa_key: Option<String>,
61}
62
63#[derive(Debug, Clone)]
65pub struct PushMessage {
66 pub proto_id: u32,
67 pub body: Bytes,
68}
69
70type PushSender = mpsc::Sender<PushMessage>;
71pub type PushReceiver = mpsc::Receiver<PushMessage>;
72
73struct PendingRequest {
75 tx: oneshot::Sender<FutuFrame>,
76}
77
78pub struct FutuClient {
87 config: ClientConfig,
88 serial_no: AtomicU32,
89 pending: Arc<DashMap<u32, PendingRequest>>,
90 push_tx: PushSender,
91 cmd_tx: Option<mpsc::Sender<ClientCommand>>,
92 conn_aes_key: parking_lot::Mutex<Option<[u8; 16]>>,
93 conn_id: AtomicU64,
94 init_ai_type: Option<i32>,
95}
96
97enum ClientCommand {
98 Send(FutuFrame, oneshot::Sender<()>),
99}
100
101fn decode_client_inbound_body(frame: &FutuFrame, aes_key: Option<&[u8; 16]>) -> Result<Bytes> {
102 let body = match aes_key {
103 Some(key) if !frame.body.is_empty() => {
104 encrypt::aes_ecb_decrypt(key, &frame.body).map(Bytes::from)
105 }
106 _ => Ok(frame.body.clone()),
107 }?;
108 verify_plaintext_body_sha1(frame, body)
109}
110
111fn verify_plaintext_body_sha1(frame: &FutuFrame, plaintext: Bytes) -> Result<Bytes> {
112 let plaintext_frame = FutuFrame {
113 header: frame.header.clone(),
114 body: plaintext.clone(),
115 };
116 if plaintext_frame.verify_sha1() {
117 Ok(plaintext)
118 } else {
119 Err(FutuError::Sha1Mismatch)
120 }
121}
122
123fn build_init_connect_request(
124 config: &ClientConfig,
125 ai_type: Option<i32>,
126) -> futu_proto::init_connect::Request {
127 futu_proto::init_connect::Request {
128 c2s: futu_proto::init_connect::C2s {
129 client_ver: 100,
130 client_id: config.client_id.clone(),
131 recv_notify: Some(config.recv_notify),
132 packet_enc_algo: Some(0),
133 push_proto_fmt: Some(0),
134 programming_language: Some(String::new()),
135 ai_type,
136 },
137 }
138}
139
140fn parse_conn_aes_key(conn_aes_key: &str) -> Result<[u8; 16]> {
141 let key_bytes = conn_aes_key.as_bytes();
142 match key_bytes.len() {
143 16 => {
144 let mut key = [0u8; 16];
145 key.copy_from_slice(key_bytes);
146 Ok(key)
147 }
148 32 => hex_decode_16(key_bytes)
149 .ok_or_else(|| FutuError::Codec("invalid hex InitConnect AES key".into())),
150 len => Err(FutuError::Codec(format!(
151 "unexpected AES key length: {len}"
152 ))),
153 }
154}
155
156fn release_pending_on_inbound_decode_error(
157 frame: &FutuFrame,
158 pending: &DashMap<u32, PendingRequest>,
159) {
160 if !proto_id::is_push_proto(frame.header.proto_id) {
161 pending.remove(&frame.header.serial_no);
162 }
163}
164
165fn push_channel() -> (PushSender, PushReceiver) {
166 mpsc::channel(PUSH_CHANNEL_CAPACITY)
167}
168
169fn send_push_message(push_tx: &PushSender, message: PushMessage) -> bool {
170 let proto_id = message.proto_id;
171 match push_tx.try_send(message) {
172 Ok(()) => true,
173 Err(mpsc::error::TrySendError::Full(_)) => {
174 tracing::warn!(
175 proto_id,
176 capacity = PUSH_CHANNEL_CAPACITY,
177 "push receiver is slow; dropping push message from bounded client queue"
178 );
179 false
180 }
181 Err(mpsc::error::TrySendError::Closed(_)) => {
182 tracing::debug!(proto_id, "push receiver dropped before delivery");
183 false
184 }
185 }
186}
187
188fn send_response_frame(tx: oneshot::Sender<FutuFrame>, frame: FutuFrame) -> bool {
189 let serial = frame.header.serial_no;
190 let proto_id = frame.header.proto_id;
191 match tx.send(frame) {
192 Ok(()) => true,
193 Err(_) => {
194 tracing::debug!(
195 serial,
196 proto_id,
197 "response receiver dropped before delivery"
198 );
199 false
200 }
201 }
202}
203
204fn send_command_ack(ack: oneshot::Sender<()>) -> bool {
205 match ack.send(()) {
206 Ok(()) => true,
207 Err(_) => {
208 tracing::debug!("command ack receiver dropped before delivery");
209 false
210 }
211 }
212}
213
214impl FutuClient {
215 pub fn new(config: ClientConfig) -> (Self, PushReceiver) {
219 let (push_tx, push_rx) = push_channel();
220
221 let client = Self {
222 config,
223 serial_no: AtomicU32::new(0),
224 pending: Arc::new(DashMap::new()),
225 push_tx,
226 cmd_tx: None,
227 conn_aes_key: parking_lot::Mutex::new(None),
228 conn_id: AtomicU64::new(0),
229 init_ai_type: None,
230 };
231
232 (client, push_rx)
233 }
234
235 pub fn with_init_ai_type(mut self, ai_type: Option<i32>) -> Self {
240 self.init_ai_type = ai_type;
241 self
242 }
243
244 pub async fn connect(&mut self) -> Result<InitConnectInfo> {
246 let mut conn = Connection::connect(&self.config.addr).await?;
247
248 let serial = self.next_serial();
250 let req = build_init_connect_request(&self.config, self.init_ai_type);
251 let raw_body = prost::Message::encode_to_vec(&req);
252
253 let send_body = if let Some(ref rsa_key) = self.config.rsa_key {
255 tracing::debug!("encrypting InitConnect with RSA");
256 encrypt::rsa_public_encrypt(rsa_key, &raw_body)?
257 } else {
258 raw_body.clone()
259 };
260
261 let mut frame = Connection::build_frame(proto_id::INIT_CONNECT, serial, send_body);
263 {
264 use sha1::{Digest, Sha1};
265 let mut hasher = Sha1::new();
266 hasher.update(&raw_body);
267 let hash = hasher.finalize();
268 frame.header.body_sha1.copy_from_slice(&hash);
269 }
270 conn.send(frame).await?;
271
272 let resp_frame = conn.recv().await?.ok_or(FutuError::Codec(
274 "connection closed during InitConnect".into(),
275 ))?;
276
277 let resp_body = if let Some(ref rsa_key) = self.config.rsa_key {
279 tracing::debug!("decrypting InitConnect response with RSA");
280 Bytes::from(encrypt::rsa_private_decrypt(rsa_key, &resp_frame.body)?)
281 } else {
282 resp_frame.body.clone()
283 };
284 let resp_body = verify_plaintext_body_sha1(&resp_frame, resp_body)?;
285
286 let resp: futu_proto::init_connect::Response =
287 prost::Message::decode(resp_body.as_ref()).map_err(FutuError::Proto)?;
288
289 let ret_type = resp.ret_type;
291 if ret_type != 0 {
292 return Err(init_connect_server_err(ret_type, resp.ret_msg));
293 }
294
295 let s2c = resp.s2c.ok_or(FutuError::Codec(
296 "missing s2c in InitConnect response".into(),
297 ))?;
298
299 let info = InitConnectInfo {
300 server_ver: s2c.server_ver,
301 login_user_id: s2c.login_user_id,
302 conn_id: s2c.conn_id,
303 conn_aes_key: s2c.conn_aes_key.clone(),
304 keep_alive_interval: s2c.keep_alive_interval,
305 };
306 self.conn_id.store(info.conn_id, Ordering::Relaxed);
307
308 if !info.conn_aes_key.is_empty() {
310 let key_bytes = info.conn_aes_key.as_bytes();
311 tracing::debug!(key_len = key_bytes.len(), "received AES key");
312 *self.conn_aes_key.lock() = Some(parse_conn_aes_key(&info.conn_aes_key)?);
313 }
314
315 tracing::info!(
316 server_ver = info.server_ver,
317 conn_id = info.conn_id,
318 keep_alive_interval = info.keep_alive_interval,
319 "InitConnect succeeded"
320 );
321
322 let keep_alive_interval_secs = sanitize_keep_alive_interval_secs(info.keep_alive_interval);
327 if info.keep_alive_interval <= 0
328 || keep_alive_interval_secs != info.keep_alive_interval as u64
329 {
330 tracing::warn!(
331 raw_keep_alive_interval = info.keep_alive_interval,
332 sanitized_keep_alive_interval = keep_alive_interval_secs,
333 "sanitized InitConnect keep_alive_interval"
334 );
335 }
336 let keep_alive_interval = Duration::from_secs(keep_alive_interval_secs);
337 self.start_background_tasks(conn, keep_alive_interval);
338
339 Ok(info)
340 }
341
342 pub async fn request(&self, proto_id: u32, body: Vec<u8>) -> Result<FutuFrame> {
344 self.request_with_timeout(proto_id, body, DEFAULT_REQUEST_TIMEOUT)
345 .await
346 }
347
348 pub async fn request_with_timeout(
354 &self,
355 proto_id: u32,
356 body: Vec<u8>,
357 timeout: Duration,
358 ) -> Result<FutuFrame> {
359 let deadline = tokio::time::Instant::now() + timeout;
360 let serial = self.next_serial();
361
362 let (final_body, sha1) = self.prepare_body(&body);
364
365 let frame = FutuFrame::with_sha1(proto_id, serial, Bytes::from(final_body), sha1);
366
367 let (resp_tx, resp_rx) = oneshot::channel();
368 self.pending.insert(serial, PendingRequest { tx: resp_tx });
369
370 if let Some(cmd_tx) = &self.cmd_tx {
372 let (ack_tx, ack_rx) = oneshot::channel();
373 match tokio::time::timeout_at(deadline, cmd_tx.send(ClientCommand::Send(frame, ack_tx)))
374 .await
375 {
376 Ok(Ok(())) => {}
377 Ok(Err(_)) => {
378 self.pending.remove(&serial);
379 return Err(FutuError::NotInitialized);
380 }
381 Err(_) => {
382 self.pending.remove(&serial);
383 return Err(FutuError::Timeout);
384 }
385 }
386
387 match tokio::time::timeout_at(deadline, ack_rx).await {
388 Ok(Ok(())) => {}
389 Ok(Err(_)) => {
390 self.pending.remove(&serial);
391 return Err(FutuError::Codec("send ack failed".into()));
392 }
393 Err(_) => {
394 self.pending.remove(&serial);
395 return Err(FutuError::Timeout);
396 }
397 }
398 } else {
399 self.pending.remove(&serial);
400 return Err(FutuError::NotInitialized);
401 }
402
403 let resp = tokio::time::timeout_at(deadline, resp_rx)
405 .await
406 .map_err(|elapsed| {
407 self.pending.remove(&serial);
408 tracing::warn!(
409 serial,
410 error = %elapsed,
411 "request response wait timed out; pending entry removed"
412 );
413 FutuError::Timeout
414 })?
415 .map_err(|err| {
416 FutuError::Codec(format!(
417 "response channel closed for serial {serial}: {err}"
418 ))
419 })?;
420
421 Ok(resp)
422 }
423
424 pub fn conn_id(&self) -> Option<u64> {
429 let conn_id = self.conn_id.load(Ordering::Relaxed);
430 (conn_id != 0).then_some(conn_id)
431 }
432
433 fn next_serial(&self) -> u32 {
434 self.serial_no.fetch_add(1, Ordering::Relaxed) + 1
435 }
436
437 fn prepare_body(&self, plaintext: &[u8]) -> (Vec<u8>, [u8; 20]) {
438 use sha1::{Digest, Sha1};
439
440 let mut hasher = Sha1::new();
442 hasher.update(plaintext);
443 let sha1_result = hasher.finalize();
444 let mut sha1 = [0u8; 20];
445 sha1.copy_from_slice(&sha1_result);
446
447 let body = if self.config.rsa_key.is_some() {
449 let key = self.conn_aes_key.lock();
450 match key.as_ref() {
451 Some(k) => encrypt::aes_ecb_encrypt(k, plaintext),
452 None => plaintext.to_vec(),
453 }
454 } else {
455 plaintext.to_vec()
456 };
457
458 (body, sha1)
459 }
460
461 fn start_background_tasks(&mut self, conn: Connection, keep_alive_interval: Duration) {
462 let (cmd_tx, cmd_rx) = mpsc::channel(256);
463 self.cmd_tx = Some(cmd_tx.clone());
464
465 let pending = Arc::clone(&self.pending);
466 let push_tx = self.push_tx.clone();
467 let aes_key = if self.config.rsa_key.is_some() {
468 *self.conn_aes_key.lock()
469 } else {
470 None
471 };
472
473 tokio::spawn(async move {
474 run_event_loop(conn, cmd_rx, pending, push_tx, aes_key, keep_alive_interval).await;
475 });
476 }
477}
478
479async fn run_event_loop(
481 mut conn: Connection,
482 mut cmd_rx: mpsc::Receiver<ClientCommand>,
483 pending: Arc<DashMap<u32, PendingRequest>>,
484 push_tx: PushSender,
485 aes_key: Option<[u8; 16]>,
486 keep_alive_interval: Duration,
487) {
488 let mut heartbeat = tokio::time::interval(keep_alive_interval);
489 heartbeat.tick().await; let mut heartbeat_serial: u32 = 10_000_000; let mut heartbeat_send_failures = 0u32;
492
493 loop {
494 tokio::select! {
495 result = conn.recv() => {
497 match result {
498 Ok(Some(frame)) => {
499 let proto_id = frame.header.proto_id;
500 let body = match decode_client_inbound_body(&frame, aes_key.as_ref()) {
501 Ok(body) => body,
502 Err(e) => {
503 tracing::warn!(
504 error = %e,
505 serial = frame.header.serial_no,
506 proto_id,
507 "decrypt failed, dropping inbound frame"
508 );
509 release_pending_on_inbound_decode_error(&frame, &pending);
510 continue;
511 }
512 };
513
514 if proto_id::is_push_proto(proto_id) {
515 send_push_message(&push_tx, PushMessage { proto_id, body });
517 } else {
518 let serial = frame.header.serial_no;
520 match pending.remove(&serial) {
521 Some((_, req)) => {
522 let resp_frame = FutuFrame {
523 header: frame.header,
524 body,
525 };
526 send_response_frame(req.tx, resp_frame);
527 }
528 _ => {
529 tracing::debug!(
530 serial = serial,
531 proto_id = proto_id,
532 "unmatched response"
533 );
534 }
535 }
536 }
537 }
538 Ok(None) => {
539 tracing::warn!("connection closed by server");
540 break;
541 }
542 Err(e) => {
543 tracing::error!(error = %e, "recv error");
544 break;
545 }
546 }
547 }
548
549 cmd = cmd_rx.recv() => {
551 match cmd {
552 Some(ClientCommand::Send(frame, ack)) => {
553 let result = conn.send(frame).await;
554 if let Err(e) = &result {
555 tracing::error!(error = %e, "send failed");
556 }
557 send_command_ack(ack);
558 if result.is_err() {
559 break;
560 }
561 }
562 None => {
563 tracing::info!("shutting down event loop");
564 break;
565 }
566 }
567 }
568
569 _ = heartbeat.tick() => {
571 heartbeat_serial += 1;
572 let req = futu_proto::keep_alive::Request {
573 c2s: futu_proto::keep_alive::C2s {
574 time: chrono::Utc::now().timestamp(),
575 },
576 };
577 let body = prost::Message::encode_to_vec(&req);
578 let frame = Connection::build_frame(
579 proto_id::KEEP_ALIVE,
580 heartbeat_serial,
581 body,
582 );
583 let result = conn.send(frame).await;
584 if let Err(e) = &result {
585 tracing::warn!(
586 error = %e,
587 consecutive_failures = heartbeat_send_failures.saturating_add(1),
588 max_failures = MAX_HEARTBEAT_SEND_FAILURES,
589 "heartbeat send failed"
590 );
591 }
592 if heartbeat_send_failure_should_exit(&mut heartbeat_send_failures, &result) {
593 tracing::error!(
594 consecutive_failures = heartbeat_send_failures,
595 max_failures = MAX_HEARTBEAT_SEND_FAILURES,
596 "heartbeat send failure threshold reached; closing event loop"
597 );
598 break;
599 }
600 if result.is_ok() {
601 tracing::trace!("heartbeat sent");
602 }
603 }
604 }
605 }
606
607 pending.clear();
609 tracing::info!("event loop exited");
610}
611
612#[derive(Debug, Clone)]
614pub struct InitConnectInfo {
615 pub server_ver: i32,
616 pub login_user_id: u64,
617 pub conn_id: u64,
618 pub conn_aes_key: String,
619 pub keep_alive_interval: i32,
620}
621
622pub struct ReconnectingClient {
624 config: ClientConfig,
625 policy: ReconnectPolicy,
626 init_ai_type: Option<i32>,
627}
628
629impl ReconnectingClient {
630 pub fn new(config: ClientConfig) -> Self {
631 Self {
632 config,
633 policy: ReconnectPolicy::default_policy(),
634 init_ai_type: None,
635 }
636 }
637
638 pub fn with_policy(mut self, policy: ReconnectPolicy) -> Self {
639 self.policy = policy;
640 self
641 }
642
643 pub fn with_init_ai_type(mut self, ai_type: Option<i32>) -> Self {
648 self.init_ai_type = ai_type;
649 self
650 }
651
652 pub async fn connect(&mut self) -> Result<(FutuClient, PushReceiver, InitConnectInfo)> {
657 loop {
658 let (mut client, push_rx) = FutuClient::new(self.config.clone());
659 client = client.with_init_ai_type(self.init_ai_type);
660 match client.connect().await {
661 Ok(info) => {
662 self.policy.reset();
663 return Ok((client, push_rx, info));
664 }
665 Err(e) => {
666 tracing::warn!(
667 error = %e,
668 attempt = self.policy.attempts(),
669 "connection failed"
670 );
671 match self.policy.next_delay() {
672 Some(delay) => {
673 tracing::info!(delay_ms = delay.as_millis(), "reconnecting...");
674 tokio::time::sleep(delay).await;
675 }
676 None => {
677 return Err(FutuError::Codec(format!(
678 "max retries reached after {} attempts",
679 self.policy.attempts()
680 )));
681 }
682 }
683 }
684 }
685 }
686 }
687}
688
689#[cfg(test)]
690mod tests;
691
692fn hex_decode_16(hex_bytes: &[u8]) -> Option<[u8; 16]> {
693 if hex_bytes.len() != 32 {
694 return None;
695 }
696 let hex_str = std::str::from_utf8(hex_bytes).ok()?;
697 let mut key = [0u8; 16];
698 for i in 0..16 {
699 key[i] = u8::from_str_radix(&hex_str[i * 2..i * 2 + 2], 16).ok()?;
700 }
701 Some(key)
702}