aboutsummaryrefslogtreecommitdiff
path: root/sql/core
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2016-01-13 22:43:28 -0800
committerReynold Xin <rxin@databricks.com>2016-01-13 22:43:28 -0800
commit962e9bcf94da6f5134983f2bf1e56c5cd84f2bf7 (patch)
treefa7174220efa51f56287d32bc82a379508ee4c17 /sql/core
parente2ae7bd046f6d8d6a375c2e81e5a51d7d78ca984 (diff)
downloadspark-962e9bcf94da6f5134983f2bf1e56c5cd84f2bf7.tar.gz
spark-962e9bcf94da6f5134983f2bf1e56c5cd84f2bf7.tar.bz2
spark-962e9bcf94da6f5134983f2bf1e56c5cd84f2bf7.zip
[SPARK-12756][SQL] use hash expression in Exchange
This PR makes bucketing and exchange share one common hash algorithm, so that we can guarantee the data distribution is same between shuffle and bucketed data source, which enables us to only shuffle one side when join a bucketed table and a normal one. This PR also fixes the tests that are broken by the new hash behaviour in shuffle. Author: Wenchen Fan <wenchen@databricks.com> Closes #10703 from cloud-fan/use-hash-expr-in-shuffle.
Diffstat (limited to 'sql/core')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala12
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala20
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java4
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java33
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala21
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala2
7 files changed, 55 insertions, 41 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
index 058d147c7d..3770883af1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
@@ -143,7 +143,13 @@ case class Exchange(
val rdd = child.execute()
val part: Partitioner = newPartitioning match {
case RoundRobinPartitioning(numPartitions) => new HashPartitioner(numPartitions)
- case HashPartitioning(expressions, numPartitions) => new HashPartitioner(numPartitions)
+ case HashPartitioning(_, n) =>
+ new Partitioner {
+ override def numPartitions: Int = n
+ // For HashPartitioning, the partitioning key is already a valid partition ID, as we use
+ // `HashPartitioning.partitionIdExpression` to produce partitioning key.
+ override def getPartition(key: Any): Int = key.asInstanceOf[Int]
+ }
case RangePartitioning(sortingExpressions, numPartitions) =>
// Internally, RangePartitioner runs a job on the RDD that samples keys to compute
// partition bounds. To get accurate samples, we need to copy the mutable keys.
@@ -173,7 +179,9 @@ case class Exchange(
position += 1
position
}
- case HashPartitioning(expressions, _) => newMutableProjection(expressions, child.output)()
+ case h: HashPartitioning =>
+ val projection = UnsafeProjection.create(h.partitionIdExpression :: Nil, child.output)
+ row => projection(row).getInt(0)
case RangePartitioning(_, _) | SinglePartition => identity
case _ => sys.error(s"Exchange not implemented for $newPartitioning")
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala
index fff72872c1..fc77529b7d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala
@@ -19,8 +19,6 @@ package org.apache.spark.sql.execution.datasources
import java.util.{Date, UUID}
-import scala.collection.JavaConverters._
-
import org.apache.hadoop.fs.Path
import org.apache.hadoop.mapreduce._
import org.apache.hadoop.mapreduce.lib.output.{FileOutputCommitter => MapReduceFileOutputCommitter}
@@ -30,6 +28,7 @@ import org.apache.spark._
import org.apache.spark.mapred.SparkHadoopMapRedUtil
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.UnsafeKVExternalSorter
import org.apache.spark.sql.sources.{HadoopFsRelation, OutputWriter, OutputWriterFactory}
@@ -322,9 +321,12 @@ private[sql] class DynamicPartitionWriterContainer(
spec => spec.sortColumnNames.map(c => inputSchema.find(_.name == c).get)
}
- private def bucketIdExpression: Option[Expression] = for {
- BucketSpec(numBuckets, _, _) <- bucketSpec
- } yield Pmod(new Murmur3Hash(bucketColumns), Literal(numBuckets))
+ private def bucketIdExpression: Option[Expression] = bucketSpec.map { spec =>
+ // Use `HashPartitioning.partitionIdExpression` as our bucket id expression, so that we can
+ // guarantee the data distribution is same between shuffle and bucketed data source, which
+ // enables us to only shuffle one side when join a bucketed table and a normal one.
+ HashPartitioning(bucketColumns, spec.numBuckets).partitionIdExpression
+ }
// Expressions that given a partition key build a string like: col1=val/col2=val/...
private def partitionStringExpression: Seq[Expression] = {
@@ -341,12 +343,8 @@ private[sql] class DynamicPartitionWriterContainer(
}
}
- private def getBucketIdFromKey(key: InternalRow): Option[Int] = {
- if (bucketSpec.isDefined) {
- Some(key.getInt(partitionColumns.length))
- } else {
- None
- }
+ private def getBucketIdFromKey(key: InternalRow): Option[Int] = bucketSpec.map { _ =>
+ key.getInt(partitionColumns.length)
}
/**
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
index 8e0b2dbca4..ac1607ba35 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
@@ -237,8 +237,8 @@ public class JavaDataFrameSuite {
DataFrame crosstab = df.stat().crosstab("a", "b");
String[] columnNames = crosstab.schema().fieldNames();
Assert.assertEquals("a_b", columnNames[0]);
- Assert.assertEquals("1", columnNames[1]);
- Assert.assertEquals("2", columnNames[2]);
+ Assert.assertEquals("2", columnNames[1]);
+ Assert.assertEquals("1", columnNames[2]);
Row[] rows = crosstab.collect();
Arrays.sort(rows, crosstabRowComparator);
Integer count = 1;
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index 9f8db39e33..1a3df1b117 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -187,7 +187,7 @@ public class JavaDatasetSuite implements Serializable {
}
}, Encoders.STRING());
- Assert.assertEquals(Arrays.asList("1a", "3foobar"), mapped.collectAsList());
+ Assert.assertEquals(asSet("1a", "3foobar"), toSet(mapped.collectAsList()));
Dataset<String> flatMapped = grouped.flatMapGroups(
new FlatMapGroupsFunction<Integer, String, String>() {
@@ -202,7 +202,7 @@ public class JavaDatasetSuite implements Serializable {
},
Encoders.STRING());
- Assert.assertEquals(Arrays.asList("1a", "3foobar"), flatMapped.collectAsList());
+ Assert.assertEquals(asSet("1a", "3foobar"), toSet(flatMapped.collectAsList()));
Dataset<Tuple2<Integer, String>> reduced = grouped.reduce(new ReduceFunction<String>() {
@Override
@@ -212,8 +212,8 @@ public class JavaDatasetSuite implements Serializable {
});
Assert.assertEquals(
- Arrays.asList(tuple2(1, "a"), tuple2(3, "foobar")),
- reduced.collectAsList());
+ asSet(tuple2(1, "a"), tuple2(3, "foobar")),
+ toSet(reduced.collectAsList()));
List<Integer> data2 = Arrays.asList(2, 6, 10);
Dataset<Integer> ds2 = context.createDataset(data2, Encoders.INT());
@@ -245,7 +245,7 @@ public class JavaDatasetSuite implements Serializable {
},
Encoders.STRING());
- Assert.assertEquals(Arrays.asList("1a#2", "3foobar#6", "5#10"), cogrouped.collectAsList());
+ Assert.assertEquals(asSet("1a#2", "3foobar#6", "5#10"), toSet(cogrouped.collectAsList()));
}
@Test
@@ -268,7 +268,7 @@ public class JavaDatasetSuite implements Serializable {
},
Encoders.STRING());
- Assert.assertEquals(Arrays.asList("1a", "3foobar"), mapped.collectAsList());
+ Assert.assertEquals(asSet("1a", "3foobar"), toSet(mapped.collectAsList()));
}
@Test
@@ -290,9 +290,7 @@ public class JavaDatasetSuite implements Serializable {
List<String> data = Arrays.asList("abc", "abc", "xyz");
Dataset<String> ds = context.createDataset(data, Encoders.STRING());
- Assert.assertEquals(
- Arrays.asList("abc", "xyz"),
- sort(ds.distinct().collectAsList().toArray(new String[0])));
+ Assert.assertEquals(asSet("abc", "xyz"), toSet(ds.distinct().collectAsList()));
List<String> data2 = Arrays.asList("xyz", "foo", "foo");
Dataset<String> ds2 = context.createDataset(data2, Encoders.STRING());
@@ -302,16 +300,23 @@ public class JavaDatasetSuite implements Serializable {
Dataset<String> unioned = ds.union(ds2);
Assert.assertEquals(
- Arrays.asList("abc", "abc", "foo", "foo", "xyz", "xyz"),
- sort(unioned.collectAsList().toArray(new String[0])));
+ Arrays.asList("abc", "abc", "xyz", "xyz", "foo", "foo"),
+ unioned.collectAsList());
Dataset<String> subtracted = ds.subtract(ds2);
Assert.assertEquals(Arrays.asList("abc", "abc"), subtracted.collectAsList());
}
- private <T extends Comparable<T>> List<T> sort(T[] data) {
- Arrays.sort(data);
- return Arrays.asList(data);
+ private <T> Set<T> toSet(List<T> records) {
+ Set<T> set = new HashSet<T>();
+ for (T record : records) {
+ set.add(record);
+ }
+ return set;
+ }
+
+ private <T> Set<T> asSet(T... records) {
+ return toSet(Arrays.asList(records));
}
@Test
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 983dfbdede..d6c140dfea 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -1083,17 +1083,20 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
// Walk each partition and verify that it is sorted descending and does not contain all
// the values.
df4.rdd.foreachPartition { p =>
- var previousValue: Int = -1
- var allSequential: Boolean = true
- p.foreach { r =>
- val v: Int = r.getInt(1)
- if (previousValue != -1) {
- if (previousValue < v) throw new SparkException("Partition is not ordered.")
- if (v + 1 != previousValue) allSequential = false
+ // Skip empty partition
+ if (p.hasNext) {
+ var previousValue: Int = -1
+ var allSequential: Boolean = true
+ p.foreach { r =>
+ val v: Int = r.getInt(1)
+ if (previousValue != -1) {
+ if (previousValue < v) throw new SparkException("Partition is not ordered.")
+ if (v + 1 != previousValue) allSequential = false
+ }
+ previousValue = v
}
- previousValue = v
+ if (allSequential) throw new SparkException("Partition should not be globally ordered")
}
- if (allSequential) throw new SparkException("Partition should not be globally ordered")
}
// Distribute and order by with multiple order bys
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 693f5aea2d..d7b86e3811 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -456,8 +456,8 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
implicit val kryoEncoder = Encoders.kryo[KryoData]
val ds = Seq(KryoData(1), KryoData(2)).toDS()
- assert(ds.groupBy(p => p).count().collect().toSeq ==
- Seq((KryoData(1), 1L), (KryoData(2), 1L)))
+ assert(ds.groupBy(p => p).count().collect().toSet ==
+ Set((KryoData(1), 1L), (KryoData(2), 1L)))
}
test("Kryo encoder self join") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 5de0979606..03d67c4e91 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -806,7 +806,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
sql("SELECT DISTINCT n FROM lowerCaseData ORDER BY n DESC")
.limit(2)
.registerTempTable("subset1")
- sql("SELECT DISTINCT n FROM lowerCaseData")
+ sql("SELECT DISTINCT n FROM lowerCaseData ORDER BY n ASC")
.limit(2)
.registerTempTable("subset2")
checkAnswer(