Files
mobile-ios/Rosetta/Core/Network/Protocol/ProtocolManager.swift

1190 lines
45 KiB
Swift
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import Foundation
import os
import Observation
import UIKit
// MARK: - Connection State
enum ConnectionState: String {
case disconnected
case connecting
case connected
case handshaking
case deviceVerificationRequired
case authenticated
}
struct MalformedMessagePacketInfo: Sendable {
let packetSize: Int
let fingerprint: String
let messageIdHint: String
}
struct MalformedCriticalPacketInfo: Sendable {
let packetId: Int
let packetSize: Int
let fingerprint: String
}
// MARK: - ProtocolManager
/// Central networking coordinator. Owns WebSocket, routes packets, manages handshake.
@Observable
final class ProtocolManager: @unchecked Sendable {
static let shared = ProtocolManager()
private static let logger = Logger(subsystem: "com.rosetta.messenger", category: "Protocol")
// MARK: - Public State
private(set) var connectionState: ConnectionState = .disconnected
// MARK: - Device Verification State
/// Device waiting for approval from this device (shown as banner on primary device).
private(set) var pendingDeviceVerification: DeviceEntry?
/// All connected devices.
private(set) var devices: [DeviceEntry] = []
// MARK: - Callbacks
var onMessageReceived: ((PacketMessage) -> Void)?
var onDeliveryReceived: ((PacketDelivery) -> Void)?
var onReadReceived: ((PacketRead) -> Void)?
var onOnlineStateReceived: ((PacketOnlineState) -> Void)?
var onUserInfoReceived: ((PacketUserInfo) -> Void)?
var onSearchResult: ((PacketSearch) -> Void)?
var onTypingReceived: ((PacketTyping) -> Void)?
var onRequestUpdateReceived: ((PacketRequestUpdate) -> Void)?
var onCreateGroupReceived: ((PacketCreateGroup) -> Void)?
var onGroupInfoReceived: ((PacketGroupInfo) -> Void)?
var onGroupInviteInfoReceived: ((PacketGroupInviteInfo) -> Void)?
var onGroupJoinReceived: ((PacketGroupJoin) -> Void)?
var onGroupLeaveReceived: ((PacketGroupLeave) -> Void)?
var onGroupBanReceived: ((PacketGroupBan) -> Void)?
var onSyncReceived: ((PacketSync) -> Void)?
var onDeviceNewReceived: ((PacketDeviceNew) -> Void)?
var onSignalPeerReceived: ((PacketSignalPeer) -> Void)?
var onWebRTCReceived: ((PacketWebRTC) -> Void)?
var onIceServersReceived: ((PacketIceServers) -> Void)?
var onHandshakeCompleted: ((PacketHandshake) -> Void)?
var onMalformedMessageReceived: ((MalformedMessagePacketInfo) -> Void)?
var onMalformedCriticalPacketReceived: ((MalformedCriticalPacketInfo) -> Void)?
// MARK: - Private
private let client = WebSocketClient()
private var packetQueue: [any Packet] = []
private var queuedPacketKeys: Set<String> = []
private var handshakeComplete = false
private var heartbeatTask: Task<Void, Never>?
private var handshakeTimeoutTask: Task<Void, Never>?
private var pingTimeoutTask: Task<Void, Never>?
/// Guards against overlapping ping-first verifications on foreground.
private var pingVerificationInProgress = false
/// Android parity: sync batch flag set SYNCHRONOUSLY on receive queue.
/// Prevents race where MainActor Task for BATCH_START runs after message Task.
/// Written on URLSession delegate queue, read on MainActor protected by lock.
private let syncBatchLock = NSLock()
private var _syncBatchActive = false
/// Thread-safe read for SessionManager to check sync state without MainActor race.
var isSyncBatchActive: Bool {
syncBatchLock.lock()
let val = _syncBatchActive
syncBatchLock.unlock()
return val
}
private let resultHandlersLock = NSLock()
private let signalPeerHandlersLock = NSLock()
private let webRTCHandlersLock = NSLock()
private let iceServersHandlersLock = NSLock()
private let groupOneShotLock = NSLock()
private let packetQueueLock = NSLock()
private let searchRouter = SearchPacketRouter()
private var resultHandlers: [UUID: (PacketResult) -> Void] = [:]
private var signalPeerHandlers: [UUID: (PacketSignalPeer) -> Void] = [:]
private var webRTCHandlers: [UUID: (PacketWebRTC) -> Void] = [:]
private var iceServersHandlers: [UUID: (PacketIceServers) -> Void] = [:]
/// Generic one-shot handlers for group packets. Key: (packetId, handlerId).
/// Handler returns `true` if it consumed the packet (auto-removed).
private var groupOneShotHandlers: [UUID: (packetId: Int, handler: (any Packet) -> Bool)] = [:]
/// Background task to keep WebSocket alive during brief background periods (active call).
/// iOS gives ~30s; enough for the call to survive app switching / notification interactions.
private var callBackgroundTask: UIBackgroundTaskIdentifier = .invalid
// Saved credentials for auto-reconnect
private var savedPublicKey: String?
private var savedPrivateHash: String?
/// Pre-built handshake packet for instant send on socket open.
/// Built once in connect() on MainActor (safe UIDevice access), reused across reconnects.
private var cachedHandshakeData: Data?
/// Timestamp of last successful authentication used to decide whether to reset backoff.
/// If connection was short-lived (<10s), don't reset backoff counter (server RST loop).
private var lastAuthenticatedTime: CFAbsoluteTime = 0
var publicKey: String? { savedPublicKey }
var privateHash: String? { savedPrivateHash }
private init() {
setupClientCallbacks()
}
// MARK: - Connection
/// Connect to server and perform handshake.
func connect(publicKey: String, privateKeyHash: String) {
let switchingAccount = savedPublicKey != nil && savedPublicKey != publicKey
if switchingAccount {
Self.logger.info("Account switch detected — resetting protocol session before reconnect")
disconnect()
}
savedPublicKey = publicKey
savedPrivateHash = privateKeyHash
cachedHandshakeData = buildHandshakeData()
switch connectionState {
case .authenticated, .handshaking, .deviceVerificationRequired:
Self.logger.info("Already connected/handshaking, skipping")
return
case .connected:
if client.isConnected {
Self.logger.info("Socket already connected, skipping duplicate connect()")
return
}
case .connecting:
// Always skip if already in connecting state. Previous code only
// checked client.isConnecting which has a race gap between
// setting connectionState=.connecting and client.connect() setting
// isConnecting=true, a second call would slip through.
// This caused 3 parallel WebSocket connections (3x every packet).
Self.logger.info("Connect already in progress, skipping duplicate connect()")
return
case .disconnected:
break
}
connectionState = .connecting
client.connect()
}
func disconnect() {
Self.logger.info("Disconnecting")
heartbeatTask?.cancel()
heartbeatTask = nil
handshakeTimeoutTask?.cancel()
handshakeTimeoutTask = nil
pingTimeoutTask?.cancel()
pingTimeoutTask = nil
pingVerificationInProgress = false
handshakeComplete = false
clearPacketQueue()
clearResultHandlers()
searchRouter.resetPending()
syncBatchLock.lock()
_syncBatchActive = false
syncBatchLock.unlock()
pendingDeviceVerification = nil
devices = []
client.disconnect()
connectionState = .disconnected
savedPublicKey = nil
savedPrivateHash = nil
cachedHandshakeData = nil
lastAuthenticatedTime = 0
Task { @MainActor in
TransportManager.shared.reset()
}
}
/// Android parity: `reconnectNowIfNeeded(reason: "foreground")`.
/// On foreground resume, always force reconnect iOS suspends the process in
/// background, server RSTs TCP, but `didCloseWith` never fires (zombie socket).
/// Android doesn't have this because OkHttp fires onFailure in background.
/// Previously iOS used ping-first (3s timeout) which was too slow.
func forceReconnectOnForeground() {
guard savedPublicKey != nil, savedPrivateHash != nil else { return }
// During an active call, WebRTC media flows via DTLS/SRTP (not WebSocket).
// Tearing down the socket would trigger server re-delivery of .call and
// cause unnecessary signaling disruption (endCallBecauseBusy).
// For .active phase: skip entirely WS is only needed for endCall signal,
// which will work after natural reconnect or ICE timeout ends the call.
// For other call phases: skip only if WS is authenticated (still alive).
if CallManager.shared.uiState.phase == .active {
Self.logger.info("⚡ Foreground reconnect skipped — call active, media via DTLS")
return
}
if CallManager.shared.uiState.phase != .idle,
connectionState == .authenticated {
Self.logger.info("⚡ Foreground reconnect skipped — call in progress, WS authenticated")
return
}
// Android parity: skip if handshake or device verification is in progress.
// These are active flows that should not be interrupted.
switch connectionState {
case .handshaking, .deviceVerificationRequired:
return
case .connecting:
// Same fix as connect() unconditional return to prevent triple connections
return
case .authenticated, .connected, .disconnected:
break // Always reconnect .authenticated/.connected may be zombie on iOS
}
Self.logger.info("⚡ Foreground reconnect — tearing down potential zombie socket")
pingVerificationInProgress = false
pingTimeoutTask?.cancel()
pingTimeoutTask = nil
handshakeComplete = false
heartbeatTask?.cancel()
searchRouter.resetPending()
// User-initiated foreground allow fast retry on next disconnect.
lastAuthenticatedTime = 0
connectionState = .connecting
client.forceReconnect()
}
// MARK: - Call Background Task
/// Keeps the process alive during active calls so WebSocket survives brief background.
func beginCallBackgroundTask() {
guard callBackgroundTask == .invalid else { return }
let remaining = UIApplication.shared.backgroundTimeRemaining
Self.logger.info("📞 Background task starting — remaining=\(remaining, privacy: .public)s wsState=\(String(describing: self.connectionState), privacy: .public)")
callBackgroundTask = UIApplication.shared.beginBackgroundTask(withName: "RosettaCall") { [weak self] in
// Don't end the call here CallKit keeps the process alive for active calls.
// This background task only buys time for WebSocket reconnection.
// Killing the call on expiry was causing premature call termination
// during keyExchange phase (~30s before Desktop could respond).
Self.logger.info("📞 Background task EXPIRED — OS reclaiming")
self?.endCallBackgroundTask()
}
Self.logger.info("📞 Background task started for call")
}
func endCallBackgroundTask() {
guard callBackgroundTask != .invalid else { return }
let task = callBackgroundTask
callBackgroundTask = .invalid
UIApplication.shared.endBackgroundTask(task)
Self.logger.info("📞 Background task ended for call")
}
/// Android parity: `reconnectNowIfNeeded()` if already in an active state,
/// skip reconnect. Otherwise reset backoff and connect immediately.
func reconnectIfNeeded() {
guard savedPublicKey != nil, savedPrivateHash != nil else { return }
// Android parity (Protocol.kt:651-658): skip if already in any active state.
switch connectionState {
case .authenticated, .handshaking, .deviceVerificationRequired, .connected:
return
case .connecting:
// Unconditional return prevent duplicate connections (same fix as connect())
return
case .disconnected:
break
}
// Reset backoff and connect immediately.
Self.logger.info("⚡ Fast reconnect — state=\(self.connectionState.rawValue)")
handshakeComplete = false
heartbeatTask?.cancel()
searchRouter.resetPending()
connectionState = .connecting
client.forceReconnect()
}
/// Ping-first zombie socket detection for foreground resume.
/// iOS suspends the process in background server RSTs TCP, but `didCloseWith`
/// delegate never fires. `connectionState` stays `.authenticated` (stale).
/// This method sends a WebSocket ping: pong alive, no pong force reconnect.
func verifyConnectionOrReconnect() {
guard savedPublicKey != nil, savedPrivateHash != nil else { return }
guard connectionState == .authenticated || connectionState == .connected else { return }
guard !pingVerificationInProgress else { return }
pingVerificationInProgress = true
Self.logger.info("🏓 Verifying connection with ping after foreground...")
client.sendPing { [weak self] error in
guard let self, self.pingVerificationInProgress else { return }
self.pingVerificationInProgress = false
self.pingTimeoutTask?.cancel()
self.pingTimeoutTask = nil
if let error {
Self.logger.warning("🏓 Ping failed — zombie socket: \(error.localizedDescription)")
self.handlePingFailure()
} else {
Self.logger.info("🏓 Pong received — connection alive")
}
}
// Safety timeout: if sendPing never calls back (completely dead socket), force reconnect.
pingTimeoutTask?.cancel()
pingTimeoutTask = Task { [weak self] in
try? await Task.sleep(for: .seconds(3))
guard let self, !Task.isCancelled, self.pingVerificationInProgress else { return }
self.pingVerificationInProgress = false
Self.logger.warning("🏓 Ping timeout (3s) — zombie socket, forcing reconnect")
self.handlePingFailure()
}
}
private func handlePingFailure() {
pingTimeoutTask?.cancel()
pingTimeoutTask = nil
handshakeComplete = false
heartbeatTask?.cancel()
searchRouter.resetPending()
Task { @MainActor in
// Guard: only downgrade to .connecting if reconnect hasn't already progressed.
// forceReconnect() is called synchronously below if it completes fast,
// this async Task could overwrite .authenticated/.handshaking with .connecting.
let s = self.connectionState
if s != .authenticated && s != .handshaking && s != .connected {
self.connectionState = .connecting
}
}
client.forceReconnect()
}
// MARK: - Sending
func sendSearchPacket(
_ packet: PacketSearch,
channel: SearchPacketChannel = .unscoped
) {
searchRouter.enqueueOutgoingRequest(channel: channel)
sendPacket(packet)
}
func sendCallSignal(
signalType: SignalType,
src: String = "",
dst: String = "",
sharedPublic: String = "",
callId: String = "",
joinToken: String = "",
roomId: String = ""
) {
var packet = PacketSignalPeer()
packet.signalType = signalType
packet.src = src
packet.dst = dst
packet.sharedPublic = sharedPublic
packet.callId = callId
packet.joinToken = joinToken
packet.roomId = roomId
sendPacket(packet)
}
func sendWebRtcSignal(signalType: WebRTCSignalType, sdpOrCandidate: String) {
var packet = PacketWebRTC()
packet.signalType = signalType
packet.sdpOrCandidate = sdpOrCandidate
packet.publicKey = SessionManager.shared.currentPublicKey
packet.deviceId = DeviceIdentityManager.shared.currentDeviceId()
sendPacket(packet)
}
func requestIceServers() {
sendPacket(PacketIceServers())
}
func sendPacket(_ packet: any Packet) {
PerformanceLogger.shared.track("protocol.sendPacket")
let id = String(type(of: packet).packetId, radix: 16)
// Android parity (Protocol.kt:436-448): triple check handshakeComplete + socket alive + authenticated.
let isAuth = connectionState == .authenticated
if (!handshakeComplete && !(packet is PacketHandshake)) || !client.isConnected || !isAuth {
Self.logger.info("⏳ Queueing packet 0x\(id) — connected=\(self.client.isConnected), handshake=\(self.handshakeComplete), auth=\(isAuth)")
enqueuePacket(packet)
return
}
Self.logger.info("📤 Sending packet 0x\(id) directly")
sendPacketDirect(packet)
}
// MARK: - Search Handlers (Android-like wait/unwait)
@discardableResult
func addSearchResultHandler(
channel: SearchPacketChannel = .unscoped,
_ handler: @escaping (PacketSearch) -> Void
) -> UUID {
searchRouter.addHandler(channel: channel, handler)
}
func removeSearchResultHandler(_ id: UUID) {
searchRouter.removeHandler(id)
}
// MARK: - Call Packet Handlers (Android-like wait/unwait)
@discardableResult
func addSignalPeerHandler(_ handler: @escaping (PacketSignalPeer) -> Void) -> UUID {
let id = UUID()
signalPeerHandlersLock.lock()
signalPeerHandlers[id] = handler
signalPeerHandlersLock.unlock()
return id
}
func removeSignalPeerHandler(_ id: UUID) {
signalPeerHandlersLock.lock()
signalPeerHandlers.removeValue(forKey: id)
signalPeerHandlersLock.unlock()
}
@discardableResult
func addWebRtcHandler(_ handler: @escaping (PacketWebRTC) -> Void) -> UUID {
let id = UUID()
webRTCHandlersLock.lock()
webRTCHandlers[id] = handler
webRTCHandlersLock.unlock()
return id
}
func removeWebRtcHandler(_ id: UUID) {
webRTCHandlersLock.lock()
webRTCHandlers.removeValue(forKey: id)
webRTCHandlersLock.unlock()
}
@discardableResult
func addIceServersHandler(_ handler: @escaping (PacketIceServers) -> Void) -> UUID {
let id = UUID()
iceServersHandlersLock.lock()
iceServersHandlers[id] = handler
iceServersHandlersLock.unlock()
return id
}
func removeIceServersHandler(_ id: UUID) {
iceServersHandlersLock.lock()
iceServersHandlers.removeValue(forKey: id)
iceServersHandlersLock.unlock()
}
// MARK: - Result Handlers (Android parity: waitPacket(0x02))
/// Register a one-shot handler for PacketResult (0x02).
@discardableResult
func addResultHandler(_ handler: @escaping (PacketResult) -> Void) -> UUID {
let id = UUID()
resultHandlersLock.lock()
resultHandlers[id] = handler
resultHandlersLock.unlock()
return id
}
func removeResultHandler(_ id: UUID) {
resultHandlersLock.lock()
resultHandlers.removeValue(forKey: id)
resultHandlersLock.unlock()
}
// MARK: - Group One-Shot Handlers
/// Registers a one-shot handler for a specific packet type.
/// Handler receives the packet and returns `true` to consume (auto-remove), `false` to keep.
@discardableResult
func addGroupOneShotHandler(packetId: Int, handler: @escaping (any Packet) -> Bool) -> UUID {
let id = UUID()
groupOneShotLock.lock()
groupOneShotHandlers[id] = (packetId, handler)
groupOneShotLock.unlock()
return id
}
func removeGroupOneShotHandler(_ id: UUID) {
groupOneShotLock.lock()
groupOneShotHandlers.removeValue(forKey: id)
groupOneShotLock.unlock()
}
/// Called from `routeIncomingPacket` dispatches to matching one-shot handlers.
private func notifyGroupOneShotHandlers(packetId: Int, packet: any Packet) {
groupOneShotLock.lock()
let matching = groupOneShotHandlers.filter { $0.value.packetId == packetId }
groupOneShotLock.unlock()
var consumed: [UUID] = []
for (id, entry) in matching {
if entry.handler(packet) {
consumed.append(id)
}
}
if !consumed.isEmpty {
groupOneShotLock.lock()
for id in consumed {
groupOneShotHandlers.removeValue(forKey: id)
}
groupOneShotLock.unlock()
}
}
// MARK: - Private Setup
private func setupClientCallbacks() {
client.onConnected = { [weak self] in
guard let self else { return }
Self.logger.info("WebSocket connected")
Task { @MainActor in
self.connectionState = .connected
}
// Send pre-built handshake immediately no packet construction on critical path.
if let data = cachedHandshakeData {
Self.logger.info("⚡ Sending pre-built handshake packet")
Task { @MainActor in
self.connectionState = .handshaking
}
client.send(data)
startHandshakeTimeout()
} else if let pk = savedPublicKey, let hash = savedPrivateHash {
// Fallback: build handshake on the fly
startHandshake(publicKey: pk, privateHash: hash)
}
}
client.onDisconnected = { [weak self] error in
guard let self else { return }
if let error {
Self.logger.error("Disconnected: \(error.localizedDescription)")
}
heartbeatTask?.cancel()
heartbeatTask = nil
handshakeTimeoutTask?.cancel()
handshakeTimeoutTask = nil
handshakeComplete = false
pingVerificationInProgress = false
pingTimeoutTask?.cancel()
pingTimeoutTask = nil
self.searchRouter.resetPending()
Task { @MainActor in
self.connectionState = .disconnected
}
}
client.onDataReceived = { [weak self] data in
self?.handleIncomingData(data)
}
// Instant reconnect when network is restored (Wi-Fi cellular, airplane mode off, etc.)
client.onNetworkRestored = { [weak self] in
guard let self, self.savedPublicKey != nil else { return }
Self.logger.info("Network restored — force reconnecting")
self.handshakeComplete = false
self.heartbeatTask?.cancel()
Task { @MainActor in
// Guard: only downgrade to .connecting if reconnect hasn't already progressed.
let s = self.connectionState
if s != .authenticated && s != .handshaking && s != .connected {
self.connectionState = .connecting
}
}
self.client.forceReconnect()
}
}
// MARK: - Handshake
/// Build serialized handshake packet from saved credentials.
/// Called from MainActor context safe to access UIDevice.
private func buildHandshakeData() -> Data? {
guard let pk = savedPublicKey, let hash = savedPrivateHash else { return nil }
let device = HandshakeDevice(
deviceId: DeviceIdentityManager.shared.currentDeviceId(),
deviceName: UIDevice.current.name,
deviceOs: "iOS \(UIDevice.current.systemVersion)"
)
let handshake = PacketHandshake(
privateKey: hash,
publicKey: pk,
protocolVersion: 1,
heartbeatInterval: 15,
device: device,
handshakeState: .needDeviceVerification
)
return PacketRegistry.encode(handshake)
}
/// Fallback handshake builds packet on the fly when cached data is unavailable.
private func startHandshake(publicKey: String, privateHash: String) {
Self.logger.info("Starting handshake for \(publicKey.prefix(20))...")
Task { @MainActor in
connectionState = .handshaking
}
let device = HandshakeDevice(
deviceId: DeviceIdentityManager.shared.currentDeviceId(),
deviceName: UIDevice.current.name,
deviceOs: "iOS \(UIDevice.current.systemVersion)"
)
let handshake = PacketHandshake(
privateKey: privateHash,
publicKey: publicKey,
protocolVersion: 1,
heartbeatInterval: 15,
device: device,
handshakeState: .needDeviceVerification
)
sendPacketDirect(handshake)
startHandshakeTimeout()
}
private func startHandshakeTimeout() {
// 5s is generous for a single packet round-trip. Faster detection
// means faster recovery via instant first retry (0ms backoff).
handshakeTimeoutTask?.cancel()
handshakeTimeoutTask = Task { [weak self] in
do {
try await Task.sleep(nanoseconds: 5_000_000_000)
} catch {
return
}
guard let self, !Task.isCancelled else { return }
if !self.handshakeComplete {
Self.logger.error("Handshake timeout (5s) — forcing reconnect")
self.handshakeComplete = false
self.heartbeatTask?.cancel()
Task { @MainActor in
let s = self.connectionState
if s != .authenticated && s != .handshaking && s != .connected {
self.connectionState = .connecting
}
}
self.client.forceReconnect()
}
}
}
// MARK: - Packet Handling
private func handleIncomingData(_ data: Data) {
#if DEBUG
if data.count >= 2 {
let peekStream = Stream(data: data)
let rawId = peekStream.readInt16()
Self.logger.debug("📥 Incoming packet 0x\(String(rawId, radix: 16)), size: \(data.count)")
}
#endif
guard let (packetId, packet) = PacketRegistry.decode(from: data) else {
#if DEBUG
if data.count >= 2 {
let stream = Stream(data: data)
let rawId = stream.readInt16()
Self.logger.debug("Unknown packet ID: 0x\(String(rawId, radix: 16)), size: \(data.count)")
}
#endif
return
}
switch packetId {
case 0x00:
if let p = packet as? PacketHandshake {
if p.isMalformed {
reportMalformedCriticalPacket(
packetId: packetId,
packetSize: data.count,
fingerprint: p.malformedFingerprint,
fallbackFingerprint: "packet00_parse_failed"
)
return
}
handleHandshakeResponse(p)
}
case 0x01:
if let p = packet as? PacketUserInfo {
onUserInfoReceived?(p)
}
case 0x02:
if let p = packet as? PacketResult {
Self.logger.info("📥 PacketResult: code=\(p.resultCode)")
notifyResultHandlers(p)
}
case 0x03:
if let p = packet as? PacketSearch {
let routedChannel = routeIncomingSearchPacket(p)
Self.logger.debug(
"📥 Search result: \(p.users.count) users, callback=\(self.onSearchResult != nil), routed=\(String(describing: routedChannel))"
)
onSearchResult?(p)
}
case 0x05:
if let p = packet as? PacketOnlineState {
onOnlineStateReceived?(p)
}
case 0x06:
if let p = packet as? PacketMessage {
if p.isMalformed {
let messageIdHint = p.messageId.isEmpty ? "-" : String(p.messageId.prefix(8))
let fingerprint = p.malformedFingerprint.isEmpty ? "packet06_parse_failed" : p.malformedFingerprint
reportMalformedCriticalPacket(
packetId: packetId,
packetSize: data.count,
fingerprint: fingerprint,
fallbackFingerprint: "packet06_parse_failed",
messageIdHint: messageIdHint
)
onMalformedMessageReceived?(
MalformedMessagePacketInfo(
packetSize: data.count,
fingerprint: fingerprint,
messageIdHint: messageIdHint
)
)
return
}
onMessageReceived?(p)
}
case 0x07:
if let p = packet as? PacketRead {
onReadReceived?(p)
}
case 0x08:
if let p = packet as? PacketDelivery {
onDeliveryReceived?(p)
}
case 0x09:
if let p = packet as? PacketDeviceNew {
onDeviceNewReceived?(p)
}
case 0x0A:
if let p = packet as? PacketRequestUpdate {
onRequestUpdateReceived?(p)
}
case 0x0B:
if let p = packet as? PacketTyping {
onTypingReceived?(p)
}
case 0x11:
if let p = packet as? PacketCreateGroup {
onCreateGroupReceived?(p)
notifyGroupOneShotHandlers(packetId: 0x11, packet: p)
}
case 0x12:
if let p = packet as? PacketGroupInfo {
onGroupInfoReceived?(p)
notifyGroupOneShotHandlers(packetId: 0x12, packet: p)
}
case 0x13:
if let p = packet as? PacketGroupInviteInfo {
onGroupInviteInfoReceived?(p)
notifyGroupOneShotHandlers(packetId: 0x13, packet: p)
}
case 0x14:
if let p = packet as? PacketGroupJoin {
onGroupJoinReceived?(p)
notifyGroupOneShotHandlers(packetId: 0x14, packet: p)
}
case 0x15:
if let p = packet as? PacketGroupLeave {
onGroupLeaveReceived?(p)
notifyGroupOneShotHandlers(packetId: 0x15, packet: p)
}
case 0x16:
if let p = packet as? PacketGroupBan {
onGroupBanReceived?(p)
notifyGroupOneShotHandlers(packetId: 0x16, packet: p)
}
case 0x0F:
if let p = packet as? PacketRequestTransport {
Self.logger.info("📥 Transport server: \(p.transportServer)")
Task { @MainActor in
TransportManager.shared.setTransportServer(p.transportServer)
}
}
case 0x17:
if let p = packet as? PacketDeviceList {
handleDeviceList(p)
}
case 0x18:
if let p = packet as? PacketDeviceResolve {
handleDeviceResolve(p)
}
case 0x19:
if let p = packet as? PacketSync {
if p.isMalformed {
reportMalformedCriticalPacket(
packetId: packetId,
packetSize: data.count,
fingerprint: p.malformedFingerprint,
fallbackFingerprint: "packet19_parse_failed"
)
return
}
// Android parity: set sync flag SYNCHRONOUSLY on receive queue
// BEFORE dispatching to MainActor callback. This prevents the race
// where a 0x06 message Task runs on MainActor before BATCH_START Task.
if p.status == .batchStart {
syncBatchLock.lock()
_syncBatchActive = true
syncBatchLock.unlock()
} else if p.status == .notNeeded {
syncBatchLock.lock()
_syncBatchActive = false
syncBatchLock.unlock()
}
onSyncReceived?(p)
}
case 0x1A:
if let p = packet as? PacketSignalPeer {
onSignalPeerReceived?(p)
notifySignalPeerHandlers(p)
}
case 0x1B:
if let p = packet as? PacketWebRTC {
onWebRTCReceived?(p)
notifyWebRtcHandlers(p)
}
case 0x1C:
if let p = packet as? PacketIceServers {
onIceServersReceived?(p)
notifyIceServersHandlers(p)
}
default:
break
}
}
private func routeIncomingSearchPacket(_ packet: PacketSearch) -> SearchPacketChannel {
searchRouter.dispatchIncomingResponse(packet)
}
private func notifyResultHandlers(_ packet: PacketResult) {
resultHandlersLock.lock()
let handlers = resultHandlers
// One-shot: clear all handlers after dispatch (Android parity)
resultHandlers.removeAll()
resultHandlersLock.unlock()
for (_, handler) in handlers {
handler(packet)
}
}
private func notifySignalPeerHandlers(_ packet: PacketSignalPeer) {
signalPeerHandlersLock.lock()
let handlers = signalPeerHandlers.values
signalPeerHandlersLock.unlock()
for handler in handlers {
handler(packet)
}
}
private func notifyWebRtcHandlers(_ packet: PacketWebRTC) {
webRTCHandlersLock.lock()
let handlers = webRTCHandlers.values
webRTCHandlersLock.unlock()
for handler in handlers {
handler(packet)
}
}
private func notifyIceServersHandlers(_ packet: PacketIceServers) {
iceServersHandlersLock.lock()
let handlers = iceServersHandlers.values
iceServersHandlersLock.unlock()
for handler in handlers {
handler(packet)
}
}
private func reportMalformedCriticalPacket(
packetId: Int,
packetSize: Int,
fingerprint: String,
fallbackFingerprint: String,
messageIdHint: String? = nil
) {
let packetHex = String(format: "0x%02X", packetId)
let normalizedFingerprint = Self.compactFingerprint(
fingerprint.isEmpty ? fallbackFingerprint : fingerprint
)
if let messageIdHint {
Self.logger.error(
"Dropping malformed \(packetHex) packet size=\(packetSize) msgHint=\(messageIdHint) fp=\(normalizedFingerprint)"
)
} else {
Self.logger.error(
"Dropping malformed \(packetHex) packet size=\(packetSize) fp=\(normalizedFingerprint)"
)
}
onMalformedCriticalPacketReceived?(
MalformedCriticalPacketInfo(
packetId: packetId,
packetSize: packetSize,
fingerprint: normalizedFingerprint
)
)
}
private static func compactFingerprint(_ fingerprint: String) -> String {
let sanitized = fingerprint
.replacingOccurrences(of: "\n", with: " ")
.replacingOccurrences(of: "\t", with: " ")
guard sanitized.count > 120 else { return sanitized }
return String(sanitized.prefix(120))
}
private func handleHandshakeResponse(_ packet: PacketHandshake) {
handshakeTimeoutTask?.cancel()
handshakeTimeoutTask = nil
switch packet.handshakeState {
case .completed:
handshakeComplete = true
// Reset backoff only if previous connection was stable (>10s).
// Prevents tight reconnect loop when server/proxy RSTs connections
// shortly after sync. Without this, resetReconnectAttempts on every auth
// means backoff always starts at 1s (attempt #1) = infinite 1s loop.
let connectionAge = CFAbsoluteTimeGetCurrent() - lastAuthenticatedTime
if lastAuthenticatedTime == 0 || connectionAge > 10 {
client.resetReconnectAttempts()
} else {
Self.logger.info("Short-lived connection (\(Int(connectionAge))s) — keeping backoff counter")
}
lastAuthenticatedTime = CFAbsoluteTimeGetCurrent()
Self.logger.info("Handshake completed. Protocol v\(packet.protocolVersion), heartbeat \(packet.heartbeatInterval)s")
flushPacketQueue()
startHeartbeat(interval: packet.heartbeatInterval)
// Desktop parity: request transport server URL after handshake.
sendPacketDirect(PacketRequestTransport())
// CRITICAL: set .authenticated and fire callback in ONE MainActor task.
// Previously these were separate tasks Swift doesn't guarantee FIFO
// ordering of unstructured tasks, so requestSynchronize() could race
// with the state change and silently drop the sync request.
let callback = self.onHandshakeCompleted
Task { @MainActor in
self.connectionState = .authenticated
callback?(packet)
}
case .needDeviceVerification:
handshakeComplete = false
Self.logger.info("Server requires device verification — approve this device from your other Rosetta app")
searchRouter.resetPending()
Task { @MainActor in
self.connectionState = .deviceVerificationRequired
}
// Android parity (Protocol.kt:163): clear packet queue on device verification.
clearPacketQueue()
startHeartbeat(interval: packet.heartbeatInterval)
}
}
// MARK: - Heartbeat
private func startHeartbeat(interval: Int) {
heartbeatTask?.cancel()
// Send heartbeat every 5 seconds aggressive keep-alive to prevent
// server/proxy idle timeouts. Server timeout is heartbeat*2 = 60s,
// so 5s gives 12× safety margin.
let intervalNs: UInt64 = 5_000_000_000
// Send first heartbeat SYNCHRONOUSLY on current thread (URLSession delegate queue).
// This bypasses the connectionState race: startHeartbeat() is called BEFORE
// the MainActor task sets .authenticated, so sendHeartbeat()'s guard would
// skip the first heartbeat. Direct sendText avoids this.
if client.isConnected {
client.sendText("heartbeat")
}
heartbeatTask = Task { [weak self] in
while !Task.isCancelled {
try? await Task.sleep(nanoseconds: intervalNs)
guard !Task.isCancelled else { break }
self?.sendHeartbeat()
}
}
}
/// Send heartbeat and trigger disconnect on failure.
private func sendHeartbeat() {
// Allow heartbeat when handshake is complete (covers the gap before
// MainActor sets .authenticated) or in device verification.
guard handshakeComplete || connectionState == .deviceVerificationRequired else { return }
guard client.isConnected else {
Self.logger.warning("💔 Heartbeat failed: socket not connected — triggering reconnect")
handleHeartbeatFailure()
return
}
client.sendText("heartbeat")
}
/// Android parity: failed heartbeat handleDisconnect.
private func handleHeartbeatFailure() {
heartbeatTask?.cancel()
handshakeComplete = false
Task { @MainActor in
self.connectionState = .disconnected
}
// Let WebSocketClient's own handleDisconnect schedule reconnect.
client.forceReconnect()
}
// MARK: - Packet Queue
private func sendPacketDirect(_ packet: any Packet) {
let data = PacketRegistry.encode(packet)
Self.logger.info("Sending packet 0x\(String(type(of: packet).packetId, radix: 16)) (\(data.count) bytes)")
if !client.send(data, onFailure: { [weak self] _ in
guard let self else { return }
Self.logger.warning("Send failed, re-queueing packet 0x\(String(type(of: packet).packetId, radix: 16))")
self.enqueuePacket(packet)
}) {
Self.logger.warning("WebSocket unavailable, re-queueing packet 0x\(String(type(of: packet).packetId, radix: 16))")
enqueuePacket(packet)
}
}
private func flushPacketQueue() {
let packets = drainPacketQueue()
Self.logger.info("Flushing \(packets.count) queued packets")
for packet in packets {
sendPacketDirect(packet)
}
}
private func enqueuePacket(_ packet: any Packet) {
packetQueueLock.lock()
if let key = packetQueueKey(packet), queuedPacketKeys.contains(key) {
packetQueueLock.unlock()
return
}
packetQueue.append(packet)
if let key = packetQueueKey(packet) {
queuedPacketKeys.insert(key)
}
let count = packetQueue.count
packetQueueLock.unlock()
Self.logger.info("Queueing packet 0x\(String(type(of: packet).packetId, radix: 16)) (queue=\(count))")
}
private func drainPacketQueue() -> [any Packet] {
packetQueueLock.lock()
let packets = packetQueue
packetQueue.removeAll()
queuedPacketKeys.removeAll()
packetQueueLock.unlock()
return packets
}
private func clearPacketQueue() {
packetQueueLock.lock()
packetQueue.removeAll()
queuedPacketKeys.removeAll()
packetQueueLock.unlock()
}
private func clearResultHandlers() {
resultHandlersLock.lock()
resultHandlers.removeAll()
resultHandlersLock.unlock()
}
// MARK: - Device Verification
private func handleDeviceList(_ packet: PacketDeviceList) {
Self.logger.info("📱 Device list received: \(packet.devices.count) devices")
for device in packet.devices {
Self.logger.info(" - \(device.deviceName) (\(device.deviceOs)) status=\(device.deviceStatus.rawValue) verify=\(device.deviceVerify.rawValue)")
}
Task { @MainActor in
self.devices = packet.devices
self.pendingDeviceVerification = packet.devices.first { $0.deviceVerify == .notVerified }
}
}
private func handleDeviceResolve(_ packet: PacketDeviceResolve) {
Self.logger.info("🔐 Device resolve received: deviceId=\(packet.deviceId.prefix(20)), solution=\(packet.solution.rawValue)")
if packet.solution == .decline {
Self.logger.info("🚫 This device was DECLINED — disconnecting")
disconnect()
}
// If accepted, server will re-send handshake with .completed
// which is handled by handleHandshakeResponse
}
/// Accept a pending device login from another device.
func acceptDevice(_ deviceId: String) {
Self.logger.info("✅ Accepting device: \(deviceId.prefix(20))")
var packet = PacketDeviceResolve()
packet.deviceId = deviceId
packet.solution = .accept
sendPacketDirect(packet)
Task { @MainActor in
self.pendingDeviceVerification = nil
}
}
/// Decline a pending device login from another device.
func declineDevice(_ deviceId: String) {
Self.logger.info("❌ Declining device: \(deviceId.prefix(20))")
var packet = PacketDeviceResolve()
packet.deviceId = deviceId
packet.solution = .decline
sendPacketDirect(packet)
Task { @MainActor in
self.pendingDeviceVerification = nil
}
}
private func packetQueueKey(_ packet: any Packet) -> String? {
switch packet {
case let message as PacketMessage:
return message.messageId.isEmpty ? nil : "0x06:\(message.messageId)"
case let delivery as PacketDelivery:
return delivery.messageId.isEmpty ? nil : "0x08:\(delivery.toPublicKey):\(delivery.messageId)"
case let read as PacketRead:
guard !read.fromPublicKey.isEmpty, !read.toPublicKey.isEmpty else { return nil }
return "0x07:\(read.fromPublicKey):\(read.toPublicKey)"
case let typing as PacketTyping:
guard !typing.fromPublicKey.isEmpty, !typing.toPublicKey.isEmpty else { return nil }
return "0x0b:\(typing.fromPublicKey):\(typing.toPublicKey)"
case let sync as PacketSync:
return "0x19:\(sync.status.rawValue):\(sync.timestamp)"
default:
return nil
}
}
// MARK: - Test Support
func testHandleIncomingData(_ data: Data) {
handleIncomingData(data)
}
}