From b2842138ca084e8b795cecdcadd5dcbaa9f41d28 Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Fri, 14 Aug 2015 12:37:21 -0700 Subject: [SPARK-9561] Re-enable BroadcastJoinSuite We can do this now that SPARK-9580 is resolved. Author: Andrew Or Closes #8208 from andrewor14/reenable-sql-tests. (cherry picked from commit ece00566e4d5f38585f2810bef38e526cae7d25e) Signed-off-by: Michael Armbrust --- .../sql/execution/joins/BroadcastJoinSuite.scala | 145 ++++++++++----------- 1 file changed, 69 insertions(+), 76 deletions(-) (limited to 'sql/core') diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala index 0554e11d25..53a0e53fd7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/BroadcastJoinSuite.scala @@ -15,80 +15,73 @@ * limitations under the License. */ -// TODO: uncomment the test here! It is currently failing due to -// bad interaction with org.apache.spark.sql.test.TestSQLContext. +package org.apache.spark.sql.execution.joins -// scalastyle:off -//package org.apache.spark.sql.execution.joins -// -//import scala.reflect.ClassTag -// -//import org.scalatest.BeforeAndAfterAll -// -//import org.apache.spark.{AccumulatorSuite, SparkConf, SparkContext} -//import org.apache.spark.sql.functions._ -//import org.apache.spark.sql.{SQLConf, SQLContext, QueryTest} -// -///** -// * Test various broadcast join operators with unsafe enabled. -// * -// * This needs to be its own suite because [[org.apache.spark.sql.test.TestSQLContext]] runs -// * in local mode, but for tests in this suite we need to run Spark in local-cluster mode. -// * In particular, the use of [[org.apache.spark.unsafe.map.BytesToBytesMap]] in -// * [[org.apache.spark.sql.execution.joins.UnsafeHashedRelation]] is not triggered without -// * serializing the hashed relation, which does not happen in local mode. -// */ -//class BroadcastJoinSuite extends QueryTest with BeforeAndAfterAll { -// private var sc: SparkContext = null -// private var sqlContext: SQLContext = null -// -// /** -// * Create a new [[SQLContext]] running in local-cluster mode with unsafe and codegen enabled. -// */ -// override def beforeAll(): Unit = { -// super.beforeAll() -// val conf = new SparkConf() -// .setMaster("local-cluster[2,1,1024]") -// .setAppName("testing") -// sc = new SparkContext(conf) -// sqlContext = new SQLContext(sc) -// sqlContext.setConf(SQLConf.UNSAFE_ENABLED, true) -// sqlContext.setConf(SQLConf.CODEGEN_ENABLED, true) -// } -// -// override def afterAll(): Unit = { -// sc.stop() -// sc = null -// sqlContext = null -// } -// -// /** -// * Test whether the specified broadcast join updates the peak execution memory accumulator. -// */ -// private def testBroadcastJoin[T: ClassTag](name: String, joinType: String): Unit = { -// AccumulatorSuite.verifyPeakExecutionMemorySet(sc, name) { -// val df1 = sqlContext.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") -// val df2 = sqlContext.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value") -// // Comparison at the end is for broadcast left semi join -// val joinExpression = df1("key") === df2("key") && df1("value") > df2("value") -// val df3 = df1.join(broadcast(df2), joinExpression, joinType) -// val plan = df3.queryExecution.executedPlan -// assert(plan.collect { case p: T => p }.size === 1) -// plan.executeCollect() -// } -// } -// -// test("unsafe broadcast hash join updates peak execution memory") { -// testBroadcastJoin[BroadcastHashJoin]("unsafe broadcast hash join", "inner") -// } -// -// test("unsafe broadcast hash outer join updates peak execution memory") { -// testBroadcastJoin[BroadcastHashOuterJoin]("unsafe broadcast hash outer join", "left_outer") -// } -// -// test("unsafe broadcast left semi join updates peak execution memory") { -// testBroadcastJoin[BroadcastLeftSemiJoinHash]("unsafe broadcast left semi join", "leftsemi") -// } -// -//} -// scalastyle:on +import scala.reflect.ClassTag + +import org.scalatest.BeforeAndAfterAll + +import org.apache.spark.{AccumulatorSuite, SparkConf, SparkContext} +import org.apache.spark.sql.functions._ +import org.apache.spark.sql.{SQLConf, SQLContext, QueryTest} + +/** + * Test various broadcast join operators with unsafe enabled. + * + * Tests in this suite we need to run Spark in local-cluster mode. In particular, the use of + * unsafe map in [[org.apache.spark.sql.execution.joins.UnsafeHashedRelation]] is not triggered + * without serializing the hashed relation, which does not happen in local mode. + */ +class BroadcastJoinSuite extends QueryTest with BeforeAndAfterAll { + private var sc: SparkContext = null + private var sqlContext: SQLContext = null + + /** + * Create a new [[SQLContext]] running in local-cluster mode with unsafe and codegen enabled. + */ + override def beforeAll(): Unit = { + super.beforeAll() + val conf = new SparkConf() + .setMaster("local-cluster[2,1,1024]") + .setAppName("testing") + sc = new SparkContext(conf) + sqlContext = new SQLContext(sc) + sqlContext.setConf(SQLConf.UNSAFE_ENABLED, true) + sqlContext.setConf(SQLConf.CODEGEN_ENABLED, true) + } + + override def afterAll(): Unit = { + sc.stop() + sc = null + sqlContext = null + } + + /** + * Test whether the specified broadcast join updates the peak execution memory accumulator. + */ + private def testBroadcastJoin[T: ClassTag](name: String, joinType: String): Unit = { + AccumulatorSuite.verifyPeakExecutionMemorySet(sc, name) { + val df1 = sqlContext.createDataFrame(Seq((1, "4"), (2, "2"))).toDF("key", "value") + val df2 = sqlContext.createDataFrame(Seq((1, "1"), (2, "2"))).toDF("key", "value") + // Comparison at the end is for broadcast left semi join + val joinExpression = df1("key") === df2("key") && df1("value") > df2("value") + val df3 = df1.join(broadcast(df2), joinExpression, joinType) + val plan = df3.queryExecution.executedPlan + assert(plan.collect { case p: T => p }.size === 1) + plan.executeCollect() + } + } + + test("unsafe broadcast hash join updates peak execution memory") { + testBroadcastJoin[BroadcastHashJoin]("unsafe broadcast hash join", "inner") + } + + test("unsafe broadcast hash outer join updates peak execution memory") { + testBroadcastJoin[BroadcastHashOuterJoin]("unsafe broadcast hash outer join", "left_outer") + } + + test("unsafe broadcast left semi join updates peak execution memory") { + testBroadcastJoin[BroadcastLeftSemiJoinHash]("unsafe broadcast left semi join", "leftsemi") + } + +} -- cgit v1.2.3