From 13f5f8ec97c6886346641b73bd99004e0d70836c Mon Sep 17 00:00:00 2001 From: zsxwing Date: Sat, 29 Aug 2015 18:10:44 -0700 Subject: [SPARK-9986] [SPARK-9991] [SPARK-9993] [SQL] Create a simple test framework for local operators This PR includes the following changes: - Add `LocalNodeTest` for local operator tests and add unit tests for FilterNode and ProjectNode. - Add `LimitNode` and `UnionNode` and their unit tests to show how to use `LocalNodeTest`. (SPARK-9991, SPARK-9993) Author: zsxwing Closes #8464 from zsxwing/local-execution. --- .../spark/sql/execution/local/FilterNode.scala | 6 +- .../spark/sql/execution/local/LimitNode.scala | 45 +++++++ .../spark/sql/execution/local/LocalNode.scala | 13 +- .../spark/sql/execution/local/ProjectNode.scala | 4 +- .../spark/sql/execution/local/SeqScanNode.scala | 2 +- .../spark/sql/execution/local/UnionNode.scala | 72 ++++++++++ .../apache/spark/sql/execution/SparkPlanTest.scala | 46 +------ .../sql/execution/local/FilterNodeSuite.scala | 41 ++++++ .../spark/sql/execution/local/LimitNodeSuite.scala | 39 ++++++ .../spark/sql/execution/local/LocalNodeTest.scala | 146 +++++++++++++++++++++ .../sql/execution/local/ProjectNodeSuite.scala | 44 +++++++ .../spark/sql/execution/local/UnionNodeSuite.scala | 52 ++++++++ .../org/apache/spark/sql/test/SQLTestData.scala | 8 ++ .../org/apache/spark/sql/test/SQLTestUtils.scala | 46 ++++++- 14 files changed, 509 insertions(+), 55 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/local/LimitNode.scala create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/local/UnionNode.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala (limited to 'sql') diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/FilterNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/FilterNode.scala index a485a1a1d7..81dd37c7da 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/FilterNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/FilterNode.scala @@ -35,13 +35,13 @@ case class FilterNode(condition: Expression, child: LocalNode) extends UnaryLoca override def next(): Boolean = { var found = false - while (child.next() && !found) { - found = predicate.apply(child.get()) + while (!found && child.next()) { + found = predicate.apply(child.fetch()) } found } - override def get(): InternalRow = child.get() + override def fetch(): InternalRow = child.fetch() override def close(): Unit = child.close() } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LimitNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LimitNode.scala new file mode 100644 index 0000000000..fffc52abf6 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LimitNode.scala @@ -0,0 +1,45 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute + + +case class LimitNode(limit: Int, child: LocalNode) extends UnaryLocalNode { + + private[this] var count = 0 + + override def output: Seq[Attribute] = child.output + + override def open(): Unit = child.open() + + override def close(): Unit = child.close() + + override def fetch(): InternalRow = child.fetch() + + override def next(): Boolean = { + if (count < limit) { + count += 1 + child.next() + } else { + false + } + } + +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala index 341c81438e..1c4469acbf 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala @@ -48,10 +48,10 @@ abstract class LocalNode extends TreeNode[LocalNode] { /** * Returns the current tuple. */ - def get(): InternalRow + def fetch(): InternalRow /** - * Closes the iterator and releases all resources. + * Closes the iterator and releases all resources. It should be idempotent. * * Implementations of this must also call the `close()` function of its children. */ @@ -64,10 +64,13 @@ abstract class LocalNode extends TreeNode[LocalNode] { val converter = CatalystTypeConverters.createToScalaConverter(StructType.fromAttributes(output)) val result = new scala.collection.mutable.ArrayBuffer[Row] open() - while (next()) { - result += converter.apply(get()).asInstanceOf[Row] + try { + while (next()) { + result += converter.apply(fetch()).asInstanceOf[Row] + } + } finally { + close() } - close() result } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ProjectNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ProjectNode.scala index e574d1473c..9b8a4fe493 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ProjectNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ProjectNode.scala @@ -34,8 +34,8 @@ case class ProjectNode(projectList: Seq[NamedExpression], child: LocalNode) exte override def next(): Boolean = child.next() - override def get(): InternalRow = { - project.apply(child.get()) + override def fetch(): InternalRow = { + project.apply(child.fetch()) } override def close(): Unit = child.close() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SeqScanNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SeqScanNode.scala index 994de8afa9..242cb66e07 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SeqScanNode.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SeqScanNode.scala @@ -41,7 +41,7 @@ case class SeqScanNode(output: Seq[Attribute], data: Seq[InternalRow]) extends L } } - override def get(): InternalRow = currentRow + override def fetch(): InternalRow = currentRow override def close(): Unit = { // Do nothing diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/UnionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/UnionNode.scala new file mode 100644 index 0000000000..ba4aa7671a --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/UnionNode.scala @@ -0,0 +1,72 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.Attribute + +case class UnionNode(children: Seq[LocalNode]) extends LocalNode { + + override def output: Seq[Attribute] = children.head.output + + private[this] var currentChild: LocalNode = _ + + private[this] var nextChildIndex: Int = _ + + override def open(): Unit = { + currentChild = children.head + currentChild.open() + nextChildIndex = 1 + } + + private def advanceToNextChild(): Boolean = { + var found = false + var exit = false + while (!exit && !found) { + if (currentChild != null) { + currentChild.close() + } + if (nextChildIndex >= children.size) { + found = false + exit = true + } else { + currentChild = children(nextChildIndex) + nextChildIndex += 1 + currentChild.open() + found = currentChild.next() + } + } + found + } + + override def close(): Unit = { + if (currentChild != null) { + currentChild.close() + } + } + + override def fetch(): InternalRow = currentChild.fetch() + + override def next(): Boolean = { + if (currentChild.next()) { + true + } else { + advanceToNextChild() + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala index 3a87f374d9..5ab8f44fae 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala @@ -24,7 +24,7 @@ import scala.util.control.NonFatal import org.apache.spark.SparkFunSuite import org.apache.spark.sql.{DataFrame, DataFrameHolder, Row, SQLContext} import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute -import org.apache.spark.sql.catalyst.util._ +import org.apache.spark.sql.test.SQLTestUtils /** * Base class for writing tests for individual physical operators. For an example of how this @@ -184,7 +184,7 @@ object SparkPlanTest { return Some(errorMessage) } - compareAnswers(actualAnswer, expectedAnswer, sortAnswers).map { errorMessage => + SQLTestUtils.compareAnswers(actualAnswer, expectedAnswer, sortAnswers).map { errorMessage => s""" | Results do not match. | Actual result Spark plan: @@ -229,7 +229,7 @@ object SparkPlanTest { return Some(errorMessage) } - compareAnswers(sparkAnswer, expectedAnswer, sortAnswers).map { errorMessage => + SQLTestUtils.compareAnswers(sparkAnswer, expectedAnswer, sortAnswers).map { errorMessage => s""" | Results do not match for Spark plan: | $outputPlan @@ -238,46 +238,6 @@ object SparkPlanTest { } } - private def compareAnswers( - sparkAnswer: Seq[Row], - expectedAnswer: Seq[Row], - sort: Boolean): Option[String] = { - def prepareAnswer(answer: Seq[Row]): Seq[Row] = { - // Converts data to types that we can do equality comparison using Scala collections. - // For BigDecimal type, the Scala type has a better definition of equality test (similar to - // Java's java.math.BigDecimal.compareTo). - // For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for - // equality test. - // This function is copied from Catalyst's QueryTest - val converted: Seq[Row] = answer.map { s => - Row.fromSeq(s.toSeq.map { - case d: java.math.BigDecimal => BigDecimal(d) - case b: Array[Byte] => b.toSeq - case o => o - }) - } - if (sort) { - converted.sortBy(_.toString()) - } else { - converted - } - } - if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) { - val errorMessage = - s""" - | == Results == - | ${sideBySide( - s"== Expected Answer - ${expectedAnswer.size} ==" +: - prepareAnswer(expectedAnswer).map(_.toString()), - s"== Actual Answer - ${sparkAnswer.size} ==" +: - prepareAnswer(sparkAnswer).map(_.toString())).mkString("\n")} - """.stripMargin - Some(errorMessage) - } else { - None - } - } - private def executePlan(outputPlan: SparkPlan, _sqlContext: SQLContext): Seq[Row] = { // A very simple resolver to make writing tests easier. In contrast to the real resolver // this is always case sensitive and does not try to handle scoping or complex type resolution. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala new file mode 100644 index 0000000000..07209f3779 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala @@ -0,0 +1,41 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.test.SharedSQLContext + +class FilterNodeSuite extends LocalNodeTest with SharedSQLContext { + + test("basic") { + val condition = (testData.col("key") % 2) === 0 + checkAnswer( + testData, + node => FilterNode(condition.expr, node), + testData.filter(condition).collect() + ) + } + + test("empty") { + val condition = (emptyTestData.col("key") % 2) === 0 + checkAnswer( + emptyTestData, + node => FilterNode(condition.expr, node), + emptyTestData.filter(condition).collect() + ) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala new file mode 100644 index 0000000000..523c02f4a6 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala @@ -0,0 +1,39 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.test.SharedSQLContext + +class LimitNodeSuite extends LocalNodeTest with SharedSQLContext { + + test("basic") { + checkAnswer( + testData, + node => LimitNode(10, node), + testData.limit(10).collect() + ) + } + + test("empty") { + checkAnswer( + emptyTestData, + node => LimitNode(10, node), + emptyTestData.limit(10).collect() + ) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala new file mode 100644 index 0000000000..95f06081bd --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala @@ -0,0 +1,146 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import scala.util.control.NonFatal + +import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.test.SQLTestUtils + +class LocalNodeTest extends SparkFunSuite { + + /** + * Runs the LocalNode and makes sure the answer matches the expected result. + * @param input the input data to be used. + * @param nodeFunction a function which accepts the input LocalNode and uses it to instantiate + * the local physical operator that's being tested. + * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. + * @param sortAnswers if true, the answers will be sorted by their toString representations prior + * to being compared. + */ + protected def checkAnswer( + input: DataFrame, + nodeFunction: LocalNode => LocalNode, + expectedAnswer: Seq[Row], + sortAnswers: Boolean = true): Unit = { + doCheckAnswer( + input :: Nil, + nodes => nodeFunction(nodes.head), + expectedAnswer, + sortAnswers) + } + + /** + * Runs the LocalNode and makes sure the answer matches the expected result. + * @param left the left input data to be used. + * @param right the right input data to be used. + * @param nodeFunction a function which accepts the input LocalNode and uses it to instantiate + * the local physical operator that's being tested. + * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. + * @param sortAnswers if true, the answers will be sorted by their toString representations prior + * to being compared. + */ + protected def checkAnswer2( + left: DataFrame, + right: DataFrame, + nodeFunction: (LocalNode, LocalNode) => LocalNode, + expectedAnswer: Seq[Row], + sortAnswers: Boolean = true): Unit = { + doCheckAnswer( + left :: right :: Nil, + nodes => nodeFunction(nodes(0), nodes(1)), + expectedAnswer, + sortAnswers) + } + + /** + * Runs the `LocalNode`s and makes sure the answer matches the expected result. + * @param input the input data to be used. + * @param nodeFunction a function which accepts a sequence of input `LocalNode`s and uses them to + * instantiate the local physical operator that's being tested. + * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. + * @param sortAnswers if true, the answers will be sorted by their toString representations prior + * to being compared. + */ + protected def doCheckAnswer( + input: Seq[DataFrame], + nodeFunction: Seq[LocalNode] => LocalNode, + expectedAnswer: Seq[Row], + sortAnswers: Boolean = true): Unit = { + LocalNodeTest.checkAnswer( + input.map(dataFrameToSeqScanNode), nodeFunction, expectedAnswer, sortAnswers) match { + case Some(errorMessage) => fail(errorMessage) + case None => + } + } + + protected def dataFrameToSeqScanNode(df: DataFrame): SeqScanNode = { + new SeqScanNode( + df.queryExecution.sparkPlan.output, + df.queryExecution.toRdd.map(_.copy()).collect()) + } + +} + +/** + * Helper methods for writing tests of individual local physical operators. + */ +object LocalNodeTest { + + /** + * Runs the `LocalNode`s and makes sure the answer matches the expected result. + * @param input the input data to be used. + * @param nodeFunction a function which accepts the input `LocalNode`s and uses them to + * instantiate the local physical operator that's being tested. + * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s. + * @param sortAnswers if true, the answers will be sorted by their toString representations prior + * to being compared. + */ + def checkAnswer( + input: Seq[SeqScanNode], + nodeFunction: Seq[LocalNode] => LocalNode, + expectedAnswer: Seq[Row], + sortAnswers: Boolean): Option[String] = { + + val outputNode = nodeFunction(input) + + val outputResult: Seq[Row] = try { + outputNode.collect() + } catch { + case NonFatal(e) => + val errorMessage = + s""" + | Exception thrown while executing local plan: + | $outputNode + | == Exception == + | $e + | ${org.apache.spark.sql.catalyst.util.stackTraceToString(e)} + """.stripMargin + return Some(errorMessage) + } + + SQLTestUtils.compareAnswers(outputResult, expectedAnswer, sortAnswers).map { errorMessage => + s""" + | Results do not match for local plan: + | $outputNode + | $errorMessage + """.stripMargin + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala new file mode 100644 index 0000000000..ffcf092e2c --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala @@ -0,0 +1,44 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.test.SharedSQLContext + +class ProjectNodeSuite extends LocalNodeTest with SharedSQLContext { + + test("basic") { + val output = testData.queryExecution.sparkPlan.output + val columns = Seq(output(1), output(0)) + checkAnswer( + testData, + node => ProjectNode(columns, node), + testData.select("value", "key").collect() + ) + } + + test("empty") { + val output = emptyTestData.queryExecution.sparkPlan.output + val columns = Seq(output(1), output(0)) + checkAnswer( + emptyTestData, + node => ProjectNode(columns, node), + emptyTestData.select("value", "key").collect() + ) + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala new file mode 100644 index 0000000000..34670287c3 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala @@ -0,0 +1,52 @@ +/* +* Licensed to the Apache Software Foundation (ASF) under one or more +* contributor license agreements. See the NOTICE file distributed with +* this work for additional information regarding copyright ownership. +* The ASF licenses this file to You under the Apache License, Version 2.0 +* (the "License"); you may not use this file except in compliance with +* the License. You may obtain a copy of the License at +* +* http://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, +* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +* See the License for the specific language governing permissions and +* limitations under the License. +*/ + +package org.apache.spark.sql.execution.local + +import org.apache.spark.sql.test.SharedSQLContext + +class UnionNodeSuite extends LocalNodeTest with SharedSQLContext { + + test("basic") { + checkAnswer2( + testData, + testData, + (node1, node2) => UnionNode(Seq(node1, node2)), + testData.unionAll(testData).collect() + ) + } + + test("empty") { + checkAnswer2( + emptyTestData, + emptyTestData, + (node1, node2) => UnionNode(Seq(node1, node2)), + emptyTestData.unionAll(emptyTestData).collect() + ) + } + + test("complicated union") { + val dfs = Seq(testData, emptyTestData, emptyTestData, testData, testData, emptyTestData, + emptyTestData, emptyTestData, testData, emptyTestData) + doCheckAnswer( + dfs, + nodes => UnionNode(nodes), + dfs.reduce(_.unionAll(_)).collect() + ) + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala index 1374a97476..3fc02df954 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala @@ -36,6 +36,13 @@ private[sql] trait SQLTestData { self => // Note: all test data should be lazy because the SQLContext is not set up yet. + protected lazy val emptyTestData: DataFrame = { + val df = _sqlContext.sparkContext.parallelize( + Seq.empty[Int].map(i => TestData(i, i.toString))).toDF() + df.registerTempTable("emptyTestData") + df + } + protected lazy val testData: DataFrame = { val df = _sqlContext.sparkContext.parallelize( (1 to 100).map(i => TestData(i, i.toString))).toDF() @@ -240,6 +247,7 @@ private[sql] trait SQLTestData { self => */ def loadTestData(): Unit = { assert(_sqlContext != null, "attempted to initialize test data before SQLContext.") + emptyTestData testData testData2 testData3 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala index cdd691e035..dc08306ad9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala @@ -27,8 +27,9 @@ import org.apache.hadoop.conf.Configuration import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.{DataFrame, SQLContext, SQLImplicits} +import org.apache.spark.sql.{DataFrame, Row, SQLContext, SQLImplicits} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.util._ import org.apache.spark.util.Utils /** @@ -179,3 +180,46 @@ private[sql] trait SQLTestUtils DataFrame(_sqlContext, plan) } } + +private[sql] object SQLTestUtils { + + def compareAnswers( + sparkAnswer: Seq[Row], + expectedAnswer: Seq[Row], + sort: Boolean): Option[String] = { + def prepareAnswer(answer: Seq[Row]): Seq[Row] = { + // Converts data to types that we can do equality comparison using Scala collections. + // For BigDecimal type, the Scala type has a better definition of equality test (similar to + // Java's java.math.BigDecimal.compareTo). + // For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for + // equality test. + // This function is copied from Catalyst's QueryTest + val converted: Seq[Row] = answer.map { s => + Row.fromSeq(s.toSeq.map { + case d: java.math.BigDecimal => BigDecimal(d) + case b: Array[Byte] => b.toSeq + case o => o + }) + } + if (sort) { + converted.sortBy(_.toString()) + } else { + converted + } + } + if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) { + val errorMessage = + s""" + | == Results == + | ${sideBySide( + s"== Expected Answer - ${expectedAnswer.size} ==" +: + prepareAnswer(expectedAnswer).map(_.toString()), + s"== Actual Answer - ${sparkAnswer.size} ==" +: + prepareAnswer(sparkAnswer).map(_.toString())).mkString("\n")} + """.stripMargin + Some(errorMessage) + } else { + None + } + } +} -- cgit v1.2.3