Files
mobile-ios/Rosetta/Core/Network/Protocol/Stream.swift

438 lines
11 KiB
Swift

import Foundation
typealias Stream = PacketBitStream
enum PacketBitStreamError: Error {
case underflow(operation: String, neededBits: Int, remainingBits: Int)
case invalidStringLength(Int)
}
/// Bit-aligned binary stream for protocol packets.
/// Matches the server (Java) implementation exactly.
///
/// Supports:
/// - signed: Int8/16/32/64 (two's complement)
/// - unsigned: UInt8/16/32/64
/// - String: length(UInt32) + chars(UInt16)
/// - Bytes: length(UInt32) + raw bytes
final class PacketBitStream: NSObject {
private var bytes: [UInt8]
private var readPointer: Int = 0 // bits
private var writePointer: Int = 0 // bits
// MARK: - Init
override init() {
bytes = []
super.init()
}
init(data: Data) {
bytes = Array(data)
super.init()
writePointer = bytes.count << 3
}
// MARK: - Output
func toData() -> Data {
let byteCount = (writePointer + 7) >> 3
guard byteCount > 0 else { return Data() }
return Data(bytes[0..<byteCount])
}
// MARK: - Pointer & State (Android/Desktop parity helpers)
func getReadPointerBits() -> Int {
readPointer
}
func setReadPointerBits(_ bits: Int) {
readPointer = min(max(bits, 0), writePointer)
}
func remainingBits() -> Int {
writePointer - readPointer
}
func hasRemainingBits() -> Bool {
readPointer < writePointer
}
// MARK: - Bit-Level I/O
func writeBit(_ value: Int) {
writeBits(value & 1, count: 1)
}
func readBit() -> Int {
guard remainingBits() >= 1 else { return 0 }
let byteIndex = readPointer >> 3
let shift = 7 - (readPointer & 7)
let bit = Int((bytes[byteIndex] >> shift) & 1)
readPointer += 1
return bit
}
// MARK: - Bool
func writeBoolean(_ value: Bool) {
writeBit(value ? 1 : 0)
}
func readBoolean() -> Bool {
readBit() == 1
}
// MARK: - Byte
func writeByte(_ value: UInt8) {
writeUInt8(Int(value))
}
func readByte() -> UInt8 {
UInt8(truncatingIfNeeded: readUInt8())
}
// MARK: - UInt8 / Int8 (8 bits)
func writeUInt8(_ value: Int) {
let v = UInt8(truncatingIfNeeded: value)
// Fast path: byte-aligned
if (writePointer & 7) == 0 {
ensureCapacityForUpcomingBits(8)
bytes[writePointer >> 3] = v
writePointer += 8
return
}
writeBits(Int(v), count: 8)
}
func readUInt8() -> Int {
guard remainingBits() >= 8 else { return 0 }
// Fast path: byte-aligned
if (readPointer & 7) == 0 {
let value = Int(bytes[readPointer >> 3])
readPointer += 8
return value
}
return readBits(8)
}
func readUInt8Strict() throws -> Int {
try ensureReadableBits(8, operation: "readUInt8")
if (readPointer & 7) == 0 {
let value = Int(bytes[readPointer >> 3])
readPointer += 8
return value
}
return try readBitsStrict(8)
}
func writeInt8(_ value: Int) {
writeUInt8(value)
}
func readInt8() -> Int {
Int(Int8(truncatingIfNeeded: readUInt8()))
}
func readInt8Strict() throws -> Int {
Int(Int8(truncatingIfNeeded: try readUInt8Strict()))
}
// MARK: - UInt16 / Int16 (16 bits)
func writeUInt16(_ value: Int) {
let v = value & 0xFFFF
writeUInt8((v >> 8) & 0xFF)
writeUInt8(v & 0xFF)
}
func readUInt16() -> Int {
let hi = readUInt8()
let lo = readUInt8()
return (hi << 8) | lo
}
func readUInt16Strict() throws -> Int {
let hi = try readUInt8Strict()
let lo = try readUInt8Strict()
return (hi << 8) | lo
}
func writeInt16(_ value: Int) {
writeUInt16(value)
}
func readInt16() -> Int {
Int(Int16(truncatingIfNeeded: readUInt16()))
}
func readInt16Strict() throws -> Int {
Int(Int16(truncatingIfNeeded: try readUInt16Strict()))
}
// MARK: - UInt32 / Int32 (32 bits)
func writeUInt32(_ value: Int) {
writeUInt8((value >> 24) & 0xFF)
writeUInt8((value >> 16) & 0xFF)
writeUInt8((value >> 8) & 0xFF)
writeUInt8(value & 0xFF)
}
func readUInt32() -> Int {
let b1 = readUInt8()
let b2 = readUInt8()
let b3 = readUInt8()
let b4 = readUInt8()
return (b1 << 24) | (b2 << 16) | (b3 << 8) | b4
}
func readUInt32Strict() throws -> Int {
let b1 = try readUInt8Strict()
let b2 = try readUInt8Strict()
let b3 = try readUInt8Strict()
let b4 = try readUInt8Strict()
return (b1 << 24) | (b2 << 16) | (b3 << 8) | b4
}
func writeInt32(_ value: Int) {
writeUInt32(value)
}
func readInt32() -> Int {
Int(Int32(truncatingIfNeeded: readUInt32()))
}
func readInt32Strict() throws -> Int {
Int(Int32(truncatingIfNeeded: try readUInt32Strict()))
}
// MARK: - UInt64 / Int64 (64 bits)
func writeUInt64(_ value: Int64) {
writeUInt8(Int((value >> 56) & 0xFF))
writeUInt8(Int((value >> 48) & 0xFF))
writeUInt8(Int((value >> 40) & 0xFF))
writeUInt8(Int((value >> 32) & 0xFF))
writeUInt8(Int((value >> 24) & 0xFF))
writeUInt8(Int((value >> 16) & 0xFF))
writeUInt8(Int((value >> 8) & 0xFF))
writeUInt8(Int(value & 0xFF))
}
func readUInt64() -> Int64 {
let high = Int64(readUInt32()) & 0xFFFFFFFF
let low = Int64(readUInt32()) & 0xFFFFFFFF
return (high << 32) | low
}
func readUInt64Strict() throws -> Int64 {
let high = Int64(try readUInt32Strict()) & 0xFFFFFFFF
let low = Int64(try readUInt32Strict()) & 0xFFFFFFFF
return (high << 32) | low
}
func writeInt64(_ value: Int64) {
writeUInt64(value)
}
func readInt64() -> Int64 {
readUInt64()
}
func readInt64Strict() throws -> Int64 {
try readUInt64Strict()
}
// MARK: - Float32
func writeFloat32(_ value: Float) {
writeInt32(Int(Int32(bitPattern: value.bitPattern)))
}
func readFloat32() -> Float {
Float(bitPattern: UInt32(bitPattern: Int32(truncatingIfNeeded: readInt32())))
}
// MARK: - String (UInt32 length + UInt16 chars)
func writeString(_ value: String) {
let utf16Units = Array(value.utf16)
let length = utf16Units.count
writeUInt32(length)
guard length > 0 else { return }
ensureCapacityForUpcomingBits(length * 16)
for codeUnit in utf16Units {
writeUInt16(Int(codeUnit))
}
}
func readString() -> String {
let length = readUInt32()
guard length > 0 else { return "" }
let requiredBits = length * 16
guard requiredBits <= remainingBits() else { return "" }
var codeUnits = [UInt16]()
codeUnits.reserveCapacity(length)
for _ in 0..<length {
codeUnits.append(UInt16(truncatingIfNeeded: readUInt16()))
}
return String(decoding: codeUnits, as: UTF16.self)
}
func readStringStrict() throws -> String {
let length = try readUInt32Strict()
guard length > 0 else { return "" }
guard length <= Int(Int32.max) else {
throw PacketBitStreamError.invalidStringLength(length)
}
let requiredBits = length * 16
try ensureReadableBits(requiredBits, operation: "readString")
var codeUnits = [UInt16]()
codeUnits.reserveCapacity(length)
for _ in 0..<length {
codeUnits.append(UInt16(truncatingIfNeeded: try readUInt16Strict()))
}
return String(decoding: codeUnits, as: UTF16.self)
}
// MARK: - Bytes (UInt32 length + raw bytes)
func writeBytes(_ value: Data) {
let length = value.count
writeUInt32(length)
guard length > 0 else { return }
ensureCapacityForUpcomingBits(length * 8)
// Fast path: byte-aligned
if (writePointer & 7) == 0 {
let byteIndex = writePointer >> 3
for (i, byte) in value.enumerated() {
bytes[byteIndex + i] = byte
}
writePointer += length << 3
return
}
for byte in value {
writeUInt8(Int(byte))
}
}
func readBytes() -> Data {
let length = readUInt32()
guard length > 0 else { return Data() }
let requiredBits = length * 8
guard requiredBits <= remainingBits() else { return Data() }
// Fast path: byte-aligned
if (readPointer & 7) == 0 {
let byteIndex = readPointer >> 3
let result = Data(bytes[byteIndex..<(byteIndex + length)])
readPointer += length << 3
return result
}
var result = Data(capacity: length)
for _ in 0..<length {
result.append(UInt8(truncatingIfNeeded: readUInt8()))
}
return result
}
// MARK: - Private
private func writeBits(_ value: Int, count: Int) {
guard count > 0 else { return }
ensureCapacityForUpcomingBits(count)
for i in stride(from: count - 1, through: 0, by: -1) {
let bit = UInt8((value >> i) & 1)
let byteIndex = writePointer >> 3
let shift = 7 - (writePointer & 7)
if bit == 1 {
bytes[byteIndex] |= (1 << shift)
} else {
bytes[byteIndex] &= ~(1 << shift)
}
writePointer += 1
}
}
private func readBits(_ count: Int) -> Int {
guard count > 0, remainingBits() >= count else { return 0 }
var value = 0
for _ in 0..<count {
let byteIndex = readPointer >> 3
let shift = 7 - (readPointer & 7)
let bit = Int((bytes[byteIndex] >> shift) & 1)
value = (value << 1) | bit
readPointer += 1
}
return value
}
private func readBitsStrict(_ count: Int) throws -> Int {
try ensureReadableBits(count, operation: "readBits")
var value = 0
for _ in 0..<count {
let byteIndex = readPointer >> 3
let shift = 7 - (readPointer & 7)
let bit = Int((bytes[byteIndex] >> shift) & 1)
value = (value << 1) | bit
readPointer += 1
}
return value
}
private func ensureReadableBits(_ needed: Int, operation: String) throws {
let remaining = remainingBits()
guard needed > 0 else { return }
guard remaining >= needed else {
throw PacketBitStreamError.underflow(
operation: operation,
neededBits: needed,
remainingBits: remaining
)
}
}
private func ensureCapacityForUpcomingBits(_ bitCount: Int) {
guard bitCount > 0 else { return }
let lastBitIndex = writePointer + bitCount - 1
ensureCapacity(lastBitIndex >> 3)
}
private func ensureCapacity(_ byteIndex: Int) {
let requiredSize = byteIndex + 1
guard requiredSize > bytes.count else { return }
var newSize = bytes.isEmpty ? 32 : bytes.count
while newSize < requiredSize {
newSize <<= 1
}
bytes.append(contentsOf: repeatElement(0 as UInt8, count: newSize - bytes.count))
}
}