aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--R/pkg/inst/tests/testthat/test_sparkSQL.R2
-rw-r--r--python/pyspark/sql/dataframe.py26
-rw-r--r--python/pyspark/sql/group.py6
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala7
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala12
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala20
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java4
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java33
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala21
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala2
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala11
12 files changed, 84 insertions, 64 deletions
diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R
index 97625b94a0..40d5066a93 100644
--- a/R/pkg/inst/tests/testthat/test_sparkSQL.R
+++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R
@@ -1173,7 +1173,7 @@ test_that("group by, agg functions", {
expect_equal(3, count(mean(gd)))
expect_equal(3, count(max(gd)))
- expect_equal(30, collect(max(gd))[1, 2])
+ expect_equal(30, collect(max(gd))[2, 2])
expect_equal(1, collect(count(gd))[1, 2])
mockLines2 <- c("{\"name\":\"ID1\", \"value\": \"10\"}",
diff --git a/python/pyspark/sql/dataframe.py b/python/pyspark/sql/dataframe.py
index a7bc288e38..90a6b5d9c0 100644
--- a/python/pyspark/sql/dataframe.py
+++ b/python/pyspark/sql/dataframe.py
@@ -403,10 +403,10 @@ class DataFrame(object):
+---+-----+
|age| name|
+---+-----+
- | 2|Alice|
- | 2|Alice|
| 5| Bob|
| 5| Bob|
+ | 2|Alice|
+ | 2|Alice|
+---+-----+
>>> data = data.repartition(7, "age")
>>> data.show()
@@ -552,7 +552,7 @@ class DataFrame(object):
>>> df_as2 = df.alias("df_as2")
>>> joined_df = df_as1.join(df_as2, col("df_as1.name") == col("df_as2.name"), 'inner')
>>> joined_df.select(col("df_as1.name"), col("df_as2.name"), col("df_as2.age")).collect()
- [Row(name=u'Alice', name=u'Alice', age=2), Row(name=u'Bob', name=u'Bob', age=5)]
+ [Row(name=u'Bob', name=u'Bob', age=5), Row(name=u'Alice', name=u'Alice', age=2)]
"""
assert isinstance(alias, basestring), "alias should be a string"
return DataFrame(getattr(self._jdf, "as")(alias), self.sql_ctx)
@@ -573,14 +573,14 @@ class DataFrame(object):
One of `inner`, `outer`, `left_outer`, `right_outer`, `leftsemi`.
>>> df.join(df2, df.name == df2.name, 'outer').select(df.name, df2.height).collect()
- [Row(name=None, height=80), Row(name=u'Alice', height=None), Row(name=u'Bob', height=85)]
+ [Row(name=None, height=80), Row(name=u'Bob', height=85), Row(name=u'Alice', height=None)]
>>> df.join(df2, 'name', 'outer').select('name', 'height').collect()
- [Row(name=u'Tom', height=80), Row(name=u'Alice', height=None), Row(name=u'Bob', height=85)]
+ [Row(name=u'Tom', height=80), Row(name=u'Bob', height=85), Row(name=u'Alice', height=None)]
>>> cond = [df.name == df3.name, df.age == df3.age]
>>> df.join(df3, cond, 'outer').select(df.name, df3.age).collect()
- [Row(name=u'Bob', age=5), Row(name=u'Alice', age=2)]
+ [Row(name=u'Alice', age=2), Row(name=u'Bob', age=5)]
>>> df.join(df2, 'name').select(df.name, df2.height).collect()
[Row(name=u'Bob', height=85)]
@@ -880,9 +880,9 @@ class DataFrame(object):
>>> df.groupBy().avg().collect()
[Row(avg(age)=3.5)]
- >>> df.groupBy('name').agg({'age': 'mean'}).collect()
+ >>> sorted(df.groupBy('name').agg({'age': 'mean'}).collect())
[Row(name=u'Alice', avg(age)=2.0), Row(name=u'Bob', avg(age)=5.0)]
- >>> df.groupBy(df.name).avg().collect()
+ >>> sorted(df.groupBy(df.name).avg().collect())
[Row(name=u'Alice', avg(age)=2.0), Row(name=u'Bob', avg(age)=5.0)]
>>> df.groupBy(['name', df.age]).count().collect()
[Row(name=u'Bob', age=5, count=1), Row(name=u'Alice', age=2, count=1)]
@@ -901,11 +901,11 @@ class DataFrame(object):
+-----+----+-----+
| name| age|count|
+-----+----+-----+
- |Alice|null| 1|
+ |Alice| 2| 1|
| Bob| 5| 1|
| Bob|null| 1|
| null|null| 2|
- |Alice| 2| 1|
+ |Alice|null| 1|
+-----+----+-----+
"""
jgd = self._jdf.rollup(self._jcols(*cols))
@@ -923,12 +923,12 @@ class DataFrame(object):
| name| age|count|
+-----+----+-----+
| null| 2| 1|
- |Alice|null| 1|
+ |Alice| 2| 1|
| Bob| 5| 1|
- | Bob|null| 1|
| null| 5| 1|
+ | Bob|null| 1|
| null|null| 2|
- |Alice| 2| 1|
+ |Alice|null| 1|
+-----+----+-----+
"""
jgd = self._jdf.cube(self._jcols(*cols))
diff --git a/python/pyspark/sql/group.py b/python/pyspark/sql/group.py
index 9ca303a974..ee734cb439 100644
--- a/python/pyspark/sql/group.py
+++ b/python/pyspark/sql/group.py
@@ -74,11 +74,11 @@ class GroupedData(object):
or a list of :class:`Column`.
>>> gdf = df.groupBy(df.name)
- >>> gdf.agg({"*": "count"}).collect()
+ >>> sorted(gdf.agg({"*": "count"}).collect())
[Row(name=u'Alice', count(1)=1), Row(name=u'Bob', count(1)=1)]
>>> from pyspark.sql import functions as F
- >>> gdf.agg(F.min(df.age)).collect()
+ >>> sorted(gdf.agg(F.min(df.age)).collect())
[Row(name=u'Alice', min(age)=2), Row(name=u'Bob', min(age)=5)]
"""
assert exprs, "exprs should not be empty"
@@ -96,7 +96,7 @@ class GroupedData(object):
def count(self):
"""Counts the number of records for each group.
- >>> df.groupBy(df.age).count().collect()
+ >>> sorted(df.groupBy(df.age).count().collect())
[Row(age=2, count=1), Row(age=5, count=1)]
"""
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
index 1bfe0ecb1e..d6e10c412c 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/partitioning.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.catalyst.plans.physical
-import org.apache.spark.sql.catalyst.expressions.{Expression, SortOrder, Unevaluable}
+import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.types.{DataType, IntegerType}
/**
@@ -249,6 +249,11 @@ case class HashPartitioning(expressions: Seq[Expression], numPartitions: Int)
case _ => false
}
+ /**
+ * Returns an expression that will produce a valid partition ID(i.e. non-negative and is less
+ * than numPartitions) based on hashing expressions.
+ */
+ def partitionIdExpression: Expression = Pmod(new Murmur3Hash(expressions), Literal(numPartitions))
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
index 058d147c7d..3770883af1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Exchange.scala
@@ -143,7 +143,13 @@ case class Exchange(
val rdd = child.execute()
val part: Partitioner = newPartitioning match {
case RoundRobinPartitioning(numPartitions) => new HashPartitioner(numPartitions)
- case HashPartitioning(expressions, numPartitions) => new HashPartitioner(numPartitions)
+ case HashPartitioning(_, n) =>
+ new Partitioner {
+ override def numPartitions: Int = n
+ // For HashPartitioning, the partitioning key is already a valid partition ID, as we use
+ // `HashPartitioning.partitionIdExpression` to produce partitioning key.
+ override def getPartition(key: Any): Int = key.asInstanceOf[Int]
+ }
case RangePartitioning(sortingExpressions, numPartitions) =>
// Internally, RangePartitioner runs a job on the RDD that samples keys to compute
// partition bounds. To get accurate samples, we need to copy the mutable keys.
@@ -173,7 +179,9 @@ case class Exchange(
position += 1
position
}
- case HashPartitioning(expressions, _) => newMutableProjection(expressions, child.output)()
+ case h: HashPartitioning =>
+ val projection = UnsafeProjection.create(h.partitionIdExpression :: Nil, child.output)
+ row => projection(row).getInt(0)
case RangePartitioning(_, _) | SinglePartition => identity
case _ => sys.error(s"Exchange not implemented for $newPartitioning")
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala
index fff72872c1..fc77529b7d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/WriterContainer.scala
@@ -19,8 +19,6 @@ package org.apache.spark.sql.execution.datasources
import java.util.{Date, UUID}
-import scala.collection.JavaConverters._
-
import org.apache.hadoop.fs.Path
import org.apache.hadoop.mapreduce._
import org.apache.hadoop.mapreduce.lib.output.{FileOutputCommitter => MapReduceFileOutputCommitter}
@@ -30,6 +28,7 @@ import org.apache.spark._
import org.apache.spark.mapred.SparkHadoopMapRedUtil
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.UnsafeKVExternalSorter
import org.apache.spark.sql.sources.{HadoopFsRelation, OutputWriter, OutputWriterFactory}
@@ -322,9 +321,12 @@ private[sql] class DynamicPartitionWriterContainer(
spec => spec.sortColumnNames.map(c => inputSchema.find(_.name == c).get)
}
- private def bucketIdExpression: Option[Expression] = for {
- BucketSpec(numBuckets, _, _) <- bucketSpec
- } yield Pmod(new Murmur3Hash(bucketColumns), Literal(numBuckets))
+ private def bucketIdExpression: Option[Expression] = bucketSpec.map { spec =>
+ // Use `HashPartitioning.partitionIdExpression` as our bucket id expression, so that we can
+ // guarantee the data distribution is same between shuffle and bucketed data source, which
+ // enables us to only shuffle one side when join a bucketed table and a normal one.
+ HashPartitioning(bucketColumns, spec.numBuckets).partitionIdExpression
+ }
// Expressions that given a partition key build a string like: col1=val/col2=val/...
private def partitionStringExpression: Seq[Expression] = {
@@ -341,12 +343,8 @@ private[sql] class DynamicPartitionWriterContainer(
}
}
- private def getBucketIdFromKey(key: InternalRow): Option[Int] = {
- if (bucketSpec.isDefined) {
- Some(key.getInt(partitionColumns.length))
- } else {
- None
- }
+ private def getBucketIdFromKey(key: InternalRow): Option[Int] = bucketSpec.map { _ =>
+ key.getInt(partitionColumns.length)
}
/**
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
index 8e0b2dbca4..ac1607ba35 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
@@ -237,8 +237,8 @@ public class JavaDataFrameSuite {
DataFrame crosstab = df.stat().crosstab("a", "b");
String[] columnNames = crosstab.schema().fieldNames();
Assert.assertEquals("a_b", columnNames[0]);
- Assert.assertEquals("1", columnNames[1]);
- Assert.assertEquals("2", columnNames[2]);
+ Assert.assertEquals("2", columnNames[1]);
+ Assert.assertEquals("1", columnNames[2]);
Row[] rows = crosstab.collect();
Arrays.sort(rows, crosstabRowComparator);
Integer count = 1;
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index 9f8db39e33..1a3df1b117 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -187,7 +187,7 @@ public class JavaDatasetSuite implements Serializable {
}
}, Encoders.STRING());
- Assert.assertEquals(Arrays.asList("1a", "3foobar"), mapped.collectAsList());
+ Assert.assertEquals(asSet("1a", "3foobar"), toSet(mapped.collectAsList()));
Dataset<String> flatMapped = grouped.flatMapGroups(
new FlatMapGroupsFunction<Integer, String, String>() {
@@ -202,7 +202,7 @@ public class JavaDatasetSuite implements Serializable {
},
Encoders.STRING());
- Assert.assertEquals(Arrays.asList("1a", "3foobar"), flatMapped.collectAsList());
+ Assert.assertEquals(asSet("1a", "3foobar"), toSet(flatMapped.collectAsList()));
Dataset<Tuple2<Integer, String>> reduced = grouped.reduce(new ReduceFunction<String>() {
@Override
@@ -212,8 +212,8 @@ public class JavaDatasetSuite implements Serializable {
});
Assert.assertEquals(
- Arrays.asList(tuple2(1, "a"), tuple2(3, "foobar")),
- reduced.collectAsList());
+ asSet(tuple2(1, "a"), tuple2(3, "foobar")),
+ toSet(reduced.collectAsList()));
List<Integer> data2 = Arrays.asList(2, 6, 10);
Dataset<Integer> ds2 = context.createDataset(data2, Encoders.INT());
@@ -245,7 +245,7 @@ public class JavaDatasetSuite implements Serializable {
},
Encoders.STRING());
- Assert.assertEquals(Arrays.asList("1a#2", "3foobar#6", "5#10"), cogrouped.collectAsList());
+ Assert.assertEquals(asSet("1a#2", "3foobar#6", "5#10"), toSet(cogrouped.collectAsList()));
}
@Test
@@ -268,7 +268,7 @@ public class JavaDatasetSuite implements Serializable {
},
Encoders.STRING());
- Assert.assertEquals(Arrays.asList("1a", "3foobar"), mapped.collectAsList());
+ Assert.assertEquals(asSet("1a", "3foobar"), toSet(mapped.collectAsList()));
}
@Test
@@ -290,9 +290,7 @@ public class JavaDatasetSuite implements Serializable {
List<String> data = Arrays.asList("abc", "abc", "xyz");
Dataset<String> ds = context.createDataset(data, Encoders.STRING());
- Assert.assertEquals(
- Arrays.asList("abc", "xyz"),
- sort(ds.distinct().collectAsList().toArray(new String[0])));
+ Assert.assertEquals(asSet("abc", "xyz"), toSet(ds.distinct().collectAsList()));
List<String> data2 = Arrays.asList("xyz", "foo", "foo");
Dataset<String> ds2 = context.createDataset(data2, Encoders.STRING());
@@ -302,16 +300,23 @@ public class JavaDatasetSuite implements Serializable {
Dataset<String> unioned = ds.union(ds2);
Assert.assertEquals(
- Arrays.asList("abc", "abc", "foo", "foo", "xyz", "xyz"),
- sort(unioned.collectAsList().toArray(new String[0])));
+ Arrays.asList("abc", "abc", "xyz", "xyz", "foo", "foo"),
+ unioned.collectAsList());
Dataset<String> subtracted = ds.subtract(ds2);
Assert.assertEquals(Arrays.asList("abc", "abc"), subtracted.collectAsList());
}
- private <T extends Comparable<T>> List<T> sort(T[] data) {
- Arrays.sort(data);
- return Arrays.asList(data);
+ private <T> Set<T> toSet(List<T> records) {
+ Set<T> set = new HashSet<T>();
+ for (T record : records) {
+ set.add(record);
+ }
+ return set;
+ }
+
+ private <T> Set<T> asSet(T... records) {
+ return toSet(Arrays.asList(records));
}
@Test
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 983dfbdede..d6c140dfea 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -1083,17 +1083,20 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
// Walk each partition and verify that it is sorted descending and does not contain all
// the values.
df4.rdd.foreachPartition { p =>
- var previousValue: Int = -1
- var allSequential: Boolean = true
- p.foreach { r =>
- val v: Int = r.getInt(1)
- if (previousValue != -1) {
- if (previousValue < v) throw new SparkException("Partition is not ordered.")
- if (v + 1 != previousValue) allSequential = false
+ // Skip empty partition
+ if (p.hasNext) {
+ var previousValue: Int = -1
+ var allSequential: Boolean = true
+ p.foreach { r =>
+ val v: Int = r.getInt(1)
+ if (previousValue != -1) {
+ if (previousValue < v) throw new SparkException("Partition is not ordered.")
+ if (v + 1 != previousValue) allSequential = false
+ }
+ previousValue = v
}
- previousValue = v
+ if (allSequential) throw new SparkException("Partition should not be globally ordered")
}
- if (allSequential) throw new SparkException("Partition should not be globally ordered")
}
// Distribute and order by with multiple order bys
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 693f5aea2d..d7b86e3811 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -456,8 +456,8 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
implicit val kryoEncoder = Encoders.kryo[KryoData]
val ds = Seq(KryoData(1), KryoData(2)).toDS()
- assert(ds.groupBy(p => p).count().collect().toSeq ==
- Seq((KryoData(1), 1L), (KryoData(2), 1L)))
+ assert(ds.groupBy(p => p).count().collect().toSet ==
+ Set((KryoData(1), 1L), (KryoData(2), 1L)))
}
test("Kryo encoder self join") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index 5de0979606..03d67c4e91 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -806,7 +806,7 @@ class SQLQuerySuite extends QueryTest with SharedSQLContext {
sql("SELECT DISTINCT n FROM lowerCaseData ORDER BY n DESC")
.limit(2)
.registerTempTable("subset1")
- sql("SELECT DISTINCT n FROM lowerCaseData")
+ sql("SELECT DISTINCT n FROM lowerCaseData ORDER BY n ASC")
.limit(2)
.registerTempTable("subset2")
checkAnswer(
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala
index b718b7cefb..3ea9826544 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/BucketedWriteSuite.scala
@@ -20,11 +20,11 @@ package org.apache.spark.sql.sources
import java.io.File
import org.apache.spark.sql.{AnalysisException, QueryTest}
-import org.apache.spark.sql.catalyst.expressions.{Murmur3Hash, UnsafeProjection, UnsafeRow}
+import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
+import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning
import org.apache.spark.sql.functions._
import org.apache.spark.sql.hive.test.TestHiveSingleton
import org.apache.spark.sql.test.SQLTestUtils
-import org.apache.spark.util.Utils
class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingleton {
import testImplicits._
@@ -98,11 +98,12 @@ class BucketedWriteSuite extends QueryTest with SQLTestUtils with TestHiveSingle
val qe = readBack.select(bucketCols.map(col): _*).queryExecution
val rows = qe.toRdd.map(_.copy()).collect()
- val getHashCode =
- UnsafeProjection.create(new Murmur3Hash(qe.analyzed.output) :: Nil, qe.analyzed.output)
+ val getHashCode = UnsafeProjection.create(
+ HashPartitioning(qe.analyzed.output, 8).partitionIdExpression :: Nil,
+ qe.analyzed.output)
for (row <- rows) {
- val actualBucketId = Utils.nonNegativeMod(getHashCode(row).getInt(0), 8)
+ val actualBucketId = getHashCode(row).getInt(0)
assert(actualBucketId == bucketId)
}
}