Фикс: исправлено исчезновение части уведомлений при открытии пуша
This commit is contained in:
451
RosettaTests/MessageDecodeHardeningTests.swift
Normal file
451
RosettaTests/MessageDecodeHardeningTests.swift
Normal file
@@ -0,0 +1,451 @@
|
||||
import XCTest
|
||||
import P256K
|
||||
@testable import Rosetta
|
||||
|
||||
@MainActor
|
||||
final class MessageDecodeHardeningTests: XCTestCase {
|
||||
|
||||
func testProtocolManagerDropsMalformedMessagePacketBeforeDispatch() {
|
||||
let proto = ProtocolManager.shared
|
||||
let originalOnMessage = proto.onMessageReceived
|
||||
let originalOnMalformed = proto.onMalformedMessageReceived
|
||||
defer {
|
||||
proto.onMessageReceived = originalOnMessage
|
||||
proto.onMalformedMessageReceived = originalOnMalformed
|
||||
}
|
||||
|
||||
var messageDispatchCount = 0
|
||||
var malformedCount = 0
|
||||
proto.onMessageReceived = { _ in
|
||||
messageDispatchCount += 1
|
||||
}
|
||||
proto.onMalformedMessageReceived = { _ in
|
||||
malformedCount += 1
|
||||
}
|
||||
|
||||
let malformedData = makePacketMessageData(metaFieldCount: 2) + Data([0x00])
|
||||
proto.testHandleIncomingData(malformedData)
|
||||
|
||||
XCTAssertEqual(messageDispatchCount, 0)
|
||||
XCTAssertEqual(malformedCount, 1)
|
||||
}
|
||||
|
||||
func testProtocolManagerStillDispatchesValidMessagePacket() {
|
||||
let proto = ProtocolManager.shared
|
||||
let originalOnMessage = proto.onMessageReceived
|
||||
let originalOnMalformed = proto.onMalformedMessageReceived
|
||||
defer {
|
||||
proto.onMessageReceived = originalOnMessage
|
||||
proto.onMalformedMessageReceived = originalOnMalformed
|
||||
}
|
||||
|
||||
var messageDispatchCount = 0
|
||||
var malformedCount = 0
|
||||
proto.onMessageReceived = { _ in
|
||||
messageDispatchCount += 1
|
||||
}
|
||||
proto.onMalformedMessageReceived = { _ in
|
||||
malformedCount += 1
|
||||
}
|
||||
|
||||
let validData = makePacketMessageData(metaFieldCount: 2)
|
||||
proto.testHandleIncomingData(validData)
|
||||
|
||||
XCTAssertEqual(messageDispatchCount, 1)
|
||||
XCTAssertEqual(malformedCount, 0)
|
||||
}
|
||||
|
||||
func testProtocolManagerDropsMalformedHandshakePacketBeforeDispatch() {
|
||||
let proto = ProtocolManager.shared
|
||||
let originalOnHandshake = proto.onHandshakeCompleted
|
||||
let originalOnMalformedCritical = proto.onMalformedCriticalPacketReceived
|
||||
defer {
|
||||
proto.onHandshakeCompleted = originalOnHandshake
|
||||
proto.onMalformedCriticalPacketReceived = originalOnMalformedCritical
|
||||
}
|
||||
|
||||
var handshakeDispatchCount = 0
|
||||
var malformedInfos: [MalformedCriticalPacketInfo] = []
|
||||
proto.onHandshakeCompleted = { _ in
|
||||
handshakeDispatchCount += 1
|
||||
}
|
||||
proto.onMalformedCriticalPacketReceived = { info in
|
||||
malformedInfos.append(info)
|
||||
}
|
||||
|
||||
let malformedData = makeHandshakeData(stateRaw: 9)
|
||||
proto.testHandleIncomingData(malformedData)
|
||||
|
||||
XCTAssertEqual(handshakeDispatchCount, 0)
|
||||
XCTAssertEqual(malformedInfos.count, 1)
|
||||
XCTAssertEqual(malformedInfos.first?.packetId, PacketHandshake.packetId)
|
||||
}
|
||||
|
||||
func testProtocolManagerDropsMalformedSyncPacketBeforeDispatch() {
|
||||
let proto = ProtocolManager.shared
|
||||
let originalOnSync = proto.onSyncReceived
|
||||
let originalOnMalformedCritical = proto.onMalformedCriticalPacketReceived
|
||||
defer {
|
||||
proto.onSyncReceived = originalOnSync
|
||||
proto.onMalformedCriticalPacketReceived = originalOnMalformedCritical
|
||||
}
|
||||
|
||||
var syncDispatchCount = 0
|
||||
var malformedInfos: [MalformedCriticalPacketInfo] = []
|
||||
proto.onSyncReceived = { _ in
|
||||
syncDispatchCount += 1
|
||||
}
|
||||
proto.onMalformedCriticalPacketReceived = { info in
|
||||
malformedInfos.append(info)
|
||||
}
|
||||
|
||||
let malformedData = makeSyncData(statusRaw: 7)
|
||||
proto.testHandleIncomingData(malformedData)
|
||||
|
||||
XCTAssertEqual(syncDispatchCount, 0)
|
||||
XCTAssertEqual(malformedInfos.count, 1)
|
||||
XCTAssertEqual(malformedInfos.first?.packetId, PacketSync.packetId)
|
||||
}
|
||||
|
||||
func testProtocolManagerDropsMalformedSignalPacketBeforeDispatch() {
|
||||
let proto = ProtocolManager.shared
|
||||
let originalOnSignal = proto.onSignalPeerReceived
|
||||
let originalOnMalformedCritical = proto.onMalformedCriticalPacketReceived
|
||||
defer {
|
||||
proto.onSignalPeerReceived = originalOnSignal
|
||||
proto.onMalformedCriticalPacketReceived = originalOnMalformedCritical
|
||||
}
|
||||
|
||||
var signalDispatchCount = 0
|
||||
var malformedInfos: [MalformedCriticalPacketInfo] = []
|
||||
proto.onSignalPeerReceived = { _ in
|
||||
signalDispatchCount += 1
|
||||
}
|
||||
proto.onMalformedCriticalPacketReceived = { info in
|
||||
malformedInfos.append(info)
|
||||
}
|
||||
|
||||
var signal = PacketSignalPeer()
|
||||
signal.signalType = .endCallBecauseBusy
|
||||
let malformedData = PacketRegistry.encode(signal) + Data([0x00])
|
||||
proto.testHandleIncomingData(malformedData)
|
||||
|
||||
XCTAssertEqual(signalDispatchCount, 0)
|
||||
XCTAssertEqual(malformedInfos.count, 1)
|
||||
XCTAssertEqual(malformedInfos.first?.packetId, PacketSignalPeer.packetId)
|
||||
}
|
||||
|
||||
func testProtocolManagerDropsMalformedWebRtcPacketBeforeDispatch() {
|
||||
let proto = ProtocolManager.shared
|
||||
let originalOnWebRtc = proto.onWebRTCReceived
|
||||
let originalOnMalformedCritical = proto.onMalformedCriticalPacketReceived
|
||||
defer {
|
||||
proto.onWebRTCReceived = originalOnWebRtc
|
||||
proto.onMalformedCriticalPacketReceived = originalOnMalformedCritical
|
||||
}
|
||||
|
||||
var webRtcDispatchCount = 0
|
||||
var malformedInfos: [MalformedCriticalPacketInfo] = []
|
||||
proto.onWebRTCReceived = { _ in
|
||||
webRtcDispatchCount += 1
|
||||
}
|
||||
proto.onMalformedCriticalPacketReceived = { info in
|
||||
malformedInfos.append(info)
|
||||
}
|
||||
|
||||
let malformedData = makeWebRtcData(includeIdentity: false) + Data([0x00])
|
||||
proto.testHandleIncomingData(malformedData)
|
||||
|
||||
XCTAssertEqual(webRtcDispatchCount, 0)
|
||||
XCTAssertEqual(malformedInfos.count, 1)
|
||||
XCTAssertEqual(malformedInfos.first?.packetId, PacketWebRTC.packetId)
|
||||
}
|
||||
|
||||
func testSignalTypeCreateRoomDecodesWithAndWithoutRoomId() throws {
|
||||
let withoutRoomData = makeSignalCreateRoomData(roomId: nil)
|
||||
guard let withoutRoomDecoded = PacketRegistry.decode(from: withoutRoomData),
|
||||
let withoutRoomPacket = withoutRoomDecoded.packet as? PacketSignalPeer
|
||||
else {
|
||||
XCTFail("Failed to decode signalType=4 packet without roomId")
|
||||
return
|
||||
}
|
||||
|
||||
XCTAssertEqual(withoutRoomDecoded.packetId, PacketSignalPeer.packetId)
|
||||
XCTAssertFalse(withoutRoomPacket.isMalformed)
|
||||
XCTAssertEqual(withoutRoomPacket.signalType, .createRoom)
|
||||
XCTAssertEqual(withoutRoomPacket.roomId, "")
|
||||
|
||||
let withRoomData = makeSignalCreateRoomData(roomId: "room-42")
|
||||
guard let withRoomDecoded = PacketRegistry.decode(from: withRoomData),
|
||||
let withRoomPacket = withRoomDecoded.packet as? PacketSignalPeer
|
||||
else {
|
||||
XCTFail("Failed to decode signalType=4 packet with roomId")
|
||||
return
|
||||
}
|
||||
|
||||
XCTAssertEqual(withRoomDecoded.packetId, PacketSignalPeer.packetId)
|
||||
XCTAssertFalse(withRoomPacket.isMalformed)
|
||||
XCTAssertEqual(withRoomPacket.signalType, .createRoom)
|
||||
XCTAssertEqual(withRoomPacket.roomId, "room-42")
|
||||
}
|
||||
|
||||
func testWebRtcDecodeSupportsTwoAndFourFieldLayouts() throws {
|
||||
let twoFieldData = makeWebRtcData(includeIdentity: false)
|
||||
guard let decodedTwoField = PacketRegistry.decode(from: twoFieldData),
|
||||
let twoFieldPacket = decodedTwoField.packet as? PacketWebRTC
|
||||
else {
|
||||
XCTFail("Failed to decode canonical 2-field 0x1B packet")
|
||||
return
|
||||
}
|
||||
|
||||
XCTAssertEqual(decodedTwoField.packetId, PacketWebRTC.packetId)
|
||||
XCTAssertFalse(twoFieldPacket.isMalformed)
|
||||
XCTAssertEqual(twoFieldPacket.signalType, .offer)
|
||||
XCTAssertEqual(twoFieldPacket.sdpOrCandidate, "{\"type\":\"offer\",\"sdp\":\"v=0\"}")
|
||||
XCTAssertEqual(twoFieldPacket.publicKey, "")
|
||||
XCTAssertEqual(twoFieldPacket.deviceId, "")
|
||||
|
||||
let fourFieldData = makeWebRtcData(
|
||||
includeIdentity: true,
|
||||
publicKey: "02legacyPublic",
|
||||
deviceId: "legacy-device"
|
||||
)
|
||||
guard let decodedFourField = PacketRegistry.decode(from: fourFieldData),
|
||||
let fourFieldPacket = decodedFourField.packet as? PacketWebRTC
|
||||
else {
|
||||
XCTFail("Failed to decode legacy 4-field 0x1B packet")
|
||||
return
|
||||
}
|
||||
|
||||
XCTAssertEqual(decodedFourField.packetId, PacketWebRTC.packetId)
|
||||
XCTAssertFalse(fourFieldPacket.isMalformed)
|
||||
XCTAssertEqual(fourFieldPacket.signalType, .offer)
|
||||
XCTAssertEqual(fourFieldPacket.sdpOrCandidate, "{\"type\":\"offer\",\"sdp\":\"v=0\"}")
|
||||
XCTAssertEqual(fourFieldPacket.publicKey, "02legacyPublic")
|
||||
XCTAssertEqual(fourFieldPacket.deviceId, "legacy-device")
|
||||
}
|
||||
|
||||
func testWebRtcEncodeUsesCanonicalTwoFieldLayout() {
|
||||
var packet = PacketWebRTC()
|
||||
packet.signalType = .answer
|
||||
packet.sdpOrCandidate = "{\"type\":\"answer\",\"sdp\":\"v=0\"}"
|
||||
packet.publicKey = "02shouldNotBeEncoded"
|
||||
packet.deviceId = "device-should-not-be-encoded"
|
||||
|
||||
let encoded = PacketRegistry.encode(packet)
|
||||
let stream = Rosetta.Stream(data: encoded)
|
||||
XCTAssertEqual(stream.readInt16(), PacketWebRTC.packetId)
|
||||
XCTAssertEqual(stream.readInt8(), WebRTCSignalType.answer.rawValue)
|
||||
XCTAssertEqual(stream.readString(), packet.sdpOrCandidate)
|
||||
XCTAssertFalse(stream.hasRemainingBits())
|
||||
}
|
||||
|
||||
func testMalformedPacketRecoveryDebouncesResyncToSingleRequest() async throws {
|
||||
let session = SessionManager.shared
|
||||
let privateKeyHex = try P256K.KeyAgreement.PrivateKey().rawRepresentation.hexString
|
||||
session.testConfigureSessionForParityFlows(
|
||||
currentPublicKey: "02malformed_resync_test",
|
||||
privateKeyHex: privateKeyHex
|
||||
)
|
||||
|
||||
session.testResetMalformedMessageResyncState()
|
||||
defer {
|
||||
session.testSetMalformedMessageResyncHook(nil)
|
||||
session.testResetMalformedMessageResyncState()
|
||||
}
|
||||
|
||||
var triggerCount = 0
|
||||
session.testSetMalformedMessageResyncHook {
|
||||
triggerCount += 1
|
||||
}
|
||||
|
||||
session.testSimulateMalformedMessagePacketDrop(packetSize: 111, fingerprint: "fp-1", messageIdHint: "m1")
|
||||
session.testSimulateMalformedMessagePacketDrop(packetSize: 112, fingerprint: "fp-2", messageIdHint: "m2")
|
||||
session.testSimulateMalformedMessagePacketDrop(packetSize: 113, fingerprint: "fp-3", messageIdHint: "m3")
|
||||
|
||||
try await Task.sleep(nanoseconds: 900_000_000)
|
||||
|
||||
XCTAssertEqual(triggerCount, 1)
|
||||
XCTAssertEqual(session.malformedMessageResyncTriggerCount, 1)
|
||||
}
|
||||
|
||||
func testMalformedPacketRecoveryWaitsUntilSyncBatchEnds() async throws {
|
||||
let session = SessionManager.shared
|
||||
let privateKeyHex = try P256K.KeyAgreement.PrivateKey().rawRepresentation.hexString
|
||||
session.testConfigureSessionForParityFlows(
|
||||
currentPublicKey: "02malformed_resync_batch_test",
|
||||
privateKeyHex: privateKeyHex
|
||||
)
|
||||
|
||||
session.testResetMalformedMessageResyncState()
|
||||
session.testSetSyncState(syncRequestInFlight: false, syncBatchInProgress: true)
|
||||
defer {
|
||||
session.testSetSyncState(syncRequestInFlight: false, syncBatchInProgress: false)
|
||||
session.testSetMalformedMessageResyncHook(nil)
|
||||
session.testResetMalformedMessageResyncState()
|
||||
}
|
||||
|
||||
var triggerCount = 0
|
||||
session.testSetMalformedMessageResyncHook {
|
||||
triggerCount += 1
|
||||
}
|
||||
|
||||
session.testSimulateMalformedMessagePacketDrop(packetSize: 211, fingerprint: "fp-batch", messageIdHint: "m-batch")
|
||||
|
||||
// While sync is active, recovery must remain queued and not fire.
|
||||
try await Task.sleep(nanoseconds: 900_000_000)
|
||||
XCTAssertEqual(triggerCount, 0)
|
||||
XCTAssertEqual(session.malformedMessageResyncTriggerCount, 0)
|
||||
|
||||
// After sync ends, queued recovery should fire once.
|
||||
session.testSetSyncState(syncBatchInProgress: false)
|
||||
try await Task.sleep(nanoseconds: 700_000_000)
|
||||
XCTAssertEqual(triggerCount, 1)
|
||||
XCTAssertEqual(session.malformedMessageResyncTriggerCount, 1)
|
||||
}
|
||||
|
||||
func testPostDecryptEmptyPayloadDropsWithoutUpsertAndDebouncesResync() async throws {
|
||||
let session = SessionManager.shared
|
||||
let myPrivateKey = try P256K.KeyAgreement.PrivateKey()
|
||||
let myPrivateKeyData = myPrivateKey.rawRepresentation
|
||||
let myPrivateKeyHex = myPrivateKeyData.hexString
|
||||
let myPublicKey = try CryptoManager.shared.deriveCompressedPublicKey(from: myPrivateKeyData).hexString
|
||||
|
||||
try DatabaseManager.shared.bootstrap(accountPublicKey: myPublicKey)
|
||||
await MessageRepository.shared.bootstrap(accountPublicKey: myPublicKey, storagePassword: myPrivateKeyHex)
|
||||
await DialogRepository.shared.bootstrap(accountPublicKey: myPublicKey, storagePassword: myPrivateKeyHex)
|
||||
|
||||
session.testConfigureSessionForParityFlows(
|
||||
currentPublicKey: myPublicKey,
|
||||
privateKeyHex: myPrivateKeyHex
|
||||
)
|
||||
session.testSetSyncState(syncRequestInFlight: false, syncBatchInProgress: false)
|
||||
session.testResetMalformedMessageResyncState()
|
||||
defer {
|
||||
session.testSetMalformedMessageResyncHook(nil)
|
||||
session.testResetMalformedMessageResyncState()
|
||||
session.testSetSyncState(syncRequestInFlight: false, syncBatchInProgress: false)
|
||||
}
|
||||
|
||||
var triggerCount = 0
|
||||
session.testSetMalformedMessageResyncHook {
|
||||
triggerCount += 1
|
||||
}
|
||||
|
||||
let peerPrivateKey = try P256K.KeyAgreement.PrivateKey()
|
||||
let peerPublicKey = try CryptoManager.shared
|
||||
.deriveCompressedPublicKey(from: peerPrivateKey.rawRepresentation)
|
||||
.hexString
|
||||
|
||||
var processedMessageIds: [String] = []
|
||||
for index in 0..<3 {
|
||||
let encrypted = try MessageCrypto.encryptOutgoing(
|
||||
plaintext: "",
|
||||
recipientPublicKeyHex: myPublicKey
|
||||
)
|
||||
|
||||
var packet = PacketMessage()
|
||||
packet.fromPublicKey = peerPublicKey
|
||||
packet.toPublicKey = myPublicKey
|
||||
packet.content = encrypted.content
|
||||
packet.chachaKey = encrypted.chachaKey
|
||||
packet.timestamp = 1_710_000_000_000 + Int64(index)
|
||||
packet.privateKey = "hash"
|
||||
packet.messageId = "empty-decrypted-\(index)-\(UUID().uuidString)"
|
||||
packet.attachments = []
|
||||
packet.aesChachaKey = ""
|
||||
|
||||
processedMessageIds.append(packet.messageId)
|
||||
await session.testProcessIncomingMessage(packet)
|
||||
}
|
||||
|
||||
try await Task.sleep(nanoseconds: 900_000_000)
|
||||
|
||||
XCTAssertEqual(triggerCount, 1)
|
||||
XCTAssertEqual(session.malformedMessageResyncTriggerCount, 1)
|
||||
|
||||
for messageId in processedMessageIds {
|
||||
XCTAssertFalse(
|
||||
MessageRepository.shared.hasMessage(messageId),
|
||||
"Post-decrypt empty payload must not be persisted as a bubble"
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
private func makePacketMessageData(metaFieldCount: Int) -> Data {
|
||||
let stream = Rosetta.Stream()
|
||||
stream.writeInt16(PacketMessage.packetId)
|
||||
stream.writeString("02from")
|
||||
stream.writeString("02to")
|
||||
stream.writeString("ciphertext")
|
||||
stream.writeString("chacha-key")
|
||||
stream.writeInt64(1_710_000_000_000)
|
||||
stream.writeString("hash")
|
||||
stream.writeString("msg-hardening")
|
||||
stream.writeInt8(1)
|
||||
stream.writeString("att-1")
|
||||
stream.writeString("preview")
|
||||
stream.writeString("blob")
|
||||
stream.writeInt8(AttachmentType.image.rawValue)
|
||||
if metaFieldCount >= 2 {
|
||||
stream.writeString("tag-1")
|
||||
stream.writeString("cdn.rosetta.im")
|
||||
}
|
||||
if metaFieldCount >= 4 {
|
||||
stream.writeString("02encoded-for")
|
||||
stream.writeString("desktop")
|
||||
}
|
||||
stream.writeString("aes-key")
|
||||
return stream.toData()
|
||||
}
|
||||
|
||||
private func makeHandshakeData(stateRaw: Int) -> Data {
|
||||
let stream = Rosetta.Stream()
|
||||
stream.writeInt16(PacketHandshake.packetId)
|
||||
stream.writeString("hash")
|
||||
stream.writeString("02public")
|
||||
stream.writeInt8(1)
|
||||
stream.writeInt8(15)
|
||||
stream.writeString("device-id")
|
||||
stream.writeString("iPhone")
|
||||
stream.writeString("iOS")
|
||||
stream.writeInt8(stateRaw)
|
||||
return stream.toData()
|
||||
}
|
||||
|
||||
private func makeSyncData(statusRaw: Int) -> Data {
|
||||
let stream = Rosetta.Stream()
|
||||
stream.writeInt16(PacketSync.packetId)
|
||||
stream.writeInt8(statusRaw)
|
||||
stream.writeInt64(1_710_000_000_000)
|
||||
return stream.toData()
|
||||
}
|
||||
|
||||
private func makeSignalCreateRoomData(roomId: String?) -> Data {
|
||||
let stream = Rosetta.Stream()
|
||||
stream.writeInt16(PacketSignalPeer.packetId)
|
||||
stream.writeInt8(SignalType.createRoom.rawValue)
|
||||
stream.writeString("02src")
|
||||
stream.writeString("02dst")
|
||||
if let roomId {
|
||||
stream.writeString(roomId)
|
||||
}
|
||||
return stream.toData()
|
||||
}
|
||||
|
||||
private func makeWebRtcData(
|
||||
includeIdentity: Bool,
|
||||
publicKey: String = "",
|
||||
deviceId: String = ""
|
||||
) -> Data {
|
||||
let stream = Rosetta.Stream()
|
||||
stream.writeInt16(PacketWebRTC.packetId)
|
||||
stream.writeInt8(WebRTCSignalType.offer.rawValue)
|
||||
stream.writeString("{\"type\":\"offer\",\"sdp\":\"v=0\"}")
|
||||
if includeIdentity {
|
||||
stream.writeString(publicKey)
|
||||
stream.writeString(deviceId)
|
||||
}
|
||||
return stream.toData()
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user