]> source.dussan.org Git - sonarqube.git/commitdiff
properly handle connections in DbTester utility methods
authorSébastien Lesaint <sebastien.lesaint@sonarsource.com>
Wed, 31 Aug 2016 08:53:09 +0000 (10:53 +0200)
committerSébastien Lesaint <sebastien.lesaint@sonarsource.com>
Mon, 5 Sep 2016 09:32:17 +0000 (11:32 +0200)
sonar-db/src/test/java/org/sonar/db/DbTester.java

index 15b62f2d1f63e6fc749777c6f151c1869b451bf2..86695a90e4e867ea2e07a876db03554e32d7435e 100644 (file)
@@ -20,7 +20,6 @@
 package org.sonar.db;
 
 import com.google.common.base.Joiner;
-import com.google.common.base.Supplier;
 import com.google.common.collect.ImmutableMap;
 import com.google.common.collect.Lists;
 import java.io.InputStream;
@@ -57,7 +56,7 @@ import org.dbunit.operation.DatabaseOperation;
 import org.junit.rules.ExternalResource;
 import org.picocontainer.containers.TransientPicoContainer;
 import org.sonar.api.utils.System2;
-import org.sonar.api.utils.log.Logger;
+import org.sonar.api.utils.log.Loggers;
 
 import static com.google.common.base.Preconditions.checkArgument;
 import static com.google.common.collect.Lists.newArrayList;
@@ -144,8 +143,6 @@ public class DbTester extends ExternalResource {
   /**
    * Very simple helper method to insert some data into a table.
    * It's the responsibility of the caller to convert column values to string.
-   *
-   * @param valuesByColumn column name and value pairs, if any value is null, the associated column won't be inserted
    */
   public void executeInsert(String table, String firstColumn, Object... others) {
     executeInsert(table, mapOf(firstColumn, others));
@@ -187,14 +184,14 @@ public class DbTester extends ExternalResource {
    * <pre>int issues = countRowsOfTable("issues")</pre>
    */
   public int countRowsOfTable(String tableName) {
-    return countRowsOfTable(tableName, this::getConnection);
+    return countRowsOfTable(tableName, new NewConnectionSupplier());
   }
 
   public int countRowsOfTable(DbSession dbSession, String tableName) {
-    return countRowsOfTable(tableName, () -> dbSession.getConnection());
+    return countRowsOfTable(tableName, new DbSessionConnectionSupplier(dbSession));
   }
 
-  private int countRowsOfTable(String tableName, SqlExceptionSupplier<Connection> connectionSupplier) {
+  private int countRowsOfTable(String tableName, ConnectionSupplier connectionSupplier) {
     checkArgument(StringUtils.containsNone(tableName, " "), "Parameter must be the name of a table. Got " + tableName);
     return countSql("select count(1) from " + tableName.toLowerCase(Locale.ENGLISH), connectionSupplier);
   }
@@ -204,19 +201,19 @@ public class DbTester extends ExternalResource {
    * <pre>int OpenIssues = countSql("select count('id') from issues where status is not null")</pre>
    */
   public int countSql(String sql) {
-    return countSql(sql, this::getConnection);
+    return countSql(sql, new NewConnectionSupplier());
   }
 
   public int countSql(DbSession dbSession, String sql) {
-    return countSql(sql, () -> dbSession.getConnection());
+    return countSql(sql, new DbSessionConnectionSupplier(dbSession));
   }
 
-  private int countSql(String sql, SqlExceptionSupplier<Connection> connectionSupplier) {
+  private int countSql(String sql, ConnectionSupplier connectionSupplier) {
     checkArgument(StringUtils.contains(sql, "count("),
       "Parameter must be a SQL request containing 'count(x)' function. Got " + sql);
     try (
-      Connection connection = connectionSupplier.get();
-      PreparedStatement stmt = connection.prepareStatement(sql);
+      ConnectionSupplier supplier = connectionSupplier;
+      PreparedStatement stmt = supplier.get().prepareStatement(sql);
       ResultSet rs = stmt.executeQuery()) {
       if (rs.next()) {
         return rs.getInt(1);
@@ -229,13 +226,13 @@ public class DbTester extends ExternalResource {
   }
 
   public List<Map<String, Object>> select(String selectSql) {
-    return select(selectSql, this::getConnection);
+    return select(selectSql, new NewConnectionSupplier());
   }
 
-  private List<Map<String, Object>> select(String selectSql, SqlExceptionSupplier<Connection> connectionSupplier) {
+  private List<Map<String, Object>> select(String selectSql, ConnectionSupplier connectionSupplier) {
     try (
-      Connection connection = connectionSupplier.get();
-      PreparedStatement stmt = connection.prepareStatement(selectSql);
+      ConnectionSupplier supplier = connectionSupplier;
+      PreparedStatement stmt = supplier.get().prepareStatement(selectSql);
       ResultSet rs = stmt.executeQuery()) {
       return getHashMap(rs);
     } catch (Exception e) {
@@ -244,14 +241,14 @@ public class DbTester extends ExternalResource {
   }
 
   public Map<String, Object> selectFirst(String selectSql) {
-    return selectFirst(selectSql, this::getConnection);
+    return selectFirst(selectSql, new NewConnectionSupplier());
   }
 
   public Map<String, Object> selectFirst(DbSession dbSession, String selectSql) {
-    return selectFirst(selectSql, () -> dbSession.getConnection());
+    return selectFirst(selectSql, new DbSessionConnectionSupplier(dbSession));
   }
 
-  private Map<String, Object> selectFirst(String selectSql, SqlExceptionSupplier<Connection> connectionSupplier) {
+  private Map<String, Object> selectFirst(String selectSql, ConnectionSupplier connectionSupplier) {
     List<Map<String, Object>> rows = select(selectSql, connectionSupplier);
     if (rows.isEmpty()) {
       throw new IllegalStateException("No results for " + selectSql);
@@ -503,11 +500,54 @@ public class DbTester extends ExternalResource {
   }
 
   /**
-   * A {@link Supplier} that declares the checked exception {@link SQLException}.
+   * An {@link AutoCloseable} supplier of {@link Connection}.
    */
-  @FunctionalInterface
-  private interface SqlExceptionSupplier<T> {
-    T get() throws SQLException;
+  private interface ConnectionSupplier extends AutoCloseable {
+    Connection get() throws SQLException;
+
+    @Override
+    void close();
+  }
+
+  private static class DbSessionConnectionSupplier implements ConnectionSupplier {
+    private final DbSession dbSession;
+
+    public DbSessionConnectionSupplier(DbSession dbSession) {
+      this.dbSession = dbSession;
+    }
+
+    @Override
+    public Connection get() throws SQLException {
+      return dbSession.getConnection();
+    }
+
+    @Override
+    public void close() {
+      // closing dbSession is not our responsability
+    }
   }
 
+  private class NewConnectionSupplier implements ConnectionSupplier {
+    private Connection connection;
+
+    @Override
+    public Connection get() throws SQLException {
+      if (this.connection == null) {
+        this.connection = getConnection();
+      }
+      return this.connection;
+    }
+
+    @Override
+    public void close() {
+      if (this.connection != null) {
+        try {
+          this.connection.close();
+        } catch (SQLException e) {
+          Loggers.get(DbTester.class).warn("Fail to close connection", e);
+          // do not re-throw the exception
+        }
+      }
+    }
+  }
 }