aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndrew Or <andrew@databricks.com>2015-08-13 17:42:01 -0700
committerReynold Xin <rxin@databricks.com>2015-08-13 17:42:01 -0700
commit8187b3ae477e2b2987ae9acc5368d57b1d5653b2 (patch)
treee80b71bbbfbf39b0fdca5a5bfca567ae8e0ca6a3
parentc50f97dafd2d5bf5a8351efcc1c8d3e2b87efc72 (diff)
downloadspark-8187b3ae477e2b2987ae9acc5368d57b1d5653b2.tar.gz
spark-8187b3ae477e2b2987ae9acc5368d57b1d5653b2.tar.bz2
spark-8187b3ae477e2b2987ae9acc5368d57b1d5653b2.zip
[SPARK-9580] [SQL] Replace singletons in SQL tests
A fundamental limitation of the existing SQL tests is that *there is simply no way to create your own `SparkContext`*. This is a serious limitation because the user may wish to use a different master or config. As a case in point, `BroadcastJoinSuite` is entirely commented out because there is no way to make it pass with the existing infrastructure. This patch removes the singletons `TestSQLContext` and `TestData`, and instead introduces a `SharedSQLContext` that starts a context per suite. Unfortunately the singletons were so ingrained in the SQL tests that this patch necessarily needed to touch *all* the SQL test files. <!-- Reviewable:start --> [<img src="https://reviewable.io/review_button.png" height=40 alt="Review on Reviewable"/>](https://reviewable.io/reviews/apache/spark/8111) <!-- Reviewable:end --> Author: Andrew Or <andrew@databricks.com> Closes #8111 from andrewor14/sql-tests-refactor.
-rw-r--r--project/MimaExcludes.scala10
-rw-r--r--project/SparkBuild.scala16
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala97
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala123
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java10
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java39
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java10
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java15
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala14
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala37
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala10
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala18
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala6
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala10
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala6
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala19
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala12
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala13
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala47
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala6
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala15
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala28
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala6
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala7
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala19
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala13
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala29
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala17
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala9
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala7
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/TestData.scala197
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala39
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala9
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala14
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala31
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala3
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala39
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala22
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala3
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala36
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala21
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala14
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala7
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala7
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala42
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala40
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala14
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala10
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala6
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala12
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetProtobufCompatibilitySuite.scala7
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala43
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala6
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala15
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala6
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala10
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala248
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala125
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala113
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala62
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala14
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala9
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala20
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala18
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala3
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala11
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala17
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala9
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala30
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala14
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala11
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala26
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala15
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala290
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala92
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala68
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala (renamed from sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala)46
-rw-r--r--sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala2
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala5
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala11
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala4
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala5
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala3
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala4
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala9
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala4
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala3
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala3
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala4
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala3
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala6
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala5
96 files changed, 1460 insertions, 1203 deletions
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 784f83c10e..88745dc086 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -179,6 +179,16 @@ object MimaExcludes {
ProblemFilters.exclude[MissingMethodProblem](
"org.apache.spark.SparkContext.supportDynamicAllocation")
) ++ Seq(
+ // SPARK-9580: Remove SQL test singletons
+ ProblemFilters.exclude[MissingClassProblem](
+ "org.apache.spark.sql.test.LocalSQLContext$SQLSession"),
+ ProblemFilters.exclude[MissingClassProblem](
+ "org.apache.spark.sql.test.LocalSQLContext"),
+ ProblemFilters.exclude[MissingClassProblem](
+ "org.apache.spark.sql.test.TestSQLContext"),
+ ProblemFilters.exclude[MissingClassProblem](
+ "org.apache.spark.sql.test.TestSQLContext$")
+ ) ++ Seq(
// SPARK-9704 Made ProbabilisticClassifier, Identifiable, VectorUDT public APIs
ProblemFilters.exclude[IncompatibleResultTypeProblem](
"org.apache.spark.mllib.linalg.VectorUDT.serialize")
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index 74f815f941..04e0d49b17 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -319,6 +319,8 @@ object SQL {
lazy val settings = Seq(
initialCommands in console :=
"""
+ |import org.apache.spark.SparkContext
+ |import org.apache.spark.sql.SQLContext
|import org.apache.spark.sql.catalyst.analysis._
|import org.apache.spark.sql.catalyst.dsl._
|import org.apache.spark.sql.catalyst.errors._
@@ -328,9 +330,14 @@ object SQL {
|import org.apache.spark.sql.catalyst.util._
|import org.apache.spark.sql.execution
|import org.apache.spark.sql.functions._
- |import org.apache.spark.sql.test.TestSQLContext._
- |import org.apache.spark.sql.types._""".stripMargin,
- cleanupCommands in console := "sparkContext.stop()"
+ |import org.apache.spark.sql.types._
+ |
+ |val sc = new SparkContext("local[*]", "dev-shell")
+ |val sqlContext = new SQLContext(sc)
+ |import sqlContext.implicits._
+ |import sqlContext._
+ """.stripMargin,
+ cleanupCommands in console := "sc.stop()"
)
}
@@ -340,8 +347,6 @@ object Hive {
javaOptions += "-XX:MaxPermSize=256m",
// Specially disable assertions since some Hive tests fail them
javaOptions in Test := (javaOptions in Test).value.filterNot(_ == "-ea"),
- // Multiple queries rely on the TestHive singleton. See comments there for more details.
- parallelExecution in Test := false,
// Supporting all SerDes requires us to depend on deprecated APIs, so we turn off the warnings
// only for this subproject.
scalacOptions <<= scalacOptions map { currentOpts: Seq[String] =>
@@ -349,6 +354,7 @@ object Hive {
},
initialCommands in console :=
"""
+ |import org.apache.spark.SparkContext
|import org.apache.spark.sql.catalyst.analysis._
|import org.apache.spark.sql.catalyst.dsl._
|import org.apache.spark.sql.catalyst.errors._
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
index 63b475b636..f60d11c988 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisErrorSuite.scala
@@ -17,14 +17,10 @@
package org.apache.spark.sql.catalyst.analysis
-import org.scalatest.BeforeAndAfter
-
-import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.plans.Inner
-import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.types._
@@ -42,7 +38,7 @@ case class UnresolvedTestPlan() extends LeafNode {
override def output: Seq[Attribute] = Nil
}
-class AnalysisErrorSuite extends AnalysisTest with BeforeAndAfter {
+class AnalysisErrorSuite extends AnalysisTest {
import TestRelations._
def errorTest(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 4bf00b3399..53de10d5fa 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -23,7 +23,6 @@ import java.util.concurrent.atomic.AtomicReference
import scala.collection.JavaConversions._
import scala.collection.immutable
-import scala.language.implicitConversions
import scala.reflect.runtime.universe.TypeTag
import scala.util.control.NonFatal
@@ -41,10 +40,9 @@ import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.catalyst.{InternalRow, ParserDialect, _}
import org.apache.spark.sql.execution._
import org.apache.spark.sql.execution.datasources._
+import org.apache.spark.sql.execution.ui.{SQLListener, SQLTab}
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.types._
-import org.apache.spark.sql.execution.ui.{SQLListener, SQLTab}
-import org.apache.spark.unsafe.types.UTF8String
import org.apache.spark.util.Utils
/**
@@ -334,97 +332,10 @@ class SQLContext(@transient val sparkContext: SparkContext)
* @since 1.3.0
*/
@Experimental
- object implicits extends Serializable {
- // scalastyle:on
-
- /**
- * Converts $"col name" into an [[Column]].
- * @since 1.3.0
- */
- implicit class StringToColumn(val sc: StringContext) {
- def $(args: Any*): ColumnName = {
- new ColumnName(sc.s(args: _*))
- }
- }
-
- /**
- * An implicit conversion that turns a Scala `Symbol` into a [[Column]].
- * @since 1.3.0
- */
- implicit def symbolToColumn(s: Symbol): ColumnName = new ColumnName(s.name)
-
- /**
- * Creates a DataFrame from an RDD of case classes or tuples.
- * @since 1.3.0
- */
- implicit def rddToDataFrameHolder[A <: Product : TypeTag](rdd: RDD[A]): DataFrameHolder = {
- DataFrameHolder(self.createDataFrame(rdd))
- }
-
- /**
- * Creates a DataFrame from a local Seq of Product.
- * @since 1.3.0
- */
- implicit def localSeqToDataFrameHolder[A <: Product : TypeTag](data: Seq[A]): DataFrameHolder =
- {
- DataFrameHolder(self.createDataFrame(data))
- }
-
- // Do NOT add more implicit conversions. They are likely to break source compatibility by
- // making existing implicit conversions ambiguous. In particular, RDD[Double] is dangerous
- // because of [[DoubleRDDFunctions]].
-
- /**
- * Creates a single column DataFrame from an RDD[Int].
- * @since 1.3.0
- */
- implicit def intRddToDataFrameHolder(data: RDD[Int]): DataFrameHolder = {
- val dataType = IntegerType
- val rows = data.mapPartitions { iter =>
- val row = new SpecificMutableRow(dataType :: Nil)
- iter.map { v =>
- row.setInt(0, v)
- row: InternalRow
- }
- }
- DataFrameHolder(
- self.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil)))
- }
-
- /**
- * Creates a single column DataFrame from an RDD[Long].
- * @since 1.3.0
- */
- implicit def longRddToDataFrameHolder(data: RDD[Long]): DataFrameHolder = {
- val dataType = LongType
- val rows = data.mapPartitions { iter =>
- val row = new SpecificMutableRow(dataType :: Nil)
- iter.map { v =>
- row.setLong(0, v)
- row: InternalRow
- }
- }
- DataFrameHolder(
- self.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil)))
- }
-
- /**
- * Creates a single column DataFrame from an RDD[String].
- * @since 1.3.0
- */
- implicit def stringRddToDataFrameHolder(data: RDD[String]): DataFrameHolder = {
- val dataType = StringType
- val rows = data.mapPartitions { iter =>
- val row = new SpecificMutableRow(dataType :: Nil)
- iter.map { v =>
- row.update(0, UTF8String.fromString(v))
- row: InternalRow
- }
- }
- DataFrameHolder(
- self.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil)))
- }
+ object implicits extends SQLImplicits with Serializable {
+ protected override def _sqlContext: SQLContext = self
}
+ // scalastyle:on
/**
* :: Experimental ::
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
new file mode 100644
index 0000000000..5f82372700
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLImplicits.scala
@@ -0,0 +1,123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import scala.language.implicitConversions
+import scala.reflect.runtime.universe.TypeTag
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.SpecificMutableRow
+import org.apache.spark.sql.types.StructField
+import org.apache.spark.unsafe.types.UTF8String
+
+/**
+ * A collection of implicit methods for converting common Scala objects into [[DataFrame]]s.
+ */
+private[sql] abstract class SQLImplicits {
+ protected def _sqlContext: SQLContext
+
+ /**
+ * Converts $"col name" into an [[Column]].
+ * @since 1.3.0
+ */
+ implicit class StringToColumn(val sc: StringContext) {
+ def $(args: Any*): ColumnName = {
+ new ColumnName(sc.s(args: _*))
+ }
+ }
+
+ /**
+ * An implicit conversion that turns a Scala `Symbol` into a [[Column]].
+ * @since 1.3.0
+ */
+ implicit def symbolToColumn(s: Symbol): ColumnName = new ColumnName(s.name)
+
+ /**
+ * Creates a DataFrame from an RDD of case classes or tuples.
+ * @since 1.3.0
+ */
+ implicit def rddToDataFrameHolder[A <: Product : TypeTag](rdd: RDD[A]): DataFrameHolder = {
+ DataFrameHolder(_sqlContext.createDataFrame(rdd))
+ }
+
+ /**
+ * Creates a DataFrame from a local Seq of Product.
+ * @since 1.3.0
+ */
+ implicit def localSeqToDataFrameHolder[A <: Product : TypeTag](data: Seq[A]): DataFrameHolder =
+ {
+ DataFrameHolder(_sqlContext.createDataFrame(data))
+ }
+
+ // Do NOT add more implicit conversions. They are likely to break source compatibility by
+ // making existing implicit conversions ambiguous. In particular, RDD[Double] is dangerous
+ // because of [[DoubleRDDFunctions]].
+
+ /**
+ * Creates a single column DataFrame from an RDD[Int].
+ * @since 1.3.0
+ */
+ implicit def intRddToDataFrameHolder(data: RDD[Int]): DataFrameHolder = {
+ val dataType = IntegerType
+ val rows = data.mapPartitions { iter =>
+ val row = new SpecificMutableRow(dataType :: Nil)
+ iter.map { v =>
+ row.setInt(0, v)
+ row: InternalRow
+ }
+ }
+ DataFrameHolder(
+ _sqlContext.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil)))
+ }
+
+ /**
+ * Creates a single column DataFrame from an RDD[Long].
+ * @since 1.3.0
+ */
+ implicit def longRddToDataFrameHolder(data: RDD[Long]): DataFrameHolder = {
+ val dataType = LongType
+ val rows = data.mapPartitions { iter =>
+ val row = new SpecificMutableRow(dataType :: Nil)
+ iter.map { v =>
+ row.setLong(0, v)
+ row: InternalRow
+ }
+ }
+ DataFrameHolder(
+ _sqlContext.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil)))
+ }
+
+ /**
+ * Creates a single column DataFrame from an RDD[String].
+ * @since 1.3.0
+ */
+ implicit def stringRddToDataFrameHolder(data: RDD[String]): DataFrameHolder = {
+ val dataType = StringType
+ val rows = data.mapPartitions { iter =>
+ val row = new SpecificMutableRow(dataType :: Nil)
+ iter.map { v =>
+ row.update(0, UTF8String.fromString(v))
+ row: InternalRow
+ }
+ }
+ DataFrameHolder(
+ _sqlContext.internalCreateDataFrame(rows, StructType(StructField("_1", dataType) :: Nil)))
+ }
+}
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java
index e912eb835d..bf693c7c39 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaApplySchemaSuite.java
@@ -27,6 +27,7 @@ import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
+import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
@@ -34,7 +35,6 @@ import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
-import org.apache.spark.sql.test.TestSQLContext$;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
@@ -48,14 +48,16 @@ public class JavaApplySchemaSuite implements Serializable {
@Before
public void setUp() {
- sqlContext = TestSQLContext$.MODULE$;
- javaCtx = new JavaSparkContext(sqlContext.sparkContext());
+ SparkContext context = new SparkContext("local[*]", "testing");
+ javaCtx = new JavaSparkContext(context);
+ sqlContext = new SQLContext(context);
}
@After
public void tearDown() {
- javaCtx = null;
+ sqlContext.sparkContext().stop();
sqlContext = null;
+ javaCtx = null;
}
public static class Person implements Serializable {
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
index 7302361ab9..7abdd3db80 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
@@ -17,44 +17,45 @@
package test.org.apache.spark.sql;
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.Comparator;
+import java.util.List;
+import java.util.Map;
+
+import scala.collection.JavaConversions;
+import scala.collection.Seq;
+
import com.google.common.collect.ImmutableMap;
import com.google.common.primitives.Ints;
+import org.junit.*;
+import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.*;
+import static org.apache.spark.sql.functions.*;
import org.apache.spark.sql.test.TestSQLContext;
-import org.apache.spark.sql.test.TestSQLContext$;
import org.apache.spark.sql.types.*;
-import org.junit.*;
-
-import scala.collection.JavaConversions;
-import scala.collection.Seq;
-
-import java.io.Serializable;
-import java.util.Arrays;
-import java.util.Comparator;
-import java.util.List;
-import java.util.Map;
-
-import static org.apache.spark.sql.functions.*;
public class JavaDataFrameSuite {
private transient JavaSparkContext jsc;
- private transient SQLContext context;
+ private transient TestSQLContext context;
@Before
public void setUp() {
// Trigger static initializer of TestData
- TestData$.MODULE$.testData();
- jsc = new JavaSparkContext(TestSQLContext.sparkContext());
- context = TestSQLContext$.MODULE$;
+ SparkContext sc = new SparkContext("local[*]", "testing");
+ jsc = new JavaSparkContext(sc);
+ context = new TestSQLContext(sc);
+ context.loadTestData();
}
@After
public void tearDown() {
- jsc = null;
+ context.sparkContext().stop();
context = null;
+ jsc = null;
}
@Test
@@ -230,7 +231,7 @@ public class JavaDataFrameSuite {
@Test
public void testSampleBy() {
- DataFrame df = context.range(0, 100).select(col("id").mod(3).as("key"));
+ DataFrame df = context.range(0, 100, 1, 2).select(col("id").mod(3).as("key"));
DataFrame sampled = df.stat().<Integer>sampleBy("key", ImmutableMap.of(0, 0.1, 1, 0.2), 0L);
Row[] actual = sampled.groupBy("key").count().orderBy("key").collect();
Row[] expected = new Row[] {RowFactory.create(0, 5), RowFactory.create(1, 8)};
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java
index 79d92734ff..bb02b58cca 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaUDFSuite.java
@@ -23,12 +23,12 @@ import org.junit.After;
import org.junit.Before;
import org.junit.Test;
+import org.apache.spark.SparkContext;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.api.java.UDF1;
import org.apache.spark.sql.api.java.UDF2;
import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.sql.test.TestSQLContext$;
import org.apache.spark.sql.types.DataTypes;
// The test suite itself is Serializable so that anonymous Function implementations can be
@@ -40,12 +40,16 @@ public class JavaUDFSuite implements Serializable {
@Before
public void setUp() {
- sqlContext = TestSQLContext$.MODULE$;
- sc = new JavaSparkContext(sqlContext.sparkContext());
+ SparkContext _sc = new SparkContext("local[*]", "testing");
+ sqlContext = new SQLContext(_sc);
+ sc = new JavaSparkContext(_sc);
}
@After
public void tearDown() {
+ sqlContext.sparkContext().stop();
+ sqlContext = null;
+ sc = null;
}
@SuppressWarnings("unchecked")
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java
index 2706e01bd2..6f9e7f68dc 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/sources/JavaSaveLoadSuite.java
@@ -21,13 +21,14 @@ import java.io.File;
import java.io.IOException;
import java.util.*;
+import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
+import org.apache.spark.SparkContext;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.sql.test.TestSQLContext$;
import org.apache.spark.sql.*;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
@@ -52,8 +53,9 @@ public class JavaSaveLoadSuite {
@Before
public void setUp() throws IOException {
- sqlContext = TestSQLContext$.MODULE$;
- sc = new JavaSparkContext(sqlContext.sparkContext());
+ SparkContext _sc = new SparkContext("local[*]", "testing");
+ sqlContext = new SQLContext(_sc);
+ sc = new JavaSparkContext(_sc);
originalDefaultSource = sqlContext.conf().defaultDataSourceName();
path =
@@ -71,6 +73,13 @@ public class JavaSaveLoadSuite {
df.registerTempTable("jsonTable");
}
+ @After
+ public void tearDown() {
+ sqlContext.sparkContext().stop();
+ sqlContext = null;
+ sc = null;
+ }
+
@Test
public void saveAndLoad() {
Map<String, String> options = new HashMap<String, String>();
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
index a88df91b10..af7590c3d3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala
@@ -18,24 +18,20 @@
package org.apache.spark.sql
import scala.concurrent.duration._
-import scala.language.{implicitConversions, postfixOps}
+import scala.language.postfixOps
import org.scalatest.concurrent.Eventually._
import org.apache.spark.Accumulators
-import org.apache.spark.sql.TestData._
import org.apache.spark.sql.columnar._
import org.apache.spark.sql.functions._
+import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.storage.{StorageLevel, RDDBlockId}
-case class BigData(s: String)
+private case class BigData(s: String)
-class CachedTableSuite extends QueryTest {
- TestData // Load test tables.
-
- private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
- import ctx.implicits._
- import ctx.sql
+class CachedTableSuite extends QueryTest with SharedSQLContext {
+ import testImplicits._
def rddIdOf(tableName: String): Int = {
val executedPlan = ctx.table(tableName).queryExecution.executedPlan
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
index 6a09a3b72c..ee74e3e83d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ColumnExpressionSuite.scala
@@ -21,16 +21,20 @@ import org.scalatest.Matchers._
import org.apache.spark.sql.execution.{Project, TungstenProject}
import org.apache.spark.sql.functions._
+import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
-import org.apache.spark.sql.test.SQLTestUtils
-class ColumnExpressionSuite extends QueryTest with SQLTestUtils {
- import org.apache.spark.sql.TestData._
+class ColumnExpressionSuite extends QueryTest with SharedSQLContext {
+ import testImplicits._
- private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
- import ctx.implicits._
-
- override def sqlContext(): SQLContext = ctx
+ private lazy val booleanData = {
+ ctx.createDataFrame(ctx.sparkContext.parallelize(
+ Row(false, false) ::
+ Row(false, true) ::
+ Row(true, false) ::
+ Row(true, true) :: Nil),
+ StructType(Seq(StructField("a", BooleanType), StructField("b", BooleanType))))
+ }
test("column names with space") {
val df = Seq((1, "a")).toDF("name with space", "name.with.dot")
@@ -258,7 +262,7 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils {
nullStrings.collect().toSeq.filter(r => r.getString(1) eq null))
checkAnswer(
- ctx.sql("select isnull(null), isnull(1)"),
+ sql("select isnull(null), isnull(1)"),
Row(true, false))
}
@@ -268,7 +272,7 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils {
nullStrings.collect().toSeq.filter(r => r.getString(1) ne null))
checkAnswer(
- ctx.sql("select isnotnull(null), isnotnull('a')"),
+ sql("select isnotnull(null), isnotnull('a')"),
Row(false, true))
}
@@ -289,7 +293,7 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils {
Row(true, true) :: Row(true, true) :: Row(false, false) :: Row(false, false) :: Nil)
checkAnswer(
- ctx.sql("select isnan(15), isnan('invalid')"),
+ sql("select isnan(15), isnan('invalid')"),
Row(false, false))
}
@@ -309,7 +313,7 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils {
)
testData.registerTempTable("t")
checkAnswer(
- ctx.sql(
+ sql(
"select nanvl(a, 5), nanvl(b, 10), nanvl(10, b), nanvl(c, null), nanvl(d, 10), " +
" nanvl(b, e), nanvl(e, f) from t"),
Row(null, 3.0, 10.0, null, Double.PositiveInfinity, 3.0, 1.0)
@@ -433,13 +437,6 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils {
}
}
- val booleanData = ctx.createDataFrame(ctx.sparkContext.parallelize(
- Row(false, false) ::
- Row(false, true) ::
- Row(true, false) ::
- Row(true, true) :: Nil),
- StructType(Seq(StructField("a", BooleanType), StructField("b", BooleanType))))
-
test("&&") {
checkAnswer(
booleanData.filter($"a" && true),
@@ -523,7 +520,7 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils {
)
checkAnswer(
- ctx.sql("SELECT upper('aB'), ucase('cDe')"),
+ sql("SELECT upper('aB'), ucase('cDe')"),
Row("AB", "CDE"))
}
@@ -544,7 +541,7 @@ class ColumnExpressionSuite extends QueryTest with SQLTestUtils {
)
checkAnswer(
- ctx.sql("SELECT lower('aB'), lcase('cDe')"),
+ sql("SELECT lower('aB'), lcase('cDe')"),
Row("ab", "cde"))
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
index f9cff7440a..72cf7aab0b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala
@@ -17,15 +17,13 @@
package org.apache.spark.sql
-import org.apache.spark.sql.TestData._
import org.apache.spark.sql.functions._
-import org.apache.spark.sql.types.{BinaryType, DecimalType}
+import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.types.DecimalType
-class DataFrameAggregateSuite extends QueryTest {
-
- private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
- import ctx.implicits._
+class DataFrameAggregateSuite extends QueryTest with SharedSQLContext {
+ import testImplicits._
test("groupBy") {
checkAnswer(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
index 03116a374f..9d965258e3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameFunctionsSuite.scala
@@ -17,17 +17,15 @@
package org.apache.spark.sql
-import org.apache.spark.sql.TestData._
import org.apache.spark.sql.functions._
+import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
/**
* Test suite for functions in [[org.apache.spark.sql.functions]].
*/
-class DataFrameFunctionsSuite extends QueryTest {
-
- private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
- import ctx.implicits._
+class DataFrameFunctionsSuite extends QueryTest with SharedSQLContext {
+ import testImplicits._
test("array with column name") {
val df = Seq((0, 1)).toDF("a", "b")
@@ -119,11 +117,11 @@ class DataFrameFunctionsSuite extends QueryTest {
test("constant functions") {
checkAnswer(
- ctx.sql("SELECT E()"),
+ sql("SELECT E()"),
Row(scala.math.E)
)
checkAnswer(
- ctx.sql("SELECT PI()"),
+ sql("SELECT PI()"),
Row(scala.math.Pi)
)
}
@@ -153,7 +151,7 @@ class DataFrameFunctionsSuite extends QueryTest {
test("nvl function") {
checkAnswer(
- ctx.sql("SELECT nvl(null, 'x'), nvl('y', 'x'), nvl(null, null)"),
+ sql("SELECT nvl(null, 'x'), nvl('y', 'x'), nvl(null, null)"),
Row("x", "y", null))
}
@@ -222,7 +220,7 @@ class DataFrameFunctionsSuite extends QueryTest {
Row(-1)
)
checkAnswer(
- ctx.sql("SELECT least(a, 2) as l from testData2 order by l"),
+ sql("SELECT least(a, 2) as l from testData2 order by l"),
Seq(Row(1), Row(1), Row(2), Row(2), Row(2), Row(2))
)
}
@@ -233,7 +231,7 @@ class DataFrameFunctionsSuite extends QueryTest {
Row(3)
)
checkAnswer(
- ctx.sql("SELECT greatest(a, 2) as g from testData2 order by g"),
+ sql("SELECT greatest(a, 2) as g from testData2 order by g"),
Seq(Row(2), Row(2), Row(2), Row(2), Row(3), Row(3))
)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala
index fbb30706a4..e5d7d63441 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameImplicitsSuite.scala
@@ -17,10 +17,10 @@
package org.apache.spark.sql
-class DataFrameImplicitsSuite extends QueryTest {
+import org.apache.spark.sql.test.SharedSQLContext
- private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
- import ctx.implicits._
+class DataFrameImplicitsSuite extends QueryTest with SharedSQLContext {
+ import testImplicits._
test("RDD of tuples") {
checkAnswer(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
index e1c6c70624..e2716d7841 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameJoinSuite.scala
@@ -17,14 +17,12 @@
package org.apache.spark.sql
-import org.apache.spark.sql.TestData._
import org.apache.spark.sql.execution.joins.BroadcastHashJoin
import org.apache.spark.sql.functions._
+import org.apache.spark.sql.test.SharedSQLContext
-class DataFrameJoinSuite extends QueryTest {
-
- private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
- import ctx.implicits._
+class DataFrameJoinSuite extends QueryTest with SharedSQLContext {
+ import testImplicits._
test("join - join using") {
val df = Seq(1, 2, 3).map(i => (i, i.toString)).toDF("int", "str")
@@ -59,7 +57,7 @@ class DataFrameJoinSuite extends QueryTest {
checkAnswer(
df1.join(df2, $"df1.key" === $"df2.key"),
- ctx.sql("SELECT a.key, b.key FROM testData a JOIN testData b ON a.key = b.key")
+ sql("SELECT a.key, b.key FROM testData a JOIN testData b ON a.key = b.key")
.collect().toSeq)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala
index dbe3b44ee2..cdaa14ac80 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameNaFunctionsSuite.scala
@@ -19,11 +19,11 @@ package org.apache.spark.sql
import scala.collection.JavaConversions._
+import org.apache.spark.sql.test.SharedSQLContext
-class DataFrameNaFunctionsSuite extends QueryTest {
- private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
- import ctx.implicits._
+class DataFrameNaFunctionsSuite extends QueryTest with SharedSQLContext {
+ import testImplicits._
def createDF(): DataFrame = {
Seq[(String, java.lang.Integer, java.lang.Double)](
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
index 8f5984e4a8..28bdd6f83b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameStatSuite.scala
@@ -19,20 +19,17 @@ package org.apache.spark.sql
import java.util.Random
-import org.scalatest.Matchers._
-
import org.apache.spark.sql.functions.col
+import org.apache.spark.sql.test.SharedSQLContext
-class DataFrameStatSuite extends QueryTest {
-
- private val sqlCtx = org.apache.spark.sql.test.TestSQLContext
- import sqlCtx.implicits._
+class DataFrameStatSuite extends QueryTest with SharedSQLContext {
+ import testImplicits._
private def toLetter(i: Int): String = (i + 97).toChar.toString
test("sample with replacement") {
val n = 100
- val data = sqlCtx.sparkContext.parallelize(1 to n, 2).toDF("id")
+ val data = ctx.sparkContext.parallelize(1 to n, 2).toDF("id")
checkAnswer(
data.sample(withReplacement = true, 0.05, seed = 13),
Seq(5, 10, 52, 73).map(Row(_))
@@ -41,7 +38,7 @@ class DataFrameStatSuite extends QueryTest {
test("sample without replacement") {
val n = 100
- val data = sqlCtx.sparkContext.parallelize(1 to n, 2).toDF("id")
+ val data = ctx.sparkContext.parallelize(1 to n, 2).toDF("id")
checkAnswer(
data.sample(withReplacement = false, 0.05, seed = 13),
Seq(16, 23, 88, 100).map(Row(_))
@@ -50,7 +47,7 @@ class DataFrameStatSuite extends QueryTest {
test("randomSplit") {
val n = 600
- val data = sqlCtx.sparkContext.parallelize(1 to n, 2).toDF("id")
+ val data = ctx.sparkContext.parallelize(1 to n, 2).toDF("id")
for (seed <- 1 to 5) {
val splits = data.randomSplit(Array[Double](1, 2, 3), seed)
assert(splits.length == 3, "wrong number of splits")
@@ -167,7 +164,7 @@ class DataFrameStatSuite extends QueryTest {
}
test("Frequent Items 2") {
- val rows = sqlCtx.sparkContext.parallelize(Seq.empty[Int], 4)
+ val rows = ctx.sparkContext.parallelize(Seq.empty[Int], 4)
// this is a regression test, where when merging partitions, we omitted values with higher
// counts than those that existed in the map when the map was full. This test should also fail
// if anything like SPARK-9614 is observed once again
@@ -185,7 +182,7 @@ class DataFrameStatSuite extends QueryTest {
}
test("sampleBy") {
- val df = sqlCtx.range(0, 100).select((col("id") % 3).as("key"))
+ val df = ctx.range(0, 100).select((col("id") % 3).as("key"))
val sampled = df.stat.sampleBy("key", Map(0 -> 0.1, 1 -> 0.2), 0L)
checkAnswer(
sampled.groupBy("key").count().orderBy("key"),
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 2feec29955..10bfa9b64f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -23,18 +23,12 @@ import scala.language.postfixOps
import scala.util.Random
import org.apache.spark.sql.catalyst.plans.logical.OneRowRelation
-import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.functions._
-import org.apache.spark.sql.execution.datasources.json.JSONRelation
-import org.apache.spark.sql.execution.datasources.parquet.ParquetRelation
import org.apache.spark.sql.types._
-import org.apache.spark.sql.test.{ExamplePointUDT, ExamplePoint, SQLTestUtils}
+import org.apache.spark.sql.test.{ExamplePointUDT, ExamplePoint, SharedSQLContext}
-class DataFrameSuite extends QueryTest with SQLTestUtils {
- import org.apache.spark.sql.TestData._
-
- lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext
- import sqlContext.implicits._
+class DataFrameSuite extends QueryTest with SharedSQLContext {
+ import testImplicits._
test("analysis error should be eagerly reported") {
// Eager analysis.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala
index bf8ef9a97b..77907e9136 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameTungstenSuite.scala
@@ -18,7 +18,7 @@
package org.apache.spark.sql
import org.apache.spark.sql.functions._
-import org.apache.spark.sql.test.SQLTestUtils
+import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
/**
@@ -27,10 +27,8 @@ import org.apache.spark.sql.types._
* This is here for now so I can make sure Tungsten project is tested without refactoring existing
* end-to-end test infra. In the long run this should just go away.
*/
-class DataFrameTungstenSuite extends QueryTest with SQLTestUtils {
-
- override lazy val sqlContext: SQLContext = org.apache.spark.sql.test.TestSQLContext
- import sqlContext.implicits._
+class DataFrameTungstenSuite extends QueryTest with SharedSQLContext {
+ import testImplicits._
test("test simple types") {
withSQLConf(SQLConf.UNSAFE_ENABLED.key -> "true") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala
index 17897caf95..9080c53c49 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DateFunctionsSuite.scala
@@ -22,19 +22,18 @@ import java.text.SimpleDateFormat
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.functions._
+import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.unsafe.types.CalendarInterval
-class DateFunctionsSuite extends QueryTest {
- private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
-
- import ctx.implicits._
+class DateFunctionsSuite extends QueryTest with SharedSQLContext {
+ import testImplicits._
test("function current_date") {
val df1 = Seq((1, 2), (3, 1)).toDF("a", "b")
val d0 = DateTimeUtils.millisToDays(System.currentTimeMillis())
val d1 = DateTimeUtils.fromJavaDate(df1.select(current_date()).collect().head.getDate(0))
val d2 = DateTimeUtils.fromJavaDate(
- ctx.sql("""SELECT CURRENT_DATE()""").collect().head.getDate(0))
+ sql("""SELECT CURRENT_DATE()""").collect().head.getDate(0))
val d3 = DateTimeUtils.millisToDays(System.currentTimeMillis())
assert(d0 <= d1 && d1 <= d2 && d2 <= d3 && d3 - d0 <= 1)
}
@@ -44,9 +43,9 @@ class DateFunctionsSuite extends QueryTest {
val df1 = Seq((1, 2), (3, 1)).toDF("a", "b")
checkAnswer(df1.select(countDistinct(current_timestamp())), Row(1))
// Execution in one query should return the same value
- checkAnswer(ctx.sql("""SELECT CURRENT_TIMESTAMP() = CURRENT_TIMESTAMP()"""),
+ checkAnswer(sql("""SELECT CURRENT_TIMESTAMP() = CURRENT_TIMESTAMP()"""),
Row(true))
- assert(math.abs(ctx.sql("""SELECT CURRENT_TIMESTAMP()""").collect().head.getTimestamp(
+ assert(math.abs(sql("""SELECT CURRENT_TIMESTAMP()""").collect().head.getTimestamp(
0).getTime - System.currentTimeMillis()) < 5000)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
index ae07eaf91c..f5c5046a8e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala
@@ -17,22 +17,15 @@
package org.apache.spark.sql
-import org.scalatest.BeforeAndAfterEach
-
-import org.apache.spark.sql.TestData._
import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation
import org.apache.spark.sql.execution.joins._
-import org.apache.spark.sql.test.SQLTestUtils
+import org.apache.spark.sql.test.SharedSQLContext
-class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach {
- // Ensures tables are loaded.
- TestData
+class JoinSuite extends QueryTest with SharedSQLContext {
+ import testImplicits._
- override def sqlContext: SQLContext = org.apache.spark.sql.test.TestSQLContext
- lazy val ctx = org.apache.spark.sql.test.TestSQLContext
- import ctx.implicits._
- import ctx.logicalPlanToSparkQuery
+ setupTestData()
test("equi-join is hash-join") {
val x = testData2.as("x")
@@ -43,7 +36,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach {
}
def assertJoin(sqlString: String, c: Class[_]): Any = {
- val df = ctx.sql(sqlString)
+ val df = sql(sqlString)
val physical = df.queryExecution.sparkPlan
val operators = physical.collect {
case j: ShuffledHashJoin => j
@@ -126,7 +119,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach {
test("broadcasted hash join operator selection") {
ctx.cacheManager.clearCache()
- ctx.sql("CACHE TABLE testData")
+ sql("CACHE TABLE testData")
for (sortMergeJoinEnabled <- Seq(true, false)) {
withClue(s"sortMergeJoinEnabled=$sortMergeJoinEnabled") {
withSQLConf(SQLConf.SORTMERGE_JOIN.key -> s"$sortMergeJoinEnabled") {
@@ -141,12 +134,12 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach {
}
}
}
- ctx.sql("UNCACHE TABLE testData")
+ sql("UNCACHE TABLE testData")
}
test("broadcasted hash outer join operator selection") {
ctx.cacheManager.clearCache()
- ctx.sql("CACHE TABLE testData")
+ sql("CACHE TABLE testData")
withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") {
Seq(
("SELECT * FROM testData LEFT JOIN testData2 ON key = a",
@@ -167,7 +160,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach {
classOf[BroadcastHashOuterJoin])
).foreach { case (query, joinClass) => assertJoin(query, joinClass) }
}
- ctx.sql("UNCACHE TABLE testData")
+ sql("UNCACHE TABLE testData")
}
test("multiple-key equi-join is hash-join") {
@@ -279,7 +272,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach {
// Make sure we are choosing left.outputPartitioning as the
// outputPartitioning for the outer join operator.
checkAnswer(
- ctx.sql(
+ sql(
"""
|SELECT l.N, count(*)
|FROM upperCaseData l LEFT OUTER JOIN allNulls r ON (l.N = r.a)
@@ -293,7 +286,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach {
Row(6, 1) :: Nil)
checkAnswer(
- ctx.sql(
+ sql(
"""
|SELECT r.a, count(*)
|FROM upperCaseData l LEFT OUTER JOIN allNulls r ON (l.N = r.a)
@@ -339,7 +332,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach {
// Make sure we are choosing right.outputPartitioning as the
// outputPartitioning for the outer join operator.
checkAnswer(
- ctx.sql(
+ sql(
"""
|SELECT l.a, count(*)
|FROM allNulls l RIGHT OUTER JOIN upperCaseData r ON (l.a = r.N)
@@ -348,7 +341,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach {
Row(null, 6))
checkAnswer(
- ctx.sql(
+ sql(
"""
|SELECT r.N, count(*)
|FROM allNulls l RIGHT OUTER JOIN upperCaseData r ON (l.a = r.N)
@@ -400,7 +393,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach {
// Make sure we are UnknownPartitioning as the outputPartitioning for the outer join operator.
checkAnswer(
- ctx.sql(
+ sql(
"""
|SELECT l.a, count(*)
|FROM allNulls l FULL OUTER JOIN upperCaseData r ON (l.a = r.N)
@@ -409,7 +402,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach {
Row(null, 10))
checkAnswer(
- ctx.sql(
+ sql(
"""
|SELECT r.N, count(*)
|FROM allNulls l FULL OUTER JOIN upperCaseData r ON (l.a = r.N)
@@ -424,7 +417,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach {
Row(null, 4) :: Nil)
checkAnswer(
- ctx.sql(
+ sql(
"""
|SELECT l.N, count(*)
|FROM upperCaseData l FULL OUTER JOIN allNulls r ON (l.N = r.a)
@@ -439,7 +432,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach {
Row(null, 4) :: Nil)
checkAnswer(
- ctx.sql(
+ sql(
"""
|SELECT r.a, count(*)
|FROM upperCaseData l FULL OUTER JOIN allNulls r ON (l.N = r.a)
@@ -450,7 +443,7 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach {
test("broadcasted left semi join operator selection") {
ctx.cacheManager.clearCache()
- ctx.sql("CACHE TABLE testData")
+ sql("CACHE TABLE testData")
withSQLConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "1000000000") {
Seq(
@@ -469,11 +462,11 @@ class JoinSuite extends QueryTest with SQLTestUtils with BeforeAndAfterEach {
}
}
- ctx.sql("UNCACHE TABLE testData")
+ sql("UNCACHE TABLE testData")
}
test("left semi join") {
- val df = ctx.sql("SELECT * FROM testData2 LEFT SEMI JOIN testData ON key = a")
+ val df = sql("SELECT * FROM testData2 LEFT SEMI JOIN testData ON key = a")
checkAnswer(df,
Row(1, 1) ::
Row(1, 2) ::
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala
index 71c26a6f8d..045fea82e4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/JsonFunctionsSuite.scala
@@ -17,10 +17,10 @@
package org.apache.spark.sql
-class JsonFunctionsSuite extends QueryTest {
+import org.apache.spark.sql.test.SharedSQLContext
- private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
- import ctx.implicits._
+class JsonFunctionsSuite extends QueryTest with SharedSQLContext {
+ import testImplicits._
test("function get_json_object") {
val df: DataFrame = Seq(("""{"name": "alice", "age": 5}""", "")).toDF("a", "b")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala
index 2089660c52..babf8835d2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ListTablesSuite.scala
@@ -19,12 +19,11 @@ package org.apache.spark.sql
import org.scalatest.BeforeAndAfter
+import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{BooleanType, StringType, StructField, StructType}
-class ListTablesSuite extends QueryTest with BeforeAndAfter {
-
- private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
- import ctx.implicits._
+class ListTablesSuite extends QueryTest with BeforeAndAfter with SharedSQLContext {
+ import testImplicits._
private lazy val df = (1 to 10).map(i => (i, s"str$i")).toDF("key", "value")
@@ -42,7 +41,7 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter {
Row("ListTablesSuiteTable", true))
checkAnswer(
- ctx.sql("SHOW tables").filter("tableName = 'ListTablesSuiteTable'"),
+ sql("SHOW tables").filter("tableName = 'ListTablesSuiteTable'"),
Row("ListTablesSuiteTable", true))
ctx.catalog.unregisterTable(Seq("ListTablesSuiteTable"))
@@ -55,7 +54,7 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter {
Row("ListTablesSuiteTable", true))
checkAnswer(
- ctx.sql("show TABLES in DB").filter("tableName = 'ListTablesSuiteTable'"),
+ sql("show TABLES in DB").filter("tableName = 'ListTablesSuiteTable'"),
Row("ListTablesSuiteTable", true))
ctx.catalog.unregisterTable(Seq("ListTablesSuiteTable"))
@@ -67,13 +66,13 @@ class ListTablesSuite extends QueryTest with BeforeAndAfter {
StructField("tableName", StringType, false) ::
StructField("isTemporary", BooleanType, false) :: Nil)
- Seq(ctx.tables(), ctx.sql("SHOW TABLes")).foreach {
+ Seq(ctx.tables(), sql("SHOW TABLes")).foreach {
case tableDF =>
assert(expectedSchema === tableDF.schema)
tableDF.registerTempTable("tables")
checkAnswer(
- ctx.sql(
+ sql(
"SELECT isTemporary, tableName from tables WHERE tableName = 'ListTablesSuiteTable'"),
Row(true, "ListTablesSuiteTable")
)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
index 8cf2ef5957..30289c3c1d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/MathExpressionsSuite.scala
@@ -19,18 +19,16 @@ package org.apache.spark.sql
import org.apache.spark.sql.functions._
import org.apache.spark.sql.functions.{log => logarithm}
+import org.apache.spark.sql.test.SharedSQLContext
private object MathExpressionsTestData {
case class DoubleData(a: java.lang.Double, b: java.lang.Double)
case class NullDoubles(a: java.lang.Double)
}
-class MathExpressionsSuite extends QueryTest {
-
+class MathExpressionsSuite extends QueryTest with SharedSQLContext {
import MathExpressionsTestData._
-
- private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
- import ctx.implicits._
+ import testImplicits._
private lazy val doubleData = (1 to 10).map(i => DoubleData(i * 0.2 - 1, i * -0.2 + 1)).toDF()
@@ -149,7 +147,7 @@ class MathExpressionsSuite extends QueryTest {
test("toDegrees") {
testOneToOneMathFunction(toDegrees, math.toDegrees)
checkAnswer(
- ctx.sql("SELECT degrees(0), degrees(1), degrees(1.5)"),
+ sql("SELECT degrees(0), degrees(1), degrees(1.5)"),
Seq((1, 2)).toDF().select(toDegrees(lit(0)), toDegrees(lit(1)), toDegrees(lit(1.5)))
)
}
@@ -157,7 +155,7 @@ class MathExpressionsSuite extends QueryTest {
test("toRadians") {
testOneToOneMathFunction(toRadians, math.toRadians)
checkAnswer(
- ctx.sql("SELECT radians(0), radians(1), radians(1.5)"),
+ sql("SELECT radians(0), radians(1), radians(1.5)"),
Seq((1, 2)).toDF().select(toRadians(lit(0)), toRadians(lit(1)), toRadians(lit(1.5)))
)
}
@@ -169,7 +167,7 @@ class MathExpressionsSuite extends QueryTest {
test("ceil and ceiling") {
testOneToOneMathFunction(ceil, math.ceil)
checkAnswer(
- ctx.sql("SELECT ceiling(0), ceiling(1), ceiling(1.5)"),
+ sql("SELECT ceiling(0), ceiling(1), ceiling(1.5)"),
Row(0.0, 1.0, 2.0))
}
@@ -214,7 +212,7 @@ class MathExpressionsSuite extends QueryTest {
val pi = 3.1415
checkAnswer(
- ctx.sql(s"SELECT round($pi, -3), round($pi, -2), round($pi, -1), " +
+ sql(s"SELECT round($pi, -3), round($pi, -2), round($pi, -1), " +
s"round($pi, 0), round($pi, 1), round($pi, 2), round($pi, 3)"),
Seq(Row(BigDecimal("0E3"), BigDecimal("0E2"), BigDecimal("0E1"), BigDecimal(3),
BigDecimal("3.1"), BigDecimal("3.14"), BigDecimal("3.142")))
@@ -233,7 +231,7 @@ class MathExpressionsSuite extends QueryTest {
testOneToOneMathFunction[Double](signum, math.signum)
checkAnswer(
- ctx.sql("SELECT sign(10), signum(-11)"),
+ sql("SELECT sign(10), signum(-11)"),
Row(1, -1))
}
@@ -241,7 +239,7 @@ class MathExpressionsSuite extends QueryTest {
testTwoToOneMathFunction(pow, pow, math.pow)
checkAnswer(
- ctx.sql("SELECT pow(1, 2), power(2, 1)"),
+ sql("SELECT pow(1, 2), power(2, 1)"),
Seq((1, 2)).toDF().select(pow(lit(1), lit(2)), pow(lit(2), lit(1)))
)
}
@@ -280,7 +278,7 @@ class MathExpressionsSuite extends QueryTest {
test("log / ln") {
testOneToOneNonNegativeMathFunction(org.apache.spark.sql.functions.log, math.log)
checkAnswer(
- ctx.sql("SELECT ln(0), ln(1), ln(1.5)"),
+ sql("SELECT ln(0), ln(1), ln(1.5)"),
Seq((1, 2)).toDF().select(logarithm(lit(0)), logarithm(lit(1)), logarithm(lit(1.5)))
)
}
@@ -375,7 +373,7 @@ class MathExpressionsSuite extends QueryTest {
df.select(log2("b") + log2("a")),
Row(1))
- checkAnswer(ctx.sql("SELECT LOG2(8), LOG2(null)"), Row(3, null))
+ checkAnswer(sql("SELECT LOG2(8), LOG2(null)"), Row(3, null))
}
test("sqrt") {
@@ -384,13 +382,13 @@ class MathExpressionsSuite extends QueryTest {
df.select(sqrt("a"), sqrt("b")),
Row(1.0, 2.0))
- checkAnswer(ctx.sql("SELECT SQRT(4.0), SQRT(null)"), Row(2.0, null))
+ checkAnswer(sql("SELECT SQRT(4.0), SQRT(null)"), Row(2.0, null))
checkAnswer(df.selectExpr("sqrt(a)", "sqrt(b)", "sqrt(null)"), Row(1.0, 2.0, null))
}
test("negative") {
checkAnswer(
- ctx.sql("SELECT negative(1), negative(0), negative(-1)"),
+ sql("SELECT negative(1), negative(0), negative(-1)"),
Row(-1, 0, 1))
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index 98ba3c9928..4adcefb7dc 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -71,12 +71,6 @@ class QueryTest extends PlanTest {
checkAnswer(df, expectedAnswer.collect())
}
- def sqlTest(sqlString: String, expectedAnswer: Seq[Row])(implicit sqlContext: SQLContext) {
- test(sqlString) {
- checkAnswer(sqlContext.sql(sqlString), expectedAnswer)
- }
- }
-
/**
* Asserts that a given [[DataFrame]] will be executed using the given number of cached results.
*/
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala
index 8a679c7865..795d4e983f 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/RowSuite.scala
@@ -20,13 +20,12 @@ package org.apache.spark.sql
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.execution.SparkSqlSerializer
import org.apache.spark.sql.catalyst.expressions.{GenericMutableRow, SpecificMutableRow}
+import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
-class RowSuite extends SparkFunSuite {
-
- private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
- import ctx.implicits._
+class RowSuite extends SparkFunSuite with SharedSQLContext {
+ import testImplicits._
test("create row") {
val expected = new GenericMutableRow(4)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala
index 75791e9d53..7699adadd9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLConfSuite.scala
@@ -17,11 +17,10 @@
package org.apache.spark.sql
+import org.apache.spark.sql.test.SharedSQLContext
-class SQLConfSuite extends QueryTest {
-
- private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
+class SQLConfSuite extends QueryTest with SharedSQLContext {
private val testKey = "test.key.0"
private val testVal = "test.val.0"
@@ -52,21 +51,21 @@ class SQLConfSuite extends QueryTest {
test("parse SQL set commands") {
ctx.conf.clear()
- ctx.sql(s"set $testKey=$testVal")
+ sql(s"set $testKey=$testVal")
assert(ctx.getConf(testKey, testVal + "_") === testVal)
assert(ctx.getConf(testKey, testVal + "_") === testVal)
- ctx.sql("set some.property=20")
+ sql("set some.property=20")
assert(ctx.getConf("some.property", "0") === "20")
- ctx.sql("set some.property = 40")
+ sql("set some.property = 40")
assert(ctx.getConf("some.property", "0") === "40")
val key = "spark.sql.key"
val vs = "val0,val_1,val2.3,my_table"
- ctx.sql(s"set $key=$vs")
+ sql(s"set $key=$vs")
assert(ctx.getConf(key, "0") === vs)
- ctx.sql(s"set $key=")
+ sql(s"set $key=")
assert(ctx.getConf(key, "0") === "")
ctx.conf.clear()
@@ -74,14 +73,14 @@ class SQLConfSuite extends QueryTest {
test("deprecated property") {
ctx.conf.clear()
- ctx.sql(s"set ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS}=10")
+ sql(s"set ${SQLConf.Deprecated.MAPRED_REDUCE_TASKS}=10")
assert(ctx.conf.numShufflePartitions === 10)
}
test("invalid conf value") {
ctx.conf.clear()
val e = intercept[IllegalArgumentException] {
- ctx.sql(s"set ${SQLConf.CASE_SENSITIVE.key}=10")
+ sql(s"set ${SQLConf.CASE_SENSITIVE.key}=10")
}
assert(e.getMessage === s"${SQLConf.CASE_SENSITIVE.key} should be boolean, but was 10")
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala
index c8d8796568..007be12950 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLContextSuite.scala
@@ -17,16 +17,17 @@
package org.apache.spark.sql
-import org.scalatest.BeforeAndAfterAll
-
import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.test.SharedSQLContext
-class SQLContextSuite extends SparkFunSuite with BeforeAndAfterAll {
-
- private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
+class SQLContextSuite extends SparkFunSuite with SharedSQLContext {
override def afterAll(): Unit = {
- SQLContext.setLastInstantiatedContext(ctx)
+ try {
+ SQLContext.setLastInstantiatedContext(ctx)
+ } finally {
+ super.afterAll()
+ }
}
test("getOrCreate instantiates SQLContext") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
index b14ef9bab9..8c2c328f81 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala
@@ -19,28 +19,23 @@ package org.apache.spark.sql
import java.sql.Timestamp
-import org.scalatest.BeforeAndAfterAll
-
import org.apache.spark.AccumulatorSuite
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
import org.apache.spark.sql.catalyst.DefaultParserDialect
import org.apache.spark.sql.catalyst.errors.DialectException
import org.apache.spark.sql.execution.aggregate
import org.apache.spark.sql.functions._
-import org.apache.spark.sql.TestData._
-import org.apache.spark.sql.test.SQLTestUtils
+import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.test.SQLTestData._
import org.apache.spark.sql.types._
/** A SQL Dialect for testing purpose, and it can not be nested type */
class MyDialect extends DefaultParserDialect
-class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
- // Make sure the tables are loaded.
- TestData
+class SQLQuerySuite extends QueryTest with SharedSQLContext {
+ import testImplicits._
- val sqlContext = org.apache.spark.sql.test.TestSQLContext
- import sqlContext.implicits._
- import sqlContext.sql
+ setupTestData()
test("having clause") {
Seq(("one", 1), ("two", 2), ("three", 3), ("one", 5)).toDF("k", "v").registerTempTable("hav")
@@ -60,7 +55,8 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
}
test("show functions") {
- checkAnswer(sql("SHOW functions"), FunctionRegistry.builtin.listFunction().sorted.map(Row(_)))
+ checkAnswer(sql("SHOW functions"),
+ FunctionRegistry.builtin.listFunction().sorted.map(Row(_)))
}
test("describe functions") {
@@ -178,7 +174,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
val df = Seq(Tuple1(1), Tuple1(2), Tuple1(3)).toDF("index")
// we except the id is materialized once
- val idUDF = udf(() => UUID.randomUUID().toString)
+ val idUDF = org.apache.spark.sql.functions.udf(() => UUID.randomUUID().toString)
val dfWithId = df.withColumn("id", idUDF())
// Make a new DataFrame (actually the same reference to the old one)
@@ -712,9 +708,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
checkAnswer(
sql(
- """
- |SELECT COUNT(a), COUNT(b), COUNT(1), COUNT(DISTINCT a), COUNT(DISTINCT b) FROM testData3
- """.stripMargin),
+ "SELECT COUNT(a), COUNT(b), COUNT(1), COUNT(DISTINCT a), COUNT(DISTINCT b) FROM testData3"),
Row(2, 1, 2, 2, 1))
}
@@ -1161,7 +1155,8 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
validateMetadata(sql("SELECT * FROM personWithMeta"))
validateMetadata(sql("SELECT id, name FROM personWithMeta"))
validateMetadata(sql("SELECT * FROM personWithMeta JOIN salary ON id = personId"))
- validateMetadata(sql("SELECT name, salary FROM personWithMeta JOIN salary ON id = personId"))
+ validateMetadata(sql(
+ "SELECT name, salary FROM personWithMeta JOIN salary ON id = personId"))
}
test("SPARK-3371 Renaming a function expression with group by gives error") {
@@ -1627,7 +1622,7 @@ class SQLQuerySuite extends QueryTest with BeforeAndAfterAll with SQLTestUtils {
.toDF("num", "str")
df.registerTempTable("1one")
- checkAnswer(sqlContext.sql("select count(num) from 1one"), Row(10))
+ checkAnswer(sql("select count(num) from 1one"), Row(10))
sqlContext.dropTempTable("1one")
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala
index ab6d3dd96d..295f02f9a7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ScalaReflectionRelationSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql
import java.sql.{Date, Timestamp}
import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.test.SharedSQLContext
case class ReflectData(
stringField: String,
@@ -71,17 +72,15 @@ case class ComplexReflectData(
mapFieldContainsNull: Map[Int, Option[Long]],
dataField: Data)
-class ScalaReflectionRelationSuite extends SparkFunSuite {
-
- private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
- import ctx.implicits._
+class ScalaReflectionRelationSuite extends SparkFunSuite with SharedSQLContext {
+ import testImplicits._
test("query case class RDD") {
val data = ReflectData("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true,
new java.math.BigDecimal(1), Date.valueOf("1970-01-01"), new Timestamp(12345), Seq(1, 2, 3))
Seq(data).toDF().registerTempTable("reflectData")
- assert(ctx.sql("SELECT * FROM reflectData").collect().head ===
+ assert(sql("SELECT * FROM reflectData").collect().head ===
Row("a", 1, 1L, 1.toFloat, 1.toDouble, 1.toShort, 1.toByte, true,
new java.math.BigDecimal(1), Date.valueOf("1970-01-01"),
new Timestamp(12345), Seq(1, 2, 3)))
@@ -91,7 +90,7 @@ class ScalaReflectionRelationSuite extends SparkFunSuite {
val data = NullReflectData(null, null, null, null, null, null, null)
Seq(data).toDF().registerTempTable("reflectNullData")
- assert(ctx.sql("SELECT * FROM reflectNullData").collect().head ===
+ assert(sql("SELECT * FROM reflectNullData").collect().head ===
Row.fromSeq(Seq.fill(7)(null)))
}
@@ -99,7 +98,7 @@ class ScalaReflectionRelationSuite extends SparkFunSuite {
val data = OptionalReflectData(None, None, None, None, None, None, None)
Seq(data).toDF().registerTempTable("reflectOptionalData")
- assert(ctx.sql("SELECT * FROM reflectOptionalData").collect().head ===
+ assert(sql("SELECT * FROM reflectOptionalData").collect().head ===
Row.fromSeq(Seq.fill(7)(null)))
}
@@ -107,7 +106,7 @@ class ScalaReflectionRelationSuite extends SparkFunSuite {
test("query binary data") {
Seq(ReflectBinary(Array[Byte](1))).toDF().registerTempTable("reflectBinary")
- val result = ctx.sql("SELECT data FROM reflectBinary")
+ val result = sql("SELECT data FROM reflectBinary")
.collect().head(0).asInstanceOf[Array[Byte]]
assert(result.toSeq === Seq[Byte](1))
}
@@ -126,7 +125,7 @@ class ScalaReflectionRelationSuite extends SparkFunSuite {
Nested(None, "abc")))
Seq(data).toDF().registerTempTable("reflectComplexData")
- assert(ctx.sql("SELECT * FROM reflectComplexData").collect().head ===
+ assert(sql("SELECT * FROM reflectComplexData").collect().head ===
Row(
Seq(1, 2, 3),
Seq(1, 2, null),
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala
index e55c9e460b..45d0ee4a8e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/SerializationSuite.scala
@@ -19,13 +19,12 @@ package org.apache.spark.sql
import org.apache.spark.{SparkConf, SparkFunSuite}
import org.apache.spark.serializer.JavaSerializer
+import org.apache.spark.sql.test.SharedSQLContext
-class SerializationSuite extends SparkFunSuite {
-
- private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
+class SerializationSuite extends SparkFunSuite with SharedSQLContext {
test("[SPARK-5235] SQLContext should be serializable") {
- val sqlContext = new SQLContext(ctx.sparkContext)
- new JavaSerializer(new SparkConf()).newInstance().serialize(sqlContext)
+ val _sqlContext = new SQLContext(sqlContext.sparkContext)
+ new JavaSerializer(new SparkConf()).newInstance().serialize(_sqlContext)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
index ca298b2434..cc95eede00 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
@@ -18,13 +18,12 @@
package org.apache.spark.sql
import org.apache.spark.sql.functions._
+import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.Decimal
-class StringFunctionsSuite extends QueryTest {
-
- private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
- import ctx.implicits._
+class StringFunctionsSuite extends QueryTest with SharedSQLContext {
+ import testImplicits._
test("string concat") {
val df = Seq[(String, String, String)](("a", "b", null)).toDF("a", "b", "c")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
deleted file mode 100644
index bd9729c431..0000000000
--- a/sql/core/src/test/scala/org/apache/spark/sql/TestData.scala
+++ /dev/null
@@ -1,197 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.sql
-
-import org.apache.spark.sql.test.TestSQLContext.implicits._
-import org.apache.spark.sql.test._
-
-
-case class TestData(key: Int, value: String)
-
-object TestData {
- val testData = TestSQLContext.sparkContext.parallelize(
- (1 to 100).map(i => TestData(i, i.toString))).toDF()
- testData.registerTempTable("testData")
-
- val negativeData = TestSQLContext.sparkContext.parallelize(
- (1 to 100).map(i => TestData(-i, (-i).toString))).toDF()
- negativeData.registerTempTable("negativeData")
-
- case class LargeAndSmallInts(a: Int, b: Int)
- val largeAndSmallInts =
- TestSQLContext.sparkContext.parallelize(
- LargeAndSmallInts(2147483644, 1) ::
- LargeAndSmallInts(1, 2) ::
- LargeAndSmallInts(2147483645, 1) ::
- LargeAndSmallInts(2, 2) ::
- LargeAndSmallInts(2147483646, 1) ::
- LargeAndSmallInts(3, 2) :: Nil).toDF()
- largeAndSmallInts.registerTempTable("largeAndSmallInts")
-
- case class TestData2(a: Int, b: Int)
- val testData2 =
- TestSQLContext.sparkContext.parallelize(
- TestData2(1, 1) ::
- TestData2(1, 2) ::
- TestData2(2, 1) ::
- TestData2(2, 2) ::
- TestData2(3, 1) ::
- TestData2(3, 2) :: Nil, 2).toDF()
- testData2.registerTempTable("testData2")
-
- case class DecimalData(a: BigDecimal, b: BigDecimal)
-
- val decimalData =
- TestSQLContext.sparkContext.parallelize(
- DecimalData(1, 1) ::
- DecimalData(1, 2) ::
- DecimalData(2, 1) ::
- DecimalData(2, 2) ::
- DecimalData(3, 1) ::
- DecimalData(3, 2) :: Nil).toDF()
- decimalData.registerTempTable("decimalData")
-
- case class BinaryData(a: Array[Byte], b: Int)
- val binaryData =
- TestSQLContext.sparkContext.parallelize(
- BinaryData("12".getBytes(), 1) ::
- BinaryData("22".getBytes(), 5) ::
- BinaryData("122".getBytes(), 3) ::
- BinaryData("121".getBytes(), 2) ::
- BinaryData("123".getBytes(), 4) :: Nil).toDF()
- binaryData.registerTempTable("binaryData")
-
- case class TestData3(a: Int, b: Option[Int])
- val testData3 =
- TestSQLContext.sparkContext.parallelize(
- TestData3(1, None) ::
- TestData3(2, Some(2)) :: Nil).toDF()
- testData3.registerTempTable("testData3")
-
- case class UpperCaseData(N: Int, L: String)
- val upperCaseData =
- TestSQLContext.sparkContext.parallelize(
- UpperCaseData(1, "A") ::
- UpperCaseData(2, "B") ::
- UpperCaseData(3, "C") ::
- UpperCaseData(4, "D") ::
- UpperCaseData(5, "E") ::
- UpperCaseData(6, "F") :: Nil).toDF()
- upperCaseData.registerTempTable("upperCaseData")
-
- case class LowerCaseData(n: Int, l: String)
- val lowerCaseData =
- TestSQLContext.sparkContext.parallelize(
- LowerCaseData(1, "a") ::
- LowerCaseData(2, "b") ::
- LowerCaseData(3, "c") ::
- LowerCaseData(4, "d") :: Nil).toDF()
- lowerCaseData.registerTempTable("lowerCaseData")
-
- case class ArrayData(data: Seq[Int], nestedData: Seq[Seq[Int]])
- val arrayData =
- TestSQLContext.sparkContext.parallelize(
- ArrayData(Seq(1, 2, 3), Seq(Seq(1, 2, 3))) ::
- ArrayData(Seq(2, 3, 4), Seq(Seq(2, 3, 4))) :: Nil)
- arrayData.toDF().registerTempTable("arrayData")
-
- case class MapData(data: scala.collection.Map[Int, String])
- val mapData =
- TestSQLContext.sparkContext.parallelize(
- MapData(Map(1 -> "a1", 2 -> "b1", 3 -> "c1", 4 -> "d1", 5 -> "e1")) ::
- MapData(Map(1 -> "a2", 2 -> "b2", 3 -> "c2", 4 -> "d2")) ::
- MapData(Map(1 -> "a3", 2 -> "b3", 3 -> "c3")) ::
- MapData(Map(1 -> "a4", 2 -> "b4")) ::
- MapData(Map(1 -> "a5")) :: Nil)
- mapData.toDF().registerTempTable("mapData")
-
- case class StringData(s: String)
- val repeatedData =
- TestSQLContext.sparkContext.parallelize(List.fill(2)(StringData("test")))
- repeatedData.toDF().registerTempTable("repeatedData")
-
- val nullableRepeatedData =
- TestSQLContext.sparkContext.parallelize(
- List.fill(2)(StringData(null)) ++
- List.fill(2)(StringData("test")))
- nullableRepeatedData.toDF().registerTempTable("nullableRepeatedData")
-
- case class NullInts(a: Integer)
- val nullInts =
- TestSQLContext.sparkContext.parallelize(
- NullInts(1) ::
- NullInts(2) ::
- NullInts(3) ::
- NullInts(null) :: Nil
- ).toDF()
- nullInts.registerTempTable("nullInts")
-
- val allNulls =
- TestSQLContext.sparkContext.parallelize(
- NullInts(null) ::
- NullInts(null) ::
- NullInts(null) ::
- NullInts(null) :: Nil).toDF()
- allNulls.registerTempTable("allNulls")
-
- case class NullStrings(n: Int, s: String)
- val nullStrings =
- TestSQLContext.sparkContext.parallelize(
- NullStrings(1, "abc") ::
- NullStrings(2, "ABC") ::
- NullStrings(3, null) :: Nil).toDF()
- nullStrings.registerTempTable("nullStrings")
-
- case class TableName(tableName: String)
- TestSQLContext
- .sparkContext
- .parallelize(TableName("test") :: Nil)
- .toDF()
- .registerTempTable("tableName")
-
- val unparsedStrings =
- TestSQLContext.sparkContext.parallelize(
- "1, A1, true, null" ::
- "2, B2, false, null" ::
- "3, C3, true, null" ::
- "4, D4, true, 2147483644" :: Nil)
-
- case class IntField(i: Int)
- // An RDD with 4 elements and 8 partitions
- val withEmptyParts = TestSQLContext.sparkContext.parallelize((1 to 4).map(IntField), 8)
- withEmptyParts.toDF().registerTempTable("withEmptyParts")
-
- case class Person(id: Int, name: String, age: Int)
- case class Salary(personId: Int, salary: Double)
- val person = TestSQLContext.sparkContext.parallelize(
- Person(0, "mike", 30) ::
- Person(1, "jim", 20) :: Nil).toDF()
- person.registerTempTable("person")
- val salary = TestSQLContext.sparkContext.parallelize(
- Salary(0, 2000.0) ::
- Salary(1, 1000.0) :: Nil).toDF()
- salary.registerTempTable("salary")
-
- case class ComplexData(m: Map[String, Int], s: TestData, a: Seq[Int], b: Boolean)
- val complexData =
- TestSQLContext.sparkContext.parallelize(
- ComplexData(Map("1" -> 1), TestData(1, "1"), Seq(1, 1, 1), true)
- :: ComplexData(Map("2" -> 2), TestData(2, "2"), Seq(2, 2, 2), false)
- :: Nil).toDF()
- complexData.registerTempTable("complexData")
-}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
index 183dc3407b..eb275af101 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UDFSuite.scala
@@ -17,16 +17,13 @@
package org.apache.spark.sql
-import org.apache.spark.sql.test.SQLTestUtils
+import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.test.SQLTestData._
-case class FunctionResult(f1: String, f2: String)
+private case class FunctionResult(f1: String, f2: String)
-class UDFSuite extends QueryTest with SQLTestUtils {
-
- private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
- import ctx.implicits._
-
- override def sqlContext(): SQLContext = ctx
+class UDFSuite extends QueryTest with SharedSQLContext {
+ import testImplicits._
test("built-in fixed arity expressions") {
val df = ctx.emptyDataFrame
@@ -57,7 +54,7 @@ class UDFSuite extends QueryTest with SQLTestUtils {
test("SPARK-8003 spark_partition_id") {
val df = Seq((1, "Tearing down the walls that divide us")).toDF("id", "saying")
df.registerTempTable("tmp_table")
- checkAnswer(ctx.sql("select spark_partition_id() from tmp_table").toDF(), Row(0))
+ checkAnswer(sql("select spark_partition_id() from tmp_table").toDF(), Row(0))
ctx.dropTempTable("tmp_table")
}
@@ -66,9 +63,9 @@ class UDFSuite extends QueryTest with SQLTestUtils {
val data = ctx.sparkContext.parallelize(0 to 10, 2).toDF("id")
data.write.parquet(dir.getCanonicalPath)
ctx.read.parquet(dir.getCanonicalPath).registerTempTable("test_table")
- val answer = ctx.sql("select input_file_name() from test_table").head().getString(0)
+ val answer = sql("select input_file_name() from test_table").head().getString(0)
assert(answer.contains(dir.getCanonicalPath))
- assert(ctx.sql("select input_file_name() from test_table").distinct().collect().length >= 2)
+ assert(sql("select input_file_name() from test_table").distinct().collect().length >= 2)
ctx.dropTempTable("test_table")
}
}
@@ -91,17 +88,17 @@ class UDFSuite extends QueryTest with SQLTestUtils {
test("Simple UDF") {
ctx.udf.register("strLenScala", (_: String).length)
- assert(ctx.sql("SELECT strLenScala('test')").head().getInt(0) === 4)
+ assert(sql("SELECT strLenScala('test')").head().getInt(0) === 4)
}
test("ZeroArgument UDF") {
ctx.udf.register("random0", () => { Math.random()})
- assert(ctx.sql("SELECT random0()").head().getDouble(0) >= 0.0)
+ assert(sql("SELECT random0()").head().getDouble(0) >= 0.0)
}
test("TwoArgument UDF") {
ctx.udf.register("strLenScala", (_: String).length + (_: Int))
- assert(ctx.sql("SELECT strLenScala('test', 1)").head().getInt(0) === 5)
+ assert(sql("SELECT strLenScala('test', 1)").head().getInt(0) === 5)
}
test("UDF in a WHERE") {
@@ -112,7 +109,7 @@ class UDFSuite extends QueryTest with SQLTestUtils {
df.registerTempTable("integerData")
val result =
- ctx.sql("SELECT * FROM integerData WHERE oneArgFilter(key)")
+ sql("SELECT * FROM integerData WHERE oneArgFilter(key)")
assert(result.count() === 20)
}
@@ -124,7 +121,7 @@ class UDFSuite extends QueryTest with SQLTestUtils {
df.registerTempTable("groupData")
val result =
- ctx.sql(
+ sql(
"""
| SELECT g, SUM(v) as s
| FROM groupData
@@ -143,7 +140,7 @@ class UDFSuite extends QueryTest with SQLTestUtils {
df.registerTempTable("groupData")
val result =
- ctx.sql(
+ sql(
"""
| SELECT SUM(v)
| FROM groupData
@@ -163,7 +160,7 @@ class UDFSuite extends QueryTest with SQLTestUtils {
df.registerTempTable("groupData")
val result =
- ctx.sql(
+ sql(
"""
| SELECT timesHundred(SUM(v)) as v100
| FROM groupData
@@ -178,7 +175,7 @@ class UDFSuite extends QueryTest with SQLTestUtils {
ctx.udf.register("returnStruct", (f1: String, f2: String) => FunctionResult(f1, f2))
val result =
- ctx.sql("SELECT returnStruct('test', 'test2') as ret")
+ sql("SELECT returnStruct('test', 'test2') as ret")
.select($"ret.f1").head().getString(0)
assert(result === "test")
}
@@ -186,12 +183,12 @@ class UDFSuite extends QueryTest with SQLTestUtils {
test("udf that is transformed") {
ctx.udf.register("makeStruct", (x: Int, y: Int) => (x, y))
// 1 + 1 is constant folded causing a transformation.
- assert(ctx.sql("SELECT makeStruct(1 + 1, 2)").first().getAs[Row](0) === Row(2, 2))
+ assert(sql("SELECT makeStruct(1 + 1, 2)").first().getAs[Row](0) === Row(2, 2))
}
test("type coercion for udf inputs") {
ctx.udf.register("intExpected", (x: Int) => x)
// pass a decimal to intExpected.
- assert(ctx.sql("SELECT intExpected(1.0)").head().getInt(0) === 1)
+ assert(sql("SELECT intExpected(1.0)").head().getInt(0) === 1)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
index 9181222f69..b6d279ae47 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/UserDefinedTypeSuite.scala
@@ -24,6 +24,7 @@ import com.clearspring.analytics.stream.cardinality.HyperLogLog
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.expressions.{OpenHashSetUDT, HyperLogLogUDT}
import org.apache.spark.sql.functions._
+import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
import org.apache.spark.util.collection.OpenHashSet
@@ -66,10 +67,8 @@ private[sql] class MyDenseVectorUDT extends UserDefinedType[MyDenseVector] {
private[spark] override def asNullable: MyDenseVectorUDT = this
}
-class UserDefinedTypeSuite extends QueryTest {
-
- private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
- import ctx.implicits._
+class UserDefinedTypeSuite extends QueryTest with SharedSQLContext {
+ import testImplicits._
private lazy val pointsRDD = Seq(
MyLabeledPoint(1.0, new MyDenseVector(Array(0.1, 1.0))),
@@ -94,7 +93,7 @@ class UserDefinedTypeSuite extends QueryTest {
ctx.udf.register("testType", (d: MyDenseVector) => d.isInstanceOf[MyDenseVector])
pointsRDD.registerTempTable("points")
checkAnswer(
- ctx.sql("SELECT testType(features) from points"),
+ sql("SELECT testType(features) from points"),
Seq(Row(true), Row(true)))
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
index 9bca4e7e66..952637c5f9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/InMemoryColumnarQuerySuite.scala
@@ -19,18 +19,16 @@ package org.apache.spark.sql.columnar
import java.sql.{Date, Timestamp}
-import org.apache.spark.sql.TestData._
+import org.apache.spark.sql.{QueryTest, Row}
+import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.test.SQLTestData._
import org.apache.spark.sql.types._
-import org.apache.spark.sql.{QueryTest, Row, TestData}
import org.apache.spark.storage.StorageLevel.MEMORY_ONLY
-class InMemoryColumnarQuerySuite extends QueryTest {
- // Make sure the tables are loaded.
- TestData
+class InMemoryColumnarQuerySuite extends QueryTest with SharedSQLContext {
+ import testImplicits._
- private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
- import ctx.implicits._
- import ctx.{logicalPlanToSparkQuery, sql}
+ setupTestData()
test("simple columnar query") {
val plan = ctx.executePlan(testData.logicalPlan).executedPlan
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala
index 2c0879927a..ab2644eb45 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/columnar/PartitionBatchPruningSuite.scala
@@ -17,20 +17,19 @@
package org.apache.spark.sql.columnar
-import org.scalatest.{BeforeAndAfter, BeforeAndAfterAll}
-
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql._
+import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.test.SQLTestData._
-class PartitionBatchPruningSuite extends SparkFunSuite with BeforeAndAfterAll with BeforeAndAfter {
-
- private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
- import ctx.implicits._
+class PartitionBatchPruningSuite extends SparkFunSuite with SharedSQLContext {
+ import testImplicits._
private lazy val originalColumnBatchSize = ctx.conf.columnBatchSize
private lazy val originalInMemoryPartitionPruning = ctx.conf.inMemoryPartitionPruning
override protected def beforeAll(): Unit = {
+ super.beforeAll()
// Make a table with 5 partitions, 2 batches per partition, 10 elements per batch
ctx.setConf(SQLConf.COLUMN_BATCH_SIZE, 10)
@@ -44,19 +43,17 @@ class PartitionBatchPruningSuite extends SparkFunSuite with BeforeAndAfterAll wi
ctx.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, true)
// Enable in-memory table scan accumulators
ctx.setConf("spark.sql.inMemoryTableScanStatistics.enable", "true")
- }
-
- override protected def afterAll(): Unit = {
- ctx.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize)
- ctx.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning)
- }
-
- before {
ctx.cacheTable("pruningData")
}
- after {
- ctx.uncacheTable("pruningData")
+ override protected def afterAll(): Unit = {
+ try {
+ ctx.setConf(SQLConf.COLUMN_BATCH_SIZE, originalColumnBatchSize)
+ ctx.setConf(SQLConf.IN_MEMORY_PARTITION_PRUNING, originalInMemoryPartitionPruning)
+ ctx.uncacheTable("pruningData")
+ } finally {
+ super.afterAll()
+ }
}
// Comparisons
@@ -110,7 +107,7 @@ class PartitionBatchPruningSuite extends SparkFunSuite with BeforeAndAfterAll wi
expectedQueryResult: => Seq[Int]): Unit = {
test(query) {
- val df = ctx.sql(query)
+ val df = sql(query)
val queryExecution = df.queryExecution
assertResult(expectedQueryResult.toArray, s"Wrong query result: $queryExecution") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala
index 79e903c2bb..8998f51111 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala
@@ -19,8 +19,9 @@ package org.apache.spark.sql.execution
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.plans.physical.SinglePartition
+import org.apache.spark.sql.test.SharedSQLContext
-class ExchangeSuite extends SparkPlanTest {
+class ExchangeSuite extends SparkPlanTest with SharedSQLContext {
test("shuffling UnsafeRows in exchange") {
val input = (1 to 1000).map(Tuple1.apply)
checkAnswer(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index 5582caa0d3..937a108543 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -19,7 +19,7 @@ package org.apache.spark.sql.execution
import org.apache.spark.SparkFunSuite
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.TestData._
+import org.apache.spark.sql.{execution, Row, SQLConf}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{Ascending, Attribute, Literal, SortOrder}
import org.apache.spark.sql.catalyst.plans._
@@ -27,19 +27,18 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.joins.{BroadcastHashJoin, ShuffledHashJoin}
import org.apache.spark.sql.functions._
-import org.apache.spark.sql.test.{SQLTestUtils, TestSQLContext}
-import org.apache.spark.sql.test.TestSQLContext._
-import org.apache.spark.sql.test.TestSQLContext.implicits._
-import org.apache.spark.sql.test.TestSQLContext.planner._
+import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
-import org.apache.spark.sql.{SQLContext, Row, SQLConf, execution}
-class PlannerSuite extends SparkFunSuite with SQLTestUtils {
+class PlannerSuite extends SparkFunSuite with SharedSQLContext {
+ import testImplicits._
- override def sqlContext: SQLContext = TestSQLContext
+ setupTestData()
private def testPartialAggregationPlan(query: LogicalPlan): Unit = {
+ val _ctx = ctx
+ import _ctx.planner._
val plannedOption = HashAggregation(query).headOption.orElse(Aggregation(query).headOption)
val planned =
plannedOption.getOrElse(
@@ -54,6 +53,8 @@ class PlannerSuite extends SparkFunSuite with SQLTestUtils {
}
test("unions are collapsed") {
+ val _ctx = ctx
+ import _ctx.planner._
val query = testData.unionAll(testData).unionAll(testData).logicalPlan
val planned = BasicOperators(query).head
val logicalUnions = query collect { case u: logical.Union => u }
@@ -81,14 +82,14 @@ class PlannerSuite extends SparkFunSuite with SQLTestUtils {
test("sizeInBytes estimation of limit operator for broadcast hash join optimization") {
def checkPlan(fieldTypes: Seq[DataType], newThreshold: Int): Unit = {
- setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, newThreshold)
+ ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, newThreshold)
val fields = fieldTypes.zipWithIndex.map {
case (dataType, index) => StructField(s"c${index}", dataType, true)
} :+ StructField("key", IntegerType, true)
val schema = StructType(fields)
val row = Row.fromSeq(Seq.fill(fields.size)(null))
- val rowRDD = org.apache.spark.sql.test.TestSQLContext.sparkContext.parallelize(row :: Nil)
- createDataFrame(rowRDD, schema).registerTempTable("testLimit")
+ val rowRDD = ctx.sparkContext.parallelize(row :: Nil)
+ ctx.createDataFrame(rowRDD, schema).registerTempTable("testLimit")
val planned = sql(
"""
@@ -102,10 +103,10 @@ class PlannerSuite extends SparkFunSuite with SQLTestUtils {
assert(broadcastHashJoins.size === 1, "Should use broadcast hash join")
assert(shuffledHashJoins.isEmpty, "Should not use shuffled hash join")
- dropTempTable("testLimit")
+ ctx.dropTempTable("testLimit")
}
- val origThreshold = conf.autoBroadcastJoinThreshold
+ val origThreshold = ctx.conf.autoBroadcastJoinThreshold
val simpleTypes =
NullType ::
@@ -137,18 +138,18 @@ class PlannerSuite extends SparkFunSuite with SQLTestUtils {
checkPlan(complexTypes, newThreshold = 901617)
- setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold)
+ ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold)
}
test("InMemoryRelation statistics propagation") {
- val origThreshold = conf.autoBroadcastJoinThreshold
- setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, 81920)
+ val origThreshold = ctx.conf.autoBroadcastJoinThreshold
+ ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, 81920)
testData.limit(3).registerTempTable("tiny")
sql("CACHE TABLE tiny")
val a = testData.as("a")
- val b = table("tiny").as("b")
+ val b = ctx.table("tiny").as("b")
val planned = a.join(b, $"a.key" === $"b.key").queryExecution.executedPlan
val broadcastHashJoins = planned.collect { case join: BroadcastHashJoin => join }
@@ -157,12 +158,12 @@ class PlannerSuite extends SparkFunSuite with SQLTestUtils {
assert(broadcastHashJoins.size === 1, "Should use broadcast hash join")
assert(shuffledHashJoins.isEmpty, "Should not use shuffled hash join")
- setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold)
+ ctx.setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold)
}
test("efficient limit -> project -> sort") {
val query = testData.sort('key).select('value).limit(2).logicalPlan
- val planned = planner.TakeOrderedAndProject(query)
+ val planned = ctx.planner.TakeOrderedAndProject(query)
assert(planned.head.isInstanceOf[execution.TakeOrderedAndProject])
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala
index dd08e9025a..ef6ad59b71 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/RowFormatConvertersSuite.scala
@@ -21,11 +21,11 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Attribute, Literal, IsNull}
-import org.apache.spark.sql.test.TestSQLContext
-import org.apache.spark.sql.types.{GenericArrayData, ArrayType, StructType, StringType}
+import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.types.{GenericArrayData, ArrayType, StringType}
import org.apache.spark.unsafe.types.UTF8String
-class RowFormatConvertersSuite extends SparkPlanTest {
+class RowFormatConvertersSuite extends SparkPlanTest with SharedSQLContext {
private def getConverters(plan: SparkPlan): Seq[SparkPlan] = plan.collect {
case c: ConvertToUnsafe => c
@@ -39,20 +39,20 @@ class RowFormatConvertersSuite extends SparkPlanTest {
test("planner should insert unsafe->safe conversions when required") {
val plan = Limit(10, outputsUnsafe)
- val preparedPlan = TestSQLContext.prepareForExecution.execute(plan)
+ val preparedPlan = ctx.prepareForExecution.execute(plan)
assert(preparedPlan.children.head.isInstanceOf[ConvertToSafe])
}
test("filter can process unsafe rows") {
val plan = Filter(IsNull(IsNull(Literal(1))), outputsUnsafe)
- val preparedPlan = TestSQLContext.prepareForExecution.execute(plan)
+ val preparedPlan = ctx.prepareForExecution.execute(plan)
assert(getConverters(preparedPlan).size === 1)
assert(preparedPlan.outputsUnsafeRows)
}
test("filter can process safe rows") {
val plan = Filter(IsNull(IsNull(Literal(1))), outputsSafe)
- val preparedPlan = TestSQLContext.prepareForExecution.execute(plan)
+ val preparedPlan = ctx.prepareForExecution.execute(plan)
assert(getConverters(preparedPlan).isEmpty)
assert(!preparedPlan.outputsUnsafeRows)
}
@@ -67,33 +67,33 @@ class RowFormatConvertersSuite extends SparkPlanTest {
test("union requires all of its input rows' formats to agree") {
val plan = Union(Seq(outputsSafe, outputsUnsafe))
assert(plan.canProcessSafeRows && plan.canProcessUnsafeRows)
- val preparedPlan = TestSQLContext.prepareForExecution.execute(plan)
+ val preparedPlan = ctx.prepareForExecution.execute(plan)
assert(preparedPlan.outputsUnsafeRows)
}
test("union can process safe rows") {
val plan = Union(Seq(outputsSafe, outputsSafe))
- val preparedPlan = TestSQLContext.prepareForExecution.execute(plan)
+ val preparedPlan = ctx.prepareForExecution.execute(plan)
assert(!preparedPlan.outputsUnsafeRows)
}
test("union can process unsafe rows") {
val plan = Union(Seq(outputsUnsafe, outputsUnsafe))
- val preparedPlan = TestSQLContext.prepareForExecution.execute(plan)
+ val preparedPlan = ctx.prepareForExecution.execute(plan)
assert(preparedPlan.outputsUnsafeRows)
}
test("round trip with ConvertToUnsafe and ConvertToSafe") {
val input = Seq(("hello", 1), ("world", 2))
checkAnswer(
- TestSQLContext.createDataFrame(input),
+ ctx.createDataFrame(input),
plan => ConvertToSafe(ConvertToUnsafe(plan)),
input.map(Row.fromTuple)
)
}
test("SPARK-9683: copy UTF8String when convert unsafe array/map to safe") {
- SparkPlan.currentContext.set(TestSQLContext)
+ SparkPlan.currentContext.set(ctx)
val schema = ArrayType(StringType)
val rows = (1 to 100).map { i =>
InternalRow(new GenericArrayData(Array[Any](UTF8String.fromString(i.toString))))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala
index a2c10fdaf6..8fa77b0fcb 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SortSuite.scala
@@ -19,8 +19,9 @@ package org.apache.spark.sql.execution
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.dsl.expressions._
+import org.apache.spark.sql.test.SharedSQLContext
-class SortSuite extends SparkPlanTest {
+class SortSuite extends SparkPlanTest with SharedSQLContext {
// This test was originally added as an example of how to use [[SparkPlanTest]];
// it's not designed to be a comprehensive test of ExternalSort.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
index f46855edfe..3a87f374d9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/SparkPlanTest.scala
@@ -17,29 +17,27 @@
package org.apache.spark.sql.execution
-import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
-import org.apache.spark.sql.catalyst.util._
-import org.apache.spark.sql.test.TestSQLContext
-import org.apache.spark.sql.{SQLContext, DataFrame, DataFrameHolder, Row}
-
import scala.language.implicitConversions
import scala.reflect.runtime.universe.TypeTag
import scala.util.control.NonFatal
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.{DataFrame, DataFrameHolder, Row, SQLContext}
+import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
+import org.apache.spark.sql.catalyst.util._
+
/**
* Base class for writing tests for individual physical operators. For an example of how this
* class's test helper methods can be used, see [[SortSuite]].
*/
-class SparkPlanTest extends SparkFunSuite {
-
- protected def sqlContext: SQLContext = TestSQLContext
+private[sql] abstract class SparkPlanTest extends SparkFunSuite {
+ protected def _sqlContext: SQLContext
/**
* Creates a DataFrame from a local Seq of Product.
*/
implicit def localSeqToDataFrameHolder[A <: Product : TypeTag](data: Seq[A]): DataFrameHolder = {
- sqlContext.implicits.localSeqToDataFrameHolder(data)
+ _sqlContext.implicits.localSeqToDataFrameHolder(data)
}
/**
@@ -100,7 +98,7 @@ class SparkPlanTest extends SparkFunSuite {
planFunction: Seq[SparkPlan] => SparkPlan,
expectedAnswer: Seq[Row],
sortAnswers: Boolean = true): Unit = {
- SparkPlanTest.checkAnswer(input, planFunction, expectedAnswer, sortAnswers, sqlContext) match {
+ SparkPlanTest.checkAnswer(input, planFunction, expectedAnswer, sortAnswers, _sqlContext) match {
case Some(errorMessage) => fail(errorMessage)
case None =>
}
@@ -124,7 +122,7 @@ class SparkPlanTest extends SparkFunSuite {
expectedPlanFunction: SparkPlan => SparkPlan,
sortAnswers: Boolean = true): Unit = {
SparkPlanTest.checkAnswer(
- input, planFunction, expectedPlanFunction, sortAnswers, sqlContext) match {
+ input, planFunction, expectedPlanFunction, sortAnswers, _sqlContext) match {
case Some(errorMessage) => fail(errorMessage)
case None =>
}
@@ -151,13 +149,13 @@ object SparkPlanTest {
planFunction: SparkPlan => SparkPlan,
expectedPlanFunction: SparkPlan => SparkPlan,
sortAnswers: Boolean,
- sqlContext: SQLContext): Option[String] = {
+ _sqlContext: SQLContext): Option[String] = {
val outputPlan = planFunction(input.queryExecution.sparkPlan)
val expectedOutputPlan = expectedPlanFunction(input.queryExecution.sparkPlan)
val expectedAnswer: Seq[Row] = try {
- executePlan(expectedOutputPlan, sqlContext)
+ executePlan(expectedOutputPlan, _sqlContext)
} catch {
case NonFatal(e) =>
val errorMessage =
@@ -172,7 +170,7 @@ object SparkPlanTest {
}
val actualAnswer: Seq[Row] = try {
- executePlan(outputPlan, sqlContext)
+ executePlan(outputPlan, _sqlContext)
} catch {
case NonFatal(e) =>
val errorMessage =
@@ -212,12 +210,12 @@ object SparkPlanTest {
planFunction: Seq[SparkPlan] => SparkPlan,
expectedAnswer: Seq[Row],
sortAnswers: Boolean,
- sqlContext: SQLContext): Option[String] = {
+ _sqlContext: SQLContext): Option[String] = {
val outputPlan = planFunction(input.map(_.queryExecution.sparkPlan))
val sparkAnswer: Seq[Row] = try {
- executePlan(outputPlan, sqlContext)
+ executePlan(outputPlan, _sqlContext)
} catch {
case NonFatal(e) =>
val errorMessage =
@@ -280,10 +278,10 @@ object SparkPlanTest {
}
}
- private def executePlan(outputPlan: SparkPlan, sqlContext: SQLContext): Seq[Row] = {
+ private def executePlan(outputPlan: SparkPlan, _sqlContext: SQLContext): Seq[Row] = {
// A very simple resolver to make writing tests easier. In contrast to the real resolver
// this is always case sensitive and does not try to handle scoping or complex type resolution.
- val resolvedPlan = sqlContext.prepareForExecution.execute(
+ val resolvedPlan = _sqlContext.prepareForExecution.execute(
outputPlan transform {
case plan: SparkPlan =>
val inputMap = plan.children.flatMap(_.output).map(a => (a.name, a)).toMap
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala
index 88bce0e319..3158458edb 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TungstenSortSuite.scala
@@ -19,25 +19,28 @@ package org.apache.spark.sql.execution
import scala.util.Random
-import org.scalatest.BeforeAndAfterAll
-
import org.apache.spark.AccumulatorSuite
import org.apache.spark.sql.{RandomDataGenerator, Row, SQLConf}
import org.apache.spark.sql.catalyst.dsl.expressions._
-import org.apache.spark.sql.test.TestSQLContext
+import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
/**
* A test suite that generates randomized data to test the [[TungstenSort]] operator.
*/
-class TungstenSortSuite extends SparkPlanTest with BeforeAndAfterAll {
+class TungstenSortSuite extends SparkPlanTest with SharedSQLContext {
override def beforeAll(): Unit = {
- TestSQLContext.conf.setConf(SQLConf.CODEGEN_ENABLED, true)
+ super.beforeAll()
+ ctx.conf.setConf(SQLConf.CODEGEN_ENABLED, true)
}
override def afterAll(): Unit = {
- TestSQLContext.conf.setConf(SQLConf.CODEGEN_ENABLED, SQLConf.CODEGEN_ENABLED.defaultValue.get)
+ try {
+ ctx.conf.setConf(SQLConf.CODEGEN_ENABLED, SQLConf.CODEGEN_ENABLED.defaultValue.get)
+ } finally {
+ super.afterAll()
+ }
}
test("sort followed by limit") {
@@ -61,7 +64,7 @@ class TungstenSortSuite extends SparkPlanTest with BeforeAndAfterAll {
}
test("sorting updates peak execution memory") {
- val sc = TestSQLContext.sparkContext
+ val sc = ctx.sparkContext
AccumulatorSuite.verifyPeakExecutionMemorySet(sc, "unsafe external sort") {
checkThatPlansAgree(
(1 to 100).map(v => Tuple1(v)).toDF("a"),
@@ -80,8 +83,8 @@ class TungstenSortSuite extends SparkPlanTest with BeforeAndAfterAll {
) {
test(s"sorting on $dataType with nullable=$nullable, sortOrder=$sortOrder") {
val inputData = Seq.fill(1000)(randomDataGenerator())
- val inputDf = TestSQLContext.createDataFrame(
- TestSQLContext.sparkContext.parallelize(Random.shuffle(inputData).map(v => Row(v))),
+ val inputDf = ctx.createDataFrame(
+ ctx.sparkContext.parallelize(Random.shuffle(inputData).map(v => Row(v))),
StructType(StructField("a", dataType, nullable = true) :: Nil)
)
assert(TungstenSort.supportsSchema(inputDf.schema))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
index e03473041c..d1f0b2b1fc 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeFixedWidthAggregationMapSuite.scala
@@ -26,7 +26,7 @@ import org.scalatest.Matchers
import org.apache.spark.sql.catalyst.expressions.{UnsafeRow, UnsafeProjection}
import org.apache.spark.{TaskContextImpl, TaskContext, SparkFunSuite}
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.test.TestSQLContext
+import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager}
import org.apache.spark.unsafe.types.UTF8String
@@ -36,7 +36,10 @@ import org.apache.spark.unsafe.types.UTF8String
*
* Use [[testWithMemoryLeakDetection]] rather than [[test]] to construct test cases.
*/
-class UnsafeFixedWidthAggregationMapSuite extends SparkFunSuite with Matchers {
+class UnsafeFixedWidthAggregationMapSuite
+ extends SparkFunSuite
+ with Matchers
+ with SharedSQLContext {
import UnsafeFixedWidthAggregationMap._
@@ -171,9 +174,6 @@ class UnsafeFixedWidthAggregationMapSuite extends SparkFunSuite with Matchers {
}
testWithMemoryLeakDetection("test external sorting") {
- // Calling this make sure we have block manager and everything else setup.
- TestSQLContext
-
// Memory consumption in the beginning of the task.
val initialMemoryConsumption = shuffleMemoryManager.getMemoryConsumptionForThisTask()
@@ -233,8 +233,6 @@ class UnsafeFixedWidthAggregationMapSuite extends SparkFunSuite with Matchers {
}
testWithMemoryLeakDetection("test external sorting with an empty map") {
- // Calling this make sure we have block manager and everything else setup.
- TestSQLContext
val map = new UnsafeFixedWidthAggregationMap(
emptyAggregationBuffer,
@@ -282,8 +280,6 @@ class UnsafeFixedWidthAggregationMapSuite extends SparkFunSuite with Matchers {
}
testWithMemoryLeakDetection("test external sorting with empty records") {
- // Calling this make sure we have block manager and everything else setup.
- TestSQLContext
// Memory consumption in the beginning of the task.
val initialMemoryConsumption = shuffleMemoryManager.getMemoryConsumptionForThisTask()
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
index a9515a03ac..d3be568a87 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/UnsafeKVExternalSorterSuite.scala
@@ -23,15 +23,14 @@ import org.apache.spark._
import org.apache.spark.sql.{RandomDataGenerator, Row}
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.expressions.{InterpretedOrdering, UnsafeRow, UnsafeProjection}
-import org.apache.spark.sql.test.TestSQLContext
+import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.memory.{ExecutorMemoryManager, MemoryAllocator, TaskMemoryManager}
/**
* Test suite for [[UnsafeKVExternalSorter]], with randomly generated test data.
*/
-class UnsafeKVExternalSorterSuite extends SparkFunSuite {
-
+class UnsafeKVExternalSorterSuite extends SparkFunSuite with SharedSQLContext {
private val keyTypes = Seq(IntegerType, FloatType, DoubleType, StringType)
private val valueTypes = Seq(IntegerType, FloatType, DoubleType, StringType)
@@ -109,8 +108,6 @@ class UnsafeKVExternalSorterSuite extends SparkFunSuite {
inputData: Seq[(InternalRow, InternalRow)],
pageSize: Long,
spill: Boolean): Unit = {
- // Calling this make sure we have block manager and everything else setup.
- TestSQLContext
val taskMemMgr = new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP))
val shuffleMemMgr = new TestShuffleMemoryManager
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala
index ac22c2f3c0..5fdb82b067 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIteratorSuite.scala
@@ -21,15 +21,12 @@ import org.apache.spark._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.InterpretedMutableProjection
import org.apache.spark.sql.execution.metric.SQLMetrics
-import org.apache.spark.sql.test.TestSQLContext
+import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.unsafe.memory.TaskMemoryManager
-class TungstenAggregationIteratorSuite extends SparkFunSuite {
+class TungstenAggregationIteratorSuite extends SparkFunSuite with SharedSQLContext {
test("memory acquired on construction") {
- // set up environment
- val ctx = TestSQLContext
-
val taskMemoryManager = new TaskMemoryManager(SparkEnv.get.executorMemoryManager)
val taskContext = new TaskContextImpl(0, 0, 0, 0, taskMemoryManager, null, Seq.empty)
TaskContext.setTaskContext(taskContext)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
index 73d5621897..1174b27732 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/JsonSuite.scala
@@ -24,22 +24,16 @@ import com.fasterxml.jackson.core.JsonFactory
import org.apache.spark.rdd.RDD
import org.scalactic.Tolerance._
-import org.apache.spark.sql.{SQLContext, QueryTest, Row, SQLConf}
-import org.apache.spark.sql.TestData._
+import org.apache.spark.sql.{QueryTest, Row, SQLConf}
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.execution.datasources.{ResolvedDataSource, LogicalRelation}
import org.apache.spark.sql.execution.datasources.json.InferSchema.compatibleType
+import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
-import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.util.Utils
-class JsonSuite extends QueryTest with SQLTestUtils with TestJsonData {
-
- protected lazy val ctx = org.apache.spark.sql.test.TestSQLContext
- override def sqlContext: SQLContext = ctx // used by SQLTestUtils
-
- import ctx.sql
- import ctx.implicits._
+class JsonSuite extends QueryTest with SharedSQLContext with TestJsonData {
+ import testImplicits._
test("Type promotion") {
def checkTypePromotion(expected: Any, actual: Any) {
@@ -596,7 +590,8 @@ class JsonSuite extends QueryTest with SQLTestUtils with TestJsonData {
val schema = StructType(StructField("a", LongType, true) :: Nil)
val logicalRelation =
- ctx.read.schema(schema).json(path).queryExecution.analyzed.asInstanceOf[LogicalRelation]
+ ctx.read.schema(schema).json(path)
+ .queryExecution.analyzed.asInstanceOf[LogicalRelation]
val relationWithSchema = logicalRelation.relation.asInstanceOf[JSONRelation]
assert(relationWithSchema.paths === Array(path))
assert(relationWithSchema.schema === schema)
@@ -1040,31 +1035,29 @@ class JsonSuite extends QueryTest with SQLTestUtils with TestJsonData {
}
test("JSONRelation equality test") {
- val context = org.apache.spark.sql.test.TestSQLContext
-
val relation0 = new JSONRelation(
Some(empty),
1.0,
Some(StructType(StructField("a", IntegerType, true) :: Nil)),
- None, None)(context)
+ None, None)(ctx)
val logicalRelation0 = LogicalRelation(relation0)
val relation1 = new JSONRelation(
Some(singleRow),
1.0,
Some(StructType(StructField("a", IntegerType, true) :: Nil)),
- None, None)(context)
+ None, None)(ctx)
val logicalRelation1 = LogicalRelation(relation1)
val relation2 = new JSONRelation(
Some(singleRow),
0.5,
Some(StructType(StructField("a", IntegerType, true) :: Nil)),
- None, None)(context)
+ None, None)(ctx)
val logicalRelation2 = LogicalRelation(relation2)
val relation3 = new JSONRelation(
Some(singleRow),
1.0,
Some(StructType(StructField("b", IntegerType, true) :: Nil)),
- None, None)(context)
+ None, None)(ctx)
val logicalRelation3 = LogicalRelation(relation3)
assert(relation0 !== relation1)
@@ -1089,14 +1082,14 @@ class JsonSuite extends QueryTest with SQLTestUtils with TestJsonData {
.map(i => s"""{"a": 1, "b": "str$i"}""").saveAsTextFile(path)
val d1 = ResolvedDataSource(
- context,
+ ctx,
userSpecifiedSchema = None,
partitionColumns = Array.empty[String],
provider = classOf[DefaultSource].getCanonicalName,
options = Map("path" -> path))
val d2 = ResolvedDataSource(
- context,
+ ctx,
userSpecifiedSchema = None,
partitionColumns = Array.empty[String],
provider = classOf[DefaultSource].getCanonicalName,
@@ -1162,11 +1155,12 @@ class JsonSuite extends QueryTest with SQLTestUtils with TestJsonData {
"abd")
ctx.read.json(root.getAbsolutePath).registerTempTable("test_myjson_with_part")
- checkAnswer(
- sql("SELECT count(a) FROM test_myjson_with_part where d1 = 1 and col1='abc'"), Row(4))
- checkAnswer(
- sql("SELECT count(a) FROM test_myjson_with_part where d1 = 1 and col1='abd'"), Row(5))
- checkAnswer(sql("SELECT count(a) FROM test_myjson_with_part where d1 = 1"), Row(9))
+ checkAnswer(sql(
+ "SELECT count(a) FROM test_myjson_with_part where d1 = 1 and col1='abc'"), Row(4))
+ checkAnswer(sql(
+ "SELECT count(a) FROM test_myjson_with_part where d1 = 1 and col1='abd'"), Row(5))
+ checkAnswer(sql(
+ "SELECT count(a) FROM test_myjson_with_part where d1 = 1"), Row(9))
})
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala
index 6b62c9a003..2864181cf9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/json/TestJsonData.scala
@@ -20,12 +20,11 @@ package org.apache.spark.sql.execution.datasources.json
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.SQLContext
-trait TestJsonData {
-
- protected def ctx: SQLContext
+private[json] trait TestJsonData {
+ protected def _sqlContext: SQLContext
def primitiveFieldAndType: RDD[String] =
- ctx.sparkContext.parallelize(
+ _sqlContext.sparkContext.parallelize(
"""{"string":"this is a simple string.",
"integer":10,
"long":21474836470,
@@ -36,7 +35,7 @@ trait TestJsonData {
}""" :: Nil)
def primitiveFieldValueTypeConflict: RDD[String] =
- ctx.sparkContext.parallelize(
+ _sqlContext.sparkContext.parallelize(
"""{"num_num_1":11, "num_num_2":null, "num_num_3": 1.1,
"num_bool":true, "num_str":13.1, "str_bool":"str1"}""" ::
"""{"num_num_1":null, "num_num_2":21474836470.9, "num_num_3": null,
@@ -47,14 +46,14 @@ trait TestJsonData {
"num_bool":null, "num_str":92233720368547758070, "str_bool":null}""" :: Nil)
def jsonNullStruct: RDD[String] =
- ctx.sparkContext.parallelize(
+ _sqlContext.sparkContext.parallelize(
"""{"nullstr":"","ip":"27.31.100.29","headers":{"Host":"1.abc.com","Charset":"UTF-8"}}""" ::
"""{"nullstr":"","ip":"27.31.100.29","headers":{}}""" ::
"""{"nullstr":"","ip":"27.31.100.29","headers":""}""" ::
"""{"nullstr":null,"ip":"27.31.100.29","headers":null}""" :: Nil)
def complexFieldValueTypeConflict: RDD[String] =
- ctx.sparkContext.parallelize(
+ _sqlContext.sparkContext.parallelize(
"""{"num_struct":11, "str_array":[1, 2, 3],
"array":[], "struct_array":[], "struct": {}}""" ::
"""{"num_struct":{"field":false}, "str_array":null,
@@ -65,14 +64,14 @@ trait TestJsonData {
"array":[7], "struct_array":{"field": true}, "struct": {"field": "str"}}""" :: Nil)
def arrayElementTypeConflict: RDD[String] =
- ctx.sparkContext.parallelize(
+ _sqlContext.sparkContext.parallelize(
"""{"array1": [1, 1.1, true, null, [], {}, [2,3,4], {"field":"str"}],
"array2": [{"field":214748364700}, {"field":1}]}""" ::
"""{"array3": [{"field":"str"}, {"field":1}]}""" ::
"""{"array3": [1, 2, 3]}""" :: Nil)
def missingFields: RDD[String] =
- ctx.sparkContext.parallelize(
+ _sqlContext.sparkContext.parallelize(
"""{"a":true}""" ::
"""{"b":21474836470}""" ::
"""{"c":[33, 44]}""" ::
@@ -80,7 +79,7 @@ trait TestJsonData {
"""{"e":"str"}""" :: Nil)
def complexFieldAndType1: RDD[String] =
- ctx.sparkContext.parallelize(
+ _sqlContext.sparkContext.parallelize(
"""{"struct":{"field1": true, "field2": 92233720368547758070},
"structWithArrayFields":{"field1":[4, 5, 6], "field2":["str1", "str2"]},
"arrayOfString":["str1", "str2"],
@@ -96,7 +95,7 @@ trait TestJsonData {
}""" :: Nil)
def complexFieldAndType2: RDD[String] =
- ctx.sparkContext.parallelize(
+ _sqlContext.sparkContext.parallelize(
"""{"arrayOfStruct":[{"field1": true, "field2": "str1"}, {"field1": false}, {"field3": null}],
"complexArrayOfStruct": [
{
@@ -150,7 +149,7 @@ trait TestJsonData {
}""" :: Nil)
def mapType1: RDD[String] =
- ctx.sparkContext.parallelize(
+ _sqlContext.sparkContext.parallelize(
"""{"map": {"a": 1}}""" ::
"""{"map": {"b": 2}}""" ::
"""{"map": {"c": 3}}""" ::
@@ -158,7 +157,7 @@ trait TestJsonData {
"""{"map": {"e": null}}""" :: Nil)
def mapType2: RDD[String] =
- ctx.sparkContext.parallelize(
+ _sqlContext.sparkContext.parallelize(
"""{"map": {"a": {"field1": [1, 2, 3, null]}}}""" ::
"""{"map": {"b": {"field2": 2}}}""" ::
"""{"map": {"c": {"field1": [], "field2": 4}}}""" ::
@@ -167,21 +166,21 @@ trait TestJsonData {
"""{"map": {"f": {"field1": null}}}""" :: Nil)
def nullsInArrays: RDD[String] =
- ctx.sparkContext.parallelize(
+ _sqlContext.sparkContext.parallelize(
"""{"field1":[[null], [[["Test"]]]]}""" ::
"""{"field2":[null, [{"Test":1}]]}""" ::
"""{"field3":[[null], [{"Test":"2"}]]}""" ::
"""{"field4":[[null, [1,2,3]]]}""" :: Nil)
def jsonArray: RDD[String] =
- ctx.sparkContext.parallelize(
+ _sqlContext.sparkContext.parallelize(
"""[{"a":"str_a_1"}]""" ::
"""[{"a":"str_a_2"}, {"b":"str_b_3"}]""" ::
"""{"b":"str_b_4", "a":"str_a_4", "c":"str_c_4"}""" ::
"""[]""" :: Nil)
def corruptRecords: RDD[String] =
- ctx.sparkContext.parallelize(
+ _sqlContext.sparkContext.parallelize(
"""{""" ::
"""""" ::
"""{"a":1, b:2}""" ::
@@ -190,7 +189,7 @@ trait TestJsonData {
"""]""" :: Nil)
def emptyRecords: RDD[String] =
- ctx.sparkContext.parallelize(
+ _sqlContext.sparkContext.parallelize(
"""{""" ::
"""""" ::
"""{"a": {}}""" ::
@@ -198,9 +197,8 @@ trait TestJsonData {
"""{"b": [{"c": {}}]}""" ::
"""]""" :: Nil)
- lazy val singleRow: RDD[String] =
- ctx.sparkContext.parallelize(
- """{"a":123}""" :: Nil)
- def empty: RDD[String] = ctx.sparkContext.parallelize(Seq[String]())
+ lazy val singleRow: RDD[String] = _sqlContext.sparkContext.parallelize("""{"a":123}""" :: Nil)
+
+ def empty: RDD[String] = _sqlContext.sparkContext.parallelize(Seq[String]())
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala
index 866a975ad5..82d40e2b61 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetAvroCompatibilitySuite.scala
@@ -27,18 +27,16 @@ import org.apache.avro.generic.IndexedRecord
import org.apache.hadoop.fs.Path
import org.apache.parquet.avro.AvroParquetWriter
-import org.apache.spark.sql.execution.datasources.parquet.test.avro.{Nested, ParquetAvroCompat, ParquetEnum, Suit}
-import org.apache.spark.sql.test.TestSQLContext
-import org.apache.spark.sql.{Row, SQLContext}
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.execution.datasources.parquet.test.avro._
+import org.apache.spark.sql.test.SharedSQLContext
-class ParquetAvroCompatibilitySuite extends ParquetCompatibilityTest {
+class ParquetAvroCompatibilitySuite extends ParquetCompatibilityTest with SharedSQLContext {
import ParquetCompatibilityTest._
- override val sqlContext: SQLContext = TestSQLContext
-
private def withWriter[T <: IndexedRecord]
(path: String, schema: Schema)
- (f: AvroParquetWriter[T] => Unit) = {
+ (f: AvroParquetWriter[T] => Unit): Unit = {
val writer = new AvroParquetWriter[T](new Path(path), schema)
try f(writer) finally writer.close()
}
@@ -129,7 +127,7 @@ class ParquetAvroCompatibilitySuite extends ParquetCompatibilityTest {
}
test("SPARK-9407 Don't push down predicates involving Parquet ENUM columns") {
- import sqlContext.implicits._
+ import testImplicits._
withTempPath { dir =>
val path = dir.getCanonicalPath
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala
index 0ea64aa2a5..b3406729fc 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetCompatibilityTest.scala
@@ -22,16 +22,18 @@ import scala.collection.JavaConversions._
import org.apache.hadoop.fs.{Path, PathFilter}
import org.apache.parquet.hadoop.ParquetFileReader
import org.apache.parquet.schema.MessageType
-import org.scalatest.BeforeAndAfterAll
import org.apache.spark.sql.QueryTest
-abstract class ParquetCompatibilityTest extends QueryTest with ParquetTest with BeforeAndAfterAll {
- def readParquetSchema(path: String): MessageType = {
+/**
+ * Helper class for testing Parquet compatibility.
+ */
+private[sql] abstract class ParquetCompatibilityTest extends QueryTest with ParquetTest {
+ protected def readParquetSchema(path: String): MessageType = {
readParquetSchema(path, { path => !path.getName.startsWith("_") })
}
- def readParquetSchema(path: String, pathFilter: Path => Boolean): MessageType = {
+ protected def readParquetSchema(path: String, pathFilter: Path => Boolean): MessageType = {
val fsPath = new Path(path)
val fs = fsPath.getFileSystem(configuration)
val parquetFiles = fs.listStatus(fsPath, new PathFilter {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
index 7dd9680d8c..5b4e568bb9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetFilterSuite.scala
@@ -20,12 +20,13 @@ package org.apache.spark.sql.execution.datasources.parquet
import org.apache.parquet.filter2.predicate.Operators._
import org.apache.parquet.filter2.predicate.{FilterPredicate, Operators}
+import org.apache.spark.sql.{Column, DataFrame, QueryTest, Row, SQLConf}
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.planning.PhysicalOperation
import org.apache.spark.sql.execution.datasources.LogicalRelation
+import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
-import org.apache.spark.sql.{Column, DataFrame, QueryTest, Row, SQLConf}
/**
* A test suite that tests Parquet filter2 API based filter pushdown optimization.
@@ -39,8 +40,7 @@ import org.apache.spark.sql.{Column, DataFrame, QueryTest, Row, SQLConf}
* 2. `Tuple1(Option(x))` is used together with `AnyVal` types like `Int` to ensure the inferred
* data type is nullable.
*/
-class ParquetFilterSuite extends QueryTest with ParquetTest {
- lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext
+class ParquetFilterSuite extends QueryTest with ParquetTest with SharedSQLContext {
private def checkFilterPredicate(
df: DataFrame,
@@ -301,7 +301,7 @@ class ParquetFilterSuite extends QueryTest with ParquetTest {
}
test("SPARK-6554: don't push down predicates which reference partition columns") {
- import sqlContext.implicits._
+ import testImplicits._
withSQLConf(SQLConf.PARQUET_FILTER_PUSHDOWN_ENABLED.key -> "true") {
withTempPath { dir =>
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
index cb166349fd..d819f3ab5e 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetIOSuite.scala
@@ -37,6 +37,7 @@ import org.apache.spark.SparkException
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.util.DateTimeUtils
+import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
// Write support class for nested groups: ParquetWriter initializes GroupWriteSupport
@@ -62,9 +63,8 @@ private[parquet] class TestGroupWriteSupport(schema: MessageType) extends WriteS
/**
* A test suite that tests basic Parquet I/O.
*/
-class ParquetIOSuite extends QueryTest with ParquetTest {
- lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext
- import sqlContext.implicits._
+class ParquetIOSuite extends QueryTest with ParquetTest with SharedSQLContext {
+ import testImplicits._
/**
* Writes `data` to a Parquet file, reads it back and check file contents.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala
index 73152de244..ed8bafb10c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetPartitionDiscoverySuite.scala
@@ -26,13 +26,13 @@ import scala.collection.mutable.ArrayBuffer
import com.google.common.io.Files
import org.apache.hadoop.fs.Path
+import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.execution.datasources.{LogicalRelation, PartitionSpec, Partition, PartitioningUtils}
+import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
-import org.apache.spark.sql._
import org.apache.spark.unsafe.types.UTF8String
-import PartitioningUtils._
// The data where the partitioning key exists only in the directory structure.
case class ParquetData(intField: Int, stringField: String)
@@ -40,11 +40,9 @@ case class ParquetData(intField: Int, stringField: String)
// The data that also includes the partitioning key
case class ParquetDataWithKey(intField: Int, pi: Int, stringField: String, ps: String)
-class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest {
-
- override lazy val sqlContext: SQLContext = org.apache.spark.sql.test.TestSQLContext
- import sqlContext.implicits._
- import sqlContext.sql
+class ParquetPartitionDiscoverySuite extends QueryTest with ParquetTest with SharedSQLContext {
+ import PartitioningUtils._
+ import testImplicits._
val defaultPartitionName = "__HIVE_DEFAULT_PARTITION__"
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetProtobufCompatibilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetProtobufCompatibilitySuite.scala
index 981334cf77..b290429c2a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetProtobufCompatibilitySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetProtobufCompatibilitySuite.scala
@@ -17,11 +17,10 @@
package org.apache.spark.sql.execution.datasources.parquet
-import org.apache.spark.sql.test.TestSQLContext
-import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.test.SharedSQLContext
-class ParquetProtobufCompatibilitySuite extends ParquetCompatibilityTest {
- override def sqlContext: SQLContext = TestSQLContext
+class ParquetProtobufCompatibilitySuite extends ParquetCompatibilityTest with SharedSQLContext {
private def readParquetProtobufFile(name: String): DataFrame = {
val url = Thread.currentThread().getContextClassLoader.getResource(name)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
index 5e6d9c1cd4..e2f2a8c744 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala
@@ -21,16 +21,15 @@ import java.io.File
import org.apache.hadoop.fs.Path
-import org.apache.spark.sql.types._
import org.apache.spark.sql.{QueryTest, Row, SQLConf}
+import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
/**
* A test suite that tests various Parquet queries.
*/
-class ParquetQuerySuite extends QueryTest with ParquetTest {
- lazy val sqlContext = org.apache.spark.sql.test.TestSQLContext
- import sqlContext.sql
+class ParquetQuerySuite extends QueryTest with ParquetTest with SharedSQLContext {
test("simple select queries") {
withParquetTable((0 until 10).map(i => (i, i.toString)), "t") {
@@ -41,22 +40,22 @@ class ParquetQuerySuite extends QueryTest with ParquetTest {
test("appending") {
val data = (0 until 10).map(i => (i, i.toString))
- sqlContext.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp")
+ ctx.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp")
withParquetTable(data, "t") {
sql("INSERT INTO TABLE t SELECT * FROM tmp")
- checkAnswer(sqlContext.table("t"), (data ++ data).map(Row.fromTuple))
+ checkAnswer(ctx.table("t"), (data ++ data).map(Row.fromTuple))
}
- sqlContext.catalog.unregisterTable(Seq("tmp"))
+ ctx.catalog.unregisterTable(Seq("tmp"))
}
test("overwriting") {
val data = (0 until 10).map(i => (i, i.toString))
- sqlContext.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp")
+ ctx.createDataFrame(data).toDF("c1", "c2").registerTempTable("tmp")
withParquetTable(data, "t") {
sql("INSERT OVERWRITE TABLE t SELECT * FROM tmp")
- checkAnswer(sqlContext.table("t"), data.map(Row.fromTuple))
+ checkAnswer(ctx.table("t"), data.map(Row.fromTuple))
}
- sqlContext.catalog.unregisterTable(Seq("tmp"))
+ ctx.catalog.unregisterTable(Seq("tmp"))
}
test("self-join") {
@@ -119,9 +118,9 @@ class ParquetQuerySuite extends QueryTest with ParquetTest {
val schema = StructType(List(StructField("d", DecimalType(18, 0), false),
StructField("time", TimestampType, false)).toArray)
withTempPath { file =>
- val df = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(data), schema)
+ val df = ctx.createDataFrame(ctx.sparkContext.parallelize(data), schema)
df.write.parquet(file.getCanonicalPath)
- val df2 = sqlContext.read.parquet(file.getCanonicalPath)
+ val df2 = ctx.read.parquet(file.getCanonicalPath)
checkAnswer(df2, df.collect().toSeq)
}
}
@@ -130,12 +129,12 @@ class ParquetQuerySuite extends QueryTest with ParquetTest {
def testSchemaMerging(expectedColumnNumber: Int): Unit = {
withTempDir { dir =>
val basePath = dir.getCanonicalPath
- sqlContext.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString)
- sqlContext.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=2").toString)
+ ctx.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString)
+ ctx.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=2").toString)
// delete summary files, so if we don't merge part-files, one column will not be included.
Utils.deleteRecursively(new File(basePath + "/foo=1/_metadata"))
Utils.deleteRecursively(new File(basePath + "/foo=1/_common_metadata"))
- assert(sqlContext.read.parquet(basePath).columns.length === expectedColumnNumber)
+ assert(ctx.read.parquet(basePath).columns.length === expectedColumnNumber)
}
}
@@ -154,9 +153,9 @@ class ParquetQuerySuite extends QueryTest with ParquetTest {
def testSchemaMerging(expectedColumnNumber: Int): Unit = {
withTempDir { dir =>
val basePath = dir.getCanonicalPath
- sqlContext.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString)
- sqlContext.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=2").toString)
- assert(sqlContext.read.parquet(basePath).columns.length === expectedColumnNumber)
+ ctx.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString)
+ ctx.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=2").toString)
+ assert(ctx.read.parquet(basePath).columns.length === expectedColumnNumber)
}
}
@@ -172,19 +171,19 @@ class ParquetQuerySuite extends QueryTest with ParquetTest {
test("SPARK-8990 DataFrameReader.parquet() should respect user specified options") {
withTempPath { dir =>
val basePath = dir.getCanonicalPath
- sqlContext.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString)
- sqlContext.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=a").toString)
+ ctx.range(0, 10).toDF("a").write.parquet(new Path(basePath, "foo=1").toString)
+ ctx.range(0, 10).toDF("b").write.parquet(new Path(basePath, "foo=a").toString)
// Disables the global SQL option for schema merging
withSQLConf(SQLConf.PARQUET_SCHEMA_MERGING_ENABLED.key -> "false") {
assertResult(2) {
// Disables schema merging via data source option
- sqlContext.read.option("mergeSchema", "false").parquet(basePath).columns.length
+ ctx.read.option("mergeSchema", "false").parquet(basePath).columns.length
}
assertResult(3) {
// Enables schema merging via data source option
- sqlContext.read.option("mergeSchema", "true").parquet(basePath).columns.length
+ ctx.read.option("mergeSchema", "true").parquet(basePath).columns.length
}
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala
index 971f71e27b..9dcbc1a047 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetSchemaSuite.scala
@@ -22,13 +22,11 @@ import scala.reflect.runtime.universe.TypeTag
import org.apache.parquet.schema.MessageTypeParser
-import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.ScalaReflection
-import org.apache.spark.sql.test.TestSQLContext
+import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
-abstract class ParquetSchemaTest extends SparkFunSuite with ParquetTest {
- val sqlContext = TestSQLContext
+abstract class ParquetSchemaTest extends ParquetTest with SharedSQLContext {
/**
* Checks whether the reflected Parquet message type for product type `T` conforms `messageType`.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala
index 3c6e54db4b..5dbc7d1630 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetTest.scala
@@ -22,9 +22,8 @@ import java.io.File
import scala.reflect.ClassTag
import scala.reflect.runtime.universe.TypeTag
-import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.test.SQLTestUtils
-import org.apache.spark.sql.{DataFrame, SaveMode}
+import org.apache.spark.sql.{DataFrame, SaveMode, SQLContext}
/**
* A helper trait that provides convenient facilities for Parquet testing.
@@ -33,7 +32,9 @@ import org.apache.spark.sql.{DataFrame, SaveMode}
* convenient to use tuples rather than special case classes when writing test cases/suites.
* Especially, `Tuple1.apply` can be used to easily wrap a single type/value.
*/
-private[sql] trait ParquetTest extends SQLTestUtils { this: SparkFunSuite =>
+private[sql] trait ParquetTest extends SQLTestUtils {
+ protected def _sqlContext: SQLContext
+
/**
* Writes `data` to a Parquet file, which is then passed to `f` and will be deleted after `f`
* returns.
@@ -42,7 +43,7 @@ private[sql] trait ParquetTest extends SQLTestUtils { this: SparkFunSuite =>
(data: Seq[T])
(f: String => Unit): Unit = {
withTempPath { file =>
- sqlContext.createDataFrame(data).write.parquet(file.getCanonicalPath)
+ _sqlContext.createDataFrame(data).write.parquet(file.getCanonicalPath)
f(file.getCanonicalPath)
}
}
@@ -54,7 +55,7 @@ private[sql] trait ParquetTest extends SQLTestUtils { this: SparkFunSuite =>
protected def withParquetDataFrame[T <: Product: ClassTag: TypeTag]
(data: Seq[T])
(f: DataFrame => Unit): Unit = {
- withParquetFile(data)(path => f(sqlContext.read.parquet(path)))
+ withParquetFile(data)(path => f(_sqlContext.read.parquet(path)))
}
/**
@@ -66,14 +67,14 @@ private[sql] trait ParquetTest extends SQLTestUtils { this: SparkFunSuite =>
(data: Seq[T], tableName: String)
(f: => Unit): Unit = {
withParquetDataFrame(data) { df =>
- sqlContext.registerDataFrameAsTable(df, tableName)
+ _sqlContext.registerDataFrameAsTable(df, tableName)
withTempTable(tableName)(f)
}
}
protected def makeParquetFile[T <: Product: ClassTag: TypeTag](
data: Seq[T], path: File): Unit = {
- sqlContext.createDataFrame(data).write.mode(SaveMode.Overwrite).parquet(path.getCanonicalPath)
+ _sqlContext.createDataFrame(data).write.mode(SaveMode.Overwrite).parquet(path.getCanonicalPath)
}
protected def makeParquetFile[T <: Product: ClassTag: TypeTag](
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala
index 92b1d82217..b789c5a106 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetThriftCompatibilitySuite.scala
@@ -17,14 +17,12 @@
package org.apache.spark.sql.execution.datasources.parquet
-import org.apache.spark.sql.test.TestSQLContext
-import org.apache.spark.sql.{Row, SQLContext}
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.test.SharedSQLContext
-class ParquetThriftCompatibilitySuite extends ParquetCompatibilityTest {
+class ParquetThriftCompatibilitySuite extends ParquetCompatibilityTest with SharedSQLContext {
import ParquetCompatibilityTest._
- override val sqlContext: SQLContext = TestSQLContext
-
private val parquetFilePath =
Thread.currentThread().getContextClassLoader.getResource("parquet-thrift-compat.snappy.parquet")
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala
index 239deb7973..22189477d2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/debug/DebuggingSuite.scala
@@ -18,10 +18,10 @@
package org.apache.spark.sql.execution.debug
import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.TestData._
-import org.apache.spark.sql.test.TestSQLContext._
+import org.apache.spark.sql.test.SharedSQLContext
+
+class DebuggingSuite extends SparkFunSuite with SharedSQLContext {
-class DebuggingSuite extends SparkFunSuite {
test("DataFrame.debug()") {
testData.debug()
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
index d33a967093..4c9187a9a7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
@@ -23,12 +23,12 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.metric.SQLMetrics
-import org.apache.spark.sql.test.TestSQLContext
+import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{IntegerType, StructField, StructType}
import org.apache.spark.util.collection.CompactBuffer
-class HashedRelationSuite extends SparkFunSuite {
+class HashedRelationSuite extends SparkFunSuite with SharedSQLContext {
// Key is simply the record itself
private val keyProjection = new Projection {
@@ -37,7 +37,7 @@ class HashedRelationSuite extends SparkFunSuite {
test("GeneralHashedRelation") {
val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2))
- val numDataRows = SQLMetrics.createLongMetric(TestSQLContext.sparkContext, "data")
+ val numDataRows = SQLMetrics.createLongMetric(ctx.sparkContext, "data")
val hashed = HashedRelation(data.iterator, numDataRows, keyProjection)
assert(hashed.isInstanceOf[GeneralHashedRelation])
@@ -53,7 +53,7 @@ class HashedRelationSuite extends SparkFunSuite {
test("UniqueKeyHashedRelation") {
val data = Array(InternalRow(0), InternalRow(1), InternalRow(2))
- val numDataRows = SQLMetrics.createLongMetric(TestSQLContext.sparkContext, "data")
+ val numDataRows = SQLMetrics.createLongMetric(ctx.sparkContext, "data")
val hashed = HashedRelation(data.iterator, numDataRows, keyProjection)
assert(hashed.isInstanceOf[UniqueKeyHashedRelation])
@@ -73,7 +73,7 @@ class HashedRelationSuite extends SparkFunSuite {
test("UnsafeHashedRelation") {
val schema = StructType(StructField("a", IntegerType, true) :: Nil)
val data = Array(InternalRow(0), InternalRow(1), InternalRow(2), InternalRow(2))
- val numDataRows = SQLMetrics.createLongMetric(TestSQLContext.sparkContext, "data")
+ val numDataRows = SQLMetrics.createLongMetric(ctx.sparkContext, "data")
val toUnsafe = UnsafeProjection.create(schema)
val unsafeData = data.map(toUnsafe(_).copy()).toArray
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
index ddff7cebcc..cc649b9bd4 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/InnerJoinSuite.scala
@@ -17,97 +17,19 @@
package org.apache.spark.sql.execution.joins
+import org.apache.spark.sql.{DataFrame, execution, Row, SQLConf}
+import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.logical.Join
-import org.apache.spark.sql.test.SQLTestUtils
-import org.apache.spark.sql.types.{IntegerType, StringType, StructType}
-import org.apache.spark.sql.{SQLConf, execution, Row, DataFrame}
-import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.execution._
+import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.types.{IntegerType, StringType, StructType}
-class InnerJoinSuite extends SparkPlanTest with SQLTestUtils {
-
- private def testInnerJoin(
- testName: String,
- leftRows: DataFrame,
- rightRows: DataFrame,
- condition: Expression,
- expectedAnswer: Seq[Product]): Unit = {
- val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition))
- ExtractEquiJoinKeys.unapply(join).foreach {
- case (joinType, leftKeys, rightKeys, boundCondition, leftChild, rightChild) =>
-
- def makeBroadcastHashJoin(left: SparkPlan, right: SparkPlan, side: BuildSide) = {
- val broadcastHashJoin =
- execution.joins.BroadcastHashJoin(leftKeys, rightKeys, side, left, right)
- boundCondition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin)
- }
-
- def makeShuffledHashJoin(left: SparkPlan, right: SparkPlan, side: BuildSide) = {
- val shuffledHashJoin =
- execution.joins.ShuffledHashJoin(leftKeys, rightKeys, side, left, right)
- val filteredJoin =
- boundCondition.map(Filter(_, shuffledHashJoin)).getOrElse(shuffledHashJoin)
- EnsureRequirements(sqlContext).apply(filteredJoin)
- }
-
- def makeSortMergeJoin(left: SparkPlan, right: SparkPlan) = {
- val sortMergeJoin =
- execution.joins.SortMergeJoin(leftKeys, rightKeys, left, right)
- val filteredJoin = boundCondition.map(Filter(_, sortMergeJoin)).getOrElse(sortMergeJoin)
- EnsureRequirements(sqlContext).apply(filteredJoin)
- }
-
- test(s"$testName using BroadcastHashJoin (build=left)") {
- withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
- checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
- makeBroadcastHashJoin(left, right, joins.BuildLeft),
- expectedAnswer.map(Row.fromTuple),
- sortAnswers = true)
- }
- }
-
- test(s"$testName using BroadcastHashJoin (build=right)") {
- withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
- checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
- makeBroadcastHashJoin(left, right, joins.BuildRight),
- expectedAnswer.map(Row.fromTuple),
- sortAnswers = true)
- }
- }
-
- test(s"$testName using ShuffledHashJoin (build=left)") {
- withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
- checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
- makeShuffledHashJoin(left, right, joins.BuildLeft),
- expectedAnswer.map(Row.fromTuple),
- sortAnswers = true)
- }
- }
-
- test(s"$testName using ShuffledHashJoin (build=right)") {
- withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
- checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
- makeShuffledHashJoin(left, right, joins.BuildRight),
- expectedAnswer.map(Row.fromTuple),
- sortAnswers = true)
- }
- }
+class InnerJoinSuite extends SparkPlanTest with SharedSQLContext {
- test(s"$testName using SortMergeJoin") {
- withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
- checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
- makeSortMergeJoin(left, right),
- expectedAnswer.map(Row.fromTuple),
- sortAnswers = true)
- }
- }
- }
- }
-
- {
- val upperCaseData = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq(
+ private lazy val myUpperCaseData = ctx.createDataFrame(
+ ctx.sparkContext.parallelize(Seq(
Row(1, "A"),
Row(2, "B"),
Row(3, "C"),
@@ -117,7 +39,8 @@ class InnerJoinSuite extends SparkPlanTest with SQLTestUtils {
Row(null, "G")
)), new StructType().add("N", IntegerType).add("L", StringType))
- val lowerCaseData = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq(
+ private lazy val myLowerCaseData = ctx.createDataFrame(
+ ctx.sparkContext.parallelize(Seq(
Row(1, "a"),
Row(2, "b"),
Row(3, "c"),
@@ -125,21 +48,7 @@ class InnerJoinSuite extends SparkPlanTest with SQLTestUtils {
Row(null, "e")
)), new StructType().add("n", IntegerType).add("l", StringType))
- testInnerJoin(
- "inner join, one match per row",
- upperCaseData,
- lowerCaseData,
- (upperCaseData.col("N") === lowerCaseData.col("n")).expr,
- Seq(
- (1, "A", 1, "a"),
- (2, "B", 2, "b"),
- (3, "C", 3, "c"),
- (4, "D", 4, "d")
- )
- )
- }
-
- private val testData2 = Seq(
+ private lazy val myTestData = Seq(
(1, 1),
(1, 2),
(2, 1),
@@ -148,14 +57,139 @@ class InnerJoinSuite extends SparkPlanTest with SQLTestUtils {
(3, 2)
).toDF("a", "b")
+ // Note: the input dataframes and expression must be evaluated lazily because
+ // the SQLContext should be used only within a test to keep SQL tests stable
+ private def testInnerJoin(
+ testName: String,
+ leftRows: => DataFrame,
+ rightRows: => DataFrame,
+ condition: () => Expression,
+ expectedAnswer: Seq[Product]): Unit = {
+
+ def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = {
+ val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition()))
+ ExtractEquiJoinKeys.unapply(join)
+ }
+
+ def makeBroadcastHashJoin(
+ leftKeys: Seq[Expression],
+ rightKeys: Seq[Expression],
+ boundCondition: Option[Expression],
+ leftPlan: SparkPlan,
+ rightPlan: SparkPlan,
+ side: BuildSide) = {
+ val broadcastHashJoin =
+ execution.joins.BroadcastHashJoin(leftKeys, rightKeys, side, leftPlan, rightPlan)
+ boundCondition.map(Filter(_, broadcastHashJoin)).getOrElse(broadcastHashJoin)
+ }
+
+ def makeShuffledHashJoin(
+ leftKeys: Seq[Expression],
+ rightKeys: Seq[Expression],
+ boundCondition: Option[Expression],
+ leftPlan: SparkPlan,
+ rightPlan: SparkPlan,
+ side: BuildSide) = {
+ val shuffledHashJoin =
+ execution.joins.ShuffledHashJoin(leftKeys, rightKeys, side, leftPlan, rightPlan)
+ val filteredJoin =
+ boundCondition.map(Filter(_, shuffledHashJoin)).getOrElse(shuffledHashJoin)
+ EnsureRequirements(sqlContext).apply(filteredJoin)
+ }
+
+ def makeSortMergeJoin(
+ leftKeys: Seq[Expression],
+ rightKeys: Seq[Expression],
+ boundCondition: Option[Expression],
+ leftPlan: SparkPlan,
+ rightPlan: SparkPlan) = {
+ val sortMergeJoin =
+ execution.joins.SortMergeJoin(leftKeys, rightKeys, leftPlan, rightPlan)
+ val filteredJoin = boundCondition.map(Filter(_, sortMergeJoin)).getOrElse(sortMergeJoin)
+ EnsureRequirements(sqlContext).apply(filteredJoin)
+ }
+
+ test(s"$testName using BroadcastHashJoin (build=left)") {
+ extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) =>
+ withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
+ checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) =>
+ makeBroadcastHashJoin(
+ leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildLeft),
+ expectedAnswer.map(Row.fromTuple),
+ sortAnswers = true)
+ }
+ }
+ }
+
+ test(s"$testName using BroadcastHashJoin (build=right)") {
+ extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) =>
+ withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
+ checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) =>
+ makeBroadcastHashJoin(
+ leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildRight),
+ expectedAnswer.map(Row.fromTuple),
+ sortAnswers = true)
+ }
+ }
+ }
+
+ test(s"$testName using ShuffledHashJoin (build=left)") {
+ extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) =>
+ withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
+ checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) =>
+ makeShuffledHashJoin(
+ leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildLeft),
+ expectedAnswer.map(Row.fromTuple),
+ sortAnswers = true)
+ }
+ }
+ }
+
+ test(s"$testName using ShuffledHashJoin (build=right)") {
+ extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) =>
+ withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
+ checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) =>
+ makeShuffledHashJoin(
+ leftKeys, rightKeys, boundCondition, leftPlan, rightPlan, joins.BuildRight),
+ expectedAnswer.map(Row.fromTuple),
+ sortAnswers = true)
+ }
+ }
+ }
+
+ test(s"$testName using SortMergeJoin") {
+ extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) =>
+ withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
+ checkAnswer2(leftRows, rightRows, (leftPlan: SparkPlan, rightPlan: SparkPlan) =>
+ makeSortMergeJoin(leftKeys, rightKeys, boundCondition, leftPlan, rightPlan),
+ expectedAnswer.map(Row.fromTuple),
+ sortAnswers = true)
+ }
+ }
+ }
+ }
+
+ testInnerJoin(
+ "inner join, one match per row",
+ myUpperCaseData,
+ myLowerCaseData,
+ () => (myUpperCaseData.col("N") === myLowerCaseData.col("n")).expr,
+ Seq(
+ (1, "A", 1, "a"),
+ (2, "B", 2, "b"),
+ (3, "C", 3, "c"),
+ (4, "D", 4, "d")
+ )
+ )
+
{
- val left = testData2.where("a = 1")
- val right = testData2.where("a = 1")
+ lazy val left = myTestData.where("a = 1")
+ lazy val right = myTestData.where("a = 1")
testInnerJoin(
"inner join, multiple matches",
left,
right,
- (left.col("a") === right.col("a")).expr,
+ () => (left.col("a") === right.col("a")).expr,
Seq(
(1, 1, 1, 1),
(1, 1, 1, 2),
@@ -166,13 +200,13 @@ class InnerJoinSuite extends SparkPlanTest with SQLTestUtils {
}
{
- val left = testData2.where("a = 1")
- val right = testData2.where("a = 2")
+ lazy val left = myTestData.where("a = 1")
+ lazy val right = myTestData.where("a = 2")
testInnerJoin(
"inner join, no matches",
left,
right,
- (left.col("a") === right.col("a")).expr,
+ () => (left.col("a") === right.col("a")).expr,
Seq.empty
)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
index e16f5e39aa..a1a617d7b7 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/OuterJoinSuite.scala
@@ -17,28 +17,65 @@
package org.apache.spark.sql.execution.joins
+import org.apache.spark.sql.{DataFrame, Row, SQLConf}
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
+import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical.Join
-import org.apache.spark.sql.test.SQLTestUtils
+import org.apache.spark.sql.catalyst.expressions.{And, Expression, LessThan}
+import org.apache.spark.sql.execution.{EnsureRequirements, SparkPlan, SparkPlanTest}
+import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{IntegerType, DoubleType, StructType}
-import org.apache.spark.sql.{SQLConf, DataFrame, Row}
-import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans._
-import org.apache.spark.sql.execution.{EnsureRequirements, joins, SparkPlan, SparkPlanTest}
-class OuterJoinSuite extends SparkPlanTest with SQLTestUtils {
+class OuterJoinSuite extends SparkPlanTest with SharedSQLContext {
+
+ private lazy val left = ctx.createDataFrame(
+ ctx.sparkContext.parallelize(Seq(
+ Row(1, 2.0),
+ Row(2, 100.0),
+ Row(2, 1.0), // This row is duplicated to ensure that we will have multiple buffered matches
+ Row(2, 1.0),
+ Row(3, 3.0),
+ Row(5, 1.0),
+ Row(6, 6.0),
+ Row(null, null)
+ )), new StructType().add("a", IntegerType).add("b", DoubleType))
+
+ private lazy val right = ctx.createDataFrame(
+ ctx.sparkContext.parallelize(Seq(
+ Row(0, 0.0),
+ Row(2, 3.0), // This row is duplicated to ensure that we will have multiple buffered matches
+ Row(2, -1.0),
+ Row(2, -1.0),
+ Row(2, 3.0),
+ Row(3, 2.0),
+ Row(4, 1.0),
+ Row(5, 3.0),
+ Row(7, 7.0),
+ Row(null, null)
+ )), new StructType().add("c", IntegerType).add("d", DoubleType))
+
+ private lazy val condition = {
+ And((left.col("a") === right.col("c")).expr,
+ LessThan(left.col("b").expr, right.col("d").expr))
+ }
+ // Note: the input dataframes and expression must be evaluated lazily because
+ // the SQLContext should be used only within a test to keep SQL tests stable
private def testOuterJoin(
testName: String,
- leftRows: DataFrame,
- rightRows: DataFrame,
+ leftRows: => DataFrame,
+ rightRows: => DataFrame,
joinType: JoinType,
- condition: Expression,
+ condition: => Expression,
expectedAnswer: Seq[Product]): Unit = {
- val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition))
- ExtractEquiJoinKeys.unapply(join).foreach {
- case (_, leftKeys, rightKeys, boundCondition, leftChild, rightChild) =>
- test(s"$testName using ShuffledHashOuterJoin") {
+
+ def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = {
+ val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition))
+ ExtractEquiJoinKeys.unapply(join)
+ }
+
+ test(s"$testName using ShuffledHashOuterJoin") {
+ extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) =>
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
EnsureRequirements(sqlContext).apply(
@@ -46,19 +83,23 @@ class OuterJoinSuite extends SparkPlanTest with SQLTestUtils {
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
}
- }
+ }
+ }
- if (joinType != FullOuter) {
- test(s"$testName using BroadcastHashOuterJoin") {
+ if (joinType != FullOuter) {
+ test(s"$testName using BroadcastHashOuterJoin") {
+ extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) =>
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
BroadcastHashOuterJoin(leftKeys, rightKeys, joinType, boundCondition, left, right),
expectedAnswer.map(Row.fromTuple),
sortAnswers = true)
}
- }
+ }
+ }
- test(s"$testName using SortMergeOuterJoin") {
+ test(s"$testName using SortMergeOuterJoin") {
+ extractJoinParts().foreach { case (_, leftKeys, rightKeys, boundCondition, _, _) =>
withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
EnsureRequirements(sqlContext).apply(
@@ -66,57 +107,9 @@ class OuterJoinSuite extends SparkPlanTest with SQLTestUtils {
expectedAnswer.map(Row.fromTuple),
sortAnswers = false)
}
- }
}
- }
-
- test(s"$testName using BroadcastNestedLoopJoin (build=left)") {
- withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
- checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
- joins.BroadcastNestedLoopJoin(left, right, joins.BuildLeft, joinType, Some(condition)),
- expectedAnswer.map(Row.fromTuple),
- sortAnswers = true)
}
}
-
- test(s"$testName using BroadcastNestedLoopJoin (build=right)") {
- withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
- checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
- joins.BroadcastNestedLoopJoin(left, right, joins.BuildRight, joinType, Some(condition)),
- expectedAnswer.map(Row.fromTuple),
- sortAnswers = true)
- }
- }
- }
-
- val left = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq(
- Row(1, 2.0),
- Row(2, 100.0),
- Row(2, 1.0), // This row is duplicated to ensure that we will have multiple buffered matches
- Row(2, 1.0),
- Row(3, 3.0),
- Row(5, 1.0),
- Row(6, 6.0),
- Row(null, null)
- )), new StructType().add("a", IntegerType).add("b", DoubleType))
-
- val right = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq(
- Row(0, 0.0),
- Row(2, 3.0), // This row is duplicated to ensure that we will have multiple buffered matches
- Row(2, -1.0),
- Row(2, -1.0),
- Row(2, 3.0),
- Row(3, 2.0),
- Row(4, 1.0),
- Row(5, 3.0),
- Row(7, 7.0),
- Row(null, null)
- )), new StructType().add("c", IntegerType).add("d", DoubleType))
-
- val condition = {
- And(
- (left.col("a") === right.col("c")).expr,
- LessThan(left.col("b").expr, right.col("d").expr))
}
// --- Basic outer joins ------------------------------------------------------------------------
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala
index 4503ed251f..baa86e320d 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/SemiJoinSuite.scala
@@ -17,44 +17,80 @@
package org.apache.spark.sql.execution.joins
+import org.apache.spark.sql.{SQLConf, DataFrame, Row}
import org.apache.spark.sql.catalyst.planning.ExtractEquiJoinKeys
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.logical.Join
-import org.apache.spark.sql.test.SQLTestUtils
-import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType}
-import org.apache.spark.sql.{SQLConf, DataFrame, Row}
import org.apache.spark.sql.catalyst.expressions.{And, LessThan, Expression}
import org.apache.spark.sql.execution.{EnsureRequirements, SparkPlan, SparkPlanTest}
+import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.types.{DoubleType, IntegerType, StructType}
+
+class SemiJoinSuite extends SparkPlanTest with SharedSQLContext {
-class SemiJoinSuite extends SparkPlanTest with SQLTestUtils {
+ private lazy val left = ctx.createDataFrame(
+ ctx.sparkContext.parallelize(Seq(
+ Row(1, 2.0),
+ Row(1, 2.0),
+ Row(2, 1.0),
+ Row(2, 1.0),
+ Row(3, 3.0),
+ Row(null, null),
+ Row(null, 5.0),
+ Row(6, null)
+ )), new StructType().add("a", IntegerType).add("b", DoubleType))
+ private lazy val right = ctx.createDataFrame(
+ ctx.sparkContext.parallelize(Seq(
+ Row(2, 3.0),
+ Row(2, 3.0),
+ Row(3, 2.0),
+ Row(4, 1.0),
+ Row(null, null),
+ Row(null, 5.0),
+ Row(6, null)
+ )), new StructType().add("c", IntegerType).add("d", DoubleType))
+
+ private lazy val condition = {
+ And((left.col("a") === right.col("c")).expr,
+ LessThan(left.col("b").expr, right.col("d").expr))
+ }
+
+ // Note: the input dataframes and expression must be evaluated lazily because
+ // the SQLContext should be used only within a test to keep SQL tests stable
private def testLeftSemiJoin(
testName: String,
- leftRows: DataFrame,
- rightRows: DataFrame,
- condition: Expression,
+ leftRows: => DataFrame,
+ rightRows: => DataFrame,
+ condition: => Expression,
expectedAnswer: Seq[Product]): Unit = {
- val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition))
- ExtractEquiJoinKeys.unapply(join).foreach {
- case (joinType, leftKeys, rightKeys, boundCondition, leftChild, rightChild) =>
- test(s"$testName using LeftSemiJoinHash") {
- withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
- checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
- EnsureRequirements(left.sqlContext).apply(
- LeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition)),
- expectedAnswer.map(Row.fromTuple),
- sortAnswers = true)
- }
+
+ def extractJoinParts(): Option[ExtractEquiJoinKeys.ReturnType] = {
+ val join = Join(leftRows.logicalPlan, rightRows.logicalPlan, Inner, Some(condition))
+ ExtractEquiJoinKeys.unapply(join)
+ }
+
+ test(s"$testName using LeftSemiJoinHash") {
+ extractJoinParts().foreach { case (joinType, leftKeys, rightKeys, boundCondition, _, _) =>
+ withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
+ checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
+ EnsureRequirements(left.sqlContext).apply(
+ LeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition)),
+ expectedAnswer.map(Row.fromTuple),
+ sortAnswers = true)
}
+ }
+ }
- test(s"$testName using BroadcastLeftSemiJoinHash") {
- withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
- checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
- BroadcastLeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition),
- expectedAnswer.map(Row.fromTuple),
- sortAnswers = true)
- }
+ test(s"$testName using BroadcastLeftSemiJoinHash") {
+ extractJoinParts().foreach { case (joinType, leftKeys, rightKeys, boundCondition, _, _) =>
+ withSQLConf(SQLConf.SHUFFLE_PARTITIONS.key -> "1") {
+ checkAnswer2(leftRows, rightRows, (left: SparkPlan, right: SparkPlan) =>
+ BroadcastLeftSemiJoinHash(leftKeys, rightKeys, left, right, boundCondition),
+ expectedAnswer.map(Row.fromTuple),
+ sortAnswers = true)
}
+ }
}
test(s"$testName using LeftSemiJoinBNL") {
@@ -67,33 +103,6 @@ class SemiJoinSuite extends SparkPlanTest with SQLTestUtils {
}
}
- val left = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq(
- Row(1, 2.0),
- Row(1, 2.0),
- Row(2, 1.0),
- Row(2, 1.0),
- Row(3, 3.0),
- Row(null, null),
- Row(null, 5.0),
- Row(6, null)
- )), new StructType().add("a", IntegerType).add("b", DoubleType))
-
- val right = sqlContext.createDataFrame(sqlContext.sparkContext.parallelize(Seq(
- Row(2, 3.0),
- Row(2, 3.0),
- Row(3, 2.0),
- Row(4, 1.0),
- Row(null, null),
- Row(null, 5.0),
- Row(6, null)
- )), new StructType().add("c", IntegerType).add("d", DoubleType))
-
- val condition = {
- And(
- (left.col("a") === right.col("c")).expr,
- LessThan(left.col("b").expr, right.col("d").expr))
- }
-
testLeftSemiJoin(
"basic test",
left,
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
index 7383d3f8fe..80006bf077 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
@@ -28,17 +28,15 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.sql._
import org.apache.spark.sql.execution.ui.SparkPlanGraph
import org.apache.spark.sql.functions._
-import org.apache.spark.sql.test.{SQLTestUtils, TestSQLContext}
+import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.util.Utils
-class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
- override val sqlContext = TestSQLContext
-
- import sqlContext.implicits._
+class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
+ import testImplicits._
test("LongSQLMetric should not box Long") {
- val l = SQLMetrics.createLongMetric(TestSQLContext.sparkContext, "long")
+ val l = SQLMetrics.createLongMetric(ctx.sparkContext, "long")
val f = () => {
l += 1L
l.add(1L)
@@ -52,7 +50,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
test("Normal accumulator should do boxing") {
// We need this test to make sure BoxingFinder works.
- val l = TestSQLContext.sparkContext.accumulator(0L)
+ val l = ctx.sparkContext.accumulator(0L)
val f = () => { l += 1L }
BoxingFinder.getClassReader(f.getClass).foreach { cl =>
val boxingFinder = new BoxingFinder()
@@ -73,19 +71,19 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
df: DataFrame,
expectedNumOfJobs: Int,
expectedMetrics: Map[Long, (String, Map[String, Any])]): Unit = {
- val previousExecutionIds = TestSQLContext.listener.executionIdToData.keySet
+ val previousExecutionIds = ctx.listener.executionIdToData.keySet
df.collect()
- TestSQLContext.sparkContext.listenerBus.waitUntilEmpty(10000)
- val executionIds = TestSQLContext.listener.executionIdToData.keySet.diff(previousExecutionIds)
+ ctx.sparkContext.listenerBus.waitUntilEmpty(10000)
+ val executionIds = ctx.listener.executionIdToData.keySet.diff(previousExecutionIds)
assert(executionIds.size === 1)
val executionId = executionIds.head
- val jobs = TestSQLContext.listener.getExecution(executionId).get.jobs
+ val jobs = ctx.listener.getExecution(executionId).get.jobs
// Use "<=" because there is a race condition that we may miss some jobs
// TODO Change it to "=" once we fix the race condition that missing the JobStarted event.
assert(jobs.size <= expectedNumOfJobs)
if (jobs.size == expectedNumOfJobs) {
// If we can track all jobs, check the metric values
- val metricValues = TestSQLContext.listener.getExecutionMetrics(executionId)
+ val metricValues = ctx.listener.getExecutionMetrics(executionId)
val actualMetrics = SparkPlanGraph(df.queryExecution.executedPlan).nodes.filter { node =>
expectedMetrics.contains(node.id)
}.map { node =>
@@ -111,7 +109,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
SQLConf.TUNGSTEN_ENABLED.key -> "false") {
// Assume the execution plan is
// PhysicalRDD(nodeId = 1) -> Project(nodeId = 0)
- val df = TestData.person.select('name)
+ val df = person.select('name)
testSparkPlanMetrics(df, 1, Map(
0L ->("Project", Map(
"number of rows" -> 2L)))
@@ -126,7 +124,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
SQLConf.TUNGSTEN_ENABLED.key -> "true") {
// Assume the execution plan is
// PhysicalRDD(nodeId = 1) -> TungstenProject(nodeId = 0)
- val df = TestData.person.select('name)
+ val df = person.select('name)
testSparkPlanMetrics(df, 1, Map(
0L ->("TungstenProject", Map(
"number of rows" -> 2L)))
@@ -137,7 +135,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
test("Filter metrics") {
// Assume the execution plan is
// PhysicalRDD(nodeId = 1) -> Filter(nodeId = 0)
- val df = TestData.person.filter('age < 25)
+ val df = person.filter('age < 25)
testSparkPlanMetrics(df, 1, Map(
0L -> ("Filter", Map(
"number of input rows" -> 2L,
@@ -152,7 +150,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
SQLConf.TUNGSTEN_ENABLED.key -> "false") {
// Assume the execution plan is
// ... -> Aggregate(nodeId = 2) -> TungstenExchange(nodeId = 1) -> Aggregate(nodeId = 0)
- val df = TestData.testData2.groupBy().count() // 2 partitions
+ val df = testData2.groupBy().count() // 2 partitions
testSparkPlanMetrics(df, 1, Map(
2L -> ("Aggregate", Map(
"number of input rows" -> 6L,
@@ -163,7 +161,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
)
// 2 partitions and each partition contains 2 keys
- val df2 = TestData.testData2.groupBy('a).count()
+ val df2 = testData2.groupBy('a).count()
testSparkPlanMetrics(df2, 1, Map(
2L -> ("Aggregate", Map(
"number of input rows" -> 6L,
@@ -185,7 +183,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
// Assume the execution plan is
// ... -> SortBasedAggregate(nodeId = 2) -> TungstenExchange(nodeId = 1) ->
// SortBasedAggregate(nodeId = 0)
- val df = TestData.testData2.groupBy().count() // 2 partitions
+ val df = testData2.groupBy().count() // 2 partitions
testSparkPlanMetrics(df, 1, Map(
2L -> ("SortBasedAggregate", Map(
"number of input rows" -> 6L,
@@ -199,7 +197,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
// ... -> SortBasedAggregate(nodeId = 3) -> TungstenExchange(nodeId = 2)
// -> ExternalSort(nodeId = 1)-> SortBasedAggregate(nodeId = 0)
// 2 partitions and each partition contains 2 keys
- val df2 = TestData.testData2.groupBy('a).count()
+ val df2 = testData2.groupBy('a).count()
testSparkPlanMetrics(df2, 1, Map(
3L -> ("SortBasedAggregate", Map(
"number of input rows" -> 6L,
@@ -219,7 +217,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
// Assume the execution plan is
// ... -> TungstenAggregate(nodeId = 2) -> Exchange(nodeId = 1)
// -> TungstenAggregate(nodeId = 0)
- val df = TestData.testData2.groupBy().count() // 2 partitions
+ val df = testData2.groupBy().count() // 2 partitions
testSparkPlanMetrics(df, 1, Map(
2L -> ("TungstenAggregate", Map(
"number of input rows" -> 6L,
@@ -230,7 +228,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
)
// 2 partitions and each partition contains 2 keys
- val df2 = TestData.testData2.groupBy('a).count()
+ val df2 = testData2.groupBy('a).count()
testSparkPlanMetrics(df2, 1, Map(
2L -> ("TungstenAggregate", Map(
"number of input rows" -> 6L,
@@ -246,7 +244,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
// Because SortMergeJoin may skip different rows if the number of partitions is different, this
// test should use the deterministic number of partitions.
withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") {
- val testDataForJoin = TestData.testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2)
+ val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2)
testDataForJoin.registerTempTable("testDataForJoin")
withTempTable("testDataForJoin") {
// Assume the execution plan is
@@ -268,7 +266,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
// Because SortMergeOuterJoin may skip different rows if the number of partitions is different,
// this test should use the deterministic number of partitions.
withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") {
- val testDataForJoin = TestData.testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2)
+ val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2)
testDataForJoin.registerTempTable("testDataForJoin")
withTempTable("testDataForJoin") {
// Assume the execution plan is
@@ -314,7 +312,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
test("ShuffledHashJoin metrics") {
withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "false") {
- val testDataForJoin = TestData.testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2)
+ val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2)
testDataForJoin.registerTempTable("testDataForJoin")
withTempTable("testDataForJoin") {
// Assume the execution plan is
@@ -390,7 +388,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
test("BroadcastNestedLoopJoin metrics") {
withSQLConf(SQLConf.SORTMERGE_JOIN.key -> "true") {
- val testDataForJoin = TestData.testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2)
+ val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2)
testDataForJoin.registerTempTable("testDataForJoin")
withTempTable("testDataForJoin") {
// Assume the execution plan is
@@ -458,7 +456,7 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
}
test("CartesianProduct metrics") {
- val testDataForJoin = TestData.testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2)
+ val testDataForJoin = testData2.filter('a < 2) // TestData2(1, 1) :: TestData2(1, 2)
testDataForJoin.registerTempTable("testDataForJoin")
withTempTable("testDataForJoin") {
// Assume the execution plan is
@@ -476,19 +474,19 @@ class SQLMetricsSuite extends SparkFunSuite with SQLTestUtils {
test("save metrics") {
withTempPath { file =>
- val previousExecutionIds = TestSQLContext.listener.executionIdToData.keySet
+ val previousExecutionIds = ctx.listener.executionIdToData.keySet
// Assume the execution plan is
// PhysicalRDD(nodeId = 0)
- TestData.person.select('name).write.format("json").save(file.getAbsolutePath)
- TestSQLContext.sparkContext.listenerBus.waitUntilEmpty(10000)
- val executionIds = TestSQLContext.listener.executionIdToData.keySet.diff(previousExecutionIds)
+ person.select('name).write.format("json").save(file.getAbsolutePath)
+ ctx.sparkContext.listenerBus.waitUntilEmpty(10000)
+ val executionIds = ctx.listener.executionIdToData.keySet.diff(previousExecutionIds)
assert(executionIds.size === 1)
val executionId = executionIds.head
- val jobs = TestSQLContext.listener.getExecution(executionId).get.jobs
+ val jobs = ctx.listener.getExecution(executionId).get.jobs
// Use "<=" because there is a race condition that we may miss some jobs
// TODO Change "<=" to "=" once we fix the race condition that missing the JobStarted event.
assert(jobs.size <= 1)
- val metricValues = TestSQLContext.listener.getExecutionMetrics(executionId)
+ val metricValues = ctx.listener.getExecutionMetrics(executionId)
// Because "save" will create a new DataFrame internally, we cannot get the real metric id.
// However, we still can check the value.
assert(metricValues.values.toSeq === Seq(2L))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala
index 41dd1896c1..80d1e88956 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ui/SQLListenerSuite.scala
@@ -25,12 +25,12 @@ import org.apache.spark.sql.execution.metric.LongSQLMetricValue
import org.apache.spark.scheduler._
import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.sql.execution.SQLExecution
-import org.apache.spark.sql.test.TestSQLContext
+import org.apache.spark.sql.test.SharedSQLContext
-class SQLListenerSuite extends SparkFunSuite {
+class SQLListenerSuite extends SparkFunSuite with SharedSQLContext {
+ import testImplicits._
private def createTestDataFrame: DataFrame = {
- import TestSQLContext.implicits._
Seq(
(1, 1),
(2, 2)
@@ -74,7 +74,7 @@ class SQLListenerSuite extends SparkFunSuite {
}
test("basic") {
- val listener = new SQLListener(TestSQLContext)
+ val listener = new SQLListener(ctx)
val executionId = 0
val df = createTestDataFrame
val accumulatorIds =
@@ -212,7 +212,7 @@ class SQLListenerSuite extends SparkFunSuite {
}
test("onExecutionEnd happens before onJobEnd(JobSucceeded)") {
- val listener = new SQLListener(TestSQLContext)
+ val listener = new SQLListener(ctx)
val executionId = 0
val df = createTestDataFrame
listener.onExecutionStart(
@@ -241,7 +241,7 @@ class SQLListenerSuite extends SparkFunSuite {
}
test("onExecutionEnd happens before multiple onJobEnd(JobSucceeded)s") {
- val listener = new SQLListener(TestSQLContext)
+ val listener = new SQLListener(ctx)
val executionId = 0
val df = createTestDataFrame
listener.onExecutionStart(
@@ -281,7 +281,7 @@ class SQLListenerSuite extends SparkFunSuite {
}
test("onExecutionEnd happens before onJobEnd(JobFailed)") {
- val listener = new SQLListener(TestSQLContext)
+ val listener = new SQLListener(ctx)
val executionId = 0
val df = createTestDataFrame
listener.onExecutionStart(
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
index e4dcf4c75d..0edac0848c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala
@@ -25,10 +25,13 @@ import org.h2.jdbc.JdbcSQLException
import org.scalatest.BeforeAndAfter
import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
-class JDBCSuite extends SparkFunSuite with BeforeAndAfter {
+class JDBCSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext {
+ import testImplicits._
+
val url = "jdbc:h2:mem:testdb0"
val urlWithUserAndPass = "jdbc:h2:mem:testdb0;user=testUser;password=testPass"
var conn: java.sql.Connection = null
@@ -42,10 +45,6 @@ class JDBCSuite extends SparkFunSuite with BeforeAndAfter {
Some(StringType)
}
- private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
- import ctx.implicits._
- import ctx.sql
-
before {
Utils.classForName("org.h2.Driver")
// Extra properties that will be specified for our database. We need these to test
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
index 84b52ca2c7..5dc3a2c07b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCWriteSuite.scala
@@ -23,11 +23,13 @@ import java.util.Properties
import org.scalatest.BeforeAndAfter
import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.{SaveMode, Row}
+import org.apache.spark.sql.{Row, SaveMode}
+import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
-class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter {
+class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter with SharedSQLContext {
+
val url = "jdbc:h2:mem:testdb2"
var conn: java.sql.Connection = null
val url1 = "jdbc:h2:mem:testdb3"
@@ -37,10 +39,6 @@ class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter {
properties.setProperty("password", "testPass")
properties.setProperty("rowId", "false")
- private lazy val ctx = org.apache.spark.sql.test.TestSQLContext
- import ctx.implicits._
- import ctx.sql
-
before {
Utils.classForName("org.h2.Driver")
conn = DriverManager.getConnection(url)
@@ -58,14 +56,14 @@ class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter {
"create table test.people1 (name TEXT(32) NOT NULL, theid INTEGER NOT NULL)").executeUpdate()
conn1.commit()
- ctx.sql(
+ sql(
s"""
|CREATE TEMPORARY TABLE PEOPLE
|USING org.apache.spark.sql.jdbc
|OPTIONS (url '$url1', dbtable 'TEST.PEOPLE', user 'testUser', password 'testPass')
""".stripMargin.replaceAll("\n", " "))
- ctx.sql(
+ sql(
s"""
|CREATE TEMPORARY TABLE PEOPLE1
|USING org.apache.spark.sql.jdbc
@@ -144,14 +142,14 @@ class JDBCWriteSuite extends SparkFunSuite with BeforeAndAfter {
}
test("INSERT to JDBC Datasource") {
- ctx.sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE")
+ sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE")
assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).count)
assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length)
}
test("INSERT to JDBC Datasource with overwrite") {
- ctx.sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE")
- ctx.sql("INSERT OVERWRITE TABLE PEOPLE1 SELECT * FROM PEOPLE")
+ sql("INSERT INTO TABLE PEOPLE1 SELECT * FROM PEOPLE")
+ sql("INSERT OVERWRITE TABLE PEOPLE1 SELECT * FROM PEOPLE")
assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).count)
assert(2 === ctx.read.jdbc(url1, "TEST.PEOPLE1", properties).collect()(0).length)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala
index 562c279067..9bc3f6bcf6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/CreateTableAsSelectSuite.scala
@@ -19,28 +19,32 @@ package org.apache.spark.sql.sources
import java.io.{File, IOException}
-import org.scalatest.BeforeAndAfterAll
+import org.scalatest.BeforeAndAfter
import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.execution.datasources.DDLException
+import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.util.Utils
-class CreateTableAsSelectSuite extends DataSourceTest with BeforeAndAfterAll {
-
- import caseInsensitiveContext.sql
+class CreateTableAsSelectSuite extends DataSourceTest with SharedSQLContext with BeforeAndAfter {
+ protected override lazy val sql = caseInsensitiveContext.sql _
private lazy val sparkContext = caseInsensitiveContext.sparkContext
-
- var path: File = null
+ private var path: File = null
override def beforeAll(): Unit = {
+ super.beforeAll()
path = Utils.createTempDir()
val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str${i}"}"""))
caseInsensitiveContext.read.json(rdd).registerTempTable("jt")
}
override def afterAll(): Unit = {
- caseInsensitiveContext.dropTempTable("jt")
+ try {
+ caseInsensitiveContext.dropTempTable("jt")
+ } finally {
+ super.afterAll()
+ }
}
after {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala
index 392da0b082..853707c036 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLSourceLoadSuite.scala
@@ -18,11 +18,12 @@
package org.apache.spark.sql.sources
import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types.{StringType, StructField, StructType}
// please note that the META-INF/services had to be modified for the test directory for this to work
-class DDLSourceLoadSuite extends DataSourceTest {
+class DDLSourceLoadSuite extends DataSourceTest with SharedSQLContext {
test("data sources with the same name") {
intercept[RuntimeException] {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala
index 84855ce45e..5f8514e1a2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DDLTestSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.sources
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
@@ -68,10 +69,12 @@ case class SimpleDDLScan(from: Int, to: Int, table: String)(@transient val sqlCo
}
}
-class DDLTestSuite extends DataSourceTest {
+class DDLTestSuite extends DataSourceTest with SharedSQLContext {
+ protected override lazy val sql = caseInsensitiveContext.sql _
- before {
- caseInsensitiveContext.sql(
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ sql(
"""
|CREATE TEMPORARY TABLE ddlPeople
|USING org.apache.spark.sql.sources.DDLScanSource
@@ -105,7 +108,7 @@ class DDLTestSuite extends DataSourceTest {
))
test("SPARK-7686 DescribeCommand should have correct physical plan output attributes") {
- val attributes = caseInsensitiveContext.sql("describe ddlPeople")
+ val attributes = sql("describe ddlPeople")
.queryExecution.executedPlan.output
assert(attributes.map(_.name) === Seq("col_name", "data_type", "comment"))
assert(attributes.map(_.dataType).toSet === Set(StringType))
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala
index 00cc7d5ea5..d74d29fb0b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/DataSourceTest.scala
@@ -17,18 +17,23 @@
package org.apache.spark.sql.sources
-import org.scalatest.BeforeAndAfter
-
import org.apache.spark.sql._
-import org.apache.spark.sql.test.TestSQLContext
-abstract class DataSourceTest extends QueryTest with BeforeAndAfter {
+private[sql] abstract class DataSourceTest extends QueryTest {
+ protected def _sqlContext: SQLContext
+
// We want to test some edge cases.
- protected implicit lazy val caseInsensitiveContext = {
- val ctx = new SQLContext(TestSQLContext.sparkContext)
+ protected lazy val caseInsensitiveContext: SQLContext = {
+ val ctx = new SQLContext(_sqlContext.sparkContext)
ctx.setConf(SQLConf.CASE_SENSITIVE, false)
ctx
}
+ protected def sqlTest(sqlString: String, expectedAnswer: Seq[Row]) {
+ test(sqlString) {
+ checkAnswer(caseInsensitiveContext.sql(sqlString), expectedAnswer)
+ }
+ }
+
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala
index 5ef365797e..c81c3d3982 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/FilteredScanSuite.scala
@@ -21,6 +21,7 @@ import scala.language.existentials
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
+import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
@@ -96,11 +97,11 @@ object FiltersPushed {
var list: Seq[Filter] = Nil
}
-class FilteredScanSuite extends DataSourceTest {
+class FilteredScanSuite extends DataSourceTest with SharedSQLContext {
+ protected override lazy val sql = caseInsensitiveContext.sql _
- import caseInsensitiveContext.sql
-
- before {
+ override def beforeAll(): Unit = {
+ super.beforeAll()
sql(
"""
|CREATE TEMPORARY TABLE oneToTenFiltered
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala
index cdbfaf6455..78bd3e5582 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala
@@ -19,20 +19,17 @@ package org.apache.spark.sql.sources
import java.io.File
-import org.scalatest.BeforeAndAfterAll
-
import org.apache.spark.sql.{SaveMode, AnalysisException, Row}
+import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.util.Utils
-class InsertSuite extends DataSourceTest with BeforeAndAfterAll {
-
- import caseInsensitiveContext.sql
-
+class InsertSuite extends DataSourceTest with SharedSQLContext {
+ protected override lazy val sql = caseInsensitiveContext.sql _
private lazy val sparkContext = caseInsensitiveContext.sparkContext
-
- var path: File = null
+ private var path: File = null
override def beforeAll(): Unit = {
+ super.beforeAll()
path = Utils.createTempDir()
val rdd = sparkContext.parallelize((1 to 10).map(i => s"""{"a":$i, "b":"str$i"}"""))
caseInsensitiveContext.read.json(rdd).registerTempTable("jt")
@@ -47,9 +44,13 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll {
}
override def afterAll(): Unit = {
- caseInsensitiveContext.dropTempTable("jsonTable")
- caseInsensitiveContext.dropTempTable("jt")
- Utils.deleteRecursively(path)
+ try {
+ caseInsensitiveContext.dropTempTable("jsonTable")
+ caseInsensitiveContext.dropTempTable("jt")
+ Utils.deleteRecursively(path)
+ } finally {
+ super.afterAll()
+ }
}
test("Simple INSERT OVERWRITE a JSONRelation") {
@@ -221,9 +222,10 @@ class InsertSuite extends DataSourceTest with BeforeAndAfterAll {
sql("SELECT a * 2 FROM jsonTable"),
(1 to 10).map(i => Row(i * 2)).toSeq)
- assertCached(sql("SELECT x.a, y.a FROM jsonTable x JOIN jsonTable y ON x.a = y.a + 1"), 2)
- checkAnswer(
- sql("SELECT x.a, y.a FROM jsonTable x JOIN jsonTable y ON x.a = y.a + 1"),
+ assertCached(sql(
+ "SELECT x.a, y.a FROM jsonTable x JOIN jsonTable y ON x.a = y.a + 1"), 2)
+ checkAnswer(sql(
+ "SELECT x.a, y.a FROM jsonTable x JOIN jsonTable y ON x.a = y.a + 1"),
(2 to 10).map(i => Row(i, i - 1)).toSeq)
// Insert overwrite and keep the same schema.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala
index c86ddd7c83..79b6e9b45c 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PartitionedWriteSuite.scala
@@ -19,21 +19,21 @@ package org.apache.spark.sql.sources
import org.apache.spark.sql.{Row, QueryTest}
import org.apache.spark.sql.functions._
-import org.apache.spark.sql.test.TestSQLContext
+import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.util.Utils
-class PartitionedWriteSuite extends QueryTest {
- import TestSQLContext.implicits._
+class PartitionedWriteSuite extends QueryTest with SharedSQLContext {
+ import testImplicits._
test("write many partitions") {
val path = Utils.createTempDir()
path.delete()
- val df = TestSQLContext.range(100).select($"id", lit(1).as("data"))
+ val df = ctx.range(100).select($"id", lit(1).as("data"))
df.write.partitionBy("id").save(path.getCanonicalPath)
checkAnswer(
- TestSQLContext.read.load(path.getCanonicalPath),
+ ctx.read.load(path.getCanonicalPath),
(0 to 99).map(Row(1, _)).toSeq)
Utils.deleteRecursively(path)
@@ -43,12 +43,12 @@ class PartitionedWriteSuite extends QueryTest {
val path = Utils.createTempDir()
path.delete()
- val base = TestSQLContext.range(100)
+ val base = ctx.range(100)
val df = base.unionAll(base).select($"id", lit(1).as("data"))
df.write.partitionBy("id").save(path.getCanonicalPath)
checkAnswer(
- TestSQLContext.read.load(path.getCanonicalPath),
+ ctx.read.load(path.getCanonicalPath),
(0 to 99).map(Row(1, _)).toSeq ++ (0 to 99).map(Row(1, _)).toSeq)
Utils.deleteRecursively(path)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala
index 0d5183444a..a89c5f8007 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/PrunedScanSuite.scala
@@ -21,6 +21,7 @@ import scala.language.existentials
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
+import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
class PrunedScanSource extends RelationProvider {
@@ -51,10 +52,12 @@ case class SimplePrunedScan(from: Int, to: Int)(@transient val sqlContext: SQLCo
}
}
-class PrunedScanSuite extends DataSourceTest {
+class PrunedScanSuite extends DataSourceTest with SharedSQLContext {
+ protected override lazy val sql = caseInsensitiveContext.sql _
- before {
- caseInsensitiveContext.sql(
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ sql(
"""
|CREATE TEMPORARY TABLE oneToTenPruned
|USING org.apache.spark.sql.sources.PrunedScanSource
@@ -114,7 +117,7 @@ class PrunedScanSuite extends DataSourceTest {
def testPruning(sqlString: String, expectedColumns: String*): Unit = {
test(s"Columns output ${expectedColumns.mkString(",")}: $sqlString") {
- val queryExecution = caseInsensitiveContext.sql(sqlString).queryExecution
+ val queryExecution = sql(sqlString).queryExecution
val rawPlan = queryExecution.executedPlan.collect {
case p: execution.PhysicalRDD => p
} match {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala
index 31730a3d3f..f18546b4c2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/SaveLoadSuite.scala
@@ -19,25 +19,22 @@ package org.apache.spark.sql.sources
import java.io.File
-import org.scalatest.BeforeAndAfterAll
+import org.scalatest.BeforeAndAfter
import org.apache.spark.sql.{AnalysisException, SaveMode, SQLConf, DataFrame}
+import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
-class SaveLoadSuite extends DataSourceTest with BeforeAndAfterAll {
-
- import caseInsensitiveContext.sql
-
+class SaveLoadSuite extends DataSourceTest with SharedSQLContext with BeforeAndAfter {
+ protected override lazy val sql = caseInsensitiveContext.sql _
private lazy val sparkContext = caseInsensitiveContext.sparkContext
-
- var originalDefaultSource: String = null
-
- var path: File = null
-
- var df: DataFrame = null
+ private var originalDefaultSource: String = null
+ private var path: File = null
+ private var df: DataFrame = null
override def beforeAll(): Unit = {
+ super.beforeAll()
originalDefaultSource = caseInsensitiveContext.conf.defaultDataSourceName
path = Utils.createTempDir()
@@ -49,11 +46,14 @@ class SaveLoadSuite extends DataSourceTest with BeforeAndAfterAll {
}
override def afterAll(): Unit = {
- caseInsensitiveContext.conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource)
+ try {
+ caseInsensitiveContext.conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource)
+ } finally {
+ super.afterAll()
+ }
}
after {
- caseInsensitiveContext.conf.setConf(SQLConf.DEFAULT_DATA_SOURCE_NAME, originalDefaultSource)
Utils.deleteRecursively(path)
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
index e34e0956d1..12af8068c3 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/TableScanSuite.scala
@@ -22,6 +22,7 @@ import java.sql.{Date, Timestamp}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql._
+import org.apache.spark.sql.test.SharedSQLContext
import org.apache.spark.sql.types._
class DefaultSource extends SimpleScanSource
@@ -95,8 +96,8 @@ case class AllDataTypesScan(
}
}
-class TableScanSuite extends DataSourceTest {
- import caseInsensitiveContext.sql
+class TableScanSuite extends DataSourceTest with SharedSQLContext {
+ protected override lazy val sql = caseInsensitiveContext.sql _
private lazy val tableWithSchemaExpected = (1 to 10).map { i =>
Row(
@@ -122,7 +123,8 @@ class TableScanSuite extends DataSourceTest {
Row(Seq(s"str_$i", s"str_${i + 1}"), Row(Seq(Date.valueOf(s"1970-01-${i + 1}")))))
}.toSeq
- before {
+ override def beforeAll(): Unit = {
+ super.beforeAll()
sql(
"""
|CREATE TEMPORARY TABLE oneToTen
@@ -303,9 +305,10 @@ class TableScanSuite extends DataSourceTest {
sql("SELECT i * 2 FROM oneToTen"),
(1 to 10).map(i => Row(i * 2)).toSeq)
- assertCached(sql("SELECT a.i, b.i FROM oneToTen a JOIN oneToTen b ON a.i = b.i + 1"), 2)
- checkAnswer(
- sql("SELECT a.i, b.i FROM oneToTen a JOIN oneToTen b ON a.i = b.i + 1"),
+ assertCached(sql(
+ "SELECT a.i, b.i FROM oneToTen a JOIN oneToTen b ON a.i = b.i + 1"), 2)
+ checkAnswer(sql(
+ "SELECT a.i, b.i FROM oneToTen a JOIN oneToTen b ON a.i = b.i + 1"),
(2 to 10).map(i => Row(i, i - 1)).toSeq)
// Verify uncaching
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala
new file mode 100644
index 0000000000..1374a97476
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestData.scala
@@ -0,0 +1,290 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.test
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{DataFrame, SQLContext, SQLImplicits}
+
+/**
+ * A collection of sample data used in SQL tests.
+ */
+private[sql] trait SQLTestData { self =>
+ protected def _sqlContext: SQLContext
+
+ // Helper object to import SQL implicits without a concrete SQLContext
+ private object internalImplicits extends SQLImplicits {
+ protected override def _sqlContext: SQLContext = self._sqlContext
+ }
+
+ import internalImplicits._
+ import SQLTestData._
+
+ // Note: all test data should be lazy because the SQLContext is not set up yet.
+
+ protected lazy val testData: DataFrame = {
+ val df = _sqlContext.sparkContext.parallelize(
+ (1 to 100).map(i => TestData(i, i.toString))).toDF()
+ df.registerTempTable("testData")
+ df
+ }
+
+ protected lazy val testData2: DataFrame = {
+ val df = _sqlContext.sparkContext.parallelize(
+ TestData2(1, 1) ::
+ TestData2(1, 2) ::
+ TestData2(2, 1) ::
+ TestData2(2, 2) ::
+ TestData2(3, 1) ::
+ TestData2(3, 2) :: Nil, 2).toDF()
+ df.registerTempTable("testData2")
+ df
+ }
+
+ protected lazy val testData3: DataFrame = {
+ val df = _sqlContext.sparkContext.parallelize(
+ TestData3(1, None) ::
+ TestData3(2, Some(2)) :: Nil).toDF()
+ df.registerTempTable("testData3")
+ df
+ }
+
+ protected lazy val negativeData: DataFrame = {
+ val df = _sqlContext.sparkContext.parallelize(
+ (1 to 100).map(i => TestData(-i, (-i).toString))).toDF()
+ df.registerTempTable("negativeData")
+ df
+ }
+
+ protected lazy val largeAndSmallInts: DataFrame = {
+ val df = _sqlContext.sparkContext.parallelize(
+ LargeAndSmallInts(2147483644, 1) ::
+ LargeAndSmallInts(1, 2) ::
+ LargeAndSmallInts(2147483645, 1) ::
+ LargeAndSmallInts(2, 2) ::
+ LargeAndSmallInts(2147483646, 1) ::
+ LargeAndSmallInts(3, 2) :: Nil).toDF()
+ df.registerTempTable("largeAndSmallInts")
+ df
+ }
+
+ protected lazy val decimalData: DataFrame = {
+ val df = _sqlContext.sparkContext.parallelize(
+ DecimalData(1, 1) ::
+ DecimalData(1, 2) ::
+ DecimalData(2, 1) ::
+ DecimalData(2, 2) ::
+ DecimalData(3, 1) ::
+ DecimalData(3, 2) :: Nil).toDF()
+ df.registerTempTable("decimalData")
+ df
+ }
+
+ protected lazy val binaryData: DataFrame = {
+ val df = _sqlContext.sparkContext.parallelize(
+ BinaryData("12".getBytes, 1) ::
+ BinaryData("22".getBytes, 5) ::
+ BinaryData("122".getBytes, 3) ::
+ BinaryData("121".getBytes, 2) ::
+ BinaryData("123".getBytes, 4) :: Nil).toDF()
+ df.registerTempTable("binaryData")
+ df
+ }
+
+ protected lazy val upperCaseData: DataFrame = {
+ val df = _sqlContext.sparkContext.parallelize(
+ UpperCaseData(1, "A") ::
+ UpperCaseData(2, "B") ::
+ UpperCaseData(3, "C") ::
+ UpperCaseData(4, "D") ::
+ UpperCaseData(5, "E") ::
+ UpperCaseData(6, "F") :: Nil).toDF()
+ df.registerTempTable("upperCaseData")
+ df
+ }
+
+ protected lazy val lowerCaseData: DataFrame = {
+ val df = _sqlContext.sparkContext.parallelize(
+ LowerCaseData(1, "a") ::
+ LowerCaseData(2, "b") ::
+ LowerCaseData(3, "c") ::
+ LowerCaseData(4, "d") :: Nil).toDF()
+ df.registerTempTable("lowerCaseData")
+ df
+ }
+
+ protected lazy val arrayData: RDD[ArrayData] = {
+ val rdd = _sqlContext.sparkContext.parallelize(
+ ArrayData(Seq(1, 2, 3), Seq(Seq(1, 2, 3))) ::
+ ArrayData(Seq(2, 3, 4), Seq(Seq(2, 3, 4))) :: Nil)
+ rdd.toDF().registerTempTable("arrayData")
+ rdd
+ }
+
+ protected lazy val mapData: RDD[MapData] = {
+ val rdd = _sqlContext.sparkContext.parallelize(
+ MapData(Map(1 -> "a1", 2 -> "b1", 3 -> "c1", 4 -> "d1", 5 -> "e1")) ::
+ MapData(Map(1 -> "a2", 2 -> "b2", 3 -> "c2", 4 -> "d2")) ::
+ MapData(Map(1 -> "a3", 2 -> "b3", 3 -> "c3")) ::
+ MapData(Map(1 -> "a4", 2 -> "b4")) ::
+ MapData(Map(1 -> "a5")) :: Nil)
+ rdd.toDF().registerTempTable("mapData")
+ rdd
+ }
+
+ protected lazy val repeatedData: RDD[StringData] = {
+ val rdd = _sqlContext.sparkContext.parallelize(List.fill(2)(StringData("test")))
+ rdd.toDF().registerTempTable("repeatedData")
+ rdd
+ }
+
+ protected lazy val nullableRepeatedData: RDD[StringData] = {
+ val rdd = _sqlContext.sparkContext.parallelize(
+ List.fill(2)(StringData(null)) ++
+ List.fill(2)(StringData("test")))
+ rdd.toDF().registerTempTable("nullableRepeatedData")
+ rdd
+ }
+
+ protected lazy val nullInts: DataFrame = {
+ val df = _sqlContext.sparkContext.parallelize(
+ NullInts(1) ::
+ NullInts(2) ::
+ NullInts(3) ::
+ NullInts(null) :: Nil).toDF()
+ df.registerTempTable("nullInts")
+ df
+ }
+
+ protected lazy val allNulls: DataFrame = {
+ val df = _sqlContext.sparkContext.parallelize(
+ NullInts(null) ::
+ NullInts(null) ::
+ NullInts(null) ::
+ NullInts(null) :: Nil).toDF()
+ df.registerTempTable("allNulls")
+ df
+ }
+
+ protected lazy val nullStrings: DataFrame = {
+ val df = _sqlContext.sparkContext.parallelize(
+ NullStrings(1, "abc") ::
+ NullStrings(2, "ABC") ::
+ NullStrings(3, null) :: Nil).toDF()
+ df.registerTempTable("nullStrings")
+ df
+ }
+
+ protected lazy val tableName: DataFrame = {
+ val df = _sqlContext.sparkContext.parallelize(TableName("test") :: Nil).toDF()
+ df.registerTempTable("tableName")
+ df
+ }
+
+ protected lazy val unparsedStrings: RDD[String] = {
+ _sqlContext.sparkContext.parallelize(
+ "1, A1, true, null" ::
+ "2, B2, false, null" ::
+ "3, C3, true, null" ::
+ "4, D4, true, 2147483644" :: Nil)
+ }
+
+ // An RDD with 4 elements and 8 partitions
+ protected lazy val withEmptyParts: RDD[IntField] = {
+ val rdd = _sqlContext.sparkContext.parallelize((1 to 4).map(IntField), 8)
+ rdd.toDF().registerTempTable("withEmptyParts")
+ rdd
+ }
+
+ protected lazy val person: DataFrame = {
+ val df = _sqlContext.sparkContext.parallelize(
+ Person(0, "mike", 30) ::
+ Person(1, "jim", 20) :: Nil).toDF()
+ df.registerTempTable("person")
+ df
+ }
+
+ protected lazy val salary: DataFrame = {
+ val df = _sqlContext.sparkContext.parallelize(
+ Salary(0, 2000.0) ::
+ Salary(1, 1000.0) :: Nil).toDF()
+ df.registerTempTable("salary")
+ df
+ }
+
+ protected lazy val complexData: DataFrame = {
+ val df = _sqlContext.sparkContext.parallelize(
+ ComplexData(Map("1" -> 1), TestData(1, "1"), Seq(1, 1, 1), true) ::
+ ComplexData(Map("2" -> 2), TestData(2, "2"), Seq(2, 2, 2), false) ::
+ Nil).toDF()
+ df.registerTempTable("complexData")
+ df
+ }
+
+ /**
+ * Initialize all test data such that all temp tables are properly registered.
+ */
+ def loadTestData(): Unit = {
+ assert(_sqlContext != null, "attempted to initialize test data before SQLContext.")
+ testData
+ testData2
+ testData3
+ negativeData
+ largeAndSmallInts
+ decimalData
+ binaryData
+ upperCaseData
+ lowerCaseData
+ arrayData
+ mapData
+ repeatedData
+ nullableRepeatedData
+ nullInts
+ allNulls
+ nullStrings
+ tableName
+ unparsedStrings
+ withEmptyParts
+ person
+ salary
+ complexData
+ }
+}
+
+/**
+ * Case classes used in test data.
+ */
+private[sql] object SQLTestData {
+ case class TestData(key: Int, value: String)
+ case class TestData2(a: Int, b: Int)
+ case class TestData3(a: Int, b: Option[Int])
+ case class LargeAndSmallInts(a: Int, b: Int)
+ case class DecimalData(a: BigDecimal, b: BigDecimal)
+ case class BinaryData(a: Array[Byte], b: Int)
+ case class UpperCaseData(N: Int, L: String)
+ case class LowerCaseData(n: Int, l: String)
+ case class ArrayData(data: Seq[Int], nestedData: Seq[Seq[Int]])
+ case class MapData(data: scala.collection.Map[Int, String])
+ case class StringData(s: String)
+ case class IntField(i: Int)
+ case class NullInts(a: Integer)
+ case class NullStrings(n: Int, s: String)
+ case class TableName(tableName: String)
+ case class Person(id: Int, name: String, age: Int)
+ case class Salary(personId: Int, salary: Double)
+ case class ComplexData(m: Map[String, Int], s: TestData, a: Seq[Int], b: Boolean)
+}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
index 1066695589..cdd691e035 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SQLTestUtils.scala
@@ -21,15 +21,71 @@ import java.io.File
import java.util.UUID
import scala.util.Try
+import scala.language.implicitConversions
+
+import org.apache.hadoop.conf.Configuration
+import org.scalatest.BeforeAndAfterAll
import org.apache.spark.SparkFunSuite
-import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.{DataFrame, SQLContext, SQLImplicits}
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.util.Utils
-trait SQLTestUtils { this: SparkFunSuite =>
- protected def sqlContext: SQLContext
+/**
+ * Helper trait that should be extended by all SQL test suites.
+ *
+ * This allows subclasses to plugin a custom [[SQLContext]]. It comes with test data
+ * prepared in advance as well as all implicit conversions used extensively by dataframes.
+ * To use implicit methods, import `testImplicits._` instead of through the [[SQLContext]].
+ *
+ * Subclasses should *not* create [[SQLContext]]s in the test suite constructor, which is
+ * prone to leaving multiple overlapping [[org.apache.spark.SparkContext]]s in the same JVM.
+ */
+private[sql] trait SQLTestUtils
+ extends SparkFunSuite
+ with BeforeAndAfterAll
+ with SQLTestData { self =>
+
+ protected def _sqlContext: SQLContext
+
+ // Whether to materialize all test data before the first test is run
+ private var loadTestDataBeforeTests = false
+
+ // Shorthand for running a query using our SQLContext
+ protected lazy val sql = _sqlContext.sql _
+
+ /**
+ * A helper object for importing SQL implicits.
+ *
+ * Note that the alternative of importing `sqlContext.implicits._` is not possible here.
+ * This is because we create the [[SQLContext]] immediately before the first test is run,
+ * but the implicits import is needed in the constructor.
+ */
+ protected object testImplicits extends SQLImplicits {
+ protected override def _sqlContext: SQLContext = self._sqlContext
+ }
+
+ /**
+ * Materialize the test data immediately after the [[SQLContext]] is set up.
+ * This is necessary if the data is accessed by name but not through direct reference.
+ */
+ protected def setupTestData(): Unit = {
+ loadTestDataBeforeTests = true
+ }
- protected def configuration = sqlContext.sparkContext.hadoopConfiguration
+ protected override def beforeAll(): Unit = {
+ super.beforeAll()
+ if (loadTestDataBeforeTests) {
+ loadTestData()
+ }
+ }
+
+ /**
+ * The Hadoop configuration used by the active [[SQLContext]].
+ */
+ protected def configuration: Configuration = {
+ _sqlContext.sparkContext.hadoopConfiguration
+ }
/**
* Sets all SQL configurations specified in `pairs`, calls `f`, and then restore all SQL
@@ -39,12 +95,12 @@ trait SQLTestUtils { this: SparkFunSuite =>
*/
protected def withSQLConf(pairs: (String, String)*)(f: => Unit): Unit = {
val (keys, values) = pairs.unzip
- val currentValues = keys.map(key => Try(sqlContext.conf.getConfString(key)).toOption)
- (keys, values).zipped.foreach(sqlContext.conf.setConfString)
+ val currentValues = keys.map(key => Try(_sqlContext.conf.getConfString(key)).toOption)
+ (keys, values).zipped.foreach(_sqlContext.conf.setConfString)
try f finally {
keys.zip(currentValues).foreach {
- case (key, Some(value)) => sqlContext.conf.setConfString(key, value)
- case (key, None) => sqlContext.conf.unsetConf(key)
+ case (key, Some(value)) => _sqlContext.conf.setConfString(key, value)
+ case (key, None) => _sqlContext.conf.unsetConf(key)
}
}
}
@@ -76,7 +132,7 @@ trait SQLTestUtils { this: SparkFunSuite =>
* Drops temporary table `tableName` after calling `f`.
*/
protected def withTempTable(tableNames: String*)(f: => Unit): Unit = {
- try f finally tableNames.foreach(sqlContext.dropTempTable)
+ try f finally tableNames.foreach(_sqlContext.dropTempTable)
}
/**
@@ -85,7 +141,7 @@ trait SQLTestUtils { this: SparkFunSuite =>
protected def withTable(tableNames: String*)(f: => Unit): Unit = {
try f finally {
tableNames.foreach { name =>
- sqlContext.sql(s"DROP TABLE IF EXISTS $name")
+ _sqlContext.sql(s"DROP TABLE IF EXISTS $name")
}
}
}
@@ -98,12 +154,12 @@ trait SQLTestUtils { this: SparkFunSuite =>
val dbName = s"db_${UUID.randomUUID().toString.replace('-', '_')}"
try {
- sqlContext.sql(s"CREATE DATABASE $dbName")
+ _sqlContext.sql(s"CREATE DATABASE $dbName")
} catch { case cause: Throwable =>
fail("Failed to create temporary database", cause)
}
- try f(dbName) finally sqlContext.sql(s"DROP DATABASE $dbName CASCADE")
+ try f(dbName) finally _sqlContext.sql(s"DROP DATABASE $dbName CASCADE")
}
/**
@@ -111,7 +167,15 @@ trait SQLTestUtils { this: SparkFunSuite =>
* `f` returns.
*/
protected def activateDatabase(db: String)(f: => Unit): Unit = {
- sqlContext.sql(s"USE $db")
- try f finally sqlContext.sql(s"USE default")
+ _sqlContext.sql(s"USE $db")
+ try f finally _sqlContext.sql(s"USE default")
+ }
+
+ /**
+ * Turn a logical plan into a [[DataFrame]]. This should be removed once we have an easier
+ * way to construct [[DataFrame]] directly out of local data without relying on implicits.
+ */
+ protected implicit def logicalPlanToSparkQuery(plan: LogicalPlan): DataFrame = {
+ DataFrame(_sqlContext, plan)
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala
new file mode 100644
index 0000000000..3cfd822e2a
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSQLContext.scala
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.test
+
+import org.apache.spark.sql.SQLContext
+
+
+/**
+ * Helper trait for SQL test suites where all tests share a single [[TestSQLContext]].
+ */
+private[sql] trait SharedSQLContext extends SQLTestUtils {
+
+ /**
+ * The [[TestSQLContext]] to use for all tests in this suite.
+ *
+ * By default, the underlying [[org.apache.spark.SparkContext]] will be run in local
+ * mode with the default test configurations.
+ */
+ private var _ctx: TestSQLContext = null
+
+ /**
+ * The [[TestSQLContext]] to use for all tests in this suite.
+ */
+ protected def ctx: TestSQLContext = _ctx
+ protected def sqlContext: TestSQLContext = _ctx
+ protected override def _sqlContext: SQLContext = _ctx
+
+ /**
+ * Initialize the [[TestSQLContext]].
+ */
+ protected override def beforeAll(): Unit = {
+ if (_ctx == null) {
+ _ctx = new TestSQLContext
+ }
+ // Ensure we have initialized the context before calling parent code
+ super.beforeAll()
+ }
+
+ /**
+ * Stop the underlying [[org.apache.spark.SparkContext]], if any.
+ */
+ protected override def afterAll(): Unit = {
+ try {
+ if (_ctx != null) {
+ _ctx.sparkContext.stop()
+ _ctx = null
+ }
+ } finally {
+ super.afterAll()
+ }
+ }
+
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala
index b3a4231da9..92ef2f7d74 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/test/TestSQLContext.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala
@@ -17,40 +17,36 @@
package org.apache.spark.sql.test
-import scala.language.implicitConversions
-
import org.apache.spark.{SparkConf, SparkContext}
-import org.apache.spark.sql.{DataFrame, SQLConf, SQLContext}
-import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-
-/** A SQLContext that can be used for local testing. */
-class LocalSQLContext
- extends SQLContext(
- new SparkContext("local[2]", "TestSQLContext", new SparkConf()
- .set("spark.sql.testkey", "true")
- // SPARK-8910
- .set("spark.ui.enabled", "false"))) {
-
- override protected[sql] def createSession(): SQLSession = {
- new this.SQLSession()
+import org.apache.spark.sql.{SQLConf, SQLContext}
+
+
+/**
+ * A special [[SQLContext]] prepared for testing.
+ */
+private[sql] class TestSQLContext(sc: SparkContext) extends SQLContext(sc) { self =>
+
+ def this() {
+ this(new SparkContext("local[2]", "test-sql-context",
+ new SparkConf().set("spark.sql.testkey", "true")))
}
+ // Use fewer partitions to speed up testing
+ protected[sql] override def createSession(): SQLSession = new this.SQLSession()
+
+ /** A special [[SQLSession]] that uses fewer shuffle partitions than normal. */
protected[sql] class SQLSession extends super.SQLSession {
protected[sql] override lazy val conf: SQLConf = new SQLConf {
- /** Fewer partitions to speed up testing. */
override def numShufflePartitions: Int = this.getConf(SQLConf.SHUFFLE_PARTITIONS, 5)
}
}
- /**
- * Turn a logical plan into a [[DataFrame]]. This should be removed once we have an easier way to
- * construct [[DataFrame]] directly out of local data without relying on implicits.
- */
- protected[sql] implicit def logicalPlanToSparkQuery(plan: LogicalPlan): DataFrame = {
- DataFrame(this, plan)
+ // Needed for Java tests
+ def loadTestData(): Unit = {
+ testData.loadTestData()
}
+ private object testData extends SQLTestData {
+ protected override def _sqlContext: SQLContext = self
+ }
}
-
-object TestSQLContext extends LocalSQLContext
-
diff --git a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala
index 806240e6de..bf431cd6b0 100644
--- a/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala
+++ b/sql/hive-thriftserver/src/test/scala/org/apache/spark/sql/hive/thriftserver/UISeleniumSuite.scala
@@ -27,7 +27,6 @@ import org.scalatest.concurrent.Eventually._
import org.scalatest.selenium.WebBrowser
import org.scalatest.time.SpanSugar._
-import org.apache.spark.sql.hive.HiveContext
import org.apache.spark.ui.SparkUICssErrorHandler
class UISeleniumSuite
@@ -36,7 +35,6 @@ class UISeleniumSuite
implicit var webDriver: WebDriver = _
var server: HiveThriftServer2 = _
- var hc: HiveContext = _
val uiPort = 20000 + Random.nextInt(10000)
override def mode: ServerMode.Value = ServerMode.binary
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala
index 59e65ff97b..574624d501 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveMetastoreCatalogSuite.scala
@@ -26,7 +26,7 @@ import org.apache.spark.sql.hive.test.TestHive.implicits._
import org.apache.spark.sql.sources.DataSourceTest
import org.apache.spark.sql.test.{ExamplePointUDT, SQLTestUtils}
import org.apache.spark.sql.types.{DecimalType, StringType, StructType}
-import org.apache.spark.sql.{Row, SaveMode}
+import org.apache.spark.sql.{Row, SaveMode, SQLContext}
import org.apache.spark.{Logging, SparkFunSuite}
@@ -53,7 +53,8 @@ class HiveMetastoreCatalogSuite extends SparkFunSuite with Logging {
}
class DataSourceWithHiveMetastoreCatalogSuite extends DataSourceTest with SQLTestUtils {
- override val sqlContext = TestHive
+ override def _sqlContext: SQLContext = TestHive
+ import testImplicits._
private val testDF = range(1, 3).select(
('id + 0.1) cast DecimalType(10, 3) as 'd1,
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala
index 1fa005d5f9..fe0db5228d 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/HiveParquetSuite.scala
@@ -19,14 +19,13 @@ package org.apache.spark.sql.hive
import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.execution.datasources.parquet.ParquetTest
-import org.apache.spark.sql.{QueryTest, Row}
+import org.apache.spark.sql.{QueryTest, Row, SQLContext}
case class Cases(lower: String, UPPER: String)
class HiveParquetSuite extends QueryTest with ParquetTest {
- val sqlContext = TestHive
-
- import sqlContext._
+ private val ctx = TestHive
+ override def _sqlContext: SQLContext = ctx
test("Case insensitive attribute names") {
withParquetTable((1 to 4).map(i => Cases(i.toString, i.toString)), "cases") {
@@ -54,7 +53,7 @@ class HiveParquetSuite extends QueryTest with ParquetTest {
test("Converting Hive to Parquet Table via saveAsParquetFile") {
withTempPath { dir =>
sql("SELECT * FROM src").write.parquet(dir.getCanonicalPath)
- read.parquet(dir.getCanonicalPath).registerTempTable("p")
+ ctx.read.parquet(dir.getCanonicalPath).registerTempTable("p")
withTempTable("p") {
checkAnswer(
sql("SELECT * FROM src ORDER BY key"),
@@ -67,7 +66,7 @@ class HiveParquetSuite extends QueryTest with ParquetTest {
withParquetTable((1 to 10).map(i => (i, s"val_$i")), "t") {
withTempPath { file =>
sql("SELECT * FROM t LIMIT 1").write.parquet(file.getCanonicalPath)
- read.parquet(file.getCanonicalPath).registerTempTable("p")
+ ctx.read.parquet(file.getCanonicalPath).registerTempTable("p")
withTempTable("p") {
// let's do three overwrites for good measure
sql("INSERT OVERWRITE TABLE p SELECT * FROM t")
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
index 7f36a483a3..20a50586d5 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MetastoreDataSourcesSuite.scala
@@ -22,7 +22,6 @@ import java.io.{IOException, File}
import scala.collection.mutable.ArrayBuffer
import org.apache.hadoop.fs.Path
-import org.apache.hadoop.mapred.InvalidInputException
import org.scalatest.BeforeAndAfterAll
import org.apache.spark.Logging
@@ -42,7 +41,8 @@ import org.apache.spark.util.Utils
*/
class MetastoreDataSourcesSuite extends QueryTest with SQLTestUtils with BeforeAndAfterAll
with Logging {
- override val sqlContext = TestHive
+ override def _sqlContext: SQLContext = TestHive
+ private val sqlContext = _sqlContext
var jsonFilePath: String = _
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala
index 73852f13ad..417e8b0791 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/MultiDatabaseSuite.scala
@@ -22,9 +22,8 @@ import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.{QueryTest, SQLContext, SaveMode}
class MultiDatabaseSuite extends QueryTest with SQLTestUtils {
- override val sqlContext: SQLContext = TestHive
-
- import sqlContext.sql
+ override val _sqlContext: SQLContext = TestHive
+ private val sqlContext = _sqlContext
private val df = sqlContext.range(10).coalesce(1)
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala
index 251e0324bf..13452e71a1 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/ParquetHiveCompatibilitySuite.scala
@@ -26,7 +26,8 @@ import org.apache.spark.sql.{Row, SQLConf, SQLContext}
class ParquetHiveCompatibilitySuite extends ParquetCompatibilityTest {
import ParquetCompatibilityTest.makeNullable
- override val sqlContext: SQLContext = TestHive
+ override def _sqlContext: SQLContext = TestHive
+ private val sqlContext = _sqlContext
/**
* Set the staging directory (and hence path to ignore Parquet files under)
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala
index 9b3ede43ee..7ee1c8d13a 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/UDFSuite.scala
@@ -17,14 +17,12 @@
package org.apache.spark.sql.hive
-import org.apache.spark.sql.{Row, QueryTest}
+import org.apache.spark.sql.QueryTest
case class FunctionResult(f1: String, f2: String)
class UDFSuite extends QueryTest {
-
private lazy val ctx = org.apache.spark.sql.hive.test.TestHive
- import ctx.implicits._
test("UDF case insensitive") {
ctx.udf.register("random0", () => { Math.random() })
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
index 7b5aa4763f..a312f84958 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/AggregationQuerySuite.scala
@@ -17,17 +17,18 @@
package org.apache.spark.sql.hive.execution
+import org.scalatest.BeforeAndAfterAll
+
+import org.apache.spark.sql._
import org.apache.spark.sql.execution.aggregate
import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.test.SQLTestUtils
import org.apache.spark.sql.types.{IntegerType, StringType, StructField, StructType}
-import org.apache.spark.sql._
-import org.scalatest.BeforeAndAfterAll
import _root_.test.org.apache.spark.sql.hive.aggregate.{MyDoubleAvg, MyDoubleSum}
abstract class AggregationQuerySuite extends QueryTest with SQLTestUtils with BeforeAndAfterAll {
-
- override val sqlContext = TestHive
+ override def _sqlContext: SQLContext = TestHive
+ protected val sqlContext = _sqlContext
import sqlContext.implicits._
var originalUseAggregate2: Boolean = _
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala
index 44c5b80392..11d7a872df 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveExplainSuite.scala
@@ -26,8 +26,8 @@ import org.apache.spark.sql.test.SQLTestUtils
* A set of tests that validates support for Hive Explain command.
*/
class HiveExplainSuite extends QueryTest with SQLTestUtils {
-
- def sqlContext: SQLContext = TestHive
+ override def _sqlContext: SQLContext = TestHive
+ private val sqlContext = _sqlContext
test("explain extended command") {
checkExistence(sql(" explain select * from src where key=123 "), true,
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
index 79a136ae6f..8b8f520776 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/SQLQuerySuite.scala
@@ -66,7 +66,8 @@ class MyDialect extends DefaultParserDialect
* valid, but Hive currently cannot execute it.
*/
class SQLQuerySuite extends QueryTest with SQLTestUtils {
- override def sqlContext: SQLContext = TestHive
+ override def _sqlContext: SQLContext = TestHive
+ private val sqlContext = _sqlContext
test("UDTF") {
sql(s"ADD JAR ${TestHive.getHiveFile("TestUDTF.jar").getCanonicalPath()}")
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala
index 0875232aed..9aca40f15a 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/ScriptTransformationSuite.scala
@@ -31,7 +31,8 @@ import org.apache.spark.sql.types.StringType
class ScriptTransformationSuite extends SparkPlanTest {
- override def sqlContext: SQLContext = TestHive
+ override def _sqlContext: SQLContext = TestHive
+ private val sqlContext = _sqlContext
private val noSerdeIOSchema = HiveScriptIOSchema(
inputRowFormat = Seq.empty,
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala
index 145965388d..f7ba20ff41 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/orc/OrcTest.scala
@@ -27,8 +27,8 @@ import org.apache.spark.sql._
import org.apache.spark.sql.test.SQLTestUtils
private[sql] trait OrcTest extends SQLTestUtils { this: SparkFunSuite =>
- lazy val sqlContext = org.apache.spark.sql.hive.test.TestHive
-
+ protected override def _sqlContext: SQLContext = org.apache.spark.sql.hive.test.TestHive
+ protected val sqlContext = _sqlContext
import sqlContext.implicits._
import sqlContext.sparkContext
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala
index 50f02432da..34d3434569 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/parquetSuites.scala
@@ -685,7 +685,8 @@ class ParquetSourceSuite extends ParquetPartitioningTest {
* A collection of tests for parquet data with various forms of partitioning.
*/
abstract class ParquetPartitioningTest extends QueryTest with SQLTestUtils with BeforeAndAfterAll {
- override def sqlContext: SQLContext = TestHive
+ override def _sqlContext: SQLContext = TestHive
+ protected val sqlContext = _sqlContext
var partitionedTableDir: File = null
var normalTableDir: File = null
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala
index e976125b37..b4640b1616 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/CommitFailureTestRelationSuite.scala
@@ -18,14 +18,16 @@
package org.apache.spark.sql.sources
import org.apache.hadoop.fs.Path
-import org.apache.spark.deploy.SparkHadoopUtil
import org.apache.spark.{SparkException, SparkFunSuite}
+import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.hive.test.TestHive
import org.apache.spark.sql.test.SQLTestUtils
class CommitFailureTestRelationSuite extends SparkFunSuite with SQLTestUtils {
- override val sqlContext = TestHive
+ override def _sqlContext: SQLContext = TestHive
+ private val sqlContext = _sqlContext
// When committing a task, `CommitFailureTestSource` throws an exception for testing purpose.
val dataSourceName: String = classOf[CommitFailureTestSource].getCanonicalName
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala
index 2a69d331b6..af445626fb 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/sources/hadoopFsRelationSuites.scala
@@ -34,9 +34,8 @@ import org.apache.spark.sql.types._
abstract class HadoopFsRelationTest extends QueryTest with SQLTestUtils {
- override lazy val sqlContext: SQLContext = TestHive
-
- import sqlContext.sql
+ override def _sqlContext: SQLContext = TestHive
+ protected val sqlContext = _sqlContext
import sqlContext.implicits._
val dataSourceName: String