import XCTest import P256K @testable import Rosetta @MainActor final class MessageDecodeHardeningTests: XCTestCase { func testProtocolManagerDropsMalformedMessagePacketBeforeDispatch() { let proto = ProtocolManager.shared let originalOnMessage = proto.onMessageReceived let originalOnMalformed = proto.onMalformedMessageReceived defer { proto.onMessageReceived = originalOnMessage proto.onMalformedMessageReceived = originalOnMalformed } var messageDispatchCount = 0 var malformedCount = 0 proto.onMessageReceived = { _ in messageDispatchCount += 1 } proto.onMalformedMessageReceived = { _ in malformedCount += 1 } let malformedData = makePacketMessageData(metaFieldCount: 2) + Data([0x00]) proto.testHandleIncomingData(malformedData) XCTAssertEqual(messageDispatchCount, 0) XCTAssertEqual(malformedCount, 1) } func testProtocolManagerStillDispatchesValidMessagePacket() { let proto = ProtocolManager.shared let originalOnMessage = proto.onMessageReceived let originalOnMalformed = proto.onMalformedMessageReceived defer { proto.onMessageReceived = originalOnMessage proto.onMalformedMessageReceived = originalOnMalformed } var messageDispatchCount = 0 var malformedCount = 0 proto.onMessageReceived = { _ in messageDispatchCount += 1 } proto.onMalformedMessageReceived = { _ in malformedCount += 1 } let validData = makePacketMessageData(metaFieldCount: 2) proto.testHandleIncomingData(validData) XCTAssertEqual(messageDispatchCount, 1) XCTAssertEqual(malformedCount, 0) } func testProtocolManagerDropsMalformedHandshakePacketBeforeDispatch() { let proto = ProtocolManager.shared let originalOnHandshake = proto.onHandshakeCompleted let originalOnMalformedCritical = proto.onMalformedCriticalPacketReceived defer { proto.onHandshakeCompleted = originalOnHandshake proto.onMalformedCriticalPacketReceived = originalOnMalformedCritical } var handshakeDispatchCount = 0 var malformedInfos: [MalformedCriticalPacketInfo] = [] proto.onHandshakeCompleted = { _ in handshakeDispatchCount += 1 } proto.onMalformedCriticalPacketReceived = { info in malformedInfos.append(info) } let malformedData = makeHandshakeData(stateRaw: 9) proto.testHandleIncomingData(malformedData) XCTAssertEqual(handshakeDispatchCount, 0) XCTAssertEqual(malformedInfos.count, 1) XCTAssertEqual(malformedInfos.first?.packetId, PacketHandshake.packetId) } func testProtocolManagerDropsMalformedSyncPacketBeforeDispatch() { let proto = ProtocolManager.shared let originalOnSync = proto.onSyncReceived let originalOnMalformedCritical = proto.onMalformedCriticalPacketReceived defer { proto.onSyncReceived = originalOnSync proto.onMalformedCriticalPacketReceived = originalOnMalformedCritical } var syncDispatchCount = 0 var malformedInfos: [MalformedCriticalPacketInfo] = [] proto.onSyncReceived = { _ in syncDispatchCount += 1 } proto.onMalformedCriticalPacketReceived = { info in malformedInfos.append(info) } let malformedData = makeSyncData(statusRaw: 7) proto.testHandleIncomingData(malformedData) XCTAssertEqual(syncDispatchCount, 0) XCTAssertEqual(malformedInfos.count, 1) XCTAssertEqual(malformedInfos.first?.packetId, PacketSync.packetId) } func testProtocolManagerDropsMalformedSignalPacketBeforeDispatch() { let proto = ProtocolManager.shared let originalOnSignal = proto.onSignalPeerReceived let originalOnMalformedCritical = proto.onMalformedCriticalPacketReceived defer { proto.onSignalPeerReceived = originalOnSignal proto.onMalformedCriticalPacketReceived = originalOnMalformedCritical } var signalDispatchCount = 0 var malformedInfos: [MalformedCriticalPacketInfo] = [] proto.onSignalPeerReceived = { _ in signalDispatchCount += 1 } proto.onMalformedCriticalPacketReceived = { info in malformedInfos.append(info) } var signal = PacketSignalPeer() signal.signalType = .endCallBecauseBusy let malformedData = PacketRegistry.encode(signal) + Data([0x00]) proto.testHandleIncomingData(malformedData) XCTAssertEqual(signalDispatchCount, 0) XCTAssertEqual(malformedInfos.count, 1) XCTAssertEqual(malformedInfos.first?.packetId, PacketSignalPeer.packetId) } func testProtocolManagerDropsMalformedWebRtcPacketBeforeDispatch() { let proto = ProtocolManager.shared let originalOnWebRtc = proto.onWebRTCReceived let originalOnMalformedCritical = proto.onMalformedCriticalPacketReceived defer { proto.onWebRTCReceived = originalOnWebRtc proto.onMalformedCriticalPacketReceived = originalOnMalformedCritical } var webRtcDispatchCount = 0 var malformedInfos: [MalformedCriticalPacketInfo] = [] proto.onWebRTCReceived = { _ in webRtcDispatchCount += 1 } proto.onMalformedCriticalPacketReceived = { info in malformedInfos.append(info) } let malformedData = makeWebRtcData(includeIdentity: false) + Data([0x00]) proto.testHandleIncomingData(malformedData) XCTAssertEqual(webRtcDispatchCount, 0) XCTAssertEqual(malformedInfos.count, 1) XCTAssertEqual(malformedInfos.first?.packetId, PacketWebRTC.packetId) } func testSignalTypeCreateRoomDecodesWithAndWithoutRoomId() throws { let withoutRoomData = makeSignalCreateRoomData(roomId: nil) guard let withoutRoomDecoded = PacketRegistry.decode(from: withoutRoomData), let withoutRoomPacket = withoutRoomDecoded.packet as? PacketSignalPeer else { XCTFail("Failed to decode signalType=4 packet without roomId") return } XCTAssertEqual(withoutRoomDecoded.packetId, PacketSignalPeer.packetId) XCTAssertEqual(withoutRoomPacket.signalType, .createRoom) XCTAssertEqual(withoutRoomPacket.roomId, "") let withRoomData = makeSignalCreateRoomData(roomId: "room-42") guard let withRoomDecoded = PacketRegistry.decode(from: withRoomData), let withRoomPacket = withRoomDecoded.packet as? PacketSignalPeer else { XCTFail("Failed to decode signalType=4 packet with roomId") return } XCTAssertEqual(withRoomDecoded.packetId, PacketSignalPeer.packetId) XCTAssertEqual(withRoomPacket.signalType, .createRoom) XCTAssertEqual(withRoomPacket.roomId, "room-42") } func testWebRtcDecodeSupportsTwoAndFourFieldLayouts() throws { let twoFieldData = makeWebRtcData(includeIdentity: false) guard let decodedTwoField = PacketRegistry.decode(from: twoFieldData), let twoFieldPacket = decodedTwoField.packet as? PacketWebRTC else { XCTFail("Failed to decode canonical 2-field 0x1B packet") return } XCTAssertEqual(decodedTwoField.packetId, PacketWebRTC.packetId) XCTAssertEqual(twoFieldPacket.signalType, .offer) XCTAssertEqual(twoFieldPacket.sdpOrCandidate, "{\"type\":\"offer\",\"sdp\":\"v=0\"}") XCTAssertEqual(twoFieldPacket.publicKey, "") XCTAssertEqual(twoFieldPacket.deviceId, "") let fourFieldData = makeWebRtcData( includeIdentity: true, publicKey: "02legacyPublic", deviceId: "legacy-device" ) guard let decodedFourField = PacketRegistry.decode(from: fourFieldData), let fourFieldPacket = decodedFourField.packet as? PacketWebRTC else { XCTFail("Failed to decode legacy 4-field 0x1B packet") return } XCTAssertEqual(decodedFourField.packetId, PacketWebRTC.packetId) XCTAssertEqual(fourFieldPacket.signalType, .offer) XCTAssertEqual(fourFieldPacket.sdpOrCandidate, "{\"type\":\"offer\",\"sdp\":\"v=0\"}") XCTAssertEqual(fourFieldPacket.publicKey, "02legacyPublic") XCTAssertEqual(fourFieldPacket.deviceId, "legacy-device") } func testWebRtcEncodeUsesCanonicalTwoFieldLayout() { var packet = PacketWebRTC() packet.signalType = .answer packet.sdpOrCandidate = "{\"type\":\"answer\",\"sdp\":\"v=0\"}" packet.publicKey = "02shouldNotBeEncoded" packet.deviceId = "device-should-not-be-encoded" let encoded = PacketRegistry.encode(packet) let stream = Rosetta.Stream(data: encoded) XCTAssertEqual(stream.readInt16(), PacketWebRTC.packetId) XCTAssertEqual(stream.readInt8(), WebRTCSignalType.answer.rawValue) XCTAssertEqual(stream.readString(), packet.sdpOrCandidate) XCTAssertFalse(stream.hasRemainingBits()) } func testMalformedPacketRecoveryDebouncesResyncToSingleRequest() async throws { let session = SessionManager.shared let privateKeyHex = try P256K.KeyAgreement.PrivateKey().rawRepresentation.hexString session.testConfigureSessionForParityFlows( currentPublicKey: "02malformed_resync_test", privateKeyHex: privateKeyHex ) session.testResetMalformedMessageResyncState() defer { session.testSetMalformedMessageResyncHook(nil) session.testResetMalformedMessageResyncState() } var triggerCount = 0 session.testSetMalformedMessageResyncHook { triggerCount += 1 } session.testSimulateMalformedMessagePacketDrop(packetSize: 111, fingerprint: "fp-1", messageIdHint: "m1") session.testSimulateMalformedMessagePacketDrop(packetSize: 112, fingerprint: "fp-2", messageIdHint: "m2") session.testSimulateMalformedMessagePacketDrop(packetSize: 113, fingerprint: "fp-3", messageIdHint: "m3") try await Task.sleep(nanoseconds: 900_000_000) XCTAssertEqual(triggerCount, 1) XCTAssertEqual(session.malformedMessageResyncTriggerCount, 1) } func testMalformedPacketRecoveryWaitsUntilSyncBatchEnds() async throws { let session = SessionManager.shared let privateKeyHex = try P256K.KeyAgreement.PrivateKey().rawRepresentation.hexString session.testConfigureSessionForParityFlows( currentPublicKey: "02malformed_resync_batch_test", privateKeyHex: privateKeyHex ) session.testResetMalformedMessageResyncState() session.testSetSyncState(syncRequestInFlight: false, syncBatchInProgress: true) defer { session.testSetSyncState(syncRequestInFlight: false, syncBatchInProgress: false) session.testSetMalformedMessageResyncHook(nil) session.testResetMalformedMessageResyncState() } var triggerCount = 0 session.testSetMalformedMessageResyncHook { triggerCount += 1 } session.testSimulateMalformedMessagePacketDrop(packetSize: 211, fingerprint: "fp-batch", messageIdHint: "m-batch") // While sync is active, recovery must remain queued and not fire. try await Task.sleep(nanoseconds: 900_000_000) XCTAssertEqual(triggerCount, 0) XCTAssertEqual(session.malformedMessageResyncTriggerCount, 0) // After sync ends, queued recovery should fire once. session.testSetSyncState(syncBatchInProgress: false) try await Task.sleep(nanoseconds: 700_000_000) XCTAssertEqual(triggerCount, 1) XCTAssertEqual(session.malformedMessageResyncTriggerCount, 1) } func testPostDecryptEmptyPayloadDropsWithoutUpsertAndDebouncesResync() async throws { let session = SessionManager.shared let myPrivateKey = try P256K.KeyAgreement.PrivateKey() let myPrivateKeyData = myPrivateKey.rawRepresentation let myPrivateKeyHex = myPrivateKeyData.hexString let myPublicKey = try CryptoManager.shared.deriveCompressedPublicKey(from: myPrivateKeyData).hexString try DatabaseManager.shared.bootstrap(accountPublicKey: myPublicKey) await MessageRepository.shared.bootstrap(accountPublicKey: myPublicKey, storagePassword: myPrivateKeyHex) await DialogRepository.shared.bootstrap(accountPublicKey: myPublicKey, storagePassword: myPrivateKeyHex) session.testConfigureSessionForParityFlows( currentPublicKey: myPublicKey, privateKeyHex: myPrivateKeyHex ) session.testSetSyncState(syncRequestInFlight: false, syncBatchInProgress: false) session.testResetMalformedMessageResyncState() defer { session.testSetMalformedMessageResyncHook(nil) session.testResetMalformedMessageResyncState() session.testSetSyncState(syncRequestInFlight: false, syncBatchInProgress: false) } var triggerCount = 0 session.testSetMalformedMessageResyncHook { triggerCount += 1 } let peerPrivateKey = try P256K.KeyAgreement.PrivateKey() let peerPublicKey = try CryptoManager.shared .deriveCompressedPublicKey(from: peerPrivateKey.rawRepresentation) .hexString var processedMessageIds: [String] = [] for index in 0..<3 { let encrypted = try MessageCrypto.encryptOutgoing( plaintext: "", recipientPublicKeyHex: myPublicKey ) var packet = PacketMessage() packet.fromPublicKey = peerPublicKey packet.toPublicKey = myPublicKey packet.content = encrypted.content packet.chachaKey = encrypted.chachaKey packet.timestamp = 1_710_000_000_000 + Int64(index) packet.privateKey = "hash" packet.messageId = "empty-decrypted-\(index)-\(UUID().uuidString)" packet.attachments = [] packet.aesChachaKey = "" processedMessageIds.append(packet.messageId) await session.testProcessIncomingMessage(packet) } try await Task.sleep(nanoseconds: 900_000_000) XCTAssertEqual(triggerCount, 1) XCTAssertEqual(session.malformedMessageResyncTriggerCount, 1) for messageId in processedMessageIds { XCTAssertFalse( MessageRepository.shared.hasMessage(messageId), "Post-decrypt empty payload must not be persisted as a bubble" ) } } private func makePacketMessageData(metaFieldCount: Int) -> Data { let stream = Rosetta.Stream() stream.writeInt16(PacketMessage.packetId) stream.writeString("02from") stream.writeString("02to") stream.writeString("ciphertext") stream.writeString("chacha-key") stream.writeInt64(1_710_000_000_000) stream.writeString("hash") stream.writeString("msg-hardening") stream.writeInt8(1) stream.writeString("att-1") stream.writeString("preview") stream.writeString("blob") stream.writeInt8(AttachmentType.image.rawValue) if metaFieldCount >= 2 { stream.writeString("tag-1") stream.writeString("cdn.rosetta.im") } if metaFieldCount >= 4 { stream.writeString("02encoded-for") stream.writeString("desktop") } stream.writeString("aes-key") return stream.toData() } private func makeHandshakeData(stateRaw: Int) -> Data { let stream = Rosetta.Stream() stream.writeInt16(PacketHandshake.packetId) stream.writeString("hash") stream.writeString("02public") stream.writeInt8(1) stream.writeInt8(15) stream.writeString("device-id") stream.writeString("iPhone") stream.writeString("iOS") stream.writeInt8(stateRaw) return stream.toData() } private func makeSyncData(statusRaw: Int) -> Data { let stream = Rosetta.Stream() stream.writeInt16(PacketSync.packetId) stream.writeInt8(statusRaw) stream.writeInt64(1_710_000_000_000) return stream.toData() } private func makeSignalCreateRoomData(roomId: String?) -> Data { let stream = Rosetta.Stream() stream.writeInt16(PacketSignalPeer.packetId) stream.writeInt8(SignalType.createRoom.rawValue) stream.writeString("02src") stream.writeString("02dst") if let roomId { stream.writeString(roomId) } return stream.toData() } private func makeWebRtcData( includeIdentity: Bool, publicKey: String = "", deviceId: String = "" ) -> Data { let stream = Rosetta.Stream() stream.writeInt16(PacketWebRTC.packetId) stream.writeInt8(WebRTCSignalType.offer.rawValue) stream.writeString("{\"type\":\"offer\",\"sdp\":\"v=0\"}") if includeIdentity { stream.writeString(publicKey) stream.writeString(deviceId) } return stream.toData() } }