aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorzsxwing <zsxwing@gmail.com>2015-09-10 12:06:49 -0700
committerAndrew Or <andrew@databricks.com>2015-09-10 12:06:49 -0700
commitd88abb7e212fb55f9b0398a0f76a753c86b85cf1 (patch)
tree22e9f176c4ffe6ea548a9791bf3acac8915ad215 /sql
parenta5ef2d0600d5e23ca05fabc1005bb81e5ada0727 (diff)
downloadspark-d88abb7e212fb55f9b0398a0f76a753c86b85cf1.tar.gz
spark-d88abb7e212fb55f9b0398a0f76a753c86b85cf1.tar.bz2
spark-d88abb7e212fb55f9b0398a0f76a753c86b85cf1.zip
[SPARK-9990] [SQL] Create local hash join operator
This PR includes the following changes: - Add SQLConf to LocalNode - Add HashJoinNode - Add ConvertToUnsafeNode and ConvertToSafeNode.scala to test unsafe hash join. Author: zsxwing <zsxwing@gmail.com> Closes #8535 from zsxwing/SPARK-9990.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/local/ConvertToSafeNode.scala40
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/local/ConvertToUnsafeNode.scala40
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/local/FilterNode.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/local/HashJoinNode.scala137
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/local/LimitNode.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala83
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/local/ProjectNode.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/local/SeqScanNode.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/local/UnionNode.scala3
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala130
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala9
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala6
16 files changed, 455 insertions, 24 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
index 6c0196c21a..0cff21ca61 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
@@ -38,7 +38,7 @@ import org.apache.spark.{SparkConf, SparkEnv}
* Interface for a hashed relation by some key. Use [[HashedRelation.apply]] to create a concrete
* object.
*/
-private[joins] sealed trait HashedRelation {
+private[execution] sealed trait HashedRelation {
def get(key: InternalRow): Seq[InternalRow]
// This is a helper method to implement Externalizable, and is used by
@@ -111,7 +111,7 @@ final class UniqueKeyHashedRelation(private var hashTable: JavaHashMap[InternalR
// TODO(rxin): a version of [[HashedRelation]] backed by arrays for consecutive integer keys.
-private[joins] object HashedRelation {
+private[execution] object HashedRelation {
def apply(
input: Iterator[InternalRow],
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ConvertToSafeNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ConvertToSafeNode.scala
new file mode 100644
index 0000000000..b31c5a8638
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ConvertToSafeNode.scala
@@ -0,0 +1,40 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql.execution.local
+
+import org.apache.spark.sql.SQLConf
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{Attribute, FromUnsafeProjection, Projection}
+
+case class ConvertToSafeNode(conf: SQLConf, child: LocalNode) extends UnaryLocalNode(conf) {
+
+ override def output: Seq[Attribute] = child.output
+
+ private[this] var convertToSafe: Projection = _
+
+ override def open(): Unit = {
+ child.open()
+ convertToSafe = FromUnsafeProjection(child.schema)
+ }
+
+ override def next(): Boolean = child.next()
+
+ override def fetch(): InternalRow = convertToSafe(child.fetch())
+
+ override def close(): Unit = child.close()
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ConvertToUnsafeNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ConvertToUnsafeNode.scala
new file mode 100644
index 0000000000..de2f4e661a
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ConvertToUnsafeNode.scala
@@ -0,0 +1,40 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql.execution.local
+
+import org.apache.spark.sql.SQLConf
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Projection, UnsafeProjection}
+
+case class ConvertToUnsafeNode(conf: SQLConf, child: LocalNode) extends UnaryLocalNode(conf) {
+
+ override def output: Seq[Attribute] = child.output
+
+ private[this] var convertToUnsafe: Projection = _
+
+ override def open(): Unit = {
+ child.open()
+ convertToUnsafe = UnsafeProjection.create(child.schema)
+ }
+
+ override def next(): Boolean = child.next()
+
+ override def fetch(): InternalRow = convertToUnsafe(child.fetch())
+
+ override def close(): Unit = child.close()
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/FilterNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/FilterNode.scala
index 81dd37c7da..dd1113b672 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/FilterNode.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/FilterNode.scala
@@ -17,12 +17,14 @@
package org.apache.spark.sql.execution.local
+import org.apache.spark.sql.SQLConf
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
import org.apache.spark.sql.catalyst.expressions.codegen.GeneratePredicate
-case class FilterNode(condition: Expression, child: LocalNode) extends UnaryLocalNode {
+case class FilterNode(conf: SQLConf, condition: Expression, child: LocalNode)
+ extends UnaryLocalNode(conf) {
private[this] var predicate: (InternalRow) => Boolean = _
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/HashJoinNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/HashJoinNode.scala
new file mode 100644
index 0000000000..a3e68d6a7c
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/HashJoinNode.scala
@@ -0,0 +1,137 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql.execution.local
+
+import org.apache.spark.sql.SQLConf
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.execution.joins._
+import org.apache.spark.sql.execution.metric.SQLMetrics
+
+/**
+ * Much of this code is similar to [[org.apache.spark.sql.execution.joins.HashJoin]].
+ */
+case class HashJoinNode(
+ conf: SQLConf,
+ leftKeys: Seq[Expression],
+ rightKeys: Seq[Expression],
+ buildSide: BuildSide,
+ left: LocalNode,
+ right: LocalNode) extends BinaryLocalNode(conf) {
+
+ private[this] lazy val (buildNode, buildKeys, streamedNode, streamedKeys) = buildSide match {
+ case BuildLeft => (left, leftKeys, right, rightKeys)
+ case BuildRight => (right, rightKeys, left, leftKeys)
+ }
+
+ private[this] var currentStreamedRow: InternalRow = _
+ private[this] var currentHashMatches: Seq[InternalRow] = _
+ private[this] var currentMatchPosition: Int = -1
+
+ private[this] var joinRow: JoinedRow = _
+ private[this] var resultProjection: (InternalRow) => InternalRow = _
+
+ private[this] var hashed: HashedRelation = _
+ private[this] var joinKeys: Projection = _
+
+ override def output: Seq[Attribute] = left.output ++ right.output
+
+ private[this] def isUnsafeMode: Boolean = {
+ (codegenEnabled && unsafeEnabled
+ && UnsafeProjection.canSupport(buildKeys)
+ && UnsafeProjection.canSupport(schema))
+ }
+
+ private[this] def buildSideKeyGenerator: Projection = {
+ if (isUnsafeMode) {
+ UnsafeProjection.create(buildKeys, buildNode.output)
+ } else {
+ newMutableProjection(buildKeys, buildNode.output)()
+ }
+ }
+
+ private[this] def streamSideKeyGenerator: Projection = {
+ if (isUnsafeMode) {
+ UnsafeProjection.create(streamedKeys, streamedNode.output)
+ } else {
+ newMutableProjection(streamedKeys, streamedNode.output)()
+ }
+ }
+
+ override def open(): Unit = {
+ buildNode.open()
+ hashed = HashedRelation.apply(
+ new LocalNodeIterator(buildNode), SQLMetrics.nullLongMetric, buildSideKeyGenerator)
+ streamedNode.open()
+ joinRow = new JoinedRow
+ resultProjection = {
+ if (isUnsafeMode) {
+ UnsafeProjection.create(schema)
+ } else {
+ identity[InternalRow]
+ }
+ }
+ joinKeys = streamSideKeyGenerator
+ }
+
+ override def next(): Boolean = {
+ currentMatchPosition += 1
+ if (currentHashMatches == null || currentMatchPosition >= currentHashMatches.size) {
+ fetchNextMatch()
+ } else {
+ true
+ }
+ }
+
+ /**
+ * Populate `currentHashMatches` with build-side rows matching the next streamed row.
+ * @return whether matches are found such that subsequent calls to `fetch` are valid.
+ */
+ private def fetchNextMatch(): Boolean = {
+ currentHashMatches = null
+ currentMatchPosition = -1
+
+ while (currentHashMatches == null && streamedNode.next()) {
+ currentStreamedRow = streamedNode.fetch()
+ val key = joinKeys(currentStreamedRow)
+ if (!key.anyNull) {
+ currentHashMatches = hashed.get(key)
+ }
+ }
+
+ if (currentHashMatches == null) {
+ false
+ } else {
+ currentMatchPosition = 0
+ true
+ }
+ }
+
+ override def fetch(): InternalRow = {
+ val ret = buildSide match {
+ case BuildRight => joinRow(currentStreamedRow, currentHashMatches(currentMatchPosition))
+ case BuildLeft => joinRow(currentHashMatches(currentMatchPosition), currentStreamedRow)
+ }
+ resultProjection(ret)
+ }
+
+ override def close(): Unit = {
+ left.close()
+ right.close()
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LimitNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LimitNode.scala
index fffc52abf6..401b10a5ed 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LimitNode.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LimitNode.scala
@@ -17,11 +17,12 @@
package org.apache.spark.sql.execution.local
+import org.apache.spark.sql.SQLConf
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
-case class LimitNode(limit: Int, child: LocalNode) extends UnaryLocalNode {
+case class LimitNode(conf: SQLConf, limit: Int, child: LocalNode) extends UnaryLocalNode(conf) {
private[this] var count = 0
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala
index 1c4469acbf..c4f8ae304d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala
@@ -17,9 +17,13 @@
package org.apache.spark.sql.execution.local
-import org.apache.spark.sql.Row
+import scala.util.control.NonFatal
+
+import org.apache.spark.Logging
+import org.apache.spark.sql.{SQLConf, Row}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
-import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.codegen.GenerateMutableProjection
import org.apache.spark.sql.catalyst.trees.TreeNode
import org.apache.spark.sql.types.StructType
@@ -29,7 +33,15 @@ import org.apache.spark.sql.types.StructType
* Before consuming the iterator, open function must be called.
* After consuming the iterator, close function must be called.
*/
-abstract class LocalNode extends TreeNode[LocalNode] {
+abstract class LocalNode(conf: SQLConf) extends TreeNode[LocalNode] with Logging {
+
+ protected val codegenEnabled: Boolean = conf.codegenEnabled
+
+ protected val unsafeEnabled: Boolean = conf.unsafeEnabled
+
+ lazy val schema: StructType = StructType.fromAttributes(output)
+
+ private[this] lazy val isTesting: Boolean = sys.props.contains("spark.testing")
def output: Seq[Attribute]
@@ -73,17 +85,78 @@ abstract class LocalNode extends TreeNode[LocalNode] {
}
result
}
+
+ protected def newMutableProjection(
+ expressions: Seq[Expression],
+ inputSchema: Seq[Attribute]): () => MutableProjection = {
+ log.debug(
+ s"Creating MutableProj: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled")
+ if (codegenEnabled) {
+ try {
+ GenerateMutableProjection.generate(expressions, inputSchema)
+ } catch {
+ case NonFatal(e) =>
+ if (isTesting) {
+ throw e
+ } else {
+ log.error("Failed to generate mutable projection, fallback to interpreted", e)
+ () => new InterpretedMutableProjection(expressions, inputSchema)
+ }
+ }
+ } else {
+ () => new InterpretedMutableProjection(expressions, inputSchema)
+ }
+ }
+
}
-abstract class LeafLocalNode extends LocalNode {
+abstract class LeafLocalNode(conf: SQLConf) extends LocalNode(conf) {
override def children: Seq[LocalNode] = Seq.empty
}
-abstract class UnaryLocalNode extends LocalNode {
+abstract class UnaryLocalNode(conf: SQLConf) extends LocalNode(conf) {
def child: LocalNode
override def children: Seq[LocalNode] = Seq(child)
}
+
+abstract class BinaryLocalNode(conf: SQLConf) extends LocalNode(conf) {
+
+ def left: LocalNode
+
+ def right: LocalNode
+
+ override def children: Seq[LocalNode] = Seq(left, right)
+}
+
+/**
+ * An thin wrapper around a [[LocalNode]] that provides an `Iterator` interface.
+ */
+private[local] class LocalNodeIterator(localNode: LocalNode) extends Iterator[InternalRow] {
+ private var nextRow: InternalRow = _
+
+ override def hasNext: Boolean = {
+ if (nextRow == null) {
+ val res = localNode.next()
+ if (res) {
+ nextRow = localNode.fetch()
+ }
+ res
+ } else {
+ true
+ }
+ }
+
+ override def next(): InternalRow = {
+ if (hasNext) {
+ val res = nextRow
+ nextRow = null
+ res
+ } else {
+ throw new NoSuchElementException
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ProjectNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ProjectNode.scala
index 9b8a4fe493..11529d6dd9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ProjectNode.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/ProjectNode.scala
@@ -17,11 +17,13 @@
package org.apache.spark.sql.execution.local
+import org.apache.spark.sql.SQLConf
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, Attribute, NamedExpression}
-case class ProjectNode(projectList: Seq[NamedExpression], child: LocalNode) extends UnaryLocalNode {
+case class ProjectNode(conf: SQLConf, projectList: Seq[NamedExpression], child: LocalNode)
+ extends UnaryLocalNode(conf) {
private[this] var project: UnsafeProjection = _
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SeqScanNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SeqScanNode.scala
index 242cb66e07..b8467f6ae5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SeqScanNode.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/SeqScanNode.scala
@@ -17,13 +17,15 @@
package org.apache.spark.sql.execution.local
+import org.apache.spark.sql.SQLConf
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
/**
* An operator that scans some local data collection in the form of Scala Seq.
*/
-case class SeqScanNode(output: Seq[Attribute], data: Seq[InternalRow]) extends LeafLocalNode {
+case class SeqScanNode(conf: SQLConf, output: Seq[Attribute], data: Seq[InternalRow])
+ extends LeafLocalNode(conf) {
private[this] var iterator: Iterator[InternalRow] = _
private[this] var currentRow: InternalRow = _
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/UnionNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/UnionNode.scala
index ba4aa7671a..0f2b8303e7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/UnionNode.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/UnionNode.scala
@@ -17,10 +17,11 @@
package org.apache.spark.sql.execution.local
+import org.apache.spark.sql.SQLConf
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Attribute
-case class UnionNode(children: Seq[LocalNode]) extends LocalNode {
+case class UnionNode(conf: SQLConf, children: Seq[LocalNode]) extends LocalNode(conf) {
override def output: Seq[Attribute] = children.head.output
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala
index 07209f3779..a12670e347 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/FilterNodeSuite.scala
@@ -25,7 +25,7 @@ class FilterNodeSuite extends LocalNodeTest with SharedSQLContext {
val condition = (testData.col("key") % 2) === 0
checkAnswer(
testData,
- node => FilterNode(condition.expr, node),
+ node => FilterNode(conf, condition.expr, node),
testData.filter(condition).collect()
)
}
@@ -34,7 +34,7 @@ class FilterNodeSuite extends LocalNodeTest with SharedSQLContext {
val condition = (emptyTestData.col("key") % 2) === 0
checkAnswer(
emptyTestData,
- node => FilterNode(condition.expr, node),
+ node => FilterNode(conf, condition.expr, node),
emptyTestData.filter(condition).collect()
)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala
new file mode 100644
index 0000000000..43b6f06aea
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala
@@ -0,0 +1,130 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql.execution.local
+
+import org.apache.spark.sql.SQLConf
+import org.apache.spark.sql.execution.joins
+
+class HashJoinNodeSuite extends LocalNodeTest {
+
+ import testImplicits._
+
+ private def wrapForUnsafe(
+ f: (LocalNode, LocalNode) => LocalNode): (LocalNode, LocalNode) => LocalNode = {
+ if (conf.unsafeEnabled) {
+ (left: LocalNode, right: LocalNode) => {
+ val _left = ConvertToUnsafeNode(conf, left)
+ val _right = ConvertToUnsafeNode(conf, right)
+ val r = f(_left, _right)
+ ConvertToSafeNode(conf, r)
+ }
+ } else {
+ f
+ }
+ }
+
+ def joinSuite(suiteName: String, confPairs: (String, String)*): Unit = {
+ test(s"$suiteName: inner join with one match per row") {
+ withSQLConf(confPairs: _*) {
+ checkAnswer2(
+ upperCaseData,
+ lowerCaseData,
+ wrapForUnsafe(
+ (node1, node2) => HashJoinNode(
+ conf,
+ Seq(upperCaseData.col("N").expr),
+ Seq(lowerCaseData.col("n").expr),
+ joins.BuildLeft,
+ node1,
+ node2)
+ ),
+ upperCaseData.join(lowerCaseData, $"n" === $"N").collect()
+ )
+ }
+ }
+
+ test(s"$suiteName: inner join with multiple matches") {
+ withSQLConf(confPairs: _*) {
+ val x = testData2.where($"a" === 1).as("x")
+ val y = testData2.where($"a" === 1).as("y")
+ checkAnswer2(
+ x,
+ y,
+ wrapForUnsafe(
+ (node1, node2) => HashJoinNode(
+ conf,
+ Seq(x.col("a").expr),
+ Seq(y.col("a").expr),
+ joins.BuildLeft,
+ node1,
+ node2)
+ ),
+ x.join(y).where($"x.a" === $"y.a").collect()
+ )
+ }
+ }
+
+ test(s"$suiteName: inner join, no matches") {
+ withSQLConf(confPairs: _*) {
+ val x = testData2.where($"a" === 1).as("x")
+ val y = testData2.where($"a" === 2).as("y")
+ checkAnswer2(
+ x,
+ y,
+ wrapForUnsafe(
+ (node1, node2) => HashJoinNode(
+ conf,
+ Seq(x.col("a").expr),
+ Seq(y.col("a").expr),
+ joins.BuildLeft,
+ node1,
+ node2)
+ ),
+ Nil
+ )
+ }
+ }
+
+ test(s"$suiteName: big inner join, 4 matches per row") {
+ withSQLConf(confPairs: _*) {
+ val bigData = testData.unionAll(testData).unionAll(testData).unionAll(testData)
+ val bigDataX = bigData.as("x")
+ val bigDataY = bigData.as("y")
+
+ checkAnswer2(
+ bigDataX,
+ bigDataY,
+ wrapForUnsafe(
+ (node1, node2) =>
+ HashJoinNode(
+ conf,
+ Seq(bigDataX.col("key").expr),
+ Seq(bigDataY.col("key").expr),
+ joins.BuildLeft,
+ node1,
+ node2)
+ ),
+ bigDataX.join(bigDataY).where($"x.key" === $"y.key").collect())
+ }
+ }
+ }
+
+ joinSuite(
+ "general", SQLConf.CODEGEN_ENABLED.key -> "false", SQLConf.UNSAFE_ENABLED.key -> "false")
+ joinSuite("tungsten", SQLConf.CODEGEN_ENABLED.key -> "true", SQLConf.UNSAFE_ENABLED.key -> "true")
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala
index 523c02f4a6..3b18390200 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LimitNodeSuite.scala
@@ -24,7 +24,7 @@ class LimitNodeSuite extends LocalNodeTest with SharedSQLContext {
test("basic") {
checkAnswer(
testData,
- node => LimitNode(10, node),
+ node => LimitNode(conf, 10, node),
testData.limit(10).collect()
)
}
@@ -32,7 +32,7 @@ class LimitNodeSuite extends LocalNodeTest with SharedSQLContext {
test("empty") {
checkAnswer(
emptyTestData,
- node => LimitNode(10, node),
+ node => LimitNode(conf, 10, node),
emptyTestData.limit(10).collect()
)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala
index 95f06081bd..b95d4ea7f8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala
@@ -20,10 +20,12 @@ package org.apache.spark.sql.execution.local
import scala.util.control.NonFatal
import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.{DataFrame, Row}
-import org.apache.spark.sql.test.SQLTestUtils
+import org.apache.spark.sql.{DataFrame, Row, SQLConf}
+import org.apache.spark.sql.test.{SharedSQLContext, SQLTestUtils}
-class LocalNodeTest extends SparkFunSuite {
+class LocalNodeTest extends SparkFunSuite with SharedSQLContext {
+
+ def conf: SQLConf = sqlContext.conf
/**
* Runs the LocalNode and makes sure the answer matches the expected result.
@@ -92,6 +94,7 @@ class LocalNodeTest extends SparkFunSuite {
protected def dataFrameToSeqScanNode(df: DataFrame): SeqScanNode = {
new SeqScanNode(
+ conf,
df.queryExecution.sparkPlan.output,
df.queryExecution.toRdd.map(_.copy()).collect())
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala
index ffcf092e2c..38e0a230c4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/ProjectNodeSuite.scala
@@ -26,7 +26,7 @@ class ProjectNodeSuite extends LocalNodeTest with SharedSQLContext {
val columns = Seq(output(1), output(0))
checkAnswer(
testData,
- node => ProjectNode(columns, node),
+ node => ProjectNode(conf, columns, node),
testData.select("value", "key").collect()
)
}
@@ -36,7 +36,7 @@ class ProjectNodeSuite extends LocalNodeTest with SharedSQLContext {
val columns = Seq(output(1), output(0))
checkAnswer(
emptyTestData,
- node => ProjectNode(columns, node),
+ node => ProjectNode(conf, columns, node),
emptyTestData.select("value", "key").collect()
)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala
index 34670287c3..eedd732090 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/UnionNodeSuite.scala
@@ -25,7 +25,7 @@ class UnionNodeSuite extends LocalNodeTest with SharedSQLContext {
checkAnswer2(
testData,
testData,
- (node1, node2) => UnionNode(Seq(node1, node2)),
+ (node1, node2) => UnionNode(conf, Seq(node1, node2)),
testData.unionAll(testData).collect()
)
}
@@ -34,7 +34,7 @@ class UnionNodeSuite extends LocalNodeTest with SharedSQLContext {
checkAnswer2(
emptyTestData,
emptyTestData,
- (node1, node2) => UnionNode(Seq(node1, node2)),
+ (node1, node2) => UnionNode(conf, Seq(node1, node2)),
emptyTestData.unionAll(emptyTestData).collect()
)
}
@@ -44,7 +44,7 @@ class UnionNodeSuite extends LocalNodeTest with SharedSQLContext {
emptyTestData, emptyTestData, testData, emptyTestData)
doCheckAnswer(
dfs,
- nodes => UnionNode(nodes),
+ nodes => UnionNode(conf, nodes),
dfs.reduce(_.unionAll(_)).collect()
)
}