From e0188c9d2755900310c98e218cd68c14dc695853 Mon Sep 17 00:00:00 2001 From: =?utf8?q?S=C3=A9bastien=20Lesaint?= Date: Wed, 31 Aug 2016 10:53:09 +0200 Subject: [PATCH] properly handle connections in DbTester utility methods --- .../src/test/java/org/sonar/db/DbTester.java | 86 ++++++++++++++----- 1 file changed, 63 insertions(+), 23 deletions(-) diff --git a/sonar-db/src/test/java/org/sonar/db/DbTester.java b/sonar-db/src/test/java/org/sonar/db/DbTester.java index 15b62f2d1f6..86695a90e4e 100644 --- a/sonar-db/src/test/java/org/sonar/db/DbTester.java +++ b/sonar-db/src/test/java/org/sonar/db/DbTester.java @@ -20,7 +20,6 @@ package org.sonar.db; import com.google.common.base.Joiner; -import com.google.common.base.Supplier; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Lists; import java.io.InputStream; @@ -57,7 +56,7 @@ import org.dbunit.operation.DatabaseOperation; import org.junit.rules.ExternalResource; import org.picocontainer.containers.TransientPicoContainer; import org.sonar.api.utils.System2; -import org.sonar.api.utils.log.Logger; +import org.sonar.api.utils.log.Loggers; import static com.google.common.base.Preconditions.checkArgument; import static com.google.common.collect.Lists.newArrayList; @@ -144,8 +143,6 @@ public class DbTester extends ExternalResource { /** * Very simple helper method to insert some data into a table. * It's the responsibility of the caller to convert column values to string. - * - * @param valuesByColumn column name and value pairs, if any value is null, the associated column won't be inserted */ public void executeInsert(String table, String firstColumn, Object... others) { executeInsert(table, mapOf(firstColumn, others)); @@ -187,14 +184,14 @@ public class DbTester extends ExternalResource { *
int issues = countRowsOfTable("issues")
*/ public int countRowsOfTable(String tableName) { - return countRowsOfTable(tableName, this::getConnection); + return countRowsOfTable(tableName, new NewConnectionSupplier()); } public int countRowsOfTable(DbSession dbSession, String tableName) { - return countRowsOfTable(tableName, () -> dbSession.getConnection()); + return countRowsOfTable(tableName, new DbSessionConnectionSupplier(dbSession)); } - private int countRowsOfTable(String tableName, SqlExceptionSupplier connectionSupplier) { + private int countRowsOfTable(String tableName, ConnectionSupplier connectionSupplier) { checkArgument(StringUtils.containsNone(tableName, " "), "Parameter must be the name of a table. Got " + tableName); return countSql("select count(1) from " + tableName.toLowerCase(Locale.ENGLISH), connectionSupplier); } @@ -204,19 +201,19 @@ public class DbTester extends ExternalResource { *
int OpenIssues = countSql("select count('id') from issues where status is not null")
*/ public int countSql(String sql) { - return countSql(sql, this::getConnection); + return countSql(sql, new NewConnectionSupplier()); } public int countSql(DbSession dbSession, String sql) { - return countSql(sql, () -> dbSession.getConnection()); + return countSql(sql, new DbSessionConnectionSupplier(dbSession)); } - private int countSql(String sql, SqlExceptionSupplier connectionSupplier) { + private int countSql(String sql, ConnectionSupplier connectionSupplier) { checkArgument(StringUtils.contains(sql, "count("), "Parameter must be a SQL request containing 'count(x)' function. Got " + sql); try ( - Connection connection = connectionSupplier.get(); - PreparedStatement stmt = connection.prepareStatement(sql); + ConnectionSupplier supplier = connectionSupplier; + PreparedStatement stmt = supplier.get().prepareStatement(sql); ResultSet rs = stmt.executeQuery()) { if (rs.next()) { return rs.getInt(1); @@ -229,13 +226,13 @@ public class DbTester extends ExternalResource { } public List> select(String selectSql) { - return select(selectSql, this::getConnection); + return select(selectSql, new NewConnectionSupplier()); } - private List> select(String selectSql, SqlExceptionSupplier connectionSupplier) { + private List> select(String selectSql, ConnectionSupplier connectionSupplier) { try ( - Connection connection = connectionSupplier.get(); - PreparedStatement stmt = connection.prepareStatement(selectSql); + ConnectionSupplier supplier = connectionSupplier; + PreparedStatement stmt = supplier.get().prepareStatement(selectSql); ResultSet rs = stmt.executeQuery()) { return getHashMap(rs); } catch (Exception e) { @@ -244,14 +241,14 @@ public class DbTester extends ExternalResource { } public Map selectFirst(String selectSql) { - return selectFirst(selectSql, this::getConnection); + return selectFirst(selectSql, new NewConnectionSupplier()); } public Map selectFirst(DbSession dbSession, String selectSql) { - return selectFirst(selectSql, () -> dbSession.getConnection()); + return selectFirst(selectSql, new DbSessionConnectionSupplier(dbSession)); } - private Map selectFirst(String selectSql, SqlExceptionSupplier connectionSupplier) { + private Map selectFirst(String selectSql, ConnectionSupplier connectionSupplier) { List> rows = select(selectSql, connectionSupplier); if (rows.isEmpty()) { throw new IllegalStateException("No results for " + selectSql); @@ -503,11 +500,54 @@ public class DbTester extends ExternalResource { } /** - * A {@link Supplier} that declares the checked exception {@link SQLException}. + * An {@link AutoCloseable} supplier of {@link Connection}. */ - @FunctionalInterface - private interface SqlExceptionSupplier { - T get() throws SQLException; + private interface ConnectionSupplier extends AutoCloseable { + Connection get() throws SQLException; + + @Override + void close(); + } + + private static class DbSessionConnectionSupplier implements ConnectionSupplier { + private final DbSession dbSession; + + public DbSessionConnectionSupplier(DbSession dbSession) { + this.dbSession = dbSession; + } + + @Override + public Connection get() throws SQLException { + return dbSession.getConnection(); + } + + @Override + public void close() { + // closing dbSession is not our responsability + } } + private class NewConnectionSupplier implements ConnectionSupplier { + private Connection connection; + + @Override + public Connection get() throws SQLException { + if (this.connection == null) { + this.connection = getConnection(); + } + return this.connection; + } + + @Override + public void close() { + if (this.connection != null) { + try { + this.connection.close(); + } catch (SQLException e) { + Loggers.get(DbTester.class).warn("Fail to close connection", e); + // do not re-throw the exception + } + } + } + } } -- 2.39.5