diff --git a/src/main/java/client/Client.java b/src/main/java/client/Client.java index 020efc290c..14f52c8ebc 100644 --- a/src/main/java/client/Client.java +++ b/src/main/java/client/Client.java @@ -315,25 +315,6 @@ public class Client extends ChannelInboundHandlerAdapter { return inServerTransition; } - // TODO: load ipbans on server start and query it on demand. This query should not be run on every login! - @Deprecated - public boolean hasBannedIP() { - boolean ret = false; - try (Connection con = DatabaseConnection.getConnection(); - PreparedStatement ps = con.prepareStatement("SELECT COUNT(*) FROM ipbans WHERE ? LIKE CONCAT(ip, '%')")) { - ps.setString(1, remoteAddress); - try (ResultSet rs = ps.executeQuery()) { - rs.next(); - if (rs.getInt(1) > 0) { - ret = true; - } - } - } catch (SQLException e) { - e.printStackTrace(); - } - return ret; - } - // TODO: load hwidbans on server start and query it on demand. This query should not be run on every login! @Deprecated public boolean hasBannedHWID() { diff --git a/src/main/java/database/JdbiConfig.java b/src/main/java/database/JdbiConfig.java index e3a109ad5f..ed1b0edf85 100644 --- a/src/main/java/database/JdbiConfig.java +++ b/src/main/java/database/JdbiConfig.java @@ -1,6 +1,7 @@ package database; import database.account.AccountRowMapper; +import database.ban.IpBanRowMapper; import database.drop.GlobalMonsterDropRowMapper; import database.drop.MonsterDropRowMapper; import database.maker.MakerIngredientRowMapper; @@ -36,7 +37,8 @@ public final class JdbiConfig { new GlobalMonsterDropRowMapper(), new ShopRowMapper(), new ShopItemRowMapper(), - new MonsterCardRowMapper() + new MonsterCardRowMapper(), + new IpBanRowMapper() ); } } diff --git a/src/main/java/database/ban/IpBan.java b/src/main/java/database/ban/IpBan.java new file mode 100644 index 0000000000..96550c0210 --- /dev/null +++ b/src/main/java/database/ban/IpBan.java @@ -0,0 +1,12 @@ +package database.ban; + +import lombok.Builder; + +import java.util.Objects; + +@Builder +public record IpBan(String ip, Integer accountId) { + public IpBan { + Objects.requireNonNull(ip); + } +} diff --git a/src/main/java/database/ban/IpBanRepository.java b/src/main/java/database/ban/IpBanRepository.java new file mode 100644 index 0000000000..f95cb459ec --- /dev/null +++ b/src/main/java/database/ban/IpBanRepository.java @@ -0,0 +1,45 @@ +package database.ban; + +import database.PgDatabaseConnection; +import lombok.extern.slf4j.Slf4j; +import org.jdbi.v3.core.Handle; + +import java.util.List; + +/** + * @author Ponk + */ +@Slf4j +public class IpBanRepository { + private final PgDatabaseConnection connection; + + public IpBanRepository(PgDatabaseConnection connection) { + this.connection = connection; + } + + public List getAllIpBans() { + String sql = """ + SELECT ip, account_id + FROM ip_ban"""; + try (Handle handle = connection.getHandle()) { + return handle.createQuery(sql) + .mapTo(IpBan.class) + .list(); + } + } + + public boolean saveIpBan(int accountId, String ip) { + String sql = """ + INSERT INTO ip_ban (account_id, ip) + VALUES (:accountId, :ip)"""; + try (Handle handle = connection.getHandle()) { + return handle.createUpdate(sql) + .bind("accountId", accountId) + .bind("ip", ip) + .execute() > 0; + } catch (Exception e) { + log.error("Failed to save ip ban. The ip is already banned? accountId: {}, ip: {}", accountId, ip, e); + return false; + } + } +} diff --git a/src/main/java/database/ban/IpBanRowMapper.java b/src/main/java/database/ban/IpBanRowMapper.java new file mode 100644 index 0000000000..6ea7a26aa2 --- /dev/null +++ b/src/main/java/database/ban/IpBanRowMapper.java @@ -0,0 +1,18 @@ +package database.ban; + +import org.jdbi.v3.core.mapper.RowMapper; +import org.jdbi.v3.core.statement.StatementContext; + +import java.sql.ResultSet; +import java.sql.SQLException; + +public class IpBanRowMapper implements RowMapper { + + @Override + public IpBan map(ResultSet rs, StatementContext ctx) throws SQLException { + return IpBan.builder() + .ip(rs.getString("ip")) + .accountId(rs.getObject("account_id", Integer.class)) + .build(); + } +} diff --git a/src/main/java/net/ChannelDependencies.java b/src/main/java/net/ChannelDependencies.java index 256ef3e462..21d843e864 100644 --- a/src/main/java/net/ChannelDependencies.java +++ b/src/main/java/net/ChannelDependencies.java @@ -8,6 +8,7 @@ import database.character.CharacterLoader; import database.character.CharacterSaver; import database.drop.DropProvider; import lombok.Builder; +import server.ban.IpBanManager; import server.shop.ShopFactory; import service.AccountService; import service.BanService; @@ -25,7 +26,7 @@ public record ChannelDependencies( CharacterCreator characterCreator, CharacterLoader characterLoader, CharacterSaver characterSaver, NoteService noteService, FredrickProcessor fredrickProcessor, MakerProcessor makerProcessor, DropProvider dropProvider, CommandsExecutor commandsExecutor, ShopFactory shopFactory, - TransitionService transitionService, BanService banService + TransitionService transitionService, IpBanManager ipBanManager, BanService banService ) { public ChannelDependencies { @@ -40,6 +41,7 @@ public record ChannelDependencies( Objects.requireNonNull(commandsExecutor); Objects.requireNonNull(shopFactory); Objects.requireNonNull(transitionService); + Objects.requireNonNull(ipBanManager); Objects.requireNonNull(banService); } } diff --git a/src/main/java/net/PacketProcessor.java b/src/main/java/net/PacketProcessor.java index b09bb336cb..e2b66e5adf 100644 --- a/src/main/java/net/PacketProcessor.java +++ b/src/main/java/net/PacketProcessor.java @@ -282,7 +282,7 @@ public final class PacketProcessor { registerHandler(RecvOpcode.CHARLIST_REQUEST, new CharlistRequestHandler()); registerHandler(RecvOpcode.CHAR_SELECT, new CharSelectedHandler(channelDeps.transitionService())); registerHandler(RecvOpcode.LOGIN_PASSWORD, new LoginPasswordHandler(channelDeps.accountService(), - channelDeps.transitionService())); + channelDeps.transitionService(), channelDeps.banService())); registerHandler(RecvOpcode.RELOG, new RelogRequestHandler()); registerHandler(RecvOpcode.SERVERLIST_REQUEST, new ServerlistRequestHandler()); registerHandler(RecvOpcode.SERVERSTATUS_REQUEST, new ServerStatusRequestHandler()); diff --git a/src/main/java/net/server/Server.java b/src/main/java/net/server/Server.java index d06eba1943..64888de6db 100644 --- a/src/main/java/net/server/Server.java +++ b/src/main/java/net/server/Server.java @@ -45,6 +45,7 @@ import constants.net.ServerConstants; import database.PgDatabaseConfig; import database.PgDatabaseConnection; import database.account.AccountRepository; +import database.ban.IpBanRepository; import database.character.CharacterLoader; import database.character.CharacterRepository; import database.character.CharacterSaver; @@ -85,6 +86,7 @@ import server.CashShop.CashItemFactory; import server.SkillbookInformationProvider; import server.ThreadManager; import server.TimerManager; +import server.ban.IpBanManager; import server.expeditions.ExpeditionBossLog; import server.life.PlayerNPC; import server.quest.Quest; @@ -716,6 +718,7 @@ public class Server { futures.add(initExecutor.submit(CashItemFactory::loadAllCashItems)); futures.add(initExecutor.submit(Quest::loadAllQuests)); futures.add(initExecutor.submit(SkillbookInformationProvider::loadAllSkillbookInformation)); + futures.add(initExecutor.submit(channelDependencies.ipBanManager()::loadIpBans)); initExecutor.shutdown(); TimeZone.setDefault(TimeZone.getTimeZone(YamlConfig.config.server.TIMEZONE)); @@ -829,7 +832,8 @@ public class Server { NoteService noteService = new NoteService(new NoteDao(connection)); DropProvider dropProvider = new DropProvider(new DropRepository(connection)); ShopFactory shopFactory = new ShopFactory(new ShopDao(connection)); - BanService banService = new BanService(accountService, transitionService); + IpBanManager ipBanManager = new IpBanManager(new IpBanRepository(connection)); + BanService banService = new BanService(accountService, transitionService, ipBanManager); ChannelDependencies channelDependencies = ChannelDependencies.builder() .accountService(accountService) .characterCreator(new CharacterCreator(connection, characterRepository)) @@ -843,6 +847,7 @@ public class Server { characterSaver, transitionService, banService))) .shopFactory(shopFactory) .transitionService(transitionService) + .ipBanManager(ipBanManager) .banService(banService) .build(); diff --git a/src/main/java/net/server/handlers/login/LoginPasswordHandler.java b/src/main/java/net/server/handlers/login/LoginPasswordHandler.java index 81df8a5101..c198bc1a24 100644 --- a/src/main/java/net/server/handlers/login/LoginPasswordHandler.java +++ b/src/main/java/net/server/handlers/login/LoginPasswordHandler.java @@ -36,6 +36,7 @@ import net.server.world.World; import org.slf4j.Logger; import org.slf4j.LoggerFactory; import service.AccountService; +import service.BanService; import service.TransitionService; import tools.BCrypt; import tools.HexTool; @@ -49,10 +50,13 @@ public final class LoginPasswordHandler implements PacketHandler { private final AccountService accountService; private final TransitionService transitionService; + private final BanService banService; - public LoginPasswordHandler(AccountService accountService, TransitionService transitionService) { + public LoginPasswordHandler(AccountService accountService, TransitionService transitionService, + BanService banService) { this.accountService = accountService; this.transitionService = transitionService; + this.banService = banService; } @Override @@ -110,7 +114,7 @@ public final class LoginPasswordHandler implements PacketHandler { } boolean banCheckDisabled = false; - if (!banCheckDisabled && (c.hasBannedIP() || c.hasBannedMac() || c.hasBannedHWID())) { + if (!banCheckDisabled && (banService.isBanned(c) || c.hasBannedMac() || c.hasBannedHWID())) { c.sendPacket(PacketCreator.getLoginFailed(3)); return; } diff --git a/src/main/java/server/ban/IpBanManager.java b/src/main/java/server/ban/IpBanManager.java new file mode 100644 index 0000000000..281b93b50c --- /dev/null +++ b/src/main/java/server/ban/IpBanManager.java @@ -0,0 +1,45 @@ +package server.ban; + +import database.ban.IpBan; +import database.ban.IpBanRepository; +import lombok.extern.slf4j.Slf4j; +import net.jcip.annotations.ThreadSafe; + +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +/** + * @author Ponk + */ +@ThreadSafe +@Slf4j +public class IpBanManager { + private final IpBanRepository ipBanRepository; + private final Set bannedIps = new HashSet<>(); + + public IpBanManager(IpBanRepository ipBanRepository) { + this.ipBanRepository = ipBanRepository; + } + + public synchronized void loadIpBans() { + List ipBans = ipBanRepository.getAllIpBans(); + log.debug("Loaded {} ip bans", ipBans.size()); + bannedIps.addAll(ipBans.stream().map(IpBan::ip).toList()); + } + + public synchronized boolean isBanned(String ip) { + return bannedIps.contains(ip); + } + + public synchronized void banIp(String ip, int accountId) { + if (ip == null) { + throw new IllegalArgumentException("ip cannot be null"); + } + // TODO: validate ip format. Or create "Ip" model class. + + bannedIps.add(ip); + ipBanRepository.saveIpBan(accountId, ip); + } + +} diff --git a/src/main/java/service/BanService.java b/src/main/java/service/BanService.java index 4b73c8acc6..ab839c5c64 100644 --- a/src/main/java/service/BanService.java +++ b/src/main/java/service/BanService.java @@ -8,6 +8,7 @@ import lombok.extern.slf4j.Slf4j; import net.packet.Packet; import net.server.Server; import server.TimerManager; +import server.ban.IpBanManager; import tools.PacketCreator; import java.time.Duration; @@ -19,10 +20,12 @@ import java.util.concurrent.TimeUnit; public class BanService { private final AccountService accountService; private final TransitionService transitionService; + private final IpBanManager ipBanManager; - public BanService(AccountService accountService, TransitionService transitionService) { + public BanService(AccountService accountService, TransitionService transitionService, IpBanManager ipBanManager) { this.accountService = accountService; this.transitionService = transitionService; + this.ipBanManager = ipBanManager; } public void autoban(Character chr, AutobanFactory type, String reason) { @@ -111,4 +114,13 @@ public class BanService { } accountService.ban(accountId, bannedUntil, reason, description); } + + public boolean isBanned(Client c) { + return isIpBanned(c); + } + + private boolean isIpBanned(Client c) { + String ip = c.getRemoteAddress(); + return ip != null && ipBanManager.isBanned(ip); + } } diff --git a/src/main/resources/db/migration/postgresql/V0.10__ban.sql b/src/main/resources/db/migration/postgresql/V0.10__ban.sql new file mode 100644 index 0000000000..437942a82d --- /dev/null +++ b/src/main/resources/db/migration/postgresql/V0.10__ban.sql @@ -0,0 +1,8 @@ +CREATE TABLE ip_ban +( + ip varchar(15) NOT NULL, + account_id integer, + created_at timestamp DEFAULT now() NOT NULL, + PRIMARY KEY (ip) +); +GRANT SELECT, INSERT ON TABLE ip_ban TO ${server-username}; diff --git a/src/test/java/database/DatabaseTest.java b/src/test/java/database/DatabaseTest.java index b8e5932344..b3df4c9dc7 100644 --- a/src/test/java/database/DatabaseTest.java +++ b/src/test/java/database/DatabaseTest.java @@ -38,14 +38,14 @@ public abstract class DatabaseTest { @Container static PostgreSQLContainer postgres = new PostgreSQLContainer<>("postgres:%s".formatted(POSTGRES_VERSION)); - protected PgDatabaseConnection pgConnection; + protected PgDatabaseConnection connection; protected GeneratedIds testIds; @BeforeAll - void setUp() { + void setUpDatabase() { prepareMysqlConnection(); runDbMigrations(); - this.pgConnection = createPgConnection(); + this.connection = createPgConnection(); } // Not using this, but due to the nature of how the db connections are set up, the application requires @@ -90,8 +90,8 @@ public abstract class DatabaseTest { @BeforeEach void insertTestData() { - int accountId = insertAccount(pgConnection); - try (Handle handle = pgConnection.getHandle()) { + int accountId = insertAccount(connection); + try (Handle handle = connection.getHandle()) { int chrId = insertChr(handle, accountId); this.testIds = new GeneratedIds(accountId, chrId); } @@ -121,8 +121,8 @@ public abstract class DatabaseTest { List.of("chr", "account").forEach(this::clearTable); } - private void clearTable(String tableName) { + protected void clearTable(String tableName) { String sql = "DELETE FROM %s".formatted(tableName); - pgConnection.getHandle().execute(sql); + connection.getHandle().execute(sql); } } diff --git a/src/test/java/database/character/CharacterSaverTest.java b/src/test/java/database/character/CharacterSaverTest.java index 77e7b717d5..872c9c9da7 100644 --- a/src/test/java/database/character/CharacterSaverTest.java +++ b/src/test/java/database/character/CharacterSaverTest.java @@ -22,8 +22,8 @@ class CharacterSaverTest extends DatabaseTest { @BeforeEach void reset() { - this.characterSaver = new CharacterSaver(pgConnection, new CharacterRepository(), - new MonsterCardRepository(pgConnection)); + this.characterSaver = new CharacterSaver(connection, new CharacterRepository(), + new MonsterCardRepository(connection)); } @Test @@ -53,7 +53,7 @@ class CharacterSaverTest extends DatabaseTest { SELECT level FROM chr WHERE id = :id"""; - try (Handle handle = pgConnection.getHandle()) { + try (Handle handle = connection.getHandle()) { return handle.createQuery(sql) .bind("id", chrId) .mapTo(Integer.class) diff --git a/src/test/java/server/ban/IpBanManagerTest.java b/src/test/java/server/ban/IpBanManagerTest.java new file mode 100644 index 0000000000..2e22c3e093 --- /dev/null +++ b/src/test/java/server/ban/IpBanManagerTest.java @@ -0,0 +1,58 @@ +package server.ban; + +import database.DatabaseTest; +import database.ban.IpBan; +import database.ban.IpBanRepository; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import testutil.AnyValues; + +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +class IpBanManagerTest extends DatabaseTest { + private IpBanRepository ipBanRepository; + private IpBanManager ipBanManager; + + @BeforeEach + void setUp() { + this.ipBanRepository = new IpBanRepository(connection); + this.ipBanManager = new IpBanManager(ipBanRepository); + } + + @AfterEach + void deleteIpBans() { + clearTable("ip_ban"); + } + + @Test + void loadIpBans_shouldLoadFromRepository() { + String ip = "157.210.75.9"; + assertFalse(ipBanManager.isBanned(ip)); + + ipBanManager.loadIpBans(); + assertFalse(ipBanManager.isBanned(ip)); + + ipBanRepository.saveIpBan(AnyValues.integer(), ip); + ipBanManager.loadIpBans(); + + assertTrue(ipBanManager.isBanned(ip)); + } + + @Test + void banIp_shouldSaveInRepository() { + String ip = "123.231.312.123"; + assertFalse(ipBanManager.isBanned(ip)); + + ipBanManager.banIp(ip, 1001); + + assertTrue(ipBanManager.isBanned(ip)); + List ipBans = ipBanRepository.getAllIpBans(); + assertEquals(1, ipBans.size()); + assertEquals(new IpBan(ip, 1001), ipBans.getFirst()); + } +}