aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorAndrew Or <andrew@databricks.com>2015-09-11 15:01:37 -0700
committerAndrew Or <andrew@databricks.com>2015-09-11 15:01:37 -0700
commitc2af42b5f32287ff595ad027a8191d4b75702d8d (patch)
tree8d190d748dcde88044be442df6ded888f0d9fa3d /sql
parente626ac5f5c27dcc74113070f2fec03682bcd12bd (diff)
downloadspark-c2af42b5f32287ff595ad027a8191d4b75702d8d.tar.gz
spark-c2af42b5f32287ff595ad027a8191d4b75702d8d.tar.bz2
spark-c2af42b5f32287ff595ad027a8191d4b75702d8d.zip
[SPARK-9990] [SQL] Local hash join follow-ups
1. Hide `LocalNodeIterator` behind the `LocalNode#asIterator` method 2. Add tests for this Author: Andrew Or <andrew@databricks.com> Closes #8708 from andrewor14/local-hash-join-follow-up.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala7
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/local/HashJoinNode.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeSuite.scala116
4 files changed, 125 insertions, 5 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
index 0cff21ca61..bc255b2750 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
@@ -25,7 +25,8 @@ import org.apache.spark.shuffle.ShuffleMemoryManager
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.SparkSqlSerializer
-import org.apache.spark.sql.execution.metric.LongSQLMetric
+import org.apache.spark.sql.execution.local.LocalNode
+import org.apache.spark.sql.execution.metric.{LongSQLMetric, SQLMetrics}
import org.apache.spark.unsafe.Platform
import org.apache.spark.unsafe.map.BytesToBytesMap
import org.apache.spark.unsafe.memory.{MemoryLocation, ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager}
@@ -113,6 +114,10 @@ final class UniqueKeyHashedRelation(private var hashTable: JavaHashMap[InternalR
private[execution] object HashedRelation {
+ def apply(localNode: LocalNode, keyGenerator: Projection): HashedRelation = {
+ apply(localNode.asIterator, SQLMetrics.nullLongMetric, keyGenerator)
+ }
+
def apply(
input: Iterator[InternalRow],
numInputRows: LongSQLMetric,
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/HashJoinNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/HashJoinNode.scala
index a3e68d6a7c..e7b24e3fca 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/HashJoinNode.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/HashJoinNode.scala
@@ -75,8 +75,7 @@ case class HashJoinNode(
override def open(): Unit = {
buildNode.open()
- hashed = HashedRelation.apply(
- new LocalNodeIterator(buildNode), SQLMetrics.nullLongMetric, buildSideKeyGenerator)
+ hashed = HashedRelation(buildNode, buildSideKeyGenerator)
streamedNode.open()
joinRow = new JoinedRow
resultProjection = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala
index a2c275db9b..e540ef8555 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/LocalNode.scala
@@ -77,7 +77,7 @@ abstract class LocalNode(conf: SQLConf) extends TreeNode[LocalNode] with Logging
/**
* Returns the content of the iterator from the beginning to the end in the form of a Scala Seq.
*/
- def collect(): Seq[Row] = {
+ final def collect(): Seq[Row] = {
val converter = CatalystTypeConverters.createToScalaConverter(StructType.fromAttributes(output))
val result = new scala.collection.mutable.ArrayBuffer[Row]
open()
@@ -140,7 +140,7 @@ abstract class BinaryLocalNode(conf: SQLConf) extends LocalNode(conf) {
/**
* An thin wrapper around a [[LocalNode]] that provides an `Iterator` interface.
*/
-private[local] class LocalNodeIterator(localNode: LocalNode) extends Iterator[InternalRow] {
+private class LocalNodeIterator(localNode: LocalNode) extends Iterator[InternalRow] {
private var nextRow: InternalRow = _
override def hasNext: Boolean = {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeSuite.scala
new file mode 100644
index 0000000000..b89fa46f8b
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeSuite.scala
@@ -0,0 +1,116 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql.execution.local
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.SQLConf
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.types.IntegerType
+
+class LocalNodeSuite extends SparkFunSuite {
+ private val data = (1 to 100).toArray
+
+ test("basic open, next, fetch, close") {
+ val node = new DummyLocalNode(data)
+ assert(!node.isOpen)
+ node.open()
+ assert(node.isOpen)
+ data.foreach { i =>
+ assert(node.next())
+ // fetch should be idempotent
+ val fetched = node.fetch()
+ assert(node.fetch() === fetched)
+ assert(node.fetch() === fetched)
+ assert(node.fetch().numFields === 1)
+ assert(node.fetch().getInt(0) === i)
+ }
+ assert(!node.next())
+ node.close()
+ assert(!node.isOpen)
+ }
+
+ test("asIterator") {
+ val node = new DummyLocalNode(data)
+ val iter = node.asIterator
+ node.open()
+ data.foreach { i =>
+ // hasNext should be idempotent
+ assert(iter.hasNext)
+ assert(iter.hasNext)
+ val item = iter.next()
+ assert(item.numFields === 1)
+ assert(item.getInt(0) === i)
+ }
+ intercept[NoSuchElementException] {
+ iter.next()
+ }
+ node.close()
+ }
+
+ test("collect") {
+ val node = new DummyLocalNode(data)
+ node.open()
+ val collected = node.collect()
+ assert(collected.size === data.size)
+ assert(collected.forall(_.size === 1))
+ assert(collected.map(_.getInt(0)) === data)
+ node.close()
+ }
+
+}
+
+/**
+ * A dummy [[LocalNode]] that just returns one row per integer in the input.
+ */
+private case class DummyLocalNode(conf: SQLConf, input: Array[Int]) extends LocalNode(conf) {
+ private var index = Int.MinValue
+
+ def this(input: Array[Int]) {
+ this(new SQLConf, input)
+ }
+
+ def isOpen: Boolean = {
+ index != Int.MinValue
+ }
+
+ override def output: Seq[Attribute] = {
+ Seq(AttributeReference("something", IntegerType)())
+ }
+
+ override def children: Seq[LocalNode] = Seq.empty
+
+ override def open(): Unit = {
+ index = -1
+ }
+
+ override def next(): Boolean = {
+ index += 1
+ index < input.size
+ }
+
+ override def fetch(): InternalRow = {
+ assert(index >= 0 && index < input.size)
+ val values = Array(input(index).asInstanceOf[Any])
+ new GenericInternalRow(values)
+ }
+
+ override def close(): Unit = {
+ index = Int.MinValue
+ }
+}