aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorCody Koeninger <cody@koeninger.org>2013-05-11 23:59:07 -0500
committerCody Koeninger <cody@koeninger.org>2013-05-11 23:59:07 -0500
commit3da2305ed0d4add7127953e5240632f86053b4aa (patch)
treee4384c6d22b403ebaf75a14f000c2c8cccd12057 /core
parentdfac0aa5c2e5f46955b008b1e8d9ee5d8069efa5 (diff)
downloadspark-3da2305ed0d4add7127953e5240632f86053b4aa.tar.gz
spark-3da2305ed0d4add7127953e5240632f86053b4aa.tar.bz2
spark-3da2305ed0d4add7127953e5240632f86053b4aa.zip
code cleanup per rxin comments
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/spark/rdd/JdbcRDD.scala67
1 files changed, 37 insertions, 30 deletions
diff --git a/core/src/main/scala/spark/rdd/JdbcRDD.scala b/core/src/main/scala/spark/rdd/JdbcRDD.scala
index 4c3054465c..b0f7054233 100644
--- a/core/src/main/scala/spark/rdd/JdbcRDD.scala
+++ b/core/src/main/scala/spark/rdd/JdbcRDD.scala
@@ -5,23 +5,27 @@ import java.sql.{Connection, ResultSet}
import spark.{Logging, Partition, RDD, SparkContext, TaskContext}
import spark.util.NextIterator
+private[spark] class JdbcPartition(idx: Int, val lower: Long, val upper: Long) extends Partition {
+ override def index = idx
+}
+
/**
- An RDD that executes an SQL query on a JDBC connection and reads results.
- @param getConnection a function that returns an open Connection.
- The RDD takes care of closing the connection.
- @param sql the text of the query.
- The query must contain two ? placeholders for parameters used to partition the results.
- E.g. "select title, author from books where ? <= id and id <= ?"
- @param lowerBound the minimum value of the first placeholder
- @param upperBound the maximum value of the second placeholder
- The lower and upper bounds are inclusive.
- @param numPartitions the number of partitions.
- Given a lowerBound of 1, an upperBound of 20, and a numPartitions of 2,
- the query would be executed twice, once with (1, 10) and once with (11, 20)
- @param mapRow a function from a ResultSet to a single row of the desired result type(s).
- This should only call getInt, getString, etc; the RDD takes care of calling next.
- The default maps a ResultSet to an array of Object.
-*/
+ * An RDD that executes an SQL query on a JDBC connection and reads results.
+ * @param getConnection a function that returns an open Connection.
+ * The RDD takes care of closing the connection.
+ * @param sql the text of the query.
+ * The query must contain two ? placeholders for parameters used to partition the results.
+ * E.g. "select title, author from books where ? <= id and id <= ?"
+ * @param lowerBound the minimum value of the first placeholder
+ * @param upperBound the maximum value of the second placeholder
+ * The lower and upper bounds are inclusive.
+ * @param numPartitions the number of partitions.
+ * Given a lowerBound of 1, an upperBound of 20, and a numPartitions of 2,
+ * the query would be executed twice, once with (1, 10) and once with (11, 20)
+ * @param mapRow a function from a ResultSet to a single row of the desired result type(s).
+ * This should only call getInt, getString, etc; the RDD takes care of calling next.
+ * The default maps a ResultSet to an array of Object.
+ */
class JdbcRDD[T: ClassManifest](
sc: SparkContext,
getConnection: () => Connection,
@@ -29,26 +33,33 @@ class JdbcRDD[T: ClassManifest](
lowerBound: Long,
upperBound: Long,
numPartitions: Int,
- mapRow: (ResultSet) => T = JdbcRDD.resultSetToObjectArray)
+ mapRow: (ResultSet) => T = JdbcRDD.resultSetToObjectArray _)
extends RDD[T](sc, Nil) with Logging {
- override def getPartitions: Array[Partition] =
- ParallelCollectionRDD.slice(lowerBound to upperBound, numPartitions).
- filter(! _.isEmpty).
- zipWithIndex.
- map(x => new JdbcPartition(x._2, x._1.head, x._1.last)).
- toArray
+ override def getPartitions: Array[Partition] = {
+ // bounds are inclusive, hence the + 1 here and - 1 on end
+ val length = 1 + upperBound - lowerBound
+ (0 until numPartitions).map(i => {
+ val start = lowerBound + ((i * length) / numPartitions).toLong
+ val end = lowerBound + (((i + 1) * length) / numPartitions).toLong - 1
+ new JdbcPartition(i, start, end)
+ }).toArray
+ }
override def compute(thePart: Partition, context: TaskContext) = new NextIterator[T] {
context.addOnCompleteCallback{ () => closeIfNeeded() }
val part = thePart.asInstanceOf[JdbcPartition]
val conn = getConnection()
val stmt = conn.prepareStatement(sql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)
- // force mysql driver to stream rather than pull entire resultset into memory
+
+ // setFetchSize(Integer.MIN_VALUE) is a mysql driver specific way to force streaming results,
+ // rather than pulling entire resultset into memory.
+ // see http://dev.mysql.com/doc/refman/5.0/en/connector-j-reference-implementation-notes.html
if (conn.getMetaData.getURL.matches("jdbc:mysql:.*")) {
stmt.setFetchSize(Integer.MIN_VALUE)
logInfo("statement fetch size set to: " + stmt.getFetchSize + " to force MySQL streaming ")
}
+
stmt.setLong(1, part.lower)
stmt.setLong(2, part.upper)
val rs = stmt.executeQuery()
@@ -81,14 +92,10 @@ class JdbcRDD[T: ClassManifest](
}
}
}
-
-}
-
-private[spark] class JdbcPartition(idx: Int, val lower: Long, val upper: Long) extends Partition {
- override def index = idx
}
object JdbcRDD {
- val resultSetToObjectArray = (rs: ResultSet) =>
+ def resultSetToObjectArray(rs: ResultSet) = {
Array.tabulate[Object](rs.getMetaData.getColumnCount)(i => rs.getObject(i + 1))
+ }
}