feat: Implement protocol packet classes and binary stream for network communication

This commit is contained in:
k1ngsterr1
2026-01-08 22:00:32 +05:00
parent b877d6fa73
commit 28a0d7a601
7 changed files with 790 additions and 5 deletions

View File

@@ -0,0 +1,187 @@
package com.rosetta.messenger.network
/**
* Base class for all protocol packets
*/
abstract class Packet {
abstract fun getPacketId(): Int
abstract fun receive(stream: Stream)
abstract fun send(): Stream
}
/**
* Handshake packet (ID: 0x00)
* First packet sent by client to authenticate with the server
*/
class PacketHandshake : Packet() {
var privateKey: String = ""
var publicKey: String = ""
var protocolVersion: Int = 1
var heartbeatInterval: Int = 15
override fun getPacketId(): Int = 0x00
override fun receive(stream: Stream) {
privateKey = stream.readString()
publicKey = stream.readString()
protocolVersion = stream.readInt8()
heartbeatInterval = stream.readInt8()
}
override fun send(): Stream {
val stream = Stream()
stream.writeInt16(getPacketId())
stream.writeString(privateKey)
stream.writeString(publicKey)
stream.writeInt8(protocolVersion)
stream.writeInt8(heartbeatInterval)
return stream
}
}
/**
* Result packet (ID: 0x02)
* Server response for various operations
*/
class PacketResult : Packet() {
var resultCode: Int = 0
var message: String = ""
override fun getPacketId(): Int = 0x02
override fun receive(stream: Stream) {
resultCode = stream.readInt8()
message = stream.readString()
}
override fun send(): Stream {
val stream = Stream()
stream.writeInt16(getPacketId())
stream.writeInt8(resultCode)
stream.writeString(message)
return stream
}
}
/**
* Search packet (ID: 0x03)
* Search for users by username or public key
*/
class PacketSearch : Packet() {
var privateKey: String = ""
var search: String = ""
var users: List<SearchUser> = emptyList()
override fun getPacketId(): Int = 0x03
override fun receive(stream: Stream) {
privateKey = stream.readString()
search = stream.readString()
val userCount = stream.readInt32()
val usersList = mutableListOf<SearchUser>()
for (i in 0 until userCount) {
val user = SearchUser(
publicKey = stream.readString(),
title = stream.readString(),
username = stream.readString(),
verified = stream.readInt8(),
online = stream.readInt8()
)
usersList.add(user)
}
users = usersList
}
override fun send(): Stream {
val stream = Stream()
stream.writeInt16(getPacketId())
stream.writeString(privateKey)
stream.writeString(search)
return stream
}
}
data class SearchUser(
val publicKey: String,
val title: String,
val username: String,
val verified: Int,
val online: Int
)
/**
* User Info packet (ID: 0x01)
* Get/Set user information
*/
class PacketUserInfo : Packet() {
var publicKey: String = ""
var title: String = ""
var username: String = ""
var verified: Int = 0
var online: Int = 0
override fun getPacketId(): Int = 0x01
override fun receive(stream: Stream) {
publicKey = stream.readString()
title = stream.readString()
username = stream.readString()
verified = stream.readInt8()
online = stream.readInt8()
}
override fun send(): Stream {
val stream = Stream()
stream.writeInt16(getPacketId())
stream.writeString(publicKey)
stream.writeString(title)
stream.writeString(username)
return stream
}
}
/**
* Online State packet (ID: 0x05)
* Notify about user online status
*/
class PacketOnlineState : Packet() {
var publicKey: String = ""
var online: Int = 0
var lastSeen: Long = 0
override fun getPacketId(): Int = 0x05
override fun receive(stream: Stream) {
publicKey = stream.readString()
online = stream.readInt8()
lastSeen = stream.readInt64()
}
override fun send(): Stream {
val stream = Stream()
stream.writeInt16(getPacketId())
stream.writeString(publicKey)
return stream
}
}
/**
* Online Subscribe packet (ID: 0x04)
* Subscribe to user online status updates
*/
class PacketOnlineSubscribe : Packet() {
var publicKey: String = ""
override fun getPacketId(): Int = 0x04
override fun receive(stream: Stream) {
publicKey = stream.readString()
}
override fun send(): Stream {
val stream = Stream()
stream.writeInt16(getPacketId())
stream.writeString(publicKey)
return stream
}
}

