1use std::sync::atomic::{AtomicU32, Ordering};
2use std::sync::Arc;
3use std::time::Duration;
4
5use bytes::Bytes;
6use dashmap::DashMap;
7use tokio::sync::{mpsc, oneshot};
8
9use futu_codec::frame::FutuFrame;
10use futu_core::error::{FutuError, Result};
11use futu_core::proto_id;
12
13use crate::connection::Connection;
14use crate::encrypt;
15use crate::reconnect::ReconnectPolicy;
16
17#[derive(Debug, Clone)]
19pub struct ClientConfig {
20 pub addr: String,
21 pub client_ver: String,
22 pub client_id: String,
23 pub recv_notify: bool,
24 pub rsa_key: Option<String>,
29}
30
31#[derive(Debug, Clone)]
33pub struct PushMessage {
34 pub proto_id: u32,
35 pub body: Bytes,
36}
37
38struct PendingRequest {
40 tx: oneshot::Sender<FutuFrame>,
41}
42
43pub struct FutuClient {
52 config: ClientConfig,
53 serial_no: AtomicU32,
54 pending: Arc<DashMap<u32, PendingRequest>>,
55 push_tx: mpsc::UnboundedSender<PushMessage>,
56 cmd_tx: Option<mpsc::Sender<ClientCommand>>,
57 conn_aes_key: parking_lot::Mutex<Option<[u8; 16]>>,
58}
59
60enum ClientCommand {
61 Send(FutuFrame, oneshot::Sender<()>),
62 #[expect(dead_code)]
63 Shutdown,
64}
65
66impl FutuClient {
67 pub fn new(config: ClientConfig) -> (Self, mpsc::UnboundedReceiver<PushMessage>) {
71 let (push_tx, push_rx) = mpsc::unbounded_channel();
72
73 let client = Self {
74 config,
75 serial_no: AtomicU32::new(0),
76 pending: Arc::new(DashMap::new()),
77 push_tx,
78 cmd_tx: None,
79 conn_aes_key: parking_lot::Mutex::new(None),
80 };
81
82 (client, push_rx)
83 }
84
85 pub async fn connect(&mut self) -> Result<InitConnectInfo> {
87 let mut conn = Connection::connect(&self.config.addr).await?;
88
89 let serial = self.next_serial();
91 let req = futu_proto::init_connect::Request {
92 c2s: futu_proto::init_connect::C2s {
93 client_ver: 100,
94 client_id: self.config.client_id.clone(),
95 recv_notify: Some(self.config.recv_notify),
96 packet_enc_algo: Some(0),
97 push_proto_fmt: Some(0),
98 programming_language: Some(String::new()),
99 },
100 };
101 let raw_body = prost::Message::encode_to_vec(&req);
102
103 let send_body = if let Some(ref rsa_key) = self.config.rsa_key {
105 tracing::debug!("encrypting InitConnect with RSA");
106 encrypt::rsa_public_encrypt(rsa_key, &raw_body)?
107 } else {
108 raw_body.clone()
109 };
110
111 let mut frame = Connection::build_frame(proto_id::INIT_CONNECT, serial, send_body);
113 {
114 use sha1::{Digest, Sha1};
115 let mut hasher = Sha1::new();
116 hasher.update(&raw_body);
117 let hash = hasher.finalize();
118 frame.header.body_sha1.copy_from_slice(&hash);
119 }
120 conn.send(frame).await?;
121
122 let resp_frame = conn.recv().await?.ok_or(FutuError::Codec(
124 "connection closed during InitConnect".into(),
125 ))?;
126
127 let resp_body = if let Some(ref rsa_key) = self.config.rsa_key {
129 tracing::debug!("decrypting InitConnect response with RSA");
130 encrypt::rsa_private_decrypt(rsa_key, &resp_frame.body)?
131 } else {
132 resp_frame.body.to_vec()
133 };
134
135 let resp: futu_proto::init_connect::Response =
136 prost::Message::decode(resp_body.as_slice()).map_err(FutuError::Proto)?;
137
138 let ret_type = resp.ret_type;
140 if ret_type != 0 {
141 return Err(FutuError::ServerError {
142 ret_type,
143 msg: resp.ret_msg.unwrap_or_default(),
144 });
145 }
146
147 let s2c = resp.s2c.ok_or(FutuError::Codec(
148 "missing s2c in InitConnect response".into(),
149 ))?;
150
151 let info = InitConnectInfo {
152 server_ver: s2c.server_ver,
153 login_user_id: s2c.login_user_id,
154 conn_id: s2c.conn_id,
155 conn_aes_key: s2c.conn_aes_key.clone(),
156 keep_alive_interval: s2c.keep_alive_interval,
157 };
158
159 if !info.conn_aes_key.is_empty() {
161 let key_bytes = info.conn_aes_key.as_bytes();
162 tracing::debug!(
163 key_len = key_bytes.len(),
164 key_hex = hex_str(key_bytes),
165 "received AES key"
166 );
167 if key_bytes.len() == 16 {
168 let mut key = [0u8; 16];
170 key.copy_from_slice(key_bytes);
171 *self.conn_aes_key.lock() = Some(key);
172 } else if key_bytes.len() == 32 {
173 if let Some(key) = hex_decode_16(key_bytes) {
175 *self.conn_aes_key.lock() = Some(key);
176 } else {
177 tracing::warn!(
178 "AES key is 32 chars but not valid hex, using raw first 16 bytes"
179 );
180 let mut key = [0u8; 16];
181 key.copy_from_slice(&key_bytes[..16]);
182 *self.conn_aes_key.lock() = Some(key);
183 }
184 } else {
185 tracing::warn!(
186 key_len = key_bytes.len(),
187 "unexpected AES key length, using raw bytes (truncated/padded to 16)"
188 );
189 let mut key = [0u8; 16];
190 let copy_len = key_bytes.len().min(16);
191 key[..copy_len].copy_from_slice(&key_bytes[..copy_len]);
192 *self.conn_aes_key.lock() = Some(key);
193 }
194 }
195
196 tracing::info!(
197 server_ver = info.server_ver,
198 conn_id = info.conn_id,
199 keep_alive_interval = info.keep_alive_interval,
200 "InitConnect succeeded"
201 );
202
203 let keep_alive_interval = Duration::from_secs(info.keep_alive_interval as u64);
205 self.start_background_tasks(conn, keep_alive_interval);
206
207 Ok(info)
208 }
209
210 pub async fn request(&self, proto_id: u32, body: Vec<u8>) -> Result<FutuFrame> {
212 let serial = self.next_serial();
213
214 let (final_body, sha1) = self.prepare_body(&body);
216
217 let mut frame = FutuFrame::new(proto_id, serial, Bytes::from(final_body));
218 frame.header.body_sha1 = sha1;
220
221 let (resp_tx, resp_rx) = oneshot::channel();
222 self.pending.insert(serial, PendingRequest { tx: resp_tx });
223
224 if let Some(cmd_tx) = &self.cmd_tx {
226 let (ack_tx, ack_rx) = oneshot::channel();
227 cmd_tx
228 .send(ClientCommand::Send(frame, ack_tx))
229 .await
230 .map_err(|_| FutuError::NotInitialized)?;
231 ack_rx
232 .await
233 .map_err(|_| FutuError::Codec("send ack failed".into()))?;
234 } else {
235 self.pending.remove(&serial);
236 return Err(FutuError::NotInitialized);
237 }
238
239 let resp = tokio::time::timeout(Duration::from_secs(12), resp_rx)
241 .await
242 .map_err(|_| {
243 self.pending.remove(&serial);
244 FutuError::Timeout
245 })?
246 .map_err(|_| FutuError::Codec("response channel closed".into()))?;
247
248 Ok(resp)
249 }
250
251 fn next_serial(&self) -> u32 {
252 self.serial_no.fetch_add(1, Ordering::Relaxed) + 1
253 }
254
255 fn prepare_body(&self, plaintext: &[u8]) -> (Vec<u8>, [u8; 20]) {
256 use sha1::{Digest, Sha1};
257
258 let mut hasher = Sha1::new();
260 hasher.update(plaintext);
261 let sha1_result = hasher.finalize();
262 let mut sha1 = [0u8; 20];
263 sha1.copy_from_slice(&sha1_result);
264
265 let body = if self.config.rsa_key.is_some() {
267 let key = self.conn_aes_key.lock();
268 match key.as_ref() {
269 Some(k) => encrypt::aes_ecb_encrypt(k, plaintext),
270 None => plaintext.to_vec(),
271 }
272 } else {
273 plaintext.to_vec()
274 };
275
276 (body, sha1)
277 }
278
279 fn start_background_tasks(&mut self, conn: Connection, keep_alive_interval: Duration) {
280 let (cmd_tx, cmd_rx) = mpsc::channel(256);
281 self.cmd_tx = Some(cmd_tx.clone());
282
283 let pending = Arc::clone(&self.pending);
284 let push_tx = self.push_tx.clone();
285 let aes_key = if self.config.rsa_key.is_some() {
286 *self.conn_aes_key.lock()
287 } else {
288 None
289 };
290
291 tokio::spawn(async move {
292 run_event_loop(conn, cmd_rx, pending, push_tx, aes_key, keep_alive_interval).await;
293 });
294 }
295}
296
297async fn run_event_loop(
299 mut conn: Connection,
300 mut cmd_rx: mpsc::Receiver<ClientCommand>,
301 pending: Arc<DashMap<u32, PendingRequest>>,
302 push_tx: mpsc::UnboundedSender<PushMessage>,
303 aes_key: Option<[u8; 16]>,
304 keep_alive_interval: Duration,
305) {
306 let mut heartbeat = tokio::time::interval(keep_alive_interval);
307 heartbeat.tick().await; let mut heartbeat_serial: u32 = 10_000_000; loop {
311 tokio::select! {
312 result = conn.recv() => {
314 match result {
315 Ok(Some(frame)) => {
316 let body = match &aes_key {
318 Some(key) if !frame.body.is_empty() => {
319 match encrypt::aes_ecb_decrypt(key, &frame.body) {
320 Ok(decrypted) => Bytes::from(decrypted),
321 Err(e) => {
322 tracing::warn!(error = %e, "decrypt failed, using raw body");
323 frame.body.clone()
324 }
325 }
326 }
327 _ => frame.body.clone(),
328 };
329
330 let proto_id = frame.header.proto_id;
331
332 if proto_id::is_push_proto(proto_id) {
333 let _ = push_tx.send(PushMessage { proto_id, body });
335 } else {
336 let serial = frame.header.serial_no;
338 if let Some((_, req)) = pending.remove(&serial) {
339 let resp_frame = FutuFrame {
340 header: frame.header,
341 body,
342 };
343 let _ = req.tx.send(resp_frame);
344 } else {
345 tracing::debug!(serial = serial, proto_id = proto_id, "unmatched response");
346 }
347 }
348 }
349 Ok(None) => {
350 tracing::warn!("connection closed by server");
351 break;
352 }
353 Err(e) => {
354 tracing::error!(error = %e, "recv error");
355 break;
356 }
357 }
358 }
359
360 cmd = cmd_rx.recv() => {
362 match cmd {
363 Some(ClientCommand::Send(frame, ack)) => {
364 let result = conn.send(frame).await;
365 if let Err(e) = &result {
366 tracing::error!(error = %e, "send failed");
367 }
368 let _ = ack.send(());
369 if result.is_err() {
370 break;
371 }
372 }
373 Some(ClientCommand::Shutdown) | None => {
374 tracing::info!("shutting down event loop");
375 break;
376 }
377 }
378 }
379
380 _ = heartbeat.tick() => {
382 heartbeat_serial += 1;
383 let req = futu_proto::keep_alive::Request {
384 c2s: futu_proto::keep_alive::C2s {
385 time: chrono::Utc::now().timestamp(),
386 },
387 };
388 let body = prost::Message::encode_to_vec(&req);
389 let frame = Connection::build_frame(
390 proto_id::KEEP_ALIVE,
391 heartbeat_serial,
392 body,
393 );
394 if let Err(e) = conn.send(frame).await {
395 tracing::error!(error = %e, "heartbeat send failed");
396 break;
397 }
398 tracing::trace!("heartbeat sent");
399 }
400 }
401 }
402
403 pending.clear();
405 tracing::info!("event loop exited");
406}
407
408#[derive(Debug, Clone)]
410pub struct InitConnectInfo {
411 pub server_ver: i32,
412 pub login_user_id: u64,
413 pub conn_id: u64,
414 pub conn_aes_key: String,
415 pub keep_alive_interval: i32,
416}
417
418pub struct ReconnectingClient {
420 config: ClientConfig,
421 policy: ReconnectPolicy,
422}
423
424impl ReconnectingClient {
425 pub fn new(config: ClientConfig) -> Self {
426 Self {
427 config,
428 policy: ReconnectPolicy::default_policy(),
429 }
430 }
431
432 pub fn with_policy(mut self, policy: ReconnectPolicy) -> Self {
433 self.policy = policy;
434 self
435 }
436
437 pub async fn connect(
442 &mut self,
443 ) -> Result<(
444 FutuClient,
445 mpsc::UnboundedReceiver<PushMessage>,
446 InitConnectInfo,
447 )> {
448 loop {
449 let (mut client, push_rx) = FutuClient::new(self.config.clone());
450 match client.connect().await {
451 Ok(info) => {
452 self.policy.reset();
453 return Ok((client, push_rx, info));
454 }
455 Err(e) => {
456 tracing::warn!(
457 error = %e,
458 attempt = self.policy.attempts(),
459 "connection failed"
460 );
461 match self.policy.next_delay() {
462 Some(delay) => {
463 tracing::info!(delay_ms = delay.as_millis(), "reconnecting...");
464 tokio::time::sleep(delay).await;
465 }
466 None => {
467 return Err(FutuError::Codec(format!(
468 "max retries reached after {} attempts",
469 self.policy.attempts()
470 )));
471 }
472 }
473 }
474 }
475 }
476 }
477}
478
479fn hex_str(bytes: &[u8]) -> String {
480 bytes.iter().map(|b| format!("{b:02x}")).collect()
481}
482
483fn hex_decode_16(hex_bytes: &[u8]) -> Option<[u8; 16]> {
484 if hex_bytes.len() != 32 {
485 return None;
486 }
487 let hex_str = std::str::from_utf8(hex_bytes).ok()?;
488 let mut key = [0u8; 16];
489 for i in 0..16 {
490 key[i] = u8::from_str_radix(&hex_str[i * 2..i * 2 + 2], 16).ok()?;
491 }
492 Some(key)
493}