aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorAndrew Or <andrew@databricks.com>2015-09-15 17:24:32 -0700
committerAndrew Or <andrew@databricks.com>2015-09-15 17:24:32 -0700
commit35a19f3357d2ec017cfefb90f1018403e9617de4 (patch)
tree8e28187f8bfc5bafee664121ca07ba5efc4ebbe1 /sql
parent38700ea40cb1dd0805cc926a9e629f93c99527ad (diff)
downloadspark-35a19f3357d2ec017cfefb90f1018403e9617de4.tar.gz
spark-35a19f3357d2ec017cfefb90f1018403e9617de4.tar.bz2
spark-35a19f3357d2ec017cfefb90f1018403e9617de4.zip
[SPARK-10613] [SPARK-10624] [SQL] Reduce LocalNode tests dependency on SQLContext
Instead of relying on `DataFrames` to verify our answers, we can just use simple arrays. This significantly simplifies the test logic for `LocalNode`s and reduces a lot of code duplicated from `SparkPlanTest`. This also fixes an additional issue [SPARK-10624](https://issues.apache.org/jira/browse/SPARK-10624) where the output of `TakeOrderedAndProjectNode` is not actually ordered. Author: Andrew Or <andrew@databricks.com> Closes #8764 from andrewor14/sql-local-tests-cleanup.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala16
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/local/DummyNode.scala68
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/local/ExpandNodeSuite.scala54
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala34
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala141
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/local/IntersectNodeSuite.scala24
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala28
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeSuite.scala73
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala165
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala316
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala39
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala35
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala50
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala49
17 files changed, 468 insertions, 636 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala
index 569cff565c..f96b62a67a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala
@@ -24,7 +24,7 @@ import org.apache.spark.sql.{SQLConf, Row}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.codegen._
-import org.apache.spark.sql.catalyst.trees.TreeNode
+import org.apache.spark.sql.catalyst.plans.QueryPlan
import org.apache.spark.sql.types.StructType
/**
@@ -33,18 +33,14 @@ import org.apache.spark.sql.types.StructType
* Before consuming the iterator, open function must be called.
* After consuming the iterator, close function must be called.
*/
-abstract class LocalNode(conf: SQLConf) extends TreeNode[LocalNode] with Logging {
+abstract class LocalNode(conf: SQLConf) extends QueryPlan[LocalNode] with Logging {
protected val codegenEnabled: Boolean = conf.codegenEnabled
protected val unsafeEnabled: Boolean = conf.unsafeEnabled
- lazy val schema: StructType = StructType.fromAttributes(output)
-
private[this] lazy val isTesting: Boolean = sys.props.contains("spark.testing")
- def output: Seq[Attribute]
-
/**
* Called before open(). Prepare can be used to reserve memory needed. It must NOT consume
* any input data.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala
index abf3df1c0c..793700803f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SampleNode.scala
@@ -17,13 +17,12 @@
package org.apache.spark.sql.execution.local
-import java.util.Random
-
import org.apache.spark.sql.SQLConf
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler}
+
/**
* Sample the dataset.
*
@@ -51,18 +50,15 @@ case class SampleNode(
override def open(): Unit = {
child.open()
- val (sampler, _seed) = if (withReplacement) {
- val random = new Random(seed)
+ val sampler =
+ if (withReplacement) {
// Disable gap sampling since the gap sampling method buffers two rows internally,
// requiring us to copy the row, which is more expensive than the random number generator.
- (new PoissonSampler[InternalRow](upperBound - lowerBound, useGapSamplingIfPossible = false),
- // Use the seed for partition 0 like PartitionwiseSampledRDD to generate the same result
- // of DataFrame
- random.nextLong())
+ new PoissonSampler[InternalRow](upperBound - lowerBound, useGapSamplingIfPossible = false)
} else {
- (new BernoulliCellSampler[InternalRow](lowerBound, upperBound), seed)
+ new BernoulliCellSampler[InternalRow](lowerBound, upperBound)
}
- sampler.setSeed(_seed)
+ sampler.setSeed(seed)
iterator = sampler.sample(child.asIterator)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala
index 53f1dcc65d..ae672fbca8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNode.scala
@@ -50,7 +50,7 @@ case class TakeOrderedAndProjectNode(
}
// Close it eagerly since we don't need it.
child.close()
- iterator = queue.iterator
+ iterator = queue.toArray.sorted(ord).iterator
}
override def next(): Boolean = {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
index de45ae4635..3d218f01c9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
@@ -238,7 +238,7 @@ object SparkPlanTest {
outputPlan transform {
case plan: SparkPlan =>
val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap
- plan.transformExpressions {
+ plan transformExpressions {
case UnresolvedAttribute(Seq(u)) =>
inputMap.getOrElse(u,
sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap"))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/DummyNode.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/DummyNode.scala
new file mode 100644
index 0000000000..efc3227dd6
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/DummyNode.scala
@@ -0,0 +1,68 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql.execution.local
+
+import org.apache.spark.sql.SQLConf
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.plans.logical.LocalRelation
+
+/**
+ * A dummy [[LocalNode]] that just returns rows from a [[LocalRelation]].
+ */
+private[local] case class DummyNode(
+ output: Seq[Attribute],
+ relation: LocalRelation,
+ conf: SQLConf)
+ extends LocalNode(conf) {
+
+ import DummyNode._
+
+ private var index: Int = CLOSED
+ private val input: Seq[InternalRow] = relation.data
+
+ def this(output: Seq[Attribute], data: Seq[Product], conf: SQLConf = new SQLConf) {
+ this(output, LocalRelation.fromProduct(output, data), conf)
+ }
+
+ def isOpen: Boolean = index != CLOSED
+
+ override def children: Seq[LocalNode] = Seq.empty
+
+ override def open(): Unit = {
+ index = -1
+ }
+
+ override def next(): Boolean = {
+ index += 1
+ index < input.size
+ }
+
+ override def fetch(): InternalRow = {
+ assert(index >= 0 && index < input.size)
+ input(index)
+ }
+
+ override def close(): Unit = {
+ index = CLOSED
+ }
+}
+
+private object DummyNode {
+ val CLOSED: Int = Int.MinValue
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ExpandNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ExpandNodeSuite.scala
index cfa7f3f6dc..bbd94d8da2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ExpandNodeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ExpandNodeSuite.scala
@@ -17,35 +17,33 @@
package org.apache.spark.sql.execution.local
+import org.apache.spark.sql.catalyst.dsl.expressions._
+
+
class ExpandNodeSuite extends LocalNodeTest {
- import testImplicits._
-
- test("expand") {
- val input = Seq((1, 1), (2, 2), (3, 3), (4, 4), (5, 5)).toDF("key", "value")
- checkAnswer(
- input,
- node =>
- ExpandNode(conf, Seq(
- Seq(
- input.col("key") + input.col("value"), input.col("key") - input.col("value")
- ).map(_.expr),
- Seq(
- input.col("key") * input.col("value"), input.col("key") / input.col("value")
- ).map(_.expr)
- ), node.output, node),
- Seq(
- (2, 0),
- (1, 1),
- (4, 0),
- (4, 1),
- (6, 0),
- (9, 1),
- (8, 0),
- (16, 1),
- (10, 0),
- (25, 1)
- ).toDF().collect()
- )
+ private def testExpand(inputData: Array[(Int, Int)] = Array.empty): Unit = {
+ val inputNode = new DummyNode(kvIntAttributes, inputData)
+ val projections = Seq(Seq('k + 'v, 'k - 'v), Seq('k * 'v, 'k / 'v))
+ val expandNode = new ExpandNode(conf, projections, inputNode.output, inputNode)
+ val resolvedNode = resolveExpressions(expandNode)
+ val expectedOutput = {
+ val firstHalf = inputData.map { case (k, v) => (k + v, k - v) }
+ val secondHalf = inputData.map { case (k, v) => (k * v, k / v) }
+ firstHalf ++ secondHalf
+ }
+ val actualOutput = resolvedNode.collect().map { case row =>
+ (row.getInt(0), row.getInt(1))
+ }
+ assert(actualOutput.toSet === expectedOutput.toSet)
+ }
+
+ test("empty") {
+ testExpand()
}
+
+ test("basic") {
+ testExpand((1 to 100).map { i => (i, i * 1000) }.toArray)
+ }
+
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala
index a12670e347..4eadce646d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala
@@ -17,25 +17,29 @@
package org.apache.spark.sql.execution.local
-import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.catalyst.dsl.expressions._
-class FilterNodeSuite extends LocalNodeTest with SharedSQLContext {
- test("basic") {
- val condition = (testData.col("key") % 2) === 0
- checkAnswer(
- testData,
- node => FilterNode(conf, condition.expr, node),
- testData.filter(condition).collect()
- )
+class FilterNodeSuite extends LocalNodeTest {
+
+ private def testFilter(inputData: Array[(Int, Int)] = Array.empty): Unit = {
+ val cond = 'k % 2 === 0
+ val inputNode = new DummyNode(kvIntAttributes, inputData)
+ val filterNode = new FilterNode(conf, cond, inputNode)
+ val resolvedNode = resolveExpressions(filterNode)
+ val expectedOutput = inputData.filter { case (k, _) => k % 2 == 0 }
+ val actualOutput = resolvedNode.collect().map { case row =>
+ (row.getInt(0), row.getInt(1))
+ }
+ assert(actualOutput === expectedOutput)
}
test("empty") {
- val condition = (emptyTestData.col("key") % 2) === 0
- checkAnswer(
- emptyTestData,
- node => FilterNode(conf, condition.expr, node),
- emptyTestData.filter(condition).collect()
- )
+ testFilter()
+ }
+
+ test("basic") {
+ testFilter((1 to 100).map { i => (i, i) }.toArray)
}
+
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala
index 78d891351f..5c1bdb088e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala
@@ -18,99 +18,80 @@
package org.apache.spark.sql.execution.local
import org.apache.spark.sql.SQLConf
-import org.apache.spark.sql.execution.joins
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide}
+
class HashJoinNodeSuite extends LocalNodeTest {
- import testImplicits._
+ // Test all combinations of the two dimensions: with/out unsafe and build sides
+ private val maybeUnsafeAndCodegen = Seq(false, true)
+ private val buildSides = Seq(BuildLeft, BuildRight)
+ maybeUnsafeAndCodegen.foreach { unsafeAndCodegen =>
+ buildSides.foreach { buildSide =>
+ testJoin(unsafeAndCodegen, buildSide)
+ }
+ }
- def joinSuite(suiteName: String, confPairs: (String, String)*): Unit = {
- test(s"$suiteName: inner join with one match per row") {
- withSQLConf(confPairs: _*) {
- checkAnswer2(
- upperCaseData,
- lowerCaseData,
- wrapForUnsafe(
- (node1, node2) => HashJoinNode(
- conf,
- Seq(upperCaseData.col("N").expr),
- Seq(lowerCaseData.col("n").expr),
- joins.BuildLeft,
- node1,
- node2)
- ),
- upperCaseData.join(lowerCaseData, $"n" === $"N").collect()
- )
+ /**
+ * Test inner hash join with varying degrees of matches.
+ */
+ private def testJoin(
+ unsafeAndCodegen: Boolean,
+ buildSide: BuildSide): Unit = {
+ val simpleOrUnsafe = if (!unsafeAndCodegen) "simple" else "unsafe"
+ val testNamePrefix = s"$simpleOrUnsafe / $buildSide"
+ val someData = (1 to 100).map { i => (i, "burger" + i) }.toArray
+ val conf = new SQLConf
+ conf.setConf(SQLConf.UNSAFE_ENABLED, unsafeAndCodegen)
+ conf.setConf(SQLConf.CODEGEN_ENABLED, unsafeAndCodegen)
+
+ // Actual test body
+ def runTest(leftInput: Array[(Int, String)], rightInput: Array[(Int, String)]): Unit = {
+ val rightInputMap = rightInput.toMap
+ val leftNode = new DummyNode(joinNameAttributes, leftInput)
+ val rightNode = new DummyNode(joinNicknameAttributes, rightInput)
+ val makeNode = (node1: LocalNode, node2: LocalNode) => {
+ resolveExpressions(new HashJoinNode(
+ conf, Seq('id1), Seq('id2), buildSide, node1, node2))
+ }
+ val makeUnsafeNode = if (unsafeAndCodegen) wrapForUnsafe(makeNode) else makeNode
+ val hashJoinNode = makeUnsafeNode(leftNode, rightNode)
+ val expectedOutput = leftInput
+ .filter { case (k, _) => rightInputMap.contains(k) }
+ .map { case (k, v) => (k, v, k, rightInputMap(k)) }
+ val actualOutput = hashJoinNode.collect().map { row =>
+ // (id, name, id, nickname)
+ (row.getInt(0), row.getString(1), row.getInt(2), row.getString(3))
}
+ assert(actualOutput === expectedOutput)
}
- test(s"$suiteName: inner join with multiple matches") {
- withSQLConf(confPairs: _*) {
- val x = testData2.where($"a" === 1).as("x")
- val y = testData2.where($"a" === 1).as("y")
- checkAnswer2(
- x,
- y,
- wrapForUnsafe(
- (node1, node2) => HashJoinNode(
- conf,
- Seq(x.col("a").expr),
- Seq(y.col("a").expr),
- joins.BuildLeft,
- node1,
- node2)
- ),
- x.join(y).where($"x.a" === $"y.a").collect()
- )
- }
+ test(s"$testNamePrefix: empty") {
+ runTest(Array.empty, Array.empty)
+ runTest(someData, Array.empty)
+ runTest(Array.empty, someData)
}
- test(s"$suiteName: inner join, no matches") {
- withSQLConf(confPairs: _*) {
- val x = testData2.where($"a" === 1).as("x")
- val y = testData2.where($"a" === 2).as("y")
- checkAnswer2(
- x,
- y,
- wrapForUnsafe(
- (node1, node2) => HashJoinNode(
- conf,
- Seq(x.col("a").expr),
- Seq(y.col("a").expr),
- joins.BuildLeft,
- node1,
- node2)
- ),
- Nil
- )
- }
+ test(s"$testNamePrefix: no matches") {
+ val someIrrelevantData = (10000 to 100100).map { i => (i, "piper" + i) }.toArray
+ runTest(someData, Array.empty)
+ runTest(Array.empty, someData)
+ runTest(someData, someIrrelevantData)
+ runTest(someIrrelevantData, someData)
}
- test(s"$suiteName: big inner join, 4 matches per row") {
- withSQLConf(confPairs: _*) {
- val bigData = testData.unionAll(testData).unionAll(testData).unionAll(testData)
- val bigDataX = bigData.as("x")
- val bigDataY = bigData.as("y")
+ test(s"$testNamePrefix: partial matches") {
+ val someOtherData = (50 to 150).map { i => (i, "finnegan" + i) }.toArray
+ runTest(someData, someOtherData)
+ runTest(someOtherData, someData)
+ }
- checkAnswer2(
- bigDataX,
- bigDataY,
- wrapForUnsafe(
- (node1, node2) =>
- HashJoinNode(
- conf,
- Seq(bigDataX.col("key").expr),
- Seq(bigDataY.col("key").expr),
- joins.BuildLeft,
- node1,
- node2)
- ),
- bigDataX.join(bigDataY).where($"x.key" === $"y.key").collect())
- }
+ test(s"$testNamePrefix: full matches") {
+ val someSuperRelevantData = someData.map { case (k, v) => (k, "cooper" + v) }.toArray
+ runTest(someData, someSuperRelevantData)
+ runTest(someSuperRelevantData, someData)
}
}
- joinSuite(
- "general", SQLConf.CODEGEN_ENABLED.key -> "false", SQLConf.UNSAFE_ENABLED.key -> "false")
- joinSuite("tungsten", SQLConf.CODEGEN_ENABLED.key -> "true", SQLConf.UNSAFE_ENABLED.key -> "true")
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/IntersectNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/IntersectNodeSuite.scala
index 7deaa375fc..c0ad2021b2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/IntersectNodeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/IntersectNodeSuite.scala
@@ -17,19 +17,21 @@
package org.apache.spark.sql.execution.local
-class IntersectNodeSuite extends LocalNodeTest {
- import testImplicits._
+class IntersectNodeSuite extends LocalNodeTest {
test("basic") {
- val input1 = (1 to 10).map(i => (i, i.toString)).toDF("key", "value")
- val input2 = (1 to 10).filter(_ % 2 == 0).map(i => (i, i.toString)).toDF("key", "value")
-
- checkAnswer2(
- input1,
- input2,
- (node1, node2) => IntersectNode(conf, node1, node2),
- input1.intersect(input2).collect()
- )
+ val n = 100
+ val leftData = (1 to n).filter { i => i % 2 == 0 }.map { i => (i, i) }.toArray
+ val rightData = (1 to n).filter { i => i % 3 == 0 }.map { i => (i, i) }.toArray
+ val leftNode = new DummyNode(kvIntAttributes, leftData)
+ val rightNode = new DummyNode(kvIntAttributes, rightData)
+ val intersectNode = new IntersectNode(conf, leftNode, rightNode)
+ val expectedOutput = leftData.intersect(rightData)
+ val actualOutput = intersectNode.collect().map { case row =>
+ (row.getInt(0), row.getInt(1))
+ }
+ assert(actualOutput === expectedOutput)
}
+
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala
index 3b18390200..fb790636a3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala
@@ -17,23 +17,25 @@
package org.apache.spark.sql.execution.local
-import org.apache.spark.sql.test.SharedSQLContext
-class LimitNodeSuite extends LocalNodeTest with SharedSQLContext {
+class LimitNodeSuite extends LocalNodeTest {
- test("basic") {
- checkAnswer(
- testData,
- node => LimitNode(conf, 10, node),
- testData.limit(10).collect()
- )
+ private def testLimit(inputData: Array[(Int, Int)] = Array.empty, limit: Int = 10): Unit = {
+ val inputNode = new DummyNode(kvIntAttributes, inputData)
+ val limitNode = new LimitNode(conf, limit, inputNode)
+ val expectedOutput = inputData.take(limit)
+ val actualOutput = limitNode.collect().map { case row =>
+ (row.getInt(0), row.getInt(1))
+ }
+ assert(actualOutput === expectedOutput)
}
test("empty") {
- checkAnswer(
- emptyTestData,
- node => LimitNode(conf, 10, node),
- emptyTestData.limit(10).collect()
- )
+ testLimit()
}
+
+ test("basic") {
+ testLimit((1 to 100).map { i => (i, i) }.toArray, 20)
+ }
+
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeSuite.scala
index b89fa46f8b..0d1ed99eec 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeSuite.scala
@@ -17,28 +17,24 @@
package org.apache.spark.sql.execution.local
-import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.SQLConf
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.types.IntegerType
-class LocalNodeSuite extends SparkFunSuite {
- private val data = (1 to 100).toArray
+class LocalNodeSuite extends LocalNodeTest {
+ private val data = (1 to 100).map { i => (i, i) }.toArray
test("basic open, next, fetch, close") {
- val node = new DummyLocalNode(data)
+ val node = new DummyNode(kvIntAttributes, data)
assert(!node.isOpen)
node.open()
assert(node.isOpen)
- data.foreach { i =>
+ data.foreach { case (k, v) =>
assert(node.next())
// fetch should be idempotent
val fetched = node.fetch()
assert(node.fetch() === fetched)
assert(node.fetch() === fetched)
- assert(node.fetch().numFields === 1)
- assert(node.fetch().getInt(0) === i)
+ assert(node.fetch().numFields === 2)
+ assert(node.fetch().getInt(0) === k)
+ assert(node.fetch().getInt(1) === v)
}
assert(!node.next())
node.close()
@@ -46,16 +42,17 @@ class LocalNodeSuite extends SparkFunSuite {
}
test("asIterator") {
- val node = new DummyLocalNode(data)
+ val node = new DummyNode(kvIntAttributes, data)
val iter = node.asIterator
node.open()
- data.foreach { i =>
+ data.foreach { case (k, v) =>
// hasNext should be idempotent
assert(iter.hasNext)
assert(iter.hasNext)
val item = iter.next()
- assert(item.numFields === 1)
- assert(item.getInt(0) === i)
+ assert(item.numFields === 2)
+ assert(item.getInt(0) === k)
+ assert(item.getInt(1) === v)
}
intercept[NoSuchElementException] {
iter.next()
@@ -64,53 +61,13 @@ class LocalNodeSuite extends SparkFunSuite {
}
test("collect") {
- val node = new DummyLocalNode(data)
+ val node = new DummyNode(kvIntAttributes, data)
node.open()
val collected = node.collect()
assert(collected.size === data.size)
- assert(collected.forall(_.size === 1))
- assert(collected.map(_.getInt(0)) === data)
+ assert(collected.forall(_.size === 2))
+ assert(collected.map { case row => (row.getInt(0), row.getInt(0)) } === data)
node.close()
}
}
-
-/**
- * A dummy [[LocalNode]] that just returns one row per integer in the input.
- */
-private case class DummyLocalNode(conf: SQLConf, input: Array[Int]) extends LocalNode(conf) {
- private var index = Int.MinValue
-
- def this(input: Array[Int]) {
- this(new SQLConf, input)
- }
-
- def isOpen: Boolean = {
- index != Int.MinValue
- }
-
- override def output: Seq[Attribute] = {
- Seq(AttributeReference("something", IntegerType)())
- }
-
- override def children: Seq[LocalNode] = Seq.empty
-
- override def open(): Unit = {
- index = -1
- }
-
- override def next(): Boolean = {
- index += 1
- index < input.size
- }
-
- override def fetch(): InternalRow = {
- assert(index >= 0 && index < input.size)
- val values = Array(input(index).asInstanceOf[Any])
- new GenericInternalRow(values)
- }
-
- override def close(): Unit = {
- index = Int.MinValue
- }
-}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala
index 86dd28064c..098050bcd2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala
@@ -17,147 +17,54 @@
package org.apache.spark.sql.execution.local
-import scala.util.control.NonFatal
-
import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.{DataFrame, Row, SQLConf}
-import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils}
+import org.apache.spark.sql.SQLConf
+import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
+import org.apache.spark.sql.catalyst.expressions.AttributeReference
+import org.apache.spark.sql.types.{IntegerType, StringType}
-class LocalNodeTest extends SparkFunSuite with SharedSQLContext {
- def conf: SQLConf = sqlContext.conf
+class LocalNodeTest extends SparkFunSuite {
- protected def wrapForUnsafe(
- f: (LocalNode, LocalNode) => LocalNode): (LocalNode, LocalNode) => LocalNode = {
- if (conf.unsafeEnabled) {
- (left: LocalNode, right: LocalNode) => {
- val _left = ConvertToUnsafeNode(conf, left)
- val _right = ConvertToUnsafeNode(conf, right)
- val r = f(_left, _right)
- ConvertToSafeNode(conf, r)
- }
- } else {
- f
- }
- }
-
- /**
- * Runs the LocalNode and makes sure the answer matches the expected result.
- * @param input the input data to be used.
- * @param nodeFunction a function which accepts the input LocalNode and uses it to instantiate
- * the local physical operator that's being tested.
- * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
- * @param sortAnswers if true, the answers will be sorted by their toString representations prior
- * to being compared.
- */
- protected def checkAnswer(
- input: DataFrame,
- nodeFunction: LocalNode => LocalNode,
- expectedAnswer: Seq[Row],
- sortAnswers: Boolean = true): Unit = {
- doCheckAnswer(
- input :: Nil,
- nodes => nodeFunction(nodes.head),
- expectedAnswer,
- sortAnswers)
- }
-
- /**
- * Runs the LocalNode and makes sure the answer matches the expected result.
- * @param left the left input data to be used.
- * @param right the right input data to be used.
- * @param nodeFunction a function which accepts the input LocalNode and uses it to instantiate
- * the local physical operator that's being tested.
- * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
- * @param sortAnswers if true, the answers will be sorted by their toString representations prior
- * to being compared.
- */
- protected def checkAnswer2(
- left: DataFrame,
- right: DataFrame,
- nodeFunction: (LocalNode, LocalNode) => LocalNode,
- expectedAnswer: Seq[Row],
- sortAnswers: Boolean = true): Unit = {
- doCheckAnswer(
- left :: right :: Nil,
- nodes => nodeFunction(nodes(0), nodes(1)),
- expectedAnswer,
- sortAnswers)
- }
+ protected val conf: SQLConf = new SQLConf
+ protected val kvIntAttributes = Seq(
+ AttributeReference("k", IntegerType)(),
+ AttributeReference("v", IntegerType)())
+ protected val joinNameAttributes = Seq(
+ AttributeReference("id1", IntegerType)(),
+ AttributeReference("name", StringType)())
+ protected val joinNicknameAttributes = Seq(
+ AttributeReference("id2", IntegerType)(),
+ AttributeReference("nickname", StringType)())
/**
- * Runs the `LocalNode`s and makes sure the answer matches the expected result.
- * @param input the input data to be used.
- * @param nodeFunction a function which accepts a sequence of input `LocalNode`s and uses them to
- * instantiate the local physical operator that's being tested.
- * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
- * @param sortAnswers if true, the answers will be sorted by their toString representations prior
- * to being compared.
+ * Wrap a function processing two [[LocalNode]]s such that:
+ * (1) all input rows are automatically converted to unsafe rows
+ * (2) all output rows are automatically converted back to safe rows
*/
- protected def doCheckAnswer(
- input: Seq[DataFrame],
- nodeFunction: Seq[LocalNode] => LocalNode,
- expectedAnswer: Seq[Row],
- sortAnswers: Boolean = true): Unit = {
- LocalNodeTest.checkAnswer(
- input.map(dataFrameToSeqScanNode), nodeFunction, expectedAnswer, sortAnswers) match {
- case Some(errorMessage) => fail(errorMessage)
- case None =>
+ protected def wrapForUnsafe(
+ f: (LocalNode, LocalNode) => LocalNode): (LocalNode, LocalNode) => LocalNode = {
+ (left: LocalNode, right: LocalNode) => {
+ val _left = ConvertToUnsafeNode(conf, left)
+ val _right = ConvertToUnsafeNode(conf, right)
+ val r = f(_left, _right)
+ ConvertToSafeNode(conf, r)
}
}
- protected def dataFrameToSeqScanNode(df: DataFrame): SeqScanNode = {
- new SeqScanNode(
- conf,
- df.queryExecution.sparkPlan.output,
- df.queryExecution.toRdd.map(_.copy()).collect())
- }
-
-}
-
-/**
- * Helper methods for writing tests of individual local physical operators.
- */
-object LocalNodeTest {
-
/**
- * Runs the `LocalNode`s and makes sure the answer matches the expected result.
- * @param input the input data to be used.
- * @param nodeFunction a function which accepts the input `LocalNode`s and uses them to
- * instantiate the local physical operator that's being tested.
- * @param expectedAnswer the expected result in a [[Seq]] of [[Row]]s.
- * @param sortAnswers if true, the answers will be sorted by their toString representations prior
- * to being compared.
+ * Recursively resolve all expressions in a [[LocalNode]] using the node's attributes.
*/
- def checkAnswer(
- input: Seq[SeqScanNode],
- nodeFunction: Seq[LocalNode] => LocalNode,
- expectedAnswer: Seq[Row],
- sortAnswers: Boolean): Option[String] = {
-
- val outputNode = nodeFunction(input)
-
- val outputResult: Seq[Row] = try {
- outputNode.collect()
- } catch {
- case NonFatal(e) =>
- val errorMessage =
- s"""
- | Exception thrown while executing local plan:
- | $outputNode
- | == Exception ==
- | $e
- | ${org.apache.spark.sql.catalyst.util.stackTraceToString(e)}
- """.stripMargin
- return Some(errorMessage)
- }
-
- SQLTestUtils.compareAnswers(outputResult, expectedAnswer, sortAnswers).map { errorMessage =>
- s"""
- | Results do not match for local plan:
- | $outputNode
- | $errorMessage
- """.stripMargin
+ protected def resolveExpressions(outputNode: LocalNode): LocalNode = {
+ outputNode transform {
+ case node: LocalNode =>
+ val inputMap = node.output.map { a => (a.name, a) }.toMap
+ node transformExpressions {
+ case UnresolvedAttribute(Seq(u)) =>
+ inputMap.getOrElse(u,
+ sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap"))
+ }
}
}
+
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala
index b1ef26ba82..40299d9d5e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/NestedLoopJoinNodeSuite.scala
@@ -18,222 +18,128 @@
package org.apache.spark.sql.execution.local
import org.apache.spark.sql.SQLConf
-import org.apache.spark.sql.catalyst.plans.{FullOuter, LeftOuter, RightOuter}
+import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.catalyst.plans.{FullOuter, JoinType, LeftOuter, RightOuter}
import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide}
+
class NestedLoopJoinNodeSuite extends LocalNodeTest {
- import testImplicits._
-
- private def joinSuite(
- suiteName: String, buildSide: BuildSide, confPairs: (String, String)*): Unit = {
- test(s"$suiteName: left outer join") {
- withSQLConf(confPairs: _*) {
- checkAnswer2(
- upperCaseData,
- lowerCaseData,
- wrapForUnsafe(
- (node1, node2) => NestedLoopJoinNode(
- conf,
- node1,
- node2,
- buildSide,
- LeftOuter,
- Some((upperCaseData.col("N") === lowerCaseData.col("n")).expr))
- ),
- upperCaseData.join(lowerCaseData, $"n" === $"N", "left").collect())
-
- checkAnswer2(
- upperCaseData,
- lowerCaseData,
- wrapForUnsafe(
- (node1, node2) => NestedLoopJoinNode(
- conf,
- node1,
- node2,
- buildSide,
- LeftOuter,
- Some(
- (upperCaseData.col("N") === lowerCaseData.col("n") &&
- lowerCaseData.col("n") > 1).expr))
- ),
- upperCaseData.join(lowerCaseData, $"n" === $"N" && $"n" > 1, "left").collect())
-
- checkAnswer2(
- upperCaseData,
- lowerCaseData,
- wrapForUnsafe(
- (node1, node2) => NestedLoopJoinNode(
- conf,
- node1,
- node2,
- buildSide,
- LeftOuter,
- Some(
- (upperCaseData.col("N") === lowerCaseData.col("n") &&
- upperCaseData.col("N") > 1).expr))
- ),
- upperCaseData.join(lowerCaseData, $"n" === $"N" && $"N" > 1, "left").collect())
-
- checkAnswer2(
- upperCaseData,
- lowerCaseData,
- wrapForUnsafe(
- (node1, node2) => NestedLoopJoinNode(
- conf,
- node1,
- node2,
- buildSide,
- LeftOuter,
- Some(
- (upperCaseData.col("N") === lowerCaseData.col("n") &&
- lowerCaseData.col("l") > upperCaseData.col("L")).expr))
- ),
- upperCaseData.join(lowerCaseData, $"n" === $"N" && $"l" > $"L", "left").collect())
+ // Test all combinations of the three dimensions: with/out unsafe, build sides, and join types
+ private val maybeUnsafeAndCodegen = Seq(false, true)
+ private val buildSides = Seq(BuildLeft, BuildRight)
+ private val joinTypes = Seq(LeftOuter, RightOuter, FullOuter)
+ maybeUnsafeAndCodegen.foreach { unsafeAndCodegen =>
+ buildSides.foreach { buildSide =>
+ joinTypes.foreach { joinType =>
+ testJoin(unsafeAndCodegen, buildSide, joinType)
}
}
+ }
- test(s"$suiteName: right outer join") {
- withSQLConf(confPairs: _*) {
- checkAnswer2(
- lowerCaseData,
- upperCaseData,
- wrapForUnsafe(
- (node1, node2) => NestedLoopJoinNode(
- conf,
- node1,
- node2,
- buildSide,
- RightOuter,
- Some((lowerCaseData.col("n") === upperCaseData.col("N")).expr))
- ),
- lowerCaseData.join(upperCaseData, $"n" === $"N", "right").collect())
-
- checkAnswer2(
- lowerCaseData,
- upperCaseData,
- wrapForUnsafe(
- (node1, node2) => NestedLoopJoinNode(
- conf,
- node1,
- node2,
- buildSide,
- RightOuter,
- Some((lowerCaseData.col("n") === upperCaseData.col("N") &&
- lowerCaseData.col("n") > 1).expr))
- ),
- lowerCaseData.join(upperCaseData, $"n" === $"N" && $"n" > 1, "right").collect())
-
- checkAnswer2(
- lowerCaseData,
- upperCaseData,
- wrapForUnsafe(
- (node1, node2) => NestedLoopJoinNode(
- conf,
- node1,
- node2,
- buildSide,
- RightOuter,
- Some((lowerCaseData.col("n") === upperCaseData.col("N") &&
- upperCaseData.col("N") > 1).expr))
- ),
- lowerCaseData.join(upperCaseData, $"n" === $"N" && $"N" > 1, "right").collect())
-
- checkAnswer2(
- lowerCaseData,
- upperCaseData,
- wrapForUnsafe(
- (node1, node2) => NestedLoopJoinNode(
- conf,
- node1,
- node2,
- buildSide,
- RightOuter,
- Some((lowerCaseData.col("n") === upperCaseData.col("N") &&
- lowerCaseData.col("l") > upperCaseData.col("L")).expr))
- ),
- lowerCaseData.join(upperCaseData, $"n" === $"N" && $"l" > $"L", "right").collect())
+ /**
+ * Test outer nested loop joins with varying degrees of matches.
+ */
+ private def testJoin(
+ unsafeAndCodegen: Boolean,
+ buildSide: BuildSide,
+ joinType: JoinType): Unit = {
+ val simpleOrUnsafe = if (!unsafeAndCodegen) "simple" else "unsafe"
+ val testNamePrefix = s"$simpleOrUnsafe / $buildSide / $joinType"
+ val someData = (1 to 100).map { i => (i, "burger" + i) }.toArray
+ val conf = new SQLConf
+ conf.setConf(SQLConf.UNSAFE_ENABLED, unsafeAndCodegen)
+ conf.setConf(SQLConf.CODEGEN_ENABLED, unsafeAndCodegen)
+
+ // Actual test body
+ def runTest(
+ joinType: JoinType,
+ leftInput: Array[(Int, String)],
+ rightInput: Array[(Int, String)]): Unit = {
+ val leftNode = new DummyNode(joinNameAttributes, leftInput)
+ val rightNode = new DummyNode(joinNicknameAttributes, rightInput)
+ val cond = 'id1 === 'id2
+ val makeNode = (node1: LocalNode, node2: LocalNode) => {
+ resolveExpressions(
+ new NestedLoopJoinNode(conf, node1, node2, buildSide, joinType, Some(cond)))
}
+ val makeUnsafeNode = if (unsafeAndCodegen) wrapForUnsafe(makeNode) else makeNode
+ val hashJoinNode = makeUnsafeNode(leftNode, rightNode)
+ val expectedOutput = generateExpectedOutput(leftInput, rightInput, joinType)
+ val actualOutput = hashJoinNode.collect().map { row =>
+ // (id, name, id, nickname)
+ (row.getInt(0), row.getString(1), row.getInt(2), row.getString(3))
+ }
+ assert(actualOutput.toSet === expectedOutput.toSet)
}
- test(s"$suiteName: full outer join") {
- withSQLConf(confPairs: _*) {
- checkAnswer2(
- lowerCaseData,
- upperCaseData,
- wrapForUnsafe(
- (node1, node2) => NestedLoopJoinNode(
- conf,
- node1,
- node2,
- buildSide,
- FullOuter,
- Some((lowerCaseData.col("n") === upperCaseData.col("N")).expr))
- ),
- lowerCaseData.join(upperCaseData, $"n" === $"N", "full").collect())
-
- checkAnswer2(
- lowerCaseData,
- upperCaseData,
- wrapForUnsafe(
- (node1, node2) => NestedLoopJoinNode(
- conf,
- node1,
- node2,
- buildSide,
- FullOuter,
- Some((lowerCaseData.col("n") === upperCaseData.col("N") &&
- lowerCaseData.col("n") > 1).expr))
- ),
- lowerCaseData.join(upperCaseData, $"n" === $"N" && $"n" > 1, "full").collect())
-
- checkAnswer2(
- lowerCaseData,
- upperCaseData,
- wrapForUnsafe(
- (node1, node2) => NestedLoopJoinNode(
- conf,
- node1,
- node2,
- buildSide,
- FullOuter,
- Some((lowerCaseData.col("n") === upperCaseData.col("N") &&
- upperCaseData.col("N") > 1).expr))
- ),
- lowerCaseData.join(upperCaseData, $"n" === $"N" && $"N" > 1, "full").collect())
-
- checkAnswer2(
- lowerCaseData,
- upperCaseData,
- wrapForUnsafe(
- (node1, node2) => NestedLoopJoinNode(
- conf,
- node1,
- node2,
- buildSide,
- FullOuter,
- Some((lowerCaseData.col("n") === upperCaseData.col("N") &&
- lowerCaseData.col("l") > upperCaseData.col("L")).expr))
- ),
- lowerCaseData.join(upperCaseData, $"n" === $"N" && $"l" > $"L", "full").collect())
- }
+ test(s"$testNamePrefix: empty") {
+ runTest(joinType, Array.empty, Array.empty)
+ }
+
+ test(s"$testNamePrefix: no matches") {
+ val someIrrelevantData = (10000 to 10100).map { i => (i, "piper" + i) }.toArray
+ runTest(joinType, someData, Array.empty)
+ runTest(joinType, Array.empty, someData)
+ runTest(joinType, someData, someIrrelevantData)
+ runTest(joinType, someIrrelevantData, someData)
+ }
+
+ test(s"$testNamePrefix: partial matches") {
+ val someOtherData = (50 to 150).map { i => (i, "finnegan" + i) }.toArray
+ runTest(joinType, someData, someOtherData)
+ runTest(joinType, someOtherData, someData)
+ }
+
+ test(s"$testNamePrefix: full matches") {
+ val someSuperRelevantData = someData.map { case (k, v) => (k, "cooper" + v) }
+ runTest(joinType, someData, someSuperRelevantData)
+ runTest(joinType, someSuperRelevantData, someData)
+ }
+ }
+
+ /**
+ * Helper method to generate the expected output of a test based on the join type.
+ */
+ private def generateExpectedOutput(
+ leftInput: Array[(Int, String)],
+ rightInput: Array[(Int, String)],
+ joinType: JoinType): Array[(Int, String, Int, String)] = {
+ joinType match {
+ case LeftOuter =>
+ val rightInputMap = rightInput.toMap
+ leftInput.map { case (k, v) =>
+ val rightKey = rightInputMap.get(k).map { _ => k }.getOrElse(0)
+ val rightValue = rightInputMap.getOrElse(k, null)
+ (k, v, rightKey, rightValue)
+ }
+
+ case RightOuter =>
+ val leftInputMap = leftInput.toMap
+ rightInput.map { case (k, v) =>
+ val leftKey = leftInputMap.get(k).map { _ => k }.getOrElse(0)
+ val leftValue = leftInputMap.getOrElse(k, null)
+ (leftKey, leftValue, k, v)
+ }
+
+ case FullOuter =>
+ val leftInputMap = leftInput.toMap
+ val rightInputMap = rightInput.toMap
+ val leftOutput = leftInput.map { case (k, v) =>
+ val rightKey = rightInputMap.get(k).map { _ => k }.getOrElse(0)
+ val rightValue = rightInputMap.getOrElse(k, null)
+ (k, v, rightKey, rightValue)
+ }
+ val rightOutput = rightInput.map { case (k, v) =>
+ val leftKey = leftInputMap.get(k).map { _ => k }.getOrElse(0)
+ val leftValue = leftInputMap.getOrElse(k, null)
+ (leftKey, leftValue, k, v)
+ }
+ (leftOutput ++ rightOutput).distinct
+
+ case other =>
+ throw new IllegalArgumentException(s"Join type $other is not applicable")
}
}
- joinSuite(
- "general-build-left",
- BuildLeft,
- SQLConf.CODEGEN_ENABLED.key -> "false", SQLConf.UNSAFE_ENABLED.key -> "false")
- joinSuite(
- "general-build-right",
- BuildRight,
- SQLConf.CODEGEN_ENABLED.key -> "false", SQLConf.UNSAFE_ENABLED.key -> "false")
- joinSuite(
- "tungsten-build-left",
- BuildLeft,
- SQLConf.CODEGEN_ENABLED.key -> "true", SQLConf.UNSAFE_ENABLED.key -> "true")
- joinSuite(
- "tungsten-build-right",
- BuildRight,
- SQLConf.CODEGEN_ENABLED.key -> "true", SQLConf.UNSAFE_ENABLED.key -> "true")
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala
index 38e0a230c4..02ecb23d34 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala
@@ -17,28 +17,33 @@
package org.apache.spark.sql.execution.local
-import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference, NamedExpression}
+import org.apache.spark.sql.types.{IntegerType, StringType}
-class ProjectNodeSuite extends LocalNodeTest with SharedSQLContext {
- test("basic") {
- val output = testData.queryExecution.sparkPlan.output
- val columns = Seq(output(1), output(0))
- checkAnswer(
- testData,
- node => ProjectNode(conf, columns, node),
- testData.select("value", "key").collect()
- )
+class ProjectNodeSuite extends LocalNodeTest {
+ private val pieAttributes = Seq(
+ AttributeReference("id", IntegerType)(),
+ AttributeReference("age", IntegerType)(),
+ AttributeReference("name", StringType)())
+
+ private def testProject(inputData: Array[(Int, Int, String)] = Array.empty): Unit = {
+ val inputNode = new DummyNode(pieAttributes, inputData)
+ val columns = Seq[NamedExpression](inputNode.output(0), inputNode.output(2))
+ val projectNode = new ProjectNode(conf, columns, inputNode)
+ val expectedOutput = inputData.map { case (id, age, name) => (id, name) }
+ val actualOutput = projectNode.collect().map { case row =>
+ (row.getInt(0), row.getString(1))
+ }
+ assert(actualOutput === expectedOutput)
}
test("empty") {
- val output = emptyTestData.queryExecution.sparkPlan.output
- val columns = Seq(output(1), output(0))
- checkAnswer(
- emptyTestData,
- node => ProjectNode(conf, columns, node),
- emptyTestData.select("value", "key").collect()
- )
+ testProject()
+ }
+
+ test("basic") {
+ testProject((1 to 100).map { i => (i, i + 1, "pie" + i) }.toArray)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala
index 87a7da4539..a3e83bbd51 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/SampleNodeSuite.scala
@@ -17,21 +17,32 @@
package org.apache.spark.sql.execution.local
-class SampleNodeSuite extends LocalNodeTest {
+import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler}
+
- import testImplicits._
+class SampleNodeSuite extends LocalNodeTest {
private def testSample(withReplacement: Boolean): Unit = {
- test(s"withReplacement: $withReplacement") {
- val seed = 0L
- val input = sqlContext.sparkContext.
- parallelize((1 to 10).map(i => (i, i.toString)), 1). // Should be only 1 partition
- toDF("key", "value")
- checkAnswer(
- input,
- node => SampleNode(conf, 0.0, 0.3, withReplacement, seed, node),
- input.sample(withReplacement, 0.3, seed).collect()
- )
+ val seed = 0L
+ val lowerb = 0.0
+ val upperb = 0.3
+ val maybeOut = if (withReplacement) "" else "out"
+ test(s"with$maybeOut replacement") {
+ val inputData = (1 to 1000).map { i => (i, i) }.toArray
+ val inputNode = new DummyNode(kvIntAttributes, inputData)
+ val sampleNode = new SampleNode(conf, lowerb, upperb, withReplacement, seed, inputNode)
+ val sampler =
+ if (withReplacement) {
+ new PoissonSampler[(Int, Int)](upperb - lowerb, useGapSamplingIfPossible = false)
+ } else {
+ new BernoulliCellSampler[(Int, Int)](lowerb, upperb)
+ }
+ sampler.setSeed(seed)
+ val expectedOutput = sampler.sample(inputData.iterator).toArray
+ val actualOutput = sampleNode.collect().map { case row =>
+ (row.getInt(0), row.getInt(1))
+ }
+ assert(actualOutput === expectedOutput)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala
index ff28b24eef..42ebc7bfca 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/TakeOrderedAndProjectNodeSuite.scala
@@ -17,38 +17,34 @@
package org.apache.spark.sql.execution.local
-import org.apache.spark.sql.Column
-import org.apache.spark.sql.catalyst.expressions.{Ascending, Expression, SortOrder}
+import scala.util.Random
-class TakeOrderedAndProjectNodeSuite extends LocalNodeTest {
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.SortOrder
- import testImplicits._
- private def columnToSortOrder(sortExprs: Column*): Seq[SortOrder] = {
- val sortOrder: Seq[SortOrder] = sortExprs.map { col =>
- col.expr match {
- case expr: SortOrder =>
- expr
- case expr: Expression =>
- SortOrder(expr, Ascending)
- }
- }
- sortOrder
- }
+class TakeOrderedAndProjectNodeSuite extends LocalNodeTest {
- private def testTakeOrderedAndProjectNode(desc: Boolean): Unit = {
- val testCaseName = if (desc) "desc" else "asc"
- test(testCaseName) {
- val input = (1 to 10).map(i => (i, i.toString)).toDF("key", "value")
- val sortColumn = if (desc) input.col("key").desc else input.col("key")
- checkAnswer(
- input,
- node => TakeOrderedAndProjectNode(conf, 5, columnToSortOrder(sortColumn), None, node),
- input.sort(sortColumn).limit(5).collect()
- )
+ private def testTakeOrderedAndProject(desc: Boolean): Unit = {
+ val limit = 10
+ val ascOrDesc = if (desc) "desc" else "asc"
+ test(ascOrDesc) {
+ val inputData = Random.shuffle((1 to 100).toList).map { i => (i, i) }.toArray
+ val inputNode = new DummyNode(kvIntAttributes, inputData)
+ val firstColumn = inputNode.output(0)
+ val sortDirection = if (desc) Descending else Ascending
+ val sortOrder = SortOrder(firstColumn, sortDirection)
+ val takeOrderAndProjectNode = new TakeOrderedAndProjectNode(
+ conf, limit, Seq(sortOrder), Some(Seq(firstColumn)), inputNode)
+ val expectedOutput = inputData
+ .map { case (k, _) => k }
+ .sortBy { k => k * (if (desc) -1 else 1) }
+ .take(limit)
+ val actualOutput = takeOrderAndProjectNode.collect().map { row => row.getInt(0) }
+ assert(actualOutput === expectedOutput)
}
}
- testTakeOrderedAndProjectNode(desc = false)
- testTakeOrderedAndProjectNode(desc = true)
+ testTakeOrderedAndProject(desc = false)
+ testTakeOrderedAndProject(desc = true)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala
index eedd732090..666b0235c0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala
@@ -17,36 +17,39 @@
package org.apache.spark.sql.execution.local
-import org.apache.spark.sql.test.SharedSQLContext
-class UnionNodeSuite extends LocalNodeTest with SharedSQLContext {
+class UnionNodeSuite extends LocalNodeTest {
- test("basic") {
- checkAnswer2(
- testData,
- testData,
- (node1, node2) => UnionNode(conf, Seq(node1, node2)),
- testData.unionAll(testData).collect()
- )
+ private def testUnion(inputData: Seq[Array[(Int, Int)]]): Unit = {
+ val inputNodes = inputData.map { data =>
+ new DummyNode(kvIntAttributes, data)
+ }
+ val unionNode = new UnionNode(conf, inputNodes)
+ val expectedOutput = inputData.flatten
+ val actualOutput = unionNode.collect().map { case row =>
+ (row.getInt(0), row.getInt(1))
+ }
+ assert(actualOutput === expectedOutput)
}
test("empty") {
- checkAnswer2(
- emptyTestData,
- emptyTestData,
- (node1, node2) => UnionNode(conf, Seq(node1, node2)),
- emptyTestData.unionAll(emptyTestData).collect()
- )
+ testUnion(Seq(Array.empty))
+ testUnion(Seq(Array.empty, Array.empty))
+ }
+
+ test("self") {
+ val data = (1 to 100).map { i => (i, i) }.toArray
+ testUnion(Seq(data))
+ testUnion(Seq(data, data))
+ testUnion(Seq(data, data, data))
}
- test("complicated union") {
- val dfs = Seq(testData, emptyTestData, emptyTestData, testData, testData, emptyTestData,
- emptyTestData, emptyTestData, testData, emptyTestData)
- doCheckAnswer(
- dfs,
- nodes => UnionNode(conf, nodes),
- dfs.reduce(_.unionAll(_)).collect()
- )
+ test("basic") {
+ val zero = Array.empty[(Int, Int)]
+ val one = (1 to 100).map { i => (i, i) }.toArray
+ val two = (50 to 150).map { i => (i, i) }.toArray
+ val three = (800 to 900).map { i => (i, i) }.toArray
+ testUnion(Seq(zero, one, two, three))
}
}