View File

@@ -0,0 +1,315 @@
package com.rosetta.messenger.network
import android.util.Log
import kotlinx.coroutines.*
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.asStateFlow
import okhttp3.*
import okio.ByteString
import java.util.concurrent.TimeUnit
/**
* Protocol connection states
*/
enum class ProtocolState {
DISCONNECTED,
CONNECTING,
CONNECTED,
HANDSHAKING,
AUTHENTICATED
}
/**
* Protocol client for Rosetta Messenger
* Handles WebSocket connection and packet exchange with server
*/
class Protocol(private val serverAddress: String) {
companion object {
private const val TAG = "RosettaProtocol"
private const val RECONNECT_INTERVAL = 10000L // 10 seconds
private const val MAX_RECONNECT_ATTEMPTS = 5
private const val HANDSHAKE_TIMEOUT = 10000L // 10 seconds
}
private val client = OkHttpClient.Builder()
.readTimeout(0, TimeUnit.MILLISECONDS)
.connectTimeout(10, TimeUnit.SECONDS)
.build()
private var webSocket: WebSocket? = null
private var reconnectAttempts = 0
private var isManuallyClosed = false
private var handshakeComplete = false
private var handshakeJob: Job? = null
private val scope = CoroutineScope(Dispatchers.IO + SupervisorJob())
private val _state = MutableStateFlow(ProtocolState.DISCONNECTED)
val state: StateFlow<ProtocolState> = _state.asStateFlow()
private val _lastError = MutableStateFlow<String?>(null)
val lastError: StateFlow<String?> = _lastError.asStateFlow()
// Packet waiters - callbacks for specific packet types
private val packetWaiters = mutableMapOf<Int, MutableList<(Packet) -> Unit>>()
// Packet queue for packets sent before handshake complete
private val packetQueue = mutableListOf<Packet>()
// Last used credentials for reconnection
private var lastPublicKey: String? = null
private var lastPrivateHash: String? = null
// Supported packets
private val supportedPackets = mapOf(
0x00 to { PacketHandshake() },
0x01 to { PacketUserInfo() },
0x02 to { PacketResult() },
0x03 to { PacketSearch() },
0x04 to { PacketOnlineSubscribe() },
0x05 to { PacketOnlineState() }
)
init {
// Register handshake response handler
waitPacket(0x00) { packet ->
if (packet is PacketHandshake) {
Log.d(TAG, "✅ Handshake response received, protocol version: ${packet.protocolVersion}")
handshakeJob?.cancel()
handshakeComplete = true
_state.value = ProtocolState.AUTHENTICATED
flushPacketQueue()
}
}
}
/**
* Initialize connection to server
*/
fun connect() {
if (_state.value == ProtocolState.CONNECTING || _state.value == ProtocolState.CONNECTED) {
Log.d(TAG, "Already connecting or connected")
return
}
isManuallyClosed = false
_state.value = ProtocolState.CONNECTING
_lastError.value = null
Log.d(TAG, "🔌 Connecting to: $serverAddress")
val request = Request.Builder()
.url(serverAddress)
.build()
webSocket = client.newWebSocket(request, object : WebSocketListener() {
override fun onOpen(webSocket: WebSocket, response: Response) {
Log.d(TAG, "✅ WebSocket connected")
reconnectAttempts = 0
_state.value = ProtocolState.CONNECTED
// If we have saved credentials, start handshake automatically
lastPublicKey?.let { publicKey ->
lastPrivateHash?.let { privateHash ->
startHandshake(publicKey, privateHash)
}
}
}
override fun onMessage(webSocket: WebSocket, bytes: ByteString) {
handleMessage(bytes.toByteArray())
}
override fun onMessage(webSocket: WebSocket, text: String) {
Log.d(TAG, "Received text message (unexpected): $text")
}
override fun onClosing(webSocket: WebSocket, code: Int, reason: String) {
Log.d(TAG, "WebSocket closing: $code - $reason")
}
override fun onClosed(webSocket: WebSocket, code: Int, reason: String) {
Log.d(TAG, "WebSocket closed: $code - $reason")
handleDisconnect()
}
override fun onFailure(webSocket: WebSocket, t: Throwable, response: Response?) {
Log.e(TAG, "❌ WebSocket error: ${t.message}")
_lastError.value = t.message
handleDisconnect()
}
})
}
/**
* Start handshake with server
*/
fun startHandshake(publicKey: String, privateHash: String) {
Log.d(TAG, "🤝 Starting handshake...")
Log.d(TAG, " Public key: ${publicKey.take(20)}...")
Log.d(TAG, " Private hash: ${privateHash.take(20)}...")
// Save credentials for reconnection
lastPublicKey = publicKey
lastPrivateHash = privateHash
if (_state.value != ProtocolState.CONNECTED && _state.value != ProtocolState.AUTHENTICATED) {
Log.d(TAG, "Not connected, will handshake after connection")
connect()
return
}
_state.value = ProtocolState.HANDSHAKING
handshakeComplete = false
val handshake = PacketHandshake().apply {
this.publicKey = publicKey
this.privateKey = privateHash
}
sendPacketDirect(handshake)
// Set handshake timeout
handshakeJob?.cancel()
handshakeJob = scope.launch {
delay(HANDSHAKE_TIMEOUT)
if (!handshakeComplete) {
Log.e(TAG, "❌ Handshake timeout")
_lastError.value = "Handshake timeout"
disconnect()
}
}
}
/**
* Send packet to server
* Packets are queued if handshake is not complete
*/
fun sendPacket(packet: Packet) {
if (!handshakeComplete && packet !is PacketHandshake) {
Log.d(TAG, "📦 Queueing packet: ${packet.getPacketId()}")
packetQueue.add(packet)
return
}
sendPacketDirect(packet)
}
private fun sendPacketDirect(packet: Packet) {
val stream = packet.send()
val data = stream.getStream()
Log.d(TAG, "📤 Sending packet: ${packet.getPacketId()} (${data.size} bytes)")
webSocket?.send(ByteString.of(*data))
}
private fun flushPacketQueue() {
Log.d(TAG, "📬 Flushing ${packetQueue.size} queued packets")
val packets = packetQueue.toList()
packetQueue.clear()
packets.forEach { sendPacketDirect(it) }
}
private fun handleMessage(data: ByteArray) {
try {
val stream = Stream(data)
val packetId = stream.readInt16()
Log.d(TAG, "📥 Received packet: $packetId")
val packetFactory = supportedPackets[packetId]
if (packetFactory == null) {
Log.w(TAG, "Unknown packet ID: $packetId")
return
}
val packet = packetFactory()
packet.receive(stream)
// Notify waiters
packetWaiters[packetId]?.forEach { callback ->
try {
callback(packet)
} catch (e: Exception) {
Log.e(TAG, "Error in packet handler: ${e.message}")
}
}
} catch (e: Exception) {
Log.e(TAG, "Error parsing packet: ${e.message}")
}
}
private fun handleDisconnect() {
_state.value = ProtocolState.DISCONNECTED
handshakeComplete = false
handshakeJob?.cancel()
if (!isManuallyClosed && reconnectAttempts < MAX_RECONNECT_ATTEMPTS) {
reconnectAttempts++
Log.d(TAG, "🔄 Reconnecting in ${RECONNECT_INTERVAL}ms (attempt $reconnectAttempts/$MAX_RECONNECT_ATTEMPTS)")
scope.launch {
delay(RECONNECT_INTERVAL)
connect()
}
} else if (reconnectAttempts >= MAX_RECONNECT_ATTEMPTS) {
Log.e(TAG, "❌ Max reconnect attempts reached")
_lastError.value = "Unable to connect to server"
}
}
/**
* Register callback for specific packet type
*/
fun waitPacket(packetId: Int, callback: (Packet) -> Unit) {
packetWaiters.getOrPut(packetId) { mutableListOf() }.add(callback)
}
/**
* Unregister callback for specific packet type
*/
fun unwaitPacket(packetId: Int, callback: (Packet) -> Unit) {
packetWaiters[packetId]?.remove(callback)
}
/**
* Disconnect from server
*/
fun disconnect() {
Log.d(TAG, "Disconnecting...")
isManuallyClosed = true
handshakeJob?.cancel()
webSocket?.close(1000, "User disconnected")
webSocket = null
_state.value = ProtocolState.DISCONNECTED
}
/**
* Check if connected and authenticated
*/
fun isAuthenticated(): Boolean = _state.value == ProtocolState.AUTHENTICATED
/**
* Check if connected (may not be authenticated yet)
*/
fun isConnected(): Boolean = _state.value == ProtocolState.CONNECTED ||
_state.value == ProtocolState.HANDSHAKING ||
_state.value == ProtocolState.AUTHENTICATED
/**
* Clear saved credentials
*/
fun clearCredentials() {
lastPublicKey = null
lastPrivateHash = null
}
/**
* Release resources
*/
fun destroy() {
disconnect()
scope.cancel()
}
}

