From b2a3f24dde7a69587a5fea50d3e1e4e8f02a2dc3 Mon Sep 17 00:00:00 2001 From: koeninger Date: Sun, 21 Apr 2013 00:29:37 -0500 Subject: first attempt at an RDD to pull data from JDBC sources --- core/src/main/scala/spark/rdd/JdbcRDD.scala | 79 +++++++++++++++++++++++++++++ 1 file changed, 79 insertions(+) create mode 100644 core/src/main/scala/spark/rdd/JdbcRDD.scala diff --git a/core/src/main/scala/spark/rdd/JdbcRDD.scala b/core/src/main/scala/spark/rdd/JdbcRDD.scala new file mode 100644 index 0000000000..c8a5d76012 --- /dev/null +++ b/core/src/main/scala/spark/rdd/JdbcRDD.scala @@ -0,0 +1,79 @@ +package spark.rdd + +import java.sql.{Connection, ResultSet} + +import spark.{Logging, Partition, RDD, SparkContext, TaskContext} +import spark.util.NextIterator + +/** + An RDD that executes an SQL query on a JDBC connection and reads results. + @param getConnection a function that returns an open Connection. + The RDD takes care of closing the connection. + @param sql the text of the query. + The query must contain two ? placeholders for parameters used to partition the results. + E.g. "select title, author from books where ? <= id and id <= ?" + @param lowerBound the minimum value of the first placeholder + @param upperBound the maximum value of the second placeholder + The lower and upper bounds are inclusive. + @param numPartitions the amount of parallelism. + Given a lowerBound of 1, an upperBound of 20, and a numPartitions of 2, + the query would be executed twice, once with (1, 10) and once with (11, 20) + @param mapRow a function from a ResultSet to a single row of the desired result type(s). + This should only call getInt, getString, etc; the RDD takes care of calling next. + The default maps a ResultSet to an array of Object. +*/ +class JdbcRDD[T: ClassManifest]( + sc: SparkContext, + getConnection: () => Connection, + sql: String, + lowerBound: Long, + upperBound: Long, + numPartitions: Int, + mapRow: (ResultSet) => T = JdbcRDD.resultSetToObjectArray) + extends RDD[T](sc, Nil) with Logging { + + override def getPartitions: Array[Partition] = + ParallelCollectionRDD.slice(lowerBound to upperBound, numPartitions). + filter(! _.isEmpty). + zipWithIndex. + map(x => new JdbcPartition(x._2, x._1.head, x._1.last)). + toArray + + override def compute(thePart: Partition, context: TaskContext) = new NextIterator[T] { + val part = thePart.asInstanceOf[JdbcPartition] + val conn = getConnection() + context.addOnCompleteCallback{ () => closeIfNeeded() } + val stmt = conn.prepareStatement(sql) + stmt.setLong(1, part.lower) + stmt.setLong(2, part.upper) + val rs = stmt.executeQuery() + + override def getNext: T = { + if (rs.next()) { + mapRow(rs) + } else { + finished = true + null.asInstanceOf[T] + } + } + + override def close() { + try { + logInfo("closing connection") + conn.close() + } catch { + case e: Exception => logWarning("Exception closing connection", e) + } + } + } + +} + +private[spark] class JdbcPartition(idx: Int, val lower: Long, val upper: Long) extends Partition { + override def index = idx +} + +object JdbcRDD { + val resultSetToObjectArray = (rs: ResultSet) => + Array.tabulate[Object](rs.getMetaData.getColumnCount)(i => rs.getObject(i + 1)) +} -- cgit v1.2.3 From dfac0aa5c2e5f46955b008b1e8d9ee5d8069efa5 Mon Sep 17 00:00:00 2001 From: koeninger Date: Mon, 22 Apr 2013 21:12:52 -0500 Subject: prevent mysql driver from pulling entire resultset into memory. explicitly close resultset and statement. --- core/src/main/scala/spark/rdd/JdbcRDD.scala | 25 ++++++++++++++++++++----- 1 file changed, 20 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/spark/rdd/JdbcRDD.scala b/core/src/main/scala/spark/rdd/JdbcRDD.scala index c8a5d76012..4c3054465c 100644 --- a/core/src/main/scala/spark/rdd/JdbcRDD.scala +++ b/core/src/main/scala/spark/rdd/JdbcRDD.scala @@ -15,7 +15,7 @@ import spark.util.NextIterator @param lowerBound the minimum value of the first placeholder @param upperBound the maximum value of the second placeholder The lower and upper bounds are inclusive. - @param numPartitions the amount of parallelism. + @param numPartitions the number of partitions. Given a lowerBound of 1, an upperBound of 20, and a numPartitions of 2, the query would be executed twice, once with (1, 10) and once with (11, 20) @param mapRow a function from a ResultSet to a single row of the desired result type(s). @@ -40,10 +40,15 @@ class JdbcRDD[T: ClassManifest]( toArray override def compute(thePart: Partition, context: TaskContext) = new NextIterator[T] { + context.addOnCompleteCallback{ () => closeIfNeeded() } val part = thePart.asInstanceOf[JdbcPartition] val conn = getConnection() - context.addOnCompleteCallback{ () => closeIfNeeded() } - val stmt = conn.prepareStatement(sql) + val stmt = conn.prepareStatement(sql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY) + // force mysql driver to stream rather than pull entire resultset into memory + if (conn.getMetaData.getURL.matches("jdbc:mysql:.*")) { + stmt.setFetchSize(Integer.MIN_VALUE) + logInfo("statement fetch size set to: " + stmt.getFetchSize + " to force MySQL streaming ") + } stmt.setLong(1, part.lower) stmt.setLong(2, part.upper) val rs = stmt.executeQuery() @@ -59,8 +64,18 @@ class JdbcRDD[T: ClassManifest]( override def close() { try { - logInfo("closing connection") - conn.close() + if (null != rs && ! rs.isClosed()) rs.close() + } catch { + case e: Exception => logWarning("Exception closing resultset", e) + } + try { + if (null != stmt && ! stmt.isClosed()) stmt.close() + } catch { + case e: Exception => logWarning("Exception closing statement", e) + } + try { + if (null != conn && ! stmt.isClosed()) conn.close() + logInfo("closed connection") } catch { case e: Exception => logWarning("Exception closing connection", e) } -- cgit v1.2.3 From 3da2305ed0d4add7127953e5240632f86053b4aa Mon Sep 17 00:00:00 2001 From: Cody Koeninger Date: Sat, 11 May 2013 23:59:07 -0500 Subject: code cleanup per rxin comments --- core/src/main/scala/spark/rdd/JdbcRDD.scala | 67 ++++++++++++++++------------- 1 file changed, 37 insertions(+), 30 deletions(-) diff --git a/core/src/main/scala/spark/rdd/JdbcRDD.scala b/core/src/main/scala/spark/rdd/JdbcRDD.scala index 4c3054465c..b0f7054233 100644 --- a/core/src/main/scala/spark/rdd/JdbcRDD.scala +++ b/core/src/main/scala/spark/rdd/JdbcRDD.scala @@ -5,23 +5,27 @@ import java.sql.{Connection, ResultSet} import spark.{Logging, Partition, RDD, SparkContext, TaskContext} import spark.util.NextIterator +private[spark] class JdbcPartition(idx: Int, val lower: Long, val upper: Long) extends Partition { + override def index = idx +} + /** - An RDD that executes an SQL query on a JDBC connection and reads results. - @param getConnection a function that returns an open Connection. - The RDD takes care of closing the connection. - @param sql the text of the query. - The query must contain two ? placeholders for parameters used to partition the results. - E.g. "select title, author from books where ? <= id and id <= ?" - @param lowerBound the minimum value of the first placeholder - @param upperBound the maximum value of the second placeholder - The lower and upper bounds are inclusive. - @param numPartitions the number of partitions. - Given a lowerBound of 1, an upperBound of 20, and a numPartitions of 2, - the query would be executed twice, once with (1, 10) and once with (11, 20) - @param mapRow a function from a ResultSet to a single row of the desired result type(s). - This should only call getInt, getString, etc; the RDD takes care of calling next. - The default maps a ResultSet to an array of Object. -*/ + * An RDD that executes an SQL query on a JDBC connection and reads results. + * @param getConnection a function that returns an open Connection. + * The RDD takes care of closing the connection. + * @param sql the text of the query. + * The query must contain two ? placeholders for parameters used to partition the results. + * E.g. "select title, author from books where ? <= id and id <= ?" + * @param lowerBound the minimum value of the first placeholder + * @param upperBound the maximum value of the second placeholder + * The lower and upper bounds are inclusive. + * @param numPartitions the number of partitions. + * Given a lowerBound of 1, an upperBound of 20, and a numPartitions of 2, + * the query would be executed twice, once with (1, 10) and once with (11, 20) + * @param mapRow a function from a ResultSet to a single row of the desired result type(s). + * This should only call getInt, getString, etc; the RDD takes care of calling next. + * The default maps a ResultSet to an array of Object. + */ class JdbcRDD[T: ClassManifest]( sc: SparkContext, getConnection: () => Connection, @@ -29,26 +33,33 @@ class JdbcRDD[T: ClassManifest]( lowerBound: Long, upperBound: Long, numPartitions: Int, - mapRow: (ResultSet) => T = JdbcRDD.resultSetToObjectArray) + mapRow: (ResultSet) => T = JdbcRDD.resultSetToObjectArray _) extends RDD[T](sc, Nil) with Logging { - override def getPartitions: Array[Partition] = - ParallelCollectionRDD.slice(lowerBound to upperBound, numPartitions). - filter(! _.isEmpty). - zipWithIndex. - map(x => new JdbcPartition(x._2, x._1.head, x._1.last)). - toArray + override def getPartitions: Array[Partition] = { + // bounds are inclusive, hence the + 1 here and - 1 on end + val length = 1 + upperBound - lowerBound + (0 until numPartitions).map(i => { + val start = lowerBound + ((i * length) / numPartitions).toLong + val end = lowerBound + (((i + 1) * length) / numPartitions).toLong - 1 + new JdbcPartition(i, start, end) + }).toArray + } override def compute(thePart: Partition, context: TaskContext) = new NextIterator[T] { context.addOnCompleteCallback{ () => closeIfNeeded() } val part = thePart.asInstanceOf[JdbcPartition] val conn = getConnection() val stmt = conn.prepareStatement(sql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY) - // force mysql driver to stream rather than pull entire resultset into memory + + // setFetchSize(Integer.MIN_VALUE) is a mysql driver specific way to force streaming results, + // rather than pulling entire resultset into memory. + // see http://dev.mysql.com/doc/refman/5.0/en/connector-j-reference-implementation-notes.html if (conn.getMetaData.getURL.matches("jdbc:mysql:.*")) { stmt.setFetchSize(Integer.MIN_VALUE) logInfo("statement fetch size set to: " + stmt.getFetchSize + " to force MySQL streaming ") } + stmt.setLong(1, part.lower) stmt.setLong(2, part.upper) val rs = stmt.executeQuery() @@ -81,14 +92,10 @@ class JdbcRDD[T: ClassManifest]( } } } - -} - -private[spark] class JdbcPartition(idx: Int, val lower: Long, val upper: Long) extends Partition { - override def index = idx } object JdbcRDD { - val resultSetToObjectArray = (rs: ResultSet) => + def resultSetToObjectArray(rs: ResultSet) = { Array.tabulate[Object](rs.getMetaData.getColumnCount)(i => rs.getObject(i + 1)) + } } -- cgit v1.2.3 From b16c4896f617f352bb230908b7c08c7c5b028434 Mon Sep 17 00:00:00 2001 From: Cody Koeninger Date: Tue, 14 May 2013 23:44:04 -0500 Subject: add test for JdbcRDD using embedded derby, per rxin suggestion --- .gitignore | 1 + core/src/test/scala/spark/rdd/JdbcRDDSuite.scala | 56 ++++++++++++++++++++++++ project/SparkBuild.scala | 1 + 3 files changed, 58 insertions(+) create mode 100644 core/src/test/scala/spark/rdd/JdbcRDDSuite.scala diff --git a/.gitignore b/.gitignore index 155e785b01..b87fc1ee79 100644 --- a/.gitignore +++ b/.gitignore @@ -36,3 +36,4 @@ streaming-tests.log dependency-reduced-pom.xml .ensime .ensime_lucene +derby.log diff --git a/core/src/test/scala/spark/rdd/JdbcRDDSuite.scala b/core/src/test/scala/spark/rdd/JdbcRDDSuite.scala new file mode 100644 index 0000000000..6afb0fa9bc --- /dev/null +++ b/core/src/test/scala/spark/rdd/JdbcRDDSuite.scala @@ -0,0 +1,56 @@ +package spark + +import org.scalatest.{ BeforeAndAfter, FunSuite } +import spark.SparkContext._ +import spark.rdd.JdbcRDD +import java.sql._ + +class JdbcRDDSuite extends FunSuite with BeforeAndAfter with LocalSparkContext { + + before { + Class.forName("org.apache.derby.jdbc.EmbeddedDriver") + val conn = DriverManager.getConnection("jdbc:derby:target/JdbcRDDSuiteDb;create=true") + try { + val create = conn.createStatement + create.execute(""" + CREATE TABLE FOO( + ID INTEGER NOT NULL GENERATED ALWAYS AS IDENTITY (START WITH 1, INCREMENT BY 1), + DATA INTEGER + )""") + create.close + val insert = conn.prepareStatement("INSERT INTO FOO(DATA) VALUES(?)") + (1 to 100).foreach { i => + insert.setInt(1, i * 2) + insert.executeUpdate + } + insert.close + } catch { + case e: SQLException if e.getSQLState == "X0Y32" => + // table exists + } finally { + conn.close + } + } + + test("basic functionality") { + sc = new SparkContext("local", "test") + val rdd = new JdbcRDD( + sc, + () => { DriverManager.getConnection("jdbc:derby:target/JdbcRDDSuiteDb") }, + "SELECT DATA FROM FOO WHERE ? <= ID AND ID <= ?", + 1, 100, 3, + (r: ResultSet) => { r.getInt(1) } ).cache + + assert(rdd.count === 100) + assert(rdd.reduce(_+_) === 10100) + } + + after { + try { + DriverManager.getConnection("jdbc:derby:;shutdown=true") + } catch { + case se: SQLException if se.getSQLState == "XJ015" => + // normal shutdown + } + } +} diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index f0b371b2cf..b11893590e 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -147,6 +147,7 @@ object SparkBuild extends Build { "cc.spray" % "spray-can" % "1.0-M2.1", "cc.spray" % "spray-server" % "1.0-M2.1", "cc.spray" % "spray-json_2.9.2" % "1.1.1", + "org.apache.derby" % "derby" % "10.4.2.0" % "test", "org.apache.mesos" % "mesos" % "0.9.0-incubating" ) ++ (if (HADOOP_MAJOR_VERSION == "2") Some("org.apache.hadoop" % "hadoop-client" % HADOOP_VERSION) else None).toSeq, unmanagedSourceDirectories in Compile <+= baseDirectory{ _ / ("src/hadoop" + HADOOP_MAJOR_VERSION + "/scala") } -- cgit v1.2.3