aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2015-11-24 09:28:39 -0800
committerMichael Armbrust <michael@databricks.com>2015-11-24 09:28:39 -0800
commite5aaae6e1145b8c25c4872b2992ab425da9c6f9b (patch)
tree8d54936bb41ffca0fb875dab8b62c432f62880bc /sql
parentbe9dd1550c1816559d3d418a19c692e715f1c94e (diff)
downloadspark-e5aaae6e1145b8c25c4872b2992ab425da9c6f9b.tar.gz
spark-e5aaae6e1145b8c25c4872b2992ab425da9c6f9b.tar.bz2
spark-e5aaae6e1145b8c25c4872b2992ab425da9c6f9b.zip
[SPARK-11942][SQL] fix encoder life cycle for CoGroup
we should pass in resolved encodera to logical `CoGroup` and bind them in physical `CoGroup` Author: Wenchen Fan <wenchen@databricks.com> Closes #9928 from cloud-fan/cogroup.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala27
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala20
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala12
4 files changed, 41 insertions, 22 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 737e62fd59..5665fd7e5f 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -553,19 +553,22 @@ case class MapGroups[K, T, U](
/** Factory for constructing new `CoGroup` nodes. */
object CoGroup {
- def apply[K : Encoder, Left : Encoder, Right : Encoder, R : Encoder](
- func: (K, Iterator[Left], Iterator[Right]) => TraversableOnce[R],
+ def apply[Key, Left, Right, Result : Encoder](
+ func: (Key, Iterator[Left], Iterator[Right]) => TraversableOnce[Result],
+ keyEnc: ExpressionEncoder[Key],
+ leftEnc: ExpressionEncoder[Left],
+ rightEnc: ExpressionEncoder[Right],
leftGroup: Seq[Attribute],
rightGroup: Seq[Attribute],
left: LogicalPlan,
- right: LogicalPlan): CoGroup[K, Left, Right, R] = {
+ right: LogicalPlan): CoGroup[Key, Left, Right, Result] = {
CoGroup(
func,
- encoderFor[K],
- encoderFor[Left],
- encoderFor[Right],
- encoderFor[R],
- encoderFor[R].schema.toAttributes,
+ keyEnc,
+ leftEnc,
+ rightEnc,
+ encoderFor[Result],
+ encoderFor[Result].schema.toAttributes,
leftGroup,
rightGroup,
left,
@@ -577,12 +580,12 @@ object CoGroup {
* A relation produced by applying `func` to each grouping key and associated values from left and
* right children.
*/
-case class CoGroup[K, Left, Right, R](
- func: (K, Iterator[Left], Iterator[Right]) => TraversableOnce[R],
- kEncoder: ExpressionEncoder[K],
+case class CoGroup[Key, Left, Right, Result](
+ func: (Key, Iterator[Left], Iterator[Right]) => TraversableOnce[Result],
+ keyEnc: ExpressionEncoder[Key],
leftEnc: ExpressionEncoder[Left],
rightEnc: ExpressionEncoder[Right],
- rEncoder: ExpressionEncoder[R],
+ resultEnc: ExpressionEncoder[Result],
output: Seq[Attribute],
leftGroup: Seq[Attribute],
rightGroup: Seq[Attribute],
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
index 793a86b132..a10a89342f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
@@ -304,11 +304,13 @@ class GroupedDataset[K, V] private[sql](
def cogroup[U, R : Encoder](
other: GroupedDataset[K, U])(
f: (K, Iterator[V], Iterator[U]) => TraversableOnce[R]): Dataset[R] = {
- implicit def uEnc: Encoder[U] = other.unresolvedVEncoder
new Dataset[R](
sqlContext,
CoGroup(
f,
+ this.resolvedKEncoder,
+ this.resolvedVEncoder,
+ other.resolvedVEncoder,
this.groupingAttributes,
other.groupingAttributes,
this.logicalPlan,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index d57b8e7a9e..a42aea0b96 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -375,12 +375,12 @@ case class MapGroups[K, T, U](
* iterators containing all elements in the group from left and right side.
* The result of this function is encoded and flattened before being output.
*/
-case class CoGroup[K, Left, Right, R](
- func: (K, Iterator[Left], Iterator[Right]) => TraversableOnce[R],
- kEncoder: ExpressionEncoder[K],
+case class CoGroup[Key, Left, Right, Result](
+ func: (Key, Iterator[Left], Iterator[Right]) => TraversableOnce[Result],
+ keyEnc: ExpressionEncoder[Key],
leftEnc: ExpressionEncoder[Left],
rightEnc: ExpressionEncoder[Right],
- rEncoder: ExpressionEncoder[R],
+ resultEnc: ExpressionEncoder[Result],
output: Seq[Attribute],
leftGroup: Seq[Attribute],
rightGroup: Seq[Attribute],
@@ -397,15 +397,17 @@ case class CoGroup[K, Left, Right, R](
left.execute().zipPartitions(right.execute()) { (leftData, rightData) =>
val leftGrouped = GroupedIterator(leftData, leftGroup, left.output)
val rightGrouped = GroupedIterator(rightData, rightGroup, right.output)
- val groupKeyEncoder = kEncoder.bind(leftGroup)
+ val boundKeyEnc = keyEnc.bind(leftGroup)
+ val boundLeftEnc = leftEnc.bind(left.output)
+ val boundRightEnc = rightEnc.bind(right.output)
new CoGroupedIterator(leftGrouped, rightGrouped, leftGroup).flatMap {
case (key, leftResult, rightResult) =>
val result = func(
- groupKeyEncoder.fromRow(key),
- leftResult.map(leftEnc.fromRow),
- rightResult.map(rightEnc.fromRow))
- result.map(rEncoder.toRow)
+ boundKeyEnc.fromRow(key),
+ leftResult.map(boundLeftEnc.fromRow),
+ rightResult.map(boundRightEnc.fromRow))
+ result.map(resultEnc.toRow)
}
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index dbdd7ba14a..13eede1b17 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -340,6 +340,18 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
1 -> "a#", 2 -> "#q", 3 -> "abcfoo#w", 5 -> "hello#er")
}
+ test("cogroup with complex data") {
+ val ds1 = Seq(1 -> ClassData("a", 1), 2 -> ClassData("b", 2)).toDS()
+ val ds2 = Seq(2 -> ClassData("c", 3), 3 -> ClassData("d", 4)).toDS()
+ val cogrouped = ds1.groupBy(_._1).cogroup(ds2.groupBy(_._1)) { case (key, data1, data2) =>
+ Iterator(key -> (data1.map(_._2.a).mkString + data2.map(_._2.a).mkString))
+ }
+
+ checkAnswer(
+ cogrouped,
+ 1 -> "a", 2 -> "bc", 3 -> "d")
+ }
+
test("SPARK-11436: we should rebind right encoder when join 2 datasets") {
val ds1 = Seq("1", "2").toDS().as("a")
val ds2 = Seq(2, 3).toDS().as("b")