Files
mobile-ios/Rosetta/Core/Network/Protocol/ProtocolManager.swift

680 lines
24 KiB
Swift

import Foundation
import os
import Observation
import UIKit
// MARK: - Connection State
enum ConnectionState: String {
case disconnected
case connecting
case connected
case handshaking
case deviceVerificationRequired
case authenticated
}
// MARK: - ProtocolManager
/// Central networking coordinator. Owns WebSocket, routes packets, manages handshake.
@Observable
final class ProtocolManager: @unchecked Sendable {
static let shared = ProtocolManager()
private static let logger = Logger(subsystem: "com.rosetta.messenger", category: "Protocol")
// MARK: - Public State
private(set) var connectionState: ConnectionState = .disconnected
// MARK: - Device Verification State
/// Device waiting for approval from this device (shown as banner on primary device).
private(set) var pendingDeviceVerification: DeviceEntry?
/// All connected devices.
private(set) var devices: [DeviceEntry] = []
// MARK: - Callbacks
var onMessageReceived: ((PacketMessage) -> Void)?
var onDeliveryReceived: ((PacketDelivery) -> Void)?
var onReadReceived: ((PacketRead) -> Void)?
var onOnlineStateReceived: ((PacketOnlineState) -> Void)?
var onUserInfoReceived: ((PacketUserInfo) -> Void)?
var onSearchResult: ((PacketSearch) -> Void)?
var onTypingReceived: ((PacketTyping) -> Void)?
var onSyncReceived: ((PacketSync) -> Void)?
var onDeviceNewReceived: ((PacketDeviceNew) -> Void)?
var onHandshakeCompleted: ((PacketHandshake) -> Void)?
// MARK: - Private
private let client = WebSocketClient()
private var packetQueue: [any Packet] = []
private var queuedPacketKeys: Set<String> = []
private var handshakeComplete = false
private var heartbeatTask: Task<Void, Never>?
private var handshakeTimeoutTask: Task<Void, Never>?
private var pingTimeoutTask: Task<Void, Never>?
/// Guards against overlapping ping-first verifications on foreground.
private var pingVerificationInProgress = false
/// Android parity: sync batch flag set SYNCHRONOUSLY on receive queue.
/// Prevents race where MainActor Task for BATCH_START runs after message Task.
/// Written on URLSession delegate queue, read on MainActor protected by lock.
private let syncBatchLock = NSLock()
private var _syncBatchActive = false
/// Thread-safe read for SessionManager to check sync state without MainActor race.
var isSyncBatchActive: Bool {
syncBatchLock.lock()
let val = _syncBatchActive
syncBatchLock.unlock()
return val
}
private let searchHandlersLock = NSLock()
private let resultHandlersLock = NSLock()
private let packetQueueLock = NSLock()
private var searchResultHandlers: [UUID: (PacketSearch) -> Void] = [:]
private var resultHandlers: [UUID: (PacketResult) -> Void] = [:]
// Saved credentials for auto-reconnect
private var savedPublicKey: String?
private var savedPrivateHash: String?
var publicKey: String? { savedPublicKey }
var privateHash: String? { savedPrivateHash }
private init() {
setupClientCallbacks()
}
// MARK: - Connection
/// Connect to server and perform handshake.
func connect(publicKey: String, privateKeyHash: String) {
savedPublicKey = publicKey
savedPrivateHash = privateKeyHash
if connectionState == .authenticated || connectionState == .handshaking {
Self.logger.info("Already connected/handshaking, skipping")
return
}
connectionState = .connecting
client.connect()
}
func disconnect() {
Self.logger.info("Disconnecting")
heartbeatTask?.cancel()
handshakeTimeoutTask?.cancel()
pingTimeoutTask?.cancel()
pingTimeoutTask = nil
pingVerificationInProgress = false
handshakeComplete = false
client.disconnect()
connectionState = .disconnected
savedPublicKey = nil
savedPrivateHash = nil
Task { @MainActor in
TransportManager.shared.reset()
}
}
/// Android parity: `reconnectNowIfNeeded()` if already in an active state,
/// skip reconnect. Otherwise reset backoff and connect immediately.
func reconnectIfNeeded() {
guard savedPublicKey != nil, savedPrivateHash != nil else { return }
// Android parity (Protocol.kt:651-658): skip if already in any active state.
switch connectionState {
case .authenticated, .handshaking, .deviceVerificationRequired, .connected:
return
case .connecting:
// Android parity: `(CONNECTING && isConnecting)` skip if connect() is in progress.
if client.isConnecting { return }
case .disconnected:
break
}
// Reset backoff and connect immediately.
Self.logger.info("⚡ Fast reconnect — state=\(self.connectionState.rawValue)")
handshakeComplete = false
heartbeatTask?.cancel()
connectionState = .connecting
client.forceReconnect()
}
/// Ping-first zombie socket detection for foreground resume.
/// iOS suspends the process in background server RSTs TCP, but `didCloseWith`
/// delegate never fires. `connectionState` stays `.authenticated` (stale).
/// This method sends a WebSocket ping: pong alive, no pong force reconnect.
func verifyConnectionOrReconnect() {
guard savedPublicKey != nil, savedPrivateHash != nil else { return }
guard connectionState == .authenticated || connectionState == .connected else { return }
guard !pingVerificationInProgress else { return }
pingVerificationInProgress = true
Self.logger.info("🏓 Verifying connection with ping after foreground...")
client.sendPing { [weak self] error in
guard let self, self.pingVerificationInProgress else { return }
self.pingVerificationInProgress = false
self.pingTimeoutTask?.cancel()
self.pingTimeoutTask = nil
if let error {
Self.logger.warning("🏓 Ping failed — zombie socket: \(error.localizedDescription)")
self.handlePingFailure()
} else {
Self.logger.info("🏓 Pong received — connection alive")
}
}
// Safety timeout: if sendPing never calls back (completely dead socket), force reconnect.
pingTimeoutTask?.cancel()
pingTimeoutTask = Task { [weak self] in
try? await Task.sleep(for: .seconds(3))
guard let self, !Task.isCancelled, self.pingVerificationInProgress else { return }
self.pingVerificationInProgress = false
Self.logger.warning("🏓 Ping timeout (3s) — zombie socket, forcing reconnect")
self.handlePingFailure()
}
}
private func handlePingFailure() {
pingTimeoutTask?.cancel()
pingTimeoutTask = nil
handshakeComplete = false
heartbeatTask?.cancel()
Task { @MainActor in
self.connectionState = .connecting
}
client.forceReconnect()
}
// MARK: - Sending
func sendPacket(_ packet: any Packet) {
PerformanceLogger.shared.track("protocol.sendPacket")
let id = String(type(of: packet).packetId, radix: 16)
// Android parity (Protocol.kt:436-448): triple check handshakeComplete + socket alive + authenticated.
let isAuth = connectionState == .authenticated
if (!handshakeComplete && !(packet is PacketHandshake)) || !client.isConnected || !isAuth {
Self.logger.info("⏳ Queueing packet 0x\(id) — connected=\(self.client.isConnected), handshake=\(self.handshakeComplete), auth=\(isAuth)")
enqueuePacket(packet)
return
}
Self.logger.info("📤 Sending packet 0x\(id) directly")
sendPacketDirect(packet)
}
// MARK: - Search Handlers (Android-like wait/unwait)
@discardableResult
func addSearchResultHandler(_ handler: @escaping (PacketSearch) -> Void) -> UUID {
let id = UUID()
searchHandlersLock.lock()
searchResultHandlers[id] = handler
searchHandlersLock.unlock()
return id
}
func removeSearchResultHandler(_ id: UUID) {
searchHandlersLock.lock()
searchResultHandlers.removeValue(forKey: id)
searchHandlersLock.unlock()
}
// MARK: - Result Handlers (Android parity: waitPacket(0x02))
/// Register a one-shot handler for PacketResult (0x02).
@discardableResult
func addResultHandler(_ handler: @escaping (PacketResult) -> Void) -> UUID {
let id = UUID()
resultHandlersLock.lock()
resultHandlers[id] = handler
resultHandlersLock.unlock()
return id
}
func removeResultHandler(_ id: UUID) {
resultHandlersLock.lock()
resultHandlers.removeValue(forKey: id)
resultHandlersLock.unlock()
}
// MARK: - Private Setup
private func setupClientCallbacks() {
client.onConnected = { [weak self] in
guard let self else { return }
Self.logger.info("WebSocket connected")
Task { @MainActor in
self.connectionState = .connected
}
// Auto-handshake with saved credentials
if let pk = savedPublicKey, let hash = savedPrivateHash {
startHandshake(publicKey: pk, privateHash: hash)
}
}
client.onDisconnected = { [weak self] error in
guard let self else { return }
if let error {
Self.logger.error("Disconnected: \(error.localizedDescription)")
}
heartbeatTask?.cancel()
handshakeComplete = false
pingVerificationInProgress = false
pingTimeoutTask?.cancel()
pingTimeoutTask = nil
Task { @MainActor in
self.connectionState = .disconnected
}
}
client.onDataReceived = { [weak self] data in
self?.handleIncomingData(data)
}
// Instant reconnect when network is restored (Wi-Fi cellular, airplane mode off, etc.)
client.onNetworkRestored = { [weak self] in
guard let self, self.savedPublicKey != nil else { return }
Self.logger.info("Network restored — force reconnecting")
self.handshakeComplete = false
self.heartbeatTask?.cancel()
Task { @MainActor in
self.connectionState = .connecting
}
self.client.forceReconnect()
}
}
// MARK: - Handshake
private func startHandshake(publicKey: String, privateHash: String) {
Self.logger.info("Starting handshake for \(publicKey.prefix(20))...")
Task { @MainActor in
connectionState = .handshaking
}
let device = HandshakeDevice(
deviceId: DeviceIdentityManager.shared.currentDeviceId(),
deviceName: UIDevice.current.name,
deviceOs: "iOS \(UIDevice.current.systemVersion)"
)
let handshake = PacketHandshake(
privateKey: privateHash,
publicKey: publicKey,
protocolVersion: 1,
heartbeatInterval: 15,
device: device,
handshakeState: .needDeviceVerification
)
sendPacketDirect(handshake)
// Timeout force reconnect instead of permanent disconnect.
// `client.disconnect()` sets `isManuallyClosed = true` which kills all
// future reconnection attempts. Use `forceReconnect()` to retry.
handshakeTimeoutTask?.cancel()
handshakeTimeoutTask = Task { [weak self] in
do {
try await Task.sleep(nanoseconds: 10_000_000_000)
} catch {
return
}
guard let self, !Task.isCancelled else { return }
if !self.handshakeComplete {
Self.logger.error("Handshake timeout — forcing reconnect")
self.handshakeComplete = false
self.heartbeatTask?.cancel()
Task { @MainActor in
self.connectionState = .connecting
}
self.client.forceReconnect()
}
}
}
// MARK: - Packet Handling
private func handleIncomingData(_ data: Data) {
#if DEBUG
if data.count >= 2 {
let peekStream = Stream(data: data)
let rawId = peekStream.readInt16()
Self.logger.debug("📥 Incoming packet 0x\(String(rawId, radix: 16)), size: \(data.count)")
}
#endif
guard let (packetId, packet) = PacketRegistry.decode(from: data) else {
#if DEBUG
if data.count >= 2 {
let stream = Stream(data: data)
let rawId = stream.readInt16()
Self.logger.debug("Unknown packet ID: 0x\(String(rawId, radix: 16)), size: \(data.count)")
}
#endif
return
}
switch packetId {
case 0x00:
if let p = packet as? PacketHandshake {
handleHandshakeResponse(p)
}
case 0x01:
if let p = packet as? PacketUserInfo {
onUserInfoReceived?(p)
}
case 0x02:
if let p = packet as? PacketResult {
Self.logger.info("📥 PacketResult: code=\(p.resultCode)")
notifyResultHandlers(p)
}
case 0x03:
if let p = packet as? PacketSearch {
Self.logger.debug("📥 Search result: \(p.users.count) users, callback=\(self.onSearchResult != nil)")
onSearchResult?(p)
notifySearchResultHandlers(p)
}
case 0x05:
if let p = packet as? PacketOnlineState {
onOnlineStateReceived?(p)
}
case 0x06:
if let p = packet as? PacketMessage {
onMessageReceived?(p)
}
case 0x07:
if let p = packet as? PacketRead {
onReadReceived?(p)
}
case 0x08:
if let p = packet as? PacketDelivery {
onDeliveryReceived?(p)
}
case 0x09:
if let p = packet as? PacketDeviceNew {
onDeviceNewReceived?(p)
}
case 0x0B:
if let p = packet as? PacketTyping {
onTypingReceived?(p)
}
case 0x0F:
if let p = packet as? PacketRequestTransport {
Self.logger.info("📥 Transport server: \(p.transportServer)")
Task { @MainActor in
TransportManager.shared.setTransportServer(p.transportServer)
}
}
case 0x17:
if let p = packet as? PacketDeviceList {
handleDeviceList(p)
}
case 0x18:
if let p = packet as? PacketDeviceResolve {
handleDeviceResolve(p)
}
case 0x19:
if let p = packet as? PacketSync {
// Android parity: set sync flag SYNCHRONOUSLY on receive queue
// BEFORE dispatching to MainActor callback. This prevents the race
// where a 0x06 message Task runs on MainActor before BATCH_START Task.
if p.status == .batchStart {
syncBatchLock.lock()
_syncBatchActive = true
syncBatchLock.unlock()
} else if p.status == .notNeeded {
syncBatchLock.lock()
_syncBatchActive = false
syncBatchLock.unlock()
}
onSyncReceived?(p)
}
default:
break
}
}
private func notifySearchResultHandlers(_ packet: PacketSearch) {
searchHandlersLock.lock()
let handlers = Array(searchResultHandlers.values)
searchHandlersLock.unlock()
for handler in handlers {
handler(packet)
}
}
private func notifyResultHandlers(_ packet: PacketResult) {
resultHandlersLock.lock()
let handlers = resultHandlers
// One-shot: clear all handlers after dispatch (Android parity)
resultHandlers.removeAll()
resultHandlersLock.unlock()
for (_, handler) in handlers {
handler(packet)
}
}
private func handleHandshakeResponse(_ packet: PacketHandshake) {
handshakeTimeoutTask?.cancel()
handshakeTimeoutTask = nil
switch packet.handshakeState {
case .completed:
handshakeComplete = true
// Android parity: reset backoff counter on successful authentication.
client.resetReconnectAttempts()
Self.logger.info("Handshake completed. Protocol v\(packet.protocolVersion), heartbeat \(packet.heartbeatInterval)s")
flushPacketQueue()
startHeartbeat(interval: packet.heartbeatInterval)
// Desktop parity: request transport server URL after handshake.
sendPacketDirect(PacketRequestTransport())
// CRITICAL: set .authenticated and fire callback in ONE MainActor task.
// Previously these were separate tasks Swift doesn't guarantee FIFO
// ordering of unstructured tasks, so requestSynchronize() could race
// with the state change and silently drop the sync request.
let callback = self.onHandshakeCompleted
Task { @MainActor in
self.connectionState = .authenticated
callback?(packet)
}
case .needDeviceVerification:
handshakeComplete = false
Self.logger.info("Server requires device verification — approve this device from your other Rosetta app")
Task { @MainActor in
self.connectionState = .deviceVerificationRequired
}
// Android parity (Protocol.kt:163): clear packet queue on device verification.
clearPacketQueue()
startHeartbeat(interval: packet.heartbeatInterval)
}
}
// MARK: - Heartbeat
private func startHeartbeat(interval: Int) {
heartbeatTask?.cancel()
// Android parity: heartbeat at 1/3 the server-specified interval (more aggressive keep-alive).
let intervalNs = UInt64(interval) * 1_000_000_000 / 3
heartbeatTask = Task { [weak self] in
// Send first heartbeat immediately
self?.sendHeartbeat()
while !Task.isCancelled {
try? await Task.sleep(nanoseconds: intervalNs)
guard !Task.isCancelled else { break }
self?.sendHeartbeat()
}
}
}
/// Android parity: send heartbeat and trigger disconnect on failure.
private func sendHeartbeat() {
let state = connectionState
guard state == .authenticated || state == .deviceVerificationRequired else { return }
guard client.isConnected else {
Self.logger.warning("💔 Heartbeat failed: socket not connected — triggering reconnect")
handleHeartbeatFailure()
return
}
client.sendText("heartbeat")
}
/// Android parity: failed heartbeat handleDisconnect.
private func handleHeartbeatFailure() {
heartbeatTask?.cancel()
handshakeComplete = false
Task { @MainActor in
self.connectionState = .disconnected
}
// Let WebSocketClient's own handleDisconnect schedule reconnect.
client.forceReconnect()
}
// MARK: - Packet Queue
private func sendPacketDirect(_ packet: any Packet) {
let data = PacketRegistry.encode(packet)
Self.logger.info("Sending packet 0x\(String(type(of: packet).packetId, radix: 16)) (\(data.count) bytes)")
if !client.send(data, onFailure: { [weak self] _ in
guard let self else { return }
Self.logger.warning("Send failed, re-queueing packet 0x\(String(type(of: packet).packetId, radix: 16))")
self.enqueuePacket(packet)
}) {
Self.logger.warning("WebSocket unavailable, re-queueing packet 0x\(String(type(of: packet).packetId, radix: 16))")
enqueuePacket(packet)
}
}
private func flushPacketQueue() {
let packets = drainPacketQueue()
Self.logger.info("Flushing \(packets.count) queued packets")
for packet in packets {
sendPacketDirect(packet)
}
}
private func enqueuePacket(_ packet: any Packet) {
packetQueueLock.lock()
if let key = packetQueueKey(packet), queuedPacketKeys.contains(key) {
packetQueueLock.unlock()
return
}
packetQueue.append(packet)
if let key = packetQueueKey(packet) {
queuedPacketKeys.insert(key)
}
let count = packetQueue.count
packetQueueLock.unlock()
Self.logger.info("Queueing packet 0x\(String(type(of: packet).packetId, radix: 16)) (queue=\(count))")
}
private func drainPacketQueue() -> [any Packet] {
packetQueueLock.lock()
let packets = packetQueue
packetQueue.removeAll()
queuedPacketKeys.removeAll()
packetQueueLock.unlock()
return packets
}
private func clearPacketQueue() {
packetQueueLock.lock()
packetQueue.removeAll()
queuedPacketKeys.removeAll()
packetQueueLock.unlock()
}
// MARK: - Device Verification
private func handleDeviceList(_ packet: PacketDeviceList) {
Self.logger.info("📱 Device list received: \(packet.devices.count) devices")
for device in packet.devices {
Self.logger.info(" - \(device.deviceName) (\(device.deviceOs)) status=\(device.deviceStatus.rawValue) verify=\(device.deviceVerify.rawValue)")
}
Task { @MainActor in
self.devices = packet.devices
self.pendingDeviceVerification = packet.devices.first { $0.deviceVerify == .notVerified }
}
}
private func handleDeviceResolve(_ packet: PacketDeviceResolve) {
Self.logger.info("🔐 Device resolve received: deviceId=\(packet.deviceId.prefix(20)), solution=\(packet.solution.rawValue)")
if packet.solution == .decline {
Self.logger.info("🚫 This device was DECLINED — disconnecting")
disconnect()
}
// If accepted, server will re-send handshake with .completed
// which is handled by handleHandshakeResponse
}
/// Accept a pending device login from another device.
func acceptDevice(_ deviceId: String) {
Self.logger.info("✅ Accepting device: \(deviceId.prefix(20))")
var packet = PacketDeviceResolve()
packet.deviceId = deviceId
packet.solution = .accept
sendPacketDirect(packet)
Task { @MainActor in
self.pendingDeviceVerification = nil
}
}
/// Decline a pending device login from another device.
func declineDevice(_ deviceId: String) {
Self.logger.info("❌ Declining device: \(deviceId.prefix(20))")
var packet = PacketDeviceResolve()
packet.deviceId = deviceId
packet.solution = .decline
sendPacketDirect(packet)
Task { @MainActor in
self.pendingDeviceVerification = nil
}
}
private func packetQueueKey(_ packet: any Packet) -> String? {
switch packet {
case let message as PacketMessage:
return message.messageId.isEmpty ? nil : "0x06:\(message.messageId)"
case let delivery as PacketDelivery:
return delivery.messageId.isEmpty ? nil : "0x08:\(delivery.toPublicKey):\(delivery.messageId)"
case let read as PacketRead:
guard !read.fromPublicKey.isEmpty, !read.toPublicKey.isEmpty else { return nil }
return "0x07:\(read.fromPublicKey):\(read.toPublicKey)"
case let typing as PacketTyping:
guard !typing.fromPublicKey.isEmpty, !typing.toPublicKey.isEmpty else { return nil }
return "0x0b:\(typing.fromPublicKey):\(typing.toPublicKey)"
case let sync as PacketSync:
return "0x19:\(sync.status.rawValue):\(sync.timestamp)"
default:
return nil
}
}
}