From a5945152c0906856809e4d85efad3d4839b75c55 Mon Sep 17 00:00:00 2001 From: senseiGai Date: Mon, 6 Apr 2026 23:35:29 +0500 Subject: [PATCH] =?UTF-8?q?=D0=A4=D0=B8=D0=BA=D1=81:=20=D0=B8=D1=81=D0=BF?= =?UTF-8?q?=D1=80=D0=B0=D0=B2=D0=BB=D0=B5=D0=BD=D0=BE=20=D0=B8=D1=81=D1=87?= =?UTF-8?q?=D0=B5=D0=B7=D0=BD=D0=BE=D0=B2=D0=B5=D0=BD=D0=B8=D0=B5=20=D1=87?= =?UTF-8?q?=D0=B0=D1=81=D1=82=D0=B8=20=D1=83=D0=B2=D0=B5=D0=B4=D0=BE=D0=BC?= =?UTF-8?q?=D0=BB=D0=B5=D0=BD=D0=B8=D0=B9=20=D0=BF=D1=80=D0=B8=20=D0=BE?= =?UTF-8?q?=D1=82=D0=BA=D1=80=D1=8B=D1=82=D0=B8=D0=B8=20=D0=BF=D1=83=D1=88?= =?UTF-8?q?=D0=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- Info.plist | 2 + Rosetta.xcodeproj/project.pbxproj | 32 +- Rosetta/Core/Crypto/CryptoPrimitives.swift | 19 + Rosetta/Core/Crypto/MessageCrypto.swift | 2 +- .../Data/Repositories/AvatarRepository.swift | 51 +- .../Data/Repositories/MessageRepository.swift | 21 + .../Network/Protocol/Packets/Packet.swift | 20 +- .../Protocol/Packets/PacketHandshake.swift | 73 ++- .../Protocol/Packets/PacketMessage.swift | 157 ++++- .../Packets/PacketPushNotification.swift | 11 +- .../Protocol/Packets/PacketSignalPeer.swift | 110 +++- .../Network/Protocol/Packets/PacketSync.swift | 47 +- .../Protocol/Packets/PacketWebRTC.swift | 117 +++- .../Network/Protocol/ProtocolManager.swift | 113 ++++ Rosetta/Core/Network/Protocol/Stream.swift | 118 +++- Rosetta/Core/Services/SessionManager.swift | 590 ++++++++++++++---- .../Core/Utils/AttachmentPreviewCodec.swift | 36 +- Rosetta/Core/Utils/ReleaseNotes.swift | 24 +- Rosetta/RosettaApp.swift | 34 +- .../NotificationService.swift | 147 ++++- RosettaTests/AttachmentParityTests.swift | 6 +- RosettaTests/CallPushIntegrationTests.swift | 71 +-- RosettaTests/CryptoParityTests.swift | 174 +++++- RosettaTests/FileAttachmentTests.swift | 20 + .../MessageDecodeHardeningTests.swift | 451 +++++++++++++ .../PushNotificationPacketTests.swift | 46 +- RosettaTests/SchemaParityTests.swift | 88 +++ 27 files changed, 2240 insertions(+), 340 deletions(-) create mode 100644 RosettaTests/MessageDecodeHardeningTests.swift diff --git a/Info.plist b/Info.plist index b56f24f..d83084f 100644 --- a/Info.plist +++ b/Info.plist @@ -14,6 +14,8 @@ Rosetta needs access to your microphone for secure voice calls and audio messages. NSSupportsLiveActivities + FirebaseAppDelegateProxyEnabled + UIBackgroundModes remote-notification diff --git a/Rosetta.xcodeproj/project.pbxproj b/Rosetta.xcodeproj/project.pbxproj index abebac0..87b17b2 100644 --- a/Rosetta.xcodeproj/project.pbxproj +++ b/Rosetta.xcodeproj/project.pbxproj @@ -9,8 +9,10 @@ /* Begin PBXBuildFile section */ 3146EDCE68162995CB5D1034 /* BehaviorParityFixtureTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = C9FC5C4F7E26FAFEC47C1D51 /* BehaviorParityFixtureTests.swift */; }; 3C4D5E6F708192A3B4C5D6E7 /* AttachmentParityTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 1A2B3C4D5E6F708192A3B4C5 /* AttachmentParityTests.swift */; }; + B7F1C2D34A5E67890ABCDEF1 /* CryptoParityTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = D9A1B2C3D4E5F60718293A4B /* CryptoParityTests.swift */; }; 4C9BDB443750F7003CFB705C /* Foundation.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 272B862BE4D99E7DD751CC3E /* Foundation.framework */; }; 4D5E6F708192A3B4C5D6E7F8 /* SearchParityTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = 2B3C4D5E6F708192A3B4C5D6 /* SearchParityTests.swift */; }; + C8E2D3F45B6A78901BCDEF12 /* MessageDecodeHardeningTests.swift in Sources */ = {isa = PBXBuildFile; fileRef = EAF1B2C3D4E5F60718293A4B /* MessageDecodeHardeningTests.swift */; }; 806C964D76E024430307C151 /* Foundation.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 272B862BE4D99E7DD751CC3E /* Foundation.framework */; }; 853F29992F4B63D20092AD05 /* Lottie in Frameworks */ = {isa = PBXBuildFile; productRef = 853F29982F4B63D20092AD05 /* Lottie */; }; 853F29A02F4B63D20092AD05 /* P256K in Frameworks */ = {isa = PBXBuildFile; productRef = 853F29A12F4B63D20092AD05 /* P256K */; }; @@ -96,6 +98,7 @@ 272B862BE4D99E7DD751CC3E /* Foundation.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = Foundation.framework; path = Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS18.0.sdk/System/Library/Frameworks/Foundation.framework; sourceTree = DEVELOPER_DIR; }; 2B3C4D5E6F708192A3B4C5D6 /* SearchParityTests.swift */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = sourcecode.swift; path = SearchParityTests.swift; sourceTree = ""; }; 4D3AF08B754B66DE17AF486D /* DBTestSupport.swift */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = sourcecode.swift; path = DBTestSupport.swift; sourceTree = ""; }; + D9A1B2C3D4E5F60718293A4B /* CryptoParityTests.swift */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = sourcecode.swift; path = CryptoParityTests.swift; sourceTree = ""; }; 75BA8A97FE297E450BB1452E /* RosettaTests.xctest */ = {isa = PBXFileReference; explicitFileType = wrapper.cfbundle; includeInIndex = 0; path = RosettaTests.xctest; sourceTree = BUILT_PRODUCTS_DIR; }; 7F4769EEC8ABADB3AA98D3A5 /* SchemaParityTests.swift */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = sourcecode.swift; path = SchemaParityTests.swift; sourceTree = ""; }; 853F29622F4B50410092AD05 /* Rosetta.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = Rosetta.app; sourceTree = BUILT_PRODUCTS_DIR; }; @@ -103,6 +106,7 @@ A182B0EDE5C68E7C6F1FB6D1 /* RosettaNotificationService.appex */ = {isa = PBXFileReference; explicitFileType = "wrapper.app-extension"; includeInIndex = 0; path = RosettaNotificationService.appex; sourceTree = BUILT_PRODUCTS_DIR; }; C9FC5C4F7E26FAFEC47C1D51 /* BehaviorParityFixtureTests.swift */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = sourcecode.swift; path = BehaviorParityFixtureTests.swift; sourceTree = ""; }; DBAA4AD95B61886B5A22EF0D /* MigrationHarnessTests.swift */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = sourcecode.swift; path = MigrationHarnessTests.swift; sourceTree = ""; }; + EAF1B2C3D4E5F60718293A4B /* MessageDecodeHardeningTests.swift */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = sourcecode.swift; path = MessageDecodeHardeningTests.swift; sourceTree = ""; }; E20000042F8D11110092AD05 /* WebRTC.xcframework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.xcframework; name = WebRTC.xcframework; path = Frameworks/WebRTC.xcframework; sourceTree = ""; }; LA00000022F8D22220092AD05 /* RosettaLiveActivityWidget.appex */ = {isa = PBXFileReference; explicitFileType = "wrapper.app-extension"; includeInIndex = 0; path = RosettaLiveActivityWidget.appex; sourceTree = BUILT_PRODUCTS_DIR; }; LA000000E2F8D22220092AD05 /* CallLiveActivity.swift */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.swift; path = CallLiveActivity.swift; sourceTree = ""; }; @@ -163,8 +167,10 @@ children = ( 1A2B3C4D5E6F708192A3B4C5 /* AttachmentParityTests.swift */, C9FC5C4F7E26FAFEC47C1D51 /* BehaviorParityFixtureTests.swift */, + D9A1B2C3D4E5F60718293A4B /* CryptoParityTests.swift */, 4D3AF08B754B66DE17AF486D /* DBTestSupport.swift */, DBAA4AD95B61886B5A22EF0D /* MigrationHarnessTests.swift */, + EAF1B2C3D4E5F60718293A4B /* MessageDecodeHardeningTests.swift */, 7F4769EEC8ABADB3AA98D3A5 /* SchemaParityTests.swift */, 2B3C4D5E6F708192A3B4C5D6 /* SearchParityTests.swift */, ); @@ -411,8 +417,10 @@ files = ( 3C4D5E6F708192A3B4C5D6E7 /* AttachmentParityTests.swift in Sources */, 3146EDCE68162995CB5D1034 /* BehaviorParityFixtureTests.swift in Sources */, + B7F1C2D34A5E67890ABCDEF1 /* CryptoParityTests.swift in Sources */, CC5AD9236E3B3BA95A0C29EC /* DBTestSupport.swift in Sources */, EC5DFA298C697AE235323240 /* MigrationHarnessTests.swift in Sources */, + C8E2D3F45B6A78901BCDEF12 /* MessageDecodeHardeningTests.swift in Sources */, D60B2E657D691F256B5B7FD4 /* SchemaParityTests.swift in Sources */, 4D5E6F708192A3B4C5D6E7F8 /* SearchParityTests.swift in Sources */, ); @@ -465,7 +473,7 @@ CLANG_ENABLE_OBJC_WEAK = NO; CODE_SIGN_ENTITLEMENTS = RosettaNotificationService/RosettaNotificationService.entitlements; CODE_SIGN_STYLE = Automatic; - CURRENT_PROJECT_VERSION = 27; + CURRENT_PROJECT_VERSION = 31; DEVELOPMENT_TEAM = QN8Z263QGX; GENERATE_INFOPLIST_FILE = NO; INFOPLIST_FILE = RosettaNotificationService/Info.plist; @@ -475,7 +483,7 @@ "@executable_path/Frameworks", "@executable_path/../../Frameworks", ); - MARKETING_VERSION = 1.2.6; + MARKETING_VERSION = 1.3.0; PRODUCT_BUNDLE_IDENTIFIER = com.rosetta.dev.NotificationService; PRODUCT_NAME = "$(TARGET_NAME)"; SDKROOT = iphoneos; @@ -613,7 +621,7 @@ CODE_SIGN_ENTITLEMENTS = Rosetta/Rosetta.entitlements; CODE_SIGN_IDENTITY = "Apple Development"; CODE_SIGN_STYLE = Automatic; - CURRENT_PROJECT_VERSION = 30; + CURRENT_PROJECT_VERSION = 31; DEVELOPMENT_TEAM = QN8Z263QGX; ENABLE_PREVIEWS = YES; GENERATE_INFOPLIST_FILE = YES; @@ -629,7 +637,7 @@ "$(inherited)", "@executable_path/Frameworks", ); - MARKETING_VERSION = 1.2.9; + MARKETING_VERSION = 1.3.0; PRODUCT_BUNDLE_IDENTIFIER = com.rosetta.dev; PRODUCT_NAME = "$(TARGET_NAME)"; PROVISIONING_PROFILE_SPECIFIER = ""; @@ -653,7 +661,7 @@ CODE_SIGN_ENTITLEMENTS = Rosetta/Rosetta.entitlements; CODE_SIGN_IDENTITY = "Apple Development"; CODE_SIGN_STYLE = Automatic; - CURRENT_PROJECT_VERSION = 30; + CURRENT_PROJECT_VERSION = 31; DEVELOPMENT_TEAM = QN8Z263QGX; ENABLE_PREVIEWS = YES; GENERATE_INFOPLIST_FILE = YES; @@ -669,7 +677,7 @@ "$(inherited)", "@executable_path/Frameworks", ); - MARKETING_VERSION = 1.2.9; + MARKETING_VERSION = 1.3.0; PRODUCT_BUNDLE_IDENTIFIER = com.rosetta.dev; PRODUCT_NAME = "$(TARGET_NAME)"; PROVISIONING_PROFILE_SPECIFIER = ""; @@ -692,7 +700,7 @@ CLANG_ENABLE_OBJC_WEAK = NO; CODE_SIGN_ENTITLEMENTS = RosettaNotificationService/RosettaNotificationService.entitlements; CODE_SIGN_STYLE = Automatic; - CURRENT_PROJECT_VERSION = 27; + CURRENT_PROJECT_VERSION = 31; DEVELOPMENT_TEAM = QN8Z263QGX; GENERATE_INFOPLIST_FILE = NO; INFOPLIST_FILE = RosettaNotificationService/Info.plist; @@ -702,7 +710,7 @@ "@executable_path/Frameworks", "@executable_path/../../Frameworks", ); - MARKETING_VERSION = 1.2.6; + MARKETING_VERSION = 1.3.0; PRODUCT_BUNDLE_IDENTIFIER = com.rosetta.dev.NotificationService; PRODUCT_NAME = "$(TARGET_NAME)"; SDKROOT = iphoneos; @@ -756,7 +764,7 @@ CLANG_CXX_LANGUAGE_STANDARD = "gnu++20"; CODE_SIGN_ENTITLEMENTS = RosettaLiveActivityWidget/RosettaLiveActivityWidget.entitlements; CODE_SIGN_STYLE = Automatic; - CURRENT_PROJECT_VERSION = 27; + CURRENT_PROJECT_VERSION = 31; DEVELOPMENT_TEAM = QN8Z263QGX; GENERATE_INFOPLIST_FILE = NO; INFOPLIST_FILE = RosettaLiveActivityWidget/Info.plist; @@ -766,7 +774,7 @@ "@executable_path/Frameworks", "@executable_path/../../Frameworks", ); - MARKETING_VERSION = 1.2.6; + MARKETING_VERSION = 1.3.0; PRODUCT_BUNDLE_IDENTIFIER = com.rosetta.dev.LiveActivityWidget; PRODUCT_NAME = "$(TARGET_NAME)"; SKIP_INSTALL = YES; @@ -783,7 +791,7 @@ CLANG_CXX_LANGUAGE_STANDARD = "gnu++20"; CODE_SIGN_ENTITLEMENTS = RosettaLiveActivityWidget/RosettaLiveActivityWidget.entitlements; CODE_SIGN_STYLE = Automatic; - CURRENT_PROJECT_VERSION = 27; + CURRENT_PROJECT_VERSION = 31; DEVELOPMENT_TEAM = QN8Z263QGX; GENERATE_INFOPLIST_FILE = NO; INFOPLIST_FILE = RosettaLiveActivityWidget/Info.plist; @@ -793,7 +801,7 @@ "@executable_path/Frameworks", "@executable_path/../../Frameworks", ); - MARKETING_VERSION = 1.2.6; + MARKETING_VERSION = 1.3.0; PRODUCT_BUNDLE_IDENTIFIER = com.rosetta.dev.LiveActivityWidget; PRODUCT_NAME = "$(TARGET_NAME)"; SKIP_INSTALL = YES; diff --git a/Rosetta/Core/Crypto/CryptoPrimitives.swift b/Rosetta/Core/Crypto/CryptoPrimitives.swift index 4384355..2620ca4 100644 --- a/Rosetta/Core/Crypto/CryptoPrimitives.swift +++ b/Rosetta/Core/Crypto/CryptoPrimitives.swift @@ -235,4 +235,23 @@ extension Data { } self = data } + + /// Initialize from a STRICT hex string. + /// Returns nil if the input contains non-hex characters or has odd length. + nonisolated init?(strictHexString: String) { + let hex = strictHexString.trimmingCharacters(in: .whitespacesAndNewlines) + guard hex.count % 2 == 0 else { return nil } + + var data = Data(capacity: hex.count / 2) + var index = hex.startIndex + while index < hex.endIndex { + let nextIndex = hex.index(index, offsetBy: 2) + guard let byte = UInt8(hex[index.. URL { avatarsDirectory! .appendingPathComponent("\(normalizedKey).enc") } + private func notificationAvatarURL(for normalizedKey: String) -> URL? { + notificationAvatarsDirectory? + .appendingPathComponent("\(normalizedKey).jpg") + } + private func normalizedKey(_ publicKey: String) -> String { publicKey .replacingOccurrences(of: "0x", with: "") .lowercased() } - private func ensureDirectoryExists() { - guard let directory = avatarsDirectory else { return } - if !FileManager.default.fileExists(atPath: directory.path) { - try? FileManager.default.createDirectory( - at: directory, - withIntermediateDirectories: true - ) - } + private func ensureStorageDirectoriesExist() { + ensureDirectoryExists(avatarsDirectory) + ensureDirectoryExists(notificationAvatarsDirectory) + } + + private func ensureDirectoryExists(_ directory: URL?) { + guard let directory else { return } + if FileManager.default.fileExists(atPath: directory.path) { return } + try? FileManager.default.createDirectory( + at: directory, + withIntermediateDirectories: true + ) + } + + private func syncAvatarToNotificationStoreIfNeeded(_ image: UIImage, normalizedKey: String) { + guard let notificationURL = notificationAvatarURL(for: normalizedKey) else { return } + ensureDirectoryExists(notificationAvatarsDirectory) + guard let jpegData = image.jpegData(compressionQuality: 0.72) else { return } + try? jpegData.write(to: notificationURL, options: [.atomic]) } } diff --git a/Rosetta/Core/Data/Repositories/MessageRepository.swift b/Rosetta/Core/Data/Repositories/MessageRepository.swift index 7ba1a0a..f385cd0 100644 --- a/Rosetta/Core/Data/Repositories/MessageRepository.swift +++ b/Rosetta/Core/Data/Repositories/MessageRepository.swift @@ -545,6 +545,21 @@ final class MessageRepository: ObservableObject { refreshCacheNow(for: opponentKey) } + /// Updates attachment password for a specific message (used during retry with re-encryption). + func updateAttachmentPassword(messageId: String, password: String) { + guard !currentAccount.isEmpty else { return } + do { + try db.writeSync { db in + try db.execute( + sql: "UPDATE messages SET attachment_password = ? WHERE account = ? AND message_id = ?", + arguments: [password, currentAccount, messageId] + ) + } + } catch { + print("[DB] updateAttachmentPassword error: \(error)") + } + } + // MARK: - Typing func markTyping(from dialogKey: String, senderKey: String) { @@ -1093,6 +1108,12 @@ final class MessageRepository: ObservableObject { return false } +#if DEBUG + internal static func testIsProbablyEncrypted(_ value: String) -> Bool { + isProbablyEncryptedPayload(value) + } +#endif + private func normalizeTimestamp(_ raw: Int64) -> Int64 { raw < 1_000_000_000_000 ? raw * 1000 : raw } diff --git a/Rosetta/Core/Network/Protocol/Packets/Packet.swift b/Rosetta/Core/Network/Protocol/Packets/Packet.swift index 405b189..2ecae78 100644 --- a/Rosetta/Core/Network/Protocol/Packets/Packet.swift +++ b/Rosetta/Core/Network/Protocol/Packets/Packet.swift @@ -1,16 +1,19 @@ import Foundation +import os /// Base protocol for all Rosetta binary packets. protocol Packet { static var packetId: Int { get } func write(to stream: Stream) - mutating func read(from stream: Stream) + mutating func read(from stream: Stream) throws } // MARK: - Packet Registry enum PacketRegistry { + private static let logger = Logger(subsystem: "com.rosetta.messenger", category: "PacketRegistry") + /// All known packet factories, keyed by packet ID. private static let factories: [Int: () -> any Packet] = [ 0x00: { PacketHandshake() }, @@ -43,15 +46,28 @@ enum PacketRegistry { /// Deserializes a packet from raw binary data. static func decode(from data: Data) -> (packetId: Int, packet: any Packet)? { + guard data.count >= 2 else { + logger.error("Rejecting packet: too short payload size=\(data.count)") + return nil + } + let stream = Stream(data: data) let packetId = stream.readInt16() guard let factory = factories[packetId] else { + let packetHex = String(format: "0x%02X", packetId) + logger.warning("Rejecting packet: unknown packetId=\(packetHex) size=\(data.count)") return nil } var packet = factory() - packet.read(from: stream) + do { + try packet.read(from: stream) + } catch { + let packetHex = String(format: "0x%02X", packetId) + logger.error("Rejecting packet: parse failure packetId=\(packetHex) size=\(data.count) error=\(String(describing: error))") + return nil + } return (packetId, packet) } diff --git a/Rosetta/Core/Network/Protocol/Packets/PacketHandshake.swift b/Rosetta/Core/Network/Protocol/Packets/PacketHandshake.swift index 5834bb2..444e856 100644 --- a/Rosetta/Core/Network/Protocol/Packets/PacketHandshake.swift +++ b/Rosetta/Core/Network/Protocol/Packets/PacketHandshake.swift @@ -5,10 +5,6 @@ import Foundation enum HandshakeState: Int { case completed = 0 case needDeviceVerification = 1 - - init(value: Int) { - self = HandshakeState(rawValue: value) ?? .completed - } } // MARK: - HandshakeDevice @@ -30,6 +26,8 @@ struct PacketHandshake: Packet { var heartbeatInterval: Int = 15 var device = HandshakeDevice() var handshakeState: HandshakeState = .needDeviceVerification + var isMalformed: Bool = false + var malformedFingerprint: String = "" func write(to stream: Stream) { stream.writeString(privateKey) @@ -43,15 +41,62 @@ struct PacketHandshake: Packet { } mutating func read(from stream: Stream) { - privateKey = stream.readString() - publicKey = stream.readString() - protocolVersion = stream.readInt8() - heartbeatInterval = stream.readInt8() - device = HandshakeDevice( - deviceId: stream.readString(), - deviceName: stream.readString(), - deviceOs: stream.readString() - ) - handshakeState = HandshakeState(value: stream.readInt8()) + do { + let parsedPrivateKey = try stream.readStringStrict() + let parsedPublicKey = try stream.readStringStrict() + let parsedProtocolVersion = try stream.readInt8Strict() + let parsedHeartbeatInterval = try stream.readInt8Strict() + let parsedDeviceId = try stream.readStringStrict() + let parsedDeviceName = try stream.readStringStrict() + let parsedDeviceOs = try stream.readStringStrict() + let rawState = try stream.readInt8Strict() + + guard let parsedState = HandshakeState(rawValue: rawState) else { + markMalformed("invalid_state:\(rawState)") + return + } + + guard !stream.hasRemainingBits() else { + markMalformed("trailing_bits:\(stream.remainingBits())") + return + } + + privateKey = parsedPrivateKey + publicKey = parsedPublicKey + protocolVersion = parsedProtocolVersion + heartbeatInterval = parsedHeartbeatInterval + device = HandshakeDevice( + deviceId: parsedDeviceId, + deviceName: parsedDeviceName, + deviceOs: parsedDeviceOs + ) + handshakeState = parsedState + isMalformed = false + malformedFingerprint = "" + } catch { + markMalformed(Self.errorFingerprint(error)) + } + } + + private mutating func markMalformed(_ fingerprint: String) { + privateKey = "" + publicKey = "" + protocolVersion = 1 + heartbeatInterval = 15 + device = HandshakeDevice() + handshakeState = .needDeviceVerification + isMalformed = true + malformedFingerprint = fingerprint + } + + private static func errorFingerprint(_ error: Error) -> String { + switch error { + case PacketBitStreamError.underflow(let operation, let neededBits, let remainingBits): + return "underflow:\(operation):\(neededBits):\(remainingBits)" + case PacketBitStreamError.invalidStringLength(let length): + return "invalid_string_length:\(length)" + default: + return "parse_error" + } } } diff --git a/Rosetta/Core/Network/Protocol/Packets/PacketMessage.swift b/Rosetta/Core/Network/Protocol/Packets/PacketMessage.swift index 8412fed..02af323 100644 --- a/Rosetta/Core/Network/Protocol/Packets/PacketMessage.swift +++ b/Rosetta/Core/Network/Protocol/Packets/PacketMessage.swift @@ -13,6 +13,22 @@ struct PacketMessage: Packet { var messageId: String = "" var attachments: [MessageAttachment] = [] var aesChachaKey: String = "" // ChaCha key+nonce encrypted by sender + /// True when payload could not be parsed in any known compatibility layout. + var isMalformed: Bool = false + /// Compact parser fingerprint for diagnostics (no sensitive payload data). + var malformedFingerprint: String = "" + + private struct ParsedPacketMessage { + let fromPublicKey: String + let toPublicKey: String + let content: String + let chachaKey: String + let timestamp: Int64 + let privateKey: String + let messageId: String + let attachments: [MessageAttachment] + let aesChachaKey: String + } func write(to stream: Stream) { // Match Android field order exactly @@ -37,33 +53,126 @@ struct PacketMessage: Packet { } mutating func read(from stream: Stream) { - fromPublicKey = stream.readString() - toPublicKey = stream.readString() - content = stream.readString() - chachaKey = stream.readString() - timestamp = stream.readInt64() - privateKey = stream.readString() - messageId = stream.readString() + let startPointer = stream.getReadPointerBits() + var parseErrors: [String] = [] + + for attachmentMetaFieldCount in [4, 2, 0] { + stream.setReadPointerBits(startPointer) + do { + let parsed = try Self.parseFromStream( + stream, + attachmentMetaFieldCount: attachmentMetaFieldCount + ) + if stream.hasRemainingBits() { + parseErrors.append( + "meta\(attachmentMetaFieldCount):trailing_bits=\(stream.remainingBits())" + ) + continue + } + + fromPublicKey = parsed.fromPublicKey + toPublicKey = parsed.toPublicKey + content = parsed.content + chachaKey = parsed.chachaKey + timestamp = parsed.timestamp + privateKey = parsed.privateKey + messageId = parsed.messageId + attachments = parsed.attachments + aesChachaKey = parsed.aesChachaKey + isMalformed = false + malformedFingerprint = "" + return + } catch { + parseErrors.append("meta\(attachmentMetaFieldCount):\(Self.errorFingerprint(error))") + } + } + + // Hard-fail parse: preserve zero/default fields and mark packet malformed. + fromPublicKey = "" + toPublicKey = "" + content = "" + chachaKey = "" + timestamp = 0 + privateKey = "" + messageId = "" + attachments = [] + aesChachaKey = "" + isMalformed = true + malformedFingerprint = parseErrors.isEmpty + ? "packet06_parse_failed" + : parseErrors.joined(separator: "|") + } + + private static func parseFromStream( + _ parser: Stream, + attachmentMetaFieldCount: Int + ) throws -> ParsedPacketMessage { + let parsedFromPublicKey = try parser.readStringStrict() + let parsedToPublicKey = try parser.readStringStrict() + let parsedContent = try parser.readStringStrict() + let parsedChachaKey = try parser.readStringStrict() + let parsedTimestamp = try parser.readInt64Strict() + let parsedPrivateKey = try parser.readStringStrict() + let parsedMessageId = try parser.readStringStrict() + + let attachmentCount = max(try parser.readInt8Strict(), 0) + var parsedAttachments: [MessageAttachment] = [] + parsedAttachments.reserveCapacity(attachmentCount) - let attachmentCount = max(stream.readInt8(), 0) - var list: [MessageAttachment] = [] for _ in 0..= 2 { + transportTag = try parser.readStringStrict() + transportServer = try parser.readStringStrict() + } else { + transportTag = "" + transportServer = "" + } + + // Older Android builds may contain extra metadata fields. + if attachmentMetaFieldCount >= 4 { + _ = try parser.readStringStrict() // encodedFor + _ = try parser.readStringStrict() // encoder + } + + parsedAttachments.append(MessageAttachment( + id: id, + preview: preview, + blob: blob, + type: type, + transportTag: transportTag, + transportServer: transportServer )) } - attachments = list - aesChachaKey = stream.readString() + + let parsedAesChachaKey = try parser.readStringStrict() + return ParsedPacketMessage( + fromPublicKey: parsedFromPublicKey, + toPublicKey: parsedToPublicKey, + content: parsedContent, + chachaKey: parsedChachaKey, + timestamp: parsedTimestamp, + privateKey: parsedPrivateKey, + messageId: parsedMessageId, + attachments: parsedAttachments, + aesChachaKey: parsedAesChachaKey + ) + } + + private static func errorFingerprint(_ error: Error) -> String { + switch error { + case PacketBitStreamError.underflow(let operation, let neededBits, let remainingBits): + return "underflow:\(operation):\(neededBits):\(remainingBits)" + case PacketBitStreamError.invalidStringLength(let length): + return "invalid_string_length:\(length)" + default: + return "parse_error" + } } } diff --git a/Rosetta/Core/Network/Protocol/Packets/PacketPushNotification.swift b/Rosetta/Core/Network/Protocol/Packets/PacketPushNotification.swift index 91a925e..b0f6cdd 100644 --- a/Rosetta/Core/Network/Protocol/Packets/PacketPushNotification.swift +++ b/Rosetta/Core/Network/Protocol/Packets/PacketPushNotification.swift @@ -7,15 +7,14 @@ enum PushNotificationAction: Int { } /// Token type for push notification registration. -/// Server parity: im.rosetta.packet.runtime.TokenType enum PushTokenType: Int { - case fcm = 0 // FCM token (iOS + Android) - case voipApns = 1 // VoIP APNs token (iOS only) + case fcm = 0 + case voipApns = 1 } -/// PushNotification packet (0x10) — registers or unregisters APNs/FCM token on server. -/// Sent after successful handshake to enable push notifications. -/// Server stores tokens at device level (PushToken entity linked to Device). +/// PushNotification packet (0x10) — registers or unregisters push token on server. +/// Cross-platform wire format parity (Server/Android): +/// writeString(token) + writeInt8(action) + writeInt8(tokenType) + writeString(deviceId) struct PacketPushNotification: Packet { static let packetId = 0x10 diff --git a/Rosetta/Core/Network/Protocol/Packets/PacketSignalPeer.swift b/Rosetta/Core/Network/Protocol/Packets/PacketSignalPeer.swift index 89d8e45..f59874f 100644 --- a/Rosetta/Core/Network/Protocol/Packets/PacketSignalPeer.swift +++ b/Rosetta/Core/Network/Protocol/Packets/PacketSignalPeer.swift @@ -27,6 +27,8 @@ struct PacketSignalPeer: Packet { var callId: String = "" var joinToken: String = "" var roomId: String = "" + var isMalformed: Bool = false + var malformedFingerprint: String = "" func write(to stream: Stream) { stream.writeInt8(signalType.rawValue) @@ -51,30 +53,57 @@ struct PacketSignalPeer: Packet { } mutating func read(from stream: Stream) { - src = "" - dst = "" - sharedPublic = "" - callId = "" - joinToken = "" - roomId = "" - signalType = SignalType(rawValue: stream.readInt8()) ?? .call - if isShortSignal { - return - } - src = stream.readString() - dst = stream.readString() - if signalType == .keyExchange { - sharedPublic = stream.readString() - } - if hasLegacyCallMetadata { - callId = stream.readString() - joinToken = stream.readString() - } - // Signal code 4 is mode-aware on read: - // - empty roomId => legacy ACTIVE - // - non-empty roomId => create-room fallback - if signalType == .createRoom { - roomId = stream.readString() + do { + let rawSignalType = try stream.readInt8Strict() + guard let parsedSignalType = SignalType(rawValue: rawSignalType) else { + markMalformed("invalid_signal_type:\(rawSignalType)") + return + } + + var parsedSrc = "" + var parsedDst = "" + var parsedSharedPublic = "" + var parsedCallId = "" + var parsedJoinToken = "" + var parsedRoomId = "" + + if !Self.isShortSignal(parsedSignalType) { + parsedSrc = try stream.readStringStrict() + parsedDst = try stream.readStringStrict() + + if parsedSignalType == .keyExchange { + parsedSharedPublic = try stream.readStringStrict() + } + + if Self.hasLegacyCallMetadata(parsedSignalType) { + parsedCallId = try stream.readStringStrict() + parsedJoinToken = try stream.readStringStrict() + } + + // signalType=4 supports both layouts: + // - legacy ACTIVE: no roomId field + // - create-room fallback: roomId field at tail + if parsedSignalType == .createRoom, stream.hasRemainingBits() { + parsedRoomId = try stream.readStringStrict() + } + } + + guard !stream.hasRemainingBits() else { + markMalformed("trailing_bits:\(stream.remainingBits())") + return + } + + src = parsedSrc + dst = parsedDst + sharedPublic = parsedSharedPublic + signalType = parsedSignalType + callId = parsedCallId + joinToken = parsedJoinToken + roomId = parsedRoomId + isMalformed = false + malformedFingerprint = "" + } catch { + markMalformed(Self.errorFingerprint(error)) } } @@ -87,4 +116,37 @@ struct PacketSignalPeer: Packet { private var hasLegacyCallMetadata: Bool { signalType == .call || signalType == .accept || signalType == .endCall } + + private mutating func markMalformed(_ fingerprint: String) { + src = "" + dst = "" + sharedPublic = "" + signalType = .call + callId = "" + joinToken = "" + roomId = "" + isMalformed = true + malformedFingerprint = fingerprint + } + + private static func isShortSignal(_ signalType: SignalType) -> Bool { + signalType == .endCallBecauseBusy + || signalType == .endCallBecausePeerDisconnected + || signalType == .ringingTimeout + } + + private static func hasLegacyCallMetadata(_ signalType: SignalType) -> Bool { + signalType == .call || signalType == .accept || signalType == .endCall + } + + private static func errorFingerprint(_ error: Error) -> String { + switch error { + case PacketBitStreamError.underflow(let operation, let neededBits, let remainingBits): + return "underflow:\(operation):\(neededBits):\(remainingBits)" + case PacketBitStreamError.invalidStringLength(let length): + return "invalid_string_length:\(length)" + default: + return "parse_error" + } + } } diff --git a/Rosetta/Core/Network/Protocol/Packets/PacketSync.swift b/Rosetta/Core/Network/Protocol/Packets/PacketSync.swift index 6fa5943..7b521a1 100644 --- a/Rosetta/Core/Network/Protocol/Packets/PacketSync.swift +++ b/Rosetta/Core/Network/Protocol/Packets/PacketSync.swift @@ -6,10 +6,6 @@ enum SyncStatus: Int { case notNeeded = 0 case batchStart = 1 case batchEnd = 2 - - init(value: Int) { - self = SyncStatus(rawValue: value) ?? .notNeeded - } } // MARK: - PacketSync (0x19) @@ -20,6 +16,8 @@ struct PacketSync: Packet { var status: SyncStatus = .notNeeded var timestamp: Int64 = 0 + var isMalformed: Bool = false + var malformedFingerprint: String = "" func write(to stream: Stream) { stream.writeInt8(status.rawValue) @@ -27,7 +25,44 @@ struct PacketSync: Packet { } mutating func read(from stream: Stream) { - status = SyncStatus(value: stream.readInt8()) - timestamp = stream.readInt64() + do { + let rawStatus = try stream.readInt8Strict() + guard let parsedStatus = SyncStatus(rawValue: rawStatus) else { + markMalformed("invalid_status:\(rawStatus)") + return + } + + let parsedTimestamp = try stream.readInt64Strict() + + guard !stream.hasRemainingBits() else { + markMalformed("trailing_bits:\(stream.remainingBits())") + return + } + + status = parsedStatus + timestamp = parsedTimestamp + isMalformed = false + malformedFingerprint = "" + } catch { + markMalformed(Self.errorFingerprint(error)) + } + } + + private mutating func markMalformed(_ fingerprint: String) { + status = .notNeeded + timestamp = 0 + isMalformed = true + malformedFingerprint = fingerprint + } + + private static func errorFingerprint(_ error: Error) -> String { + switch error { + case PacketBitStreamError.underflow(let operation, let neededBits, let remainingBits): + return "underflow:\(operation):\(neededBits):\(remainingBits)" + case PacketBitStreamError.invalidStringLength(let length): + return "invalid_string_length:\(length)" + default: + return "parse_error" + } } } diff --git a/Rosetta/Core/Network/Protocol/Packets/PacketWebRTC.swift b/Rosetta/Core/Network/Protocol/Packets/PacketWebRTC.swift index b96715c..0fc3d5d 100644 --- a/Rosetta/Core/Network/Protocol/Packets/PacketWebRTC.swift +++ b/Rosetta/Core/Network/Protocol/Packets/PacketWebRTC.swift @@ -17,18 +17,123 @@ struct PacketWebRTC: Packet { var publicKey: String = "" /// Sender's device ID — server checks publicKey↔deviceId binding. var deviceId: String = "" + var isMalformed: Bool = false + var malformedFingerprint: String = "" func write(to stream: Stream) { + // Canonical wire format: signalType + sdpOrCandidate. + // Keep publicKey/deviceId as in-memory fields for backward compatibility. stream.writeInt8(signalType.rawValue) stream.writeString(sdpOrCandidate) - stream.writeString(publicKey) - stream.writeString(deviceId) } mutating func read(from stream: Stream) { - signalType = WebRTCSignalType(rawValue: stream.readInt8()) ?? .offer - sdpOrCandidate = stream.readString() - publicKey = stream.readString() - deviceId = stream.readString() + let startPointer = stream.getReadPointerBits() + var parseErrors: [String] = [] + + do { + let parsed = try Self.parse(from: stream, includeIdentityFields: false) + if stream.hasRemainingBits() { + parseErrors.append("v2:trailing_bits:\(stream.remainingBits())") + } else { + apply(parsed) + isMalformed = false + malformedFingerprint = "" + return + } + } catch { + parseErrors.append("v2:\(Self.errorFingerprint(error))") + } + + stream.setReadPointerBits(startPointer) + + do { + let parsed = try Self.parse(from: stream, includeIdentityFields: true) + if stream.hasRemainingBits() { + parseErrors.append("v4:trailing_bits:\(stream.remainingBits())") + } else { + apply(parsed) + isMalformed = false + malformedFingerprint = "" + return + } + } catch { + parseErrors.append("v4:\(Self.errorFingerprint(error))") + } + + markMalformed( + parseErrors.isEmpty + ? "packet1b_parse_failed" + : parseErrors.joined(separator: "|") + ) + } + + private mutating func apply(_ parsed: ParsedPacketWebRTC) { + signalType = parsed.signalType + sdpOrCandidate = parsed.sdpOrCandidate + publicKey = parsed.publicKey + deviceId = parsed.deviceId + } + + private mutating func markMalformed(_ fingerprint: String) { + signalType = .offer + sdpOrCandidate = "" + publicKey = "" + deviceId = "" + isMalformed = true + malformedFingerprint = fingerprint + } + + private struct ParsedPacketWebRTC { + let signalType: WebRTCSignalType + let sdpOrCandidate: String + let publicKey: String + let deviceId: String + } + + private enum PacketWebRTCParseError: Error { + case invalidSignalType(Int) + } + + private static func parse( + from stream: Stream, + includeIdentityFields: Bool + ) throws -> ParsedPacketWebRTC { + let rawSignalType = try stream.readInt8Strict() + guard let parsedSignalType = WebRTCSignalType(rawValue: rawSignalType) else { + throw PacketWebRTCParseError.invalidSignalType(rawSignalType) + } + + let parsedSdpOrCandidate = try stream.readStringStrict() + let parsedPublicKey: String + let parsedDeviceId: String + + if includeIdentityFields { + parsedPublicKey = try stream.readStringStrict() + parsedDeviceId = try stream.readStringStrict() + } else { + parsedPublicKey = "" + parsedDeviceId = "" + } + + return ParsedPacketWebRTC( + signalType: parsedSignalType, + sdpOrCandidate: parsedSdpOrCandidate, + publicKey: parsedPublicKey, + deviceId: parsedDeviceId + ) + } + + private static func errorFingerprint(_ error: Error) -> String { + switch error { + case PacketBitStreamError.underflow(let operation, let neededBits, let remainingBits): + return "underflow:\(operation):\(neededBits):\(remainingBits)" + case PacketBitStreamError.invalidStringLength(let length): + return "invalid_string_length:\(length)" + case PacketWebRTCParseError.invalidSignalType(let raw): + return "invalid_signal_type:\(raw)" + default: + return "parse_error" + } } } diff --git a/Rosetta/Core/Network/Protocol/ProtocolManager.swift b/Rosetta/Core/Network/Protocol/ProtocolManager.swift index ce69487..3973641 100644 --- a/Rosetta/Core/Network/Protocol/ProtocolManager.swift +++ b/Rosetta/Core/Network/Protocol/ProtocolManager.swift @@ -14,6 +14,18 @@ enum ConnectionState: String { case authenticated } +struct MalformedMessagePacketInfo: Sendable { + let packetSize: Int + let fingerprint: String + let messageIdHint: String +} + +struct MalformedCriticalPacketInfo: Sendable { + let packetId: Int + let packetSize: Int + let fingerprint: String +} + // MARK: - ProtocolManager /// Central networking coordinator. Owns WebSocket, routes packets, manages handshake. @@ -58,6 +70,8 @@ final class ProtocolManager: @unchecked Sendable { var onWebRTCReceived: ((PacketWebRTC) -> Void)? var onIceServersReceived: ((PacketIceServers) -> Void)? var onHandshakeCompleted: ((PacketHandshake) -> Void)? + var onMalformedMessageReceived: ((MalformedMessagePacketInfo) -> Void)? + var onMalformedCriticalPacketReceived: ((MalformedCriticalPacketInfo) -> Void)? // MARK: - Private @@ -687,6 +701,15 @@ final class ProtocolManager: @unchecked Sendable { switch packetId { case 0x00: if let p = packet as? PacketHandshake { + if p.isMalformed { + reportMalformedCriticalPacket( + packetId: packetId, + packetSize: data.count, + fingerprint: p.malformedFingerprint, + fallbackFingerprint: "packet00_parse_failed" + ) + return + } handleHandshakeResponse(p) } case 0x01: @@ -712,6 +735,25 @@ final class ProtocolManager: @unchecked Sendable { } case 0x06: if let p = packet as? PacketMessage { + if p.isMalformed { + let messageIdHint = p.messageId.isEmpty ? "-" : String(p.messageId.prefix(8)) + let fingerprint = p.malformedFingerprint.isEmpty ? "packet06_parse_failed" : p.malformedFingerprint + reportMalformedCriticalPacket( + packetId: packetId, + packetSize: data.count, + fingerprint: fingerprint, + fallbackFingerprint: "packet06_parse_failed", + messageIdHint: messageIdHint + ) + onMalformedMessageReceived?( + MalformedMessagePacketInfo( + packetSize: data.count, + fingerprint: fingerprint, + messageIdHint: messageIdHint + ) + ) + return + } onMessageReceived?(p) } case 0x07: @@ -781,6 +823,15 @@ final class ProtocolManager: @unchecked Sendable { } case 0x19: if let p = packet as? PacketSync { + if p.isMalformed { + reportMalformedCriticalPacket( + packetId: packetId, + packetSize: data.count, + fingerprint: p.malformedFingerprint, + fallbackFingerprint: "packet19_parse_failed" + ) + return + } // Android parity: set sync flag SYNCHRONOUSLY on receive queue // BEFORE dispatching to MainActor callback. This prevents the race // where a 0x06 message Task runs on MainActor before BATCH_START Task. @@ -797,11 +848,29 @@ final class ProtocolManager: @unchecked Sendable { } case 0x1A: if let p = packet as? PacketSignalPeer { + if p.isMalformed { + reportMalformedCriticalPacket( + packetId: packetId, + packetSize: data.count, + fingerprint: p.malformedFingerprint, + fallbackFingerprint: "packet1a_parse_failed" + ) + return + } onSignalPeerReceived?(p) notifySignalPeerHandlers(p) } case 0x1B: if let p = packet as? PacketWebRTC { + if p.isMalformed { + reportMalformedCriticalPacket( + packetId: packetId, + packetSize: data.count, + fingerprint: p.malformedFingerprint, + fallbackFingerprint: "packet1b_parse_failed" + ) + return + } onWebRTCReceived?(p) notifyWebRtcHandlers(p) } @@ -861,6 +930,44 @@ final class ProtocolManager: @unchecked Sendable { } } + private func reportMalformedCriticalPacket( + packetId: Int, + packetSize: Int, + fingerprint: String, + fallbackFingerprint: String, + messageIdHint: String? = nil + ) { + let packetHex = String(format: "0x%02X", packetId) + let normalizedFingerprint = Self.compactFingerprint( + fingerprint.isEmpty ? fallbackFingerprint : fingerprint + ) + if let messageIdHint { + Self.logger.error( + "Dropping malformed \(packetHex) packet size=\(packetSize) msgHint=\(messageIdHint) fp=\(normalizedFingerprint)" + ) + } else { + Self.logger.error( + "Dropping malformed \(packetHex) packet size=\(packetSize) fp=\(normalizedFingerprint)" + ) + } + + onMalformedCriticalPacketReceived?( + MalformedCriticalPacketInfo( + packetId: packetId, + packetSize: packetSize, + fingerprint: normalizedFingerprint + ) + ) + } + + private static func compactFingerprint(_ fingerprint: String) -> String { + let sanitized = fingerprint + .replacingOccurrences(of: "\n", with: " ") + .replacingOccurrences(of: "\t", with: " ") + guard sanitized.count > 120 else { return sanitized } + return String(sanitized.prefix(120)) + } + private func handleHandshakeResponse(_ packet: PacketHandshake) { handshakeTimeoutTask?.cancel() handshakeTimeoutTask = nil @@ -1091,4 +1198,10 @@ final class ProtocolManager: @unchecked Sendable { return nil } } + + // MARK: - Test Support + + func testHandleIncomingData(_ data: Data) { + handleIncomingData(data) + } } diff --git a/Rosetta/Core/Network/Protocol/Stream.swift b/Rosetta/Core/Network/Protocol/Stream.swift index e393b52..7b9a6fc 100644 --- a/Rosetta/Core/Network/Protocol/Stream.swift +++ b/Rosetta/Core/Network/Protocol/Stream.swift @@ -2,6 +2,11 @@ import Foundation typealias Stream = PacketBitStream +enum PacketBitStreamError: Error { + case underflow(operation: String, neededBits: Int, remainingBits: Int) + case invalidStringLength(Int) +} + /// Bit-aligned binary stream for protocol packets. /// Matches the server (Java) implementation exactly. /// @@ -37,6 +42,24 @@ final class PacketBitStream: NSObject { return Data(bytes[0.. Int { + readPointer + } + + func setReadPointerBits(_ bits: Int) { + readPointer = min(max(bits, 0), writePointer) + } + + func remainingBits() -> Int { + writePointer - readPointer + } + + func hasRemainingBits() -> Bool { + readPointer < writePointer + } + // MARK: - Bit-Level I/O func writeBit(_ value: Int) { @@ -101,6 +124,18 @@ final class PacketBitStream: NSObject { return readBits(8) } + func readUInt8Strict() throws -> Int { + try ensureReadableBits(8, operation: "readUInt8") + + if (readPointer & 7) == 0 { + let value = Int(bytes[readPointer >> 3]) + readPointer += 8 + return value + } + + return try readBitsStrict(8) + } + func writeInt8(_ value: Int) { writeUInt8(value) } @@ -109,6 +144,10 @@ final class PacketBitStream: NSObject { Int(Int8(truncatingIfNeeded: readUInt8())) } + func readInt8Strict() throws -> Int { + Int(Int8(truncatingIfNeeded: try readUInt8Strict())) + } + // MARK: - UInt16 / Int16 (16 bits) func writeUInt16(_ value: Int) { @@ -123,6 +162,12 @@ final class PacketBitStream: NSObject { return (hi << 8) | lo } + func readUInt16Strict() throws -> Int { + let hi = try readUInt8Strict() + let lo = try readUInt8Strict() + return (hi << 8) | lo + } + func writeInt16(_ value: Int) { writeUInt16(value) } @@ -131,6 +176,10 @@ final class PacketBitStream: NSObject { Int(Int16(truncatingIfNeeded: readUInt16())) } + func readInt16Strict() throws -> Int { + Int(Int16(truncatingIfNeeded: try readUInt16Strict())) + } + // MARK: - UInt32 / Int32 (32 bits) func writeUInt32(_ value: Int) { @@ -148,6 +197,14 @@ final class PacketBitStream: NSObject { return (b1 << 24) | (b2 << 16) | (b3 << 8) | b4 } + func readUInt32Strict() throws -> Int { + let b1 = try readUInt8Strict() + let b2 = try readUInt8Strict() + let b3 = try readUInt8Strict() + let b4 = try readUInt8Strict() + return (b1 << 24) | (b2 << 16) | (b3 << 8) | b4 + } + func writeInt32(_ value: Int) { writeUInt32(value) } @@ -156,6 +213,10 @@ final class PacketBitStream: NSObject { Int(Int32(truncatingIfNeeded: readUInt32())) } + func readInt32Strict() throws -> Int { + Int(Int32(truncatingIfNeeded: try readUInt32Strict())) + } + // MARK: - UInt64 / Int64 (64 bits) func writeUInt64(_ value: Int64) { @@ -175,6 +236,12 @@ final class PacketBitStream: NSObject { return (high << 32) | low } + func readUInt64Strict() throws -> Int64 { + let high = Int64(try readUInt32Strict()) & 0xFFFFFFFF + let low = Int64(try readUInt32Strict()) & 0xFFFFFFFF + return (high << 32) | low + } + func writeInt64(_ value: Int64) { writeUInt64(value) } @@ -183,6 +250,10 @@ final class PacketBitStream: NSObject { readUInt64() } + func readInt64Strict() throws -> Int64 { + try readUInt64Strict() + } + // MARK: - Float32 func writeFloat32(_ value: Float) { @@ -223,6 +294,24 @@ final class PacketBitStream: NSObject { return String(decoding: codeUnits, as: UTF16.self) } + func readStringStrict() throws -> String { + let length = try readUInt32Strict() + guard length > 0 else { return "" } + guard length <= Int(Int32.max) else { + throw PacketBitStreamError.invalidStringLength(length) + } + + let requiredBits = length * 16 + try ensureReadableBits(requiredBits, operation: "readString") + + var codeUnits = [UInt16]() + codeUnits.reserveCapacity(length) + for _ in 0.. Int { - writePointer - readPointer - } - private func writeBits(_ value: Int, count: Int) { guard count > 0 else { return } ensureCapacityForUpcomingBits(count) @@ -307,6 +392,31 @@ final class PacketBitStream: NSObject { return value } + private func readBitsStrict(_ count: Int) throws -> Int { + try ensureReadableBits(count, operation: "readBits") + var value = 0 + for _ in 0..> 3 + let shift = 7 - (readPointer & 7) + let bit = Int((bytes[byteIndex] >> shift) & 1) + value = (value << 1) | bit + readPointer += 1 + } + return value + } + + private func ensureReadableBits(_ needed: Int, operation: String) throws { + let remaining = remainingBits() + guard needed > 0 else { return } + guard remaining >= needed else { + throw PacketBitStreamError.underflow( + operation: operation, + neededBits: needed, + remainingBits: remaining + ) + } + } + private func ensureCapacityForUpcomingBits(_ bitCount: Int) { guard bitCount > 0 else { return } let lastBitIndex = writePointer + bitCount - 1 diff --git a/Rosetta/Core/Services/SessionManager.swift b/Rosetta/Core/Services/SessionManager.swift index d29a9aa..3307536 100644 --- a/Rosetta/Core/Services/SessionManager.swift +++ b/Rosetta/Core/Services/SessionManager.swift @@ -46,6 +46,11 @@ final class SessionManager { private var hasTriggeredGroupRecoverySync = false private var pendingIncomingMessages: [PacketMessage] = [] private var isProcessingIncomingMessages = false + /// Drop+resync recovery for malformed 0x06 packets. + private var malformedMessageResyncTask: Task? + private var malformedMessageResyncQueued = false + private static let malformedMessageResyncDebounceNs: UInt64 = 350_000_000 + private static let malformedMessageResyncRetryDelayNs: UInt64 = 200_000_000 /// Android parity: tracks the latest incoming message timestamp per dialog /// for which a read receipt was already sent. Prevents redundant sends. private var lastReadReceiptTimestamp: [String: Int64] = [:] @@ -64,6 +69,10 @@ final class SessionManager { private var pendingOutgoingRetryTasks: [String: Task] = [:] private var pendingOutgoingPackets: [String: PacketMessage] = [:] private var pendingOutgoingAttempts: [String: Int] = [:] + /// Guards against rapid duplicate push subscribe sends during reconnect storms. + private var lastPushTokenSubscribe: (token: String, sentAt: TimeInterval)? + /// Guards against rapid duplicate VoIP subscribe sends during reconnect storms. + private var lastVoIPTokenSubscribe: (token: String, sentAt: TimeInterval)? private let maxOutgoingRetryAttempts = ProtocolConstants.maxOutgoingRetryAttempts private let maxOutgoingWaitingLifetimeMs: Int64 = ProtocolConstants.messageDeliveryTimeoutS * 1000 var attachmentFlowTransport: AttachmentFlowTransporting = LiveAttachmentFlowTransport() @@ -93,6 +102,10 @@ final class SessionManager { } private var userInfoSearchHandlerToken: UUID? + #if DEBUG + private var malformedMessageResyncTestHook: (() -> Void)? + private(set) var malformedMessageResyncTriggerCount: Int = 0 + #endif private init() { setupProtocolCallbacks() @@ -754,54 +767,71 @@ final class SessionManager { } // ── Phase 2: Upload in background, then send packet ── - let flowTransport = attachmentFlowTransport - let messageAttachments: [MessageAttachment] = try await withThrowingTaskGroup( - of: (Int, String, String).self - ) { group in - for (index, item) in encryptedAttachments.enumerated() { - group.addTask { - let result = try await flowTransport.uploadFile( - id: item.original.id, content: item.encryptedData + // Wrapped in do/catch: if CDN upload fails, mark message as .error + // so retryWaitingOutgoingMessagesAfterReconnect() doesn't pick it up + // and send a text-only packet (which causes empty messages on recipient). + do { + let flowTransport = attachmentFlowTransport + let messageAttachments: [MessageAttachment] = try await withThrowingTaskGroup( + of: (Int, String, String).self + ) { group in + for (index, item) in encryptedAttachments.enumerated() { + group.addTask { + let result = try await flowTransport.uploadFile( + id: item.original.id, content: item.encryptedData + ) + return (index, result.tag, result.server) + } + } + var uploads = [Int: (tag: String, server: String)]() + for try await (index, tag, server) in group { uploads[index] = (tag, server) } + return encryptedAttachments.enumerated().map { index, item in + let upload = uploads[index] ?? (tag: "", server: "") + // Desktop parity: preview = payload only (no tag prefix). + // Desktop MessageFile.tsx does preview.split("::")[0] for filesize — + // embedding tag prefix makes it parse the UUID as filesize → shows wrong filename. + // CDN tag stored in transportTag for download. + let preview = item.preview + Self.logger.info("📤 Attachment uploaded: type=\(String(describing: item.original.type)), tag=\(upload.tag), server=\(upload.server)") + return MessageAttachment( + id: item.original.id, preview: preview, blob: "", + type: item.original.type, + transportTag: upload.tag, transportServer: upload.server ) - return (index, result.tag, result.server) } } - var uploads = [Int: (tag: String, server: String)]() - for try await (index, tag, server) in group { uploads[index] = (tag, server) } - return encryptedAttachments.enumerated().map { index, item in - let upload = uploads[index] ?? (tag: "", server: "") - // Desktop parity: preview = payload only (no tag prefix). - // Desktop MessageFile.tsx does preview.split("::")[0] for filesize — - // embedding tag prefix makes it parse the UUID as filesize → shows wrong filename. - // CDN tag stored in transportTag for download. - let preview = item.preview - Self.logger.info("📤 Attachment uploaded: type=\(String(describing: item.original.type)), tag=\(upload.tag), server=\(upload.server)") - return MessageAttachment( - id: item.original.id, preview: preview, blob: "", - type: item.original.type, - transportTag: upload.tag, transportServer: upload.server + + // Update message with real attachment tags (preview with CDN tag) + MessageRepository.shared.updateAttachments(messageId: messageId, attachments: messageAttachments) + + // Build final packet for WebSocket send + var packet = optimisticPacket + packet.attachments = messageAttachments + + packetFlowSender.sendPacket(packet) + if isGroup { + MessageRepository.shared.updateDeliveryStatus(messageId: messageId, status: .delivered) + DialogRepository.shared.updateDeliveryStatus( + messageId: messageId, opponentKey: optimisticDialogKey, status: .delivered ) + } else { + registerOutgoingRetry(for: packet) } - } - - // Update message with real attachment tags (preview with CDN tag) - MessageRepository.shared.updateAttachments(messageId: messageId, attachments: messageAttachments) - - // Build final packet for WebSocket send - var packet = optimisticPacket - packet.attachments = messageAttachments - - packetFlowSender.sendPacket(packet) - if isGroup { - MessageRepository.shared.updateDeliveryStatus(messageId: messageId, status: .delivered) + MessageRepository.shared.persistNow() + Self.logger.info("📤 Message with \(attachments.count) attachment(s) sent to \(toPublicKey.prefix(12))…") + } catch { + // CDN upload or packet send failed — mark as .error to show failure to user. + // Note: retryWaitingOutgoingMessagesAfterReconnect() may still pick up .error + // messages within 80s, but the retry logic now checks for uploaded CDN tags + // and skips messages with placeholder-only attachments. + Self.logger.error("📤 CDN upload/send failed for \(messageId.prefix(8))…: \(error.localizedDescription)") + MessageRepository.shared.updateDeliveryStatus(messageId: messageId, status: .error) DialogRepository.shared.updateDeliveryStatus( - messageId: messageId, opponentKey: optimisticDialogKey, status: .delivered + messageId: messageId, opponentKey: optimisticDialogKey, status: .error ) - } else { - registerOutgoingRetry(for: packet) + MessageRepository.shared.persistNow() + throw error } - MessageRepository.shared.persistNow() - Self.logger.info("📤 Message with \(attachments.count) attachment(s) sent to \(toPublicKey.prefix(12))…") } /// Builds a data URI from attachment data (desktop: `FileReader.readAsDataURL()`). @@ -1209,8 +1239,6 @@ final class SessionManager { /// Ends the session and disconnects. func endSession() { // Unsubscribe push tokens from server BEFORE disconnecting. - // Without this, old account's tokens stay registered → server sends - // VoIP pushes for calls to this device even after account switch. if let voipToken = UserDefaults.standard.string(forKey: "voip_push_token"), !voipToken.isEmpty { unsubscribeVoIPToken(voipToken) @@ -1231,12 +1259,21 @@ final class SessionManager { pendingOpponentReads.removeAll() pendingIncomingMessages.removeAll() isProcessingIncomingMessages = false + malformedMessageResyncTask?.cancel() + malformedMessageResyncTask = nil + malformedMessageResyncQueued = false + #if DEBUG + malformedMessageResyncTestHook = nil + malformedMessageResyncTriggerCount = 0 + #endif lastReadReceiptTimestamp.removeAll() requestedUserInfoKeys.removeAll() pendingOutgoingRetryTasks.values.forEach { $0.cancel() } pendingOutgoingRetryTasks.removeAll() pendingOutgoingPackets.removeAll() pendingOutgoingAttempts.removeAll() + lastPushTokenSubscribe = nil + lastVoIPTokenSubscribe = nil isAuthenticated = false currentPublicKey = "" displayName = "" @@ -1264,6 +1301,13 @@ final class SessionManager { } } + proto.onMalformedMessageReceived = { [weak self] info in + guard let self else { return } + Task { @MainActor [weak self] in + self?.handleMalformedMessagePacket(info) + } + } + proto.onDeliveryReceived = { [weak self] packet in Task { @MainActor in let opponentKey = MessageRepository.shared.dialogKey(forMessageId: packet.messageId) @@ -1696,6 +1740,17 @@ final class SessionManager { let isGroupDialog = context.kind == .group let wasKnownBefore = MessageRepository.shared.hasMessage(packet.messageId) + if packet.content.isEmpty && packet.attachments.isEmpty { + Self.logger.warning(""" + processIncoming: drop empty payload packet \ + msgId=\(packet.messageId.prefix(8))… \ + from=\(packet.fromPublicKey.prefix(8))… \ + to=\(packet.toPublicKey.prefix(8))… + """) + scheduleMessageRecoveryResync(messageId: packet.messageId, fingerprint: "empty_payload") + return + } + // Optimization: skip expensive crypto + upsert for incoming messages // already stored in DB. Only outgoing messages need re-processing // (sync may update delivery status from .waiting → .delivered). @@ -1714,19 +1769,8 @@ final class SessionManager { ) }() if isGroupDialog, groupKey == nil { - // Don't drop the message — store with encrypted content for retry. - // Group key may arrive later (join confirmation, sync). - Self.logger.warning("processIncoming: group key not found for \(opponentKey) — storing fallback") - let effectiveFromSync = syncBatchInProgress || ProtocolManager.shared.isSyncBatchActive - || packet.fromPublicKey == myKey - MessageRepository.shared.upsertFromMessagePacket( - packet, - myPublicKey: myKey, - decryptedText: "", - fromSync: effectiveFromSync, - dialogIdentityOverride: opponentKey - ) - DialogRepository.shared.updateDialogFromMessages(opponentKey: opponentKey) + Self.logger.warning("processIncoming: group key not found for \(opponentKey) — dropping and scheduling resync") + scheduleMessageRecoveryResync(messageId: packet.messageId, fingerprint: "group_key_missing") return } @@ -1744,11 +1788,8 @@ final class SessionManager { }.value guard let cryptoResult else { - // Desktop/Android parity: NEVER drop a message on decrypt failure. - // Desktop and Android store encrypted content and retry decryption - // on load. iOS was the only platform that lost messages permanently. - Self.logger.warning(""" - processIncoming: decrypt FAILED — storing fallback \ + Self.logger.error(""" + processIncoming: decrypt FAILED — dropping packet \ msgId=\(packet.messageId.prefix(8))… \ from=\(packet.fromPublicKey.prefix(8))… \ hasChachaKey=\(!packet.chachaKey.isEmpty) \ @@ -1756,22 +1797,27 @@ final class SessionManager { contentLen=\(packet.content.count) \ isOwnMessage=\(fromMe) """) - - let effectiveFromSync = syncBatchInProgress || ProtocolManager.shared.isSyncBatchActive || fromMe - MessageRepository.shared.upsertFromMessagePacket( - packet, - myPublicKey: myKey, - decryptedText: "", - fromSync: effectiveFromSync, - dialogIdentityOverride: opponentKey - ) - DialogRepository.shared.updateDialogFromMessages(opponentKey: opponentKey) + scheduleMessageRecoveryResync(messageId: packet.messageId, fingerprint: "decrypt_failed") return } let text = cryptoResult.text let processedPacket = cryptoResult.processedPacket let resolvedAttachmentPassword = cryptoResult.attachmentPassword + if text.isEmpty && processedPacket.attachments.isEmpty { + Self.logger.warning(""" + processIncoming: drop post-decrypt empty payload \ + msgId=\(packet.messageId.prefix(8))… \ + from=\(packet.fromPublicKey.prefix(8))… \ + to=\(packet.toPublicKey.prefix(8))… + """) + scheduleMessageRecoveryResync( + messageId: packet.messageId, + fingerprint: "empty_decrypted_payload" + ) + return + } + // For outgoing messages received from the server (sent by another device // on the same account), treat as sync-equivalent so status = .delivered. // Without this, real-time fromMe messages get .waiting → timeout → .error. @@ -1925,6 +1971,73 @@ final class SessionManager { ProtocolManager.shared.sendSearchPacket(searchPacket, channel: .userInfo) } + private func handleMalformedMessagePacket(_ info: MalformedMessagePacketInfo) { + let fingerprint = info.fingerprint.isEmpty ? "packet06_parse_failed" : info.fingerprint + Self.logger.error(""" + Dropping malformed 0x06 packet \ + size=\(info.packetSize) \ + msgHint=\(info.messageIdHint) \ + fp=\(fingerprint) + """) + + malformedMessageResyncQueued = true + malformedMessageResyncTask?.cancel() + scheduleMalformedMessageResyncFlush(afterNanoseconds: Self.malformedMessageResyncDebounceNs) + } + + private func scheduleMessageRecoveryResync(messageId: String, fingerprint: String) { + let msgHint = messageId.isEmpty ? "-" : String(messageId.prefix(8)) + handleMalformedMessagePacket( + MalformedMessagePacketInfo( + packetSize: 0, + fingerprint: fingerprint, + messageIdHint: msgHint + ) + ) + } + + private func scheduleMalformedMessageResyncFlush(afterNanoseconds delay: UInt64) { + malformedMessageResyncTask = Task { @MainActor [weak self] in + try? await Task.sleep(nanoseconds: delay) + guard let self, !Task.isCancelled else { return } + self.malformedMessageResyncTask = nil + self.flushMalformedMessageResync() + } + } + + private func flushMalformedMessageResync() { + guard malformedMessageResyncQueued else { return } + + guard !currentPublicKey.isEmpty else { + malformedMessageResyncQueued = false + return + } + + if syncBatchInProgress || ProtocolManager.shared.isSyncBatchActive { + scheduleMalformedMessageResyncFlush(afterNanoseconds: Self.malformedMessageResyncRetryDelayNs) + return + } + + if syncRequestInFlight { + scheduleMalformedMessageResyncFlush(afterNanoseconds: Self.malformedMessageResyncRetryDelayNs) + return + } + + malformedMessageResyncQueued = false + let cursor = loadLastSyncTimestamp() + + #if DEBUG + malformedMessageResyncTriggerCount += 1 + if let hook = malformedMessageResyncTestHook { + hook() + return + } + #endif + + Self.logger.warning("Malformed 0x06 recovery: requesting sync cursor=\(cursor)") + requestSynchronize(cursor: cursor) + } + private func requestSynchronize(cursor: Int64? = nil) { // No connectionState guard: this method is only called from (1) handshake // completion handler and (2) BATCH_END handler — both inherently authenticated. @@ -2121,10 +2234,14 @@ final class SessionManager { guard let privateKeyHex else { return nil } - // Allow empty content for messages with attachments (photo-only, call, etc.). - // Normally content is always non-empty (XChaCha20 of "" still produces ciphertext), - // but buggy senders or edge cases may send empty content with valid attachments. + // Allow empty content only for attachment-only packets. + // Text messages with empty content and no attachments are treated as invalid + // and must not create empty bubbles in UI. if packet.content.isEmpty { + guard !packet.attachments.isEmpty else { + Self.logger.warning("Rejecting packet with empty content and no attachments") + return nil + } return ("", nil) } @@ -2321,6 +2438,75 @@ final class SessionManager { packetFlowSender = LivePacketFlowSender() } + static func testDecryptIncomingMessage( + packet: PacketMessage, + myPublicKey: String, + privateKeyHex: String?, + groupKey: String? + ) -> (text: String, hasRawKeyData: Bool)? { + guard let result = decryptIncomingMessage( + packet: packet, + myPublicKey: myPublicKey, + privateKeyHex: privateKeyHex, + groupKey: groupKey + ) else { + return nil + } + return (result.text, result.rawKeyData != nil) + } + + #if DEBUG + static func testRecoverRetryPlaintext(storedText: String, privateKeyHex: String) -> String? { + recoverRetryPlaintext(storedText: storedText, privateKeyHex: privateKeyHex) + } + + static func testRawKeyAndNonceFromStoredAttachmentPassword(_ stored: String) -> Data? { + rawKeyAndNonceFromStoredAttachmentPassword(stored) + } + + func testSetMalformedMessageResyncHook(_ hook: (() -> Void)?) { + malformedMessageResyncTestHook = hook + } + + func testSimulateMalformedMessagePacketDrop( + packetSize: Int = 0, + fingerprint: String = "test_packet06_malformed", + messageIdHint: String = "-" + ) { + handleMalformedMessagePacket( + MalformedMessagePacketInfo( + packetSize: packetSize, + fingerprint: fingerprint, + messageIdHint: messageIdHint + ) + ) + } + + func testResetMalformedMessageResyncState() { + malformedMessageResyncTask?.cancel() + malformedMessageResyncTask = nil + malformedMessageResyncQueued = false + malformedMessageResyncTestHook = nil + malformedMessageResyncTriggerCount = 0 + } + + func testSetSyncState( + syncRequestInFlight: Bool? = nil, + syncBatchInProgress: Bool? = nil + ) { + if let syncRequestInFlight { + self.syncRequestInFlight = syncRequestInFlight + } + if let syncBatchInProgress { + self.syncBatchInProgress = syncBatchInProgress + } + } + + func testProcessIncomingMessage(_ packet: PacketMessage) async { + await processIncomingMessage(packet) + } + #endif + /// Public convenience for views that need to trigger a user-info fetch. func requestUserInfoIfNeeded(forKey publicKey: String) { requestUserInfoIfNeeded(opponentKey: publicKey, privateKeyHash: privateKeyHash) @@ -2548,8 +2734,41 @@ final class SessionManager { continue } - let text = message.text.trimmingCharacters(in: .whitespacesAndNewlines) - guard !text.isEmpty else { continue } + // Decrypt stored text to recover original plaintext. + // DB stores encryptWithPassword(plaintext, privateKey) — we must reverse this + // to avoid double-encrypting (makeOutgoingPacket encrypts with XChaCha20). + guard let plaintext = Self.recoverRetryPlaintext( + storedText: message.text, + privateKeyHex: privateKeyHex + ) else { + markRetryMessageAsError( + message, + reason: "plaintext_unrecoverable" + ) + continue + } + + let uploadBackedAttachments = message.attachments.filter { + $0.type == .image || $0.type == .file || $0.type == .avatar + } + let uploadedAttachments = uploadBackedAttachments.filter { + !$0.effectiveDownloadTag.isEmpty + } + let hasUploadBackedAttachments = !uploadBackedAttachments.isEmpty + + if hasUploadBackedAttachments, + uploadedAttachments.count != uploadBackedAttachments.count { + markRetryMessageAsError( + message, + reason: "missing_attachment_transport_tags" + ) + continue + } + + let hasPlaintext = !plaintext.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty + guard hasUploadBackedAttachments || hasPlaintext else { + continue + } // Update dialog delivery status back to .waiting (shows clock icon). DialogRepository.shared.updateDeliveryStatus( @@ -2563,17 +2782,75 @@ final class SessionManager { // (Executor6Message maxPaddingSec=30). Original timestamp would fail validation // if reconnect took >30s. Server overwrites timestamp with System.currentTimeMillis() // anyway (Executor6Message:102), so client timestamp is only for age validation. - let packet = try makeOutgoingPacket( - text: text, - toPublicKey: message.toPublicKey, - messageId: message.id, - timestamp: Int64(Date().timeIntervalSince1970 * 1000), - privateKeyHex: privateKeyHex, - privateKeyHash: privateKeyHash - ) - ProtocolManager.shared.sendPacket(packet) - registerOutgoingRetry(for: packet) - Self.logger.info("Retrying message \(message.id.prefix(8))… to \(message.toPublicKey.prefix(12))…") + let freshTimestamp = Int64(Date().timeIntervalSince1970 * 1000) + + if hasUploadBackedAttachments { + // ── Retry WITH attachments using ORIGINAL key+nonce ── + // CDN blob was encrypted with hex(originalKeyAndNonce). + // We MUST reuse the SAME key+nonce so the recipient derives the + // SAME PBKDF2 password and can decrypt the blob from CDN. + guard let storedPwd = message.attachmentPassword, + let originalKeyAndNonce = Self.rawKeyAndNonceFromStoredAttachmentPassword(storedPwd) + else { + markRetryMessageAsError( + message, + reason: "attachment_key_unrecoverable" + ) + continue + } + + let key = Data(originalKeyAndNonce.prefix(32)) + let nonce = Data(originalKeyAndNonce.subdata(in: 32..<56)) + + // XChaCha20-encrypt text with ORIGINAL key+nonce + let ciphertextWithTag = try XChaCha20Engine.encrypt( + plaintext: Data(plaintext.utf8), key: key, nonce: nonce + ) + let content = ciphertextWithTag.hexString + + // ECDH-encrypt ORIGINAL key+nonce for recipient (fresh ephemeral key) + let chachaKey = try MessageCrypto.encryptKeyForRecipient( + keyAndNonce: originalKeyAndNonce, + recipientPublicKeyHex: message.toPublicKey + ) + + // aesChachaKey for sync (same derivation as original send) + guard let latin1String = String(data: originalKeyAndNonce, encoding: .isoLatin1) else { + throw CryptoError.encryptionFailed + } + let aesChachaPayload = Data(latin1String.utf8) + let aesChachaKey = try CryptoManager.shared.encryptWithPasswordDesktopCompat( + aesChachaPayload, password: privateKeyHex + ) + + var packet = PacketMessage() + packet.fromPublicKey = currentPublicKey + packet.toPublicKey = message.toPublicKey + packet.content = content + packet.chachaKey = chachaKey + packet.timestamp = freshTimestamp + packet.privateKey = privateKeyHash + packet.messageId = message.id + packet.attachments = uploadedAttachments + packet.aesChachaKey = aesChachaKey + + ProtocolManager.shared.sendPacket(packet) + registerOutgoingRetry(for: packet) + Self.logger.info("Retrying message+attachments \(message.id.prefix(8))… (\(uploadedAttachments.count) att) to \(message.toPublicKey.prefix(12))…") + } else { + // ── Retry text-only (no upload-backed attachments) ── + let packet = try makeOutgoingPacket( + text: plaintext, + toPublicKey: message.toPublicKey, + messageId: message.id, + timestamp: freshTimestamp, + privateKeyHex: privateKeyHex, + privateKeyHash: privateKeyHash + ) + ProtocolManager.shared.sendPacket(packet) + registerOutgoingRetry(for: packet) + Self.logger.info("Retrying message \(message.id.prefix(8))… to \(message.toPublicKey.prefix(12))…") + } } catch { Self.logger.error("Failed to retry message \(message.id): \(error.localizedDescription)") MessageRepository.shared.updateDeliveryStatus(messageId: message.id, status: .error) @@ -2586,6 +2863,80 @@ final class SessionManager { } } + private static func recoverRetryPlaintext( + storedText: String, + privateKeyHex: String + ) -> String? { + if storedText.isEmpty { return "" } + + // Try decryption with compression first (standard path), then without. + if let data = try? CryptoManager.shared.decryptWithPassword( + storedText, + password: privateKeyHex, + requireCompression: true + ), + let decoded = String(data: data, encoding: .utf8) { + return decoded + } + if let data = try? CryptoManager.shared.decryptWithPassword( + storedText, + password: privateKeyHex + ), + let decoded = String(data: data, encoding: .utf8) { + return decoded + } + + // Legacy fallback: allow plain text only if it doesn't look like ciphertext. + if isLikelyEncryptedPayload(storedText) { + return nil + } + return storedText + } + + private static func isLikelyEncryptedPayload(_ value: String) -> Bool { + let trimmed = value.trimmingCharacters(in: .whitespacesAndNewlines) + if trimmed.isEmpty { return false } + + if trimmed.hasPrefix("CHNK:") { return true } + + let parts = trimmed.components(separatedBy: ":") + if parts.count == 2 { + let base64Chars = CharacterSet.alphanumerics.union(CharacterSet(charactersIn: "+/=")) + if parts.allSatisfy({ part in + part.count >= 16 && part.unicodeScalars.allSatisfy { base64Chars.contains($0) } + }) { + return true + } + } + + if trimmed.count >= 40 { + let hexChars = CharacterSet(charactersIn: "0123456789abcdefABCDEF") + if trimmed.unicodeScalars.allSatisfy({ hexChars.contains($0) }) { + return true + } + } + return false + } + + private static func rawKeyAndNonceFromStoredAttachmentPassword(_ stored: String) -> Data? { + guard stored.hasPrefix("rawkey:") else { return nil } + let hex = String(stored.dropFirst("rawkey:".count)) + guard let decoded = Data(strictHexString: hex), decoded.count == 56 else { + return nil + } + return decoded + } + + private func markRetryMessageAsError(_ message: ChatMessage, reason: String) { + Self.logger.error("Retry ABORT: \(reason) for message \(message.id.prefix(8))…") + MessageRepository.shared.updateDeliveryStatus(messageId: message.id, status: .error) + DialogRepository.shared.updateDeliveryStatus( + messageId: message.id, + opponentKey: message.toPublicKey, + status: .error + ) + } + private func registerOutgoingRetry(for packet: PacketMessage) { let messageId = packet.messageId pendingOutgoingRetryTasks[messageId]?.cancel() @@ -2705,30 +3056,40 @@ final class SessionManager { // MARK: - Push Notifications - /// Stores the APNs device token received from AppDelegate. - /// Called from AppDelegate.didRegisterForRemoteNotificationsWithDeviceToken. + /// Stores the FCM registration token received from Firebase Messaging delegate. func setAPNsToken(_ token: String) { - UserDefaults.standard.set(token, forKey: "apns_device_token") + let normalizedToken = token.trimmingCharacters(in: .whitespacesAndNewlines) + guard !normalizedToken.isEmpty else { return } + UserDefaults.standard.set(normalizedToken, forKey: "apns_device_token") // If already authenticated, send immediately if ProtocolManager.shared.connectionState == .authenticated { - sendPushTokenToServer() + sendPushTokenToServer(force: true) } } - /// Sends the stored APNs push token to the server via PacketPushNotification (0x10). + /// Sends the stored FCM push token to the server via PacketPushNotification (0x10). /// Android parity: called after successful handshake. - private func sendPushTokenToServer() { + private func sendPushTokenToServer(force: Bool = false) { guard let token = UserDefaults.standard.string(forKey: "apns_device_token"), !token.isEmpty, ProtocolManager.shared.connectionState == .authenticated else { return } + let now = Date().timeIntervalSince1970 + if !force, + let lastPushTokenSubscribe, + lastPushTokenSubscribe.token == token, + now - lastPushTokenSubscribe.sentAt < 5 { + return + } + var packet = PacketPushNotification() packet.notificationsToken = token packet.action = .subscribe packet.tokenType = .fcm packet.deviceId = DeviceIdentityManager.shared.currentDeviceId() ProtocolManager.shared.sendPacket(packet) + lastPushTokenSubscribe = (token, now) Self.logger.info("FCM push token sent to server") } @@ -2736,25 +3097,36 @@ final class SessionManager { /// Stores the VoIP push token received from PushKit. func setVoIPToken(_ token: String) { - UserDefaults.standard.set(token, forKey: "voip_push_token") + let normalizedToken = token.trimmingCharacters(in: .whitespacesAndNewlines) + guard !normalizedToken.isEmpty else { return } + UserDefaults.standard.set(normalizedToken, forKey: "voip_push_token") if ProtocolManager.shared.connectionState == .authenticated { - sendVoIPTokenToServer() + sendVoIPTokenToServer(force: true) } } /// Sends the stored VoIP push token to the server via PacketPushNotification (0x10). - private func sendVoIPTokenToServer() { + private func sendVoIPTokenToServer(force: Bool = false) { guard let token = UserDefaults.standard.string(forKey: "voip_push_token"), !token.isEmpty, ProtocolManager.shared.connectionState == .authenticated else { return } + let now = Date().timeIntervalSince1970 + if !force, + let lastVoIPTokenSubscribe, + lastVoIPTokenSubscribe.token == token, + now - lastVoIPTokenSubscribe.sentAt < 5 { + return + } + var packet = PacketPushNotification() packet.notificationsToken = token packet.action = .subscribe packet.tokenType = .voipApns packet.deviceId = DeviceIdentityManager.shared.currentDeviceId() ProtocolManager.shared.sendPacket(packet) + lastVoIPTokenSubscribe = (token, now) Self.logger.info("VoIP push token sent to server") } @@ -2773,19 +3145,23 @@ final class SessionManager { Self.logger.info("FCM token unsubscribed from server") } - /// Sends unsubscribe for a stale VoIP token (called when PushKit invalidates token). + /// Sends unsubscribe for a stale VoIP token and clears local storage. func unsubscribeVoIPToken(_ token: String) { - guard !token.isEmpty, - ProtocolManager.shared.connectionState == .authenticated - else { return } - - var packet = PacketPushNotification() - packet.notificationsToken = token - packet.action = .unsubscribe - packet.tokenType = .voipApns - packet.deviceId = DeviceIdentityManager.shared.currentDeviceId() - ProtocolManager.shared.sendPacket(packet) - Self.logger.info("VoIP token unsubscribed from server") + let normalizedToken = token.trimmingCharacters(in: .whitespacesAndNewlines) + guard !normalizedToken.isEmpty else { return } + if ProtocolManager.shared.connectionState == .authenticated { + var packet = PacketPushNotification() + packet.notificationsToken = normalizedToken + packet.action = .unsubscribe + packet.tokenType = .voipApns + packet.deviceId = DeviceIdentityManager.shared.currentDeviceId() + ProtocolManager.shared.sendPacket(packet) + Self.logger.info("VoIP token unsubscribed from server") + } + UserDefaults.standard.removeObject(forKey: "voip_push_token") + if let lastVoIPTokenSubscribe, lastVoIPTokenSubscribe.token == normalizedToken { + self.lastVoIPTokenSubscribe = nil + } } // MARK: - Release Notes (Desktop Parity) diff --git a/Rosetta/Core/Utils/AttachmentPreviewCodec.swift b/Rosetta/Core/Utils/AttachmentPreviewCodec.swift index 99c2d30..36394ba 100644 --- a/Rosetta/Core/Utils/AttachmentPreviewCodec.swift +++ b/Rosetta/Core/Utils/AttachmentPreviewCodec.swift @@ -8,6 +8,14 @@ import Foundation /// - legacy/local-only: raw preview payload without `tag::` prefix enum AttachmentPreviewCodec { + // Legacy iOS upload IDs were generated from [a-z0-9] with fixed length 8. + // We intentionally keep this strict to avoid stripping arbitrary prefixes + // from valid blurhash payloads that may contain "::". + private static let legacyTransportIdRegex = try! NSRegularExpression( + pattern: "^[a-z0-9]{8}$", + options: [] + ) + struct ParsedFilePreview: Equatable { let downloadTag: String let fileSize: Int @@ -42,11 +50,25 @@ enum AttachmentPreviewCodec { static func blurHash(from preview: String) -> String { let raw = payload(from: preview) - // Strip trailing "|WxH" dimension suffix if present. - if let pipeIdx = raw.lastIndex(of: "|") { - return String(raw[raw.startIndex.. Bool { + let trimmed = value.trimmingCharacters(in: .whitespacesAndNewlines) + let range = NSRange(trimmed.startIndex.. Bool { + let dedupKey = senderKey.trimmingCharacters(in: .whitespacesAndNewlines).isEmpty + ? "__no_sender__" + : senderKey.trimmingCharacters(in: .whitespacesAndNewlines) + + let now = Date().timeIntervalSince1970 + var timestamps = shared.dictionary(forKey: recentSenderNotificationTimestampsKey) as? [String: Double] ?? [:] + // Keep the map compact under NSE memory constraints. + timestamps = timestamps.filter { now - $0.value < 120 } + + if let last = timestamps[dedupKey], now - last < senderDedupWindowSeconds { + shared.set(timestamps, forKey: recentSenderNotificationTimestampsKey) + return true + } + + if shouldRecord { + timestamps[dedupKey] = now + shared.set(timestamps, forKey: recentSenderNotificationTimestampsKey) + } else { + shared.set(timestamps, forKey: recentSenderNotificationTimestampsKey) + } + return false + } + + /// Removes delivered notifications for this sender so duplicate bursts collapse + /// into a single latest entry (Android parity). + private static func replaceDeliveredNotifications( + for senderKey: String, + then completion: @escaping () -> Void + ) { + let center = UNUserNotificationCenter.current() + center.getDeliveredNotifications { delivered in + let idsToRemove = delivered + .filter { notification in + let info = notification.request.content.userInfo + let infoSender = info["sender_public_key"] as? String + ?? info["dialog"] as? String + ?? "" + return infoSender == senderKey || notification.request.content.threadIdentifier == senderKey + } + .map { $0.request.identifier } + + if !idsToRemove.isEmpty { + center.removeDeliveredNotifications(withIdentifiers: idsToRemove) + } + + completion() + } + } + // MARK: - Communication Notification (CarPlay + Focus) /// Wraps the notification content with an INSendMessageIntent so iOS treats it @@ -259,7 +337,8 @@ final class NotificationService: UNNotificationServiceExtension { ) -> UNNotificationContent { let handle = INPersonHandle(value: senderKey, type: .unknown) let displayName = senderName.isEmpty ? "Rosetta" : senderName - let avatarImage = generateLetterAvatar(name: displayName, key: senderKey) + let avatarImage = loadNotificationAvatar(for: senderKey) + ?? generateLetterAvatar(name: displayName, key: senderKey) let sender = INPerson( personHandle: handle, nameComponents: nil, @@ -280,7 +359,7 @@ final class NotificationService: UNNotificationServiceExtension { attachments: nil ) - // Set avatar on sender parameter (Telegram parity: 50x50 letter avatar). + // Set avatar on sender parameter (prefer real avatar from App Group, fallback to letter avatar). if let avatarImage { intent.setImage(avatarImage, forParameterNamed: \.sender) } @@ -364,6 +443,64 @@ final class NotificationService: UNNotificationServiceExtension { return INImage(imageData: pngData) } + /// Loads sender avatar from shared App Group cache written by the main app. + /// Falls back to letter avatar when no real image is available. + private static func loadNotificationAvatar(for senderKey: String) -> INImage? { + guard let appGroupURL = FileManager.default.containerURL( + forSecurityApplicationGroupIdentifier: Self.appGroupID + ) else { + return nil + } + + let avatarsDir = appGroupURL.appendingPathComponent("NotificationAvatars", isDirectory: true) + for candidate in avatarKeyCandidates(for: senderKey) { + let normalized = normalizedAvatarKey(candidate) + guard !normalized.isEmpty else { continue } + let avatarURL = avatarsDir.appendingPathComponent("\(normalized).jpg") + if let data = try? Data(contentsOf: avatarURL), !data.isEmpty { + return INImage(imageData: data) + } + } + return nil + } + + /// Server may send group dialog key in different forms (`raw`, `#group:raw`, `group:raw`). + /// Probe all variants so NSE can find avatar mirrored by the main app. + private static func avatarKeyCandidates(for senderKey: String) -> [String] { + let trimmed = senderKey.trimmingCharacters(in: .whitespacesAndNewlines) + guard !trimmed.isEmpty else { return [] } + + var candidates: [String] = [trimmed] + let lower = trimmed.lowercased() + if lower.hasPrefix("#group:") { + let raw = String(trimmed.dropFirst("#group:".count)).trimmingCharacters(in: .whitespacesAndNewlines) + if !raw.isEmpty { + candidates.append(raw) + candidates.append("group:\(raw)") + } + } else if lower.hasPrefix("group:") { + let raw = String(trimmed.dropFirst("group:".count)).trimmingCharacters(in: .whitespacesAndNewlines) + if !raw.isEmpty { + candidates.append(raw) + candidates.append("#group:\(raw)") + } + } else if !trimmed.isEmpty { + candidates.append("#group:\(trimmed)") + candidates.append("group:\(trimmed)") + } + + // Keep first-seen order and drop duplicates. + var seen = Set() + return candidates.filter { seen.insert($0).inserted } + } + + private static func normalizedAvatarKey(_ key: String) -> String { + key + .trimmingCharacters(in: .whitespacesAndNewlines) + .replacingOccurrences(of: "0x", with: "") + .lowercased() + } + // MARK: - Helpers /// Android parity: extract sender key from multiple possible key names. diff --git a/RosettaTests/AttachmentParityTests.swift b/RosettaTests/AttachmentParityTests.swift index 2d516e7..c5cb06a 100644 --- a/RosettaTests/AttachmentParityTests.swift +++ b/RosettaTests/AttachmentParityTests.swift @@ -99,14 +99,16 @@ final class AttachmentParityTests: XCTestCase { XCTFail("Missing image attachment in packet") return } - XCTAssertEqual(AttachmentPreviewCodec.downloadTag(from: sentImage.preview), imageTag) + XCTAssertEqual(sentImage.transportTag, imageTag) + XCTAssertEqual(AttachmentPreviewCodec.downloadTag(from: sentImage.preview), "") guard let sentFile = sent.attachments.first(where: { $0.id == fileAttachment.id }) else { XCTFail("Missing file attachment in packet") return } + XCTAssertEqual(sentFile.transportTag, fileTag) let parsedFile = AttachmentPreviewCodec.parseFilePreview(sentFile.preview) - XCTAssertEqual(parsedFile.downloadTag, fileTag) + XCTAssertEqual(parsedFile.downloadTag, "") XCTAssertEqual(parsedFile.fileSize, fileData.count) XCTAssertEqual(parsedFile.fileName, "notes.txt") } diff --git a/RosettaTests/CallPushIntegrationTests.swift b/RosettaTests/CallPushIntegrationTests.swift index 0df0332..bd1a412 100644 --- a/RosettaTests/CallPushIntegrationTests.swift +++ b/RosettaTests/CallPushIntegrationTests.swift @@ -5,36 +5,37 @@ import Testing struct PushNotificationExtendedTests { - @Test("Realistic FCM token with device ID round-trip") - func fcmTokenWithDeviceIdRoundTrip() throws { + @Test("Realistic FCM token round-trip") + func fcmTokenRoundTrip() throws { // Real FCM tokens are ~163 chars let fcmToken = "dQw4w9WgXcQ:APA91bHnzPc5Y0z4R8kP3mN6vX2tL7wJ1qA5sD8fG0hK3lZ9xC2vB4nM7oP1iU8yT6rE5wQ3jF4kL2mN0bV7cX9sD1aF3gH5jK7lP9oI2uY4tR6eW8qZ0xC" var packet = PacketPushNotification() packet.notificationsToken = fcmToken packet.action = .subscribe packet.tokenType = .fcm - packet.deviceId = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnop" + packet.deviceId = "ios-fcm-device" let decoded = try decode(packet) #expect(decoded.notificationsToken == fcmToken) #expect(decoded.action == .subscribe) #expect(decoded.tokenType == .fcm) - #expect(decoded.deviceId == "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnop") + #expect(decoded.deviceId == "ios-fcm-device") } - @Test("Realistic VoIP hex token round-trip") - func voipTokenWithDeviceIdRoundTrip() throws { - // PushKit tokens are 32 bytes = 64 hex chars - let voipToken = "a1b2c3d4e5f6a7b8c9d0e1f2a3b4c5d6e7f8a9b0c1d2e3f4a5b6c7d8e9f0a1b2" + @Test("Realistic APNs hex token round-trip") + func apnsTokenRoundTrip() throws { + let apnsToken = "a1b2c3d4e5f6a7b8c9d0e1f2a3b4c5d6e7f8a9b0c1d2e3f4a5b6c7d8e9f0a1b2" var packet = PacketPushNotification() - packet.notificationsToken = voipToken + packet.notificationsToken = apnsToken packet.action = .subscribe packet.tokenType = .voipApns - packet.deviceId = "device-xyz-123" + packet.deviceId = "ios-voip-device" let decoded = try decode(packet) - #expect(decoded.notificationsToken == voipToken) + #expect(decoded.notificationsToken == apnsToken) + #expect(decoded.action == .subscribe) #expect(decoded.tokenType == .voipApns) + #expect(decoded.deviceId == "ios-voip-device") } @Test("Long token (256 chars) round-trip — stress test UInt32 string length") @@ -44,38 +45,43 @@ struct PushNotificationExtendedTests { packet.notificationsToken = longToken packet.action = .subscribe packet.tokenType = .fcm - packet.deviceId = "dev" + packet.deviceId = "ios-long-device" let decoded = try decode(packet) #expect(decoded.notificationsToken == longToken) #expect(decoded.notificationsToken.count == 256) + #expect(decoded.tokenType == .fcm) + #expect(decoded.deviceId == "ios-long-device") } - @Test("Unicode device ID with emoji and Cyrillic round-trip") - func unicodeDeviceIdRoundTrip() throws { - let unicodeId = "Телефон Гайдара 📱" + @Test("Unicode token round-trip") + func unicodeTokenRoundTrip() throws { + let unicodeToken = "Токен-Гайдара-📱" var packet = PacketPushNotification() - packet.notificationsToken = "token" + packet.notificationsToken = unicodeToken packet.action = .subscribe packet.tokenType = .fcm - packet.deviceId = unicodeId + packet.deviceId = "ios-unicode-device" let decoded = try decode(packet) - #expect(decoded.deviceId == unicodeId) + #expect(decoded.notificationsToken == unicodeToken) + #expect(decoded.tokenType == .fcm) + #expect(decoded.deviceId == "ios-unicode-device") } - @Test("Unsubscribe action round-trip for both token types", - arguments: [PushTokenType.fcm, PushTokenType.voipApns]) - func unsubscribeRoundTrip(tokenType: PushTokenType) throws { + @Test("Unsubscribe action round-trip") + func unsubscribeRoundTrip() throws { var packet = PacketPushNotification() packet.notificationsToken = "test-token" packet.action = .unsubscribe - packet.tokenType = tokenType - packet.deviceId = "dev" + packet.tokenType = .voipApns + packet.deviceId = "ios-unsub-device" let decoded = try decode(packet) #expect(decoded.action == .unsubscribe) - #expect(decoded.tokenType == tokenType) + #expect(decoded.notificationsToken == "test-token") + #expect(decoded.tokenType == .voipApns) + #expect(decoded.deviceId == "ios-unsub-device") } private func decode(_ packet: PacketPushNotification) throws -> PacketPushNotification { @@ -200,11 +206,6 @@ struct CallPushEnumParityTests { #expect(pair.0.rawValue == pair.1) } - @Test("PushTokenType enum values match server") - func pushTokenTypeValues() { - #expect(PushTokenType.fcm.rawValue == 0) - #expect(PushTokenType.voipApns.rawValue == 1) - } } // MARK: - Wire Format Byte-Level Tests @@ -216,8 +217,8 @@ struct CallPushWireFormatTests { var packet = PacketPushNotification() packet.notificationsToken = "A" packet.action = .unsubscribe - packet.tokenType = .fcm - packet.deviceId = "B" + packet.tokenType = .voipApns + packet.deviceId = "D" let data = PacketRegistry.encode(packet) #expect(data.count == 16) @@ -230,12 +231,12 @@ struct CallPushWireFormatTests { #expect(data[6] == 0x00); #expect(data[7] == 0x41) // action = 1 (unsubscribe) #expect(data[8] == 0x01) - // tokenType = 0 (fcm) - #expect(data[9] == 0x00) - // deviceId "B": length=1, 'B'=0x0042 + // tokenType = 1 (voipApns) + #expect(data[9] == 0x01) + // deviceId "D": length=1, 'D'=0x0044 #expect(data[10] == 0x00); #expect(data[11] == 0x00) #expect(data[12] == 0x00); #expect(data[13] == 0x01) - #expect(data[14] == 0x00); #expect(data[15] == 0x42) + #expect(data[14] == 0x00); #expect(data[15] == 0x44) } @Test("SignalPeer call byte layout: signalType→src→dst→callId→joinToken") diff --git a/RosettaTests/CryptoParityTests.swift b/RosettaTests/CryptoParityTests.swift index fccae45..fa3f8aa 100644 --- a/RosettaTests/CryptoParityTests.swift +++ b/RosettaTests/CryptoParityTests.swift @@ -1,9 +1,12 @@ import XCTest +import CommonCrypto +import P256K @testable import Rosetta /// Cross-platform crypto parity tests: iOS ↔ Desktop ↔ Android. /// Verifies that all crypto operations produce compatible output /// and that messages encrypted on any platform can be decrypted on iOS. +@MainActor final class CryptoParityTests: XCTestCase { // MARK: - XChaCha20-Poly1305 Round-Trip @@ -228,16 +231,23 @@ final class CryptoParityTests: XCTestCase { let data = try CryptoPrimitives.randomBytes(count: 56) guard let latin1 = String(data: data, encoding: .isoLatin1) else { return } + let plaintext = Data(latin1.utf8) let encrypted = try CryptoManager.shared.encryptWithPasswordDesktopCompat( - Data(latin1.utf8), password: privateKeyHex + plaintext, password: privateKeyHex ) - XCTAssertThrowsError( - try CryptoManager.shared.decryptWithPassword( + do { + let decrypted = try CryptoManager.shared.decryptWithPassword( encrypted, password: wrongKeyHex, requireCompression: true - ), - "Decryption with wrong password must fail" - ) + ) + XCTAssertNotEqual( + decrypted, + plaintext, + "Wrong password must never recover original plaintext" + ) + } catch { + // Expected path for the majority of wrong-password attempts. + } } // MARK: - Attachment Password Candidates @@ -273,8 +283,9 @@ final class CryptoParityTests: XCTestCase { let stored = "some_legacy_password_string" let candidates = MessageCrypto.attachmentPasswordCandidates(from: stored) - XCTAssertEqual(candidates.count, 1, "Legacy format returns single candidate") - XCTAssertEqual(candidates[0], stored, "Legacy candidate is the stored value itself") + XCTAssertEqual(candidates.count, 2, "Legacy format returns hex+plain candidates") + XCTAssertEqual(candidates[0], Data(stored.utf8).map { String(format: "%02x", $0) }.joined()) + XCTAssertEqual(candidates[1], stored, "Legacy plain candidate must be preserved") } func testAttachmentPasswordCandidates_hexMatchesDesktop() { @@ -292,7 +303,7 @@ final class CryptoParityTests: XCTestCase { let candidates = MessageCrypto.attachmentPasswordCandidates(from: stored) // Desktop: Buffer.from(keyBytes).toString('hex') - let expectedDesktopPassword = "deadbeefcafebabe01020304050607080910111213141516171819202122232425262728292a2b2c2d2e2f30" + let expectedDesktopPassword = "deadbeefcafebabe0102030405060708090a0b0c0d0e0f101112131415161718191a1b1c1d1e1f202122232425262728292a2b2c2d2e2f30" // Verify hex format matches (lowercase, no separators) XCTAssertTrue(candidates[0].allSatisfy { "0123456789abcdef".contains($0) }, @@ -301,6 +312,7 @@ final class CryptoParityTests: XCTestCase { // Verify exact match with expected Desktop output // Note: the hex is based on keyBytes.hexString which should be lowercase XCTAssertEqual(candidates[0], keyBytes.hexString) + XCTAssertEqual(candidates[0], expectedDesktopPassword) } // MARK: - PBKDF2 Parity @@ -319,7 +331,7 @@ final class CryptoParityTests: XCTestCase { XCTAssertNotNil(key1) XCTAssertNotNil(key2) XCTAssertEqual(key1, key2, "PBKDF2 must be deterministic") - XCTAssertEqual(key1!.count, 32, "PBKDF2 key must be 32 bytes") + XCTAssertEqual(key1.count, 32, "PBKDF2 key must be 32 bytes") } func testPBKDF2_differentPasswordsDifferentKeys() throws { @@ -378,12 +390,21 @@ final class CryptoParityTests: XCTestCase { Data("secret".utf8), password: "correct_password" ) - XCTAssertThrowsError( - try CryptoManager.shared.decryptWithPassword( - encrypted, password: "wrong_password", requireCompression: true - ), - "Wrong password with requireCompression must fail" - ) + do { + let decrypted = try CryptoManager.shared.decryptWithPassword( + encrypted, + password: "wrong_password", + requireCompression: true + ) + let decryptedText = String(data: decrypted, encoding: .utf8) + XCTAssertNotEqual( + decryptedText, + "secret", + "Wrong password must never recover original plaintext" + ) + } catch { + // Expected path for most wrong-password attempts. + } } // MARK: - UTF-8 Decoder Parity (Android ↔ iOS) @@ -395,9 +416,11 @@ final class CryptoParityTests: XCTestCase { } func testAndroidUtf8Decoder_validMultibyte() { - let bytes = Data("Привет 🔐".utf8) + // Use BMP-only multibyte text here; four-byte emoji sequences are + // covered by malformed/compatibility behavior in separate tests. + let bytes = Data("Привет мир".utf8) let result = MessageCrypto.bytesToAndroidUtf8String(bytes) - XCTAssertEqual(result, "Привет 🔐", "Valid UTF-8 must decode identically") + XCTAssertEqual(result, "Привет мир", "Valid UTF-8 must decode identically") } func testAndroidUtf8Decoder_matchesWhatWG_onValidUtf8() { @@ -567,6 +590,112 @@ final class CryptoParityTests: XCTestCase { "Attachment password candidates must be identical across both paths") } + func testDecryptIncomingMessage_allowsAttachmentOnlyEmptyContent() throws { + let privateKeyHex = try P256K.KeyAgreement.PrivateKey().rawRepresentation.hexString + var packet = PacketMessage() + packet.fromPublicKey = "02peer_attachment_only" + packet.toPublicKey = "02my_attachment_only" + packet.content = "" + packet.chachaKey = "" + packet.attachments = [ + MessageAttachment( + id: "att-1", + preview: "preview", + blob: "", + type: .image, + transportTag: "tag-1", + transportServer: "cdn.rosetta.im" + ), + ] + + let result = SessionManager.testDecryptIncomingMessage( + packet: packet, + myPublicKey: "02my_attachment_only", + privateKeyHex: privateKeyHex, + groupKey: nil + ) + + XCTAssertNotNil(result) + XCTAssertEqual(result?.text, "") + } + + func testDecryptIncomingMessage_rejectsEmptyContentWithoutAttachments() throws { + let privateKeyHex = try P256K.KeyAgreement.PrivateKey().rawRepresentation.hexString + var packet = PacketMessage() + packet.fromPublicKey = "02peer_invalid" + packet.toPublicKey = "02my_invalid" + packet.content = "" + packet.chachaKey = "" + packet.attachments = [] + + let result = SessionManager.testDecryptIncomingMessage( + packet: packet, + myPublicKey: "02my_invalid", + privateKeyHex: privateKeyHex, + groupKey: nil + ) + + XCTAssertNil(result) + } + + func testRecoverRetryPlaintext_rejectsCiphertextFallback() throws { + let privateKeyHex = try P256K.KeyAgreement.PrivateKey().rawRepresentation.hexString + let wrongPrivateKeyHex = try P256K.KeyAgreement.PrivateKey().rawRepresentation.hexString + + let encrypted = try CryptoManager.shared.encryptWithPassword( + Data("retry-text".utf8), + password: privateKeyHex + ) + + let recoveredWithWrongKey = SessionManager.testRecoverRetryPlaintext( + storedText: encrypted, + privateKeyHex: wrongPrivateKeyHex + ) + XCTAssertNotEqual( + recoveredWithWrongKey, + encrypted, + "Retry recovery must never return encrypted wire payload as plaintext" + ) + XCTAssertNotEqual( + recoveredWithWrongKey, + "retry-text", + "Wrong key must never recover original plaintext" + ) + + let recoveredPlainLegacy = SessionManager.testRecoverRetryPlaintext( + storedText: "legacy plain text", + privateKeyHex: privateKeyHex + ) + XCTAssertEqual(recoveredPlainLegacy, "legacy plain text") + } + + func testRawKeyAndNonceParser_requiresStrictRawKeyFormat() throws { + let raw = try CryptoPrimitives.randomBytes(count: 56) + let validStored = "rawkey:" + raw.hexString + + let decodedValid = SessionManager.testRawKeyAndNonceFromStoredAttachmentPassword(validStored) + XCTAssertEqual(decodedValid, raw) + + XCTAssertNil( + SessionManager.testRawKeyAndNonceFromStoredAttachmentPassword(raw.hexString), + "Missing rawkey prefix must be rejected" + ) + XCTAssertNil( + SessionManager.testRawKeyAndNonceFromStoredAttachmentPassword("rawkey:abc"), + "Odd-length hex must be rejected" + ) + XCTAssertNil( + SessionManager.testRawKeyAndNonceFromStoredAttachmentPassword("rawkey:zz"), + "Non-hex symbols must be rejected" + ) + } + + func testDataStrictHexString_rejectsInvalidInput() { + XCTAssertNil(Data(strictHexString: "abc")) + XCTAssertNil(Data(strictHexString: "0g")) + XCTAssertEqual(Data(strictHexString: "0A0b"), Data([0x0A, 0x0B])) + } + // MARK: - Stress Test: Random Key Bytes func testECDH_100RandomKeys_allDecryptSuccessfully() throws { @@ -613,12 +742,3 @@ final class CryptoParityTests: XCTestCase { } } } - -// MARK: - Test Helpers - -extension MessageRepository { - /// Exposes isProbablyEncryptedPayload for testing. - static func testIsProbablyEncrypted(_ value: String) -> Bool { - isProbablyEncryptedPayload(value) - } -} diff --git a/RosettaTests/FileAttachmentTests.swift b/RosettaTests/FileAttachmentTests.swift index 1186dab..009c9b0 100644 --- a/RosettaTests/FileAttachmentTests.swift +++ b/RosettaTests/FileAttachmentTests.swift @@ -107,6 +107,26 @@ final class FileAttachmentTests: XCTestCase { XCTAssertEqual(AttachmentPreviewCodec.blurHash(from: preview), "LVRv{GtR") } + func testBlurHash_LegacyNonUUIDTagPrefix() { + let preview = "jbov1nac::LVRv{GtRSXWB" + XCTAssertEqual(AttachmentPreviewCodec.blurHash(from: preview), "LVRv{GtRSXWB") + } + + func testBlurHash_LegacyNonUUIDTagWithDimensions() { + let preview = "jbov1nac::LVRv{GtR|640x480" + XCTAssertEqual(AttachmentPreviewCodec.blurHash(from: preview), "LVRv{GtR") + } + + func testBlurHash_DoesNotStripUnknownNonUUIDPrefix() { + let preview = "legacy_upload_id::LVRv{GtRSXWB" + XCTAssertEqual(AttachmentPreviewCodec.blurHash(from: preview), "legacy_upload_id::LVRv{GtRSXWB") + } + + func testBlurHash_LegacyNonUUIDTagWithEmptySuffix() { + let preview = "jbov1nac::" + XCTAssertEqual(AttachmentPreviewCodec.blurHash(from: preview), "") + } + // ========================================================================= // MARK: - AttachmentPreviewCodec: Image Dimensions // ========================================================================= diff --git a/RosettaTests/MessageDecodeHardeningTests.swift b/RosettaTests/MessageDecodeHardeningTests.swift new file mode 100644 index 0000000..9141fde --- /dev/null +++ b/RosettaTests/MessageDecodeHardeningTests.swift @@ -0,0 +1,451 @@ +import XCTest +import P256K +@testable import Rosetta + +@MainActor +final class MessageDecodeHardeningTests: XCTestCase { + + func testProtocolManagerDropsMalformedMessagePacketBeforeDispatch() { + let proto = ProtocolManager.shared + let originalOnMessage = proto.onMessageReceived + let originalOnMalformed = proto.onMalformedMessageReceived + defer { + proto.onMessageReceived = originalOnMessage + proto.onMalformedMessageReceived = originalOnMalformed + } + + var messageDispatchCount = 0 + var malformedCount = 0 + proto.onMessageReceived = { _ in + messageDispatchCount += 1 + } + proto.onMalformedMessageReceived = { _ in + malformedCount += 1 + } + + let malformedData = makePacketMessageData(metaFieldCount: 2) + Data([0x00]) + proto.testHandleIncomingData(malformedData) + + XCTAssertEqual(messageDispatchCount, 0) + XCTAssertEqual(malformedCount, 1) + } + + func testProtocolManagerStillDispatchesValidMessagePacket() { + let proto = ProtocolManager.shared + let originalOnMessage = proto.onMessageReceived + let originalOnMalformed = proto.onMalformedMessageReceived + defer { + proto.onMessageReceived = originalOnMessage + proto.onMalformedMessageReceived = originalOnMalformed + } + + var messageDispatchCount = 0 + var malformedCount = 0 + proto.onMessageReceived = { _ in + messageDispatchCount += 1 + } + proto.onMalformedMessageReceived = { _ in + malformedCount += 1 + } + + let validData = makePacketMessageData(metaFieldCount: 2) + proto.testHandleIncomingData(validData) + + XCTAssertEqual(messageDispatchCount, 1) + XCTAssertEqual(malformedCount, 0) + } + + func testProtocolManagerDropsMalformedHandshakePacketBeforeDispatch() { + let proto = ProtocolManager.shared + let originalOnHandshake = proto.onHandshakeCompleted + let originalOnMalformedCritical = proto.onMalformedCriticalPacketReceived + defer { + proto.onHandshakeCompleted = originalOnHandshake + proto.onMalformedCriticalPacketReceived = originalOnMalformedCritical + } + + var handshakeDispatchCount = 0 + var malformedInfos: [MalformedCriticalPacketInfo] = [] + proto.onHandshakeCompleted = { _ in + handshakeDispatchCount += 1 + } + proto.onMalformedCriticalPacketReceived = { info in + malformedInfos.append(info) + } + + let malformedData = makeHandshakeData(stateRaw: 9) + proto.testHandleIncomingData(malformedData) + + XCTAssertEqual(handshakeDispatchCount, 0) + XCTAssertEqual(malformedInfos.count, 1) + XCTAssertEqual(malformedInfos.first?.packetId, PacketHandshake.packetId) + } + + func testProtocolManagerDropsMalformedSyncPacketBeforeDispatch() { + let proto = ProtocolManager.shared + let originalOnSync = proto.onSyncReceived + let originalOnMalformedCritical = proto.onMalformedCriticalPacketReceived + defer { + proto.onSyncReceived = originalOnSync + proto.onMalformedCriticalPacketReceived = originalOnMalformedCritical + } + + var syncDispatchCount = 0 + var malformedInfos: [MalformedCriticalPacketInfo] = [] + proto.onSyncReceived = { _ in + syncDispatchCount += 1 + } + proto.onMalformedCriticalPacketReceived = { info in + malformedInfos.append(info) + } + + let malformedData = makeSyncData(statusRaw: 7) + proto.testHandleIncomingData(malformedData) + + XCTAssertEqual(syncDispatchCount, 0) + XCTAssertEqual(malformedInfos.count, 1) + XCTAssertEqual(malformedInfos.first?.packetId, PacketSync.packetId) + } + + func testProtocolManagerDropsMalformedSignalPacketBeforeDispatch() { + let proto = ProtocolManager.shared + let originalOnSignal = proto.onSignalPeerReceived + let originalOnMalformedCritical = proto.onMalformedCriticalPacketReceived + defer { + proto.onSignalPeerReceived = originalOnSignal + proto.onMalformedCriticalPacketReceived = originalOnMalformedCritical + } + + var signalDispatchCount = 0 + var malformedInfos: [MalformedCriticalPacketInfo] = [] + proto.onSignalPeerReceived = { _ in + signalDispatchCount += 1 + } + proto.onMalformedCriticalPacketReceived = { info in + malformedInfos.append(info) + } + + var signal = PacketSignalPeer() + signal.signalType = .endCallBecauseBusy + let malformedData = PacketRegistry.encode(signal) + Data([0x00]) + proto.testHandleIncomingData(malformedData) + + XCTAssertEqual(signalDispatchCount, 0) + XCTAssertEqual(malformedInfos.count, 1) + XCTAssertEqual(malformedInfos.first?.packetId, PacketSignalPeer.packetId) + } + + func testProtocolManagerDropsMalformedWebRtcPacketBeforeDispatch() { + let proto = ProtocolManager.shared + let originalOnWebRtc = proto.onWebRTCReceived + let originalOnMalformedCritical = proto.onMalformedCriticalPacketReceived + defer { + proto.onWebRTCReceived = originalOnWebRtc + proto.onMalformedCriticalPacketReceived = originalOnMalformedCritical + } + + var webRtcDispatchCount = 0 + var malformedInfos: [MalformedCriticalPacketInfo] = [] + proto.onWebRTCReceived = { _ in + webRtcDispatchCount += 1 + } + proto.onMalformedCriticalPacketReceived = { info in + malformedInfos.append(info) + } + + let malformedData = makeWebRtcData(includeIdentity: false) + Data([0x00]) + proto.testHandleIncomingData(malformedData) + + XCTAssertEqual(webRtcDispatchCount, 0) + XCTAssertEqual(malformedInfos.count, 1) + XCTAssertEqual(malformedInfos.first?.packetId, PacketWebRTC.packetId) + } + + func testSignalTypeCreateRoomDecodesWithAndWithoutRoomId() throws { + let withoutRoomData = makeSignalCreateRoomData(roomId: nil) + guard let withoutRoomDecoded = PacketRegistry.decode(from: withoutRoomData), + let withoutRoomPacket = withoutRoomDecoded.packet as? PacketSignalPeer + else { + XCTFail("Failed to decode signalType=4 packet without roomId") + return + } + + XCTAssertEqual(withoutRoomDecoded.packetId, PacketSignalPeer.packetId) + XCTAssertFalse(withoutRoomPacket.isMalformed) + XCTAssertEqual(withoutRoomPacket.signalType, .createRoom) + XCTAssertEqual(withoutRoomPacket.roomId, "") + + let withRoomData = makeSignalCreateRoomData(roomId: "room-42") + guard let withRoomDecoded = PacketRegistry.decode(from: withRoomData), + let withRoomPacket = withRoomDecoded.packet as? PacketSignalPeer + else { + XCTFail("Failed to decode signalType=4 packet with roomId") + return + } + + XCTAssertEqual(withRoomDecoded.packetId, PacketSignalPeer.packetId) + XCTAssertFalse(withRoomPacket.isMalformed) + XCTAssertEqual(withRoomPacket.signalType, .createRoom) + XCTAssertEqual(withRoomPacket.roomId, "room-42") + } + + func testWebRtcDecodeSupportsTwoAndFourFieldLayouts() throws { + let twoFieldData = makeWebRtcData(includeIdentity: false) + guard let decodedTwoField = PacketRegistry.decode(from: twoFieldData), + let twoFieldPacket = decodedTwoField.packet as? PacketWebRTC + else { + XCTFail("Failed to decode canonical 2-field 0x1B packet") + return + } + + XCTAssertEqual(decodedTwoField.packetId, PacketWebRTC.packetId) + XCTAssertFalse(twoFieldPacket.isMalformed) + XCTAssertEqual(twoFieldPacket.signalType, .offer) + XCTAssertEqual(twoFieldPacket.sdpOrCandidate, "{\"type\":\"offer\",\"sdp\":\"v=0\"}") + XCTAssertEqual(twoFieldPacket.publicKey, "") + XCTAssertEqual(twoFieldPacket.deviceId, "") + + let fourFieldData = makeWebRtcData( + includeIdentity: true, + publicKey: "02legacyPublic", + deviceId: "legacy-device" + ) + guard let decodedFourField = PacketRegistry.decode(from: fourFieldData), + let fourFieldPacket = decodedFourField.packet as? PacketWebRTC + else { + XCTFail("Failed to decode legacy 4-field 0x1B packet") + return + } + + XCTAssertEqual(decodedFourField.packetId, PacketWebRTC.packetId) + XCTAssertFalse(fourFieldPacket.isMalformed) + XCTAssertEqual(fourFieldPacket.signalType, .offer) + XCTAssertEqual(fourFieldPacket.sdpOrCandidate, "{\"type\":\"offer\",\"sdp\":\"v=0\"}") + XCTAssertEqual(fourFieldPacket.publicKey, "02legacyPublic") + XCTAssertEqual(fourFieldPacket.deviceId, "legacy-device") + } + + func testWebRtcEncodeUsesCanonicalTwoFieldLayout() { + var packet = PacketWebRTC() + packet.signalType = .answer + packet.sdpOrCandidate = "{\"type\":\"answer\",\"sdp\":\"v=0\"}" + packet.publicKey = "02shouldNotBeEncoded" + packet.deviceId = "device-should-not-be-encoded" + + let encoded = PacketRegistry.encode(packet) + let stream = Rosetta.Stream(data: encoded) + XCTAssertEqual(stream.readInt16(), PacketWebRTC.packetId) + XCTAssertEqual(stream.readInt8(), WebRTCSignalType.answer.rawValue) + XCTAssertEqual(stream.readString(), packet.sdpOrCandidate) + XCTAssertFalse(stream.hasRemainingBits()) + } + + func testMalformedPacketRecoveryDebouncesResyncToSingleRequest() async throws { + let session = SessionManager.shared + let privateKeyHex = try P256K.KeyAgreement.PrivateKey().rawRepresentation.hexString + session.testConfigureSessionForParityFlows( + currentPublicKey: "02malformed_resync_test", + privateKeyHex: privateKeyHex + ) + + session.testResetMalformedMessageResyncState() + defer { + session.testSetMalformedMessageResyncHook(nil) + session.testResetMalformedMessageResyncState() + } + + var triggerCount = 0 + session.testSetMalformedMessageResyncHook { + triggerCount += 1 + } + + session.testSimulateMalformedMessagePacketDrop(packetSize: 111, fingerprint: "fp-1", messageIdHint: "m1") + session.testSimulateMalformedMessagePacketDrop(packetSize: 112, fingerprint: "fp-2", messageIdHint: "m2") + session.testSimulateMalformedMessagePacketDrop(packetSize: 113, fingerprint: "fp-3", messageIdHint: "m3") + + try await Task.sleep(nanoseconds: 900_000_000) + + XCTAssertEqual(triggerCount, 1) + XCTAssertEqual(session.malformedMessageResyncTriggerCount, 1) + } + + func testMalformedPacketRecoveryWaitsUntilSyncBatchEnds() async throws { + let session = SessionManager.shared + let privateKeyHex = try P256K.KeyAgreement.PrivateKey().rawRepresentation.hexString + session.testConfigureSessionForParityFlows( + currentPublicKey: "02malformed_resync_batch_test", + privateKeyHex: privateKeyHex + ) + + session.testResetMalformedMessageResyncState() + session.testSetSyncState(syncRequestInFlight: false, syncBatchInProgress: true) + defer { + session.testSetSyncState(syncRequestInFlight: false, syncBatchInProgress: false) + session.testSetMalformedMessageResyncHook(nil) + session.testResetMalformedMessageResyncState() + } + + var triggerCount = 0 + session.testSetMalformedMessageResyncHook { + triggerCount += 1 + } + + session.testSimulateMalformedMessagePacketDrop(packetSize: 211, fingerprint: "fp-batch", messageIdHint: "m-batch") + + // While sync is active, recovery must remain queued and not fire. + try await Task.sleep(nanoseconds: 900_000_000) + XCTAssertEqual(triggerCount, 0) + XCTAssertEqual(session.malformedMessageResyncTriggerCount, 0) + + // After sync ends, queued recovery should fire once. + session.testSetSyncState(syncBatchInProgress: false) + try await Task.sleep(nanoseconds: 700_000_000) + XCTAssertEqual(triggerCount, 1) + XCTAssertEqual(session.malformedMessageResyncTriggerCount, 1) + } + + func testPostDecryptEmptyPayloadDropsWithoutUpsertAndDebouncesResync() async throws { + let session = SessionManager.shared + let myPrivateKey = try P256K.KeyAgreement.PrivateKey() + let myPrivateKeyData = myPrivateKey.rawRepresentation + let myPrivateKeyHex = myPrivateKeyData.hexString + let myPublicKey = try CryptoManager.shared.deriveCompressedPublicKey(from: myPrivateKeyData).hexString + + try DatabaseManager.shared.bootstrap(accountPublicKey: myPublicKey) + await MessageRepository.shared.bootstrap(accountPublicKey: myPublicKey, storagePassword: myPrivateKeyHex) + await DialogRepository.shared.bootstrap(accountPublicKey: myPublicKey, storagePassword: myPrivateKeyHex) + + session.testConfigureSessionForParityFlows( + currentPublicKey: myPublicKey, + privateKeyHex: myPrivateKeyHex + ) + session.testSetSyncState(syncRequestInFlight: false, syncBatchInProgress: false) + session.testResetMalformedMessageResyncState() + defer { + session.testSetMalformedMessageResyncHook(nil) + session.testResetMalformedMessageResyncState() + session.testSetSyncState(syncRequestInFlight: false, syncBatchInProgress: false) + } + + var triggerCount = 0 + session.testSetMalformedMessageResyncHook { + triggerCount += 1 + } + + let peerPrivateKey = try P256K.KeyAgreement.PrivateKey() + let peerPublicKey = try CryptoManager.shared + .deriveCompressedPublicKey(from: peerPrivateKey.rawRepresentation) + .hexString + + var processedMessageIds: [String] = [] + for index in 0..<3 { + let encrypted = try MessageCrypto.encryptOutgoing( + plaintext: "", + recipientPublicKeyHex: myPublicKey + ) + + var packet = PacketMessage() + packet.fromPublicKey = peerPublicKey + packet.toPublicKey = myPublicKey + packet.content = encrypted.content + packet.chachaKey = encrypted.chachaKey + packet.timestamp = 1_710_000_000_000 + Int64(index) + packet.privateKey = "hash" + packet.messageId = "empty-decrypted-\(index)-\(UUID().uuidString)" + packet.attachments = [] + packet.aesChachaKey = "" + + processedMessageIds.append(packet.messageId) + await session.testProcessIncomingMessage(packet) + } + + try await Task.sleep(nanoseconds: 900_000_000) + + XCTAssertEqual(triggerCount, 1) + XCTAssertEqual(session.malformedMessageResyncTriggerCount, 1) + + for messageId in processedMessageIds { + XCTAssertFalse( + MessageRepository.shared.hasMessage(messageId), + "Post-decrypt empty payload must not be persisted as a bubble" + ) + } + } + + private func makePacketMessageData(metaFieldCount: Int) -> Data { + let stream = Rosetta.Stream() + stream.writeInt16(PacketMessage.packetId) + stream.writeString("02from") + stream.writeString("02to") + stream.writeString("ciphertext") + stream.writeString("chacha-key") + stream.writeInt64(1_710_000_000_000) + stream.writeString("hash") + stream.writeString("msg-hardening") + stream.writeInt8(1) + stream.writeString("att-1") + stream.writeString("preview") + stream.writeString("blob") + stream.writeInt8(AttachmentType.image.rawValue) + if metaFieldCount >= 2 { + stream.writeString("tag-1") + stream.writeString("cdn.rosetta.im") + } + if metaFieldCount >= 4 { + stream.writeString("02encoded-for") + stream.writeString("desktop") + } + stream.writeString("aes-key") + return stream.toData() + } + + private func makeHandshakeData(stateRaw: Int) -> Data { + let stream = Rosetta.Stream() + stream.writeInt16(PacketHandshake.packetId) + stream.writeString("hash") + stream.writeString("02public") + stream.writeInt8(1) + stream.writeInt8(15) + stream.writeString("device-id") + stream.writeString("iPhone") + stream.writeString("iOS") + stream.writeInt8(stateRaw) + return stream.toData() + } + + private func makeSyncData(statusRaw: Int) -> Data { + let stream = Rosetta.Stream() + stream.writeInt16(PacketSync.packetId) + stream.writeInt8(statusRaw) + stream.writeInt64(1_710_000_000_000) + return stream.toData() + } + + private func makeSignalCreateRoomData(roomId: String?) -> Data { + let stream = Rosetta.Stream() + stream.writeInt16(PacketSignalPeer.packetId) + stream.writeInt8(SignalType.createRoom.rawValue) + stream.writeString("02src") + stream.writeString("02dst") + if let roomId { + stream.writeString(roomId) + } + return stream.toData() + } + + private func makeWebRtcData( + includeIdentity: Bool, + publicKey: String = "", + deviceId: String = "" + ) -> Data { + let stream = Rosetta.Stream() + stream.writeInt16(PacketWebRTC.packetId) + stream.writeInt8(WebRTCSignalType.offer.rawValue) + stream.writeString("{\"type\":\"offer\",\"sdp\":\"v=0\"}") + if includeIdentity { + stream.writeString(publicKey) + stream.writeString(deviceId) + } + return stream.toData() + } +} diff --git a/RosettaTests/PushNotificationPacketTests.swift b/RosettaTests/PushNotificationPacketTests.swift index c16829f..77f4fba 100644 --- a/RosettaTests/PushNotificationPacketTests.swift +++ b/RosettaTests/PushNotificationPacketTests.swift @@ -1,14 +1,11 @@ import Testing @testable import Rosetta -/// Verifies PacketPushNotification wire format matches server -/// (im.rosetta.packet.Packet16PushNotification). -/// -/// Server wire format: +/// Verifies PacketPushNotification wire format matches Server/Android: /// writeInt16(packetId=0x10) /// writeString(notificationToken) -/// writeInt8(action) — 0=subscribe, 1=unsubscribe -/// writeInt8(tokenType) — 0=FCM, 1=VoIPApns +/// writeInt8(action) +/// writeInt8(tokenType) /// writeString(deviceId) struct PushNotificationPacketTests { @@ -24,49 +21,49 @@ struct PushNotificationPacketTests { #expect(PushNotificationAction.unsubscribe.rawValue == 1) } - @Test("PushTokenType.fcm == 0 (server: FCM)") + @Test("PushTokenType.fcm == 0") func fcmTokenTypeValue() { #expect(PushTokenType.fcm.rawValue == 0) } - @Test("PushTokenType.voipApns == 1 (server: VoIPApns)") + @Test("PushTokenType.voipApns == 1") func voipTokenTypeValue() { #expect(PushTokenType.voipApns.rawValue == 1) } // MARK: - Round Trip (encode → decode) - @Test("FCM subscribe round-trip preserves all fields") - func fcmSubscribeRoundTrip() throws { + @Test("Subscribe round-trip preserves all fields") + func subscribeRoundTrip() throws { var packet = PacketPushNotification() packet.notificationsToken = "test-fcm-token-abc123" packet.action = .subscribe packet.tokenType = .fcm - packet.deviceId = "device-id-xyz" + packet.deviceId = "ios-device-1" let decoded = try decodePushNotification(packet) #expect(decoded.notificationsToken == "test-fcm-token-abc123") #expect(decoded.action == .subscribe) #expect(decoded.tokenType == .fcm) - #expect(decoded.deviceId == "device-id-xyz") + #expect(decoded.deviceId == "ios-device-1") } - @Test("VoIP unsubscribe round-trip preserves all fields") - func voipUnsubscribeRoundTrip() throws { + @Test("Unsubscribe round-trip preserves all fields") + func unsubscribeRoundTrip() throws { var packet = PacketPushNotification() - packet.notificationsToken = "voip-hex-token-deadbeef" + packet.notificationsToken = "test-token-deadbeef" packet.action = .unsubscribe packet.tokenType = .voipApns - packet.deviceId = "another-device-id" + packet.deviceId = "ios-device-2" let decoded = try decodePushNotification(packet) - #expect(decoded.notificationsToken == "voip-hex-token-deadbeef") + #expect(decoded.notificationsToken == "test-token-deadbeef") #expect(decoded.action == .unsubscribe) #expect(decoded.tokenType == .voipApns) - #expect(decoded.deviceId == "another-device-id") + #expect(decoded.deviceId == "ios-device-2") } - @Test("Empty token and deviceId round-trip") + @Test("Empty token round-trip") func emptyFieldsRoundTrip() throws { var packet = PacketPushNotification() packet.notificationsToken = "" @@ -76,6 +73,7 @@ struct PushNotificationPacketTests { let decoded = try decodePushNotification(packet) #expect(decoded.notificationsToken == "") + #expect(decoded.tokenType == .fcm) #expect(decoded.deviceId == "") } @@ -100,7 +98,7 @@ struct PushNotificationPacketTests { var packet = PacketPushNotification() packet.notificationsToken = "T" // 1 UTF-16 code unit packet.action = .subscribe // 0 - packet.tokenType = .voipApns // 1 + packet.tokenType = .fcm // 0 packet.deviceId = "D" // 1 UTF-16 code unit let data = PacketRegistry.encode(packet) @@ -110,9 +108,9 @@ struct PushNotificationPacketTests { // [2-5] string length = 1 (UInt32 big-endian) for "T" // [6-7] 'T' = 0x0054 (UInt16 big-endian) // [8] action = 0 (subscribe) - // [9] tokenType = 1 (voipApns) + // [9] tokenType = 0 (fcm) // [10-13] string length = 1 for "D" - // [14-15] 'D' = 0x0044 (UInt16 big-endian) + // [14-15] 'D' = 0x0044 #expect(data.count == 16) @@ -133,8 +131,8 @@ struct PushNotificationPacketTests { // action = 0 (subscribe) #expect(data[8] == 0x00) - // tokenType = 1 (voipApns) - #expect(data[9] == 0x01) + // tokenType = 0 (fcm) + #expect(data[9] == 0x00) // deviceId string length = 1 #expect(data[10] == 0x00) diff --git a/RosettaTests/SchemaParityTests.swift b/RosettaTests/SchemaParityTests.swift index b6802fc..c64c66a 100644 --- a/RosettaTests/SchemaParityTests.swift +++ b/RosettaTests/SchemaParityTests.swift @@ -153,6 +153,65 @@ final class SchemaParityTests: XCTestCase { XCTAssertEqual(decoded.packetId, 0x06) XCTAssertEqual(decodedMessage.attachments.first?.type, .call) + XCTAssertFalse(decodedMessage.isMalformed) + } + + func testPacketMessageDecodeSupportsAttachmentMeta4Compatibility() throws { + let encoded = makePacketMessageData(attachmentMetaFieldCount: 4) + guard let decoded = PacketRegistry.decode(from: encoded), + let message = decoded.packet as? PacketMessage else { + XCTFail("Failed to decode packet with 4 attachment meta fields") + return + } + + XCTAssertEqual(decoded.packetId, 0x06) + XCTAssertFalse(message.isMalformed) + XCTAssertEqual(message.fromPublicKey, "02from") + XCTAssertEqual(message.toPublicKey, "02to") + XCTAssertEqual(message.messageId, "msg-compat") + XCTAssertEqual(message.aesChachaKey, "aes-key") + XCTAssertEqual(message.attachments.count, 1) + XCTAssertEqual(message.attachments[0].transportTag, "tag-1") + XCTAssertEqual(message.attachments[0].transportServer, "cdn.rosetta.im") + } + + func testPacketMessageDecodeSupportsAttachmentMeta0Compatibility() throws { + let encoded = makePacketMessageData(attachmentMetaFieldCount: 0) + guard let decoded = PacketRegistry.decode(from: encoded), + let message = decoded.packet as? PacketMessage else { + XCTFail("Failed to decode packet with 0 attachment meta fields") + return + } + + XCTAssertEqual(decoded.packetId, 0x06) + XCTAssertFalse(message.isMalformed) + XCTAssertEqual(message.messageId, "msg-compat") + XCTAssertEqual(message.aesChachaKey, "aes-key") + XCTAssertEqual(message.attachments.count, 1) + XCTAssertEqual(message.attachments[0].transportTag, "") + XCTAssertEqual(message.attachments[0].transportServer, "") + } + + func testPacketMessageDecodeMarksMalformedForTruncatedOrMisalignedPayload() throws { + let canonical = makePacketMessageData(attachmentMetaFieldCount: 2) + let truncated = canonical.dropLast(3) + let withTrailingByte = canonical + Data([0x00]) + + guard let decodedTruncated = PacketRegistry.decode(from: Data(truncated)), + let messageTruncated = decodedTruncated.packet as? PacketMessage else { + XCTFail("Failed to decode truncated packet wrapper") + return + } + XCTAssertTrue(messageTruncated.isMalformed) + XCTAssertFalse(messageTruncated.malformedFingerprint.isEmpty) + + guard let decodedTrailing = PacketRegistry.decode(from: withTrailingByte), + let messageTrailing = decodedTrailing.packet as? PacketMessage else { + XCTFail("Failed to decode trailing-byte packet wrapper") + return + } + XCTAssertTrue(messageTrailing.isMalformed) + XCTAssertFalse(messageTrailing.malformedFingerprint.isEmpty) } func testSessionPacketContextResolverAcceptsGroupWireShape() throws { @@ -188,4 +247,33 @@ final class SchemaParityTests: XCTestCase { _ = stream XCTAssertTrue(true) } + + private func makePacketMessageData(attachmentMetaFieldCount: Int) -> Data { + let stream = Rosetta.Stream() + stream.writeInt16(PacketMessage.packetId) + stream.writeString("02from") + stream.writeString("02to") + stream.writeString("ciphertext") + stream.writeString("chacha-key") + stream.writeInt64(1_710_000_000_000) + stream.writeString("hash") + stream.writeString("msg-compat") + stream.writeInt8(1) + stream.writeString("att-1") + stream.writeString("preview") + stream.writeString("blob") + stream.writeInt8(AttachmentType.image.rawValue) + + if attachmentMetaFieldCount >= 2 { + stream.writeString("tag-1") + stream.writeString("cdn.rosetta.im") + } + if attachmentMetaFieldCount >= 4 { + stream.writeString("02encoded-for") + stream.writeString("desktop") + } + + stream.writeString("aes-key") + return stream.toData() + } }