aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorCheng Lian <lian@databricks.com>2014-11-27 18:01:14 -0800
committerMatei Zaharia <matei@databricks.com>2014-11-27 18:01:14 -0800
commit120a350240f58196eafcb038ca3a353636d89239 (patch)
treeeb38a76fe24422c3b92809f2980afdbcb6cc7fc0 /core
parent84376d31392858f7df215ddb3f05419181152e68 (diff)
downloadspark-120a350240f58196eafcb038ca3a353636d89239.tar.gz
spark-120a350240f58196eafcb038ca3a353636d89239.tar.bz2
spark-120a350240f58196eafcb038ca3a353636d89239.zip
[SPARK-4613][Core] Java API for JdbcRDD
This PR introduces a set of Java APIs for using `JdbcRDD`: 1. Trait (interface) `JdbcRDD.ConnectionFactory`: equivalent to the `getConnection: () => Connection` parameter in `JdbcRDD` constructor. 2. Two overloaded versions of `Jdbc.create`: used to create `JavaRDD` that wraps a `JdbcRDD`. <!-- Reviewable:start --> [<img src="https://reviewable.io/review_button.png" height=40 alt="Review on Reviewable"/>](https://reviewable.io/reviews/apache/spark/3478) <!-- Reviewable:end --> Author: Cheng Lian <lian@databricks.com> Closes #3478 from liancheng/japi-jdbc-rdd and squashes the following commits: 9a54625 [Cheng Lian] Only shutdowns a single DB rather than the whole Derby driver d4cedc5 [Cheng Lian] Moves Java JdbcRDD test case to a separate test suite ffcdf2e [Cheng Lian] Java API for JdbcRDD
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala84
-rw-r--r--core/src/test/java/org/apache/spark/JavaJdbcRDDSuite.java118
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/JdbcRDDSuite.scala7
3 files changed, 204 insertions, 5 deletions
diff --git a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala
index 0e38f224ac..642a12c1ed 100644
--- a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala
@@ -21,8 +21,11 @@ import java.sql.{Connection, ResultSet}
import scala.reflect.ClassTag
-import org.apache.spark.{Logging, Partition, SparkContext, TaskContext}
+import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
+import org.apache.spark.api.java.function.{Function => JFunction}
+import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.util.NextIterator
+import org.apache.spark.{Logging, Partition, SparkContext, TaskContext}
private[spark] class JdbcPartition(idx: Int, val lower: Long, val upper: Long) extends Partition {
override def index = idx
@@ -125,5 +128,82 @@ object JdbcRDD {
def resultSetToObjectArray(rs: ResultSet): Array[Object] = {
Array.tabulate[Object](rs.getMetaData.getColumnCount)(i => rs.getObject(i + 1))
}
-}
+ trait ConnectionFactory extends Serializable {
+ @throws[Exception]
+ def getConnection: Connection
+ }
+
+ /**
+ * Create an RDD that executes an SQL query on a JDBC connection and reads results.
+ * For usage example, see test case JavaAPISuite.testJavaJdbcRDD.
+ *
+ * @param connectionFactory a factory that returns an open Connection.
+ * The RDD takes care of closing the connection.
+ * @param sql the text of the query.
+ * The query must contain two ? placeholders for parameters used to partition the results.
+ * E.g. "select title, author from books where ? <= id and id <= ?"
+ * @param lowerBound the minimum value of the first placeholder
+ * @param upperBound the maximum value of the second placeholder
+ * The lower and upper bounds are inclusive.
+ * @param numPartitions the number of partitions.
+ * Given a lowerBound of 1, an upperBound of 20, and a numPartitions of 2,
+ * the query would be executed twice, once with (1, 10) and once with (11, 20)
+ * @param mapRow a function from a ResultSet to a single row of the desired result type(s).
+ * This should only call getInt, getString, etc; the RDD takes care of calling next.
+ * The default maps a ResultSet to an array of Object.
+ */
+ def create[T](
+ sc: JavaSparkContext,
+ connectionFactory: ConnectionFactory,
+ sql: String,
+ lowerBound: Long,
+ upperBound: Long,
+ numPartitions: Int,
+ mapRow: JFunction[ResultSet, T]): JavaRDD[T] = {
+
+ val jdbcRDD = new JdbcRDD[T](
+ sc.sc,
+ () => connectionFactory.getConnection,
+ sql,
+ lowerBound,
+ upperBound,
+ numPartitions,
+ (resultSet: ResultSet) => mapRow.call(resultSet))(fakeClassTag)
+
+ new JavaRDD[T](jdbcRDD)(fakeClassTag)
+ }
+
+ /**
+ * Create an RDD that executes an SQL query on a JDBC connection and reads results. Each row is
+ * converted into a `Object` array. For usage example, see test case JavaAPISuite.testJavaJdbcRDD.
+ *
+ * @param connectionFactory a factory that returns an open Connection.
+ * The RDD takes care of closing the connection.
+ * @param sql the text of the query.
+ * The query must contain two ? placeholders for parameters used to partition the results.
+ * E.g. "select title, author from books where ? <= id and id <= ?"
+ * @param lowerBound the minimum value of the first placeholder
+ * @param upperBound the maximum value of the second placeholder
+ * The lower and upper bounds are inclusive.
+ * @param numPartitions the number of partitions.
+ * Given a lowerBound of 1, an upperBound of 20, and a numPartitions of 2,
+ * the query would be executed twice, once with (1, 10) and once with (11, 20)
+ */
+ def create(
+ sc: JavaSparkContext,
+ connectionFactory: ConnectionFactory,
+ sql: String,
+ lowerBound: Long,
+ upperBound: Long,
+ numPartitions: Int): JavaRDD[Array[Object]] = {
+
+ val mapRow = new JFunction[ResultSet, Array[Object]] {
+ override def call(resultSet: ResultSet): Array[Object] = {
+ resultSetToObjectArray(resultSet)
+ }
+ }
+
+ create(sc, connectionFactory, sql, lowerBound, upperBound, numPartitions, mapRow)
+ }
+}
diff --git a/core/src/test/java/org/apache/spark/JavaJdbcRDDSuite.java b/core/src/test/java/org/apache/spark/JavaJdbcRDDSuite.java
new file mode 100644
index 0000000000..7fe452a48d
--- /dev/null
+++ b/core/src/test/java/org/apache/spark/JavaJdbcRDDSuite.java
@@ -0,0 +1,118 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark;
+
+import java.io.Serializable;
+import java.sql.Connection;
+import java.sql.DriverManager;
+import java.sql.PreparedStatement;
+import java.sql.ResultSet;
+import java.sql.SQLException;
+import java.sql.Statement;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.api.java.function.Function;
+import org.apache.spark.api.java.function.Function2;
+import org.apache.spark.rdd.JdbcRDD;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+public class JavaJdbcRDDSuite implements Serializable {
+ private transient JavaSparkContext sc;
+
+ @Before
+ public void setUp() throws ClassNotFoundException, SQLException {
+ sc = new JavaSparkContext("local", "JavaAPISuite");
+
+ Class.forName("org.apache.derby.jdbc.EmbeddedDriver");
+ Connection connection =
+ DriverManager.getConnection("jdbc:derby:target/JavaJdbcRDDSuiteDb;create=true");
+
+ try {
+ Statement create = connection.createStatement();
+ create.execute(
+ "CREATE TABLE FOO(" +
+ "ID INTEGER NOT NULL GENERATED ALWAYS AS IDENTITY (START WITH 1, INCREMENT BY 1)," +
+ "DATA INTEGER)");
+ create.close();
+
+ PreparedStatement insert = connection.prepareStatement("INSERT INTO FOO(DATA) VALUES(?)");
+ for (int i = 1; i <= 100; i++) {
+ insert.setInt(1, i * 2);
+ insert.executeUpdate();
+ }
+ insert.close();
+ } catch (SQLException e) {
+ // If table doesn't exist...
+ if (e.getSQLState().compareTo("X0Y32") != 0) {
+ throw e;
+ }
+ } finally {
+ connection.close();
+ }
+ }
+
+ @After
+ public void tearDown() throws SQLException {
+ try {
+ DriverManager.getConnection("jdbc:derby:target/JavaJdbcRDDSuiteDb;shutdown=true");
+ } catch(SQLException e) {
+ // Throw if not normal single database shutdown
+ // https://db.apache.org/derby/docs/10.2/ref/rrefexcept71493.html
+ if (e.getSQLState().compareTo("08006") != 0) {
+ throw e;
+ }
+ }
+
+ sc.stop();
+ sc = null;
+ }
+
+ @Test
+ public void testJavaJdbcRDD() throws Exception {
+ JavaRDD<Integer> rdd = JdbcRDD.create(
+ sc,
+ new JdbcRDD.ConnectionFactory() {
+ @Override
+ public Connection getConnection() throws SQLException {
+ return DriverManager.getConnection("jdbc:derby:target/JavaJdbcRDDSuiteDb");
+ }
+ },
+ "SELECT DATA FROM FOO WHERE ? <= ID AND ID <= ?",
+ 1, 100, 1,
+ new Function<ResultSet, Integer>() {
+ @Override
+ public Integer call(ResultSet r) throws Exception {
+ return r.getInt(1);
+ }
+ }
+ ).cache();
+
+ Assert.assertEquals(100, rdd.count());
+ Assert.assertEquals(
+ Integer.valueOf(10100),
+ rdd.reduce(new Function2<Integer, Integer, Integer>() {
+ @Override
+ public Integer call(Integer i1, Integer i2) {
+ return i1 + i2;
+ }
+ }));
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/rdd/JdbcRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/JdbcRDDSuite.scala
index 76e317d754..6138d0bbd5 100644
--- a/core/src/test/scala/org/apache/spark/rdd/JdbcRDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/JdbcRDDSuite.scala
@@ -65,10 +65,11 @@ class JdbcRDDSuite extends FunSuite with BeforeAndAfter with LocalSparkContext {
after {
try {
- DriverManager.getConnection("jdbc:derby:;shutdown=true")
+ DriverManager.getConnection("jdbc:derby:target/JdbcRDDSuiteDb;shutdown=true")
} catch {
- case se: SQLException if se.getSQLState == "XJ015" =>
- // normal shutdown
+ case se: SQLException if se.getSQLState == "08006" =>
+ // Normal single database shutdown
+ // https://db.apache.org/derby/docs/10.2/ref/rrefexcept71493.html
}
}
}