View File

@@ -0,0 +1,102 @@
package com.rosetta.messenger.network
import android.util.Log
import kotlinx.coroutines.flow.StateFlow
/**
* Singleton manager for Protocol instance
* Ensures single connection across the app
*/
object ProtocolManager {
private const val TAG = "ProtocolManager"
// Server address - same as React Native version
private const val SERVER_ADDRESS = "ws://46.28.71.12:3000"
private var protocol: Protocol? = null
/**
* Get or create Protocol instance
*/
fun getProtocol(): Protocol {
if (protocol == null) {
Log.d(TAG, "Creating new Protocol instance")
protocol = Protocol(SERVER_ADDRESS)
}
return protocol!!
}
/**
* Get connection state flow
*/
val state: StateFlow<ProtocolState>
get() = getProtocol().state
/**
* Get last error flow
*/
val lastError: StateFlow<String?>
get() = getProtocol().lastError
/**
* Connect to server
*/
fun connect() {
getProtocol().connect()
}
/**
* Authenticate with server
*/
fun authenticate(publicKey: String, privateHash: String) {
Log.d(TAG, "Authenticating...")
getProtocol().startHandshake(publicKey, privateHash)
}
/**
* Send packet
*/
fun sendPacket(packet: Packet) {
getProtocol().sendPacket(packet)
}
/**
* Register packet handler
*/
fun waitPacket(packetId: Int, callback: (Packet) -> Unit) {
getProtocol().waitPacket(packetId, callback)
}
/**
* Unregister packet handler
*/
fun unwaitPacket(packetId: Int, callback: (Packet) -> Unit) {
getProtocol().unwaitPacket(packetId, callback)
}
/**
* Disconnect and clear
*/
fun disconnect() {
protocol?.disconnect()
protocol?.clearCredentials()
}
/**
* Destroy instance completely
*/
fun destroy() {
protocol?.destroy()
protocol = null
}
/**
* Check if authenticated
*/
fun isAuthenticated(): Boolean = protocol?.isAuthenticated() ?: false
/**
* Check if connected
*/
fun isConnected(): Boolean = protocol?.isConnected() ?: false
}

