aboutsummaryrefslogtreecommitdiff
path: root/sql/core
diff options
context:
space:
mode:
authorZongheng Yang <zongheng.y@gmail.com>2014-06-25 18:06:33 -0700
committerMichael Armbrust <michael@databricks.com>2014-06-25 18:06:33 -0700
commit9d824fed8c62dd6c87b4c855c2fea930c01b58f4 (patch)
tree641e36b5ec43664a1c3dc5c9aa7821ae7f2ee2ec /sql/core
parent1132e472eca1a00c2ce10d2f84e8f0e79a5193d3 (diff)
downloadspark-9d824fed8c62dd6c87b4c855c2fea930c01b58f4.tar.gz
spark-9d824fed8c62dd6c87b4c855c2fea930c01b58f4.tar.bz2
spark-9d824fed8c62dd6c87b4c855c2fea930c01b58f4.zip
[SQL] SPARK-1800 Add broadcast hash join operator & associated hints.
This PR is based off Michael's [PR 734](https://github.com/apache/spark/pull/734) and includes a bunch of cleanups. Moreover, this PR also - makes `SparkLogicalPlan` take a `tableName: String`, which facilitates testing. - moves join-related tests to a single file. Author: Zongheng Yang <zongheng.y@gmail.com> Author: Michael Armbrust <michael@databricks.com> Closes #1163 from concretevitamin/auto-broadcast-hash-join and squashes the following commits: d0f4991 [Zongheng Yang] Fix bug in broadcast hash join & add test to cover it. af080d7 [Zongheng Yang] Fix in joinIterators()'s next(). 440d277 [Zongheng Yang] Fixes to imports; add back requiredChildDistribution (lost when merging) 208d5f6 [Zongheng Yang] Make LeftSemiJoinHash mix in HashJoin. ad6c7cc [Zongheng Yang] Minor cleanups. 814b3bf [Zongheng Yang] Merge branch 'master' into auto-broadcast-hash-join a8a093e [Zongheng Yang] Minor cleanups. 6fd8443 [Zongheng Yang] Cut down size estimation related stuff. a4267be [Zongheng Yang] Add test for broadcast hash join and related necessary refactorings: 0e64b08 [Zongheng Yang] Scalastyle fix. 91461c2 [Zongheng Yang] Merge branch 'master' into auto-broadcast-hash-join 7c7158b [Zongheng Yang] Prototype of auto conversion to broadcast hash join. 0ad122f [Zongheng Yang] Merge branch 'master' into auto-broadcast-hash-join 3e5d77c [Zongheng Yang] WIP: giant and messy WIP. a92ed0c [Michael Armbrust] Formatting. 76ca434 [Michael Armbrust] A simple strategy that broadcasts tables only when they are found in a configuration hint. cf6b381 [Michael Armbrust] Split out generic logic for hash joins and create two concrete physical operators: BroadcastHashJoin and ShuffledHashJoin. a8420ca [Michael Armbrust] Copy records in executeCollect to avoid issues with mutable rows.
Diffstat (limited to 'sql/core')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala17
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala15
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala54
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala1
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala219
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala5
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala99
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala173
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala17
11 files changed, 387 insertions, 223 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
index b378252ba2..2fe7f94663 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLConf.scala
@@ -29,9 +29,26 @@ import scala.collection.JavaConverters._
*/
trait SQLConf {
+ /** ************************ Spark SQL Params/Hints ******************* */
+ // TODO: refactor so that these hints accessors don't pollute the name space of SQLContext?
+
/** Number of partitions to use for shuffle operators. */
private[spark] def numShufflePartitions: Int = get("spark.sql.shuffle.partitions", "200").toInt
+ /**
+ * Upper bound on the sizes (in bytes) of the tables qualified for the auto conversion to
+ * a broadcast value during the physical executions of join operations. Setting this to 0
+ * effectively disables auto conversion.
+ * Hive setting: hive.auto.convert.join.noconditionaltask.size.
+ */
+ private[spark] def autoConvertJoinSize: Int =
+ get("spark.sql.auto.convert.join.size", "10000").toInt
+
+ /** A comma-separated list of table names marked to be broadcasted during joins. */
+ private[spark] def joinBroadcastTables: String = get("spark.sql.join.broadcastTables", "")
+
+ /** ********************** SQLConf functionality methods ************ */
+
@transient
private val settings = java.util.Collections.synchronizedMap(
new java.util.HashMap[String, String]())
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 7195f9709d..7edb548678 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -170,7 +170,11 @@ class SQLContext(@transient val sparkContext: SparkContext)
* @group userf
*/
def registerRDDAsTable(rdd: SchemaRDD, tableName: String): Unit = {
- catalog.registerTable(None, tableName, rdd.logicalPlan)
+ val name = tableName
+ val newPlan = rdd.logicalPlan transform {
+ case s @ SparkLogicalPlan(ExistingRdd(_, _), _) => s.copy(tableName = name)
+ }
+ catalog.registerTable(None, tableName, newPlan)
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index 07967fe75e..27dc091b85 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -23,9 +23,9 @@ import org.apache.spark.sql.{Logging, Row}
import org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.expressions.GenericRow
-import org.apache.spark.sql.catalyst.plans.{QueryPlan, logical}
+import org.apache.spark.sql.catalyst.plans.QueryPlan
+import org.apache.spark.sql.catalyst.plans.logical.BaseRelation
import org.apache.spark.sql.catalyst.plans.physical._
-import org.apache.spark.sql.columnar.InMemoryColumnarTableScan
/**
* :: DeveloperApi ::
@@ -66,19 +66,20 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging {
* linking.
*/
@DeveloperApi
-case class SparkLogicalPlan(alreadyPlanned: SparkPlan)
- extends logical.LogicalPlan with MultiInstanceRelation {
+case class SparkLogicalPlan(alreadyPlanned: SparkPlan, tableName: String = "SparkLogicalPlan")
+ extends BaseRelation with MultiInstanceRelation {
def output = alreadyPlanned.output
- def references = Set.empty
- def children = Nil
+ override def references = Set.empty
+ override def children = Nil
override final def newInstance: this.type = {
SparkLogicalPlan(
alreadyPlanned match {
case ExistingRdd(output, rdd) => ExistingRdd(output.map(_.newInstance), rdd)
case _ => sys.error("Multiple instance of the same relation detected.")
- }).asInstanceOf[this.type]
+ }, tableName)
+ .asInstanceOf[this.type]
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index bd8ae4cdde..3cd29967d1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -21,10 +21,10 @@ import org.apache.spark.sql.{SQLContext, execution}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning._
import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.plans.logical.{BaseRelation, LogicalPlan}
import org.apache.spark.sql.catalyst.plans.physical._
-import org.apache.spark.sql.parquet._
import org.apache.spark.sql.columnar.{InMemoryRelation, InMemoryColumnarTableScan}
+import org.apache.spark.sql.parquet._
private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
self: SQLContext#SparkPlanner =>
@@ -45,14 +45,52 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
}
}
+ /**
+ * Uses the HashFilteredJoin pattern to find joins where at least some of the predicates can be
+ * evaluated by matching hash keys.
+ */
object HashJoin extends Strategy with PredicateHelper {
+ private[this] def broadcastHashJoin(
+ leftKeys: Seq[Expression],
+ rightKeys: Seq[Expression],
+ left: LogicalPlan,
+ right: LogicalPlan,
+ condition: Option[Expression],
+ side: BuildSide) = {
+ val broadcastHashJoin = execution.BroadcastHashJoin(
+ leftKeys, rightKeys, side, planLater(left), planLater(right))(sqlContext)
+ condition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin) :: Nil
+ }
+
+ def broadcastTables: Seq[String] = sqlContext.joinBroadcastTables.split(",").toBuffer
+
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
- // Find inner joins where at least some predicates can be evaluated by matching hash keys
- // using the HashFilteredJoin pattern.
+ case HashFilteredJoin(
+ Inner,
+ leftKeys,
+ rightKeys,
+ condition,
+ left,
+ right @ PhysicalOperation(_, _, b: BaseRelation))
+ if broadcastTables.contains(b.tableName) =>
+ broadcastHashJoin(leftKeys, rightKeys, left, right, condition, BuildRight)
+
+ case HashFilteredJoin(
+ Inner,
+ leftKeys,
+ rightKeys,
+ condition,
+ left @ PhysicalOperation(_, _, b: BaseRelation),
+ right)
+ if broadcastTables.contains(b.tableName) =>
+ broadcastHashJoin(leftKeys, rightKeys, left, right, condition, BuildLeft)
+
case HashFilteredJoin(Inner, leftKeys, rightKeys, condition, left, right) =>
val hashJoin =
- execution.HashJoin(leftKeys, rightKeys, BuildRight, planLater(left), planLater(right))
+ execution.ShuffledHashJoin(
+ leftKeys, rightKeys, BuildRight, planLater(left), planLater(right))
condition.map(Filter(_, hashJoin)).getOrElse(hashJoin) :: Nil
+
case _ => Nil
}
}
@@ -62,10 +100,10 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case logical.Aggregate(groupingExpressions, aggregateExpressions, child) =>
// Collect all aggregate expressions.
val allAggregates =
- aggregateExpressions.flatMap(_ collect { case a: AggregateExpression => a})
+ aggregateExpressions.flatMap(_ collect { case a: AggregateExpression => a })
// Collect all aggregate expressions that can be computed partially.
val partialAggregates =
- aggregateExpressions.flatMap(_ collect { case p: PartialAggregate => p})
+ aggregateExpressions.flatMap(_ collect { case p: PartialAggregate => p })
// Only do partial aggregation if supported by all aggregate expressions.
if (allAggregates.size == partialAggregates.size) {
@@ -242,7 +280,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
execution.ExistingRdd(Nil, singleRowRdd) :: Nil
case logical.Repartition(expressions, child) =>
execution.Exchange(HashPartitioning(expressions, numPartitions), planLater(child)) :: Nil
- case SparkLogicalPlan(existingPlan) => existingPlan :: Nil
+ case SparkLogicalPlan(existingPlan, _) => existingPlan :: Nil
case _ => Nil
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index b40d4e3a3b..a278f1ca98 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -205,4 +205,3 @@ object ExistingRdd {
case class ExistingRdd(output: Seq[Attribute], rdd: RDD[Row]) extends LeafNode {
override def execute() = rdd
}
-
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
index 84bdde38b7..32c5f26fe8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
@@ -18,12 +18,15 @@
package org.apache.spark.sql.execution
import scala.collection.mutable.{ArrayBuffer, BitSet}
+import scala.concurrent.ExecutionContext.Implicits.global
+import scala.concurrent._
+import scala.concurrent.duration._
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.catalyst.plans.physical.{ClusteredDistribution, Partitioning}
+import org.apache.spark.sql.catalyst.plans.physical._
@DeveloperApi
sealed abstract class BuildSide
@@ -34,28 +37,19 @@ case object BuildLeft extends BuildSide
@DeveloperApi
case object BuildRight extends BuildSide
-/**
- * :: DeveloperApi ::
- */
-@DeveloperApi
-case class HashJoin(
- leftKeys: Seq[Expression],
- rightKeys: Seq[Expression],
- buildSide: BuildSide,
- left: SparkPlan,
- right: SparkPlan) extends BinaryNode {
-
- override def outputPartitioning: Partitioning = left.outputPartitioning
+trait HashJoin {
+ val leftKeys: Seq[Expression]
+ val rightKeys: Seq[Expression]
+ val buildSide: BuildSide
+ val left: SparkPlan
+ val right: SparkPlan
- override def requiredChildDistribution =
- ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
-
- val (buildPlan, streamedPlan) = buildSide match {
+ lazy val (buildPlan, streamedPlan) = buildSide match {
case BuildLeft => (left, right)
case BuildRight => (right, left)
}
- val (buildKeys, streamedKeys) = buildSide match {
+ lazy val (buildKeys, streamedKeys) = buildSide match {
case BuildLeft => (leftKeys, rightKeys)
case BuildRight => (rightKeys, leftKeys)
}
@@ -66,73 +60,74 @@ case class HashJoin(
@transient lazy val streamSideKeyGenerator =
() => new MutableProjection(streamedKeys, streamedPlan.output)
- def execute() = {
-
- buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) =>
- // TODO: Use Spark's HashMap implementation.
- val hashTable = new java.util.HashMap[Row, ArrayBuffer[Row]]()
- var currentRow: Row = null
-
- // Create a mapping of buildKeys -> rows
- while (buildIter.hasNext) {
- currentRow = buildIter.next()
- val rowKey = buildSideKeyGenerator(currentRow)
- if(!rowKey.anyNull) {
- val existingMatchList = hashTable.get(rowKey)
- val matchList = if (existingMatchList == null) {
- val newMatchList = new ArrayBuffer[Row]()
- hashTable.put(rowKey, newMatchList)
- newMatchList
- } else {
- existingMatchList
- }
- matchList += currentRow.copy()
+ def joinIterators(buildIter: Iterator[Row], streamIter: Iterator[Row]): Iterator[Row] = {
+ // TODO: Use Spark's HashMap implementation.
+
+ val hashTable = new java.util.HashMap[Row, ArrayBuffer[Row]]()
+ var currentRow: Row = null
+
+ // Create a mapping of buildKeys -> rows
+ while (buildIter.hasNext) {
+ currentRow = buildIter.next()
+ val rowKey = buildSideKeyGenerator(currentRow)
+ if(!rowKey.anyNull) {
+ val existingMatchList = hashTable.get(rowKey)
+ val matchList = if (existingMatchList == null) {
+ val newMatchList = new ArrayBuffer[Row]()
+ hashTable.put(rowKey, newMatchList)
+ newMatchList
+ } else {
+ existingMatchList
}
+ matchList += currentRow.copy()
}
+ }
- new Iterator[Row] {
- private[this] var currentStreamedRow: Row = _
- private[this] var currentHashMatches: ArrayBuffer[Row] = _
- private[this] var currentMatchPosition: Int = -1
+ new Iterator[Row] {
+ private[this] var currentStreamedRow: Row = _
+ private[this] var currentHashMatches: ArrayBuffer[Row] = _
+ private[this] var currentMatchPosition: Int = -1
- // Mutable per row objects.
- private[this] val joinRow = new JoinedRow
+ // Mutable per row objects.
+ private[this] val joinRow = new JoinedRow
- private[this] val joinKeys = streamSideKeyGenerator()
+ private[this] val joinKeys = streamSideKeyGenerator()
- override final def hasNext: Boolean =
- (currentMatchPosition != -1 && currentMatchPosition < currentHashMatches.size) ||
+ override final def hasNext: Boolean =
+ (currentMatchPosition != -1 && currentMatchPosition < currentHashMatches.size) ||
(streamIter.hasNext && fetchNext())
- override final def next() = {
- val ret = joinRow(currentStreamedRow, currentHashMatches(currentMatchPosition))
- currentMatchPosition += 1
- ret
+ override final def next() = {
+ val ret = buildSide match {
+ case BuildRight => joinRow(currentStreamedRow, currentHashMatches(currentMatchPosition))
+ case BuildLeft => joinRow(currentHashMatches(currentMatchPosition), currentStreamedRow)
}
+ currentMatchPosition += 1
+ ret
+ }
- /**
- * Searches the streamed iterator for the next row that has at least one match in hashtable.
- *
- * @return true if the search is successful, and false the streamed iterator runs out of
- * tuples.
- */
- private final def fetchNext(): Boolean = {
- currentHashMatches = null
- currentMatchPosition = -1
-
- while (currentHashMatches == null && streamIter.hasNext) {
- currentStreamedRow = streamIter.next()
- if (!joinKeys(currentStreamedRow).anyNull) {
- currentHashMatches = hashTable.get(joinKeys.currentValue)
- }
+ /**
+ * Searches the streamed iterator for the next row that has at least one match in hashtable.
+ *
+ * @return true if the search is successful, and false if the streamed iterator runs out of
+ * tuples.
+ */
+ private final def fetchNext(): Boolean = {
+ currentHashMatches = null
+ currentMatchPosition = -1
+
+ while (currentHashMatches == null && streamIter.hasNext) {
+ currentStreamedRow = streamIter.next()
+ if (!joinKeys(currentStreamedRow).anyNull) {
+ currentHashMatches = hashTable.get(joinKeys.currentValue)
}
+ }
- if (currentHashMatches == null) {
- false
- } else {
- currentMatchPosition = 0
- true
- }
+ if (currentHashMatches == null) {
+ false
+ } else {
+ currentMatchPosition = 0
+ true
}
}
}
@@ -141,32 +136,49 @@ case class HashJoin(
/**
* :: DeveloperApi ::
- * Build the right table's join keys into a HashSet, and iteratively go through the left
- * table, to find the if join keys are in the Hash set.
+ * Performs an inner hash join of two child relations by first shuffling the data using the join
+ * keys.
*/
@DeveloperApi
-case class LeftSemiJoinHash(
+case class ShuffledHashJoin(
leftKeys: Seq[Expression],
rightKeys: Seq[Expression],
+ buildSide: BuildSide,
left: SparkPlan,
- right: SparkPlan) extends BinaryNode {
+ right: SparkPlan) extends BinaryNode with HashJoin {
override def outputPartitioning: Partitioning = left.outputPartitioning
override def requiredChildDistribution =
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
- val (buildPlan, streamedPlan) = (right, left)
- val (buildKeys, streamedKeys) = (rightKeys, leftKeys)
+ def execute() = {
+ buildPlan.execute().zipPartitions(streamedPlan.execute()) {
+ (buildIter, streamIter) => joinIterators(buildIter, streamIter)
+ }
+ }
+}
- def output = left.output
+/**
+ * :: DeveloperApi ::
+ * Build the right table's join keys into a HashSet, and iteratively go through the left
+ * table, to find the if join keys are in the Hash set.
+ */
+@DeveloperApi
+case class LeftSemiJoinHash(
+ leftKeys: Seq[Expression],
+ rightKeys: Seq[Expression],
+ left: SparkPlan,
+ right: SparkPlan) extends BinaryNode with HashJoin {
- @transient lazy val buildSideKeyGenerator = new Projection(buildKeys, buildPlan.output)
- @transient lazy val streamSideKeyGenerator =
- () => new MutableProjection(streamedKeys, streamedPlan.output)
+ val buildSide = BuildRight
- def execute() = {
+ override def requiredChildDistribution =
+ ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
+
+ override def output = left.output
+ def execute() = {
buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) =>
val hashSet = new java.util.HashSet[Row]()
var currentRow: Row = null
@@ -191,6 +203,43 @@ case class LeftSemiJoinHash(
}
}
+
+/**
+ * :: DeveloperApi ::
+ * Performs an inner hash join of two child relations. When the output RDD of this operator is
+ * being constructed, a Spark job is asynchronously started to calculate the values for the
+ * broadcasted relation. This data is then placed in a Spark broadcast variable. The streamed
+ * relation is not shuffled.
+ */
+@DeveloperApi
+case class BroadcastHashJoin(
+ leftKeys: Seq[Expression],
+ rightKeys: Seq[Expression],
+ buildSide: BuildSide,
+ left: SparkPlan,
+ right: SparkPlan)(@transient sqlContext: SQLContext) extends BinaryNode with HashJoin {
+
+ override def otherCopyArgs = sqlContext :: Nil
+
+ override def outputPartitioning: Partitioning = left.outputPartitioning
+
+ override def requiredChildDistribution =
+ UnspecifiedDistribution :: UnspecifiedDistribution :: Nil
+
+ @transient
+ lazy val broadcastFuture = future {
+ sqlContext.sparkContext.broadcast(buildPlan.executeCollect())
+ }
+
+ def execute() = {
+ val broadcastRelation = Await.result(broadcastFuture, 5.minute)
+
+ streamedPlan.execute().mapPartitions { streamedIter =>
+ joinIterators(broadcastRelation.value.iterator, streamedIter)
+ }
+ }
+}
+
/**
* :: DeveloperApi ::
* Using BroadcastNestedLoopJoin to calculate left semi join result when there's no join keys
@@ -220,7 +269,6 @@ case class LeftSemiJoinBNL(
.map(c => BindReferences.bindReference(c, left.output ++ right.output))
.getOrElse(Literal(true)))
-
def execute() = {
val broadcastedRelation =
sqlContext.sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq)
@@ -284,7 +332,6 @@ case class BroadcastNestedLoopJoin(
.map(c => BindReferences.bindReference(c, left.output ++ right.output))
.getOrElse(Literal(true)))
-
def execute() = {
val broadcastedRelation =
sqlContext.sparkContext.broadcast(broadcast.execute().map(_.copy()).collect().toIndexedSeq)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala
index 96c131a7f8..9c4771d1a9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/parquet/ParquetRelation.scala
@@ -44,8 +44,9 @@ import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, LeafNode}
* @param path The path to the Parquet file.
*/
private[sql] case class ParquetRelation(
- val path: String,
- @transient val conf: Option[Configuration] = None) extends LeafNode with MultiInstanceRelation {
+ path: String,
+ @transient conf: Option[Configuration] = None) extends LeafNode with MultiInstanceRelation {
+
self: Product =>
/** Schema derived from ParquetFile */
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
index fb599e1e01..e4a64a7a48 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DslQuerySuite.scala
@@ -19,7 +19,6 @@ package org.apache.spark.sql
import org.apache.spark.sql.catalyst.analysis._
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.test._
/* Implicits */
@@ -149,102 +148,4 @@ class DslQuerySuite extends QueryTest {
test("zero count") {
assert(emptyTableData.count() === 0)
}
-
- test("inner join where, one match per row") {
- checkAnswer(
- upperCaseData.join(lowerCaseData, Inner).where('n === 'N),
- Seq(
- (1, "A", 1, "a"),
- (2, "B", 2, "b"),
- (3, "C", 3, "c"),
- (4, "D", 4, "d")
- ))
- }
-
- test("inner join ON, one match per row") {
- checkAnswer(
- upperCaseData.join(lowerCaseData, Inner, Some('n === 'N)),
- Seq(
- (1, "A", 1, "a"),
- (2, "B", 2, "b"),
- (3, "C", 3, "c"),
- (4, "D", 4, "d")
- ))
- }
-
- test("inner join, where, multiple matches") {
- val x = testData2.where('a === 1).as('x)
- val y = testData2.where('a === 1).as('y)
- checkAnswer(
- x.join(y).where("x.a".attr === "y.a".attr),
- (1,1,1,1) ::
- (1,1,1,2) ::
- (1,2,1,1) ::
- (1,2,1,2) :: Nil
- )
- }
-
- test("inner join, no matches") {
- val x = testData2.where('a === 1).as('x)
- val y = testData2.where('a === 2).as('y)
- checkAnswer(
- x.join(y).where("x.a".attr === "y.a".attr),
- Nil)
- }
-
- test("big inner join, 4 matches per row") {
- val bigData = testData.unionAll(testData).unionAll(testData).unionAll(testData)
- val bigDataX = bigData.as('x)
- val bigDataY = bigData.as('y)
-
- checkAnswer(
- bigDataX.join(bigDataY).where("x.key".attr === "y.key".attr),
- testData.flatMap(
- row => Seq.fill(16)((row ++ row).toSeq)).collect().toSeq)
- }
-
- test("cartisian product join") {
- checkAnswer(
- testData3.join(testData3),
- (1, null, 1, null) ::
- (1, null, 2, 2) ::
- (2, 2, 1, null) ::
- (2, 2, 2, 2) :: Nil)
- }
-
- test("left outer join") {
- checkAnswer(
- upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N)),
- (1, "A", 1, "a") ::
- (2, "B", 2, "b") ::
- (3, "C", 3, "c") ::
- (4, "D", 4, "d") ::
- (5, "E", null, null) ::
- (6, "F", null, null) :: Nil)
- }
-
- test("right outer join") {
- checkAnswer(
- lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N)),
- (1, "a", 1, "A") ::
- (2, "b", 2, "B") ::
- (3, "c", 3, "C") ::
- (4, "d", 4, "D") ::
- (null, null, 5, "E") ::
- (null, null, 6, "F") :: Nil)
- }
-
- test("full outer join") {
- val left = upperCaseData.where('N <= 4).as('left)
- val right = upperCaseData.where('N >= 3).as('right)
-
- checkAnswer(
- left.join(right, FullOuter, Some("left.N".attr === "right.N".attr)),
- (1, "A", null, null) ::
- (2, "B", null, null) ::
- (3, "C", 3, "C") ::
- (4, "D", 4, "D") ::
- (null, null, 5, "E") ::
- (null, null, 6, "F") :: Nil)
- }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
new file mode 100644
index 0000000000..3d7d5eedbe
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -0,0 +1,173 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import org.apache.spark.sql.TestData._
+import org.apache.spark.sql.catalyst.plans.{LeftOuter, RightOuter, FullOuter, Inner}
+import org.apache.spark.sql.execution._
+import org.apache.spark.sql.test.TestSQLContext
+import org.apache.spark.sql.test.TestSQLContext._
+
+class JoinSuite extends QueryTest {
+
+ // Ensures tables are loaded.
+ TestData
+
+ test("equi-join is hash-join") {
+ val x = testData2.as('x)
+ val y = testData2.as('y)
+ val join = x.join(y, Inner, Some("x.a".attr === "y.a".attr)).queryExecution.analyzed
+ val planned = planner.HashJoin(join)
+ assert(planned.size === 1)
+ }
+
+ test("plans broadcast hash join, given hints") {
+
+ def mkTest(buildSide: BuildSide, leftTable: String, rightTable: String) = {
+ TestSQLContext.set("spark.sql.join.broadcastTables",
+ s"${if (buildSide == BuildRight) rightTable else leftTable}")
+ val rdd = sql(s"""SELECT * FROM $leftTable JOIN $rightTable ON key = a""")
+ // Using `sparkPlan` because for relevant patterns in HashJoin to be
+ // matched, other strategies need to be applied.
+ val physical = rdd.queryExecution.sparkPlan
+ val bhj = physical.collect { case j: BroadcastHashJoin if j.buildSide == buildSide => j }
+
+ assert(bhj.size === 1, "planner does not pick up hint to generate broadcast hash join")
+ checkAnswer(
+ rdd,
+ Seq(
+ (1, "1", 1, 1),
+ (1, "1", 1, 2),
+ (2, "2", 2, 1),
+ (2, "2", 2, 2),
+ (3, "3", 3, 1),
+ (3, "3", 3, 2)
+ ))
+ }
+
+ mkTest(BuildRight, "testData", "testData2")
+ mkTest(BuildLeft, "testData", "testData2")
+ }
+
+ test("multiple-key equi-join is hash-join") {
+ val x = testData2.as('x)
+ val y = testData2.as('y)
+ val join = x.join(y, Inner,
+ Some("x.a".attr === "y.a".attr && "x.b".attr === "y.b".attr)).queryExecution.analyzed
+ val planned = planner.HashJoin(join)
+ assert(planned.size === 1)
+ }
+
+ test("inner join where, one match per row") {
+ checkAnswer(
+ upperCaseData.join(lowerCaseData, Inner).where('n === 'N),
+ Seq(
+ (1, "A", 1, "a"),
+ (2, "B", 2, "b"),
+ (3, "C", 3, "c"),
+ (4, "D", 4, "d")
+ ))
+ }
+
+ test("inner join ON, one match per row") {
+ checkAnswer(
+ upperCaseData.join(lowerCaseData, Inner, Some('n === 'N)),
+ Seq(
+ (1, "A", 1, "a"),
+ (2, "B", 2, "b"),
+ (3, "C", 3, "c"),
+ (4, "D", 4, "d")
+ ))
+ }
+
+ test("inner join, where, multiple matches") {
+ val x = testData2.where('a === 1).as('x)
+ val y = testData2.where('a === 1).as('y)
+ checkAnswer(
+ x.join(y).where("x.a".attr === "y.a".attr),
+ (1,1,1,1) ::
+ (1,1,1,2) ::
+ (1,2,1,1) ::
+ (1,2,1,2) :: Nil
+ )
+ }
+
+ test("inner join, no matches") {
+ val x = testData2.where('a === 1).as('x)
+ val y = testData2.where('a === 2).as('y)
+ checkAnswer(
+ x.join(y).where("x.a".attr === "y.a".attr),
+ Nil)
+ }
+
+ test("big inner join, 4 matches per row") {
+ val bigData = testData.unionAll(testData).unionAll(testData).unionAll(testData)
+ val bigDataX = bigData.as('x)
+ val bigDataY = bigData.as('y)
+
+ checkAnswer(
+ bigDataX.join(bigDataY).where("x.key".attr === "y.key".attr),
+ testData.flatMap(
+ row => Seq.fill(16)((row ++ row).toSeq)).collect().toSeq)
+ }
+
+ test("cartisian product join") {
+ checkAnswer(
+ testData3.join(testData3),
+ (1, null, 1, null) ::
+ (1, null, 2, 2) ::
+ (2, 2, 1, null) ::
+ (2, 2, 2, 2) :: Nil)
+ }
+
+ test("left outer join") {
+ checkAnswer(
+ upperCaseData.join(lowerCaseData, LeftOuter, Some('n === 'N)),
+ (1, "A", 1, "a") ::
+ (2, "B", 2, "b") ::
+ (3, "C", 3, "c") ::
+ (4, "D", 4, "d") ::
+ (5, "E", null, null) ::
+ (6, "F", null, null) :: Nil)
+ }
+
+ test("right outer join") {
+ checkAnswer(
+ lowerCaseData.join(upperCaseData, RightOuter, Some('n === 'N)),
+ (1, "a", 1, "A") ::
+ (2, "b", 2, "B") ::
+ (3, "c", 3, "C") ::
+ (4, "d", 4, "D") ::
+ (null, null, 5, "E") ::
+ (null, null, 6, "F") :: Nil)
+ }
+
+ test("full outer join") {
+ val left = upperCaseData.where('N <= 4).as('left)
+ val right = upperCaseData.where('N >= 3).as('right)
+
+ checkAnswer(
+ left.join(right, FullOuter, Some("left.N".attr === "right.N".attr)),
+ (1, "A", null, null) ::
+ (2, "B", null, null) ::
+ (3, "C", 3, "C") ::
+ (4, "D", 4, "D") ::
+ (null, null, 5, "E") ::
+ (null, null, 6, "F") :: Nil)
+ }
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index ef84ead2e6..8e1e1971d9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -35,7 +35,7 @@ class QueryTest extends PlanTest {
case singleItem => Seq(Seq(singleItem))
}
- val isSorted = rdd.logicalPlan.collect { case s: logical.Sort => s}.nonEmpty
+ val isSorted = rdd.logicalPlan.collect { case s: logical.Sort => s }.nonEmpty
def prepareAnswer(answer: Seq[Any]) = if (!isSorted) answer.sortBy(_.toString) else answer
val sparkAnswer = try rdd.collect().toSeq catch {
case e: Exception =>
@@ -48,7 +48,7 @@ class QueryTest extends PlanTest {
""".stripMargin)
}
- if(prepareAnswer(convertedAnswer) != prepareAnswer(sparkAnswer)) {
+ if (prepareAnswer(convertedAnswer) != prepareAnswer(sparkAnswer)) {
fail(s"""
|Results do not match for query:
|${rdd.logicalPlan}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index df6b118360..215618e852 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -57,21 +57,4 @@ class PlannerSuite extends FunSuite {
val planned = PartialAggregation(query)
assert(planned.isEmpty)
}
-
- test("equi-join is hash-join") {
- val x = testData2.as('x)
- val y = testData2.as('y)
- val join = x.join(y, Inner, Some("x.a".attr === "y.a".attr)).queryExecution.analyzed
- val planned = planner.HashJoin(join)
- assert(planned.size === 1)
- }
-
- test("multiple-key equi-join is hash-join") {
- val x = testData2.as('x)
- val y = testData2.as('y)
- val join = x.join(y, Inner,
- Some("x.a".attr === "y.a".attr && "x.b".attr === "y.b".attr)).queryExecution.analyzed
- val planned = planner.HashJoin(join)
- assert(planned.size === 1)
- }
}