aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorzsxwing <zsxwing@gmail.com>2015-08-29 18:10:44 -0700
committerReynold Xin <rxin@databricks.com>2015-08-29 18:10:44 -0700
commit13f5f8ec97c6886346641b73bd99004e0d70836c (patch)
tree91a977694bf4baee865f1af409c1f9b71a1c0b71 /sql
parent097a7e36e0bf7290b1879331375bacc905583bd3 (diff)
downloadspark-13f5f8ec97c6886346641b73bd99004e0d70836c.tar.gz
spark-13f5f8ec97c6886346641b73bd99004e0d70836c.tar.bz2
spark-13f5f8ec97c6886346641b73bd99004e0d70836c.zip
[SPARK-9986] [SPARK-9991] [SPARK-9993] [SQL] Create a simple test framework for local operators
This PR includes the following changes: - Add `LocalNodeTest` for local operator tests and add unit tests for FilterNode and ProjectNode. - Add `LimitNode` and `UnionNode` and their unit tests to show how to use `LocalNodeTest`. (SPARK-9991, SPARK-9993) Author: zsxwing <zsxwing@gmail.com> Closes #8464 from zsxwing/local-execution.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/local/FilterNode.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/local/LimitNode.scala45
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala13
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/local/ProjectNode.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/local/SeqScanNode.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/local/UnionNode.scala72
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala46
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala41
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala39
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala146
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala44
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala52
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala46
14 files changed, 509 insertions, 55 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/FilterNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/FilterNode.scala
index a485a1a1d7..81dd37c7da 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/FilterNode.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/FilterNode.scala
@@ -35,13 +35,13 @@ case class FilterNode(condition: Expression, child: LocalNode) extends UnaryLoca
override def next(): Boolean = {
var found = false
- while (child.next() && !found) {
- found = predicate.apply(child.get())
+ while (!found && child.next()) {
+ found = predicate.apply(child.fetch())
}
found
}
- override def get(): InternalRow = child.get()
+ override def fetch(): InternalRow = child.fetch()
override def close(): Unit = child.close()
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LimitNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LimitNode.scala
new file mode 100644
index 0000000000..fffc52abf6
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LimitNode.scala
@@ -0,0 +1,45 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql.execution.local
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.Attribute
+
+
+case class LimitNode(limit: Int, child: LocalNode) extends UnaryLocalNode {
+
+ private[this] var count = 0
+
+ override def output: Seq[Attribute] = child.output
+
+ override def open(): Unit = child.open()
+
+ override def close(): Unit = child.close()
+
+ override def fetch(): InternalRow = child.fetch()
+
+ override def next(): Boolean = {
+ if (count < limit) {
+ count += 1
+ child.next()
+ } else {
+ false
+ }
+ }
+
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala
index 341c81438e..1c4469acbf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala
@@ -48,10 +48,10 @@ abstract class LocalNode extends TreeNode[LocalNode] {
/**
* Returns the current tuple.
*/
- def get(): InternalRow
+ def fetch(): InternalRow
/**
- * Closes the iterator and releases all resources.
+ * Closes the iterator and releases all resources. It should be idempotent.
*
* Implementations of this must also call the `close()` function of its children.
*/
@@ -64,10 +64,13 @@ abstract class LocalNode extends TreeNode[LocalNode] {
val converter = CatalystTypeConverters.createToScalaConverter(StructType.fromAttributes(output))
val result = new scala.collection.mutable.ArrayBuffer[Row]
open()
- while (next()) {
- result += converter.apply(get()).asInstanceOf[Row]
+ try {
+ while (next()) {
+ result += converter.apply(fetch()).asInstanceOf[Row]
+ }
+ } finally {
+ close()
}
- close()
result
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ProjectNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ProjectNode.scala
index e574d1473c..9b8a4fe493 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ProjectNode.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ProjectNode.scala
@@ -34,8 +34,8 @@ case class ProjectNode(projectList: Seq[NamedExpression], child: LocalNode) exte
override def next(): Boolean = child.next()
- override def get(): InternalRow = {
- project.apply(child.get())
+ override def fetch(): InternalRow = {
+ project.apply(child.fetch())
}
override def close(): Unit = child.close()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SeqScanNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SeqScanNode.scala
index 994de8afa9..242cb66e07 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SeqScanNode.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SeqScanNode.scala
@@ -41,7 +41,7 @@ case class SeqScanNode(output: Seq[Attribute], data: Seq[InternalRow]) extends L
}
}
- override def get(): InternalRow = currentRow
+ override def fetch(): InternalRow = currentRow
override def close(): Unit = {
// Do nothing
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/UnionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/UnionNode.scala
new file mode 100644
index 0000000000..ba4aa7671a
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/UnionNode.scala
@@ -0,0 +1,72 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql.execution.local
+
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.Attribute
+
+case class UnionNode(children: Seq[LocalNode]) extends LocalNode {
+
+ override def output: Seq[Attribute] = children.head.output
+
+ private[this] var currentChild: LocalNode = _
+
+ private[this] var nextChildIndex: Int = _
+
+ override def open(): Unit = {
+ currentChild = children.head
+ currentChild.open()
+ nextChildIndex = 1
+ }
+
+ private def advanceToNextChild(): Boolean = {
+ var found = false
+ var exit = false
+ while (!exit && !found) {
+ if (currentChild != null) {
+ currentChild.close()
+ }
+ if (nextChildIndex >= children.size) {
+ found = false
+ exit = true
+ } else {
+ currentChild = children(nextChildIndex)
+ nextChildIndex += 1
+ currentChild.open()
+ found = currentChild.next()
+ }
+ }
+ found
+ }
+
+ override def close(): Unit = {
+ if (currentChild != null) {
+ currentChild.close()
+ }
+ }
+
+ override def fetch(): InternalRow = currentChild.fetch()
+
+ override def next(): Boolean = {
+ if (currentChild.next()) {
+ true
+ } else {
+ advanceToNextChild()
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
index 3a87f374d9..5ab8f44fae 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
@@ -24,7 +24,7 @@ import scala.util.control.NonFatal
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.{DataFrame, DataFrameHolder, Row, SQLContext}
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
-import org.apache.spark.sql.catalyst.util._
+import org.apache.spark.sql.test.SQLTestUtils
/**
* Base class for writing tests for individual physical operators. For an example of how this
@@ -184,7 +184,7 @@ object SparkPlanTest {
return Some(errorMessage)
}
- compareAnswers(actualAnswer, expectedAnswer, sortAnswers).map { errorMessage =>
+ SQLTestUtils.compareAnswers(actualAnswer, expectedAnswer, sortAnswers).map { errorMessage =>
s"""
| Results do not match.
| Actual result Spark plan:
@@ -229,7 +229,7 @@ object SparkPlanTest {
return Some(errorMessage)
}
- compareAnswers(sparkAnswer, expectedAnswer, sortAnswers).map { errorMessage =>
+ SQLTestUtils.compareAnswers(sparkAnswer, expectedAnswer, sortAnswers).map { errorMessage =>
s"""
| Results do not match for Spark plan:
| $outputPlan
@@ -238,46 +238,6 @@ object SparkPlanTest {
}
}
- private def compareAnswers(
- sparkAnswer: Seq[Row],
- expectedAnswer: Seq[Row],
- sort: Boolean): Option[String] = {
- def prepareAnswer(answer: Seq[Row]): Seq[Row] = {
- // Converts data to types that we can do equality comparison using Scala collections.
- // For BigDecimal type, the Scala type has a better definition of equality test (similar to
- // Java's java.math.BigDecimal.compareTo).
- // For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for
- // equality test.
- // This function is copied from Catalyst's QueryTest
- val converted: Seq[Row] = answer.map { s =>
- Row.fromSeq(s.toSeq.map {
- case d: java.math.BigDecimal => BigDecimal(d)
- case b: Array[Byte] => b.toSeq
- case o => o
- })
- }
- if (sort) {
- converted.sortBy(_.toString())
- } else {
- converted
- }
- }
- if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) {
- val errorMessage =
- s"""
- | == Results ==
- | ${sideBySide(
- s"== Expected Answer - ${expectedAnswer.size} ==" +:
- prepareAnswer(expectedAnswer).map(_.toString()),
- s"== Actual Answer - ${sparkAnswer.size} ==" +:
- prepareAnswer(sparkAnswer).map(_.toString())).mkString("\n")}
- """.stripMargin
- Some(errorMessage)
- } else {
- None
- }
- }
-
private def executePlan(outputPlan: SparkPlan, _sqlContext: SQLContext): Seq[Row] = {
// A very simple resolver to make writing tests easier. In contrast to the real resolver
// this is always case sensitive and does not try to handle scoping or complex type resolution.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala
new file mode 100644
index 0000000000..07209f3779
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala
@@ -0,0 +1,41 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql.execution.local
+
+import org.apache.spark.sql.test.SharedSQLContext
+
+class FilterNodeSuite extends LocalNodeTest with SharedSQLContext {
+
+ test("basic") {
+ val condition = (testData.col("key") % 2) === 0
+ checkAnswer(
+ testData,
+ node => FilterNode(condition.expr, node),
+ testData.filter(condition).collect()
+ )
+ }
+
+ test("empty") {
+ val condition = (emptyTestData.col("key") % 2) === 0
+ checkAnswer(
+ emptyTestData,
+ node => FilterNode(condition.expr, node),
+ emptyTestData.filter(condition).collect()
+ )
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala
new file mode 100644
index 0000000000..523c02f4a6
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala
@@ -0,0 +1,39 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql.execution.local
+
+import org.apache.spark.sql.test.SharedSQLContext
+
+class LimitNodeSuite extends LocalNodeTest with SharedSQLContext {
+
+ test("basic") {
+ checkAnswer(
+ testData,
+ node => LimitNode(10, node),
+ testData.limit(10).collect()
+ )
+ }
+
+ test("empty") {
+ checkAnswer(
+ emptyTestData,
+ node => LimitNode(10, node),
+ emptyTestData.limit(10).collect()
+ )
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala
new file mode 100644
index 0000000000..95f06081bd
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala
@@ -0,0 +1,146 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql.execution.local
+
+import scala.util.control.NonFatal
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.test.SQLTestUtils
+
+class LocalNodeTest extends SparkFunSuite {
+
+ /**
+ * Runs the LocalNode and makes sure the answer matches the expected result.
+ * @param input the input data to be used.
+ * @param nodeFunction a function which accepts the input LocalNode and uses it to instantiate
+ * the local physical operator that's being tested.
+ * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
+ * @param sortAnswers if true, the answers will be sorted by their toString representations prior
+ * to being compared.
+ */
+ protected def checkAnswer(
+ input: DataFrame,
+ nodeFunction: LocalNode => LocalNode,
+ expectedAnswer: Seq[Row],
+ sortAnswers: Boolean = true): Unit = {
+ doCheckAnswer(
+ input :: Nil,
+ nodes => nodeFunction(nodes.head),
+ expectedAnswer,
+ sortAnswers)
+ }
+
+ /**
+ * Runs the LocalNode and makes sure the answer matches the expected result.
+ * @param left the left input data to be used.
+ * @param right the right input data to be used.
+ * @param nodeFunction a function which accepts the input LocalNode and uses it to instantiate
+ * the local physical operator that's being tested.
+ * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
+ * @param sortAnswers if true, the answers will be sorted by their toString representations prior
+ * to being compared.
+ */
+ protected def checkAnswer2(
+ left: DataFrame,
+ right: DataFrame,
+ nodeFunction: (LocalNode, LocalNode) => LocalNode,
+ expectedAnswer: Seq[Row],
+ sortAnswers: Boolean = true): Unit = {
+ doCheckAnswer(
+ left :: right :: Nil,
+ nodes => nodeFunction(nodes(0), nodes(1)),
+ expectedAnswer,
+ sortAnswers)
+ }
+
+ /**
+ * Runs the `LocalNode`s and makes sure the answer matches the expected result.
+ * @param input the input data to be used.
+ * @param nodeFunction a function which accepts a sequence of input `LocalNode`s and uses them to
+ * instantiate the local physical operator that's being tested.
+ * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
+ * @param sortAnswers if true, the answers will be sorted by their toString representations prior
+ * to being compared.
+ */
+ protected def doCheckAnswer(
+ input: Seq[DataFrame],
+ nodeFunction: Seq[LocalNode] => LocalNode,
+ expectedAnswer: Seq[Row],
+ sortAnswers: Boolean = true): Unit = {
+ LocalNodeTest.checkAnswer(
+ input.map(dataFrameToSeqScanNode), nodeFunction, expectedAnswer, sortAnswers) match {
+ case Some(errorMessage) => fail(errorMessage)
+ case None =>
+ }
+ }
+
+ protected def dataFrameToSeqScanNode(df: DataFrame): SeqScanNode = {
+ new SeqScanNode(
+ df.queryExecution.sparkPlan.output,
+ df.queryExecution.toRdd.map(_.copy()).collect())
+ }
+
+}
+
+/**
+ * Helper methods for writing tests of individual local physical operators.
+ */
+object LocalNodeTest {
+
+ /**
+ * Runs the `LocalNode`s and makes sure the answer matches the expected result.
+ * @param input the input data to be used.
+ * @param nodeFunction a function which accepts the input `LocalNode`s and uses them to
+ * instantiate the local physical operator that's being tested.
+ * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
+ * @param sortAnswers if true, the answers will be sorted by their toString representations prior
+ * to being compared.
+ */
+ def checkAnswer(
+ input: Seq[SeqScanNode],
+ nodeFunction: Seq[LocalNode] => LocalNode,
+ expectedAnswer: Seq[Row],
+ sortAnswers: Boolean): Option[String] = {
+
+ val outputNode = nodeFunction(input)
+
+ val outputResult: Seq[Row] = try {
+ outputNode.collect()
+ } catch {
+ case NonFatal(e) =>
+ val errorMessage =
+ s"""
+ | Exception thrown while executing local plan:
+ | $outputNode
+ | == Exception ==
+ | $e
+ | ${org.apache.spark.sql.catalyst.util.stackTraceToString(e)}
+ """.stripMargin
+ return Some(errorMessage)
+ }
+
+ SQLTestUtils.compareAnswers(outputResult, expectedAnswer, sortAnswers).map { errorMessage =>
+ s"""
+ | Results do not match for local plan:
+ | $outputNode
+ | $errorMessage
+ """.stripMargin
+ }
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala
new file mode 100644
index 0000000000..ffcf092e2c
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala
@@ -0,0 +1,44 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql.execution.local
+
+import org.apache.spark.sql.test.SharedSQLContext
+
+class ProjectNodeSuite extends LocalNodeTest with SharedSQLContext {
+
+ test("basic") {
+ val output = testData.queryExecution.sparkPlan.output
+ val columns = Seq(output(1), output(0))
+ checkAnswer(
+ testData,
+ node => ProjectNode(columns, node),
+ testData.select("value", "key").collect()
+ )
+ }
+
+ test("empty") {
+ val output = emptyTestData.queryExecution.sparkPlan.output
+ val columns = Seq(output(1), output(0))
+ checkAnswer(
+ emptyTestData,
+ node => ProjectNode(columns, node),
+ emptyTestData.select("value", "key").collect()
+ )
+ }
+
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala
new file mode 100644
index 0000000000..34670287c3
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala
@@ -0,0 +1,52 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql.execution.local
+
+import org.apache.spark.sql.test.SharedSQLContext
+
+class UnionNodeSuite extends LocalNodeTest with SharedSQLContext {
+
+ test("basic") {
+ checkAnswer2(
+ testData,
+ testData,
+ (node1, node2) => UnionNode(Seq(node1, node2)),
+ testData.unionAll(testData).collect()
+ )
+ }
+
+ test("empty") {
+ checkAnswer2(
+ emptyTestData,
+ emptyTestData,
+ (node1, node2) => UnionNode(Seq(node1, node2)),
+ emptyTestData.unionAll(emptyTestData).collect()
+ )
+ }
+
+ test("complicated union") {
+ val dfs = Seq(testData, emptyTestData, emptyTestData, testData, testData, emptyTestData,
+ emptyTestData, emptyTestData, testData, emptyTestData)
+ doCheckAnswer(
+ dfs,
+ nodes => UnionNode(nodes),
+ dfs.reduce(_.unionAll(_)).collect()
+ )
+ }
+
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala
index 1374a97476..3fc02df954 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala
@@ -36,6 +36,13 @@ private[sql] trait SQLTestData { self =>
// Note: all test data should be lazy because the SQLContext is not set up yet.
+ protected lazy val emptyTestData: DataFrame = {
+ val df = _sqlContext.sparkContext.parallelize(
+ Seq.empty[Int].map(i => TestData(i, i.toString))).toDF()
+ df.registerTempTable("emptyTestData")
+ df
+ }
+
protected lazy val testData: DataFrame = {
val df = _sqlContext.sparkContext.parallelize(
(1 to 100).map(i => TestData(i, i.toString))).toDF()
@@ -240,6 +247,7 @@ private[sql] trait SQLTestData { self =>
*/
def loadTestData(): Unit = {
assert(_sqlContext != null, "attempted to initialize test data before SQLContext.")
+ emptyTestData
testData
testData2
testData3
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
index cdd691e035..dc08306ad9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
@@ -27,8 +27,9 @@ import org.apache.hadoop.conf.Configuration
import org.scalatest.BeforeAndAfterAll
import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.{DataFrame, SQLContext, SQLImplicits}
+import org.apache.spark.sql.{DataFrame, Row, SQLContext, SQLImplicits}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.util._
import org.apache.spark.util.Utils
/**
@@ -179,3 +180,46 @@ private[sql] trait SQLTestUtils
DataFrame(_sqlContext, plan)
}
}
+
+private[sql] object SQLTestUtils {
+
+ def compareAnswers(
+ sparkAnswer: Seq[Row],
+ expectedAnswer: Seq[Row],
+ sort: Boolean): Option[String] = {
+ def prepareAnswer(answer: Seq[Row]): Seq[Row] = {
+ // Converts data to types that we can do equality comparison using Scala collections.
+ // For BigDecimal type, the Scala type has a better definition of equality test (similar to
+ // Java's java.math.BigDecimal.compareTo).
+ // For binary arrays, we convert it to Seq to avoid of calling java.util.Arrays.equals for
+ // equality test.
+ // This function is copied from Catalyst's QueryTest
+ val converted: Seq[Row] = answer.map { s =>
+ Row.fromSeq(s.toSeq.map {
+ case d: java.math.BigDecimal => BigDecimal(d)
+ case b: Array[Byte] => b.toSeq
+ case o => o
+ })
+ }
+ if (sort) {
+ converted.sortBy(_.toString())
+ } else {
+ converted
+ }
+ }
+ if (prepareAnswer(expectedAnswer) != prepareAnswer(sparkAnswer)) {
+ val errorMessage =
+ s"""
+ | == Results ==
+ | ${sideBySide(
+ s"== Expected Answer - ${expectedAnswer.size} ==" +:
+ prepareAnswer(expectedAnswer).map(_.toString()),
+ s"== Actual Answer - ${sparkAnswer.size} ==" +:
+ prepareAnswer(sparkAnswer).map(_.toString())).mkString("\n")}
+ """.stripMargin
+ Some(errorMessage)
+ } else {
+ None
+ }
+ }
+}