diff options
author | Reynold Xin <rxin@cs.berkeley.edu> | 2013-05-14 23:12:00 -0700 |
---|---|---|
committer | Reynold Xin <rxin@cs.berkeley.edu> | 2013-05-14 23:12:00 -0700 |
commit | 81ad2fa3319bbe93fae88ce0831a1bc26c0797d2 (patch) | |
tree | 974342912ce640b4e92ff7a0ab37a9e3d47d88ea | |
parent | 016ac868303adbffb19165430610869363fa943f (diff) | |
parent | b16c4896f617f352bb230908b7c08c7c5b028434 (diff) | |
download | spark-81ad2fa3319bbe93fae88ce0831a1bc26c0797d2.tar.gz spark-81ad2fa3319bbe93fae88ce0831a1bc26c0797d2.tar.bz2 spark-81ad2fa3319bbe93fae88ce0831a1bc26c0797d2.zip |
Merge branch 'jdbc' of github.com:koeninger/spark
Conflicts:
project/SparkBuild.scala
-rw-r--r-- | .gitignore | 1 | ||||
-rw-r--r-- | core/src/main/scala/spark/rdd/JdbcRDD.scala | 101 | ||||
-rw-r--r-- | core/src/test/scala/spark/rdd/JdbcRDDSuite.scala | 56 | ||||
-rw-r--r-- | project/SparkBuild.scala | 3 |
4 files changed, 160 insertions, 1 deletions
diff --git a/.gitignore b/.gitignore index 155e785b01..b87fc1ee79 100644 --- a/.gitignore +++ b/.gitignore @@ -36,3 +36,4 @@ streaming-tests.log dependency-reduced-pom.xml .ensime .ensime_lucene +derby.log diff --git a/core/src/main/scala/spark/rdd/JdbcRDD.scala b/core/src/main/scala/spark/rdd/JdbcRDD.scala new file mode 100644 index 0000000000..b0f7054233 --- /dev/null +++ b/core/src/main/scala/spark/rdd/JdbcRDD.scala @@ -0,0 +1,101 @@ +package spark.rdd + +import java.sql.{Connection, ResultSet} + +import spark.{Logging, Partition, RDD, SparkContext, TaskContext} +import spark.util.NextIterator + +private[spark] class JdbcPartition(idx: Int, val lower: Long, val upper: Long) extends Partition { + override def index = idx +} + +/** + * An RDD that executes an SQL query on a JDBC connection and reads results. + * @param getConnection a function that returns an open Connection. + * The RDD takes care of closing the connection. + * @param sql the text of the query. + * The query must contain two ? placeholders for parameters used to partition the results. + * E.g. "select title, author from books where ? <= id and id <= ?" + * @param lowerBound the minimum value of the first placeholder + * @param upperBound the maximum value of the second placeholder + * The lower and upper bounds are inclusive. + * @param numPartitions the number of partitions. + * Given a lowerBound of 1, an upperBound of 20, and a numPartitions of 2, + * the query would be executed twice, once with (1, 10) and once with (11, 20) + * @param mapRow a function from a ResultSet to a single row of the desired result type(s). + * This should only call getInt, getString, etc; the RDD takes care of calling next. + * The default maps a ResultSet to an array of Object. + */ +class JdbcRDD[T: ClassManifest]( + sc: SparkContext, + getConnection: () => Connection, + sql: String, + lowerBound: Long, + upperBound: Long, + numPartitions: Int, + mapRow: (ResultSet) => T = JdbcRDD.resultSetToObjectArray _) + extends RDD[T](sc, Nil) with Logging { + + override def getPartitions: Array[Partition] = { + // bounds are inclusive, hence the + 1 here and - 1 on end + val length = 1 + upperBound - lowerBound + (0 until numPartitions).map(i => { + val start = lowerBound + ((i * length) / numPartitions).toLong + val end = lowerBound + (((i + 1) * length) / numPartitions).toLong - 1 + new JdbcPartition(i, start, end) + }).toArray + } + + override def compute(thePart: Partition, context: TaskContext) = new NextIterator[T] { + context.addOnCompleteCallback{ () => closeIfNeeded() } + val part = thePart.asInstanceOf[JdbcPartition] + val conn = getConnection() + val stmt = conn.prepareStatement(sql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY) + + // setFetchSize(Integer.MIN_VALUE) is a mysql driver specific way to force streaming results, + // rather than pulling entire resultset into memory. + // see http://dev.mysql.com/doc/refman/5.0/en/connector-j-reference-implementation-notes.html + if (conn.getMetaData.getURL.matches("jdbc:mysql:.*")) { + stmt.setFetchSize(Integer.MIN_VALUE) + logInfo("statement fetch size set to: " + stmt.getFetchSize + " to force MySQL streaming ") + } + + stmt.setLong(1, part.lower) + stmt.setLong(2, part.upper) + val rs = stmt.executeQuery() + + override def getNext: T = { + if (rs.next()) { + mapRow(rs) + } else { + finished = true + null.asInstanceOf[T] + } + } + + override def close() { + try { + if (null != rs && ! rs.isClosed()) rs.close() + } catch { + case e: Exception => logWarning("Exception closing resultset", e) + } + try { + if (null != stmt && ! stmt.isClosed()) stmt.close() + } catch { + case e: Exception => logWarning("Exception closing statement", e) + } + try { + if (null != conn && ! stmt.isClosed()) conn.close() + logInfo("closed connection") + } catch { + case e: Exception => logWarning("Exception closing connection", e) + } + } + } +} + +object JdbcRDD { + def resultSetToObjectArray(rs: ResultSet) = { + Array.tabulate[Object](rs.getMetaData.getColumnCount)(i => rs.getObject(i + 1)) + } +} diff --git a/core/src/test/scala/spark/rdd/JdbcRDDSuite.scala b/core/src/test/scala/spark/rdd/JdbcRDDSuite.scala new file mode 100644 index 0000000000..6afb0fa9bc --- /dev/null +++ b/core/src/test/scala/spark/rdd/JdbcRDDSuite.scala @@ -0,0 +1,56 @@ +package spark + +import org.scalatest.{ BeforeAndAfter, FunSuite } +import spark.SparkContext._ +import spark.rdd.JdbcRDD +import java.sql._ + +class JdbcRDDSuite extends FunSuite with BeforeAndAfter with LocalSparkContext { + + before { + Class.forName("org.apache.derby.jdbc.EmbeddedDriver") + val conn = DriverManager.getConnection("jdbc:derby:target/JdbcRDDSuiteDb;create=true") + try { + val create = conn.createStatement + create.execute(""" + CREATE TABLE FOO( + ID INTEGER NOT NULL GENERATED ALWAYS AS IDENTITY (START WITH 1, INCREMENT BY 1), + DATA INTEGER + )""") + create.close + val insert = conn.prepareStatement("INSERT INTO FOO(DATA) VALUES(?)") + (1 to 100).foreach { i => + insert.setInt(1, i * 2) + insert.executeUpdate + } + insert.close + } catch { + case e: SQLException if e.getSQLState == "X0Y32" => + // table exists + } finally { + conn.close + } + } + + test("basic functionality") { + sc = new SparkContext("local", "test") + val rdd = new JdbcRDD( + sc, + () => { DriverManager.getConnection("jdbc:derby:target/JdbcRDDSuiteDb") }, + "SELECT DATA FROM FOO WHERE ? <= ID AND ID <= ?", + 1, 100, 3, + (r: ResultSet) => { r.getInt(1) } ).cache + + assert(rdd.count === 100) + assert(rdd.reduce(_+_) === 10100) + } + + after { + try { + DriverManager.getConnection("jdbc:derby:;shutdown=true") + } catch { + case se: SQLException if se.getSQLState == "XJ015" => + // normal shutdown + } + } +} diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 190d723435..267008bfa4 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -160,7 +160,8 @@ object SparkBuild extends Build { "cc.spray" % "spray-can" % "1.0-M2.1" excludeAll(excludeNetty), "cc.spray" % "spray-server" % "1.0-M2.1" excludeAll(excludeNetty), "cc.spray" % "spray-json_2.9.2" % "1.1.1" excludeAll(excludeNetty), - "org.apache.mesos" % "mesos" % "0.9.0-incubating" + "org.apache.mesos" % "mesos" % "0.9.0-incubating", + "org.apache.derby" % "derby" % "10.4.2.0" % "test" ) ++ ( if (HADOOP_MAJOR_VERSION == "2") { if (HADOOP_YARN) { |