From 8f16b94afb39e1641c02d4e0be18d34ef7c211cc Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Thu, 4 Jun 2015 22:15:58 -0700 Subject: [SPARK-8114][SQL] Remove some wildcard import on TestSQLContext._ I kept some of the sql import there to avoid changing too many lines. Author: Reynold Xin Closes #6661 from rxin/remove-wildcard-import-sqlcontext and squashes the following commits: c265347 [Reynold Xin] Fixed ListTablesSuite failure. de9d491 [Reynold Xin] Fixed tests. 73b5365 [Reynold Xin] Mima. 8f6b642 [Reynold Xin] Fixed style violation. 443f6e8 [Reynold Xin] [SPARK-8113][SQL] Remove some wildcard import on TestSQLContext._ --- .../spark/sql/catalyst/analysis/Analyzer.scala | 12 +- .../org/apache/spark/sql/CachedTableSuite.scala | 160 +++++++++++---------- .../apache/spark/sql/ColumnExpressionSuite.scala | 15 +- .../apache/spark/sql/DataFrameAggregateSuite.scala | 9 +- .../apache/spark/sql/DataFrameFunctionsSuite.scala | 4 +- .../apache/spark/sql/DataFrameImplicitsSuite.scala | 15 +- .../org/apache/spark/sql/DataFrameJoinSuite.scala | 9 +- .../spark/sql/DataFrameNaFunctionsSuite.scala | 5 +- .../org/apache/spark/sql/DataFrameStatSuite.scala | 8 +- .../org/apache/spark/sql/DataFrameSuite.scala | 68 ++++----- .../scala/org/apache/spark/sql/JoinSuite.scala | 65 ++++----- .../org/apache/spark/sql/ListTablesSuite.scala | 35 +++-- .../apache/spark/sql/MathExpressionsSuite.scala | 44 +++--- .../test/scala/org/apache/spark/sql/RowSuite.scala | 7 +- .../scala/org/apache/spark/sql/SQLConfSuite.scala | 67 +++++---- .../org/apache/spark/sql/SQLContextSuite.scala | 16 +-- .../scala/org/apache/spark/sql/SQLQuerySuite.scala | 125 ++++++++-------- .../spark/sql/ScalaReflectionRelationSuite.scala | 31 ++-- .../org/apache/spark/sql/SerializationSuite.scala | 5 +- .../test/scala/org/apache/spark/sql/UDFSuite.scala | 28 ++-- .../apache/spark/sql/UserDefinedTypeSuite.scala | 23 ++- 21 files changed, 373 insertions(+), 378 deletions(-) (limited to 'sql') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index bc17169f35..5883d938b6 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -235,9 +235,8 @@ class Analyzer( } /** - * Replaces [[UnresolvedAttribute]]s with concrete - * [[catalyst.expressions.AttributeReference AttributeReferences]] from a logical plan node's - * children. + * Replaces [[UnresolvedAttribute]]s with concrete [[AttributeReference]]s from + * a logical plan node's children. */ object ResolveReferences extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transformUp { @@ -455,7 +454,7 @@ class Analyzer( } /** - * Replaces [[UnresolvedFunction]]s with concrete [[catalyst.expressions.Expression Expressions]]. + * Replaces [[UnresolvedFunction]]s with concrete [[Expression]]s. */ object ResolveFunctions extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { @@ -846,9 +845,8 @@ class Analyzer( } /** - * Removes [[catalyst.plans.logical.Subquery Subquery]] operators from the plan. Subqueries are - * only required to provide scoping information for attributes and can be removed once analysis is - * complete. + * Removes [[Subquery]] operators from the plan. Subqueries are only required to provide + * scoping information for attributes and can be removed once analysis is complete. */ object EliminateSubQueries extends Rule[LogicalPlan] { def apply(plan: LogicalPlan): LogicalPlan = plan transform { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 0772e5e187..72e60d9aa7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -25,8 +25,6 @@ import org.scalatest.concurrent.Eventually._ import org.apache.spark.Accumulators import org.apache.spark.sql.TestData._ import org.apache.spark.sql.columnar._ -import org.apache.spark.sql.test.TestSQLContext._ -import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.storage.{RDDBlockId, StorageLevel} case class BigData(s: String) @@ -34,8 +32,12 @@ case class BigData(s: String) class CachedTableSuite extends QueryTest { TestData // Load test tables. + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + import ctx.sql + def rddIdOf(tableName: String): Int = { - val executedPlan = table(tableName).queryExecution.executedPlan + val executedPlan = ctx.table(tableName).queryExecution.executedPlan executedPlan.collect { case InMemoryColumnarTableScan(_, _, relation) => relation.cachedColumnBuffers.id @@ -45,47 +47,47 @@ class CachedTableSuite extends QueryTest { } def isMaterialized(rddId: Int): Boolean = { - sparkContext.env.blockManager.get(RDDBlockId(rddId, 0)).nonEmpty + ctx.sparkContext.env.blockManager.get(RDDBlockId(rddId, 0)).nonEmpty } test("cache temp table") { testData.select('key).registerTempTable("tempTable") assertCached(sql("SELECT COUNT(*) FROM tempTable"), 0) - cacheTable("tempTable") + ctx.cacheTable("tempTable") assertCached(sql("SELECT COUNT(*) FROM tempTable")) - uncacheTable("tempTable") + ctx.uncacheTable("tempTable") } test("unpersist an uncached table will not raise exception") { - assert(None == cacheManager.lookupCachedData(testData)) + assert(None == ctx.cacheManager.lookupCachedData(testData)) testData.unpersist(blocking = true) - assert(None == cacheManager.lookupCachedData(testData)) + assert(None == ctx.cacheManager.lookupCachedData(testData)) testData.unpersist(blocking = false) - assert(None == cacheManager.lookupCachedData(testData)) + assert(None == ctx.cacheManager.lookupCachedData(testData)) testData.persist() - assert(None != cacheManager.lookupCachedData(testData)) + assert(None != ctx.cacheManager.lookupCachedData(testData)) testData.unpersist(blocking = true) - assert(None == cacheManager.lookupCachedData(testData)) + assert(None == ctx.cacheManager.lookupCachedData(testData)) testData.unpersist(blocking = false) - assert(None == cacheManager.lookupCachedData(testData)) + assert(None == ctx.cacheManager.lookupCachedData(testData)) } test("cache table as select") { sql("CACHE TABLE tempTable AS SELECT key FROM testData") assertCached(sql("SELECT COUNT(*) FROM tempTable")) - uncacheTable("tempTable") + ctx.uncacheTable("tempTable") } test("uncaching temp table") { testData.select('key).registerTempTable("tempTable1") testData.select('key).registerTempTable("tempTable2") - cacheTable("tempTable1") + ctx.cacheTable("tempTable1") assertCached(sql("SELECT COUNT(*) FROM tempTable1")) assertCached(sql("SELECT COUNT(*) FROM tempTable2")) // Is this valid? - uncacheTable("tempTable2") + ctx.uncacheTable("tempTable2") // Should this be cached? assertCached(sql("SELECT COUNT(*) FROM tempTable1"), 0) @@ -93,103 +95,103 @@ class CachedTableSuite extends QueryTest { test("too big for memory") { val data = "*" * 10000 - sparkContext.parallelize(1 to 200000, 1).map(_ => BigData(data)).toDF() + ctx.sparkContext.parallelize(1 to 200000, 1).map(_ => BigData(data)).toDF() .registerTempTable("bigData") - table("bigData").persist(StorageLevel.MEMORY_AND_DISK) - assert(table("bigData").count() === 200000L) - table("bigData").unpersist(blocking = true) + ctx.table("bigData").persist(StorageLevel.MEMORY_AND_DISK) + assert(ctx.table("bigData").count() === 200000L) + ctx.table("bigData").unpersist(blocking = true) } test("calling .cache() should use in-memory columnar caching") { - table("testData").cache() - assertCached(table("testData")) - table("testData").unpersist(blocking = true) + ctx.table("testData").cache() + assertCached(ctx.table("testData")) + ctx.table("testData").unpersist(blocking = true) } test("calling .unpersist() should drop in-memory columnar cache") { - table("testData").cache() - table("testData").count() - table("testData").unpersist(blocking = true) - assertCached(table("testData"), 0) + ctx.table("testData").cache() + ctx.table("testData").count() + ctx.table("testData").unpersist(blocking = true) + assertCached(ctx.table("testData"), 0) } test("isCached") { - cacheTable("testData") + ctx.cacheTable("testData") - assertCached(table("testData")) - assert(table("testData").queryExecution.withCachedData match { + assertCached(ctx.table("testData")) + assert(ctx.table("testData").queryExecution.withCachedData match { case _: InMemoryRelation => true case _ => false }) - uncacheTable("testData") - assert(!isCached("testData")) - assert(table("testData").queryExecution.withCachedData match { + ctx.uncacheTable("testData") + assert(!ctx.isCached("testData")) + assert(ctx.table("testData").queryExecution.withCachedData match { case _: InMemoryRelation => false case _ => true }) } test("SPARK-1669: cacheTable should be idempotent") { - assume(!table("testData").logicalPlan.isInstanceOf[InMemoryRelation]) + assume(!ctx.table("testData").logicalPlan.isInstanceOf[InMemoryRelation]) - cacheTable("testData") - assertCached(table("testData")) + ctx.cacheTable("testData") + assertCached(ctx.table("testData")) assertResult(1, "InMemoryRelation not found, testData should have been cached") { - table("testData").queryExecution.withCachedData.collect { + ctx.table("testData").queryExecution.withCachedData.collect { case r: InMemoryRelation => r }.size } - cacheTable("testData") + ctx.cacheTable("testData") assertResult(0, "Double InMemoryRelations found, cacheTable() is not idempotent") { - table("testData").queryExecution.withCachedData.collect { + ctx.table("testData").queryExecution.withCachedData.collect { case r @ InMemoryRelation(_, _, _, _, _: InMemoryColumnarTableScan, _) => r }.size } - uncacheTable("testData") + ctx.uncacheTable("testData") } test("read from cached table and uncache") { - cacheTable("testData") - checkAnswer(table("testData"), testData.collect().toSeq) - assertCached(table("testData")) + ctx.cacheTable("testData") + checkAnswer(ctx.table("testData"), testData.collect().toSeq) + assertCached(ctx.table("testData")) - uncacheTable("testData") - checkAnswer(table("testData"), testData.collect().toSeq) - assertCached(table("testData"), 0) + ctx.uncacheTable("testData") + checkAnswer(ctx.table("testData"), testData.collect().toSeq) + assertCached(ctx.table("testData"), 0) } test("correct error on uncache of non-cached table") { intercept[IllegalArgumentException] { - uncacheTable("testData") + ctx.uncacheTable("testData") } } test("SELECT star from cached table") { sql("SELECT * FROM testData").registerTempTable("selectStar") - cacheTable("selectStar") + ctx.cacheTable("selectStar") checkAnswer( sql("SELECT * FROM selectStar WHERE key = 1"), Seq(Row(1, "1"))) - uncacheTable("selectStar") + ctx.uncacheTable("selectStar") } test("Self-join cached") { val unCachedAnswer = sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key").collect() - cacheTable("testData") + ctx.cacheTable("testData") checkAnswer( sql("SELECT * FROM testData a JOIN testData b ON a.key = b.key"), unCachedAnswer.toSeq) - uncacheTable("testData") + ctx.uncacheTable("testData") } test("'CACHE TABLE' and 'UNCACHE TABLE' SQL statement") { sql("CACHE TABLE testData") - assertCached(table("testData")) + assertCached(ctx.table("testData")) val rddId = rddIdOf("testData") assert( @@ -197,7 +199,7 @@ class CachedTableSuite extends QueryTest { "Eagerly cached in-memory table should have already been materialized") sql("UNCACHE TABLE testData") - assert(!isCached("testData"), "Table 'testData' should not be cached") + assert(!ctx.isCached("testData"), "Table 'testData' should not be cached") eventually(timeout(10 seconds)) { assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") @@ -206,14 +208,14 @@ class CachedTableSuite extends QueryTest { test("CACHE TABLE tableName AS SELECT * FROM anotherTable") { sql("CACHE TABLE testCacheTable AS SELECT * FROM testData") - assertCached(table("testCacheTable")) + assertCached(ctx.table("testCacheTable")) val rddId = rddIdOf("testCacheTable") assert( isMaterialized(rddId), "Eagerly cached in-memory table should have already been materialized") - uncacheTable("testCacheTable") + ctx.uncacheTable("testCacheTable") eventually(timeout(10 seconds)) { assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") } @@ -221,14 +223,14 @@ class CachedTableSuite extends QueryTest { test("CACHE TABLE tableName AS SELECT ...") { sql("CACHE TABLE testCacheTable AS SELECT key FROM testData LIMIT 10") - assertCached(table("testCacheTable")) + assertCached(ctx.table("testCacheTable")) val rddId = rddIdOf("testCacheTable") assert( isMaterialized(rddId), "Eagerly cached in-memory table should have already been materialized") - uncacheTable("testCacheTable") + ctx.uncacheTable("testCacheTable") eventually(timeout(10 seconds)) { assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") } @@ -236,7 +238,7 @@ class CachedTableSuite extends QueryTest { test("CACHE LAZY TABLE tableName") { sql("CACHE LAZY TABLE testData") - assertCached(table("testData")) + assertCached(ctx.table("testData")) val rddId = rddIdOf("testData") assert( @@ -248,7 +250,7 @@ class CachedTableSuite extends QueryTest { isMaterialized(rddId), "Lazily cached in-memory table should have been materialized") - uncacheTable("testData") + ctx.uncacheTable("testData") eventually(timeout(10 seconds)) { assert(!isMaterialized(rddId), "Uncached in-memory table should have been unpersisted") } @@ -256,7 +258,7 @@ class CachedTableSuite extends QueryTest { test("InMemoryRelation statistics") { sql("CACHE TABLE testData") - table("testData").queryExecution.withCachedData.collect { + ctx.table("testData").queryExecution.withCachedData.collect { case cached: InMemoryRelation => val actualSizeInBytes = (1 to 100).map(i => INT.defaultSize + i.toString.length + 4).sum assert(cached.statistics.sizeInBytes === actualSizeInBytes) @@ -265,38 +267,38 @@ class CachedTableSuite extends QueryTest { test("Drops temporary table") { testData.select('key).registerTempTable("t1") - table("t1") - dropTempTable("t1") - assert(intercept[RuntimeException](table("t1")).getMessage.startsWith("Table Not Found")) + ctx.table("t1") + ctx.dropTempTable("t1") + assert(intercept[RuntimeException](ctx.table("t1")).getMessage.startsWith("Table Not Found")) } test("Drops cached temporary table") { testData.select('key).registerTempTable("t1") testData.select('key).registerTempTable("t2") - cacheTable("t1") + ctx.cacheTable("t1") - assert(isCached("t1")) - assert(isCached("t2")) + assert(ctx.isCached("t1")) + assert(ctx.isCached("t2")) - dropTempTable("t1") - assert(intercept[RuntimeException](table("t1")).getMessage.startsWith("Table Not Found")) - assert(!isCached("t2")) + ctx.dropTempTable("t1") + assert(intercept[RuntimeException](ctx.table("t1")).getMessage.startsWith("Table Not Found")) + assert(!ctx.isCached("t2")) } test("Clear all cache") { sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1") sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2") - cacheTable("t1") - cacheTable("t2") - clearCache() - assert(cacheManager.isEmpty) + ctx.cacheTable("t1") + ctx.cacheTable("t2") + ctx.clearCache() + assert(ctx.cacheManager.isEmpty) sql("SELECT key FROM testData LIMIT 10").registerTempTable("t1") sql("SELECT key FROM testData LIMIT 5").registerTempTable("t2") - cacheTable("t1") - cacheTable("t2") + ctx.cacheTable("t1") + ctx.cacheTable("t2") sql("Clear CACHE") - assert(cacheManager.isEmpty) + assert(ctx.cacheManager.isEmpty) } test("Clear accumulators when uncacheTable to prevent memory leaking") { @@ -305,8 +307,8 @@ class CachedTableSuite extends QueryTest { Accumulators.synchronized { val accsSize = Accumulators.originals.size - cacheTable("t1") - cacheTable("t2") + ctx.cacheTable("t1") + ctx.cacheTable("t2") assert((accsSize + 2) == Accumulators.originals.size) } @@ -317,8 +319,8 @@ class CachedTableSuite extends QueryTest { Accumulators.synchronized { val accsSize = Accumulators.originals.size - uncacheTable("t1") - uncacheTable("t2") + ctx.uncacheTable("t1") + ctx.uncacheTable("t2") assert((accsSize - 2) == Accumulators.originals.size) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala index bfba379d9a..4f5484f136 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala @@ -21,13 +21,14 @@ import org.scalatest.Matchers._ import org.apache.spark.sql.execution.Project import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.types._ class ColumnExpressionSuite extends QueryTest { import org.apache.spark.sql.TestData._ + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + test("alias") { val df = Seq((1, Seq(1, 2, 3))).toDF("a", "intList") assert(df.select(df("a").as("b")).columns.head === "b") @@ -213,7 +214,7 @@ class ColumnExpressionSuite extends QueryTest { } test("!==") { - val nullData = TestSQLContext.createDataFrame(TestSQLContext.sparkContext.parallelize( + val nullData = ctx.createDataFrame(ctx.sparkContext.parallelize( Row(1, 1) :: Row(1, 2) :: Row(1, null) :: @@ -274,7 +275,7 @@ class ColumnExpressionSuite extends QueryTest { } test("between") { - val testData = TestSQLContext.sparkContext.parallelize( + val testData = ctx.sparkContext.parallelize( (0, 1, 2) :: (1, 2, 3) :: (2, 1, 0) :: @@ -287,7 +288,7 @@ class ColumnExpressionSuite extends QueryTest { checkAnswer(testData.filter($"a".between($"b", $"c")), expectAnswer) } - val booleanData = TestSQLContext.createDataFrame(TestSQLContext.sparkContext.parallelize( + val booleanData = ctx.createDataFrame(ctx.sparkContext.parallelize( Row(false, false) :: Row(false, true) :: Row(true, false) :: @@ -413,7 +414,7 @@ class ColumnExpressionSuite extends QueryTest { test("monotonicallyIncreasingId") { // Make sure we have 2 partitions, each with 2 records. - val df = TestSQLContext.sparkContext.parallelize(1 to 2, 2).mapPartitions { iter => + val df = ctx.sparkContext.parallelize(1 to 2, 2).mapPartitions { iter => Iterator(Tuple1(1), Tuple1(2)) }.toDF("a") checkAnswer( @@ -423,7 +424,7 @@ class ColumnExpressionSuite extends QueryTest { } test("sparkPartitionId") { - val df = TestSQLContext.sparkContext.parallelize(1 to 1, 1).map(i => (i, i)).toDF("a", "b") + val df = ctx.sparkContext.parallelize(1 to 1, 1).map(i => (i, i)).toDF("a", "b") checkAnswer( df.select(sparkPartitionId()), Row(0) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 232f05c009..790b405c72 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -19,13 +19,14 @@ package org.apache.spark.sql import org.apache.spark.sql.TestData._ import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.types.DecimalType class DataFrameAggregateSuite extends QueryTest { + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + test("groupBy") { checkAnswer( testData2.groupBy("a").agg(sum($"b")), @@ -67,12 +68,12 @@ class DataFrameAggregateSuite extends QueryTest { Seq(Row(1, 3), Row(2, 3), Row(3, 3)) ) - TestSQLContext.conf.setConf("spark.sql.retainGroupColumns", "false") + ctx.conf.setConf("spark.sql.retainGroupColumns", "false") checkAnswer( testData2.groupBy("a").agg(sum($"b")), Seq(Row(3), Row(3), Row(3)) ) - TestSQLContext.conf.setConf("spark.sql.retainGroupColumns", "true") + ctx.conf.setConf("spark.sql.retainGroupColumns", "true") } test("agg without groups") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala index b1e0faa310..53c2befb73 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala @@ -19,7 +19,6 @@ package org.apache.spark.sql import org.apache.spark.sql.TestData._ import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.types._ /** @@ -27,6 +26,9 @@ import org.apache.spark.sql.types._ */ class DataFrameFunctionsSuite extends QueryTest { + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + test("array with column name") { val df = Seq((0, 1)).toDF("a", "b") val row = df.select(array("a", "b")).first() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala index 2d2367d6e7..fbb30706a4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala @@ -17,15 +17,14 @@ package org.apache.spark.sql -import org.apache.spark.sql.test.TestSQLContext.{sparkContext => sc} -import org.apache.spark.sql.test.TestSQLContext.implicits._ - - class DataFrameImplicitsSuite extends QueryTest { + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + test("RDD of tuples") { checkAnswer( - sc.parallelize(1 to 10).map(i => (i, i.toString)).toDF("intCol", "strCol"), + ctx.sparkContext.parallelize(1 to 10).map(i => (i, i.toString)).toDF("intCol", "strCol"), (1 to 10).map(i => Row(i, i.toString))) } @@ -37,19 +36,19 @@ class DataFrameImplicitsSuite extends QueryTest { test("RDD[Int]") { checkAnswer( - sc.parallelize(1 to 10).toDF("intCol"), + ctx.sparkContext.parallelize(1 to 10).toDF("intCol"), (1 to 10).map(i => Row(i))) } test("RDD[Long]") { checkAnswer( - sc.parallelize(1L to 10L).toDF("longCol"), + ctx.sparkContext.parallelize(1L to 10L).toDF("longCol"), (1L to 10L).map(i => Row(i))) } test("RDD[String]") { checkAnswer( - sc.parallelize(1 to 10).map(_.toString).toDF("stringCol"), + ctx.sparkContext.parallelize(1 to 10).map(_.toString).toDF("stringCol"), (1 to 10).map(i => Row(i.toString))) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala index 787f3f175f..051d13e9a5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala @@ -19,12 +19,12 @@ package org.apache.spark.sql import org.apache.spark.sql.TestData._ import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.TestSQLContext._ -import org.apache.spark.sql.test.TestSQLContext.implicits._ - class DataFrameJoinSuite extends QueryTest { + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + test("join - join using") { val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str") val df2 = Seq(1, 2, 3).map(i => (i, (i + 1).toString)).toDF("int", "str") @@ -49,7 +49,8 @@ class DataFrameJoinSuite extends QueryTest { checkAnswer( df1.join(df2, $"df1.key" === $"df2.key"), - sql("SELECT a.key, b.key FROM testData a JOIN testData b ON a.key = b.key").collect().toSeq) + ctx.sql("SELECT a.key, b.key FROM testData a JOIN testData b ON a.key = b.key") + .collect().toSeq) } test("join - using aliases after self join") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala index 41b4f02e6a..495701d4f6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala @@ -19,11 +19,12 @@ package org.apache.spark.sql import scala.collection.JavaConversions._ -import org.apache.spark.sql.test.TestSQLContext.implicits._ - class DataFrameNaFunctionsSuite extends QueryTest { + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + def createDF(): DataFrame = { Seq[(String, java.lang.Integer, java.lang.Double)]( ("Bob", 16, 176.5), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala index 438f479459..0d3ff899da 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala @@ -20,13 +20,13 @@ package org.apache.spark.sql import org.scalatest.Matchers._ import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.test.TestSQLContext.implicits._ class DataFrameStatSuite extends SparkFunSuite { - val sqlCtx = TestSQLContext - def toLetter(i: Int): String = (i + 97).toChar.toString + private val sqlCtx = org.apache.spark.sql.test.TestSQLContext + import sqlCtx.implicits._ + + private def toLetter(i: Int): String = (i + 97).toChar.toString test("pearson correlation") { val df = Seq.tabulate(10)(i => (i, 2 * i, i * -1.0)).toDF("a", "b", "c") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 8e81dacb86..bb8621abe6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -21,17 +21,19 @@ import scala.language.postfixOps import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ -import org.apache.spark.sql.test.{ExamplePointUDT, ExamplePoint, TestSQLContext} -import org.apache.spark.sql.test.TestSQLContext.implicits._ +import org.apache.spark.sql.test.{ExamplePointUDT, ExamplePoint} class DataFrameSuite extends QueryTest { import org.apache.spark.sql.TestData._ + lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + test("analysis error should be eagerly reported") { - val oldSetting = TestSQLContext.conf.dataFrameEagerAnalysis + val oldSetting = ctx.conf.dataFrameEagerAnalysis // Eager analysis. - TestSQLContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, "true") + ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, "true") intercept[Exception] { testData.select('nonExistentName) } intercept[Exception] { @@ -45,11 +47,11 @@ class DataFrameSuite extends QueryTest { } // No more eager analysis once the flag is turned off - TestSQLContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, "false") + ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, "false") testData.select('nonExistentName) // Set the flag back to original value before this test. - TestSQLContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, oldSetting.toString) + ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, oldSetting.toString) } test("dataframe toString") { @@ -67,12 +69,12 @@ class DataFrameSuite extends QueryTest { } test("invalid plan toString, debug mode") { - val oldSetting = TestSQLContext.conf.dataFrameEagerAnalysis - TestSQLContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, "true") + val oldSetting = ctx.conf.dataFrameEagerAnalysis + ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, "true") // Turn on debug mode so we can see invalid query plans. import org.apache.spark.sql.execution.debug._ - TestSQLContext.debug() + ctx.debug() val badPlan = testData.select('badColumn) @@ -81,7 +83,7 @@ class DataFrameSuite extends QueryTest { badPlan.toString) // Set the flag back to original value before this test. - TestSQLContext.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, oldSetting.toString) + ctx.setConf(SQLConf.DATAFRAME_EAGER_ANALYSIS, oldSetting.toString) } test("access complex data") { @@ -97,8 +99,8 @@ class DataFrameSuite extends QueryTest { } test("empty data frame") { - assert(TestSQLContext.emptyDataFrame.columns.toSeq === Seq.empty[String]) - assert(TestSQLContext.emptyDataFrame.count() === 0) + assert(ctx.emptyDataFrame.columns.toSeq === Seq.empty[String]) + assert(ctx.emptyDataFrame.count() === 0) } test("head and take") { @@ -311,7 +313,7 @@ class DataFrameSuite extends QueryTest { } test("replace column using withColumn") { - val df2 = TestSQLContext.sparkContext.parallelize(Array(1, 2, 3)).toDF("x") + val df2 = ctx.sparkContext.parallelize(Array(1, 2, 3)).toDF("x") val df3 = df2.withColumn("x", df2("x") + 1) checkAnswer( df3.select("x"), @@ -392,7 +394,7 @@ class DataFrameSuite extends QueryTest { test("randomSplit") { val n = 600 - val data = TestSQLContext.sparkContext.parallelize(1 to n, 2).toDF("id") + val data = ctx.sparkContext.parallelize(1 to n, 2).toDF("id") for (seed <- 1 to 5) { val splits = data.randomSplit(Array[Double](1, 2, 3), seed) assert(splits.length == 3, "wrong number of splits") @@ -487,21 +489,21 @@ class DataFrameSuite extends QueryTest { } test("createDataFrame(RDD[Row], StructType) should convert UDTs (SPARK-6672)") { - val rowRDD = TestSQLContext.sparkContext.parallelize(Seq(Row(new ExamplePoint(1.0, 2.0)))) + val rowRDD = ctx.sparkContext.parallelize(Seq(Row(new ExamplePoint(1.0, 2.0)))) val schema = StructType(Array(StructField("point", new ExamplePointUDT(), false))) - val df = TestSQLContext.createDataFrame(rowRDD, schema) + val df = ctx.createDataFrame(rowRDD, schema) df.rdd.collect() } test("SPARK-6899") { - val originalValue = TestSQLContext.conf.codegenEnabled - TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, "true") + val originalValue = ctx.conf.codegenEnabled + ctx.setConf(SQLConf.CODEGEN_ENABLED, "true") try{ checkAnswer( decimalData.agg(avg('a)), Row(new java.math.BigDecimal(2.0))) } finally { - TestSQLContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString) + ctx.setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString) } } @@ -513,14 +515,14 @@ class DataFrameSuite extends QueryTest { } test("SPARK-7551: support backticks for DataFrame attribute resolution") { - val df = TestSQLContext.read.json(TestSQLContext.sparkContext.makeRDD( + val df = ctx.read.json(ctx.sparkContext.makeRDD( """{"a.b": {"c": {"d..e": {"f": 1}}}}""" :: Nil)) checkAnswer( df.select(df("`a.b`.c.`d..e`.`f`")), Row(1) ) - val df2 = TestSQLContext.read.json(TestSQLContext.sparkContext.makeRDD( + val df2 = ctx.read.json(ctx.sparkContext.makeRDD( """{"a b": {"c": {"d e": {"f": 1}}}}""" :: Nil)) checkAnswer( df2.select(df2("`a b`.c.d e.f")), @@ -540,7 +542,7 @@ class DataFrameSuite extends QueryTest { } test("SPARK-7324 dropDuplicates") { - val testData = TestSQLContext.sparkContext.parallelize( + val testData = ctx.sparkContext.parallelize( (2, 1, 2) :: (1, 1, 1) :: (1, 2, 1) :: (2, 1, 2) :: (2, 2, 2) :: (2, 2, 1) :: @@ -588,49 +590,49 @@ class DataFrameSuite extends QueryTest { test("SPARK-7150 range api") { // numSlice is greater than length - val res1 = TestSQLContext.range(0, 10, 1, 15).select("id") + val res1 = ctx.range(0, 10, 1, 15).select("id") assert(res1.count == 10) assert(res1.agg(sum("id")).as("sumid").collect() === Seq(Row(45))) - val res2 = TestSQLContext.range(3, 15, 3, 2).select("id") + val res2 = ctx.range(3, 15, 3, 2).select("id") assert(res2.count == 4) assert(res2.agg(sum("id")).as("sumid").collect() === Seq(Row(30))) - val res3 = TestSQLContext.range(1, -2).select("id") + val res3 = ctx.range(1, -2).select("id") assert(res3.count == 0) // start is positive, end is negative, step is negative - val res4 = TestSQLContext.range(1, -2, -2, 6).select("id") + val res4 = ctx.range(1, -2, -2, 6).select("id") assert(res4.count == 2) assert(res4.agg(sum("id")).as("sumid").collect() === Seq(Row(0))) // start, end, step are negative - val res5 = TestSQLContext.range(-3, -8, -2, 1).select("id") + val res5 = ctx.range(-3, -8, -2, 1).select("id") assert(res5.count == 3) assert(res5.agg(sum("id")).as("sumid").collect() === Seq(Row(-15))) // start, end are negative, step is positive - val res6 = TestSQLContext.range(-8, -4, 2, 1).select("id") + val res6 = ctx.range(-8, -4, 2, 1).select("id") assert(res6.count == 2) assert(res6.agg(sum("id")).as("sumid").collect() === Seq(Row(-14))) - val res7 = TestSQLContext.range(-10, -9, -20, 1).select("id") + val res7 = ctx.range(-10, -9, -20, 1).select("id") assert(res7.count == 0) - val res8 = TestSQLContext.range(Long.MinValue, Long.MaxValue, Long.MaxValue, 100).select("id") + val res8 = ctx.range(Long.MinValue, Long.MaxValue, Long.MaxValue, 100).select("id") assert(res8.count == 3) assert(res8.agg(sum("id")).as("sumid").collect() === Seq(Row(-3))) - val res9 = TestSQLContext.range(Long.MaxValue, Long.MinValue, Long.MinValue, 100).select("id") + val res9 = ctx.range(Long.MaxValue, Long.MinValue, Long.MinValue, 100).select("id") assert(res9.count == 2) assert(res9.agg(sum("id")).as("sumid").collect() === Seq(Row(Long.MaxValue - 1))) // only end provided as argument - val res10 = TestSQLContext.range(10).select("id") + val res10 = ctx.range(10).select("id") assert(res10.count == 10) assert(res10.agg(sum("id")).as("sumid").collect() === Seq(Row(45))) - val res11 = TestSQLContext.range(-1).select("id") + val res11 = ctx.range(-1).select("id") assert(res11.count == 0) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 407c789657..ffd26c4f5a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -20,27 +20,28 @@ package org.apache.spark.sql import org.scalatest.BeforeAndAfterEach import org.apache.spark.sql.TestData._ -import org.apache.spark.sql.functions._ import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.execution.joins._ -import org.apache.spark.sql.test.TestSQLContext._ -import org.apache.spark.sql.test.TestSQLContext.implicits._ class JoinSuite extends QueryTest with BeforeAndAfterEach { // Ensures tables are loaded. TestData + lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + import ctx.logicalPlanToSparkQuery + test("equi-join is hash-join") { val x = testData2.as("x") val y = testData2.as("y") val join = x.join(y, $"x.a" === $"y.a", "inner").queryExecution.optimizedPlan - val planned = planner.HashJoin(join) + val planned = ctx.planner.HashJoin(join) assert(planned.size === 1) } def assertJoin(sqlString: String, c: Class[_]): Any = { - val df = sql(sqlString) + val df = ctx.sql(sqlString) val physical = df.queryExecution.sparkPlan val operators = physical.collect { case j: ShuffledHashJoin => j @@ -61,9 +62,9 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { } test("join operator selection") { - cacheManager.clearCache() + ctx.cacheManager.clearCache() - val SORTMERGEJOIN_ENABLED: Boolean = conf.sortMergeJoinEnabled + val SORTMERGEJOIN_ENABLED: Boolean = ctx.conf.sortMergeJoinEnabled Seq( ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]), ("SELECT * FROM testData LEFT SEMI JOIN testData2", classOf[LeftSemiJoinBNL]), @@ -94,22 +95,22 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { classOf[BroadcastNestedLoopJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } try { - conf.setConf("spark.sql.planner.sortMergeJoin", "true") + ctx.conf.setConf("spark.sql.planner.sortMergeJoin", "true") Seq( ("SELECT * FROM testData JOIN testData2 ON key = a", classOf[SortMergeJoin]), ("SELECT * FROM testData JOIN testData2 ON key = a and key = 2", classOf[SortMergeJoin]), ("SELECT * FROM testData JOIN testData2 ON key = a where key = 2", classOf[SortMergeJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } } finally { - conf.setConf("spark.sql.planner.sortMergeJoin", SORTMERGEJOIN_ENABLED.toString) + ctx.conf.setConf("spark.sql.planner.sortMergeJoin", SORTMERGEJOIN_ENABLED.toString) } } test("broadcasted hash join operator selection") { - cacheManager.clearCache() - sql("CACHE TABLE testData") + ctx.cacheManager.clearCache() + ctx.sql("CACHE TABLE testData") - val SORTMERGEJOIN_ENABLED: Boolean = conf.sortMergeJoinEnabled + val SORTMERGEJOIN_ENABLED: Boolean = ctx.conf.sortMergeJoinEnabled Seq( ("SELECT * FROM testData join testData2 ON key = a", classOf[BroadcastHashJoin]), ("SELECT * FROM testData join testData2 ON key = a and key = 2", classOf[BroadcastHashJoin]), @@ -117,7 +118,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { classOf[BroadcastHashJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } try { - conf.setConf("spark.sql.planner.sortMergeJoin", "true") + ctx.conf.setConf("spark.sql.planner.sortMergeJoin", "true") Seq( ("SELECT * FROM testData join testData2 ON key = a", classOf[BroadcastHashJoin]), ("SELECT * FROM testData join testData2 ON key = a and key = 2", @@ -126,17 +127,17 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { classOf[BroadcastHashJoin]) ).foreach { case (query, joinClass) => assertJoin(query, joinClass) } } finally { - conf.setConf("spark.sql.planner.sortMergeJoin", SORTMERGEJOIN_ENABLED.toString) + ctx.conf.setConf("spark.sql.planner.sortMergeJoin", SORTMERGEJOIN_ENABLED.toString) } - sql("UNCACHE TABLE testData") + ctx.sql("UNCACHE TABLE testData") } test("multiple-key equi-join is hash-join") { val x = testData2.as("x") val y = testData2.as("y") val join = x.join(y, ($"x.a" === $"y.a") && ($"x.b" === $"y.b")).queryExecution.optimizedPlan - val planned = planner.HashJoin(join) + val planned = ctx.planner.HashJoin(join) assert(planned.size === 1) } @@ -241,7 +242,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { // Make sure we are choosing left.outputPartitioning as the // outputPartitioning for the outer join operator. checkAnswer( - sql( + ctx.sql( """ |SELECT l.N, count(*) |FROM upperCaseData l LEFT OUTER JOIN allNulls r ON (l.N = r.a) @@ -255,7 +256,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(6, 1) :: Nil) checkAnswer( - sql( + ctx.sql( """ |SELECT r.a, count(*) |FROM upperCaseData l LEFT OUTER JOIN allNulls r ON (l.N = r.a) @@ -301,7 +302,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { // Make sure we are choosing right.outputPartitioning as the // outputPartitioning for the outer join operator. checkAnswer( - sql( + ctx.sql( """ |SELECT l.a, count(*) |FROM allNulls l RIGHT OUTER JOIN upperCaseData r ON (l.a = r.N) @@ -310,7 +311,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(null, 6)) checkAnswer( - sql( + ctx.sql( """ |SELECT r.N, count(*) |FROM allNulls l RIGHT OUTER JOIN upperCaseData r ON (l.a = r.N) @@ -362,7 +363,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { // Make sure we are UnknownPartitioning as the outputPartitioning for the outer join operator. checkAnswer( - sql( + ctx.sql( """ |SELECT l.a, count(*) |FROM allNulls l FULL OUTER JOIN upperCaseData r ON (l.a = r.N) @@ -371,7 +372,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(null, 10)) checkAnswer( - sql( + ctx.sql( """ |SELECT r.N, count(*) |FROM allNulls l FULL OUTER JOIN upperCaseData r ON (l.a = r.N) @@ -386,7 +387,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(null, 4) :: Nil) checkAnswer( - sql( + ctx.sql( """ |SELECT l.N, count(*) |FROM upperCaseData l FULL OUTER JOIN allNulls r ON (l.N = r.a) @@ -401,7 +402,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { Row(null, 4) :: Nil) checkAnswer( - sql( + ctx.sql( """ |SELECT r.a, count(*) |FROM upperCaseData l FULL OUTER JOIN allNulls r ON (l.N = r.a) @@ -411,11 +412,11 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { } test("broadcasted left semi join operator selection") { - cacheManager.clearCache() - sql("CACHE TABLE testData") - val tmp = conf.autoBroadcastJoinThreshold + ctx.cacheManager.clearCache() + ctx.sql("CACHE TABLE testData") + val tmp = ctx.conf.autoBroadcastJoinThreshold - sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=1000000000") + ctx.sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=1000000000") Seq( ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[BroadcastLeftSemiJoinHash]) @@ -423,7 +424,7 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { case (query, joinClass) => assertJoin(query, joinClass) } - sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=-1") + ctx.sql(s"SET ${SQLConf.AUTO_BROADCASTJOIN_THRESHOLD}=-1") Seq( ("SELECT * FROM testData LEFT SEMI JOIN testData2 ON key = a", classOf[LeftSemiJoinHash]) @@ -431,12 +432,12 @@ class JoinSuite extends QueryTest with BeforeAndAfterEach { case (query, joinClass) => assertJoin(query, joinClass) } - setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, tmp.toString) - sql("UNCACHE TABLE testData") + ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, tmp.toString) + ctx.sql("UNCACHE TABLE testData") } test("left semi join") { - val df = sql("SELECT * FROM testData2 LEFT SEMI JOIN testData ON key = a") + val df = ctx.sql("SELECT * FROM testData2 LEFT SEMI JOIN testData ON key = a") checkAnswer(df, Row(1, 1) :: Row(1, 2) :: diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala index 3ce97c3fff..2089660c52 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala @@ -19,49 +19,47 @@ package org.apache.spark.sql import org.scalatest.BeforeAndAfter -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.test.TestSQLContext._ import org.apache.spark.sql.types.{BooleanType, StringType, StructField, StructType} class ListTablesSuite extends QueryTest with BeforeAndAfter { - import org.apache.spark.sql.test.TestSQLContext.implicits._ + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ - val df = - sparkContext.parallelize((1 to 10).map(i => (i, s"str$i"))).toDF("key", "value") + private lazy val df = (1 to 10).map(i => (i, s"str$i")).toDF("key", "value") before { df.registerTempTable("ListTablesSuiteTable") } after { - catalog.unregisterTable(Seq("ListTablesSuiteTable")) + ctx.catalog.unregisterTable(Seq("ListTablesSuiteTable")) } test("get all tables") { checkAnswer( - tables().filter("tableName = 'ListTablesSuiteTable'"), + ctx.tables().filter("tableName = 'ListTablesSuiteTable'"), Row("ListTablesSuiteTable", true)) checkAnswer( - sql("SHOW tables").filter("tableName = 'ListTablesSuiteTable'"), + ctx.sql("SHOW tables").filter("tableName = 'ListTablesSuiteTable'"), Row("ListTablesSuiteTable", true)) - catalog.unregisterTable(Seq("ListTablesSuiteTable")) - assert(tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0) + ctx.catalog.unregisterTable(Seq("ListTablesSuiteTable")) + assert(ctx.tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0) } test("getting all Tables with a database name has no impact on returned table names") { checkAnswer( - tables("DB").filter("tableName = 'ListTablesSuiteTable'"), + ctx.tables("DB").filter("tableName = 'ListTablesSuiteTable'"), Row("ListTablesSuiteTable", true)) checkAnswer( - sql("show TABLES in DB").filter("tableName = 'ListTablesSuiteTable'"), + ctx.sql("show TABLES in DB").filter("tableName = 'ListTablesSuiteTable'"), Row("ListTablesSuiteTable", true)) - catalog.unregisterTable(Seq("ListTablesSuiteTable")) - assert(tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0) + ctx.catalog.unregisterTable(Seq("ListTablesSuiteTable")) + assert(ctx.tables().filter("tableName = 'ListTablesSuiteTable'").count() === 0) } test("query the returned DataFrame of tables") { @@ -69,19 +67,20 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter { StructField("tableName", StringType, false) :: StructField("isTemporary", BooleanType, false) :: Nil) - Seq(tables(), sql("SHOW TABLes")).foreach { + Seq(ctx.tables(), ctx.sql("SHOW TABLes")).foreach { case tableDF => assert(expectedSchema === tableDF.schema) tableDF.registerTempTable("tables") checkAnswer( - sql("SELECT isTemporary, tableName from tables WHERE tableName = 'ListTablesSuiteTable'"), + ctx.sql( + "SELECT isTemporary, tableName from tables WHERE tableName = 'ListTablesSuiteTable'"), Row(true, "ListTablesSuiteTable") ) checkAnswer( - tables().filter("tableName = 'tables'").select("tableName", "isTemporary"), + ctx.tables().filter("tableName = 'tables'").select("tableName", "isTemporary"), Row("tables", true)) - dropTempTable("tables") + ctx.dropTempTable("tables") } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala index dd68965444..0a38af2b4c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala @@ -17,36 +17,29 @@ package org.apache.spark.sql -import java.lang.{Double => JavaDouble} - import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.test.TestSQLContext.implicits._ - -private[this] object MathExpressionsTestData { - - case class DoubleData(a: JavaDouble, b: JavaDouble) - val doubleData = TestSQLContext.sparkContext.parallelize( - (1 to 10).map(i => DoubleData(i * 0.2 - 1, i * -0.2 + 1))).toDF() - - val nnDoubleData = TestSQLContext.sparkContext.parallelize( - (1 to 10).map(i => DoubleData(i * 0.1, i * -0.1))).toDF() - - case class NullDoubles(a: JavaDouble) - val nullDoubles = - TestSQLContext.sparkContext.parallelize( - NullDoubles(1.0) :: - NullDoubles(2.0) :: - NullDoubles(3.0) :: - NullDoubles(null) :: Nil - ).toDF() + + +private object MathExpressionsTestData { + case class DoubleData(a: java.lang.Double, b: java.lang.Double) + case class NullDoubles(a: java.lang.Double) } class MathExpressionsSuite extends QueryTest { import MathExpressionsTestData._ - def testOneToOneMathFunction[@specialized(Int, Long, Float, Double) T]( + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + + private lazy val doubleData = (1 to 10).map(i => DoubleData(i * 0.2 - 1, i * -0.2 + 1)).toDF() + + private lazy val nnDoubleData = (1 to 10).map(i => DoubleData(i * 0.1, i * -0.1)).toDF() + + private lazy val nullDoubles = + Seq(NullDoubles(1.0), NullDoubles(2.0), NullDoubles(3.0), NullDoubles(null)).toDF() + + private def testOneToOneMathFunction[@specialized(Int, Long, Float, Double) T]( c: Column => Column, f: T => T): Unit = { checkAnswer( @@ -65,7 +58,8 @@ class MathExpressionsSuite extends QueryTest { ) } - def testOneToOneNonNegativeMathFunction(c: Column => Column, f: Double => Double): Unit = { + private def testOneToOneNonNegativeMathFunction(c: Column => Column, f: Double => Double): Unit = + { checkAnswer( nnDoubleData.select(c('a)), (1 to 10).map(n => Row(f(n * 0.1))) @@ -89,7 +83,7 @@ class MathExpressionsSuite extends QueryTest { ) } - def testTwoToOneMathFunction( + private def testTwoToOneMathFunction( c: (Column, Column) => Column, d: (Column, Double) => Column, f: (Double, Double) => Double): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala index 513ac915dc..d84b57af9c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala @@ -21,12 +21,13 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.execution.SparkSqlSerializer import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, SpecificMutableRow} -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.types._ class RowSuite extends SparkFunSuite { + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + test("create row") { val expected = new GenericMutableRow(4) expected.update(0, 2147483647) @@ -56,7 +57,7 @@ class RowSuite extends SparkFunSuite { test("serialize w/ kryo") { val row = Seq((1, Seq(1), Map(1 -> 1), BigDecimal(1))).toDF().first() - val serializer = new SparkSqlSerializer(TestSQLContext.sparkContext.getConf) + val serializer = new SparkSqlSerializer(ctx.sparkContext.getConf) val instance = serializer.newInstance() val ser = instance.serialize(row) val de = instance.deserialize(ser).asInstanceOf[Row] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala index 3a5f071e2f..76d0dd1744 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala @@ -17,67 +17,64 @@ package org.apache.spark.sql -import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.test._ - -/* Implicits */ -import TestSQLContext._ class SQLConfSuite extends QueryTest { - val testKey = "test.key.0" - val testVal = "test.val.0" + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + + private val testKey = "test.key.0" + private val testVal = "test.val.0" test("propagate from spark conf") { // We create a new context here to avoid order dependence with other tests that might call // clear(). - val newContext = new SQLContext(TestSQLContext.sparkContext) - assert(newContext.getConf("spark.sql.testkey", "false") == "true") + val newContext = new SQLContext(ctx.sparkContext) + assert(newContext.getConf("spark.sql.testkey", "false") === "true") } test("programmatic ways of basic setting and getting") { - conf.clear() - assert(getAllConfs.size === 0) + ctx.conf.clear() + assert(ctx.getAllConfs.size === 0) - setConf(testKey, testVal) - assert(getConf(testKey) == testVal) - assert(getConf(testKey, testVal + "_") == testVal) - assert(getAllConfs.contains(testKey)) + ctx.setConf(testKey, testVal) + assert(ctx.getConf(testKey) === testVal) + assert(ctx.getConf(testKey, testVal + "_") === testVal) + assert(ctx.getAllConfs.contains(testKey)) // Tests SQLConf as accessed from a SQLContext is mutable after // the latter is initialized, unlike SparkConf inside a SparkContext. - assert(TestSQLContext.getConf(testKey) == testVal) - assert(TestSQLContext.getConf(testKey, testVal + "_") == testVal) - assert(TestSQLContext.getAllConfs.contains(testKey)) + assert(ctx.getConf(testKey) == testVal) + assert(ctx.getConf(testKey, testVal + "_") === testVal) + assert(ctx.getAllConfs.contains(testKey)) - conf.clear() + ctx.conf.clear() } test("parse SQL set commands") { - conf.clear() - sql(s"set $testKey=$testVal") - assert(getConf(testKey, testVal + "_") == testVal) - assert(TestSQLContext.getConf(testKey, testVal + "_") == testVal) + ctx.conf.clear() + ctx.sql(s"set $testKey=$testVal") + assert(ctx.getConf(testKey, testVal + "_") === testVal) + assert(ctx.getConf(testKey, testVal + "_") === testVal) - sql("set some.property=20") - assert(getConf("some.property", "0") == "20") - sql("set some.property = 40") - assert(getConf("some.property", "0") == "40") + ctx.sql("set some.property=20") + assert(ctx.getConf("some.property", "0") === "20") + ctx.sql("set some.property = 40") + assert(ctx.getConf("some.property", "0") === "40") val key = "spark.sql.key" val vs = "val0,val_1,val2.3,my_table" - sql(s"set $key=$vs") - assert(getConf(key, "0") == vs) + ctx.sql(s"set $key=$vs") + assert(ctx.getConf(key, "0") === vs) - sql(s"set $key=") - assert(getConf(key, "0") == "") + ctx.sql(s"set $key=") + assert(ctx.getConf(key, "0") === "") - conf.clear() + ctx.conf.clear() } test("deprecated property") { - conf.clear() - sql(s"set ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS}=10") - assert(getConf(SQLConf.SHUFFLE_PARTITIONS) == "10") + ctx.conf.clear() + ctx.sql(s"set ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS}=10") + assert(ctx.getConf(SQLConf.SHUFFLE_PARTITIONS) === "10") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala index 797d123b48..c8d8796568 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala @@ -20,31 +20,29 @@ package org.apache.spark.sql import org.scalatest.BeforeAndAfterAll import org.apache.spark.SparkFunSuite -import org.apache.spark.sql.test.TestSQLContext class SQLContextSuite extends SparkFunSuite with BeforeAndAfterAll { - private val testSqlContext = TestSQLContext - private val testSparkContext = TestSQLContext.sparkContext + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext override def afterAll(): Unit = { - SQLContext.setLastInstantiatedContext(testSqlContext) + SQLContext.setLastInstantiatedContext(ctx) } test("getOrCreate instantiates SQLContext") { SQLContext.clearLastInstantiatedContext() - val sqlContext = SQLContext.getOrCreate(testSparkContext) + val sqlContext = SQLContext.getOrCreate(ctx.sparkContext) assert(sqlContext != null, "SQLContext.getOrCreate returned null") - assert(SQLContext.getOrCreate(testSparkContext).eq(sqlContext), + assert(SQLContext.getOrCreate(ctx.sparkContext).eq(sqlContext), "SQLContext created by SQLContext.getOrCreate not returned by SQLContext.getOrCreate") } test("getOrCreate gets last explicitly instantiated SQLContext") { SQLContext.clearLastInstantiatedContext() - val sqlContext = new SQLContext(testSparkContext) - assert(SQLContext.getOrCreate(testSparkContext) != null, + val sqlContext = new SQLContext(ctx.sparkContext) + assert(SQLContext.getOrCreate(ctx.sparkContext) != null, "SQLContext.getOrCreate after explicitly created SQLContext returned null") - assert(SQLContext.getOrCreate(testSparkContext).eq(sqlContext), + assert(SQLContext.getOrCreate(ctx.sparkContext).eq(sqlContext), "SQLContext.getOrCreate after explicitly created SQLContext did not return the context") } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 55b68d8e22..5babc4332c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -24,9 +24,7 @@ import org.apache.spark.sql.catalyst.errors.DialectException import org.apache.spark.sql.execution.GeneratedAggregate import org.apache.spark.sql.functions._ import org.apache.spark.sql.TestData._ -import org.apache.spark.sql.test.{SQLTestUtils, TestSQLContext} -import org.apache.spark.sql.test.TestSQLContext.{udf => _, _} - +import org.apache.spark.sql.test.SQLTestUtils import org.apache.spark.sql.types._ /** A SQL Dialect for testing purpose, and it can not be nested type */ @@ -36,8 +34,9 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { // Make sure the tables are loaded. TestData - val sqlContext = TestSQLContext + val sqlContext = org.apache.spark.sql.test.TestSQLContext import sqlContext.implicits._ + import sqlContext.sql test("SPARK-6743: no columns from cache") { Seq( @@ -46,7 +45,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { (43, 81, 24) ).toDF("a", "b", "c").registerTempTable("cachedData") - cacheTable("cachedData") + sqlContext.cacheTable("cachedData") checkAnswer( sql("SELECT t1.b FROM cachedData, cachedData t1 GROUP BY t1.b"), Row(0) :: Row(81) :: Nil) @@ -94,14 +93,14 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { } test("SQL Dialect Switching to a new SQL parser") { - val newContext = new SQLContext(TestSQLContext.sparkContext) + val newContext = new SQLContext(sqlContext.sparkContext) newContext.setConf("spark.sql.dialect", classOf[MyDialect].getCanonicalName()) assert(newContext.getSQLDialect().getClass === classOf[MyDialect]) assert(newContext.sql("SELECT 1").collect() === Array(Row(1))) } test("SQL Dialect Switch to an invalid parser with alias") { - val newContext = new SQLContext(TestSQLContext.sparkContext) + val newContext = new SQLContext(sqlContext.sparkContext) newContext.sql("SET spark.sql.dialect=MyTestClass") intercept[DialectException] { newContext.sql("SELECT 1") @@ -118,7 +117,8 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { } test("grouping on nested fields") { - read.json(sparkContext.parallelize("""{"nested": {"attribute": 1}, "value": 2}""" :: Nil)) + sqlContext.read.json(sqlContext.sparkContext.parallelize( + """{"nested": {"attribute": 1}, "value": 2}""" :: Nil)) .registerTempTable("rows") checkAnswer( @@ -135,8 +135,9 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { } test("SPARK-6201 IN type conversion") { - read.json( - sparkContext.parallelize(Seq("{\"a\": \"1\"}}", "{\"a\": \"2\"}}", "{\"a\": \"3\"}}"))) + sqlContext.read.json( + sqlContext.sparkContext.parallelize( + Seq("{\"a\": \"1\"}}", "{\"a\": \"2\"}}", "{\"a\": \"3\"}}"))) .registerTempTable("d") checkAnswer( @@ -157,12 +158,12 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { } test("aggregation with codegen") { - val originalValue = conf.codegenEnabled - setConf(SQLConf.CODEGEN_ENABLED, "true") + val originalValue = sqlContext.conf.codegenEnabled + sqlContext.setConf(SQLConf.CODEGEN_ENABLED, "true") // Prepare a table that we can group some rows. - table("testData") - .unionAll(table("testData")) - .unionAll(table("testData")) + sqlContext.table("testData") + .unionAll(sqlContext.table("testData")) + .unionAll(sqlContext.table("testData")) .registerTempTable("testData3x") def testCodeGen(sqlText: String, expectedResults: Seq[Row]): Unit = { @@ -254,8 +255,8 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { "SELECT sum('a'), avg('a'), count(null) FROM testData", Row(0, null, 0) :: Nil) } finally { - dropTempTable("testData3x") - setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString) + sqlContext.dropTempTable("testData3x") + sqlContext.setConf(SQLConf.CODEGEN_ENABLED, originalValue.toString) } } @@ -447,42 +448,42 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { } test("sorting") { - val before = conf.externalSortEnabled - setConf(SQLConf.EXTERNAL_SORT, "false") + val before = sqlContext.conf.externalSortEnabled + sqlContext.setConf(SQLConf.EXTERNAL_SORT, "false") sortTest() - setConf(SQLConf.EXTERNAL_SORT, before.toString) + sqlContext.setConf(SQLConf.EXTERNAL_SORT, before.toString) } test("external sorting") { - val before = conf.externalSortEnabled - setConf(SQLConf.EXTERNAL_SORT, "true") + val before = sqlContext.conf.externalSortEnabled + sqlContext.setConf(SQLConf.EXTERNAL_SORT, "true") sortTest() - setConf(SQLConf.EXTERNAL_SORT, before.toString) + sqlContext.setConf(SQLConf.EXTERNAL_SORT, before.toString) } test("SPARK-6927 sorting with codegen on") { - val externalbefore = conf.externalSortEnabled - val codegenbefore = conf.codegenEnabled - setConf(SQLConf.EXTERNAL_SORT, "false") - setConf(SQLConf.CODEGEN_ENABLED, "true") + val externalbefore = sqlContext.conf.externalSortEnabled + val codegenbefore = sqlContext.conf.codegenEnabled + sqlContext.setConf(SQLConf.EXTERNAL_SORT, "false") + sqlContext.setConf(SQLConf.CODEGEN_ENABLED, "true") try{ sortTest() } finally { - setConf(SQLConf.EXTERNAL_SORT, externalbefore.toString) - setConf(SQLConf.CODEGEN_ENABLED, codegenbefore.toString) + sqlContext.setConf(SQLConf.EXTERNAL_SORT, externalbefore.toString) + sqlContext.setConf(SQLConf.CODEGEN_ENABLED, codegenbefore.toString) } } test("SPARK-6927 external sorting with codegen on") { - val externalbefore = conf.externalSortEnabled - val codegenbefore = conf.codegenEnabled - setConf(SQLConf.CODEGEN_ENABLED, "true") - setConf(SQLConf.EXTERNAL_SORT, "true") + val externalbefore = sqlContext.conf.externalSortEnabled + val codegenbefore = sqlContext.conf.codegenEnabled + sqlContext.setConf(SQLConf.CODEGEN_ENABLED, "true") + sqlContext.setConf(SQLConf.EXTERNAL_SORT, "true") try { sortTest() } finally { - setConf(SQLConf.EXTERNAL_SORT, externalbefore.toString) - setConf(SQLConf.CODEGEN_ENABLED, codegenbefore.toString) + sqlContext.setConf(SQLConf.EXTERNAL_SORT, externalbefore.toString) + sqlContext.setConf(SQLConf.CODEGEN_ENABLED, codegenbefore.toString) } } @@ -516,7 +517,8 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { test("Allow only a single WITH clause per query") { intercept[RuntimeException] { - sql("with q1 as (select * from testData) with q2 as (select * from q1) select * from q2") + sql( + "with q1 as (select * from testData) with q2 as (select * from q1) select * from q2") } } @@ -863,7 +865,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { } test("SET commands semantics using sql()") { - conf.clear() + sqlContext.conf.clear() val testKey = "test.key.0" val testVal = "test.val.0" val nonexistentKey = "nonexistent" @@ -895,17 +897,17 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { sql(s"SET $nonexistentKey"), Row(s"$nonexistentKey=") ) - conf.clear() + sqlContext.conf.clear() } test("SET commands with illegal or inappropriate argument") { - conf.clear() + sqlContext.conf.clear() // Set negative mapred.reduce.tasks for automatically determing // the number of reducers is not supported intercept[IllegalArgumentException](sql(s"SET mapred.reduce.tasks=-1")) intercept[IllegalArgumentException](sql(s"SET mapred.reduce.tasks=-01")) intercept[IllegalArgumentException](sql(s"SET mapred.reduce.tasks=-2")) - conf.clear() + sqlContext.conf.clear() } test("apply schema") { @@ -923,7 +925,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { Row(values(0).toInt, values(1), values(2).toBoolean, v4) } - val df1 = createDataFrame(rowRDD1, schema1) + val df1 = sqlContext.createDataFrame(rowRDD1, schema1) df1.registerTempTable("applySchema1") checkAnswer( sql("SELECT * FROM applySchema1"), @@ -953,7 +955,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { Row(Row(values(0).toInt, values(2).toBoolean), Map(values(1) -> v4)) } - val df2 = createDataFrame(rowRDD2, schema2) + val df2 = sqlContext.createDataFrame(rowRDD2, schema2) df2.registerTempTable("applySchema2") checkAnswer( sql("SELECT * FROM applySchema2"), @@ -978,7 +980,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { Row(Row(values(0).toInt, values(2).toBoolean), scala.collection.mutable.Map(values(1) -> v4)) } - val df3 = createDataFrame(rowRDD3, schema2) + val df3 = sqlContext.createDataFrame(rowRDD3, schema2) df3.registerTempTable("applySchema3") checkAnswer( @@ -1023,7 +1025,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { .build() val schemaWithMeta = new StructType(Array( schema("id"), schema("name").copy(metadata = metadata), schema("age"))) - val personWithMeta = createDataFrame(person.rdd, schemaWithMeta) + val personWithMeta = sqlContext.createDataFrame(person.rdd, schemaWithMeta) def validateMetadata(rdd: DataFrame): Unit = { assert(rdd.schema("name").metadata.getString(docKey) == docValue) } @@ -1038,7 +1040,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { } test("SPARK-3371 Renaming a function expression with group by gives error") { - TestSQLContext.udf.register("len", (s: String) => s.length) + sqlContext.udf.register("len", (s: String) => s.length) checkAnswer( sql("SELECT len(value) as temp FROM testData WHERE key = 1 group by len(value)"), Row(1)) @@ -1219,9 +1221,9 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { } test("SPARK-3483 Special chars in column names") { - val data = sparkContext.parallelize( + val data = sqlContext.sparkContext.parallelize( Seq("""{"key?number1": "value1", "key.number2": "value2"}""")) - read.json(data).registerTempTable("records") + sqlContext.read.json(data).registerTempTable("records") sql("SELECT `key?number1`, `key.number2` FROM records") } @@ -1262,13 +1264,15 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { } test("SPARK-4322 Grouping field with struct field as sub expression") { - read.json(sparkContext.makeRDD("""{"a": {"b": [{"c": 1}]}}""" :: Nil)).registerTempTable("data") + sqlContext.read.json(sqlContext.sparkContext.makeRDD("""{"a": {"b": [{"c": 1}]}}""" :: Nil)) + .registerTempTable("data") checkAnswer(sql("SELECT a.b[0].c FROM data GROUP BY a.b[0].c"), Row(1)) - dropTempTable("data") + sqlContext.dropTempTable("data") - read.json(sparkContext.makeRDD("""{"a": {"b": 1}}""" :: Nil)).registerTempTable("data") + sqlContext.read.json( + sqlContext.sparkContext.makeRDD("""{"a": {"b": 1}}""" :: Nil)).registerTempTable("data") checkAnswer(sql("SELECT a.b + 1 FROM data GROUP BY a.b + 1"), Row(2)) - dropTempTable("data") + sqlContext.dropTempTable("data") } test("SPARK-4432 Fix attribute reference resolution error when using ORDER BY") { @@ -1287,10 +1291,10 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { test("Supporting relational operator '<=>' in Spark SQL") { val nullCheckData1 = TestData(1, "1") :: TestData(2, null) :: Nil - val rdd1 = sparkContext.parallelize((0 to 1).map(i => nullCheckData1(i))) + val rdd1 = sqlContext.sparkContext.parallelize((0 to 1).map(i => nullCheckData1(i))) rdd1.toDF().registerTempTable("nulldata1") val nullCheckData2 = TestData(1, "1") :: TestData(2, null) :: Nil - val rdd2 = sparkContext.parallelize((0 to 1).map(i => nullCheckData2(i))) + val rdd2 = sqlContext.sparkContext.parallelize((0 to 1).map(i => nullCheckData2(i))) rdd2.toDF().registerTempTable("nulldata2") checkAnswer(sql("SELECT nulldata1.key FROM nulldata1 join " + "nulldata2 on nulldata1.value <=> nulldata2.value"), @@ -1299,22 +1303,23 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { test("Multi-column COUNT(DISTINCT ...)") { val data = TestData(1, "val_1") :: TestData(2, "val_2") :: Nil - val rdd = sparkContext.parallelize((0 to 1).map(i => data(i))) + val rdd = sqlContext.sparkContext.parallelize((0 to 1).map(i => data(i))) rdd.toDF().registerTempTable("distinctData") checkAnswer(sql("SELECT COUNT(DISTINCT key,value) FROM distinctData"), Row(2)) } test("SPARK-4699 case sensitivity SQL query") { - setConf(SQLConf.CASE_SENSITIVE, "false") + sqlContext.setConf(SQLConf.CASE_SENSITIVE, "false") val data = TestData(1, "val_1") :: TestData(2, "val_2") :: Nil - val rdd = sparkContext.parallelize((0 to 1).map(i => data(i))) + val rdd = sqlContext.sparkContext.parallelize((0 to 1).map(i => data(i))) rdd.toDF().registerTempTable("testTable1") checkAnswer(sql("SELECT VALUE FROM TESTTABLE1 where KEY = 1"), Row("val_1")) - setConf(SQLConf.CASE_SENSITIVE, "true") + sqlContext.setConf(SQLConf.CASE_SENSITIVE, "true") } test("SPARK-6145: ORDER BY test for nested fields") { - read.json(sparkContext.makeRDD("""{"a": {"b": 1, "a": {"a": 1}}, "c": [{"d": 1}]}""" :: Nil)) + sqlContext.read.json(sqlContext.sparkContext.makeRDD( + """{"a": {"b": 1, "a": {"a": 1}}, "c": [{"d": 1}]}""" :: Nil)) .registerTempTable("nestedOrder") checkAnswer(sql("SELECT 1 FROM nestedOrder ORDER BY a.b"), Row(1)) @@ -1326,14 +1331,14 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils { } test("SPARK-6145: special cases") { - read.json(sparkContext.makeRDD( + sqlContext.read.json(sqlContext.sparkContext.makeRDD( """{"a": {"b": [1]}, "b": [{"a": 1}], "c0": {"a": 1}}""" :: Nil)).registerTempTable("t") checkAnswer(sql("SELECT a.b[0] FROM t ORDER BY c0.a"), Row(1)) checkAnswer(sql("SELECT b[0].a FROM t ORDER BY c0.a"), Row(1)) } test("SPARK-6898: complete support for special chars in column names") { - read.json(sparkContext.makeRDD( + sqlContext.read.json(sqlContext.sparkContext.makeRDD( """{"a": {"c.b": 1}, "b.$q": [{"a@!.q": 1}], "q.w": {"w.i&": [1]}}""" :: Nil)) .registerTempTable("t") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala index d2ede39f0a..ece3d6fdf2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala @@ -21,7 +21,6 @@ import java.sql.{Date, Timestamp} import org.apache.spark.SparkFunSuite import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.test.TestSQLContext._ case class ReflectData( stringField: String, @@ -75,15 +74,15 @@ case class ComplexReflectData( class ScalaReflectionRelationSuite extends SparkFunSuite { - import org.apache.spark.sql.test.TestSQLContext.implicits._ + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ test("query case class RDD") { val data = ReflectData("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true, new java.math.BigDecimal(1), new Date(12345), new Timestamp(12345), Seq(1, 2, 3)) - val rdd = sparkContext.parallelize(data :: Nil) - rdd.toDF().registerTempTable("reflectData") + Seq(data).toDF().registerTempTable("reflectData") - assert(sql("SELECT * FROM reflectData").collect().head === + assert(ctx.sql("SELECT * FROM reflectData").collect().head === Row("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true, new java.math.BigDecimal(1), Date.valueOf("1970-01-01"), new Timestamp(12345), Seq(1, 2, 3))) @@ -91,27 +90,26 @@ class ScalaReflectionRelationSuite extends SparkFunSuite { test("query case class RDD with nulls") { val data = NullReflectData(null, null, null, null, null, null, null) - val rdd = sparkContext.parallelize(data :: Nil) - rdd.toDF().registerTempTable("reflectNullData") + Seq(data).toDF().registerTempTable("reflectNullData") - assert(sql("SELECT * FROM reflectNullData").collect().head === Row.fromSeq(Seq.fill(7)(null))) + assert(ctx.sql("SELECT * FROM reflectNullData").collect().head === + Row.fromSeq(Seq.fill(7)(null))) } test("query case class RDD with Nones") { val data = OptionalReflectData(None, None, None, None, None, None, None) - val rdd = sparkContext.parallelize(data :: Nil) - rdd.toDF().registerTempTable("reflectOptionalData") + Seq(data).toDF().registerTempTable("reflectOptionalData") - assert(sql("SELECT * FROM reflectOptionalData").collect().head === + assert(ctx.sql("SELECT * FROM reflectOptionalData").collect().head === Row.fromSeq(Seq.fill(7)(null))) } // Equality is broken for Arrays, so we test that separately. test("query binary data") { - val rdd = sparkContext.parallelize(ReflectBinary(Array[Byte](1)) :: Nil) - rdd.toDF().registerTempTable("reflectBinary") + Seq(ReflectBinary(Array[Byte](1))).toDF().registerTempTable("reflectBinary") - val result = sql("SELECT data FROM reflectBinary").collect().head(0).asInstanceOf[Array[Byte]] + val result = ctx.sql("SELECT data FROM reflectBinary") + .collect().head(0).asInstanceOf[Array[Byte]] assert(result.toSeq === Seq[Byte](1)) } @@ -127,10 +125,9 @@ class ScalaReflectionRelationSuite extends SparkFunSuite { Map(10 -> 100L, 20 -> 200L), Map(10 -> Some(100L), 20 -> Some(200L), 30 -> None), Nested(None, "abc"))) - val rdd = sparkContext.parallelize(data :: Nil) - rdd.toDF().registerTempTable("reflectComplexData") - assert(sql("SELECT * FROM reflectComplexData").collect().head === + Seq(data).toDF().registerTempTable("reflectComplexData") + assert(ctx.sql("SELECT * FROM reflectComplexData").collect().head === new GenericRow(Array[Any]( Seq(1, 2, 3), Seq(1, 2, null), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala index 1e8cde606b..e55c9e460b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala @@ -19,12 +19,13 @@ package org.apache.spark.sql import org.apache.spark.{SparkConf, SparkFunSuite} import org.apache.spark.serializer.JavaSerializer -import org.apache.spark.sql.test.TestSQLContext class SerializationSuite extends SparkFunSuite { + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + test("[SPARK-5235] SQLContext should be serializable") { - val sqlContext = new SQLContext(TestSQLContext.sparkContext) + val sqlContext = new SQLContext(ctx.sparkContext) new JavaSerializer(new SparkConf()).newInstance().serialize(sqlContext) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala index 1a9ba66416..064c040d2b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala @@ -17,43 +17,41 @@ package org.apache.spark.sql -import org.apache.spark.sql.test._ - -/* Implicits */ -import TestSQLContext._ -import TestSQLContext.implicits._ case class FunctionResult(f1: String, f2: String) class UDFSuite extends QueryTest { + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + test("Simple UDF") { - udf.register("strLenScala", (_: String).length) - assert(sql("SELECT strLenScala('test')").head().getInt(0) === 4) + ctx.udf.register("strLenScala", (_: String).length) + assert(ctx.sql("SELECT strLenScala('test')").head().getInt(0) === 4) } test("ZeroArgument UDF") { - udf.register("random0", () => { Math.random()}) - assert(sql("SELECT random0()").head().getDouble(0) >= 0.0) + ctx.udf.register("random0", () => { Math.random()}) + assert(ctx.sql("SELECT random0()").head().getDouble(0) >= 0.0) } test("TwoArgument UDF") { - udf.register("strLenScala", (_: String).length + (_: Int)) - assert(sql("SELECT strLenScala('test', 1)").head().getInt(0) === 5) + ctx.udf.register("strLenScala", (_: String).length + (_: Int)) + assert(ctx.sql("SELECT strLenScala('test', 1)").head().getInt(0) === 5) } test("struct UDF") { - udf.register("returnStruct", (f1: String, f2: String) => FunctionResult(f1, f2)) + ctx.udf.register("returnStruct", (f1: String, f2: String) => FunctionResult(f1, f2)) val result = - sql("SELECT returnStruct('test', 'test2') as ret") + ctx.sql("SELECT returnStruct('test', 'test2') as ret") .select($"ret.f1").head().getString(0) assert(result === "test") } test("udf that is transformed") { - udf.register("makeStruct", (x: Int, y: Int) => (x, y)) + ctx.udf.register("makeStruct", (x: Int, y: Int) => (x, y)) // 1 + 1 is constant folded causing a transformation. - assert(sql("SELECT makeStruct(1 + 1, 2)").first().getAs[Row](0) === Row(2, 2)) + assert(ctx.sql("SELECT makeStruct(1 + 1, 2)").first().getAs[Row](0) === Row(2, 2)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala index dc2d43a197..45c9f06941 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala @@ -17,10 +17,6 @@ package org.apache.spark.sql -import java.io.File - -import org.apache.spark.util.Utils - import scala.beans.{BeanInfo, BeanProperty} import com.clearspring.analytics.stream.cardinality.HyperLogLog @@ -28,12 +24,11 @@ import com.clearspring.analytics.stream.cardinality.HyperLogLog import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.expressions.{OpenHashSetUDT, HyperLogLogUDT} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.test.TestSQLContext -import org.apache.spark.sql.test.TestSQLContext.{sparkContext, sql} -import org.apache.spark.sql.test.TestSQLContext.implicits._ import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils import org.apache.spark.util.collection.OpenHashSet + @SQLUserDefinedType(udt = classOf[MyDenseVectorUDT]) private[sql] class MyDenseVector(val data: Array[Double]) extends Serializable { override def equals(other: Any): Boolean = other match { @@ -72,11 +67,13 @@ private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] { } class UserDefinedTypeSuite extends QueryTest { - val points = Seq( - MyLabeledPoint(1.0, new MyDenseVector(Array(0.1, 1.0))), - MyLabeledPoint(0.0, new MyDenseVector(Array(0.2, 2.0)))) - val pointsRDD = sparkContext.parallelize(points).toDF() + private lazy val ctx = org.apache.spark.sql.test.TestSQLContext + import ctx.implicits._ + + private lazy val pointsRDD = Seq( + MyLabeledPoint(1.0, new MyDenseVector(Array(0.1, 1.0))), + MyLabeledPoint(0.0, new MyDenseVector(Array(0.2, 2.0)))).toDF() test("register user type: MyDenseVector for MyLabeledPoint") { val labels: RDD[Double] = pointsRDD.select('label).rdd.map { case Row(v: Double) => v } @@ -94,10 +91,10 @@ class UserDefinedTypeSuite extends QueryTest { } test("UDTs and UDFs") { - TestSQLContext.udf.register("testType", (d: MyDenseVector) => d.isInstanceOf[MyDenseVector]) + ctx.udf.register("testType", (d: MyDenseVector) => d.isInstanceOf[MyDenseVector]) pointsRDD.registerTempTable("points") checkAnswer( - sql("SELECT testType(features) from points"), + ctx.sql("SELECT testType(features) from points"), Seq(Row(true), Row(true))) } -- cgit v1.2.3