aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorAndrew Or <andrew@databricks.com>2016-04-21 14:18:18 -0700
committerYin Huai <yhuai@databricks.com>2016-04-21 14:18:18 -0700
commita2e8d4fddd1446df946b3c05223e8b8ac6312c3c (patch)
treedef6b1e9c95abbad4abc367506a96c4c18c019b7
parent8e1bb0456db1ad60afa24aa033b574c4a79b9c09 (diff)
downloadspark-a2e8d4fddd1446df946b3c05223e8b8ac6312c3c.tar.gz
spark-a2e8d4fddd1446df946b3c05223e8b8ac6312c3c.tar.bz2
spark-a2e8d4fddd1446df946b3c05223e8b8ac6312c3c.zip
[SPARK-13643][SQL] Implement SparkSession
## What changes were proposed in this pull request? After removing most of `HiveContext` in 8fc267ab3322e46db81e725a5cb1adb5a71b2b4d we can now move existing functionality in `SQLContext` to `SparkSession`. As of this PR `SQLContext` becomes a simple wrapper that has a `SparkSession` and delegates all functionality to it. ## How was this patch tested? Jenkins. Author: Andrew Or <andrew@databricks.com> Closes #12553 from andrewor14/implement-spark-session.
-rw-r--r--project/MimaExcludes.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala254
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala862
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/command/resources.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala4
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala36
6 files changed, 964 insertions, 197 deletions
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 3c9f1532f9..9b2a966aaf 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -653,6 +653,9 @@ object MimaExcludes {
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.status.api.v1.TaskMetricDistributions.shuffleReadMetrics"),
ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.status.api.v1.TaskMetricDistributions.this")
) ++ Seq(
+ // SPARK-13643: Move functionality from SQLContext to SparkSession
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.sql.SQLContext.getSchema")
+ ) ++ Seq(
// [SPARK-14407] Hides HadoopFsRelation related data source API into execution package
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.OutputWriter"),
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.sql.sources.OutputWriterFactory")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index d85ddd5a98..4c9977c8c7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -17,11 +17,10 @@
package org.apache.spark.sql
-import java.beans.{BeanInfo, Introspector}
+import java.beans.BeanInfo
import java.util.Properties
import java.util.concurrent.atomic.AtomicReference
-import scala.collection.JavaConverters._
import scala.collection.immutable
import scala.reflect.runtime.universe.TypeTag
@@ -34,18 +33,14 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd}
import org.apache.spark.sql.catalyst._
import org.apache.spark.sql.catalyst.catalog._
-import org.apache.spark.sql.catalyst.encoders.encoderFor
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Range}
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution._
-import org.apache.spark.sql.execution.command.ShowTablesCommand
-import org.apache.spark.sql.execution.datasources._
import org.apache.spark.sql.execution.ui.{SQLListener, SQLTab}
import org.apache.spark.sql.internal.{SessionState, SharedState, SQLConf}
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.ExecutionListenerManager
-import org.apache.spark.util.Utils
/**
* The entry point for working with structured data (rows and columns) in Spark. Allows the
@@ -69,6 +64,9 @@ class SQLContext private[sql](
self =>
+ // Note: Since Spark 2.0 this class has become a wrapper of SparkSession, where the
+ // real functionality resides. This class remains mainly for backward compatibility.
+
private[sql] def this(sparkSession: SparkSession) = {
this(sparkSession, true)
}
@@ -79,6 +77,8 @@ class SQLContext private[sql](
def this(sparkContext: JavaSparkContext) = this(sparkContext.sc)
+ // TODO: move this logic into SparkSession
+
// If spark.sql.allowMultipleContexts is true, we will throw an exception if a user
// wants to create a new root SQLContext (a SQLContext that is not created by newSession).
private val allowMultipleContexts =
@@ -103,12 +103,12 @@ class SQLContext private[sql](
protected[sql] def sessionState: SessionState = sparkSession.sessionState
protected[sql] def sharedState: SharedState = sparkSession.sharedState
- protected[sql] def conf: SQLConf = sessionState.conf
- protected[sql] def cacheManager: CacheManager = sharedState.cacheManager
- protected[sql] def listener: SQLListener = sharedState.listener
- protected[sql] def externalCatalog: ExternalCatalog = sharedState.externalCatalog
+ protected[sql] def conf: SQLConf = sparkSession.conf
+ protected[sql] def cacheManager: CacheManager = sparkSession.cacheManager
+ protected[sql] def listener: SQLListener = sparkSession.listener
+ protected[sql] def externalCatalog: ExternalCatalog = sparkSession.externalCatalog
- def sparkContext: SparkContext = sharedState.sparkContext
+ def sparkContext: SparkContext = sparkSession.sparkContext
/**
* Returns a [[SQLContext]] as new session, with separated SQL configurations, temporary
@@ -117,16 +117,14 @@ class SQLContext private[sql](
*
* @since 1.6.0
*/
- def newSession(): SQLContext = {
- new SQLContext(sparkSession.newSession(), isRootContext = false)
- }
+ def newSession(): SQLContext = sparkSession.newSession().wrapped
/**
* An interface to register custom [[org.apache.spark.sql.util.QueryExecutionListener]]s
* that listen for execution metrics.
*/
@Experimental
- def listenerManager: ExecutionListenerManager = sessionState.listenerManager
+ def listenerManager: ExecutionListenerManager = sparkSession.listenerManager
/**
* Set Spark SQL configuration properties.
@@ -134,13 +132,13 @@ class SQLContext private[sql](
* @group config
* @since 1.0.0
*/
- def setConf(props: Properties): Unit = sessionState.setConf(props)
+ def setConf(props: Properties): Unit = sparkSession.setConf(props)
/**
* Set the given Spark SQL configuration property.
*/
private[sql] def setConf[T](entry: ConfigEntry[T], value: T): Unit = {
- sessionState.setConf(entry, value)
+ sparkSession.setConf(entry, value)
}
/**
@@ -149,7 +147,7 @@ class SQLContext private[sql](
* @group config
* @since 1.0.0
*/
- def setConf(key: String, value: String): Unit = sessionState.setConf(key, value)
+ def setConf(key: String, value: String): Unit = sparkSession.setConf(key, value)
/**
* Return the value of Spark SQL configuration property for the given key.
@@ -157,13 +155,13 @@ class SQLContext private[sql](
* @group config
* @since 1.0.0
*/
- def getConf(key: String): String = conf.getConfString(key)
+ def getConf(key: String): String = sparkSession.getConf(key)
/**
* Return the value of Spark SQL configuration property for the given key. If the key is not set
* yet, return `defaultValue` in [[ConfigEntry]].
*/
- private[sql] def getConf[T](entry: ConfigEntry[T]): T = conf.getConf(entry)
+ private[sql] def getConf[T](entry: ConfigEntry[T]): T = sparkSession.getConf(entry)
/**
* Return the value of Spark SQL configuration property for the given key. If the key is not set
@@ -171,7 +169,7 @@ class SQLContext private[sql](
* desired one.
*/
private[sql] def getConf[T](entry: ConfigEntry[T], defaultValue: T): T = {
- conf.getConf(entry, defaultValue)
+ sparkSession.getConf(entry, defaultValue)
}
/**
@@ -181,7 +179,7 @@ class SQLContext private[sql](
* @group config
* @since 1.0.0
*/
- def getConf(key: String, defaultValue: String): String = conf.getConfString(key, defaultValue)
+ def getConf(key: String, defaultValue: String): String = sparkSession.getConf(key, defaultValue)
/**
* Return all the configuration properties that have been set (i.e. not the default).
@@ -190,21 +188,14 @@ class SQLContext private[sql](
* @group config
* @since 1.0.0
*/
- def getAllConfs: immutable.Map[String, String] = conf.getAllConfs
+ def getAllConfs: immutable.Map[String, String] = sparkSession.getAllConfs
- protected[sql] def parseSql(sql: String): LogicalPlan = sessionState.sqlParser.parsePlan(sql)
+ protected[sql] def parseSql(sql: String): LogicalPlan = sparkSession.parseSql(sql)
- protected[sql] def executeSql(sql: String): QueryExecution = executePlan(parseSql(sql))
+ protected[sql] def executeSql(sql: String): QueryExecution = sparkSession.executeSql(sql)
protected[sql] def executePlan(plan: LogicalPlan): QueryExecution = {
- sessionState.executePlan(plan)
- }
-
- /**
- * Add a jar to SQLContext
- */
- protected[sql] def addJar(path: String): Unit = {
- sessionState.addJar(path)
+ sparkSession.executePlan(plan)
}
/**
@@ -217,7 +208,7 @@ class SQLContext private[sql](
*/
@Experimental
@transient
- def experimental: ExperimentalMethods = sessionState.experimentalMethods
+ def experimental: ExperimentalMethods = sparkSession.experimental
/**
* :: Experimental ::
@@ -227,8 +218,7 @@ class SQLContext private[sql](
* @since 1.3.0
*/
@Experimental
- @transient
- lazy val emptyDataFrame: DataFrame = createDataFrame(sparkContext.emptyRDD[Row], StructType(Nil))
+ def emptyDataFrame: DataFrame = sparkSession.emptyDataFrame
/**
* A collection of methods for registering user-defined functions (UDF).
@@ -259,7 +249,7 @@ class SQLContext private[sql](
* @group basic
* @since 1.3.0
*/
- def udf: UDFRegistration = sessionState.udf
+ def udf: UDFRegistration = sparkSession.udf
/**
* Returns true if the table is currently cached in-memory.
@@ -267,7 +257,7 @@ class SQLContext private[sql](
* @since 1.3.0
*/
def isCached(tableName: String): Boolean = {
- cacheManager.lookupCachedData(table(tableName)).nonEmpty
+ sparkSession.isCached(tableName)
}
/**
@@ -276,7 +266,7 @@ class SQLContext private[sql](
* @since 1.3.0
*/
private[sql] def isCached(qName: Dataset[_]): Boolean = {
- cacheManager.lookupCachedData(qName).nonEmpty
+ sparkSession.isCached(qName)
}
/**
@@ -285,7 +275,7 @@ class SQLContext private[sql](
* @since 1.3.0
*/
def cacheTable(tableName: String): Unit = {
- cacheManager.cacheQuery(table(tableName), Some(tableName))
+ sparkSession.cacheTable(tableName)
}
/**
@@ -293,13 +283,17 @@ class SQLContext private[sql](
* @group cachemgmt
* @since 1.3.0
*/
- def uncacheTable(tableName: String): Unit = cacheManager.uncacheQuery(table(tableName))
+ def uncacheTable(tableName: String): Unit = {
+ sparkSession.uncacheTable(tableName)
+ }
/**
* Removes all cached tables from the in-memory cache.
* @since 1.3.0
*/
- def clearCache(): Unit = cacheManager.clearCache()
+ def clearCache(): Unit = {
+ sparkSession.clearCache()
+ }
// scalastyle:off
// Disable style checker so "implicits" object can start with lowercase i
@@ -331,11 +325,7 @@ class SQLContext private[sql](
*/
@Experimental
def createDataFrame[A <: Product : TypeTag](rdd: RDD[A]): DataFrame = {
- SQLContext.setActive(self)
- val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType]
- val attributeSeq = schema.toAttributes
- val rowRDD = RDDConversions.productToRowRdd(rdd, schema.map(_.dataType))
- Dataset.ofRows(self, LogicalRDD(attributeSeq, rowRDD)(self))
+ sparkSession.createDataFrame(rdd)
}
/**
@@ -347,10 +337,7 @@ class SQLContext private[sql](
*/
@Experimental
def createDataFrame[A <: Product : TypeTag](data: Seq[A]): DataFrame = {
- SQLContext.setActive(self)
- val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType]
- val attributeSeq = schema.toAttributes
- Dataset.ofRows(self, LocalRelation.fromProduct(attributeSeq, data))
+ sparkSession.createDataFrame(data)
}
/**
@@ -360,7 +347,7 @@ class SQLContext private[sql](
* @since 1.3.0
*/
def baseRelationToDataFrame(baseRelation: BaseRelation): DataFrame = {
- Dataset.ofRows(this, LogicalRelation(baseRelation))
+ sparkSession.baseRelationToDataFrame(baseRelation)
}
/**
@@ -397,7 +384,7 @@ class SQLContext private[sql](
*/
@DeveloperApi
def createDataFrame(rowRDD: RDD[Row], schema: StructType): DataFrame = {
- createDataFrame(rowRDD, schema, needsConversion = true)
+ sparkSession.createDataFrame(rowRDD, schema)
}
/**
@@ -406,39 +393,20 @@ class SQLContext private[sql](
*/
private[sql]
def createDataFrame(rowRDD: RDD[Row], schema: StructType, needsConversion: Boolean) = {
- // TODO: use MutableProjection when rowRDD is another DataFrame and the applied
- // schema differs from the existing schema on any field data type.
- val catalystRows = if (needsConversion) {
- val converter = CatalystTypeConverters.createToCatalystConverter(schema)
- rowRDD.map(converter(_).asInstanceOf[InternalRow])
- } else {
- rowRDD.map{r: Row => InternalRow.fromSeq(r.toSeq)}
- }
- val logicalPlan = LogicalRDD(schema.toAttributes, catalystRows)(self)
- Dataset.ofRows(this, logicalPlan)
+ sparkSession.createDataFrame(rowRDD, schema, needsConversion)
}
def createDataset[T : Encoder](data: Seq[T]): Dataset[T] = {
- val enc = encoderFor[T]
- val attributes = enc.schema.toAttributes
- val encoded = data.map(d => enc.toRow(d).copy())
- val plan = new LocalRelation(attributes, encoded)
-
- Dataset[T](this, plan)
+ sparkSession.createDataset(data)
}
def createDataset[T : Encoder](data: RDD[T]): Dataset[T] = {
- val enc = encoderFor[T]
- val attributes = enc.schema.toAttributes
- val encoded = data.map(d => enc.toRow(d))
- val plan = LogicalRDD(attributes, encoded)(self)
-
- Dataset[T](this, plan)
+ sparkSession.createDataset(data)
}
def createDataset[T : Encoder](data: java.util.List[T]): Dataset[T] = {
- createDataset(data.asScala)
+ sparkSession.createDataset(data)
}
/**
@@ -447,10 +415,7 @@ class SQLContext private[sql](
*/
private[sql]
def internalCreateDataFrame(catalystRows: RDD[InternalRow], schema: StructType) = {
- // TODO: use MutableProjection when rowRDD is another DataFrame and the applied
- // schema differs from the existing schema on any field data type.
- val logicalPlan = LogicalRDD(schema.toAttributes, catalystRows)(self)
- Dataset.ofRows(this, logicalPlan)
+ sparkSession.internalCreateDataFrame(catalystRows, schema)
}
/**
@@ -464,7 +429,7 @@ class SQLContext private[sql](
*/
@DeveloperApi
def createDataFrame(rowRDD: JavaRDD[Row], schema: StructType): DataFrame = {
- createDataFrame(rowRDD.rdd, schema)
+ sparkSession.createDataFrame(rowRDD, schema)
}
/**
@@ -478,7 +443,7 @@ class SQLContext private[sql](
*/
@DeveloperApi
def createDataFrame(rows: java.util.List[Row], schema: StructType): DataFrame = {
- Dataset.ofRows(self, LocalRelation.fromExternalRows(schema.toAttributes, rows.asScala))
+ sparkSession.createDataFrame(rows, schema)
}
/**
@@ -490,14 +455,7 @@ class SQLContext private[sql](
* @since 1.3.0
*/
def createDataFrame(rdd: RDD[_], beanClass: Class[_]): DataFrame = {
- val attributeSeq: Seq[AttributeReference] = getSchema(beanClass)
- val className = beanClass.getName
- val rowRdd = rdd.mapPartitions { iter =>
- // BeanInfo is not serializable so we must rediscover it remotely for each partition.
- val localBeanInfo = Introspector.getBeanInfo(Utils.classForName(className))
- SQLContext.beansToRows(iter, localBeanInfo, attributeSeq)
- }
- Dataset.ofRows(this, LogicalRDD(attributeSeq, rowRdd)(this))
+ sparkSession.createDataFrame(rdd, beanClass)
}
/**
@@ -509,7 +467,7 @@ class SQLContext private[sql](
* @since 1.3.0
*/
def createDataFrame(rdd: JavaRDD[_], beanClass: Class[_]): DataFrame = {
- createDataFrame(rdd.rdd, beanClass)
+ sparkSession.createDataFrame(rdd, beanClass)
}
/**
@@ -521,11 +479,7 @@ class SQLContext private[sql](
* @since 1.6.0
*/
def createDataFrame(data: java.util.List[_], beanClass: Class[_]): DataFrame = {
- val attrSeq = getSchema(beanClass)
- val className = beanClass.getName
- val beanInfo = Introspector.getBeanInfo(beanClass)
- val rows = SQLContext.beansToRows(data.asScala.iterator, beanInfo, attrSeq)
- Dataset.ofRows(self, LocalRelation(attrSeq, rows.toSeq))
+ sparkSession.createDataFrame(data, beanClass)
}
/**
@@ -540,7 +494,7 @@ class SQLContext private[sql](
* @since 1.4.0
*/
@Experimental
- def read: DataFrameReader = new DataFrameReader(this)
+ def read: DataFrameReader = sparkSession.read
/**
* :: Experimental ::
@@ -552,8 +506,7 @@ class SQLContext private[sql](
*/
@Experimental
def createExternalTable(tableName: String, path: String): DataFrame = {
- val dataSourceName = conf.defaultDataSourceName
- createExternalTable(tableName, path, dataSourceName)
+ sparkSession.createExternalTable(tableName, path)
}
/**
@@ -569,7 +522,7 @@ class SQLContext private[sql](
tableName: String,
path: String,
source: String): DataFrame = {
- createExternalTable(tableName, source, Map("path" -> path))
+ sparkSession.createExternalTable(tableName, path, source)
}
/**
@@ -585,7 +538,7 @@ class SQLContext private[sql](
tableName: String,
source: String,
options: java.util.Map[String, String]): DataFrame = {
- createExternalTable(tableName, source, options.asScala.toMap)
+ sparkSession.createExternalTable(tableName, source, options)
}
/**
@@ -602,18 +555,7 @@ class SQLContext private[sql](
tableName: String,
source: String,
options: Map[String, String]): DataFrame = {
- val tableIdent = sessionState.sqlParser.parseTableIdentifier(tableName)
- val cmd =
- CreateTableUsing(
- tableIdent,
- userSpecifiedSchema = None,
- source,
- temporary = false,
- options,
- allowExisting = false,
- managedIfNoPath = false)
- executePlan(cmd).toRdd
- table(tableIdent)
+ sparkSession.createExternalTable(tableName, source, options)
}
/**
@@ -630,7 +572,7 @@ class SQLContext private[sql](
source: String,
schema: StructType,
options: java.util.Map[String, String]): DataFrame = {
- createExternalTable(tableName, source, schema, options.asScala.toMap)
+ sparkSession.createExternalTable(tableName, source, schema, options)
}
/**
@@ -648,18 +590,7 @@ class SQLContext private[sql](
source: String,
schema: StructType,
options: Map[String, String]): DataFrame = {
- val tableIdent = sessionState.sqlParser.parseTableIdentifier(tableName)
- val cmd =
- CreateTableUsing(
- tableIdent,
- userSpecifiedSchema = Some(schema),
- source,
- temporary = false,
- options,
- allowExisting = false,
- managedIfNoPath = false)
- executePlan(cmd).toRdd
- table(tableIdent)
+ sparkSession.createExternalTable(tableName, source, schema, options)
}
/**
@@ -667,10 +598,7 @@ class SQLContext private[sql](
* only during the lifetime of this instance of SQLContext.
*/
private[sql] def registerDataFrameAsTable(df: DataFrame, tableName: String): Unit = {
- sessionState.catalog.createTempTable(
- sessionState.sqlParser.parseTableIdentifier(tableName).table,
- df.logicalPlan,
- overrideIfExists = true)
+ sparkSession.registerDataFrameAsTable(df, tableName)
}
/**
@@ -682,8 +610,7 @@ class SQLContext private[sql](
* @since 1.3.0
*/
def dropTempTable(tableName: String): Unit = {
- cacheManager.tryUncacheQuery(table(tableName))
- sessionState.catalog.dropTable(TableIdentifier(tableName), ignoreIfNotExists = true)
+ sparkSession.dropTempTable(tableName)
}
/**
@@ -695,7 +622,7 @@ class SQLContext private[sql](
* @group dataset
*/
@Experimental
- def range(end: Long): Dataset[java.lang.Long] = range(0, end)
+ def range(end: Long): Dataset[java.lang.Long] = sparkSession.range(end)
/**
* :: Experimental ::
@@ -706,9 +633,7 @@ class SQLContext private[sql](
* @group dataset
*/
@Experimental
- def range(start: Long, end: Long): Dataset[java.lang.Long] = {
- range(start, end, step = 1, numPartitions = sparkContext.defaultParallelism)
- }
+ def range(start: Long, end: Long): Dataset[java.lang.Long] = sparkSession.range(start, end)
/**
* :: Experimental ::
@@ -720,7 +645,7 @@ class SQLContext private[sql](
*/
@Experimental
def range(start: Long, end: Long, step: Long): Dataset[java.lang.Long] = {
- range(start, end, step, numPartitions = sparkContext.defaultParallelism)
+ sparkSession.range(start, end, step)
}
/**
@@ -734,7 +659,7 @@ class SQLContext private[sql](
*/
@Experimental
def range(start: Long, end: Long, step: Long, numPartitions: Int): Dataset[java.lang.Long] = {
- new Dataset(this, Range(start, end, step, numPartitions), Encoders.LONG)
+ sparkSession.range(start, end, step, numPartitions)
}
/**
@@ -744,9 +669,7 @@ class SQLContext private[sql](
* @group basic
* @since 1.3.0
*/
- def sql(sqlText: String): DataFrame = {
- Dataset.ofRows(this, parseSql(sqlText))
- }
+ def sql(sqlText: String): DataFrame = sparkSession.sql(sqlText)
/**
* Executes a SQL query without parsing it, but instead passing it directly to an underlying
@@ -754,7 +677,7 @@ class SQLContext private[sql](
* as Spark can parse all supported Hive DDLs itself.
*/
private[sql] def runNativeSql(sqlText: String): Seq[Row] = {
- sessionState.runNativeSql(sqlText).map { r => Row(r) }
+ sparkSession.runNativeSql(sqlText)
}
/**
@@ -764,11 +687,7 @@ class SQLContext private[sql](
* @since 1.3.0
*/
def table(tableName: String): DataFrame = {
- table(sessionState.sqlParser.parseTableIdentifier(tableName))
- }
-
- private def table(tableIdent: TableIdentifier): DataFrame = {
- Dataset.ofRows(this, sessionState.catalog.lookupRelation(tableIdent))
+ sparkSession.table(tableName)
}
/**
@@ -780,7 +699,7 @@ class SQLContext private[sql](
* @since 1.3.0
*/
def tables(): DataFrame = {
- Dataset.ofRows(this, ShowTablesCommand(None, None))
+ sparkSession.tables()
}
/**
@@ -792,7 +711,7 @@ class SQLContext private[sql](
* @since 1.3.0
*/
def tables(databaseName: String): DataFrame = {
- Dataset.ofRows(this, ShowTablesCommand(Some(databaseName), None))
+ sparkSession.tables(databaseName)
}
/**
@@ -801,7 +720,7 @@ class SQLContext private[sql](
*
* @since 2.0.0
*/
- def streams: ContinuousQueryManager = sessionState.continuousQueryManager
+ def streams: ContinuousQueryManager = sparkSession.streams
/**
* Returns the names of tables in the current database as an array.
@@ -810,7 +729,7 @@ class SQLContext private[sql](
* @since 1.3.0
*/
def tableNames(): Array[String] = {
- tableNames(sessionState.catalog.getCurrentDatabase)
+ sparkSession.tableNames()
}
/**
@@ -820,19 +739,16 @@ class SQLContext private[sql](
* @since 1.3.0
*/
def tableNames(databaseName: String): Array[String] = {
- sessionState.catalog.listTables(databaseName).map(_.table).toArray
+ sparkSession.tableNames(databaseName)
}
- @transient
- protected[sql] lazy val emptyResult = sparkContext.parallelize(Seq.empty[InternalRow], 1)
-
/**
* Parses the data type in our internal string representation. The data type string should
* have the same format as the one generated by `toString` in scala.
* It is only used by PySpark.
*/
protected[sql] def parseDataType(dataTypeString: String): DataType = {
- DataType.fromJson(dataTypeString)
+ sparkSession.parseDataType(dataTypeString)
}
/**
@@ -841,8 +757,7 @@ class SQLContext private[sql](
protected[sql] def applySchemaToPythonRDD(
rdd: RDD[Array[Any]],
schemaString: String): DataFrame = {
- val schema = parseDataType(schemaString).asInstanceOf[StructType]
- applySchemaToPythonRDD(rdd, schema)
+ sparkSession.applySchemaToPythonRDD(rdd, schemaString)
}
/**
@@ -851,20 +766,10 @@ class SQLContext private[sql](
protected[sql] def applySchemaToPythonRDD(
rdd: RDD[Array[Any]],
schema: StructType): DataFrame = {
-
- val rowRdd = rdd.map(r => python.EvaluatePython.fromJava(r, schema).asInstanceOf[InternalRow])
- Dataset.ofRows(this, LogicalRDD(schema.toAttributes, rowRdd)(self))
+ sparkSession.applySchemaToPythonRDD(rdd, schema)
}
- /**
- * Returns a Catalyst Schema for the given java bean class.
- */
- protected def getSchema(beanClass: Class[_]): Seq[AttributeReference] = {
- val (dataType, _) = JavaTypeInference.inferDataType(beanClass)
- dataType.asInstanceOf[StructType].fields.map { f =>
- AttributeReference(f.name, f.dataType, f.nullable)()
- }
- }
+ // TODO: move this logic into SparkSession
// Register a successfully instantiated context to the singleton. This should be at the end of
// the class definition so that the singleton is updated only if there is no exception in the
@@ -876,6 +781,7 @@ class SQLContext private[sql](
}
})
+ sparkSession.setWrappedContext(self)
SQLContext.setInstantiatedContext(self)
}
@@ -980,8 +886,10 @@ object SQLContext {
* bean info & schema. This is not related to the singleton, but is a static
* method for internal use.
*/
- private def beansToRows(data: Iterator[_], beanInfo: BeanInfo, attrs: Seq[AttributeReference]):
- Iterator[InternalRow] = {
+ private[sql] def beansToRows(
+ data: Iterator[_],
+ beanInfo: BeanInfo,
+ attrs: Seq[AttributeReference]): Iterator[InternalRow] = {
val extractors =
beanInfo.getPropertyDescriptors.filterNot(_.getName == "class").map(_.getReadMethod)
val methodsToConverts = extractors.zip(attrs).map { case (e, attr) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
index 17ba299825..70d889b002 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -17,12 +17,33 @@
package org.apache.spark.sql
+import java.beans.Introspector
+import java.util.Properties
+
+import scala.collection.immutable
+import scala.collection.JavaConverters._
import scala.reflect.ClassTag
+import scala.reflect.runtime.universe.TypeTag
import scala.util.control.NonFatal
import org.apache.spark.{SparkConf, SparkContext}
-import org.apache.spark.internal.config.CATALOG_IMPLEMENTATION
-import org.apache.spark.sql.internal.{SessionState, SharedState}
+import org.apache.spark.annotation.{DeveloperApi, Experimental}
+import org.apache.spark.api.java.JavaRDD
+import org.apache.spark.internal.config.{CATALOG_IMPLEMENTATION, ConfigEntry}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst._
+import org.apache.spark.sql.catalyst.catalog._
+import org.apache.spark.sql.catalyst.encoders._
+import org.apache.spark.sql.catalyst.expressions.AttributeReference
+import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Range}
+import org.apache.spark.sql.execution._
+import org.apache.spark.sql.execution.command.ShowTablesCommand
+import org.apache.spark.sql.execution.datasources.{CreateTableUsing, LogicalRelation}
+import org.apache.spark.sql.execution.ui.SQLListener
+import org.apache.spark.sql.internal.{SessionState, SharedState, SQLConf}
+import org.apache.spark.sql.sources.BaseRelation
+import org.apache.spark.sql.types.{DataType, LongType, StructType}
+import org.apache.spark.sql.util.ExecutionListenerManager
import org.apache.spark.util.Utils
@@ -30,22 +51,22 @@ import org.apache.spark.util.Utils
* The entry point to Spark execution.
*/
class SparkSession private(
- sparkContext: SparkContext,
- existingSharedState: Option[SharedState]) { self =>
+ @transient val sparkContext: SparkContext,
+ @transient private val existingSharedState: Option[SharedState]) { self =>
def this(sc: SparkContext) {
this(sc, None)
}
+
+ /* ----------------------- *
+ | Session-related state |
+ * ----------------------- */
+
/**
- * Start a new session where configurations, temp tables, temp functions etc. are isolated.
+ * State shared across sessions, including the [[SparkContext]], cached data, listener,
+ * and a catalog that interacts with external systems.
*/
- def newSession(): SparkSession = {
- // Note: materialize the shared state here to ensure the parent and child sessions are
- // initialized with the same shared state.
- new SparkSession(sparkContext, Some(sharedState))
- }
-
@transient
protected[sql] lazy val sharedState: SharedState = {
existingSharedState.getOrElse(
@@ -54,6 +75,10 @@ class SparkSession private(
sparkContext))
}
+ /**
+ * State isolated across sessions, including SQL configurations, temporary tables,
+ * registered functions, and everything else that accepts a [[SQLConf]].
+ */
@transient
protected[sql] lazy val sessionState: SessionState = {
SparkSession.reflect[SessionState, SQLContext](
@@ -61,6 +86,821 @@ class SparkSession private(
new SQLContext(self, isRootContext = false))
}
+ /**
+ * A wrapped version of this session in the form of a [[SQLContext]].
+ */
+ @transient
+ private var _wrapped: SQLContext = _
+
+ protected[sql] def wrapped: SQLContext = {
+ if (_wrapped == null) {
+ _wrapped = new SQLContext(self, isRootContext = false)
+ }
+ _wrapped
+ }
+
+ protected[sql] def setWrappedContext(sqlContext: SQLContext): Unit = {
+ _wrapped = sqlContext
+ }
+
+ protected[sql] def conf: SQLConf = sessionState.conf
+ protected[sql] def cacheManager: CacheManager = sharedState.cacheManager
+ protected[sql] def listener: SQLListener = sharedState.listener
+ protected[sql] def externalCatalog: ExternalCatalog = sharedState.externalCatalog
+
+ /**
+ * :: Experimental ::
+ * An interface to register custom [[org.apache.spark.sql.util.QueryExecutionListener]]s
+ * that listen for execution metrics.
+ *
+ * @group basic
+ * @since 2.0.0
+ */
+ @Experimental
+ def listenerManager: ExecutionListenerManager = sessionState.listenerManager
+
+ /**
+ * :: Experimental ::
+ * A collection of methods that are considered experimental, but can be used to hook into
+ * the query planner for advanced functionality.
+ *
+ * @group basic
+ * @since 2.0.0
+ */
+ @Experimental
+ def experimental: ExperimentalMethods = sessionState.experimentalMethods
+
+ /**
+ * A collection of methods for registering user-defined functions (UDF).
+ *
+ * The following example registers a Scala closure as UDF:
+ * {{{
+ * sparkSession.udf.register("myUDF", (arg1: Int, arg2: String) => arg2 + arg1)
+ * }}}
+ *
+ * The following example registers a UDF in Java:
+ * {{{
+ * sparkSession.udf().register("myUDF",
+ * new UDF2<Integer, String, String>() {
+ * @Override
+ * public String call(Integer arg1, String arg2) {
+ * return arg2 + arg1;
+ * }
+ * }, DataTypes.StringType);
+ * }}}
+ *
+ * Or, to use Java 8 lambda syntax:
+ * {{{
+ * sparkSession.udf().register("myUDF",
+ * (Integer arg1, String arg2) -> arg2 + arg1,
+ * DataTypes.StringType);
+ * }}}
+ *
+ * @group basic
+ * @since 2.0.0
+ */
+ def udf: UDFRegistration = sessionState.udf
+
+ /**
+ * Returns a [[ContinuousQueryManager]] that allows managing all the
+ * [[org.apache.spark.sql.ContinuousQuery ContinuousQueries]] active on `this`.
+ *
+ * @group basic
+ * @since 2.0.0
+ */
+ def streams: ContinuousQueryManager = sessionState.continuousQueryManager
+
+ /**
+ * Start a new session with isolated SQL configurations, temporary tables, registered
+ * functions are isolated, but sharing the underlying [[SparkContext]] and cached data.
+ *
+ * Note: Other than the [[SparkContext]], all shared state is initialized lazily.
+ * This method will force the initialization of the shared state to ensure that parent
+ * and child sessions are set up with the same shared state. If the underlying catalog
+ * implementation is Hive, this will initialize the metastore, which may take some time.
+ *
+ * @group basic
+ * @since 2.0.0
+ */
+ def newSession(): SparkSession = {
+ new SparkSession(sparkContext, Some(sharedState))
+ }
+
+
+ /* ------------------------------------------------- *
+ | Methods for accessing or mutating configurations |
+ * ------------------------------------------------- */
+
+ /**
+ * Set Spark SQL configuration properties.
+ *
+ * @group config
+ * @since 2.0.0
+ */
+ def setConf(props: Properties): Unit = sessionState.setConf(props)
+
+ /**
+ * Set the given Spark SQL configuration property.
+ *
+ * @group config
+ * @since 2.0.0
+ */
+ def setConf(key: String, value: String): Unit = sessionState.setConf(key, value)
+
+ /**
+ * Return the value of Spark SQL configuration property for the given key.
+ *
+ * @group config
+ * @since 2.0.0
+ */
+ def getConf(key: String): String = conf.getConfString(key)
+
+ /**
+ * Return the value of Spark SQL configuration property for the given key. If the key is not set
+ * yet, return `defaultValue`.
+ *
+ * @group config
+ * @since 2.0.0
+ */
+ def getConf(key: String, defaultValue: String): String = conf.getConfString(key, defaultValue)
+
+ /**
+ * Return all the configuration properties that have been set (i.e. not the default).
+ * This creates a new copy of the config properties in the form of a Map.
+ *
+ * @group config
+ * @since 2.0.0
+ */
+ def getAllConfs: immutable.Map[String, String] = conf.getAllConfs
+
+ /**
+ * Set the given Spark SQL configuration property.
+ */
+ protected[sql] def setConf[T](entry: ConfigEntry[T], value: T): Unit = {
+ sessionState.setConf(entry, value)
+ }
+
+ /**
+ * Return the value of Spark SQL configuration property for the given key. If the key is not set
+ * yet, return `defaultValue` in [[ConfigEntry]].
+ */
+ protected[sql] def getConf[T](entry: ConfigEntry[T]): T = conf.getConf(entry)
+
+ /**
+ * Return the value of Spark SQL configuration property for the given key. If the key is not set
+ * yet, return `defaultValue`. This is useful when `defaultValue` in ConfigEntry is not the
+ * desired one.
+ */
+ protected[sql] def getConf[T](entry: ConfigEntry[T], defaultValue: T): T = {
+ conf.getConf(entry, defaultValue)
+ }
+
+
+ /* ------------------------------------- *
+ | Methods related to cache management |
+ * ------------------------------------- */
+
+ /**
+ * Returns true if the table is currently cached in-memory.
+ *
+ * @group cachemgmt
+ * @since 2.0.0
+ */
+ def isCached(tableName: String): Boolean = {
+ cacheManager.lookupCachedData(table(tableName)).nonEmpty
+ }
+
+ /**
+ * Caches the specified table in-memory.
+ *
+ * @group cachemgmt
+ * @since 2.0.0
+ */
+ def cacheTable(tableName: String): Unit = {
+ cacheManager.cacheQuery(table(tableName), Some(tableName))
+ }
+
+ /**
+ * Removes the specified table from the in-memory cache.
+ *
+ * @group cachemgmt
+ * @since 2.0.0
+ */
+ def uncacheTable(tableName: String): Unit = {
+ cacheManager.uncacheQuery(table(tableName))
+ }
+
+ /**
+ * Removes all cached tables from the in-memory cache.
+ *
+ * @group cachemgmt
+ * @since 2.0.0
+ */
+ def clearCache(): Unit = {
+ cacheManager.clearCache()
+ }
+
+ /**
+ * Returns true if the [[Dataset]] is currently cached in-memory.
+ *
+ * @group cachemgmt
+ * @since 2.0.0
+ */
+ protected[sql] def isCached(qName: Dataset[_]): Boolean = {
+ cacheManager.lookupCachedData(qName).nonEmpty
+ }
+
+ /* --------------------------------- *
+ | Methods for creating DataFrames |
+ * --------------------------------- */
+
+ /**
+ * :: Experimental ::
+ * Returns a [[DataFrame]] with no rows or columns.
+ *
+ * @group dataframes
+ * @since 2.0.0
+ */
+ @Experimental
+ @transient
+ lazy val emptyDataFrame: DataFrame = {
+ createDataFrame(sparkContext.emptyRDD[Row], StructType(Nil))
+ }
+
+ /**
+ * :: Experimental ::
+ * Creates a [[DataFrame]] from an RDD of Product (e.g. case classes, tuples).
+ *
+ * @group dataframes
+ * @since 2.0.0
+ */
+ @Experimental
+ def createDataFrame[A <: Product : TypeTag](rdd: RDD[A]): DataFrame = {
+ SQLContext.setActive(wrapped)
+ val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType]
+ val attributeSeq = schema.toAttributes
+ val rowRDD = RDDConversions.productToRowRdd(rdd, schema.map(_.dataType))
+ Dataset.ofRows(wrapped, LogicalRDD(attributeSeq, rowRDD)(wrapped))
+ }
+
+ /**
+ * :: Experimental ::
+ * Creates a [[DataFrame]] from a local Seq of Product.
+ *
+ * @group dataframes
+ * @since 2.0.0
+ */
+ @Experimental
+ def createDataFrame[A <: Product : TypeTag](data: Seq[A]): DataFrame = {
+ SQLContext.setActive(wrapped)
+ val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType]
+ val attributeSeq = schema.toAttributes
+ Dataset.ofRows(wrapped, LocalRelation.fromProduct(attributeSeq, data))
+ }
+
+ /**
+ * :: DeveloperApi ::
+ * Creates a [[DataFrame]] from an [[RDD]] containing [[Row]]s using the given schema.
+ * It is important to make sure that the structure of every [[Row]] of the provided RDD matches
+ * the provided schema. Otherwise, there will be runtime exception.
+ * Example:
+ * {{{
+ * import org.apache.spark.sql._
+ * import org.apache.spark.sql.types._
+ * val sparkSession = new org.apache.spark.sql.SparkSession(sc)
+ *
+ * val schema =
+ * StructType(
+ * StructField("name", StringType, false) ::
+ * StructField("age", IntegerType, true) :: Nil)
+ *
+ * val people =
+ * sc.textFile("examples/src/main/resources/people.txt").map(
+ * _.split(",")).map(p => Row(p(0), p(1).trim.toInt))
+ * val dataFrame = sparkSession.createDataFrame(people, schema)
+ * dataFrame.printSchema
+ * // root
+ * // |-- name: string (nullable = false)
+ * // |-- age: integer (nullable = true)
+ *
+ * dataFrame.registerTempTable("people")
+ * sparkSession.sql("select name from people").collect.foreach(println)
+ * }}}
+ *
+ * @group dataframes
+ * @since 2.0.0
+ */
+ @DeveloperApi
+ def createDataFrame(rowRDD: RDD[Row], schema: StructType): DataFrame = {
+ createDataFrame(rowRDD, schema, needsConversion = true)
+ }
+
+ /**
+ * :: DeveloperApi ::
+ * Creates a [[DataFrame]] from an [[JavaRDD]] containing [[Row]]s using the given schema.
+ * It is important to make sure that the structure of every [[Row]] of the provided RDD matches
+ * the provided schema. Otherwise, there will be runtime exception.
+ *
+ * @group dataframes
+ * @since 2.0.0
+ */
+ @DeveloperApi
+ def createDataFrame(rowRDD: JavaRDD[Row], schema: StructType): DataFrame = {
+ createDataFrame(rowRDD.rdd, schema)
+ }
+
+ /**
+ * :: DeveloperApi ::
+ * Creates a [[DataFrame]] from an [[java.util.List]] containing [[Row]]s using the given schema.
+ * It is important to make sure that the structure of every [[Row]] of the provided List matches
+ * the provided schema. Otherwise, there will be runtime exception.
+ *
+ * @group dataframes
+ * @since 2.0.0
+ */
+ @DeveloperApi
+ def createDataFrame(rows: java.util.List[Row], schema: StructType): DataFrame = {
+ Dataset.ofRows(wrapped, LocalRelation.fromExternalRows(schema.toAttributes, rows.asScala))
+ }
+
+ /**
+ * Applies a schema to an RDD of Java Beans.
+ *
+ * WARNING: Since there is no guaranteed ordering for fields in a Java Bean,
+ * SELECT * queries will return the columns in an undefined order.
+ *
+ * @group dataframes
+ * @since 2.0.0
+ */
+ def createDataFrame(rdd: RDD[_], beanClass: Class[_]): DataFrame = {
+ val attributeSeq: Seq[AttributeReference] = getSchema(beanClass)
+ val className = beanClass.getName
+ val rowRdd = rdd.mapPartitions { iter =>
+ // BeanInfo is not serializable so we must rediscover it remotely for each partition.
+ val localBeanInfo = Introspector.getBeanInfo(Utils.classForName(className))
+ SQLContext.beansToRows(iter, localBeanInfo, attributeSeq)
+ }
+ Dataset.ofRows(wrapped, LogicalRDD(attributeSeq, rowRdd)(wrapped))
+ }
+
+ /**
+ * Applies a schema to an RDD of Java Beans.
+ *
+ * WARNING: Since there is no guaranteed ordering for fields in a Java Bean,
+ * SELECT * queries will return the columns in an undefined order.
+ *
+ * @group dataframes
+ * @since 2.0.0
+ */
+ def createDataFrame(rdd: JavaRDD[_], beanClass: Class[_]): DataFrame = {
+ createDataFrame(rdd.rdd, beanClass)
+ }
+
+ /**
+ * Applies a schema to an List of Java Beans.
+ *
+ * WARNING: Since there is no guaranteed ordering for fields in a Java Bean,
+ * SELECT * queries will return the columns in an undefined order.
+ * @group dataframes
+ * @since 1.6.0
+ */
+ def createDataFrame(data: java.util.List[_], beanClass: Class[_]): DataFrame = {
+ val attrSeq = getSchema(beanClass)
+ val beanInfo = Introspector.getBeanInfo(beanClass)
+ val rows = SQLContext.beansToRows(data.asScala.iterator, beanInfo, attrSeq)
+ Dataset.ofRows(wrapped, LocalRelation(attrSeq, rows.toSeq))
+ }
+
+ /**
+ * Convert a [[BaseRelation]] created for external data sources into a [[DataFrame]].
+ *
+ * @group dataframes
+ * @since 2.0.0
+ */
+ def baseRelationToDataFrame(baseRelation: BaseRelation): DataFrame = {
+ Dataset.ofRows(wrapped, LogicalRelation(baseRelation))
+ }
+
+ def createDataset[T : Encoder](data: Seq[T]): Dataset[T] = {
+ val enc = encoderFor[T]
+ val attributes = enc.schema.toAttributes
+ val encoded = data.map(d => enc.toRow(d).copy())
+ val plan = new LocalRelation(attributes, encoded)
+ Dataset[T](wrapped, plan)
+ }
+
+ def createDataset[T : Encoder](data: RDD[T]): Dataset[T] = {
+ val enc = encoderFor[T]
+ val attributes = enc.schema.toAttributes
+ val encoded = data.map(d => enc.toRow(d))
+ val plan = LogicalRDD(attributes, encoded)(wrapped)
+ Dataset[T](wrapped, plan)
+ }
+
+ def createDataset[T : Encoder](data: java.util.List[T]): Dataset[T] = {
+ createDataset(data.asScala)
+ }
+
+ /**
+ * :: Experimental ::
+ * Creates a [[Dataset]] with a single [[LongType]] column named `id`, containing elements
+ * in an range from 0 to `end` (exclusive) with step value 1.
+ *
+ * @since 2.0.0
+ * @group dataset
+ */
+ @Experimental
+ def range(end: Long): Dataset[java.lang.Long] = range(0, end)
+
+ /**
+ * :: Experimental ::
+ * Creates a [[Dataset]] with a single [[LongType]] column named `id`, containing elements
+ * in an range from `start` to `end` (exclusive) with step value 1.
+ *
+ * @since 2.0.0
+ * @group dataset
+ */
+ @Experimental
+ def range(start: Long, end: Long): Dataset[java.lang.Long] = {
+ range(start, end, step = 1, numPartitions = sparkContext.defaultParallelism)
+ }
+
+ /**
+ * :: Experimental ::
+ * Creates a [[Dataset]] with a single [[LongType]] column named `id`, containing elements
+ * in an range from `start` to `end` (exclusive) with an step value.
+ *
+ * @since 2.0.0
+ * @group dataset
+ */
+ @Experimental
+ def range(start: Long, end: Long, step: Long): Dataset[java.lang.Long] = {
+ range(start, end, step, numPartitions = sparkContext.defaultParallelism)
+ }
+
+ /**
+ * :: Experimental ::
+ * Creates a [[Dataset]] with a single [[LongType]] column named `id`, containing elements
+ * in an range from `start` to `end` (exclusive) with an step value, with partition number
+ * specified.
+ *
+ * @since 2.0.0
+ * @group dataset
+ */
+ @Experimental
+ def range(start: Long, end: Long, step: Long, numPartitions: Int): Dataset[java.lang.Long] = {
+ new Dataset(wrapped, Range(start, end, step, numPartitions), Encoders.LONG)
+ }
+
+ /**
+ * Creates a [[DataFrame]] from an RDD[Row].
+ * User can specify whether the input rows should be converted to Catalyst rows.
+ */
+ protected[sql] def internalCreateDataFrame(
+ catalystRows: RDD[InternalRow],
+ schema: StructType): DataFrame = {
+ // TODO: use MutableProjection when rowRDD is another DataFrame and the applied
+ // schema differs from the existing schema on any field data type.
+ val logicalPlan = LogicalRDD(schema.toAttributes, catalystRows)(wrapped)
+ Dataset.ofRows(wrapped, logicalPlan)
+ }
+
+ /**
+ * Creates a [[DataFrame]] from an RDD[Row].
+ * User can specify whether the input rows should be converted to Catalyst rows.
+ */
+ protected[sql] def createDataFrame(
+ rowRDD: RDD[Row],
+ schema: StructType,
+ needsConversion: Boolean) = {
+ // TODO: use MutableProjection when rowRDD is another DataFrame and the applied
+ // schema differs from the existing schema on any field data type.
+ val catalystRows = if (needsConversion) {
+ val converter = CatalystTypeConverters.createToCatalystConverter(schema)
+ rowRDD.map(converter(_).asInstanceOf[InternalRow])
+ } else {
+ rowRDD.map{r: Row => InternalRow.fromSeq(r.toSeq)}
+ }
+ val logicalPlan = LogicalRDD(schema.toAttributes, catalystRows)(wrapped)
+ Dataset.ofRows(wrapped, logicalPlan)
+ }
+
+
+ /* -------------------------- *
+ | Methods related to tables |
+ * -------------------------- */
+
+ /**
+ * :: Experimental ::
+ * Creates an external table from the given path and returns the corresponding DataFrame.
+ * It will use the default data source configured by spark.sql.sources.default.
+ *
+ * @group ddl_ops
+ * @since 2.0.0
+ */
+ @Experimental
+ def createExternalTable(tableName: String, path: String): DataFrame = {
+ val dataSourceName = conf.defaultDataSourceName
+ createExternalTable(tableName, path, dataSourceName)
+ }
+
+ /**
+ * :: Experimental ::
+ * Creates an external table from the given path based on a data source
+ * and returns the corresponding DataFrame.
+ *
+ * @group ddl_ops
+ * @since 2.0.0
+ */
+ @Experimental
+ def createExternalTable(tableName: String, path: String, source: String): DataFrame = {
+ createExternalTable(tableName, source, Map("path" -> path))
+ }
+
+ /**
+ * :: Experimental ::
+ * Creates an external table from the given path based on a data source and a set of options.
+ * Then, returns the corresponding DataFrame.
+ *
+ * @group ddl_ops
+ * @since 2.0.0
+ */
+ @Experimental
+ def createExternalTable(
+ tableName: String,
+ source: String,
+ options: java.util.Map[String, String]): DataFrame = {
+ createExternalTable(tableName, source, options.asScala.toMap)
+ }
+
+ /**
+ * :: Experimental ::
+ * (Scala-specific)
+ * Creates an external table from the given path based on a data source and a set of options.
+ * Then, returns the corresponding DataFrame.
+ *
+ * @group ddl_ops
+ * @since 2.0.0
+ */
+ @Experimental
+ def createExternalTable(
+ tableName: String,
+ source: String,
+ options: Map[String, String]): DataFrame = {
+ val tableIdent = sessionState.sqlParser.parseTableIdentifier(tableName)
+ val cmd =
+ CreateTableUsing(
+ tableIdent,
+ userSpecifiedSchema = None,
+ source,
+ temporary = false,
+ options,
+ allowExisting = false,
+ managedIfNoPath = false)
+ executePlan(cmd).toRdd
+ table(tableIdent)
+ }
+
+ /**
+ * :: Experimental ::
+ * Create an external table from the given path based on a data source, a schema and
+ * a set of options. Then, returns the corresponding DataFrame.
+ *
+ * @group ddl_ops
+ * @since 2.0.0
+ */
+ @Experimental
+ def createExternalTable(
+ tableName: String,
+ source: String,
+ schema: StructType,
+ options: java.util.Map[String, String]): DataFrame = {
+ createExternalTable(tableName, source, schema, options.asScala.toMap)
+ }
+
+ /**
+ * :: Experimental ::
+ * (Scala-specific)
+ * Create an external table from the given path based on a data source, a schema and
+ * a set of options. Then, returns the corresponding DataFrame.
+ *
+ * @group ddl_ops
+ * @since 2.0.0
+ */
+ @Experimental
+ def createExternalTable(
+ tableName: String,
+ source: String,
+ schema: StructType,
+ options: Map[String, String]): DataFrame = {
+ val tableIdent = sessionState.sqlParser.parseTableIdentifier(tableName)
+ val cmd =
+ CreateTableUsing(
+ tableIdent,
+ userSpecifiedSchema = Some(schema),
+ source,
+ temporary = false,
+ options,
+ allowExisting = false,
+ managedIfNoPath = false)
+ executePlan(cmd).toRdd
+ table(tableIdent)
+ }
+
+ /**
+ * Drops the temporary table with the given table name in the catalog.
+ * If the table has been cached/persisted before, it's also unpersisted.
+ *
+ * @param tableName the name of the table to be unregistered.
+ * @group ddl_ops
+ * @since 2.0.0
+ */
+ def dropTempTable(tableName: String): Unit = {
+ cacheManager.tryUncacheQuery(table(tableName))
+ sessionState.catalog.dropTable(TableIdentifier(tableName), ignoreIfNotExists = true)
+ }
+
+ /**
+ * Returns the specified table as a [[DataFrame]].
+ *
+ * @group ddl_ops
+ * @since 2.0.0
+ */
+ def table(tableName: String): DataFrame = {
+ table(sessionState.sqlParser.parseTableIdentifier(tableName))
+ }
+
+ private def table(tableIdent: TableIdentifier): DataFrame = {
+ Dataset.ofRows(wrapped, sessionState.catalog.lookupRelation(tableIdent))
+ }
+
+ /**
+ * Returns a [[DataFrame]] containing names of existing tables in the current database.
+ * The returned DataFrame has two columns, tableName and isTemporary (a Boolean
+ * indicating if a table is a temporary one or not).
+ *
+ * @group ddl_ops
+ * @since 2.0.0
+ */
+ def tables(): DataFrame = {
+ Dataset.ofRows(wrapped, ShowTablesCommand(None, None))
+ }
+
+ /**
+ * Returns a [[DataFrame]] containing names of existing tables in the given database.
+ * The returned DataFrame has two columns, tableName and isTemporary (a Boolean
+ * indicating if a table is a temporary one or not).
+ *
+ * @group ddl_ops
+ * @since 2.0.0
+ */
+ def tables(databaseName: String): DataFrame = {
+ Dataset.ofRows(wrapped, ShowTablesCommand(Some(databaseName), None))
+ }
+
+ /**
+ * Returns the names of tables in the current database as an array.
+ *
+ * @group ddl_ops
+ * @since 2.0.0
+ */
+ def tableNames(): Array[String] = {
+ tableNames(sessionState.catalog.getCurrentDatabase)
+ }
+
+ /**
+ * Returns the names of tables in the given database as an array.
+ *
+ * @group ddl_ops
+ * @since 2.0.0
+ */
+ def tableNames(databaseName: String): Array[String] = {
+ sessionState.catalog.listTables(databaseName).map(_.table).toArray
+ }
+
+ /**
+ * Registers the given [[DataFrame]] as a temporary table in the catalog.
+ * Temporary tables exist only during the lifetime of this instance of [[SparkSession]].
+ */
+ protected[sql] def registerDataFrameAsTable(df: DataFrame, tableName: String): Unit = {
+ sessionState.catalog.createTempTable(
+ sessionState.sqlParser.parseTableIdentifier(tableName).table,
+ df.logicalPlan,
+ overrideIfExists = true)
+ }
+
+
+ /* ---------------- *
+ | Everything else |
+ * ---------------- */
+
+ /**
+ * Executes a SQL query using Spark, returning the result as a [[DataFrame]].
+ * The dialect that is used for SQL parsing can be configured with 'spark.sql.dialect'.
+ *
+ * @group basic
+ * @since 2.0.0
+ */
+ def sql(sqlText: String): DataFrame = {
+ Dataset.ofRows(wrapped, parseSql(sqlText))
+ }
+
+ /**
+ * :: Experimental ::
+ * Returns a [[DataFrameReader]] that can be used to read data and streams in as a [[DataFrame]].
+ * {{{
+ * sparkSession.read.parquet("/path/to/file.parquet")
+ * sparkSession.read.schema(schema).json("/path/to/file.json")
+ * }}}
+ *
+ * @group genericdata
+ * @since 2.0.0
+ */
+ @Experimental
+ def read: DataFrameReader = new DataFrameReader(wrapped)
+
+
+ // scalastyle:off
+ // Disable style checker so "implicits" object can start with lowercase i
+ /**
+ * :: Experimental ::
+ * (Scala-specific) Implicit methods available in Scala for converting
+ * common Scala objects into [[DataFrame]]s.
+ *
+ * {{{
+ * val sparkSession = new SparkSession(sc)
+ * import sparkSession.implicits._
+ * }}}
+ *
+ * @group basic
+ * @since 2.0.0
+ */
+ @Experimental
+ object implicits extends SQLImplicits with Serializable {
+ protected override def _sqlContext: SQLContext = wrapped
+ }
+ // scalastyle:on
+
+ protected[sql] def parseSql(sql: String): LogicalPlan = {
+ sessionState.sqlParser.parsePlan(sql)
+ }
+
+ protected[sql] def executeSql(sql: String): QueryExecution = {
+ executePlan(parseSql(sql))
+ }
+
+ protected[sql] def executePlan(plan: LogicalPlan): QueryExecution = {
+ sessionState.executePlan(plan)
+ }
+
+ /**
+ * Executes a SQL query without parsing it, but instead passing it directly to an underlying
+ * system to process. This is currently only used for Hive DDLs and will be removed as soon
+ * as Spark can parse all supported Hive DDLs itself.
+ */
+ protected[sql] def runNativeSql(sqlText: String): Seq[Row] = {
+ sessionState.runNativeSql(sqlText).map { r => Row(r) }
+ }
+
+ /**
+ * Parses the data type in our internal string representation. The data type string should
+ * have the same format as the one generated by `toString` in scala.
+ * It is only used by PySpark.
+ */
+ protected[sql] def parseDataType(dataTypeString: String): DataType = {
+ DataType.fromJson(dataTypeString)
+ }
+
+ /**
+ * Apply a schema defined by the schemaString to an RDD. It is only used by PySpark.
+ */
+ protected[sql] def applySchemaToPythonRDD(
+ rdd: RDD[Array[Any]],
+ schemaString: String): DataFrame = {
+ val schema = parseDataType(schemaString).asInstanceOf[StructType]
+ applySchemaToPythonRDD(rdd, schema)
+ }
+
+ /**
+ * Apply a schema defined by the schema to an RDD. It is only used by PySpark.
+ */
+ protected[sql] def applySchemaToPythonRDD(
+ rdd: RDD[Array[Any]],
+ schema: StructType): DataFrame = {
+ val rowRdd = rdd.map(r => python.EvaluatePython.fromJava(r, schema).asInstanceOf[InternalRow])
+ Dataset.ofRows(wrapped, LogicalRDD(schema.toAttributes, rowRdd)(wrapped))
+ }
+
+ /**
+ * Returns a Catalyst Schema for the given java bean class.
+ */
+ private def getSchema(beanClass: Class[_]): Seq[AttributeReference] = {
+ val (dataType, _) = JavaTypeInference.inferDataType(beanClass)
+ dataType.asInstanceOf[StructType].fields.map { f =>
+ AttributeReference(f.name, f.dataType, f.nullable)()
+ }
+ }
+
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/resources.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/resources.scala
index e7191e4bfe..fc7ecb11ec 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/resources.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/resources.scala
@@ -32,7 +32,7 @@ case class AddJar(path: String) extends RunnableCommand {
}
override def run(sqlContext: SQLContext): Seq[Row] = {
- sqlContext.addJar(path)
+ sqlContext.sessionState.addJar(path)
Seq(Row(0))
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
index 08a99627bf..8563dc3d5a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
@@ -56,7 +56,9 @@ private[sql] class SessionState(ctx: SQLContext) {
*/
lazy val functionRegistry: FunctionRegistry = FunctionRegistry.builtin.copy()
- /** A [[FunctionResourceLoader]] that can be used in SessionCatalog. */
+ /**
+ * A class for loading resources specified by a function.
+ */
lazy val functionResourceLoader: FunctionResourceLoader = {
new FunctionResourceLoader {
override def loadResource(resource: FunctionResource): Unit = {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala
index 7ab79b12ce..431ac8e2c8 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala
@@ -18,13 +18,20 @@
package org.apache.spark.sql.test
import org.apache.spark.{SparkConf, SparkContext}
-import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.{SparkSession, SQLContext}
import org.apache.spark.sql.internal.{SessionState, SQLConf}
/**
* A special [[SQLContext]] prepared for testing.
*/
-private[sql] class TestSQLContext(sc: SparkContext) extends SQLContext(sc) { self =>
+private[sql] class TestSQLContext(
+ @transient private val sparkSession: SparkSession,
+ isRootContext: Boolean)
+ extends SQLContext(sparkSession, isRootContext) { self =>
+
+ def this(sc: SparkContext) {
+ this(new TestSparkSession(sc), true)
+ }
def this(sparkConf: SparkConf) {
this(new SparkContext("local[2]", "test-sql-context",
@@ -35,8 +42,22 @@ private[sql] class TestSQLContext(sc: SparkContext) extends SQLContext(sc) { sel
this(new SparkConf)
}
+ // Needed for Java tests
+ def loadTestData(): Unit = {
+ testData.loadTestData()
+ }
+
+ private object testData extends SQLTestData {
+ protected override def sqlContext: SQLContext = self
+ }
+
+}
+
+
+private[sql] class TestSparkSession(sc: SparkContext) extends SparkSession(sc) { self =>
+
@transient
- protected[sql] override lazy val sessionState: SessionState = new SessionState(self) {
+ protected[sql] override lazy val sessionState: SessionState = new SessionState(wrapped) {
override lazy val conf: SQLConf = {
new SQLConf {
clear()
@@ -49,16 +70,9 @@ private[sql] class TestSQLContext(sc: SparkContext) extends SQLContext(sc) { sel
}
}
- // Needed for Java tests
- def loadTestData(): Unit = {
- testData.loadTestData()
- }
-
- private object testData extends SQLTestData {
- protected override def sqlContext: SQLContext = self
- }
}
+
private[sql] object TestSQLContext {
/**