Хэндшейк, сервисы, аннотационные блокировки в протоколе, репозитории

This commit is contained in:
RoyceDa
2026-02-03 05:42:46 +02:00
parent 4c290a01ac
commit 9b715df09d
12 changed files with 457 additions and 20 deletions

View File

@@ -1,6 +1,7 @@
package com.rosetta.im.database;
import java.util.HashMap;
import java.util.List;
import org.hibernate.Session;
import org.hibernate.Transaction;
@@ -16,6 +17,11 @@ public abstract class Repository<T> {
this.entityClass = entityClass;
}
/**
* Сохранение сущности в базе данных
* @param entity сущность для сохранения
* @return сохраненная сущность
*/
public T save(T entity) {
return executeInTransaction(session -> {
session.persist(entity);
@@ -23,6 +29,11 @@ public abstract class Repository<T> {
});
}
/**
* Обновление сущности в базе данных
* @param entity сущность для обновления
* @return обновленная сущность
*/
public T update(T entity) {
return executeInTransaction(session -> {
session.merge(entity);
@@ -30,6 +41,10 @@ public abstract class Repository<T> {
});
}
/**
* Удаление сущности из базы данных
* @param entity сущность для удаления
*/
public void delete(T entity) {
executeInTransaction(session -> {
session.remove(entity);
@@ -76,6 +91,126 @@ public abstract class Repository<T> {
});
}
/**
* Удаление сущностей по значению одного поля
* @param fieldName поле
* @param value значение
*/
public void deleteByField(String fieldName, Object value) {
executeInTransaction(session -> {
String queryString = "DELETE FROM " + entityClass.getSimpleName() + " WHERE " + fieldName + " = :value";
session.createQuery(queryString, entityClass)
.setParameter("value", value)
.executeUpdate();
return null;
});
}
/**
* Поиск всех сущностей по значению одного поля
* @param fieldName поле
* @param value значение
* @return список найденных сущностей
*/
public List<T> findAllByField(String fieldName, Object value) {
return executeInSession(session -> {
String queryString = "FROM " + entityClass.getSimpleName() + " WHERE " + fieldName + " = :value";
return session.createQuery(queryString, entityClass)
.setParameter("value", value)
.list();
});
}
/**
* Поиск всех сущностей по значению набора полей
* @param fields карта полей и их значений
* @return список найденных сущностей
*/
public List<T> findAllByField(HashMap<String, Object> fields) {
return executeInSession(session -> {
StringBuilder queryString = new StringBuilder("FROM " + entityClass.getSimpleName() + " WHERE ");
int index = 0;
for (String fieldName : fields.keySet()) {
if (index > 0) {
queryString.append(" AND ");
}
queryString.append(fieldName).append(" = :").append(fieldName);
index++;
}
var query = session.createQuery(queryString.toString(), this.entityClass);
for (var entry : fields.entrySet()) {
query.setParameter(entry.getKey(), entry.getValue());
}
return query.list();
});
}
/**
* Поиск всех сущностей, тяжелый метод, лучше не выполнять без необходимости
* @return список всех сущностей
*/
public List<T> findAll() {
return executeInSession(session -> {
String queryString = "FROM " + entityClass.getSimpleName();
return session.createQuery(queryString, entityClass).list();
});
}
/**
* Подсчет всех сущностей в таблице
* @return количество сущностей
*/
public long countAll() {
return executeInSession(session -> {
String queryString = "SELECT COUNT(*) FROM " + entityClass.getSimpleName();
return session.createQuery(queryString, Long.class).uniqueResult();
});
}
/**
* Подсчет сущностей по значению одного поля
* @param fieldName поле
* @param value значение
* @return количество сущностей
*/
public long countByField(String fieldName, Object value) {
return executeInSession(session -> {
String queryString = "SELECT COUNT(*) FROM " + entityClass.getSimpleName() + " WHERE " + fieldName + " = :value";
return session.createQuery(queryString, Long.class)
.setParameter("value", value)
.uniqueResult();
});
}
/**
* Подсчет сущностей по набору полей
* @param fields карта полей и их значений
* @return количество сущностей
*/
public long countByField(HashMap<String, Object> fields) {
return executeInSession(session -> {
StringBuilder queryString = new StringBuilder("SELECT COUNT(*) FROM " + entityClass.getSimpleName() + " WHERE ");
int index = 0;
for (String fieldName : fields.keySet()) {
if (index > 0) {
queryString.append(" AND ");
}
queryString.append(fieldName).append(" = :").append(fieldName);
index++;
}
var query = session.createQuery(queryString.toString(), Long.class);
for (var entry : fields.entrySet()) {
query.setParameter(entry.getKey(), entry.getValue());
}
return query.uniqueResult();
});
}
/**
* Обновление полей сущности по заданным условиям
* @param fieldsToUpdate поля для обновления
* @param whereFields условия для выбора сущностей
*/
public void update(HashMap<String, Object> fieldsToUpdate, HashMap<String, Object> whereFields) {
executeInTransaction(session -> {
StringBuilder queryString = new StringBuilder("UPDATE " + entityClass.getSimpleName() + " SET ");

View File

@@ -7,7 +7,6 @@ import jakarta.persistence.Entity;
import jakarta.persistence.GeneratedValue;
import jakarta.persistence.GenerationType;
import jakarta.persistence.Id;
import jakarta.persistence.PrePersist;
import jakarta.persistence.Table;
@Entity
@@ -29,14 +28,9 @@ public class Device extends CreateUpdateEntity {
/**
* Время завершения сессии устройства
*/
@Column(name = "leaveTime", nullable = true)
@Column(name = "leaveTime", nullable = true, columnDefinition = "bigint default 0")
private Long leaveTime;
@PrePersist
protected void onCreate() {
this.leaveTime = 0L;
}
public Long getId() {
return id;
}

View File

@@ -1,5 +1,25 @@
package com.rosetta.im.database.repository;
public class DeviceRepository {
import java.util.List;
import com.rosetta.im.database.Repository;
import com.rosetta.im.database.entity.Device;
import com.rosetta.im.database.entity.User;
public class DeviceRepository extends Repository<Device> {
public DeviceRepository() {
super(Device.class);
}
/**
* Найти все устройства пользователя
* @param user пользователь
* @return список устройств
*/
public List<Device> findAll(User user) {
return this.findAllByField("publicKey", user.getPublicKey());
}
}

View File

@@ -0,0 +1,18 @@
package com.rosetta.im.event.events.handshake;
import com.rosetta.im.client.tags.ECIAuthentificate;
import com.rosetta.im.client.tags.ECIDevice;
import io.orprotocol.client.Client;
/**
* Вызывается когда устройство клиента нуждается в подтверждении
* пользователем с другого устрйоства для завершения хэндшейка.
*/
public class HandshakeDeviceConfirmEvent extends BaseHandshakeEvent {
public HandshakeDeviceConfirmEvent(String publicKey, String privateKey, ECIDevice device, ECIAuthentificate eciAuthentificate, Client client) {
super(publicKey, privateKey, device, eciAuthentificate, client);
}
}

View File

@@ -5,23 +5,32 @@ import com.rosetta.im.Configuration;
import com.rosetta.im.Failures;
import com.rosetta.im.client.tags.ECIAuthentificate;
import com.rosetta.im.client.tags.ECIDevice;
import com.rosetta.im.database.entity.Device;
import com.rosetta.im.database.entity.User;
import com.rosetta.im.database.repository.DeviceRepository;
import com.rosetta.im.database.repository.UserRepository;
import com.rosetta.im.event.events.handshake.HandshakeCompletedEvent;
import com.rosetta.im.event.events.handshake.HandshakeDeviceConfirmEvent;
import com.rosetta.im.event.events.handshake.HandshakeFailedEvent;
import com.rosetta.im.packet.Packet0Handshake;
import com.rosetta.im.packet.enums.HandshakeStage;
import com.rosetta.im.service.services.DeviceService;
import io.orprotocol.ProtocolException;
import io.orprotocol.client.Client;
import io.orprotocol.lock.Lock;
import io.orprotocol.packet.Packet;
import io.orprotocol.packet.PacketExecutor;
public class Executor0Handshake extends PacketExecutor {
private final UserRepository userRepository = new UserRepository();
private final DeviceRepository deviceRepository = new DeviceRepository();
private final DeviceService deviceService = new DeviceService(deviceRepository);
@Override
@Lock(lockFor = "publicKey")
public void onPacketReceived(Packet packet, Client client) throws ProtocolException {
Packet0Handshake handshake = (Packet0Handshake) packet;
String publicKey = handshake.getPublicKey();
@@ -31,7 +40,6 @@ public class Executor0Handshake extends PacketExecutor {
String deviceOs = handshake.getDeviceOs();
int protocolVersion = handshake.getProtocolVersion();
AppContext context = (AppContext) this.getContext();
/**
* Получаем информацию об аутентификации клиента
* используя возможности ECI тэгов.
@@ -43,7 +51,6 @@ public class Executor0Handshake extends PacketExecutor {
*/
return;
}
/**
* Проверяем корректность версии протокола
*/
@@ -56,6 +63,7 @@ public class Executor0Handshake extends PacketExecutor {
* Создаем минимальную информацию об устройстве клиента
*/
ECIDevice device = new ECIDevice(deviceId, deviceName, deviceOs);
client.addTag(device);
/**
* Проверяем есть ли такой пользователь
@@ -100,10 +108,9 @@ public class Executor0Handshake extends PacketExecutor {
/**
* Отправляем клиенту подтверждение успешного хэндшейка
*/
Packet0Handshake response = new Packet0Handshake();
response.setHandshakeStage(HandshakeStage.COMPLETED);
response.setHeartbeatInterval(this.settings.heartbeatInterval);
client.send(response);
handshake.setHandshakeStage(HandshakeStage.COMPLETED);
handshake.setHeartbeatInterval(this.settings.heartbeatInterval);
client.send(handshake);
return;
}
/**
@@ -117,10 +124,77 @@ public class Executor0Handshake extends PacketExecutor {
client.disconnect(Failures.AUTHENTIFICATION_ERROR);
return;
}
long userDevicesCount = deviceService.countUserDevices(user);
/**
* Проверяем верифицировано ли устройство
*/
if(userDevicesCount > 0 && !deviceService.isDeviceVerifiedByUser(deviceId, user)) {
/**
* Устройство не верифицировано, нужно отправить клиента
* на подтверждение устройства
*/
handshake.setHandshakeStage(HandshakeStage.NEED_DEVICE_VERIFICATION);
handshake.setHeartbeatInterval(this.settings.heartbeatInterval);
/**
* Вызываем событие подтверждения устройства
*/
context.getEventManager().callEvent(
new HandshakeDeviceConfirmEvent(publicKey, privateKey, device, authentificate, client)
);
/**
* Ставим метку аутентификации на клиента
*/
ECIAuthentificate eciTag = new ECIAuthentificate
(publicKey, privateKey, HandshakeStage.NEED_DEVICE_VERIFICATION);
client.addTag(eciTag);
/**
* Отправляем клиенту информацию о необходимости
* подтверждения устройства
*/
client.send(handshake);
return;
}
if(userDevicesCount == 0) {
/**
* Это первое устройство пользователя, сохраняем его
* как верифицированное
*/
Device newDevice = new Device();
newDevice.setDeviceId(deviceId);
newDevice.setDeviceName(deviceName);
newDevice.setDeviceOs(deviceOs);
newDevice.setPublicKey(publicKey);
newDevice.setLeaveTime(System.currentTimeMillis());
deviceRepository.save(newDevice);
}
/**
* Ставим метку аутентификации на клиента
*/
ECIAuthentificate eciTag = new ECIAuthentificate
(publicKey, privateKey, HandshakeStage.COMPLETED);
client.addTag(eciTag);
/**
* Вызываем событие завершения хэндшейка
*/
boolean cancelled = context.getEventManager().callEvent(
new HandshakeCompletedEvent(publicKey, privateKey, device, eciTag, client)
);
if(cancelled) {
/**
* Событие было отменено, не даем завершить хэндшейк
*/
client.disconnect(Failures.DATA_MISSMATCH);
return;
}
/**
* Отправляем клиенту подтверждение успешного хэндшейка
*/
handshake.setHandshakeStage(HandshakeStage.COMPLETED);
handshake.setHeartbeatInterval(this.settings.heartbeatInterval);
client.send(handshake);
}
}

View File

@@ -0,0 +1,23 @@
package com.rosetta.im.service;
/**
* Базовый класс для всех сервисов. Нужно чтобы унифицировать доступ к репозиториям,
* а так же не раздувать логику в executor'ах. Так код в executor'ах будет чище и
* проще для понимания. Для атомарных операций с сущностями сервисы не используются, они используются только для
* более сложной логики, требующей взаимодействия с несколькими репозиториями или
* иной бизнес-логики.
* @param <T> тип репозитория
*/
public abstract class Service<T> {
private T repository;
public Service(T repository) {
this.repository = repository;
}
public T getRepository() {
return repository;
}
}

View File

@@ -0,0 +1,48 @@
package com.rosetta.im.service.services;
import java.util.List;
import com.rosetta.im.database.entity.Device;
import com.rosetta.im.database.entity.User;
import com.rosetta.im.database.repository.DeviceRepository;
import com.rosetta.im.service.Service;
public class DeviceService extends Service<DeviceRepository> {
public DeviceService(DeviceRepository repository) {
super(repository);
}
/**
* Проверяет, верифицировано ли устройство с deviceId для пользователя user
* @param deviceId ID устройства
* @param user пользователь
* @return true если устройство верифицировано, иначе false
*/
public boolean isDeviceVerifiedByUser(String deviceId, User user) {
List<Device> devices = this.getRepository().findAll(user);
if(devices.size() == 0) {
/**
* Если у пользователя нет устройств, значит текущее устройство верифицировано
* такого быть не может, это избыточная проверка
*/
return true;
}
for(Device device : devices) {
if(device.getDeviceId().equals(deviceId)) {
return true;
}
}
return false;
}
/**
* Считает количество устройств пользователя
* @param user пользователь
* @return количество устройств
*/
public long countUserDevices(User user) {
return this.getRepository().countByField("publicKey", user.getPublicKey());
}
}

View File

@@ -11,6 +11,7 @@ import org.java_websocket.handshake.ClientHandshake;
import org.java_websocket.server.WebSocketServer;
import io.orprotocol.client.Client;
import io.orprotocol.lock.ThreadLocker;
import io.orprotocol.packet.Packet;
import io.orprotocol.packet.PacketExecutor;
import io.orprotocol.packet.PacketManager;
@@ -22,6 +23,7 @@ public class Server extends WebSocketServer {
private ScheduledExecutorService scheduler = Executors.newScheduledThreadPool(1);
private Context context;
private ServerListener listener;
private ThreadLocker threadLocker = new ThreadLocker();
/**
* Конструктор сервера
@@ -144,10 +146,27 @@ public class Server extends WebSocketServer {
*/
return;
}
executor.onPacketReceived(packet, client);
/**
* Проверяем наличие блокировки для данного пакета и ключа в аннотации @Lock.
*/
if(!threadLocker.acquireLock(packet, executorClass)) {
/**
* Если блокировка уже существует, значит другой поток обрабатывает пакет
* с таким же значением lockFor, отклоняем текущий пакет.
*/
return;
}
try {
executor.onPacketReceived(packet, client);
} finally {
/**
* Снимаем блокировку после обработки пакета.
*/
threadLocker.releaseLock(packet, executorClass);
}
} catch (Exception e) {
System.out.println("Error while processing packet " + packetClass.getName());
System.out.println(e.getStackTrace());
e.printStackTrace();
}
}

View File

@@ -46,7 +46,7 @@ public class Client {
this.eciTags = new HashMap<Class<? extends ECITag>, ECITag>();
this.heartbeatInterval = heartbeatInterval;
this.lastHeartbeatTime = System.currentTimeMillis();
this.packetManager = new PacketManager();
this.packetManager = packetManager;
}
/**
@@ -57,7 +57,7 @@ public class Client {
* @return
*/
public boolean isAlive() {
return (System.currentTimeMillis() - this.lastHeartbeatTime) * 2 <= this.heartbeatInterval * 1000;
return (System.currentTimeMillis() - this.lastHeartbeatTime) <= ((this.heartbeatInterval * 1000) * 2);
}
/**

View File

@@ -0,0 +1,23 @@
package io.orprotocol.lock;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
/**
* Аннотация для указания блокировки
* при обработке пакета.
* Мультипоточная блокировка, то есть блокировка сработает только
* если другой поток уже обрабатывает пакет с полем lockFor, однако другие потоки
* могут обрабатывать пакеты с другими значениями lockFor параллельно.
*/
@Retention(RetentionPolicy.RUNTIME)
@Target(ElementType.METHOD)
public @interface Lock {
/**
* По какому полю в пакете
* будет осуществляться блокировка
*/
String lockFor();
}

View File

@@ -0,0 +1,82 @@
package io.orprotocol.lock;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.util.concurrent.ConcurrentHashMap;
import io.orprotocol.client.Client;
import io.orprotocol.packet.Packet;
import io.orprotocol.packet.PacketExecutor;
/**
* Менеджер блокировок для обработки пакетов с аннотацией @Lock.
*/
public class ThreadLocker {
private final ConcurrentHashMap<String, Boolean> locks = new ConcurrentHashMap<>();
/**
* Пытается захватить блокировку для указанного пакета и ключа в аннотации @Lock.
* @param packet Пакет для которого требуется блокировка
* @param exectuor Класс исполнителя пакета
* @return true, если блокировка успешно захвачена, иначе false.
*/
public boolean acquireLock(Packet packet, Class<? extends PacketExecutor> exectuor) {
try{
Method prMethod = exectuor.getMethod("onPacketReceived", Packet.class, Client.class);
if(prMethod == null) {
return true;
}
Lock lockAnnotation = prMethod.getAnnotation(Lock.class);
if(lockAnnotation == null) {
return true;
}
String fieldName = lockAnnotation.lockFor();
Field field = packet.getClass().getDeclaredField(fieldName);
field.setAccessible(true);
String fieldValue = (String) field.get(packet);
String lockValue = packet.getClass().getName() + "_" + fieldValue;
if(locks.putIfAbsent(lockValue, true) != null) {
/**
* Если блокировка уже существует, значит другой поток обрабатывает пакет
* с таким же значением lockFor, отклоняем текущий пакет.
*/
return false;
}
return true;
}catch(Exception e) {
/**
* Игнорируем ошибки при попытке блокировки,
* чтобы не блокировать обработку пакета из-за ошибок рефлексии
*/
return true;
}
}
/**
* Освобождает блокировку для указанного пакета и ключа в аннотации @Lock.
* @param packet Пакет для которого требуется разблокировка
* @param exectuor Класс исполнителя пакета
*/
public void releaseLock(Packet packet, Class<? extends PacketExecutor> exectuor) {
try{
Method prMethod = exectuor.getMethod("onPacketReceived", Packet.class, Client.class);
if(prMethod == null) {
return;
}
Lock lockAnnotation = prMethod.getAnnotation(Lock.class);
if(lockAnnotation == null) {
return;
}
String fieldName = lockAnnotation.lockFor();
Field field = packet.getClass().getDeclaredField(fieldName);
field.setAccessible(true);
String fieldValue = (String) field.get(packet);
String lockValue = packet.getClass().getName() + "_" + fieldValue;
locks.remove(lockValue);
}catch(Exception e) {
// Игнорируем ошибки при разблокировке
}
}
}

View File

@@ -10,6 +10,7 @@
<property name="hibernate.hbm2ddl.auto">update</property>
<!--Зарегистрированные таблицы-->
<mapping class="com.rosetta.im.database.entity.User"/>
<mapping class="com.rosetta.im.database.entity.Device"/>
</session-factory>
</hibernate-configuration>