aboutsummaryrefslogtreecommitdiff
path: root/sql/hive
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2016-01-13 12:29:02 -0800
committerReynold Xin <rxin@databricks.com>2016-01-13 12:29:02 -0800
commitc2ea79f96acd076351b48162644ed1cff4c8e090 (patch)
tree55ca22bdd84dac3cb225cd2b9bddaf0c11c93d19 /sql/hive
parente4e0b3f7b2945aae5ec7c3d68296010bbc5160cf (diff)
downloadspark-c2ea79f96acd076351b48162644ed1cff4c8e090.tar.gz
spark-c2ea79f96acd076351b48162644ed1cff4c8e090.tar.bz2
spark-c2ea79f96acd076351b48162644ed1cff4c8e090.zip
[SPARK-12642][SQL] improve the hash expression to be decoupled from unsafe row
https://issues.apache.org/jira/browse/SPARK-12642 Author: Wenchen Fan <wenchen@databricks.com> Closes #10694 from cloud-fan/hash-expr.
Diffstat (limited to 'sql/hive')
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala26
1 files changed, 16 insertions, 10 deletions
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala
index 7f1745705a..b718b7cefb 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala
@@ -20,10 +20,11 @@ package org.apache.spark.sql.sources
import java.io.File
import org.apache.spark.sql.{AnalysisException, QueryTest}
-import org.apache.spark.sql.catalyst.expressions.UnsafeRow
+import org.apache.spark.sql.catalyst.expressions.{Murmur3Hash, UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.test.SQLTestUtils
+import org.apache.spark.util.Utils
class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
import testImplicits._
@@ -70,6 +71,8 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle
}
}
+ private val df = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k")
+
private def testBucketing(
dataDir: File,
source: String,
@@ -82,27 +85,30 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle
assert(groupedBucketFiles.size <= 8)
for ((bucketId, bucketFiles) <- groupedBucketFiles) {
- for (bucketFile <- bucketFiles) {
- val df = sqlContext.read.format(source).load(bucketFile.getAbsolutePath)
- .select((bucketCols ++ sortCols).map(col): _*)
+ for (bucketFilePath <- bucketFiles.map(_.getAbsolutePath)) {
+ val types = df.select((bucketCols ++ sortCols).map(col): _*).schema.map(_.dataType)
+ val columns = (bucketCols ++ sortCols).zip(types).map {
+ case (colName, dt) => col(colName).cast(dt)
+ }
+ val readBack = sqlContext.read.format(source).load(bucketFilePath).select(columns: _*)
if (sortCols.nonEmpty) {
- checkAnswer(df.sort(sortCols.map(col): _*), df.collect())
+ checkAnswer(readBack.sort(sortCols.map(col): _*), readBack.collect())
}
- val rows = df.select(bucketCols.map(col): _*).queryExecution.toRdd.map(_.copy()).collect()
+ val qe = readBack.select(bucketCols.map(col): _*).queryExecution
+ val rows = qe.toRdd.map(_.copy()).collect()
+ val getHashCode =
+ UnsafeProjection.create(new Murmur3Hash(qe.analyzed.output) :: Nil, qe.analyzed.output)
for (row <- rows) {
- assert(row.isInstanceOf[UnsafeRow])
- val actualBucketId = (row.hashCode() % 8 + 8) % 8
+ val actualBucketId = Utils.nonNegativeMod(getHashCode(row).getInt(0), 8)
assert(actualBucketId == bucketId)
}
}
}
}
- private val df = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k")
-
test("write bucketed data") {
for (source <- Seq("parquet", "json", "orc")) {
withTable("bucketed_table") {