From c2ea79f96acd076351b48162644ed1cff4c8e090 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 13 Jan 2016 12:29:02 -0800 Subject: [SPARK-12642][SQL] improve the hash expression to be decoupled from unsafe row https://issues.apache.org/jira/browse/SPARK-12642 Author: Wenchen Fan Closes #10694 from cloud-fan/hash-expr. --- .../spark/sql/sources/BucketedWriteSuite.scala | 26 +++++++++++++--------- 1 file changed, 16 insertions(+), 10 deletions(-) (limited to 'sql/hive') diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala index 7f1745705a..b718b7cefb 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala @@ -20,10 +20,11 @@ package org.apache.spark.sql.sources import java.io.File import org.apache.spark.sql.{AnalysisException, QueryTest} -import org.apache.spark.sql.catalyst.expressions.UnsafeRow +import org.apache.spark.sql.catalyst.expressions.{Murmur3Hash, UnsafeProjection, UnsafeRow} import org.apache.spark.sql.functions._ import org.apache.spark.sql.hive.test.TestHiveSingleton import org.apache.spark.sql.test.SQLTestUtils +import org.apache.spark.util.Utils class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingleton { import testImplicits._ @@ -70,6 +71,8 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle } } + private val df = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k") + private def testBucketing( dataDir: File, source: String, @@ -82,27 +85,30 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle assert(groupedBucketFiles.size <= 8) for ((bucketId, bucketFiles) <- groupedBucketFiles) { - for (bucketFile <- bucketFiles) { - val df = sqlContext.read.format(source).load(bucketFile.getAbsolutePath) - .select((bucketCols ++ sortCols).map(col): _*) + for (bucketFilePath <- bucketFiles.map(_.getAbsolutePath)) { + val types = df.select((bucketCols ++ sortCols).map(col): _*).schema.map(_.dataType) + val columns = (bucketCols ++ sortCols).zip(types).map { + case (colName, dt) => col(colName).cast(dt) + } + val readBack = sqlContext.read.format(source).load(bucketFilePath).select(columns: _*) if (sortCols.nonEmpty) { - checkAnswer(df.sort(sortCols.map(col): _*), df.collect()) + checkAnswer(readBack.sort(sortCols.map(col): _*), readBack.collect()) } - val rows = df.select(bucketCols.map(col): _*).queryExecution.toRdd.map(_.copy()).collect() + val qe = readBack.select(bucketCols.map(col): _*).queryExecution + val rows = qe.toRdd.map(_.copy()).collect() + val getHashCode = + UnsafeProjection.create(new Murmur3Hash(qe.analyzed.output) :: Nil, qe.analyzed.output) for (row <- rows) { - assert(row.isInstanceOf[UnsafeRow]) - val actualBucketId = (row.hashCode() % 8 + 8) % 8 + val actualBucketId = Utils.nonNegativeMod(getHashCode(row).getInt(0), 8) assert(actualBucketId == bucketId) } } } } - private val df = (0 until 50).map(i => (i % 5, i % 13, i.toString)).toDF("i", "j", "k") - test("write bucketed data") { for (source <- Seq("parquet", "json", "orc")) { withTable("bucketed_table") { -- cgit v1.2.3