From 075ce4914fdcbbcc7286c3c30cb940ed28d474d2 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Wed, 28 Oct 2015 13:58:52 +0100 Subject: [SPARK-11313][SQL] implement cogroup on DataSets (support 2 datasets) A simpler version of https://github.com/apache/spark/pull/9279, only support 2 datasets. Author: Wenchen Fan Closes #9324 from cloud-fan/cogroup2. --- .../org/apache/spark/sql/GroupedDataset.scala | 20 +++++ .../spark/sql/execution/CoGroupedIterator.scala | 89 ++++++++++++++++++++++ .../spark/sql/execution/SparkStrategies.scala | 4 + .../spark/sql/execution/basicOperators.scala | 41 ++++++++++ .../scala/org/apache/spark/sql/DatasetSuite.scala | 12 +++ .../sql/execution/CoGroupedIteratorSuite.scala | 51 +++++++++++++ 6 files changed, 217 insertions(+) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/CoGroupedIterator.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/CoGroupedIteratorSuite.scala (limited to 'sql/core') diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index 89a16dd8b0..612f2b60cd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -65,4 +65,24 @@ class GroupedDataset[K, T] private[sql]( sqlContext, MapGroups(f, groupingAttributes, logicalPlan)) } + + /** + * Applies the given function to each cogrouped data. For each unique group, the function will + * be passed the grouping key and 2 iterators containing all elements in the group from + * [[Dataset]] `this` and `other`. The function can return an iterator containing elements of an + * arbitrary type which will be returned as a new [[Dataset]]. + */ + def cogroup[U, R : Encoder]( + other: GroupedDataset[K, U])( + f: (K, Iterator[T], Iterator[U]) => Iterator[R]): Dataset[R] = { + implicit def uEnc: Encoder[U] = other.tEncoder + new Dataset[R]( + sqlContext, + CoGroup( + f, + this.groupingAttributes, + other.groupingAttributes, + this.logicalPlan, + other.logicalPlan)) + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CoGroupedIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CoGroupedIterator.scala new file mode 100644 index 0000000000..ce5827855e --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CoGroupedIterator.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Ascending, SortOrder, Attribute} +import org.apache.spark.sql.catalyst.expressions.codegen.GenerateOrdering + +/** + * Iterates over [[GroupedIterator]]s and returns the cogrouped data, i.e. each record is a + * grouping key with its associated values from all [[GroupedIterator]]s. + * Note: we assume the output of each [[GroupedIterator]] is ordered by the grouping key. + */ +class CoGroupedIterator( + left: Iterator[(InternalRow, Iterator[InternalRow])], + right: Iterator[(InternalRow, Iterator[InternalRow])], + groupingSchema: Seq[Attribute]) + extends Iterator[(InternalRow, Iterator[InternalRow], Iterator[InternalRow])] { + + private val keyOrdering = + GenerateOrdering.generate(groupingSchema.map(SortOrder(_, Ascending)), groupingSchema) + + private var currentLeftData: (InternalRow, Iterator[InternalRow]) = _ + private var currentRightData: (InternalRow, Iterator[InternalRow]) = _ + + override def hasNext: Boolean = left.hasNext || right.hasNext + + override def next(): (InternalRow, Iterator[InternalRow], Iterator[InternalRow]) = { + if (currentLeftData.eq(null) && left.hasNext) { + currentLeftData = left.next() + } + if (currentRightData.eq(null) && right.hasNext) { + currentRightData = right.next() + } + + assert(currentLeftData.ne(null) || currentRightData.ne(null)) + + if (currentLeftData.eq(null)) { + // left is null, right is not null, consume the right data. + rightOnly() + } else if (currentRightData.eq(null)) { + // left is not null, right is null, consume the left data. + leftOnly() + } else if (currentLeftData._1 == currentRightData._1) { + // left and right have the same grouping key, consume both of them. + val result = (currentLeftData._1, currentLeftData._2, currentRightData._2) + currentLeftData = null + currentRightData = null + result + } else { + val compare = keyOrdering.compare(currentLeftData._1, currentRightData._1) + assert(compare != 0) + if (compare < 0) { + // the grouping key of left is smaller, consume the left data. + leftOnly() + } else { + // the grouping key of right is smaller, consume the right data. + rightOnly() + } + } + } + + private def leftOnly(): (InternalRow, Iterator[InternalRow], Iterator[InternalRow]) = { + val result = (currentLeftData._1, currentLeftData._2, Iterator.empty) + currentLeftData = null + result + } + + private def rightOnly(): (InternalRow, Iterator[InternalRow], Iterator[InternalRow]) = { + val result = (currentRightData._1, Iterator.empty, currentRightData._2) + currentRightData = null + result + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index ee97162853..32067266b5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -393,6 +393,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.AppendColumns(f, tEnc, uEnc, newCol, planLater(child)) :: Nil case logical.MapGroups(f, kEnc, tEnc, uEnc, grouping, output, child) => execution.MapGroups(f, kEnc, tEnc, uEnc, grouping, output, planLater(child)) :: Nil + case logical.CoGroup(f, kEnc, leftEnc, rightEnc, rEnc, output, + leftGroup, rightGroup, left, right) => + execution.CoGroup(f, kEnc, leftEnc, rightEnc, rEnc, output, leftGroup, rightGroup, + planLater(left), planLater(right)) :: Nil case logical.Repartition(numPartitions, shuffle, child) => if (shuffle) { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index 89938471ee..d5a803f8c4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -390,3 +390,44 @@ case class MapGroups[K, T, U]( } } } + +/** + * Co-groups the data from left and right children, and calls the function with each group and 2 + * iterators containing all elements in the group from left and right side. + * The result of this function is encoded and flattened before being output. + */ +case class CoGroup[K, Left, Right, R]( + func: (K, Iterator[Left], Iterator[Right]) => Iterator[R], + kEncoder: ExpressionEncoder[K], + leftEnc: ExpressionEncoder[Left], + rightEnc: ExpressionEncoder[Right], + rEncoder: ExpressionEncoder[R], + output: Seq[Attribute], + leftGroup: Seq[Attribute], + rightGroup: Seq[Attribute], + left: SparkPlan, + right: SparkPlan) extends BinaryNode { + + override def requiredChildDistribution: Seq[Distribution] = + ClusteredDistribution(leftGroup) :: ClusteredDistribution(rightGroup) :: Nil + + override def requiredChildOrdering: Seq[Seq[SortOrder]] = + leftGroup.map(SortOrder(_, Ascending)) :: rightGroup.map(SortOrder(_, Ascending)) :: Nil + + override protected def doExecute(): RDD[InternalRow] = { + left.execute().zipPartitions(right.execute()) { (leftData, rightData) => + val leftGrouped = GroupedIterator(leftData, leftGroup, left.output) + val rightGrouped = GroupedIterator(rightData, rightGroup, right.output) + val groupKeyEncoder = kEncoder.bind(leftGroup) + + new CoGroupedIterator(leftGrouped, rightGrouped, leftGroup).flatMap { + case (key, leftResult, rightResult) => + val result = func( + groupKeyEncoder.fromRow(key), + leftResult.map(leftEnc.fromRow), + rightResult.map(rightEnc.fromRow)) + result.map(rEncoder.toRow) + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index aebb390a1d..993e6d269e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -202,4 +202,16 @@ class DatasetSuite extends QueryTest with SharedSQLContext { agged, ("a", 30), ("b", 3), ("c", 1)) } + + test("cogroup") { + val ds1 = Seq(1 -> "a", 3 -> "abc", 5 -> "hello", 3 -> "foo").toDS() + val ds2 = Seq(2 -> "q", 3 -> "w", 5 -> "e", 5 -> "r").toDS() + val cogrouped = ds1.groupBy(_._1).cogroup(ds2.groupBy(_._1)) { case (key, data1, data2) => + Iterator(key -> (data1.map(_._2).mkString + "#" + data2.map(_._2).mkString)) + } + + checkAnswer( + cogrouped, + 1 -> "a#", 2 -> "#q", 3 -> "abcfoo#w", 5 -> "hello#er") + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/CoGroupedIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/CoGroupedIteratorSuite.scala new file mode 100644 index 0000000000..d1fe81947e --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/CoGroupedIteratorSuite.scala @@ -0,0 +1,51 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.dsl.expressions._ +import org.apache.spark.sql.catalyst.expressions.ExpressionEvalHelper + +class CoGroupedIteratorSuite extends SparkFunSuite with ExpressionEvalHelper { + + test("basic") { + val leftInput = Seq(create_row(1, "a"), create_row(1, "b"), create_row(2, "c")).iterator + val rightInput = Seq(create_row(1, 2L), create_row(2, 3L), create_row(3, 4L)).iterator + val leftGrouped = GroupedIterator(leftInput, Seq('i.int.at(0)), Seq('i.int, 's.string)) + val rightGrouped = GroupedIterator(rightInput, Seq('i.int.at(0)), Seq('i.int, 'l.long)) + val cogrouped = new CoGroupedIterator(leftGrouped, rightGrouped, Seq('i.int)) + + val result = cogrouped.map { + case (key, leftData, rightData) => + assert(key.numFields == 1) + (key.getInt(0), leftData.toSeq, rightData.toSeq) + }.toSeq + assert(result == + (1, + Seq(create_row(1, "a"), create_row(1, "b")), + Seq(create_row(1, 2L))) :: + (2, + Seq(create_row(2, "c")), + Seq(create_row(2, 3L))) :: + (3, + Seq.empty, + Seq(create_row(3, 4L))) :: + Nil + ) + } +} -- cgit v1.2.3