diff --git a/app/src/main/java/com/rosetta/messenger/network/Protocol.kt b/app/src/main/java/com/rosetta/messenger/network/Protocol.kt index 9605f41..b305993 100644 --- a/app/src/main/java/com/rosetta/messenger/network/Protocol.kt +++ b/app/src/main/java/com/rosetta/messenger/network/Protocol.kt @@ -32,7 +32,7 @@ class Protocol( private const val TAG = "RosettaProtocol" private const val RECONNECT_INTERVAL = 5000L // 5 seconds (как в Архиве) private const val HANDSHAKE_TIMEOUT = 10000L // 10 seconds - private const val MIN_PACKET_ID_BITS = 18 // Stream.readInt16() = 2 * readInt8() (9 bits each) + private const val MIN_PACKET_ID_BITS = 16 // Stream.readInt16() reads exactly 16 bits private const val DEFAULT_HEARTBEAT_INTERVAL_SECONDS = 15 private const val MIN_HEARTBEAT_SEND_INTERVAL_MS = 2_000L private const val HEARTBEAT_OK_LOG_THROTTLE_MS = 30_000L diff --git a/app/src/main/java/com/rosetta/messenger/network/Stream.kt b/app/src/main/java/com/rosetta/messenger/network/Stream.kt index c2e825f..bc898b2 100644 --- a/app/src/main/java/com/rosetta/messenger/network/Stream.kt +++ b/app/src/main/java/com/rosetta/messenger/network/Stream.kt @@ -1,163 +1,332 @@ package com.rosetta.messenger.network /** - * Binary stream for protocol packets - * Matches the React Native implementation exactly + * Binary stream for protocol packets. + * Ported from desktop/dev stream.ts implementation. */ class Stream(stream: ByteArray = ByteArray(0)) { - private var _stream = mutableListOf() - private var _readPointer = 0 - private var _writePointer = 0 - + private var stream: ByteArray + private var readPointer = 0 // bits + private var writePointer = 0 // bits + init { - _stream = stream.map { it.toInt() and 0xFF }.toMutableList() + if (stream.isEmpty()) { + this.stream = ByteArray(0) + } else { + this.stream = stream.copyOf() + this.writePointer = this.stream.size shl 3 + } } - + fun getStream(): ByteArray { - return _stream.map { it.toByte() }.toByteArray() + return stream.copyOf(length()) } - fun getReadPointerBits(): Int = _readPointer - - fun getTotalBits(): Int = _stream.size * 8 - - fun getRemainingBits(): Int = getTotalBits() - _readPointer - - fun hasRemainingBits(): Boolean = _readPointer < getTotalBits() - - fun setStream(stream: ByteArray) { - _stream = stream.map { it.toInt() and 0xFF }.toMutableList() - _readPointer = 0 - } - - fun writeInt8(value: Int) { - val negationBit = if (value < 0) 1 else 0 - val int8Value = Math.abs(value) and 0xFF - - ensureCapacity(_writePointer shr 3) - _stream[_writePointer shr 3] = _stream[_writePointer shr 3] or (negationBit shl (7 - (_writePointer and 7))) - _writePointer++ - - for (i in 0 until 8) { - val bit = (int8Value shr (7 - i)) and 1 - ensureCapacity(_writePointer shr 3) - _stream[_writePointer shr 3] = _stream[_writePointer shr 3] or (bit shl (7 - (_writePointer and 7))) - _writePointer++ + fun setStream(stream: ByteArray = ByteArray(0)) { + if (stream.isEmpty()) { + this.stream = ByteArray(0) + this.readPointer = 0 + this.writePointer = 0 + return } + this.stream = stream.copyOf() + this.readPointer = 0 + this.writePointer = this.stream.size shl 3 } - - fun readInt8(): Int { - var value = 0 - val negationBit = (_stream[_readPointer shr 3] shr (7 - (_readPointer and 7))) and 1 - _readPointer++ - - for (i in 0 until 8) { - val bit = (_stream[_readPointer shr 3] shr (7 - (_readPointer and 7))) and 1 - value = value or (bit shl (7 - i)) - _readPointer++ - } - - return if (negationBit == 1) -value else value - } - + + fun getBuffer(): ByteArray = getStream() + + fun isEmpty(): Boolean = writePointer == 0 + + fun length(): Int = (writePointer + 7) shr 3 + + fun getReadPointerBits(): Int = readPointer + + fun getTotalBits(): Int = writePointer + + fun getRemainingBits(): Int = writePointer - readPointer + + fun hasRemainingBits(): Boolean = readPointer < writePointer + fun writeBit(value: Int) { - val bit = value and 1 - ensureCapacity(_writePointer shr 3) - _stream[_writePointer shr 3] = _stream[_writePointer shr 3] or (bit shl (7 - (_writePointer and 7))) - _writePointer++ + writeBits((value and 1).toULong(), 1) } - - fun readBit(): Int { - val bit = (_stream[_readPointer shr 3] shr (7 - (_readPointer and 7))) and 1 - _readPointer++ - return bit - } - + + fun readBit(): Int = readBits(1).toInt() + fun writeBoolean(value: Boolean) { writeBit(if (value) 1 else 0) } - - fun readBoolean(): Boolean { - return readBit() == 1 + + fun readBoolean(): Boolean = readBit() == 1 + + fun writeByte(value: Int) { + writeUInt8(value and 0xFF) } - + + fun readByte(): Int { + val value = readUInt8() + return if (value >= 0x80) value - 0x100 else value + } + + fun writeUInt8(value: Int) { + val v = value and 0xFF + + if ((writePointer and 7) == 0) { + reserveBits(8) + stream[writePointer shr 3] = v.toByte() + writePointer += 8 + return + } + + writeBits(v.toULong(), 8) + } + + fun readUInt8(): Int { + if (remainingBits() < 8L) { + throw IllegalStateException("Not enough bits to read UInt8") + } + + if ((readPointer and 7) == 0) { + val value = stream[readPointer shr 3].toInt() and 0xFF + readPointer += 8 + return value + } + + return readBits(8).toInt() + } + + fun writeInt8(value: Int) { + writeUInt8(value) + } + + fun readInt8(): Int { + val value = readUInt8() + return if (value >= 0x80) value - 0x100 else value + } + + fun writeUInt16(value: Int) { + val v = value and 0xFFFF + writeUInt8((v ushr 8) and 0xFF) + writeUInt8(v and 0xFF) + } + + fun readUInt16(): Int { + val hi = readUInt8() + val lo = readUInt8() + return (hi shl 8) or lo + } + fun writeInt16(value: Int) { - writeInt8(value shr 8) - writeInt8(value and 0xFF) + writeUInt16(value) } - + fun readInt16(): Int { - val high = readInt8() shl 8 - return high or readInt8() + val value = readUInt16() + return if (value >= 0x8000) value - 0x10000 else value } - + + fun writeUInt32(value: Long) { + if (value < 0L || value > 0xFFFF_FFFFL) { + throw IllegalArgumentException("UInt32 out of range: $value") + } + + writeUInt8(((value ushr 24) and 0xFF).toInt()) + writeUInt8(((value ushr 16) and 0xFF).toInt()) + writeUInt8(((value ushr 8) and 0xFF).toInt()) + writeUInt8((value and 0xFF).toInt()) + } + + fun readUInt32(): Long { + val b1 = readUInt8().toLong() + val b2 = readUInt8().toLong() + val b3 = readUInt8().toLong() + val b4 = readUInt8().toLong() + return ((b1 shl 24) or (b2 shl 16) or (b3 shl 8) or b4) and 0xFFFF_FFFFL + } + fun writeInt32(value: Int) { - writeInt16(value shr 16) - writeInt16(value and 0xFFFF) + writeUInt32(value.toLong() and 0xFFFF_FFFFL) } - - fun readInt32(): Int { - val high = readInt16() shl 16 - return high or readInt16() + + fun readInt32(): Int = readUInt32().toInt() + + fun writeUInt64(value: ULong) { + writeUInt8(((value shr 56) and 0xFFu).toInt()) + writeUInt8(((value shr 48) and 0xFFu).toInt()) + writeUInt8(((value shr 40) and 0xFFu).toInt()) + writeUInt8(((value shr 32) and 0xFFu).toInt()) + writeUInt8(((value shr 24) and 0xFFu).toInt()) + writeUInt8(((value shr 16) and 0xFFu).toInt()) + writeUInt8(((value shr 8) and 0xFFu).toInt()) + writeUInt8((value and 0xFFu).toInt()) } - - fun writeInt64(value: Long) { - val high = (value shr 32).toInt() - val low = (value and 0xFFFFFFFF).toInt() - writeInt32(high) - writeInt32(low) - } - - fun readInt64(): Long { - val high = readInt32().toLong() - val low = (readInt32().toLong() and 0xFFFFFFFFL) + + fun readUInt64(): ULong { + val high = readUInt32().toULong() + val low = readUInt32().toULong() return (high shl 32) or low } - - fun writeString(value: String) { - writeInt32(value.length) - for (char in value) { - writeInt16(char.code) + + fun writeInt64(value: Long) { + writeUInt64(value.toULong()) + } + + fun readInt64(): Long = readUInt64().toLong() + + fun writeFloat32(value: Float) { + val bits = value.toRawBits().toLong() and 0xFFFF_FFFFL + writeUInt32(bits) + } + + fun readFloat32(): Float { + val bits = readUInt32().toInt() + return Float.fromBits(bits) + } + + fun writeString(value: String?) { + val str = value ?: "" + writeUInt32(str.length.toLong()) + + if (str.isEmpty()) return + + reserveBits(str.length.toLong() * 16L) + for (i in str.indices) { + writeUInt16(str[i].code and 0xFFFF) } } - + fun readString(): String { - val length = readInt32() - // Desktop parity + safety: don't trust malformed string length. - val bytesAvailable = _stream.size - (_readPointer shr 3) - if (length < 0 || (length.toLong() * 2L) > bytesAvailable.toLong()) { - android.util.Log.w( - "RosettaStream", - "readString invalid length=$length, bytesAvailable=$bytesAvailable, readPointer=$_readPointer" - ) - return "" + val len = readUInt32() + if (len > Int.MAX_VALUE.toLong()) { + throw IllegalStateException("String length too large: $len") } - val sb = StringBuilder() - for (i in 0 until length) { - sb.append(readInt16().toChar()) + + val requiredBits = len * 16L + if (requiredBits > remainingBits()) { + throw IllegalStateException("Not enough bits to read string") } - return sb.toString() + + val chars = CharArray(len.toInt()) + for (i in chars.indices) { + chars[i] = readUInt16().toChar() + } + return String(chars) } - - fun writeBytes(value: ByteArray) { - writeInt32(value.size) - for (byte in value) { - writeInt8(byte.toInt()) + + fun writeBytes(value: ByteArray?) { + val bytes = value ?: ByteArray(0) + writeUInt32(bytes.size.toLong()) + if (bytes.isEmpty()) return + + reserveBits(bytes.size.toLong() * 8L) + + if ((writePointer and 7) == 0) { + val byteIndex = writePointer shr 3 + ensureCapacity(byteIndex + bytes.size - 1) + System.arraycopy(bytes, 0, stream, byteIndex, bytes.size) + writePointer += bytes.size shl 3 + return + } + + for (b in bytes) { + writeUInt8(b.toInt() and 0xFF) } } - + fun readBytes(): ByteArray { - val length = readInt32() - val bytes = ByteArray(length) - for (i in 0 until length) { - bytes[i] = readInt8().toByte() + val len = readUInt32() + if (len == 0L) return ByteArray(0) + if (len > Int.MAX_VALUE.toLong()) return ByteArray(0) + + val requiredBits = len * 8L + if (requiredBits > remainingBits()) { + return ByteArray(0) } - return bytes + + val out = ByteArray(len.toInt()) + + if ((readPointer and 7) == 0) { + val byteIndex = readPointer shr 3 + System.arraycopy(stream, byteIndex, out, 0, out.size) + readPointer += out.size shl 3 + return out + } + + for (i in out.indices) { + out[i] = readUInt8().toByte() + } + return out } - - private fun ensureCapacity(index: Int) { - while (_stream.size <= index) { - _stream.add(0) + + private fun remainingBits(): Long = (writePointer - readPointer).toLong() + + private fun writeBits(value: ULong, bits: Int) { + if (bits <= 0) return + + reserveBits(bits.toLong()) + + for (i in bits - 1 downTo 0) { + val bit = ((value shr i) and 1u).toInt() + val byteIndex = writePointer shr 3 + val shift = 7 - (writePointer and 7) + + if (bit == 1) { + stream[byteIndex] = (stream[byteIndex].toInt() or (1 shl shift)).toByte() + } else { + stream[byteIndex] = (stream[byteIndex].toInt() and (1 shl shift).inv()).toByte() + } + + writePointer++ } } + + private fun readBits(bits: Int): ULong { + if (bits <= 0) return 0u + if (remainingBits() < bits.toLong()) { + throw IllegalStateException("Not enough bits to read") + } + + var value = 0uL + repeat(bits) { + val bit = (stream[readPointer shr 3].toInt() ushr (7 - (readPointer and 7))) and 1 + value = (value shl 1) or bit.toULong() + readPointer++ + } + return value + } + + private fun reserveBits(bitsToWrite: Long) { + if (bitsToWrite <= 0L) return + + val lastBitIndex = writePointer.toLong() + bitsToWrite - 1L + if (lastBitIndex < 0L) { + throw IllegalStateException("Bit index overflow") + } + + val byteIndex = lastBitIndex ushr 3 + if (byteIndex > Int.MAX_VALUE.toLong()) { + throw IllegalStateException("Stream too large") + } + + ensureCapacity(byteIndex.toInt()) + } + + private fun ensureCapacity(index: Int) { + val requiredSize = index + 1 + if (requiredSize <= stream.size) return + + var newSize = if (stream.isEmpty()) 32 else stream.size + while (newSize < requiredSize) { + if (newSize > (Int.MAX_VALUE shr 1)) { + newSize = requiredSize + break + } + newSize = newSize shl 1 + } + + val next = ByteArray(newSize) + System.arraycopy(stream, 0, next, 0, stream.size) + stream = next + } }