diff options
Diffstat (limited to 'sql')
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala | 48 | ||||
-rw-r--r-- | sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala | 65 |
2 files changed, 98 insertions, 15 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala index 233b7891d6..11613dd912 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCRelation.scala @@ -21,6 +21,7 @@ import java.util.Properties import scala.collection.mutable.ArrayBuffer +import org.apache.spark.internal.Logging import org.apache.spark.Partition import org.apache.spark.rdd.RDD import org.apache.spark.sql.{DataFrame, Row, SaveMode, SparkSession, SQLContext} @@ -36,7 +37,7 @@ private[sql] case class JDBCPartitioningInfo( upperBound: Long, numPartitions: Int) -private[sql] object JDBCRelation { +private[sql] object JDBCRelation extends Logging { /** * Given a partitioning schematic (a column of integral type, a number of * partitions, and upper and lower bounds on the column's value), generate @@ -52,29 +53,46 @@ private[sql] object JDBCRelation { * @return an array of partitions with where clause for each partition */ def columnPartition(partitioning: JDBCPartitioningInfo): Array[Partition] = { - if (partitioning == null) return Array[Partition](JDBCPartition(null, 0)) + if (partitioning == null || partitioning.numPartitions <= 1 || + partitioning.lowerBound == partitioning.upperBound) { + return Array[Partition](JDBCPartition(null, 0)) + } - val numPartitions = partitioning.numPartitions - val column = partitioning.column - if (numPartitions == 1) return Array[Partition](JDBCPartition(null, 0)) + val lowerBound = partitioning.lowerBound + val upperBound = partitioning.upperBound + require (lowerBound <= upperBound, + "Operation not allowed: the lower bound of partitioning column is larger than the upper " + + s"bound. Lower bound: $lowerBound; Upper bound: $upperBound") + + val numPartitions = + if ((upperBound - lowerBound) >= partitioning.numPartitions) { + partitioning.numPartitions + } else { + logWarning("The number of partitions is reduced because the specified number of " + + "partitions is less than the difference between upper bound and lower bound. " + + s"Updated number of partitions: ${upperBound - lowerBound}; Input number of " + + s"partitions: ${partitioning.numPartitions}; Lower bound: $lowerBound; " + + s"Upper bound: $upperBound.") + upperBound - lowerBound + } // Overflow and silliness can happen if you subtract then divide. // Here we get a little roundoff, but that's (hopefully) OK. - val stride: Long = (partitioning.upperBound / numPartitions - - partitioning.lowerBound / numPartitions) + val stride: Long = upperBound / numPartitions - lowerBound / numPartitions + val column = partitioning.column var i: Int = 0 - var currentValue: Long = partitioning.lowerBound + var currentValue: Long = lowerBound var ans = new ArrayBuffer[Partition]() while (i < numPartitions) { - val lowerBound = if (i != 0) s"$column >= $currentValue" else null + val lBound = if (i != 0) s"$column >= $currentValue" else null currentValue += stride - val upperBound = if (i != numPartitions - 1) s"$column < $currentValue" else null + val uBound = if (i != numPartitions - 1) s"$column < $currentValue" else null val whereClause = - if (upperBound == null) { - lowerBound - } else if (lowerBound == null) { - s"$upperBound or $column is null" + if (uBound == null) { + lBound + } else if (lBound == null) { + s"$uBound or $column is null" } else { - s"$lowerBound AND $upperBound" + s"$lBound AND $uBound" } ans += JDBCPartition(whereClause, i) i = i + 1 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index d6ec40c18b..fd6671a39b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -184,6 +184,16 @@ class JDBCSuite extends SparkFunSuite "insert into test.emp values ('kathy', null, null)").executeUpdate() conn.commit() + conn.prepareStatement( + "create table test.seq(id INTEGER)").executeUpdate() + (0 to 6).foreach { value => + conn.prepareStatement( + s"insert into test.seq values ($value)").executeUpdate() + } + conn.prepareStatement( + "insert into test.seq values (null)").executeUpdate() + conn.commit() + sql( s""" |CREATE TEMPORARY TABLE nullparts @@ -373,6 +383,61 @@ class JDBCSuite extends SparkFunSuite .collect().length === 4) } + test("Partitioning on column where numPartitions is zero") { + val res = spark.read.jdbc( + url = urlWithUserAndPass, + table = "TEST.seq", + columnName = "id", + lowerBound = 0, + upperBound = 4, + numPartitions = 0, + connectionProperties = new Properties + ) + assert(res.count() === 8) + } + + test("Partitioning on column where numPartitions are more than the number of total rows") { + val res = spark.read.jdbc( + url = urlWithUserAndPass, + table = "TEST.seq", + columnName = "id", + lowerBound = 1, + upperBound = 5, + numPartitions = 10, + connectionProperties = new Properties + ) + assert(res.count() === 8) + } + + test("Partitioning on column where lowerBound is equal to upperBound") { + val res = spark.read.jdbc( + url = urlWithUserAndPass, + table = "TEST.seq", + columnName = "id", + lowerBound = 5, + upperBound = 5, + numPartitions = 4, + connectionProperties = new Properties + ) + assert(res.count() === 8) + } + + test("Partitioning on column where lowerBound is larger than upperBound") { + val e = intercept[IllegalArgumentException] { + spark.read.jdbc( + url = urlWithUserAndPass, + table = "TEST.seq", + columnName = "id", + lowerBound = 5, + upperBound = 1, + numPartitions = 3, + connectionProperties = new Properties + ) + }.getMessage + assert(e.contains("Operation not allowed: the lower bound of partitioning column " + + "is larger than the upper bound. Lower bound: 5; Upper bound: 1")) + } + test("SELECT * on partitioned table with a nullable partition column") { assert(sql("SELECT * FROM nullparts").collect().size == 4) } |