Files
mobile-ios/RosettaTests/MessageDecodeHardeningTests.swift

452 lines
17 KiB
Swift

import XCTest
import P256K
@testable import Rosetta
@MainActor
final class MessageDecodeHardeningTests: XCTestCase {
func testProtocolManagerDropsMalformedMessagePacketBeforeDispatch() {
let proto = ProtocolManager.shared
let originalOnMessage = proto.onMessageReceived
let originalOnMalformed = proto.onMalformedMessageReceived
defer {
proto.onMessageReceived = originalOnMessage
proto.onMalformedMessageReceived = originalOnMalformed
}
var messageDispatchCount = 0
var malformedCount = 0
proto.onMessageReceived = { _ in
messageDispatchCount += 1
}
proto.onMalformedMessageReceived = { _ in
malformedCount += 1
}
let malformedData = makePacketMessageData(metaFieldCount: 2) + Data([0x00])
proto.testHandleIncomingData(malformedData)
XCTAssertEqual(messageDispatchCount, 0)
XCTAssertEqual(malformedCount, 1)
}
func testProtocolManagerStillDispatchesValidMessagePacket() {
let proto = ProtocolManager.shared
let originalOnMessage = proto.onMessageReceived
let originalOnMalformed = proto.onMalformedMessageReceived
defer {
proto.onMessageReceived = originalOnMessage
proto.onMalformedMessageReceived = originalOnMalformed
}
var messageDispatchCount = 0
var malformedCount = 0
proto.onMessageReceived = { _ in
messageDispatchCount += 1
}
proto.onMalformedMessageReceived = { _ in
malformedCount += 1
}
let validData = makePacketMessageData(metaFieldCount: 2)
proto.testHandleIncomingData(validData)
XCTAssertEqual(messageDispatchCount, 1)
XCTAssertEqual(malformedCount, 0)
}
func testProtocolManagerDropsMalformedHandshakePacketBeforeDispatch() {
let proto = ProtocolManager.shared
let originalOnHandshake = proto.onHandshakeCompleted
let originalOnMalformedCritical = proto.onMalformedCriticalPacketReceived
defer {
proto.onHandshakeCompleted = originalOnHandshake
proto.onMalformedCriticalPacketReceived = originalOnMalformedCritical
}
var handshakeDispatchCount = 0
var malformedInfos: [MalformedCriticalPacketInfo] = []
proto.onHandshakeCompleted = { _ in
handshakeDispatchCount += 1
}
proto.onMalformedCriticalPacketReceived = { info in
malformedInfos.append(info)
}
let malformedData = makeHandshakeData(stateRaw: 9)
proto.testHandleIncomingData(malformedData)
XCTAssertEqual(handshakeDispatchCount, 0)
XCTAssertEqual(malformedInfos.count, 1)
XCTAssertEqual(malformedInfos.first?.packetId, PacketHandshake.packetId)
}
func testProtocolManagerDropsMalformedSyncPacketBeforeDispatch() {
let proto = ProtocolManager.shared
let originalOnSync = proto.onSyncReceived
let originalOnMalformedCritical = proto.onMalformedCriticalPacketReceived
defer {
proto.onSyncReceived = originalOnSync
proto.onMalformedCriticalPacketReceived = originalOnMalformedCritical
}
var syncDispatchCount = 0
var malformedInfos: [MalformedCriticalPacketInfo] = []
proto.onSyncReceived = { _ in
syncDispatchCount += 1
}
proto.onMalformedCriticalPacketReceived = { info in
malformedInfos.append(info)
}
let malformedData = makeSyncData(statusRaw: 7)
proto.testHandleIncomingData(malformedData)
XCTAssertEqual(syncDispatchCount, 0)
XCTAssertEqual(malformedInfos.count, 1)
XCTAssertEqual(malformedInfos.first?.packetId, PacketSync.packetId)
}
func testProtocolManagerDropsMalformedSignalPacketBeforeDispatch() {
let proto = ProtocolManager.shared
let originalOnSignal = proto.onSignalPeerReceived
let originalOnMalformedCritical = proto.onMalformedCriticalPacketReceived
defer {
proto.onSignalPeerReceived = originalOnSignal
proto.onMalformedCriticalPacketReceived = originalOnMalformedCritical
}
var signalDispatchCount = 0
var malformedInfos: [MalformedCriticalPacketInfo] = []
proto.onSignalPeerReceived = { _ in
signalDispatchCount += 1
}
proto.onMalformedCriticalPacketReceived = { info in
malformedInfos.append(info)
}
var signal = PacketSignalPeer()
signal.signalType = .endCallBecauseBusy
let malformedData = PacketRegistry.encode(signal) + Data([0x00])
proto.testHandleIncomingData(malformedData)
XCTAssertEqual(signalDispatchCount, 0)
XCTAssertEqual(malformedInfos.count, 1)
XCTAssertEqual(malformedInfos.first?.packetId, PacketSignalPeer.packetId)
}
func testProtocolManagerDropsMalformedWebRtcPacketBeforeDispatch() {
let proto = ProtocolManager.shared
let originalOnWebRtc = proto.onWebRTCReceived
let originalOnMalformedCritical = proto.onMalformedCriticalPacketReceived
defer {
proto.onWebRTCReceived = originalOnWebRtc
proto.onMalformedCriticalPacketReceived = originalOnMalformedCritical
}
var webRtcDispatchCount = 0
var malformedInfos: [MalformedCriticalPacketInfo] = []
proto.onWebRTCReceived = { _ in
webRtcDispatchCount += 1
}
proto.onMalformedCriticalPacketReceived = { info in
malformedInfos.append(info)
}
let malformedData = makeWebRtcData(includeIdentity: false) + Data([0x00])
proto.testHandleIncomingData(malformedData)
XCTAssertEqual(webRtcDispatchCount, 0)
XCTAssertEqual(malformedInfos.count, 1)
XCTAssertEqual(malformedInfos.first?.packetId, PacketWebRTC.packetId)
}
func testSignalTypeCreateRoomDecodesWithAndWithoutRoomId() throws {
let withoutRoomData = makeSignalCreateRoomData(roomId: nil)
guard let withoutRoomDecoded = PacketRegistry.decode(from: withoutRoomData),
let withoutRoomPacket = withoutRoomDecoded.packet as? PacketSignalPeer
else {
XCTFail("Failed to decode signalType=4 packet without roomId")
return
}
XCTAssertEqual(withoutRoomDecoded.packetId, PacketSignalPeer.packetId)
XCTAssertFalse(withoutRoomPacket.isMalformed)
XCTAssertEqual(withoutRoomPacket.signalType, .createRoom)
XCTAssertEqual(withoutRoomPacket.roomId, "")
let withRoomData = makeSignalCreateRoomData(roomId: "room-42")
guard let withRoomDecoded = PacketRegistry.decode(from: withRoomData),
let withRoomPacket = withRoomDecoded.packet as? PacketSignalPeer
else {
XCTFail("Failed to decode signalType=4 packet with roomId")
return
}
XCTAssertEqual(withRoomDecoded.packetId, PacketSignalPeer.packetId)
XCTAssertFalse(withRoomPacket.isMalformed)
XCTAssertEqual(withRoomPacket.signalType, .createRoom)
XCTAssertEqual(withRoomPacket.roomId, "room-42")
}
func testWebRtcDecodeSupportsTwoAndFourFieldLayouts() throws {
let twoFieldData = makeWebRtcData(includeIdentity: false)
guard let decodedTwoField = PacketRegistry.decode(from: twoFieldData),
let twoFieldPacket = decodedTwoField.packet as? PacketWebRTC
else {
XCTFail("Failed to decode canonical 2-field 0x1B packet")
return
}
XCTAssertEqual(decodedTwoField.packetId, PacketWebRTC.packetId)
XCTAssertFalse(twoFieldPacket.isMalformed)
XCTAssertEqual(twoFieldPacket.signalType, .offer)
XCTAssertEqual(twoFieldPacket.sdpOrCandidate, "{\"type\":\"offer\",\"sdp\":\"v=0\"}")
XCTAssertEqual(twoFieldPacket.publicKey, "")
XCTAssertEqual(twoFieldPacket.deviceId, "")
let fourFieldData = makeWebRtcData(
includeIdentity: true,
publicKey: "02legacyPublic",
deviceId: "legacy-device"
)
guard let decodedFourField = PacketRegistry.decode(from: fourFieldData),
let fourFieldPacket = decodedFourField.packet as? PacketWebRTC
else {
XCTFail("Failed to decode legacy 4-field 0x1B packet")
return
}
XCTAssertEqual(decodedFourField.packetId, PacketWebRTC.packetId)
XCTAssertFalse(fourFieldPacket.isMalformed)
XCTAssertEqual(fourFieldPacket.signalType, .offer)
XCTAssertEqual(fourFieldPacket.sdpOrCandidate, "{\"type\":\"offer\",\"sdp\":\"v=0\"}")
XCTAssertEqual(fourFieldPacket.publicKey, "02legacyPublic")
XCTAssertEqual(fourFieldPacket.deviceId, "legacy-device")
}
func testWebRtcEncodeUsesCanonicalTwoFieldLayout() {
var packet = PacketWebRTC()
packet.signalType = .answer
packet.sdpOrCandidate = "{\"type\":\"answer\",\"sdp\":\"v=0\"}"
packet.publicKey = "02shouldNotBeEncoded"
packet.deviceId = "device-should-not-be-encoded"
let encoded = PacketRegistry.encode(packet)
let stream = Rosetta.Stream(data: encoded)
XCTAssertEqual(stream.readInt16(), PacketWebRTC.packetId)
XCTAssertEqual(stream.readInt8(), WebRTCSignalType.answer.rawValue)
XCTAssertEqual(stream.readString(), packet.sdpOrCandidate)
XCTAssertFalse(stream.hasRemainingBits())
}
func testMalformedPacketRecoveryDebouncesResyncToSingleRequest() async throws {
let session = SessionManager.shared
let privateKeyHex = try P256K.KeyAgreement.PrivateKey().rawRepresentation.hexString
session.testConfigureSessionForParityFlows(
currentPublicKey: "02malformed_resync_test",
privateKeyHex: privateKeyHex
)
session.testResetMalformedMessageResyncState()
defer {
session.testSetMalformedMessageResyncHook(nil)
session.testResetMalformedMessageResyncState()
}
var triggerCount = 0
session.testSetMalformedMessageResyncHook {
triggerCount += 1
}
session.testSimulateMalformedMessagePacketDrop(packetSize: 111, fingerprint: "fp-1", messageIdHint: "m1")
session.testSimulateMalformedMessagePacketDrop(packetSize: 112, fingerprint: "fp-2", messageIdHint: "m2")
session.testSimulateMalformedMessagePacketDrop(packetSize: 113, fingerprint: "fp-3", messageIdHint: "m3")
try await Task.sleep(nanoseconds: 900_000_000)
XCTAssertEqual(triggerCount, 1)
XCTAssertEqual(session.malformedMessageResyncTriggerCount, 1)
}
func testMalformedPacketRecoveryWaitsUntilSyncBatchEnds() async throws {
let session = SessionManager.shared
let privateKeyHex = try P256K.KeyAgreement.PrivateKey().rawRepresentation.hexString
session.testConfigureSessionForParityFlows(
currentPublicKey: "02malformed_resync_batch_test",
privateKeyHex: privateKeyHex
)
session.testResetMalformedMessageResyncState()
session.testSetSyncState(syncRequestInFlight: false, syncBatchInProgress: true)
defer {
session.testSetSyncState(syncRequestInFlight: false, syncBatchInProgress: false)
session.testSetMalformedMessageResyncHook(nil)
session.testResetMalformedMessageResyncState()
}
var triggerCount = 0
session.testSetMalformedMessageResyncHook {
triggerCount += 1
}
session.testSimulateMalformedMessagePacketDrop(packetSize: 211, fingerprint: "fp-batch", messageIdHint: "m-batch")
// While sync is active, recovery must remain queued and not fire.
try await Task.sleep(nanoseconds: 900_000_000)
XCTAssertEqual(triggerCount, 0)
XCTAssertEqual(session.malformedMessageResyncTriggerCount, 0)
// After sync ends, queued recovery should fire once.
session.testSetSyncState(syncBatchInProgress: false)
try await Task.sleep(nanoseconds: 700_000_000)
XCTAssertEqual(triggerCount, 1)
XCTAssertEqual(session.malformedMessageResyncTriggerCount, 1)
}
func testPostDecryptEmptyPayloadDropsWithoutUpsertAndDebouncesResync() async throws {
let session = SessionManager.shared
let myPrivateKey = try P256K.KeyAgreement.PrivateKey()
let myPrivateKeyData = myPrivateKey.rawRepresentation
let myPrivateKeyHex = myPrivateKeyData.hexString
let myPublicKey = try CryptoManager.shared.deriveCompressedPublicKey(from: myPrivateKeyData).hexString
try DatabaseManager.shared.bootstrap(accountPublicKey: myPublicKey)
await MessageRepository.shared.bootstrap(accountPublicKey: myPublicKey, storagePassword: myPrivateKeyHex)
await DialogRepository.shared.bootstrap(accountPublicKey: myPublicKey, storagePassword: myPrivateKeyHex)
session.testConfigureSessionForParityFlows(
currentPublicKey: myPublicKey,
privateKeyHex: myPrivateKeyHex
)
session.testSetSyncState(syncRequestInFlight: false, syncBatchInProgress: false)
session.testResetMalformedMessageResyncState()
defer {
session.testSetMalformedMessageResyncHook(nil)
session.testResetMalformedMessageResyncState()
session.testSetSyncState(syncRequestInFlight: false, syncBatchInProgress: false)
}
var triggerCount = 0
session.testSetMalformedMessageResyncHook {
triggerCount += 1
}
let peerPrivateKey = try P256K.KeyAgreement.PrivateKey()
let peerPublicKey = try CryptoManager.shared
.deriveCompressedPublicKey(from: peerPrivateKey.rawRepresentation)
.hexString
var processedMessageIds: [String] = []
for index in 0..<3 {
let encrypted = try MessageCrypto.encryptOutgoing(
plaintext: "",
recipientPublicKeyHex: myPublicKey
)
var packet = PacketMessage()
packet.fromPublicKey = peerPublicKey
packet.toPublicKey = myPublicKey
packet.content = encrypted.content
packet.chachaKey = encrypted.chachaKey
packet.timestamp = 1_710_000_000_000 + Int64(index)
packet.privateKey = "hash"
packet.messageId = "empty-decrypted-\(index)-\(UUID().uuidString)"
packet.attachments = []
packet.aesChachaKey = ""
processedMessageIds.append(packet.messageId)
await session.testProcessIncomingMessage(packet)
}
try await Task.sleep(nanoseconds: 900_000_000)
XCTAssertEqual(triggerCount, 1)
XCTAssertEqual(session.malformedMessageResyncTriggerCount, 1)
for messageId in processedMessageIds {
XCTAssertFalse(
MessageRepository.shared.hasMessage(messageId),
"Post-decrypt empty payload must not be persisted as a bubble"
)
}
}
private func makePacketMessageData(metaFieldCount: Int) -> Data {
let stream = Rosetta.Stream()
stream.writeInt16(PacketMessage.packetId)
stream.writeString("02from")
stream.writeString("02to")
stream.writeString("ciphertext")
stream.writeString("chacha-key")
stream.writeInt64(1_710_000_000_000)
stream.writeString("hash")
stream.writeString("msg-hardening")
stream.writeInt8(1)
stream.writeString("att-1")
stream.writeString("preview")
stream.writeString("blob")
stream.writeInt8(AttachmentType.image.rawValue)
if metaFieldCount >= 2 {
stream.writeString("tag-1")
stream.writeString("cdn.rosetta.im")
}
if metaFieldCount >= 4 {
stream.writeString("02encoded-for")
stream.writeString("desktop")
}
stream.writeString("aes-key")
return stream.toData()
}
private func makeHandshakeData(stateRaw: Int) -> Data {
let stream = Rosetta.Stream()
stream.writeInt16(PacketHandshake.packetId)
stream.writeString("hash")
stream.writeString("02public")
stream.writeInt8(1)
stream.writeInt8(15)
stream.writeString("device-id")
stream.writeString("iPhone")
stream.writeString("iOS")
stream.writeInt8(stateRaw)
return stream.toData()
}
private func makeSyncData(statusRaw: Int) -> Data {
let stream = Rosetta.Stream()
stream.writeInt16(PacketSync.packetId)
stream.writeInt8(statusRaw)
stream.writeInt64(1_710_000_000_000)
return stream.toData()
}
private func makeSignalCreateRoomData(roomId: String?) -> Data {
let stream = Rosetta.Stream()
stream.writeInt16(PacketSignalPeer.packetId)
stream.writeInt8(SignalType.createRoom.rawValue)
stream.writeString("02src")
stream.writeString("02dst")
if let roomId {
stream.writeString(roomId)
}
return stream.toData()
}
private func makeWebRtcData(
includeIdentity: Bool,
publicKey: String = "",
deviceId: String = ""
) -> Data {
let stream = Rosetta.Stream()
stream.writeInt16(PacketWebRTC.packetId)
stream.writeInt8(WebRTCSignalType.offer.rawValue)
stream.writeString("{\"type\":\"offer\",\"sdp\":\"v=0\"}")
if includeIdentity {
stream.writeString(publicKey)
stream.writeString(deviceId)
}
return stream.toData()
}
}