From a2e8d4fddd1446df946b3c05223e8b8ac6312c3c Mon Sep 17 00:00:00 2001 From: Andrew Or Date: Thu, 21 Apr 2016 14:18:18 -0700 Subject: [SPARK-13643][SQL] Implement SparkSession ## What changes were proposed in this pull request? After removing most of `HiveContext` in 8fc267ab3322e46db81e725a5cb1adb5a71b2b4d we can now move existing functionality in `SQLContext` to `SparkSession`. As of this PR `SQLContext` becomes a simple wrapper that has a `SparkSession` and delegates all functionality to it. ## How was this patch tested? Jenkins. Author: Andrew Or Closes #12553 from andrewor14/implement-spark-session. --- .../scala/org/apache/spark/sql/SQLContext.scala | 254 ++---- .../scala/org/apache/spark/sql/SparkSession.scala | 862 ++++++++++++++++++++- .../spark/sql/execution/command/resources.scala | 2 +- .../apache/spark/sql/internal/SessionState.scala | 4 +- .../org/apache/spark/sql/test/TestSQLContext.scala | 36 +- 5 files changed, 961 insertions(+), 197 deletions(-) (limited to 'sql') diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index d85ddd5a98..4c9977c8c7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -17,11 +17,10 @@ package org.apache.spark.sql -import java.beans.{BeanInfo, Introspector} +import java.beans.BeanInfo import java.util.Properties import java.util.concurrent.atomic.AtomicReference -import scala.collection.JavaConverters._ import scala.collection.immutable import scala.reflect.runtime.universe.TypeTag @@ -34,18 +33,14 @@ import org.apache.spark.rdd.RDD import org.apache.spark.scheduler.{SparkListener, SparkListenerApplicationEnd} import org.apache.spark.sql.catalyst._ import org.apache.spark.sql.catalyst.catalog._ -import org.apache.spark.sql.catalyst.encoders.encoderFor import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Range} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution._ -import org.apache.spark.sql.execution.command.ShowTablesCommand -import org.apache.spark.sql.execution.datasources._ import org.apache.spark.sql.execution.ui.{SQLListener, SQLTab} import org.apache.spark.sql.internal.{SessionState, SharedState, SQLConf} import org.apache.spark.sql.sources.BaseRelation import org.apache.spark.sql.types._ import org.apache.spark.sql.util.ExecutionListenerManager -import org.apache.spark.util.Utils /** * The entry point for working with structured data (rows and columns) in Spark. Allows the @@ -69,6 +64,9 @@ class SQLContext private[sql]( self => + // Note: Since Spark 2.0 this class has become a wrapper of SparkSession, where the + // real functionality resides. This class remains mainly for backward compatibility. + private[sql] def this(sparkSession: SparkSession) = { this(sparkSession, true) } @@ -79,6 +77,8 @@ class SQLContext private[sql]( def this(sparkContext: JavaSparkContext) = this(sparkContext.sc) + // TODO: move this logic into SparkSession + // If spark.sql.allowMultipleContexts is true, we will throw an exception if a user // wants to create a new root SQLContext (a SQLContext that is not created by newSession). private val allowMultipleContexts = @@ -103,12 +103,12 @@ class SQLContext private[sql]( protected[sql] def sessionState: SessionState = sparkSession.sessionState protected[sql] def sharedState: SharedState = sparkSession.sharedState - protected[sql] def conf: SQLConf = sessionState.conf - protected[sql] def cacheManager: CacheManager = sharedState.cacheManager - protected[sql] def listener: SQLListener = sharedState.listener - protected[sql] def externalCatalog: ExternalCatalog = sharedState.externalCatalog + protected[sql] def conf: SQLConf = sparkSession.conf + protected[sql] def cacheManager: CacheManager = sparkSession.cacheManager + protected[sql] def listener: SQLListener = sparkSession.listener + protected[sql] def externalCatalog: ExternalCatalog = sparkSession.externalCatalog - def sparkContext: SparkContext = sharedState.sparkContext + def sparkContext: SparkContext = sparkSession.sparkContext /** * Returns a [[SQLContext]] as new session, with separated SQL configurations, temporary @@ -117,16 +117,14 @@ class SQLContext private[sql]( * * @since 1.6.0 */ - def newSession(): SQLContext = { - new SQLContext(sparkSession.newSession(), isRootContext = false) - } + def newSession(): SQLContext = sparkSession.newSession().wrapped /** * An interface to register custom [[org.apache.spark.sql.util.QueryExecutionListener]]s * that listen for execution metrics. */ @Experimental - def listenerManager: ExecutionListenerManager = sessionState.listenerManager + def listenerManager: ExecutionListenerManager = sparkSession.listenerManager /** * Set Spark SQL configuration properties. @@ -134,13 +132,13 @@ class SQLContext private[sql]( * @group config * @since 1.0.0 */ - def setConf(props: Properties): Unit = sessionState.setConf(props) + def setConf(props: Properties): Unit = sparkSession.setConf(props) /** * Set the given Spark SQL configuration property. */ private[sql] def setConf[T](entry: ConfigEntry[T], value: T): Unit = { - sessionState.setConf(entry, value) + sparkSession.setConf(entry, value) } /** @@ -149,7 +147,7 @@ class SQLContext private[sql]( * @group config * @since 1.0.0 */ - def setConf(key: String, value: String): Unit = sessionState.setConf(key, value) + def setConf(key: String, value: String): Unit = sparkSession.setConf(key, value) /** * Return the value of Spark SQL configuration property for the given key. @@ -157,13 +155,13 @@ class SQLContext private[sql]( * @group config * @since 1.0.0 */ - def getConf(key: String): String = conf.getConfString(key) + def getConf(key: String): String = sparkSession.getConf(key) /** * Return the value of Spark SQL configuration property for the given key. If the key is not set * yet, return `defaultValue` in [[ConfigEntry]]. */ - private[sql] def getConf[T](entry: ConfigEntry[T]): T = conf.getConf(entry) + private[sql] def getConf[T](entry: ConfigEntry[T]): T = sparkSession.getConf(entry) /** * Return the value of Spark SQL configuration property for the given key. If the key is not set @@ -171,7 +169,7 @@ class SQLContext private[sql]( * desired one. */ private[sql] def getConf[T](entry: ConfigEntry[T], defaultValue: T): T = { - conf.getConf(entry, defaultValue) + sparkSession.getConf(entry, defaultValue) } /** @@ -181,7 +179,7 @@ class SQLContext private[sql]( * @group config * @since 1.0.0 */ - def getConf(key: String, defaultValue: String): String = conf.getConfString(key, defaultValue) + def getConf(key: String, defaultValue: String): String = sparkSession.getConf(key, defaultValue) /** * Return all the configuration properties that have been set (i.e. not the default). @@ -190,21 +188,14 @@ class SQLContext private[sql]( * @group config * @since 1.0.0 */ - def getAllConfs: immutable.Map[String, String] = conf.getAllConfs + def getAllConfs: immutable.Map[String, String] = sparkSession.getAllConfs - protected[sql] def parseSql(sql: String): LogicalPlan = sessionState.sqlParser.parsePlan(sql) + protected[sql] def parseSql(sql: String): LogicalPlan = sparkSession.parseSql(sql) - protected[sql] def executeSql(sql: String): QueryExecution = executePlan(parseSql(sql)) + protected[sql] def executeSql(sql: String): QueryExecution = sparkSession.executeSql(sql) protected[sql] def executePlan(plan: LogicalPlan): QueryExecution = { - sessionState.executePlan(plan) - } - - /** - * Add a jar to SQLContext - */ - protected[sql] def addJar(path: String): Unit = { - sessionState.addJar(path) + sparkSession.executePlan(plan) } /** @@ -217,7 +208,7 @@ class SQLContext private[sql]( */ @Experimental @transient - def experimental: ExperimentalMethods = sessionState.experimentalMethods + def experimental: ExperimentalMethods = sparkSession.experimental /** * :: Experimental :: @@ -227,8 +218,7 @@ class SQLContext private[sql]( * @since 1.3.0 */ @Experimental - @transient - lazy val emptyDataFrame: DataFrame = createDataFrame(sparkContext.emptyRDD[Row], StructType(Nil)) + def emptyDataFrame: DataFrame = sparkSession.emptyDataFrame /** * A collection of methods for registering user-defined functions (UDF). @@ -259,7 +249,7 @@ class SQLContext private[sql]( * @group basic * @since 1.3.0 */ - def udf: UDFRegistration = sessionState.udf + def udf: UDFRegistration = sparkSession.udf /** * Returns true if the table is currently cached in-memory. @@ -267,7 +257,7 @@ class SQLContext private[sql]( * @since 1.3.0 */ def isCached(tableName: String): Boolean = { - cacheManager.lookupCachedData(table(tableName)).nonEmpty + sparkSession.isCached(tableName) } /** @@ -276,7 +266,7 @@ class SQLContext private[sql]( * @since 1.3.0 */ private[sql] def isCached(qName: Dataset[_]): Boolean = { - cacheManager.lookupCachedData(qName).nonEmpty + sparkSession.isCached(qName) } /** @@ -285,7 +275,7 @@ class SQLContext private[sql]( * @since 1.3.0 */ def cacheTable(tableName: String): Unit = { - cacheManager.cacheQuery(table(tableName), Some(tableName)) + sparkSession.cacheTable(tableName) } /** @@ -293,13 +283,17 @@ class SQLContext private[sql]( * @group cachemgmt * @since 1.3.0 */ - def uncacheTable(tableName: String): Unit = cacheManager.uncacheQuery(table(tableName)) + def uncacheTable(tableName: String): Unit = { + sparkSession.uncacheTable(tableName) + } /** * Removes all cached tables from the in-memory cache. * @since 1.3.0 */ - def clearCache(): Unit = cacheManager.clearCache() + def clearCache(): Unit = { + sparkSession.clearCache() + } // scalastyle:off // Disable style checker so "implicits" object can start with lowercase i @@ -331,11 +325,7 @@ class SQLContext private[sql]( */ @Experimental def createDataFrame[A <: Product : TypeTag](rdd: RDD[A]): DataFrame = { - SQLContext.setActive(self) - val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] - val attributeSeq = schema.toAttributes - val rowRDD = RDDConversions.productToRowRdd(rdd, schema.map(_.dataType)) - Dataset.ofRows(self, LogicalRDD(attributeSeq, rowRDD)(self)) + sparkSession.createDataFrame(rdd) } /** @@ -347,10 +337,7 @@ class SQLContext private[sql]( */ @Experimental def createDataFrame[A <: Product : TypeTag](data: Seq[A]): DataFrame = { - SQLContext.setActive(self) - val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] - val attributeSeq = schema.toAttributes - Dataset.ofRows(self, LocalRelation.fromProduct(attributeSeq, data)) + sparkSession.createDataFrame(data) } /** @@ -360,7 +347,7 @@ class SQLContext private[sql]( * @since 1.3.0 */ def baseRelationToDataFrame(baseRelation: BaseRelation): DataFrame = { - Dataset.ofRows(this, LogicalRelation(baseRelation)) + sparkSession.baseRelationToDataFrame(baseRelation) } /** @@ -397,7 +384,7 @@ class SQLContext private[sql]( */ @DeveloperApi def createDataFrame(rowRDD: RDD[Row], schema: StructType): DataFrame = { - createDataFrame(rowRDD, schema, needsConversion = true) + sparkSession.createDataFrame(rowRDD, schema) } /** @@ -406,39 +393,20 @@ class SQLContext private[sql]( */ private[sql] def createDataFrame(rowRDD: RDD[Row], schema: StructType, needsConversion: Boolean) = { - // TODO: use MutableProjection when rowRDD is another DataFrame and the applied - // schema differs from the existing schema on any field data type. - val catalystRows = if (needsConversion) { - val converter = CatalystTypeConverters.createToCatalystConverter(schema) - rowRDD.map(converter(_).asInstanceOf[InternalRow]) - } else { - rowRDD.map{r: Row => InternalRow.fromSeq(r.toSeq)} - } - val logicalPlan = LogicalRDD(schema.toAttributes, catalystRows)(self) - Dataset.ofRows(this, logicalPlan) + sparkSession.createDataFrame(rowRDD, schema, needsConversion) } def createDataset[T : Encoder](data: Seq[T]): Dataset[T] = { - val enc = encoderFor[T] - val attributes = enc.schema.toAttributes - val encoded = data.map(d => enc.toRow(d).copy()) - val plan = new LocalRelation(attributes, encoded) - - Dataset[T](this, plan) + sparkSession.createDataset(data) } def createDataset[T : Encoder](data: RDD[T]): Dataset[T] = { - val enc = encoderFor[T] - val attributes = enc.schema.toAttributes - val encoded = data.map(d => enc.toRow(d)) - val plan = LogicalRDD(attributes, encoded)(self) - - Dataset[T](this, plan) + sparkSession.createDataset(data) } def createDataset[T : Encoder](data: java.util.List[T]): Dataset[T] = { - createDataset(data.asScala) + sparkSession.createDataset(data) } /** @@ -447,10 +415,7 @@ class SQLContext private[sql]( */ private[sql] def internalCreateDataFrame(catalystRows: RDD[InternalRow], schema: StructType) = { - // TODO: use MutableProjection when rowRDD is another DataFrame and the applied - // schema differs from the existing schema on any field data type. - val logicalPlan = LogicalRDD(schema.toAttributes, catalystRows)(self) - Dataset.ofRows(this, logicalPlan) + sparkSession.internalCreateDataFrame(catalystRows, schema) } /** @@ -464,7 +429,7 @@ class SQLContext private[sql]( */ @DeveloperApi def createDataFrame(rowRDD: JavaRDD[Row], schema: StructType): DataFrame = { - createDataFrame(rowRDD.rdd, schema) + sparkSession.createDataFrame(rowRDD, schema) } /** @@ -478,7 +443,7 @@ class SQLContext private[sql]( */ @DeveloperApi def createDataFrame(rows: java.util.List[Row], schema: StructType): DataFrame = { - Dataset.ofRows(self, LocalRelation.fromExternalRows(schema.toAttributes, rows.asScala)) + sparkSession.createDataFrame(rows, schema) } /** @@ -490,14 +455,7 @@ class SQLContext private[sql]( * @since 1.3.0 */ def createDataFrame(rdd: RDD[_], beanClass: Class[_]): DataFrame = { - val attributeSeq: Seq[AttributeReference] = getSchema(beanClass) - val className = beanClass.getName - val rowRdd = rdd.mapPartitions { iter => - // BeanInfo is not serializable so we must rediscover it remotely for each partition. - val localBeanInfo = Introspector.getBeanInfo(Utils.classForName(className)) - SQLContext.beansToRows(iter, localBeanInfo, attributeSeq) - } - Dataset.ofRows(this, LogicalRDD(attributeSeq, rowRdd)(this)) + sparkSession.createDataFrame(rdd, beanClass) } /** @@ -509,7 +467,7 @@ class SQLContext private[sql]( * @since 1.3.0 */ def createDataFrame(rdd: JavaRDD[_], beanClass: Class[_]): DataFrame = { - createDataFrame(rdd.rdd, beanClass) + sparkSession.createDataFrame(rdd, beanClass) } /** @@ -521,11 +479,7 @@ class SQLContext private[sql]( * @since 1.6.0 */ def createDataFrame(data: java.util.List[_], beanClass: Class[_]): DataFrame = { - val attrSeq = getSchema(beanClass) - val className = beanClass.getName - val beanInfo = Introspector.getBeanInfo(beanClass) - val rows = SQLContext.beansToRows(data.asScala.iterator, beanInfo, attrSeq) - Dataset.ofRows(self, LocalRelation(attrSeq, rows.toSeq)) + sparkSession.createDataFrame(data, beanClass) } /** @@ -540,7 +494,7 @@ class SQLContext private[sql]( * @since 1.4.0 */ @Experimental - def read: DataFrameReader = new DataFrameReader(this) + def read: DataFrameReader = sparkSession.read /** * :: Experimental :: @@ -552,8 +506,7 @@ class SQLContext private[sql]( */ @Experimental def createExternalTable(tableName: String, path: String): DataFrame = { - val dataSourceName = conf.defaultDataSourceName - createExternalTable(tableName, path, dataSourceName) + sparkSession.createExternalTable(tableName, path) } /** @@ -569,7 +522,7 @@ class SQLContext private[sql]( tableName: String, path: String, source: String): DataFrame = { - createExternalTable(tableName, source, Map("path" -> path)) + sparkSession.createExternalTable(tableName, path, source) } /** @@ -585,7 +538,7 @@ class SQLContext private[sql]( tableName: String, source: String, options: java.util.Map[String, String]): DataFrame = { - createExternalTable(tableName, source, options.asScala.toMap) + sparkSession.createExternalTable(tableName, source, options) } /** @@ -602,18 +555,7 @@ class SQLContext private[sql]( tableName: String, source: String, options: Map[String, String]): DataFrame = { - val tableIdent = sessionState.sqlParser.parseTableIdentifier(tableName) - val cmd = - CreateTableUsing( - tableIdent, - userSpecifiedSchema = None, - source, - temporary = false, - options, - allowExisting = false, - managedIfNoPath = false) - executePlan(cmd).toRdd - table(tableIdent) + sparkSession.createExternalTable(tableName, source, options) } /** @@ -630,7 +572,7 @@ class SQLContext private[sql]( source: String, schema: StructType, options: java.util.Map[String, String]): DataFrame = { - createExternalTable(tableName, source, schema, options.asScala.toMap) + sparkSession.createExternalTable(tableName, source, schema, options) } /** @@ -648,18 +590,7 @@ class SQLContext private[sql]( source: String, schema: StructType, options: Map[String, String]): DataFrame = { - val tableIdent = sessionState.sqlParser.parseTableIdentifier(tableName) - val cmd = - CreateTableUsing( - tableIdent, - userSpecifiedSchema = Some(schema), - source, - temporary = false, - options, - allowExisting = false, - managedIfNoPath = false) - executePlan(cmd).toRdd - table(tableIdent) + sparkSession.createExternalTable(tableName, source, schema, options) } /** @@ -667,10 +598,7 @@ class SQLContext private[sql]( * only during the lifetime of this instance of SQLContext. */ private[sql] def registerDataFrameAsTable(df: DataFrame, tableName: String): Unit = { - sessionState.catalog.createTempTable( - sessionState.sqlParser.parseTableIdentifier(tableName).table, - df.logicalPlan, - overrideIfExists = true) + sparkSession.registerDataFrameAsTable(df, tableName) } /** @@ -682,8 +610,7 @@ class SQLContext private[sql]( * @since 1.3.0 */ def dropTempTable(tableName: String): Unit = { - cacheManager.tryUncacheQuery(table(tableName)) - sessionState.catalog.dropTable(TableIdentifier(tableName), ignoreIfNotExists = true) + sparkSession.dropTempTable(tableName) } /** @@ -695,7 +622,7 @@ class SQLContext private[sql]( * @group dataset */ @Experimental - def range(end: Long): Dataset[java.lang.Long] = range(0, end) + def range(end: Long): Dataset[java.lang.Long] = sparkSession.range(end) /** * :: Experimental :: @@ -706,9 +633,7 @@ class SQLContext private[sql]( * @group dataset */ @Experimental - def range(start: Long, end: Long): Dataset[java.lang.Long] = { - range(start, end, step = 1, numPartitions = sparkContext.defaultParallelism) - } + def range(start: Long, end: Long): Dataset[java.lang.Long] = sparkSession.range(start, end) /** * :: Experimental :: @@ -720,7 +645,7 @@ class SQLContext private[sql]( */ @Experimental def range(start: Long, end: Long, step: Long): Dataset[java.lang.Long] = { - range(start, end, step, numPartitions = sparkContext.defaultParallelism) + sparkSession.range(start, end, step) } /** @@ -734,7 +659,7 @@ class SQLContext private[sql]( */ @Experimental def range(start: Long, end: Long, step: Long, numPartitions: Int): Dataset[java.lang.Long] = { - new Dataset(this, Range(start, end, step, numPartitions), Encoders.LONG) + sparkSession.range(start, end, step, numPartitions) } /** @@ -744,9 +669,7 @@ class SQLContext private[sql]( * @group basic * @since 1.3.0 */ - def sql(sqlText: String): DataFrame = { - Dataset.ofRows(this, parseSql(sqlText)) - } + def sql(sqlText: String): DataFrame = sparkSession.sql(sqlText) /** * Executes a SQL query without parsing it, but instead passing it directly to an underlying @@ -754,7 +677,7 @@ class SQLContext private[sql]( * as Spark can parse all supported Hive DDLs itself. */ private[sql] def runNativeSql(sqlText: String): Seq[Row] = { - sessionState.runNativeSql(sqlText).map { r => Row(r) } + sparkSession.runNativeSql(sqlText) } /** @@ -764,11 +687,7 @@ class SQLContext private[sql]( * @since 1.3.0 */ def table(tableName: String): DataFrame = { - table(sessionState.sqlParser.parseTableIdentifier(tableName)) - } - - private def table(tableIdent: TableIdentifier): DataFrame = { - Dataset.ofRows(this, sessionState.catalog.lookupRelation(tableIdent)) + sparkSession.table(tableName) } /** @@ -780,7 +699,7 @@ class SQLContext private[sql]( * @since 1.3.0 */ def tables(): DataFrame = { - Dataset.ofRows(this, ShowTablesCommand(None, None)) + sparkSession.tables() } /** @@ -792,7 +711,7 @@ class SQLContext private[sql]( * @since 1.3.0 */ def tables(databaseName: String): DataFrame = { - Dataset.ofRows(this, ShowTablesCommand(Some(databaseName), None)) + sparkSession.tables(databaseName) } /** @@ -801,7 +720,7 @@ class SQLContext private[sql]( * * @since 2.0.0 */ - def streams: ContinuousQueryManager = sessionState.continuousQueryManager + def streams: ContinuousQueryManager = sparkSession.streams /** * Returns the names of tables in the current database as an array. @@ -810,7 +729,7 @@ class SQLContext private[sql]( * @since 1.3.0 */ def tableNames(): Array[String] = { - tableNames(sessionState.catalog.getCurrentDatabase) + sparkSession.tableNames() } /** @@ -820,19 +739,16 @@ class SQLContext private[sql]( * @since 1.3.0 */ def tableNames(databaseName: String): Array[String] = { - sessionState.catalog.listTables(databaseName).map(_.table).toArray + sparkSession.tableNames(databaseName) } - @transient - protected[sql] lazy val emptyResult = sparkContext.parallelize(Seq.empty[InternalRow], 1) - /** * Parses the data type in our internal string representation. The data type string should * have the same format as the one generated by `toString` in scala. * It is only used by PySpark. */ protected[sql] def parseDataType(dataTypeString: String): DataType = { - DataType.fromJson(dataTypeString) + sparkSession.parseDataType(dataTypeString) } /** @@ -841,8 +757,7 @@ class SQLContext private[sql]( protected[sql] def applySchemaToPythonRDD( rdd: RDD[Array[Any]], schemaString: String): DataFrame = { - val schema = parseDataType(schemaString).asInstanceOf[StructType] - applySchemaToPythonRDD(rdd, schema) + sparkSession.applySchemaToPythonRDD(rdd, schemaString) } /** @@ -851,20 +766,10 @@ class SQLContext private[sql]( protected[sql] def applySchemaToPythonRDD( rdd: RDD[Array[Any]], schema: StructType): DataFrame = { - - val rowRdd = rdd.map(r => python.EvaluatePython.fromJava(r, schema).asInstanceOf[InternalRow]) - Dataset.ofRows(this, LogicalRDD(schema.toAttributes, rowRdd)(self)) + sparkSession.applySchemaToPythonRDD(rdd, schema) } - /** - * Returns a Catalyst Schema for the given java bean class. - */ - protected def getSchema(beanClass: Class[_]): Seq[AttributeReference] = { - val (dataType, _) = JavaTypeInference.inferDataType(beanClass) - dataType.asInstanceOf[StructType].fields.map { f => - AttributeReference(f.name, f.dataType, f.nullable)() - } - } + // TODO: move this logic into SparkSession // Register a successfully instantiated context to the singleton. This should be at the end of // the class definition so that the singleton is updated only if there is no exception in the @@ -876,6 +781,7 @@ class SQLContext private[sql]( } }) + sparkSession.setWrappedContext(self) SQLContext.setInstantiatedContext(self) } @@ -980,8 +886,10 @@ object SQLContext { * bean info & schema. This is not related to the singleton, but is a static * method for internal use. */ - private def beansToRows(data: Iterator[_], beanInfo: BeanInfo, attrs: Seq[AttributeReference]): - Iterator[InternalRow] = { + private[sql] def beansToRows( + data: Iterator[_], + beanInfo: BeanInfo, + attrs: Seq[AttributeReference]): Iterator[InternalRow] = { val extractors = beanInfo.getPropertyDescriptors.filterNot(_.getName == "class").map(_.getReadMethod) val methodsToConverts = extractors.zip(attrs).map { case (e, attr) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala index 17ba299825..70d889b002 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala @@ -17,12 +17,33 @@ package org.apache.spark.sql +import java.beans.Introspector +import java.util.Properties + +import scala.collection.immutable +import scala.collection.JavaConverters._ import scala.reflect.ClassTag +import scala.reflect.runtime.universe.TypeTag import scala.util.control.NonFatal import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.internal.config.CATALOG_IMPLEMENTATION -import org.apache.spark.sql.internal.{SessionState, SharedState} +import org.apache.spark.annotation.{DeveloperApi, Experimental} +import org.apache.spark.api.java.JavaRDD +import org.apache.spark.internal.config.{CATALOG_IMPLEMENTATION, ConfigEntry} +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst._ +import org.apache.spark.sql.catalyst.catalog._ +import org.apache.spark.sql.catalyst.encoders._ +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.catalyst.plans.logical.{LocalRelation, LogicalPlan, Range} +import org.apache.spark.sql.execution._ +import org.apache.spark.sql.execution.command.ShowTablesCommand +import org.apache.spark.sql.execution.datasources.{CreateTableUsing, LogicalRelation} +import org.apache.spark.sql.execution.ui.SQLListener +import org.apache.spark.sql.internal.{SessionState, SharedState, SQLConf} +import org.apache.spark.sql.sources.BaseRelation +import org.apache.spark.sql.types.{DataType, LongType, StructType} +import org.apache.spark.sql.util.ExecutionListenerManager import org.apache.spark.util.Utils @@ -30,22 +51,22 @@ import org.apache.spark.util.Utils * The entry point to Spark execution. */ class SparkSession private( - sparkContext: SparkContext, - existingSharedState: Option[SharedState]) { self => + @transient val sparkContext: SparkContext, + @transient private val existingSharedState: Option[SharedState]) { self => def this(sc: SparkContext) { this(sc, None) } + + /* ----------------------- * + | Session-related state | + * ----------------------- */ + /** - * Start a new session where configurations, temp tables, temp functions etc. are isolated. + * State shared across sessions, including the [[SparkContext]], cached data, listener, + * and a catalog that interacts with external systems. */ - def newSession(): SparkSession = { - // Note: materialize the shared state here to ensure the parent and child sessions are - // initialized with the same shared state. - new SparkSession(sparkContext, Some(sharedState)) - } - @transient protected[sql] lazy val sharedState: SharedState = { existingSharedState.getOrElse( @@ -54,6 +75,10 @@ class SparkSession private( sparkContext)) } + /** + * State isolated across sessions, including SQL configurations, temporary tables, + * registered functions, and everything else that accepts a [[SQLConf]]. + */ @transient protected[sql] lazy val sessionState: SessionState = { SparkSession.reflect[SessionState, SQLContext]( @@ -61,6 +86,821 @@ class SparkSession private( new SQLContext(self, isRootContext = false)) } + /** + * A wrapped version of this session in the form of a [[SQLContext]]. + */ + @transient + private var _wrapped: SQLContext = _ + + protected[sql] def wrapped: SQLContext = { + if (_wrapped == null) { + _wrapped = new SQLContext(self, isRootContext = false) + } + _wrapped + } + + protected[sql] def setWrappedContext(sqlContext: SQLContext): Unit = { + _wrapped = sqlContext + } + + protected[sql] def conf: SQLConf = sessionState.conf + protected[sql] def cacheManager: CacheManager = sharedState.cacheManager + protected[sql] def listener: SQLListener = sharedState.listener + protected[sql] def externalCatalog: ExternalCatalog = sharedState.externalCatalog + + /** + * :: Experimental :: + * An interface to register custom [[org.apache.spark.sql.util.QueryExecutionListener]]s + * that listen for execution metrics. + * + * @group basic + * @since 2.0.0 + */ + @Experimental + def listenerManager: ExecutionListenerManager = sessionState.listenerManager + + /** + * :: Experimental :: + * A collection of methods that are considered experimental, but can be used to hook into + * the query planner for advanced functionality. + * + * @group basic + * @since 2.0.0 + */ + @Experimental + def experimental: ExperimentalMethods = sessionState.experimentalMethods + + /** + * A collection of methods for registering user-defined functions (UDF). + * + * The following example registers a Scala closure as UDF: + * {{{ + * sparkSession.udf.register("myUDF", (arg1: Int, arg2: String) => arg2 + arg1) + * }}} + * + * The following example registers a UDF in Java: + * {{{ + * sparkSession.udf().register("myUDF", + * new UDF2() { + * @Override + * public String call(Integer arg1, String arg2) { + * return arg2 + arg1; + * } + * }, DataTypes.StringType); + * }}} + * + * Or, to use Java 8 lambda syntax: + * {{{ + * sparkSession.udf().register("myUDF", + * (Integer arg1, String arg2) -> arg2 + arg1, + * DataTypes.StringType); + * }}} + * + * @group basic + * @since 2.0.0 + */ + def udf: UDFRegistration = sessionState.udf + + /** + * Returns a [[ContinuousQueryManager]] that allows managing all the + * [[org.apache.spark.sql.ContinuousQuery ContinuousQueries]] active on `this`. + * + * @group basic + * @since 2.0.0 + */ + def streams: ContinuousQueryManager = sessionState.continuousQueryManager + + /** + * Start a new session with isolated SQL configurations, temporary tables, registered + * functions are isolated, but sharing the underlying [[SparkContext]] and cached data. + * + * Note: Other than the [[SparkContext]], all shared state is initialized lazily. + * This method will force the initialization of the shared state to ensure that parent + * and child sessions are set up with the same shared state. If the underlying catalog + * implementation is Hive, this will initialize the metastore, which may take some time. + * + * @group basic + * @since 2.0.0 + */ + def newSession(): SparkSession = { + new SparkSession(sparkContext, Some(sharedState)) + } + + + /* ------------------------------------------------- * + | Methods for accessing or mutating configurations | + * ------------------------------------------------- */ + + /** + * Set Spark SQL configuration properties. + * + * @group config + * @since 2.0.0 + */ + def setConf(props: Properties): Unit = sessionState.setConf(props) + + /** + * Set the given Spark SQL configuration property. + * + * @group config + * @since 2.0.0 + */ + def setConf(key: String, value: String): Unit = sessionState.setConf(key, value) + + /** + * Return the value of Spark SQL configuration property for the given key. + * + * @group config + * @since 2.0.0 + */ + def getConf(key: String): String = conf.getConfString(key) + + /** + * Return the value of Spark SQL configuration property for the given key. If the key is not set + * yet, return `defaultValue`. + * + * @group config + * @since 2.0.0 + */ + def getConf(key: String, defaultValue: String): String = conf.getConfString(key, defaultValue) + + /** + * Return all the configuration properties that have been set (i.e. not the default). + * This creates a new copy of the config properties in the form of a Map. + * + * @group config + * @since 2.0.0 + */ + def getAllConfs: immutable.Map[String, String] = conf.getAllConfs + + /** + * Set the given Spark SQL configuration property. + */ + protected[sql] def setConf[T](entry: ConfigEntry[T], value: T): Unit = { + sessionState.setConf(entry, value) + } + + /** + * Return the value of Spark SQL configuration property for the given key. If the key is not set + * yet, return `defaultValue` in [[ConfigEntry]]. + */ + protected[sql] def getConf[T](entry: ConfigEntry[T]): T = conf.getConf(entry) + + /** + * Return the value of Spark SQL configuration property for the given key. If the key is not set + * yet, return `defaultValue`. This is useful when `defaultValue` in ConfigEntry is not the + * desired one. + */ + protected[sql] def getConf[T](entry: ConfigEntry[T], defaultValue: T): T = { + conf.getConf(entry, defaultValue) + } + + + /* ------------------------------------- * + | Methods related to cache management | + * ------------------------------------- */ + + /** + * Returns true if the table is currently cached in-memory. + * + * @group cachemgmt + * @since 2.0.0 + */ + def isCached(tableName: String): Boolean = { + cacheManager.lookupCachedData(table(tableName)).nonEmpty + } + + /** + * Caches the specified table in-memory. + * + * @group cachemgmt + * @since 2.0.0 + */ + def cacheTable(tableName: String): Unit = { + cacheManager.cacheQuery(table(tableName), Some(tableName)) + } + + /** + * Removes the specified table from the in-memory cache. + * + * @group cachemgmt + * @since 2.0.0 + */ + def uncacheTable(tableName: String): Unit = { + cacheManager.uncacheQuery(table(tableName)) + } + + /** + * Removes all cached tables from the in-memory cache. + * + * @group cachemgmt + * @since 2.0.0 + */ + def clearCache(): Unit = { + cacheManager.clearCache() + } + + /** + * Returns true if the [[Dataset]] is currently cached in-memory. + * + * @group cachemgmt + * @since 2.0.0 + */ + protected[sql] def isCached(qName: Dataset[_]): Boolean = { + cacheManager.lookupCachedData(qName).nonEmpty + } + + /* --------------------------------- * + | Methods for creating DataFrames | + * --------------------------------- */ + + /** + * :: Experimental :: + * Returns a [[DataFrame]] with no rows or columns. + * + * @group dataframes + * @since 2.0.0 + */ + @Experimental + @transient + lazy val emptyDataFrame: DataFrame = { + createDataFrame(sparkContext.emptyRDD[Row], StructType(Nil)) + } + + /** + * :: Experimental :: + * Creates a [[DataFrame]] from an RDD of Product (e.g. case classes, tuples). + * + * @group dataframes + * @since 2.0.0 + */ + @Experimental + def createDataFrame[A <: Product : TypeTag](rdd: RDD[A]): DataFrame = { + SQLContext.setActive(wrapped) + val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] + val attributeSeq = schema.toAttributes + val rowRDD = RDDConversions.productToRowRdd(rdd, schema.map(_.dataType)) + Dataset.ofRows(wrapped, LogicalRDD(attributeSeq, rowRDD)(wrapped)) + } + + /** + * :: Experimental :: + * Creates a [[DataFrame]] from a local Seq of Product. + * + * @group dataframes + * @since 2.0.0 + */ + @Experimental + def createDataFrame[A <: Product : TypeTag](data: Seq[A]): DataFrame = { + SQLContext.setActive(wrapped) + val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] + val attributeSeq = schema.toAttributes + Dataset.ofRows(wrapped, LocalRelation.fromProduct(attributeSeq, data)) + } + + /** + * :: DeveloperApi :: + * Creates a [[DataFrame]] from an [[RDD]] containing [[Row]]s using the given schema. + * It is important to make sure that the structure of every [[Row]] of the provided RDD matches + * the provided schema. Otherwise, there will be runtime exception. + * Example: + * {{{ + * import org.apache.spark.sql._ + * import org.apache.spark.sql.types._ + * val sparkSession = new org.apache.spark.sql.SparkSession(sc) + * + * val schema = + * StructType( + * StructField("name", StringType, false) :: + * StructField("age", IntegerType, true) :: Nil) + * + * val people = + * sc.textFile("examples/src/main/resources/people.txt").map( + * _.split(",")).map(p => Row(p(0), p(1).trim.toInt)) + * val dataFrame = sparkSession.createDataFrame(people, schema) + * dataFrame.printSchema + * // root + * // |-- name: string (nullable = false) + * // |-- age: integer (nullable = true) + * + * dataFrame.registerTempTable("people") + * sparkSession.sql("select name from people").collect.foreach(println) + * }}} + * + * @group dataframes + * @since 2.0.0 + */ + @DeveloperApi + def createDataFrame(rowRDD: RDD[Row], schema: StructType): DataFrame = { + createDataFrame(rowRDD, schema, needsConversion = true) + } + + /** + * :: DeveloperApi :: + * Creates a [[DataFrame]] from an [[JavaRDD]] containing [[Row]]s using the given schema. + * It is important to make sure that the structure of every [[Row]] of the provided RDD matches + * the provided schema. Otherwise, there will be runtime exception. + * + * @group dataframes + * @since 2.0.0 + */ + @DeveloperApi + def createDataFrame(rowRDD: JavaRDD[Row], schema: StructType): DataFrame = { + createDataFrame(rowRDD.rdd, schema) + } + + /** + * :: DeveloperApi :: + * Creates a [[DataFrame]] from an [[java.util.List]] containing [[Row]]s using the given schema. + * It is important to make sure that the structure of every [[Row]] of the provided List matches + * the provided schema. Otherwise, there will be runtime exception. + * + * @group dataframes + * @since 2.0.0 + */ + @DeveloperApi + def createDataFrame(rows: java.util.List[Row], schema: StructType): DataFrame = { + Dataset.ofRows(wrapped, LocalRelation.fromExternalRows(schema.toAttributes, rows.asScala)) + } + + /** + * Applies a schema to an RDD of Java Beans. + * + * WARNING: Since there is no guaranteed ordering for fields in a Java Bean, + * SELECT * queries will return the columns in an undefined order. + * + * @group dataframes + * @since 2.0.0 + */ + def createDataFrame(rdd: RDD[_], beanClass: Class[_]): DataFrame = { + val attributeSeq: Seq[AttributeReference] = getSchema(beanClass) + val className = beanClass.getName + val rowRdd = rdd.mapPartitions { iter => + // BeanInfo is not serializable so we must rediscover it remotely for each partition. + val localBeanInfo = Introspector.getBeanInfo(Utils.classForName(className)) + SQLContext.beansToRows(iter, localBeanInfo, attributeSeq) + } + Dataset.ofRows(wrapped, LogicalRDD(attributeSeq, rowRdd)(wrapped)) + } + + /** + * Applies a schema to an RDD of Java Beans. + * + * WARNING: Since there is no guaranteed ordering for fields in a Java Bean, + * SELECT * queries will return the columns in an undefined order. + * + * @group dataframes + * @since 2.0.0 + */ + def createDataFrame(rdd: JavaRDD[_], beanClass: Class[_]): DataFrame = { + createDataFrame(rdd.rdd, beanClass) + } + + /** + * Applies a schema to an List of Java Beans. + * + * WARNING: Since there is no guaranteed ordering for fields in a Java Bean, + * SELECT * queries will return the columns in an undefined order. + * @group dataframes + * @since 1.6.0 + */ + def createDataFrame(data: java.util.List[_], beanClass: Class[_]): DataFrame = { + val attrSeq = getSchema(beanClass) + val beanInfo = Introspector.getBeanInfo(beanClass) + val rows = SQLContext.beansToRows(data.asScala.iterator, beanInfo, attrSeq) + Dataset.ofRows(wrapped, LocalRelation(attrSeq, rows.toSeq)) + } + + /** + * Convert a [[BaseRelation]] created for external data sources into a [[DataFrame]]. + * + * @group dataframes + * @since 2.0.0 + */ + def baseRelationToDataFrame(baseRelation: BaseRelation): DataFrame = { + Dataset.ofRows(wrapped, LogicalRelation(baseRelation)) + } + + def createDataset[T : Encoder](data: Seq[T]): Dataset[T] = { + val enc = encoderFor[T] + val attributes = enc.schema.toAttributes + val encoded = data.map(d => enc.toRow(d).copy()) + val plan = new LocalRelation(attributes, encoded) + Dataset[T](wrapped, plan) + } + + def createDataset[T : Encoder](data: RDD[T]): Dataset[T] = { + val enc = encoderFor[T] + val attributes = enc.schema.toAttributes + val encoded = data.map(d => enc.toRow(d)) + val plan = LogicalRDD(attributes, encoded)(wrapped) + Dataset[T](wrapped, plan) + } + + def createDataset[T : Encoder](data: java.util.List[T]): Dataset[T] = { + createDataset(data.asScala) + } + + /** + * :: Experimental :: + * Creates a [[Dataset]] with a single [[LongType]] column named `id`, containing elements + * in an range from 0 to `end` (exclusive) with step value 1. + * + * @since 2.0.0 + * @group dataset + */ + @Experimental + def range(end: Long): Dataset[java.lang.Long] = range(0, end) + + /** + * :: Experimental :: + * Creates a [[Dataset]] with a single [[LongType]] column named `id`, containing elements + * in an range from `start` to `end` (exclusive) with step value 1. + * + * @since 2.0.0 + * @group dataset + */ + @Experimental + def range(start: Long, end: Long): Dataset[java.lang.Long] = { + range(start, end, step = 1, numPartitions = sparkContext.defaultParallelism) + } + + /** + * :: Experimental :: + * Creates a [[Dataset]] with a single [[LongType]] column named `id`, containing elements + * in an range from `start` to `end` (exclusive) with an step value. + * + * @since 2.0.0 + * @group dataset + */ + @Experimental + def range(start: Long, end: Long, step: Long): Dataset[java.lang.Long] = { + range(start, end, step, numPartitions = sparkContext.defaultParallelism) + } + + /** + * :: Experimental :: + * Creates a [[Dataset]] with a single [[LongType]] column named `id`, containing elements + * in an range from `start` to `end` (exclusive) with an step value, with partition number + * specified. + * + * @since 2.0.0 + * @group dataset + */ + @Experimental + def range(start: Long, end: Long, step: Long, numPartitions: Int): Dataset[java.lang.Long] = { + new Dataset(wrapped, Range(start, end, step, numPartitions), Encoders.LONG) + } + + /** + * Creates a [[DataFrame]] from an RDD[Row]. + * User can specify whether the input rows should be converted to Catalyst rows. + */ + protected[sql] def internalCreateDataFrame( + catalystRows: RDD[InternalRow], + schema: StructType): DataFrame = { + // TODO: use MutableProjection when rowRDD is another DataFrame and the applied + // schema differs from the existing schema on any field data type. + val logicalPlan = LogicalRDD(schema.toAttributes, catalystRows)(wrapped) + Dataset.ofRows(wrapped, logicalPlan) + } + + /** + * Creates a [[DataFrame]] from an RDD[Row]. + * User can specify whether the input rows should be converted to Catalyst rows. + */ + protected[sql] def createDataFrame( + rowRDD: RDD[Row], + schema: StructType, + needsConversion: Boolean) = { + // TODO: use MutableProjection when rowRDD is another DataFrame and the applied + // schema differs from the existing schema on any field data type. + val catalystRows = if (needsConversion) { + val converter = CatalystTypeConverters.createToCatalystConverter(schema) + rowRDD.map(converter(_).asInstanceOf[InternalRow]) + } else { + rowRDD.map{r: Row => InternalRow.fromSeq(r.toSeq)} + } + val logicalPlan = LogicalRDD(schema.toAttributes, catalystRows)(wrapped) + Dataset.ofRows(wrapped, logicalPlan) + } + + + /* -------------------------- * + | Methods related to tables | + * -------------------------- */ + + /** + * :: Experimental :: + * Creates an external table from the given path and returns the corresponding DataFrame. + * It will use the default data source configured by spark.sql.sources.default. + * + * @group ddl_ops + * @since 2.0.0 + */ + @Experimental + def createExternalTable(tableName: String, path: String): DataFrame = { + val dataSourceName = conf.defaultDataSourceName + createExternalTable(tableName, path, dataSourceName) + } + + /** + * :: Experimental :: + * Creates an external table from the given path based on a data source + * and returns the corresponding DataFrame. + * + * @group ddl_ops + * @since 2.0.0 + */ + @Experimental + def createExternalTable(tableName: String, path: String, source: String): DataFrame = { + createExternalTable(tableName, source, Map("path" -> path)) + } + + /** + * :: Experimental :: + * Creates an external table from the given path based on a data source and a set of options. + * Then, returns the corresponding DataFrame. + * + * @group ddl_ops + * @since 2.0.0 + */ + @Experimental + def createExternalTable( + tableName: String, + source: String, + options: java.util.Map[String, String]): DataFrame = { + createExternalTable(tableName, source, options.asScala.toMap) + } + + /** + * :: Experimental :: + * (Scala-specific) + * Creates an external table from the given path based on a data source and a set of options. + * Then, returns the corresponding DataFrame. + * + * @group ddl_ops + * @since 2.0.0 + */ + @Experimental + def createExternalTable( + tableName: String, + source: String, + options: Map[String, String]): DataFrame = { + val tableIdent = sessionState.sqlParser.parseTableIdentifier(tableName) + val cmd = + CreateTableUsing( + tableIdent, + userSpecifiedSchema = None, + source, + temporary = false, + options, + allowExisting = false, + managedIfNoPath = false) + executePlan(cmd).toRdd + table(tableIdent) + } + + /** + * :: Experimental :: + * Create an external table from the given path based on a data source, a schema and + * a set of options. Then, returns the corresponding DataFrame. + * + * @group ddl_ops + * @since 2.0.0 + */ + @Experimental + def createExternalTable( + tableName: String, + source: String, + schema: StructType, + options: java.util.Map[String, String]): DataFrame = { + createExternalTable(tableName, source, schema, options.asScala.toMap) + } + + /** + * :: Experimental :: + * (Scala-specific) + * Create an external table from the given path based on a data source, a schema and + * a set of options. Then, returns the corresponding DataFrame. + * + * @group ddl_ops + * @since 2.0.0 + */ + @Experimental + def createExternalTable( + tableName: String, + source: String, + schema: StructType, + options: Map[String, String]): DataFrame = { + val tableIdent = sessionState.sqlParser.parseTableIdentifier(tableName) + val cmd = + CreateTableUsing( + tableIdent, + userSpecifiedSchema = Some(schema), + source, + temporary = false, + options, + allowExisting = false, + managedIfNoPath = false) + executePlan(cmd).toRdd + table(tableIdent) + } + + /** + * Drops the temporary table with the given table name in the catalog. + * If the table has been cached/persisted before, it's also unpersisted. + * + * @param tableName the name of the table to be unregistered. + * @group ddl_ops + * @since 2.0.0 + */ + def dropTempTable(tableName: String): Unit = { + cacheManager.tryUncacheQuery(table(tableName)) + sessionState.catalog.dropTable(TableIdentifier(tableName), ignoreIfNotExists = true) + } + + /** + * Returns the specified table as a [[DataFrame]]. + * + * @group ddl_ops + * @since 2.0.0 + */ + def table(tableName: String): DataFrame = { + table(sessionState.sqlParser.parseTableIdentifier(tableName)) + } + + private def table(tableIdent: TableIdentifier): DataFrame = { + Dataset.ofRows(wrapped, sessionState.catalog.lookupRelation(tableIdent)) + } + + /** + * Returns a [[DataFrame]] containing names of existing tables in the current database. + * The returned DataFrame has two columns, tableName and isTemporary (a Boolean + * indicating if a table is a temporary one or not). + * + * @group ddl_ops + * @since 2.0.0 + */ + def tables(): DataFrame = { + Dataset.ofRows(wrapped, ShowTablesCommand(None, None)) + } + + /** + * Returns a [[DataFrame]] containing names of existing tables in the given database. + * The returned DataFrame has two columns, tableName and isTemporary (a Boolean + * indicating if a table is a temporary one or not). + * + * @group ddl_ops + * @since 2.0.0 + */ + def tables(databaseName: String): DataFrame = { + Dataset.ofRows(wrapped, ShowTablesCommand(Some(databaseName), None)) + } + + /** + * Returns the names of tables in the current database as an array. + * + * @group ddl_ops + * @since 2.0.0 + */ + def tableNames(): Array[String] = { + tableNames(sessionState.catalog.getCurrentDatabase) + } + + /** + * Returns the names of tables in the given database as an array. + * + * @group ddl_ops + * @since 2.0.0 + */ + def tableNames(databaseName: String): Array[String] = { + sessionState.catalog.listTables(databaseName).map(_.table).toArray + } + + /** + * Registers the given [[DataFrame]] as a temporary table in the catalog. + * Temporary tables exist only during the lifetime of this instance of [[SparkSession]]. + */ + protected[sql] def registerDataFrameAsTable(df: DataFrame, tableName: String): Unit = { + sessionState.catalog.createTempTable( + sessionState.sqlParser.parseTableIdentifier(tableName).table, + df.logicalPlan, + overrideIfExists = true) + } + + + /* ---------------- * + | Everything else | + * ---------------- */ + + /** + * Executes a SQL query using Spark, returning the result as a [[DataFrame]]. + * The dialect that is used for SQL parsing can be configured with 'spark.sql.dialect'. + * + * @group basic + * @since 2.0.0 + */ + def sql(sqlText: String): DataFrame = { + Dataset.ofRows(wrapped, parseSql(sqlText)) + } + + /** + * :: Experimental :: + * Returns a [[DataFrameReader]] that can be used to read data and streams in as a [[DataFrame]]. + * {{{ + * sparkSession.read.parquet("/path/to/file.parquet") + * sparkSession.read.schema(schema).json("/path/to/file.json") + * }}} + * + * @group genericdata + * @since 2.0.0 + */ + @Experimental + def read: DataFrameReader = new DataFrameReader(wrapped) + + + // scalastyle:off + // Disable style checker so "implicits" object can start with lowercase i + /** + * :: Experimental :: + * (Scala-specific) Implicit methods available in Scala for converting + * common Scala objects into [[DataFrame]]s. + * + * {{{ + * val sparkSession = new SparkSession(sc) + * import sparkSession.implicits._ + * }}} + * + * @group basic + * @since 2.0.0 + */ + @Experimental + object implicits extends SQLImplicits with Serializable { + protected override def _sqlContext: SQLContext = wrapped + } + // scalastyle:on + + protected[sql] def parseSql(sql: String): LogicalPlan = { + sessionState.sqlParser.parsePlan(sql) + } + + protected[sql] def executeSql(sql: String): QueryExecution = { + executePlan(parseSql(sql)) + } + + protected[sql] def executePlan(plan: LogicalPlan): QueryExecution = { + sessionState.executePlan(plan) + } + + /** + * Executes a SQL query without parsing it, but instead passing it directly to an underlying + * system to process. This is currently only used for Hive DDLs and will be removed as soon + * as Spark can parse all supported Hive DDLs itself. + */ + protected[sql] def runNativeSql(sqlText: String): Seq[Row] = { + sessionState.runNativeSql(sqlText).map { r => Row(r) } + } + + /** + * Parses the data type in our internal string representation. The data type string should + * have the same format as the one generated by `toString` in scala. + * It is only used by PySpark. + */ + protected[sql] def parseDataType(dataTypeString: String): DataType = { + DataType.fromJson(dataTypeString) + } + + /** + * Apply a schema defined by the schemaString to an RDD. It is only used by PySpark. + */ + protected[sql] def applySchemaToPythonRDD( + rdd: RDD[Array[Any]], + schemaString: String): DataFrame = { + val schema = parseDataType(schemaString).asInstanceOf[StructType] + applySchemaToPythonRDD(rdd, schema) + } + + /** + * Apply a schema defined by the schema to an RDD. It is only used by PySpark. + */ + protected[sql] def applySchemaToPythonRDD( + rdd: RDD[Array[Any]], + schema: StructType): DataFrame = { + val rowRdd = rdd.map(r => python.EvaluatePython.fromJava(r, schema).asInstanceOf[InternalRow]) + Dataset.ofRows(wrapped, LogicalRDD(schema.toAttributes, rowRdd)(wrapped)) + } + + /** + * Returns a Catalyst Schema for the given java bean class. + */ + private def getSchema(beanClass: Class[_]): Seq[AttributeReference] = { + val (dataType, _) = JavaTypeInference.inferDataType(beanClass) + dataType.asInstanceOf[StructType].fields.map { f => + AttributeReference(f.name, f.dataType, f.nullable)() + } + } + } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/resources.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/resources.scala index e7191e4bfe..fc7ecb11ec 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/resources.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/resources.scala @@ -32,7 +32,7 @@ case class AddJar(path: String) extends RunnableCommand { } override def run(sqlContext: SQLContext): Seq[Row] = { - sqlContext.addJar(path) + sqlContext.sessionState.addJar(path) Seq(Row(0)) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala index 08a99627bf..8563dc3d5a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala @@ -56,7 +56,9 @@ private[sql] class SessionState(ctx: SQLContext) { */ lazy val functionRegistry: FunctionRegistry = FunctionRegistry.builtin.copy() - /** A [[FunctionResourceLoader]] that can be used in SessionCatalog. */ + /** + * A class for loading resources specified by a function. + */ lazy val functionResourceLoader: FunctionResourceLoader = { new FunctionResourceLoader { override def loadResource(resource: FunctionResource): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala index 7ab79b12ce..431ac8e2c8 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/TestSQLContext.scala @@ -18,13 +18,20 @@ package org.apache.spark.sql.test import org.apache.spark.{SparkConf, SparkContext} -import org.apache.spark.sql.SQLContext +import org.apache.spark.sql.{SparkSession, SQLContext} import org.apache.spark.sql.internal.{SessionState, SQLConf} /** * A special [[SQLContext]] prepared for testing. */ -private[sql] class TestSQLContext(sc: SparkContext) extends SQLContext(sc) { self => +private[sql] class TestSQLContext( + @transient private val sparkSession: SparkSession, + isRootContext: Boolean) + extends SQLContext(sparkSession, isRootContext) { self => + + def this(sc: SparkContext) { + this(new TestSparkSession(sc), true) + } def this(sparkConf: SparkConf) { this(new SparkContext("local[2]", "test-sql-context", @@ -35,8 +42,22 @@ private[sql] class TestSQLContext(sc: SparkContext) extends SQLContext(sc) { sel this(new SparkConf) } + // Needed for Java tests + def loadTestData(): Unit = { + testData.loadTestData() + } + + private object testData extends SQLTestData { + protected override def sqlContext: SQLContext = self + } + +} + + +private[sql] class TestSparkSession(sc: SparkContext) extends SparkSession(sc) { self => + @transient - protected[sql] override lazy val sessionState: SessionState = new SessionState(self) { + protected[sql] override lazy val sessionState: SessionState = new SessionState(wrapped) { override lazy val conf: SQLConf = { new SQLConf { clear() @@ -49,16 +70,9 @@ private[sql] class TestSQLContext(sc: SparkContext) extends SQLContext(sc) { sel } } - // Needed for Java tests - def loadTestData(): Unit = { - testData.loadTestData() - } - - private object testData extends SQLTestData { - protected override def sqlContext: SQLContext = self - } } + private[sql] object TestSQLContext { /** -- cgit v1.2.3