View File

@@ -0,0 +1,146 @@
package com.rosetta.messenger.network
/**
* Binary stream for protocol packets
* Matches the React Native implementation exactly
*/
class Stream(stream: ByteArray = ByteArray(0)) {
private var _stream = mutableListOf<Int>()
private var _readPointer = 0
private var _writePointer = 0
init {
_stream = stream.map { it.toInt() and 0xFF }.toMutableList()
}
fun getStream(): ByteArray {
return _stream.map { it.toByte() }.toByteArray()
}
fun setStream(stream: ByteArray) {
_stream = stream.map { it.toInt() and 0xFF }.toMutableList()
_readPointer = 0
}
fun writeInt8(value: Int) {
val negationBit = if (value < 0) 1 else 0
val int8Value = Math.abs(value) and 0xFF
ensureCapacity(_writePointer shr 3)
_stream[_writePointer shr 3] = _stream[_writePointer shr 3] or (negationBit shl (7 - (_writePointer and 7)))
_writePointer++
for (i in 0 until 8) {
val bit = (int8Value shr (7 - i)) and 1
ensureCapacity(_writePointer shr 3)
_stream[_writePointer shr 3] = _stream[_writePointer shr 3] or (bit shl (7 - (_writePointer and 7)))
_writePointer++
}
}
fun readInt8(): Int {
var value = 0
val negationBit = (_stream[_readPointer shr 3] shr (7 - (_readPointer and 7))) and 1
_readPointer++
for (i in 0 until 8) {
val bit = (_stream[_readPointer shr 3] shr (7 - (_readPointer and 7))) and 1
value = value or (bit shl (7 - i))
_readPointer++
}
return if (negationBit == 1) -value else value
}
fun writeBit(value: Int) {
val bit = value and 1
ensureCapacity(_writePointer shr 3)
_stream[_writePointer shr 3] = _stream[_writePointer shr 3] or (bit shl (7 - (_writePointer and 7)))
_writePointer++
}
fun readBit(): Int {
val bit = (_stream[_readPointer shr 3] shr (7 - (_readPointer and 7))) and 1
_readPointer++
return bit
}
fun writeBoolean(value: Boolean) {
writeBit(if (value) 1 else 0)
}
fun readBoolean(): Boolean {
return readBit() == 1
}
fun writeInt16(value: Int) {
writeInt8(value shr 8)
writeInt8(value and 0xFF)
}
fun readInt16(): Int {
val high = readInt8() shl 8
return high or readInt8()
}
fun writeInt32(value: Int) {
writeInt16(value shr 16)
writeInt16(value and 0xFFFF)
}
fun readInt32(): Int {
val high = readInt16() shl 16
return high or readInt16()
}
fun writeInt64(value: Long) {
val high = (value shr 32).toInt()
val low = (value and 0xFFFFFFFF).toInt()
writeInt32(high)
writeInt32(low)
}
fun readInt64(): Long {
val high = readInt32().toLong()
val low = (readInt32().toLong() and 0xFFFFFFFFL)
return (high shl 32) or low
}
fun writeString(value: String) {
writeInt32(value.length)
for (char in value) {
writeInt16(char.code)
}
}
fun readString(): String {
val length = readInt32()
val sb = StringBuilder()
for (i in 0 until length) {
sb.append(readInt16().toChar())
}
return sb.toString()
}
fun writeBytes(value: ByteArray) {
writeInt32(value.size)
for (byte in value) {
writeInt8(byte.toInt())
}
}
fun readBytes(): ByteArray {
val length = readInt32()
val bytes = ByteArray(length)
for (i in 0 until length) {
bytes[i] = readInt8().toByte()
}
return bytes
}
private fun ensureCapacity(index: Int) {
while (_stream.size <= index) {
_stream.add(0)
}
}
}

