452 lines
17 KiB
Swift
452 lines
17 KiB
Swift
import XCTest
|
|
import P256K
|
|
@testable import Rosetta
|
|
|
|
@MainActor
|
|
final class MessageDecodeHardeningTests: XCTestCase {
|
|
|
|
func testProtocolManagerDropsMalformedMessagePacketBeforeDispatch() {
|
|
let proto = ProtocolManager.shared
|
|
let originalOnMessage = proto.onMessageReceived
|
|
let originalOnMalformed = proto.onMalformedMessageReceived
|
|
defer {
|
|
proto.onMessageReceived = originalOnMessage
|
|
proto.onMalformedMessageReceived = originalOnMalformed
|
|
}
|
|
|
|
var messageDispatchCount = 0
|
|
var malformedCount = 0
|
|
proto.onMessageReceived = { _ in
|
|
messageDispatchCount += 1
|
|
}
|
|
proto.onMalformedMessageReceived = { _ in
|
|
malformedCount += 1
|
|
}
|
|
|
|
let malformedData = makePacketMessageData(metaFieldCount: 2) + Data([0x00])
|
|
proto.testHandleIncomingData(malformedData)
|
|
|
|
XCTAssertEqual(messageDispatchCount, 0)
|
|
XCTAssertEqual(malformedCount, 1)
|
|
}
|
|
|
|
func testProtocolManagerStillDispatchesValidMessagePacket() {
|
|
let proto = ProtocolManager.shared
|
|
let originalOnMessage = proto.onMessageReceived
|
|
let originalOnMalformed = proto.onMalformedMessageReceived
|
|
defer {
|
|
proto.onMessageReceived = originalOnMessage
|
|
proto.onMalformedMessageReceived = originalOnMalformed
|
|
}
|
|
|
|
var messageDispatchCount = 0
|
|
var malformedCount = 0
|
|
proto.onMessageReceived = { _ in
|
|
messageDispatchCount += 1
|
|
}
|
|
proto.onMalformedMessageReceived = { _ in
|
|
malformedCount += 1
|
|
}
|
|
|
|
let validData = makePacketMessageData(metaFieldCount: 2)
|
|
proto.testHandleIncomingData(validData)
|
|
|
|
XCTAssertEqual(messageDispatchCount, 1)
|
|
XCTAssertEqual(malformedCount, 0)
|
|
}
|
|
|
|
func testProtocolManagerDropsMalformedHandshakePacketBeforeDispatch() {
|
|
let proto = ProtocolManager.shared
|
|
let originalOnHandshake = proto.onHandshakeCompleted
|
|
let originalOnMalformedCritical = proto.onMalformedCriticalPacketReceived
|
|
defer {
|
|
proto.onHandshakeCompleted = originalOnHandshake
|
|
proto.onMalformedCriticalPacketReceived = originalOnMalformedCritical
|
|
}
|
|
|
|
var handshakeDispatchCount = 0
|
|
var malformedInfos: [MalformedCriticalPacketInfo] = []
|
|
proto.onHandshakeCompleted = { _ in
|
|
handshakeDispatchCount += 1
|
|
}
|
|
proto.onMalformedCriticalPacketReceived = { info in
|
|
malformedInfos.append(info)
|
|
}
|
|
|
|
let malformedData = makeHandshakeData(stateRaw: 9)
|
|
proto.testHandleIncomingData(malformedData)
|
|
|
|
XCTAssertEqual(handshakeDispatchCount, 0)
|
|
XCTAssertEqual(malformedInfos.count, 1)
|
|
XCTAssertEqual(malformedInfos.first?.packetId, PacketHandshake.packetId)
|
|
}
|
|
|
|
func testProtocolManagerDropsMalformedSyncPacketBeforeDispatch() {
|
|
let proto = ProtocolManager.shared
|
|
let originalOnSync = proto.onSyncReceived
|
|
let originalOnMalformedCritical = proto.onMalformedCriticalPacketReceived
|
|
defer {
|
|
proto.onSyncReceived = originalOnSync
|
|
proto.onMalformedCriticalPacketReceived = originalOnMalformedCritical
|
|
}
|
|
|
|
var syncDispatchCount = 0
|
|
var malformedInfos: [MalformedCriticalPacketInfo] = []
|
|
proto.onSyncReceived = { _ in
|
|
syncDispatchCount += 1
|
|
}
|
|
proto.onMalformedCriticalPacketReceived = { info in
|
|
malformedInfos.append(info)
|
|
}
|
|
|
|
let malformedData = makeSyncData(statusRaw: 7)
|
|
proto.testHandleIncomingData(malformedData)
|
|
|
|
XCTAssertEqual(syncDispatchCount, 0)
|
|
XCTAssertEqual(malformedInfos.count, 1)
|
|
XCTAssertEqual(malformedInfos.first?.packetId, PacketSync.packetId)
|
|
}
|
|
|
|
func testProtocolManagerDropsMalformedSignalPacketBeforeDispatch() {
|
|
let proto = ProtocolManager.shared
|
|
let originalOnSignal = proto.onSignalPeerReceived
|
|
let originalOnMalformedCritical = proto.onMalformedCriticalPacketReceived
|
|
defer {
|
|
proto.onSignalPeerReceived = originalOnSignal
|
|
proto.onMalformedCriticalPacketReceived = originalOnMalformedCritical
|
|
}
|
|
|
|
var signalDispatchCount = 0
|
|
var malformedInfos: [MalformedCriticalPacketInfo] = []
|
|
proto.onSignalPeerReceived = { _ in
|
|
signalDispatchCount += 1
|
|
}
|
|
proto.onMalformedCriticalPacketReceived = { info in
|
|
malformedInfos.append(info)
|
|
}
|
|
|
|
var signal = PacketSignalPeer()
|
|
signal.signalType = .endCallBecauseBusy
|
|
let malformedData = PacketRegistry.encode(signal) + Data([0x00])
|
|
proto.testHandleIncomingData(malformedData)
|
|
|
|
XCTAssertEqual(signalDispatchCount, 0)
|
|
XCTAssertEqual(malformedInfos.count, 1)
|
|
XCTAssertEqual(malformedInfos.first?.packetId, PacketSignalPeer.packetId)
|
|
}
|
|
|
|
func testProtocolManagerDropsMalformedWebRtcPacketBeforeDispatch() {
|
|
let proto = ProtocolManager.shared
|
|
let originalOnWebRtc = proto.onWebRTCReceived
|
|
let originalOnMalformedCritical = proto.onMalformedCriticalPacketReceived
|
|
defer {
|
|
proto.onWebRTCReceived = originalOnWebRtc
|
|
proto.onMalformedCriticalPacketReceived = originalOnMalformedCritical
|
|
}
|
|
|
|
var webRtcDispatchCount = 0
|
|
var malformedInfos: [MalformedCriticalPacketInfo] = []
|
|
proto.onWebRTCReceived = { _ in
|
|
webRtcDispatchCount += 1
|
|
}
|
|
proto.onMalformedCriticalPacketReceived = { info in
|
|
malformedInfos.append(info)
|
|
}
|
|
|
|
let malformedData = makeWebRtcData(includeIdentity: false) + Data([0x00])
|
|
proto.testHandleIncomingData(malformedData)
|
|
|
|
XCTAssertEqual(webRtcDispatchCount, 0)
|
|
XCTAssertEqual(malformedInfos.count, 1)
|
|
XCTAssertEqual(malformedInfos.first?.packetId, PacketWebRTC.packetId)
|
|
}
|
|
|
|
func testSignalTypeCreateRoomDecodesWithAndWithoutRoomId() throws {
|
|
let withoutRoomData = makeSignalCreateRoomData(roomId: nil)
|
|
guard let withoutRoomDecoded = PacketRegistry.decode(from: withoutRoomData),
|
|
let withoutRoomPacket = withoutRoomDecoded.packet as? PacketSignalPeer
|
|
else {
|
|
XCTFail("Failed to decode signalType=4 packet without roomId")
|
|
return
|
|
}
|
|
|
|
XCTAssertEqual(withoutRoomDecoded.packetId, PacketSignalPeer.packetId)
|
|
XCTAssertFalse(withoutRoomPacket.isMalformed)
|
|
XCTAssertEqual(withoutRoomPacket.signalType, .createRoom)
|
|
XCTAssertEqual(withoutRoomPacket.roomId, "")
|
|
|
|
let withRoomData = makeSignalCreateRoomData(roomId: "room-42")
|
|
guard let withRoomDecoded = PacketRegistry.decode(from: withRoomData),
|
|
let withRoomPacket = withRoomDecoded.packet as? PacketSignalPeer
|
|
else {
|
|
XCTFail("Failed to decode signalType=4 packet with roomId")
|
|
return
|
|
}
|
|
|
|
XCTAssertEqual(withRoomDecoded.packetId, PacketSignalPeer.packetId)
|
|
XCTAssertFalse(withRoomPacket.isMalformed)
|
|
XCTAssertEqual(withRoomPacket.signalType, .createRoom)
|
|
XCTAssertEqual(withRoomPacket.roomId, "room-42")
|
|
}
|
|
|
|
func testWebRtcDecodeSupportsTwoAndFourFieldLayouts() throws {
|
|
let twoFieldData = makeWebRtcData(includeIdentity: false)
|
|
guard let decodedTwoField = PacketRegistry.decode(from: twoFieldData),
|
|
let twoFieldPacket = decodedTwoField.packet as? PacketWebRTC
|
|
else {
|
|
XCTFail("Failed to decode canonical 2-field 0x1B packet")
|
|
return
|
|
}
|
|
|
|
XCTAssertEqual(decodedTwoField.packetId, PacketWebRTC.packetId)
|
|
XCTAssertFalse(twoFieldPacket.isMalformed)
|
|
XCTAssertEqual(twoFieldPacket.signalType, .offer)
|
|
XCTAssertEqual(twoFieldPacket.sdpOrCandidate, "{\"type\":\"offer\",\"sdp\":\"v=0\"}")
|
|
XCTAssertEqual(twoFieldPacket.publicKey, "")
|
|
XCTAssertEqual(twoFieldPacket.deviceId, "")
|
|
|
|
let fourFieldData = makeWebRtcData(
|
|
includeIdentity: true,
|
|
publicKey: "02legacyPublic",
|
|
deviceId: "legacy-device"
|
|
)
|
|
guard let decodedFourField = PacketRegistry.decode(from: fourFieldData),
|
|
let fourFieldPacket = decodedFourField.packet as? PacketWebRTC
|
|
else {
|
|
XCTFail("Failed to decode legacy 4-field 0x1B packet")
|
|
return
|
|
}
|
|
|
|
XCTAssertEqual(decodedFourField.packetId, PacketWebRTC.packetId)
|
|
XCTAssertFalse(fourFieldPacket.isMalformed)
|
|
XCTAssertEqual(fourFieldPacket.signalType, .offer)
|
|
XCTAssertEqual(fourFieldPacket.sdpOrCandidate, "{\"type\":\"offer\",\"sdp\":\"v=0\"}")
|
|
XCTAssertEqual(fourFieldPacket.publicKey, "02legacyPublic")
|
|
XCTAssertEqual(fourFieldPacket.deviceId, "legacy-device")
|
|
}
|
|
|
|
func testWebRtcEncodeUsesCanonicalTwoFieldLayout() {
|
|
var packet = PacketWebRTC()
|
|
packet.signalType = .answer
|
|
packet.sdpOrCandidate = "{\"type\":\"answer\",\"sdp\":\"v=0\"}"
|
|
packet.publicKey = "02shouldNotBeEncoded"
|
|
packet.deviceId = "device-should-not-be-encoded"
|
|
|
|
let encoded = PacketRegistry.encode(packet)
|
|
let stream = Rosetta.Stream(data: encoded)
|
|
XCTAssertEqual(stream.readInt16(), PacketWebRTC.packetId)
|
|
XCTAssertEqual(stream.readInt8(), WebRTCSignalType.answer.rawValue)
|
|
XCTAssertEqual(stream.readString(), packet.sdpOrCandidate)
|
|
XCTAssertFalse(stream.hasRemainingBits())
|
|
}
|
|
|
|
func testMalformedPacketRecoveryDebouncesResyncToSingleRequest() async throws {
|
|
let session = SessionManager.shared
|
|
let privateKeyHex = try P256K.KeyAgreement.PrivateKey().rawRepresentation.hexString
|
|
session.testConfigureSessionForParityFlows(
|
|
currentPublicKey: "02malformed_resync_test",
|
|
privateKeyHex: privateKeyHex
|
|
)
|
|
|
|
session.testResetMalformedMessageResyncState()
|
|
defer {
|
|
session.testSetMalformedMessageResyncHook(nil)
|
|
session.testResetMalformedMessageResyncState()
|
|
}
|
|
|
|
var triggerCount = 0
|
|
session.testSetMalformedMessageResyncHook {
|
|
triggerCount += 1
|
|
}
|
|
|
|
session.testSimulateMalformedMessagePacketDrop(packetSize: 111, fingerprint: "fp-1", messageIdHint: "m1")
|
|
session.testSimulateMalformedMessagePacketDrop(packetSize: 112, fingerprint: "fp-2", messageIdHint: "m2")
|
|
session.testSimulateMalformedMessagePacketDrop(packetSize: 113, fingerprint: "fp-3", messageIdHint: "m3")
|
|
|
|
try await Task.sleep(nanoseconds: 900_000_000)
|
|
|
|
XCTAssertEqual(triggerCount, 1)
|
|
XCTAssertEqual(session.malformedMessageResyncTriggerCount, 1)
|
|
}
|
|
|
|
func testMalformedPacketRecoveryWaitsUntilSyncBatchEnds() async throws {
|
|
let session = SessionManager.shared
|
|
let privateKeyHex = try P256K.KeyAgreement.PrivateKey().rawRepresentation.hexString
|
|
session.testConfigureSessionForParityFlows(
|
|
currentPublicKey: "02malformed_resync_batch_test",
|
|
privateKeyHex: privateKeyHex
|
|
)
|
|
|
|
session.testResetMalformedMessageResyncState()
|
|
session.testSetSyncState(syncRequestInFlight: false, syncBatchInProgress: true)
|
|
defer {
|
|
session.testSetSyncState(syncRequestInFlight: false, syncBatchInProgress: false)
|
|
session.testSetMalformedMessageResyncHook(nil)
|
|
session.testResetMalformedMessageResyncState()
|
|
}
|
|
|
|
var triggerCount = 0
|
|
session.testSetMalformedMessageResyncHook {
|
|
triggerCount += 1
|
|
}
|
|
|
|
session.testSimulateMalformedMessagePacketDrop(packetSize: 211, fingerprint: "fp-batch", messageIdHint: "m-batch")
|
|
|
|
// While sync is active, recovery must remain queued and not fire.
|
|
try await Task.sleep(nanoseconds: 900_000_000)
|
|
XCTAssertEqual(triggerCount, 0)
|
|
XCTAssertEqual(session.malformedMessageResyncTriggerCount, 0)
|
|
|
|
// After sync ends, queued recovery should fire once.
|
|
session.testSetSyncState(syncBatchInProgress: false)
|
|
try await Task.sleep(nanoseconds: 700_000_000)
|
|
XCTAssertEqual(triggerCount, 1)
|
|
XCTAssertEqual(session.malformedMessageResyncTriggerCount, 1)
|
|
}
|
|
|
|
func testPostDecryptEmptyPayloadDropsWithoutUpsertAndDebouncesResync() async throws {
|
|
let session = SessionManager.shared
|
|
let myPrivateKey = try P256K.KeyAgreement.PrivateKey()
|
|
let myPrivateKeyData = myPrivateKey.rawRepresentation
|
|
let myPrivateKeyHex = myPrivateKeyData.hexString
|
|
let myPublicKey = try CryptoManager.shared.deriveCompressedPublicKey(from: myPrivateKeyData).hexString
|
|
|
|
try DatabaseManager.shared.bootstrap(accountPublicKey: myPublicKey)
|
|
await MessageRepository.shared.bootstrap(accountPublicKey: myPublicKey, storagePassword: myPrivateKeyHex)
|
|
await DialogRepository.shared.bootstrap(accountPublicKey: myPublicKey, storagePassword: myPrivateKeyHex)
|
|
|
|
session.testConfigureSessionForParityFlows(
|
|
currentPublicKey: myPublicKey,
|
|
privateKeyHex: myPrivateKeyHex
|
|
)
|
|
session.testSetSyncState(syncRequestInFlight: false, syncBatchInProgress: false)
|
|
session.testResetMalformedMessageResyncState()
|
|
defer {
|
|
session.testSetMalformedMessageResyncHook(nil)
|
|
session.testResetMalformedMessageResyncState()
|
|
session.testSetSyncState(syncRequestInFlight: false, syncBatchInProgress: false)
|
|
}
|
|
|
|
var triggerCount = 0
|
|
session.testSetMalformedMessageResyncHook {
|
|
triggerCount += 1
|
|
}
|
|
|
|
let peerPrivateKey = try P256K.KeyAgreement.PrivateKey()
|
|
let peerPublicKey = try CryptoManager.shared
|
|
.deriveCompressedPublicKey(from: peerPrivateKey.rawRepresentation)
|
|
.hexString
|
|
|
|
var processedMessageIds: [String] = []
|
|
for index in 0..<3 {
|
|
let encrypted = try MessageCrypto.encryptOutgoing(
|
|
plaintext: "",
|
|
recipientPublicKeyHex: myPublicKey
|
|
)
|
|
|
|
var packet = PacketMessage()
|
|
packet.fromPublicKey = peerPublicKey
|
|
packet.toPublicKey = myPublicKey
|
|
packet.content = encrypted.content
|
|
packet.chachaKey = encrypted.chachaKey
|
|
packet.timestamp = 1_710_000_000_000 + Int64(index)
|
|
packet.privateKey = "hash"
|
|
packet.messageId = "empty-decrypted-\(index)-\(UUID().uuidString)"
|
|
packet.attachments = []
|
|
packet.aesChachaKey = ""
|
|
|
|
processedMessageIds.append(packet.messageId)
|
|
await session.testProcessIncomingMessage(packet)
|
|
}
|
|
|
|
try await Task.sleep(nanoseconds: 900_000_000)
|
|
|
|
XCTAssertEqual(triggerCount, 1)
|
|
XCTAssertEqual(session.malformedMessageResyncTriggerCount, 1)
|
|
|
|
for messageId in processedMessageIds {
|
|
XCTAssertFalse(
|
|
MessageRepository.shared.hasMessage(messageId),
|
|
"Post-decrypt empty payload must not be persisted as a bubble"
|
|
)
|
|
}
|
|
}
|
|
|
|
private func makePacketMessageData(metaFieldCount: Int) -> Data {
|
|
let stream = Rosetta.Stream()
|
|
stream.writeInt16(PacketMessage.packetId)
|
|
stream.writeString("02from")
|
|
stream.writeString("02to")
|
|
stream.writeString("ciphertext")
|
|
stream.writeString("chacha-key")
|
|
stream.writeInt64(1_710_000_000_000)
|
|
stream.writeString("hash")
|
|
stream.writeString("msg-hardening")
|
|
stream.writeInt8(1)
|
|
stream.writeString("att-1")
|
|
stream.writeString("preview")
|
|
stream.writeString("blob")
|
|
stream.writeInt8(AttachmentType.image.rawValue)
|
|
if metaFieldCount >= 2 {
|
|
stream.writeString("tag-1")
|
|
stream.writeString("cdn.rosetta.im")
|
|
}
|
|
if metaFieldCount >= 4 {
|
|
stream.writeString("02encoded-for")
|
|
stream.writeString("desktop")
|
|
}
|
|
stream.writeString("aes-key")
|
|
return stream.toData()
|
|
}
|
|
|
|
private func makeHandshakeData(stateRaw: Int) -> Data {
|
|
let stream = Rosetta.Stream()
|
|
stream.writeInt16(PacketHandshake.packetId)
|
|
stream.writeString("hash")
|
|
stream.writeString("02public")
|
|
stream.writeInt8(1)
|
|
stream.writeInt8(15)
|
|
stream.writeString("device-id")
|
|
stream.writeString("iPhone")
|
|
stream.writeString("iOS")
|
|
stream.writeInt8(stateRaw)
|
|
return stream.toData()
|
|
}
|
|
|
|
private func makeSyncData(statusRaw: Int) -> Data {
|
|
let stream = Rosetta.Stream()
|
|
stream.writeInt16(PacketSync.packetId)
|
|
stream.writeInt8(statusRaw)
|
|
stream.writeInt64(1_710_000_000_000)
|
|
return stream.toData()
|
|
}
|
|
|
|
private func makeSignalCreateRoomData(roomId: String?) -> Data {
|
|
let stream = Rosetta.Stream()
|
|
stream.writeInt16(PacketSignalPeer.packetId)
|
|
stream.writeInt8(SignalType.createRoom.rawValue)
|
|
stream.writeString("02src")
|
|
stream.writeString("02dst")
|
|
if let roomId {
|
|
stream.writeString(roomId)
|
|
}
|
|
return stream.toData()
|
|
}
|
|
|
|
private func makeWebRtcData(
|
|
includeIdentity: Bool,
|
|
publicKey: String = "",
|
|
deviceId: String = ""
|
|
) -> Data {
|
|
let stream = Rosetta.Stream()
|
|
stream.writeInt16(PacketWebRTC.packetId)
|
|
stream.writeInt8(WebRTCSignalType.offer.rawValue)
|
|
stream.writeString("{\"type\":\"offer\",\"sdp\":\"v=0\"}")
|
|
if includeIdentity {
|
|
stream.writeString(publicKey)
|
|
stream.writeString(deviceId)
|
|
}
|
|
return stream.toData()
|
|
}
|
|
}
|