import Foundation import os import Observation import UIKit // MARK: - Connection State enum ConnectionState: String { case disconnected case connecting case connected case handshaking case deviceVerificationRequired case authenticated } // MARK: - ProtocolManager /// Central networking coordinator. Owns WebSocket, routes packets, manages handshake. @Observable final class ProtocolManager: @unchecked Sendable { static let shared = ProtocolManager() private static let logger = Logger(subsystem: "com.rosetta.messenger", category: "Protocol") // MARK: - Public State private(set) var connectionState: ConnectionState = .disconnected // MARK: - Device Verification State /// Device waiting for approval from this device (shown as banner on primary device). private(set) var pendingDeviceVerification: DeviceEntry? /// All connected devices. private(set) var devices: [DeviceEntry] = [] // MARK: - Callbacks var onMessageReceived: ((PacketMessage) -> Void)? var onDeliveryReceived: ((PacketDelivery) -> Void)? var onReadReceived: ((PacketRead) -> Void)? var onOnlineStateReceived: ((PacketOnlineState) -> Void)? var onUserInfoReceived: ((PacketUserInfo) -> Void)? var onSearchResult: ((PacketSearch) -> Void)? var onTypingReceived: ((PacketTyping) -> Void)? var onSyncReceived: ((PacketSync) -> Void)? var onHandshakeCompleted: ((PacketHandshake) -> Void)? // MARK: - Private private let client = WebSocketClient() private var packetQueue: [any Packet] = [] private var queuedPacketKeys: Set = [] private var handshakeComplete = false private var heartbeatTask: Task? private var handshakeTimeoutTask: Task? private let searchHandlersLock = NSLock() private let packetQueueLock = NSLock() private var searchResultHandlers: [UUID: (PacketSearch) -> Void] = [:] // Saved credentials for auto-reconnect private var savedPublicKey: String? private var savedPrivateHash: String? var publicKey: String? { savedPublicKey } var privateHash: String? { savedPrivateHash } private init() { setupClientCallbacks() } // MARK: - Connection /// Connect to server and perform handshake. func connect(publicKey: String, privateKeyHash: String) { savedPublicKey = publicKey savedPrivateHash = privateKeyHash if connectionState == .authenticated || connectionState == .handshaking { Self.logger.info("Already connected/handshaking, skipping") return } connectionState = .connecting client.connect() } func disconnect() { Self.logger.info("Disconnecting") heartbeatTask?.cancel() handshakeTimeoutTask?.cancel() handshakeComplete = false client.disconnect() connectionState = .disconnected savedPublicKey = nil savedPrivateHash = nil } // MARK: - Sending func sendPacket(_ packet: any Packet) { let id = String(type(of: packet).packetId, radix: 16) if (!handshakeComplete && !(packet is PacketHandshake)) || !client.isConnected { Self.logger.info("⏳ Queueing packet 0x\(id) — connected=\(self.client.isConnected), handshake=\(self.handshakeComplete)") enqueuePacket(packet) return } Self.logger.info("📤 Sending packet 0x\(id) directly") sendPacketDirect(packet) } // MARK: - Search Handlers (Android-like wait/unwait) @discardableResult func addSearchResultHandler(_ handler: @escaping (PacketSearch) -> Void) -> UUID { let id = UUID() searchHandlersLock.lock() searchResultHandlers[id] = handler searchHandlersLock.unlock() return id } func removeSearchResultHandler(_ id: UUID) { searchHandlersLock.lock() searchResultHandlers.removeValue(forKey: id) searchHandlersLock.unlock() } // MARK: - Private Setup private func setupClientCallbacks() { client.onConnected = { [weak self] in guard let self else { return } Self.logger.info("WebSocket connected") Task { @MainActor in self.connectionState = .connected } // Auto-handshake with saved credentials if let pk = savedPublicKey, let hash = savedPrivateHash { startHandshake(publicKey: pk, privateHash: hash) } } client.onDisconnected = { [weak self] error in guard let self else { return } if let error { Self.logger.error("Disconnected: \(error.localizedDescription)") } heartbeatTask?.cancel() handshakeComplete = false Task { @MainActor in self.connectionState = .disconnected } } client.onDataReceived = { [weak self] data in self?.handleIncomingData(data) } } // MARK: - Handshake private func startHandshake(publicKey: String, privateHash: String) { Self.logger.info("Starting handshake for \(publicKey.prefix(20))...") Task { @MainActor in connectionState = .handshaking } let device = HandshakeDevice( deviceId: UIDevice.current.identifierForVendor?.uuidString ?? "unknown", deviceName: UIDevice.current.name, deviceOs: "iOS \(UIDevice.current.systemVersion)" ) let handshake = PacketHandshake( privateKey: privateHash, publicKey: publicKey, protocolVersion: 1, heartbeatInterval: 15, device: device, handshakeState: .completed ) sendPacketDirect(handshake) // Timeout handshakeTimeoutTask?.cancel() handshakeTimeoutTask = Task { [weak self] in do { try await Task.sleep(nanoseconds: 10_000_000_000) } catch { return } guard let self, !Task.isCancelled else { return } if !self.handshakeComplete { Self.logger.error("Handshake timeout") self.client.disconnect() } } } // MARK: - Packet Handling private func handleIncomingData(_ data: Data) { #if DEBUG if data.count >= 2 { let peekStream = Stream(data: data) let rawId = peekStream.readInt16() Self.logger.debug("📥 Incoming packet 0x\(String(rawId, radix: 16)), size: \(data.count)") } #endif guard let (packetId, packet) = PacketRegistry.decode(from: data) else { #if DEBUG if data.count >= 2 { let stream = Stream(data: data) let rawId = stream.readInt16() Self.logger.debug("Unknown packet ID: 0x\(String(rawId, radix: 16)), size: \(data.count)") } #endif return } switch packetId { case 0x00: if let p = packet as? PacketHandshake { handleHandshakeResponse(p) } case 0x01: if let p = packet as? PacketUserInfo { onUserInfoReceived?(p) } case 0x02: if let p = packet as? PacketResult { let _ = ResultCode(rawValue: p.resultCode) } case 0x03: if let p = packet as? PacketSearch { Self.logger.debug("📥 Search result: \(p.users.count) users, callback=\(self.onSearchResult != nil)") onSearchResult?(p) notifySearchResultHandlers(p) } case 0x05: if let p = packet as? PacketOnlineState { onOnlineStateReceived?(p) } case 0x06: if let p = packet as? PacketMessage { onMessageReceived?(p) } case 0x07: if let p = packet as? PacketRead { onReadReceived?(p) } case 0x08: if let p = packet as? PacketDelivery { onDeliveryReceived?(p) } case 0x0B: if let p = packet as? PacketTyping { onTypingReceived?(p) } case 0x17: if let p = packet as? PacketDeviceList { handleDeviceList(p) } case 0x18: if let p = packet as? PacketDeviceResolve { handleDeviceResolve(p) } case 0x19: if let p = packet as? PacketSync { onSyncReceived?(p) } default: break } } private func notifySearchResultHandlers(_ packet: PacketSearch) { searchHandlersLock.lock() let handlers = Array(searchResultHandlers.values) searchHandlersLock.unlock() for handler in handlers { handler(packet) } } private func handleHandshakeResponse(_ packet: PacketHandshake) { handshakeTimeoutTask?.cancel() handshakeTimeoutTask = nil switch packet.handshakeState { case .completed: handshakeComplete = true Self.logger.info("Handshake completed. Protocol v\(packet.protocolVersion), heartbeat \(packet.heartbeatInterval)s") Task { @MainActor in self.connectionState = .authenticated } flushPacketQueue() startHeartbeat(interval: packet.heartbeatInterval) onHandshakeCompleted?(packet) case .needDeviceVerification: handshakeComplete = false Self.logger.info("Server requires device verification — approve this device from your other Rosetta app") Task { @MainActor in self.connectionState = .deviceVerificationRequired } // Keep packet queue: messages will be flushed when the other device // approves this login and the server re-sends handshake with .completed startHeartbeat(interval: packet.heartbeatInterval) } } // MARK: - Heartbeat private func startHeartbeat(interval: Int) { heartbeatTask?.cancel() let intervalMs = UInt64(interval) * 1_000_000_000 / 3 heartbeatTask = Task { // Send first heartbeat immediately client.sendText("heartbeat") while !Task.isCancelled { try? await Task.sleep(nanoseconds: intervalMs) guard !Task.isCancelled else { break } client.sendText("heartbeat") } } } // MARK: - Packet Queue private func sendPacketDirect(_ packet: any Packet) { let data = PacketRegistry.encode(packet) Self.logger.info("Sending packet 0x\(String(type(of: packet).packetId, radix: 16)) (\(data.count) bytes)") if !client.send(data, onFailure: { [weak self] _ in guard let self else { return } Self.logger.warning("Send failed, re-queueing packet 0x\(String(type(of: packet).packetId, radix: 16))") self.enqueuePacket(packet) }) { Self.logger.warning("WebSocket unavailable, re-queueing packet 0x\(String(type(of: packet).packetId, radix: 16))") enqueuePacket(packet) } } private func flushPacketQueue() { let packets = drainPacketQueue() Self.logger.info("Flushing \(packets.count) queued packets") for packet in packets { sendPacketDirect(packet) } } private func enqueuePacket(_ packet: any Packet) { packetQueueLock.lock() if let key = packetQueueKey(packet), queuedPacketKeys.contains(key) { packetQueueLock.unlock() return } packetQueue.append(packet) if let key = packetQueueKey(packet) { queuedPacketKeys.insert(key) } let count = packetQueue.count packetQueueLock.unlock() Self.logger.info("Queueing packet 0x\(String(type(of: packet).packetId, radix: 16)) (queue=\(count))") } private func drainPacketQueue() -> [any Packet] { packetQueueLock.lock() let packets = packetQueue packetQueue.removeAll() queuedPacketKeys.removeAll() packetQueueLock.unlock() return packets } private func clearPacketQueue() { packetQueueLock.lock() packetQueue.removeAll() queuedPacketKeys.removeAll() packetQueueLock.unlock() } // MARK: - Device Verification private func handleDeviceList(_ packet: PacketDeviceList) { Self.logger.info("📱 Device list received: \(packet.devices.count) devices") for device in packet.devices { Self.logger.info(" - \(device.deviceName) (\(device.deviceOs)) status=\(device.deviceStatus.rawValue) verify=\(device.deviceVerify.rawValue)") } Task { @MainActor in self.devices = packet.devices self.pendingDeviceVerification = packet.devices.first { $0.deviceVerify == .notVerified } } } private func handleDeviceResolve(_ packet: PacketDeviceResolve) { Self.logger.info("🔐 Device resolve received: deviceId=\(packet.deviceId.prefix(20)), solution=\(packet.solution.rawValue)") if packet.solution == .decline { Self.logger.info("🚫 This device was DECLINED — disconnecting") disconnect() } // If accepted, server will re-send handshake with .completed // which is handled by handleHandshakeResponse } /// Accept a pending device login from another device. func acceptDevice(_ deviceId: String) { Self.logger.info("✅ Accepting device: \(deviceId.prefix(20))") var packet = PacketDeviceResolve() packet.deviceId = deviceId packet.solution = .accept sendPacketDirect(packet) Task { @MainActor in self.pendingDeviceVerification = nil } } /// Decline a pending device login from another device. func declineDevice(_ deviceId: String) { Self.logger.info("❌ Declining device: \(deviceId.prefix(20))") var packet = PacketDeviceResolve() packet.deviceId = deviceId packet.solution = .decline sendPacketDirect(packet) Task { @MainActor in self.pendingDeviceVerification = nil } } private func packetQueueKey(_ packet: any Packet) -> String? { switch packet { case let message as PacketMessage: return message.messageId.isEmpty ? nil : "0x06:\(message.messageId)" case let delivery as PacketDelivery: return delivery.messageId.isEmpty ? nil : "0x08:\(delivery.toPublicKey):\(delivery.messageId)" case let read as PacketRead: guard !read.fromPublicKey.isEmpty, !read.toPublicKey.isEmpty else { return nil } return "0x07:\(read.fromPublicKey):\(read.toPublicKey)" case let typing as PacketTyping: guard !typing.fromPublicKey.isEmpty, !typing.toPublicKey.isEmpty else { return nil } return "0x0b:\(typing.fromPublicKey):\(typing.toPublicKey)" case let sync as PacketSync: return "0x19:\(sync.status.rawValue):\(sync.timestamp)" default: return nil } } }