View File

@@ -1,5 +1,6 @@
package com.rosetta.messenger.ui.auth
import android.util.Log
import androidx.compose.animation.*
import androidx.compose.animation.core.*
import androidx.compose.foundation.*
@@ -27,6 +28,7 @@ import androidx.compose.ui.unit.sp
import com.rosetta.messenger.crypto.CryptoManager
import com.rosetta.messenger.data.AccountManager
import com.rosetta.messenger.data.EncryptedAccount
import com.rosetta.messenger.network.ProtocolManager
import com.rosetta.messenger.ui.onboarding.PrimaryBlue
import kotlinx.coroutines.launch
@@ -426,6 +428,11 @@ fun SetPasswordScreen(
accountManager.saveAccount(account)
accountManager.setCurrentAccount(keyPair.publicKey)
// 🔌 Connect to server and authenticate
val privateKeyHash = CryptoManager.generatePrivateKeyHash(keyPair.privateKey)
Log.d("SetPasswordScreen", "🔌 Connecting to server...")
ProtocolManager.authenticate(keyPair.publicKey, privateKeyHash)
onAccountCreated()
} catch (e: Exception) {
error = "Failed to create account: ${e.message}"

View File

@@ -1,5 +1,6 @@
package com.rosetta.messenger.ui.auth
import android.util.Log
import androidx.compose.animation.*
import androidx.compose.animation.core.*
import androidx.compose.foundation.*
@@ -29,6 +30,8 @@ import com.rosetta.messenger.R
import com.rosetta.messenger.crypto.CryptoManager
import com.rosetta.messenger.data.AccountManager
import com.rosetta.messenger.data.DecryptedAccount
import com.rosetta.messenger.network.ProtocolManager
import com.rosetta.messenger.network.ProtocolState
import com.rosetta.messenger.ui.onboarding.PrimaryBlue
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.launch
@@ -249,6 +252,10 @@ fun UnlockScreen(
name = account.name
)
// 🔌 Connect to server and authenticate
Log.d("UnlockScreen", "🔌 Connecting to server...")
ProtocolManager.authenticate(account.publicKey, privateKeyHash)
accountManager.setCurrentAccount(publicKey)
onUnlocked(decryptedAccount)

View File

@@ -28,6 +28,8 @@ import androidx.compose.ui.unit.dp
import androidx.compose.ui.unit.sp
import com.airbnb.lottie.compose.*
import com.rosetta.messenger.R
import com.rosetta.messenger.network.ProtocolManager
import com.rosetta.messenger.network.ProtocolState
import com.rosetta.messenger.ui.onboarding.PrimaryBlue
import kotlinx.coroutines.launch
import java.text.SimpleDateFormat
@@ -120,6 +122,9 @@ fun ChatsListScreen(
val drawerState = rememberDrawerState(initialValue = DrawerValue.Closed)
val scope = rememberCoroutineScope()
// Protocol connection state
val protocolState by ProtocolManager.state.collectAsState()
var visible by remember { mutableStateOf(false) }
LaunchedEffect(Unit) {
@@ -283,11 +288,27 @@ fun ChatsListScreen(
Spacer(modifier = Modifier.width(12.dp))
Text(
"Rosetta",
fontWeight = FontWeight.Bold,
fontSize = 20.sp
)
// Title with connection status
Column {
Text(
"Rosetta",
fontWeight = FontWeight.Bold,
fontSize = 20.sp
)
if (protocolState != ProtocolState.AUTHENTICATED) {
Text(
text = when (protocolState) {
ProtocolState.DISCONNECTED -> "Connecting..."
ProtocolState.CONNECTING -> "Connecting..."
ProtocolState.CONNECTED -> "Authenticating..."
ProtocolState.HANDSHAKING -> "Authenticating..."
ProtocolState.AUTHENTICATED -> ""
},
fontSize = 12.sp,
color = secondaryTextColor
)
}
}
}
},
actions = {