aboutsummaryrefslogtreecommitdiff
path: root/sql/core
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2015-10-28 13:58:52 +0100
committerMichael Armbrust <michael@databricks.com>2015-10-28 13:58:52 +0100
commit075ce4914fdcbbcc7286c3c30cb940ed28d474d2 (patch)
treef4eaa13efe6d0322649ad1be161e84ba9dd35e7e /sql/core
parent5f1cee6f158adb1f9f485ed1d529c56bace68adc (diff)
downloadspark-075ce4914fdcbbcc7286c3c30cb940ed28d474d2.tar.gz
spark-075ce4914fdcbbcc7286c3c30cb940ed28d474d2.tar.bz2
spark-075ce4914fdcbbcc7286c3c30cb940ed28d474d2.zip
[SPARK-11313][SQL] implement cogroup on DataSets (support 2 datasets)
A simpler version of https://github.com/apache/spark/pull/9279, only support 2 datasets. Author: Wenchen Fan <wenchen@databricks.com> Closes #9324 from cloud-fan/cogroup2.
Diffstat (limited to 'sql/core')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala20
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/CoGroupedIterator.scala89
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala41
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala12
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/CoGroupedIteratorSuite.scala51
6 files changed, 217 insertions, 0 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
index 89a16dd8b0..612f2b60cd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
@@ -65,4 +65,24 @@ class GroupedDataset[K, T] private[sql](
sqlContext,
MapGroups(f, groupingAttributes, logicalPlan))
}
+
+ /**
+ * Applies the given function to each cogrouped data. For each unique group, the function will
+ * be passed the grouping key and 2 iterators containing all elements in the group from
+ * [[Dataset]] `this` and `other`. The function can return an iterator containing elements of an
+ * arbitrary type which will be returned as a new [[Dataset]].
+ */
+ def cogroup[U, R : Encoder](
+ other: GroupedDataset[K, U])(
+ f: (K, Iterator[T], Iterator[U]) => Iterator[R]): Dataset[R] = {
+ implicit def uEnc: Encoder[U] = other.tEncoder
+ new Dataset[R](
+ sqlContext,
+ CoGroup(
+ f,
+ this.groupingAttributes,
+ other.groupingAttributes,
+ this.logicalPlan,
+ other.logicalPlan))
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CoGroupedIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CoGroupedIterator.scala
new file mode 100644
index 0000000000..ce5827855e
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CoGroupedIterator.scala
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{Ascending, SortOrder, Attribute}
+import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering
+
+/**
+ * Iterates over [[GroupedIterator]]s and returns the cogrouped data, i.e. each record is a
+ * grouping key with its associated values from all [[GroupedIterator]]s.
+ * Note: we assume the output of each [[GroupedIterator]] is ordered by the grouping key.
+ */
+class CoGroupedIterator(
+ left: Iterator[(InternalRow, Iterator[InternalRow])],
+ right: Iterator[(InternalRow, Iterator[InternalRow])],
+ groupingSchema: Seq[Attribute])
+ extends Iterator[(InternalRow, Iterator[InternalRow], Iterator[InternalRow])] {
+
+ private val keyOrdering =
+ GenerateOrdering.generate(groupingSchema.map(SortOrder(_, Ascending)), groupingSchema)
+
+ private var currentLeftData: (InternalRow, Iterator[InternalRow]) = _
+ private var currentRightData: (InternalRow, Iterator[InternalRow]) = _
+
+ override def hasNext: Boolean = left.hasNext || right.hasNext
+
+ override def next(): (InternalRow, Iterator[InternalRow], Iterator[InternalRow]) = {
+ if (currentLeftData.eq(null) && left.hasNext) {
+ currentLeftData = left.next()
+ }
+ if (currentRightData.eq(null) && right.hasNext) {
+ currentRightData = right.next()
+ }
+
+ assert(currentLeftData.ne(null) || currentRightData.ne(null))
+
+ if (currentLeftData.eq(null)) {
+ // left is null, right is not null, consume the right data.
+ rightOnly()
+ } else if (currentRightData.eq(null)) {
+ // left is not null, right is null, consume the left data.
+ leftOnly()
+ } else if (currentLeftData._1 == currentRightData._1) {
+ // left and right have the same grouping key, consume both of them.
+ val result = (currentLeftData._1, currentLeftData._2, currentRightData._2)
+ currentLeftData = null
+ currentRightData = null
+ result
+ } else {
+ val compare = keyOrdering.compare(currentLeftData._1, currentRightData._1)
+ assert(compare != 0)
+ if (compare < 0) {
+ // the grouping key of left is smaller, consume the left data.
+ leftOnly()
+ } else {
+ // the grouping key of right is smaller, consume the right data.
+ rightOnly()
+ }
+ }
+ }
+
+ private def leftOnly(): (InternalRow, Iterator[InternalRow], Iterator[InternalRow]) = {
+ val result = (currentLeftData._1, currentLeftData._2, Iterator.empty)
+ currentLeftData = null
+ result
+ }
+
+ private def rightOnly(): (InternalRow, Iterator[InternalRow], Iterator[InternalRow]) = {
+ val result = (currentRightData._1, Iterator.empty, currentRightData._2)
+ currentRightData = null
+ result
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index ee97162853..32067266b5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -393,6 +393,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
execution.AppendColumns(f, tEnc, uEnc, newCol, planLater(child)) :: Nil
case logical.MapGroups(f, kEnc, tEnc, uEnc, grouping, output, child) =>
execution.MapGroups(f, kEnc, tEnc, uEnc, grouping, output, planLater(child)) :: Nil
+ case logical.CoGroup(f, kEnc, leftEnc, rightEnc, rEnc, output,
+ leftGroup, rightGroup, left, right) =>
+ execution.CoGroup(f, kEnc, leftEnc, rightEnc, rEnc, output, leftGroup, rightGroup,
+ planLater(left), planLater(right)) :: Nil
case logical.Repartition(numPartitions, shuffle, child) =>
if (shuffle) {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 89938471ee..d5a803f8c4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -390,3 +390,44 @@ case class MapGroups[K, T, U](
}
}
}
+
+/**
+ * Co-groups the data from left and right children, and calls the function with each group and 2
+ * iterators containing all elements in the group from left and right side.
+ * The result of this function is encoded and flattened before being output.
+ */
+case class CoGroup[K, Left, Right, R](
+ func: (K, Iterator[Left], Iterator[Right]) => Iterator[R],
+ kEncoder: ExpressionEncoder[K],
+ leftEnc: ExpressionEncoder[Left],
+ rightEnc: ExpressionEncoder[Right],
+ rEncoder: ExpressionEncoder[R],
+ output: Seq[Attribute],
+ leftGroup: Seq[Attribute],
+ rightGroup: Seq[Attribute],
+ left: SparkPlan,
+ right: SparkPlan) extends BinaryNode {
+
+ override def requiredChildDistribution: Seq[Distribution] =
+ ClusteredDistribution(leftGroup) :: ClusteredDistribution(rightGroup) :: Nil
+
+ override def requiredChildOrdering: Seq[Seq[SortOrder]] =
+ leftGroup.map(SortOrder(_, Ascending)) :: rightGroup.map(SortOrder(_, Ascending)) :: Nil
+
+ override protected def doExecute(): RDD[InternalRow] = {
+ left.execute().zipPartitions(right.execute()) { (leftData, rightData) =>
+ val leftGrouped = GroupedIterator(leftData, leftGroup, left.output)
+ val rightGrouped = GroupedIterator(rightData, rightGroup, right.output)
+ val groupKeyEncoder = kEncoder.bind(leftGroup)
+
+ new CoGroupedIterator(leftGrouped, rightGrouped, leftGroup).flatMap {
+ case (key, leftResult, rightResult) =>
+ val result = func(
+ groupKeyEncoder.fromRow(key),
+ leftResult.map(leftEnc.fromRow),
+ rightResult.map(rightEnc.fromRow))
+ result.map(rEncoder.toRow)
+ }
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index aebb390a1d..993e6d269e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -202,4 +202,16 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
agged,
("a", 30), ("b", 3), ("c", 1))
}
+
+ test("cogroup") {
+ val ds1 = Seq(1 -> "a", 3 -> "abc", 5 -> "hello", 3 -> "foo").toDS()
+ val ds2 = Seq(2 -> "q", 3 -> "w", 5 -> "e", 5 -> "r").toDS()
+ val cogrouped = ds1.groupBy(_._1).cogroup(ds2.groupBy(_._1)) { case (key, data1, data2) =>
+ Iterator(key -> (data1.map(_._2).mkString + "#" + data2.map(_._2).mkString))
+ }
+
+ checkAnswer(
+ cogrouped,
+ 1 -> "a#", 2 -> "#q", 3 -> "abcfoo#w", 5 -> "hello#er")
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/CoGroupedIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/CoGroupedIteratorSuite.scala
new file mode 100644
index 0000000000..d1fe81947e
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/CoGroupedIteratorSuite.scala
@@ -0,0 +1,51 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.expressions.ExpressionEvalHelper
+
+class CoGroupedIteratorSuite extends SparkFunSuite with ExpressionEvalHelper {
+
+ test("basic") {
+ val leftInput = Seq(create_row(1, "a"), create_row(1, "b"), create_row(2, "c")).iterator
+ val rightInput = Seq(create_row(1, 2L), create_row(2, 3L), create_row(3, 4L)).iterator
+ val leftGrouped = GroupedIterator(leftInput, Seq('i.int.at(0)), Seq('i.int, 's.string))
+ val rightGrouped = GroupedIterator(rightInput, Seq('i.int.at(0)), Seq('i.int, 'l.long))
+ val cogrouped = new CoGroupedIterator(leftGrouped, rightGrouped, Seq('i.int))
+
+ val result = cogrouped.map {
+ case (key, leftData, rightData) =>
+ assert(key.numFields == 1)
+ (key.getInt(0), leftData.toSeq, rightData.toSeq)
+ }.toSeq
+ assert(result ==
+ (1,
+ Seq(create_row(1, "a"), create_row(1, "b")),
+ Seq(create_row(1, 2L))) ::
+ (2,
+ Seq(create_row(2, "c")),
+ Seq(create_row(2, 3L))) ::
+ (3,
+ Seq.empty,
+ Seq(create_row(3, 4L))) ::
+ Nil
+ )
+ }
+}