aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorFelix Cheung <felixcheung_m@hotmail.com>2016-06-17 21:36:01 -0700
committerShivaram Venkataraman <shivaram@cs.berkeley.edu>2016-06-17 21:36:01 -0700
commit8c198e246d64b5779dc3a2625d06ec958553a20b (patch)
tree8e882c1a467cb454863b08c74124a36d30120314 /sql
parentedb23f9e47eecfe60992dde0e037ec1985c77e1d (diff)
downloadspark-8c198e246d64b5779dc3a2625d06ec958553a20b.tar.gz
spark-8c198e246d64b5779dc3a2625d06ec958553a20b.tar.bz2
spark-8c198e246d64b5779dc3a2625d06ec958553a20b.zip
[SPARK-15159][SPARKR] SparkR SparkSession API
## What changes were proposed in this pull request? This PR introduces the new SparkSession API for SparkR. `sparkR.session.getOrCreate()` and `sparkR.session.stop()` "getOrCreate" is a bit unusual in R but it's important to name this clearly. SparkR implementation should - SparkSession is the main entrypoint (vs SparkContext; due to limited functionality supported with SparkContext in SparkR) - SparkSession replaces SQLContext and HiveContext (both a wrapper around SparkSession, and because of API changes, supporting all 3 would be a lot more work) - Changes to SparkSession is mostly transparent to users due to SPARK-10903 - Full backward compatibility is expected - users should be able to initialize everything just in Spark 1.6.1 (`sparkR.init()`), but with deprecation warning - Mostly cosmetic changes to parameter list - users should be able to move to `sparkR.session.getOrCreate()` easily - An advanced syntax with named parameters (aka varargs aka "...") is supported; that should be closer to the Builder syntax that is in Scala/Python (which unfortunately does not work in R because it will look like this: `enableHiveSupport(config(config(master(appName(builder(), "foo"), "local"), "first", "value"), "next, "value"))` - Updating config on an existing SparkSession is supported, the behavior is the same as Python, in which config is applied to both SparkContext and SparkSession - Some SparkSession changes are not matched in SparkR, mostly because it would be breaking API change: `catalog` object, `createOrReplaceTempView` - Other SQLContext workarounds are replicated in SparkR, eg. `tables`, `tableNames` - `sparkR` shell is updated to use the SparkSession entrypoint (`sqlContext` is removed, just like with Scale/Python) - All tests are updated to use the SparkSession entrypoint - A bug in `read.jdbc` is fixed TODO - [x] Add more tests - [ ] Separate PR - update all roxygen2 doc coding example - [ ] Separate PR - update SparkR programming guide ## How was this patch tested? unit tests, manual tests shivaram sun-rui rxin Author: Felix Cheung <felixcheung_m@hotmail.com> Author: felixcheung <felixcheung_m@hotmail.com> Closes #13635 from felixcheung/rsparksession.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala76
1 files changed, 64 insertions, 12 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
index fe426fa3c7..0a995d2e9d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala
@@ -18,27 +18,61 @@
package org.apache.spark.sql.api.r
import java.io.{ByteArrayInputStream, ByteArrayOutputStream, DataInputStream, DataOutputStream}
+import java.util.{Map => JMap}
import scala.collection.JavaConverters._
import scala.util.matching.Regex
+import org.apache.spark.internal.Logging
+import org.apache.spark.SparkContext
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.api.r.SerDe
import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.internal.config.CATALOG_IMPLEMENTATION
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.{DataFrame, RelationalGroupedDataset, Row, SaveMode, SQLContext}
+import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions.GenericRowWithSchema
+import org.apache.spark.sql.execution.command.ShowTablesCommand
import org.apache.spark.sql.types._
-private[sql] object SQLUtils {
+private[sql] object SQLUtils extends Logging {
SerDe.registerSqlSerDe((readSqlObject, writeSqlObject))
- def createSQLContext(jsc: JavaSparkContext): SQLContext = {
- SQLContext.getOrCreate(jsc.sc)
+ private[this] def withHiveExternalCatalog(sc: SparkContext): SparkContext = {
+ sc.conf.set(CATALOG_IMPLEMENTATION.key, "hive")
+ sc
}
- def getJavaSparkContext(sqlCtx: SQLContext): JavaSparkContext = {
- new JavaSparkContext(sqlCtx.sparkContext)
+ def getOrCreateSparkSession(
+ jsc: JavaSparkContext,
+ sparkConfigMap: JMap[Object, Object],
+ enableHiveSupport: Boolean): SparkSession = {
+ val spark = if (SparkSession.hiveClassesArePresent && enableHiveSupport) {
+ SparkSession.builder().sparkContext(withHiveExternalCatalog(jsc.sc)).getOrCreate()
+ } else {
+ if (enableHiveSupport) {
+ logWarning("SparkR: enableHiveSupport is requested for SparkSession but " +
+ "Spark is not built with Hive; falling back to without Hive support.")
+ }
+ SparkSession.builder().sparkContext(jsc.sc).getOrCreate()
+ }
+ setSparkContextSessionConf(spark, sparkConfigMap)
+ spark
+ }
+
+ def setSparkContextSessionConf(
+ spark: SparkSession,
+ sparkConfigMap: JMap[Object, Object]): Unit = {
+ for ((name, value) <- sparkConfigMap.asScala) {
+ spark.conf.set(name.toString, value.toString)
+ }
+ for ((name, value) <- sparkConfigMap.asScala) {
+ spark.sparkContext.conf.set(name.toString, value.toString)
+ }
+ }
+
+ def getJavaSparkContext(spark: SparkSession): JavaSparkContext = {
+ new JavaSparkContext(spark.sparkContext)
}
def createStructType(fields : Seq[StructField]): StructType = {
@@ -95,10 +129,10 @@ private[sql] object SQLUtils {
StructField(name, dtObj, nullable)
}
- def createDF(rdd: RDD[Array[Byte]], schema: StructType, sqlContext: SQLContext): DataFrame = {
+ def createDF(rdd: RDD[Array[Byte]], schema: StructType, sparkSession: SparkSession): DataFrame = {
val num = schema.fields.length
val rowRDD = rdd.map(bytesToRow(_, schema))
- sqlContext.createDataFrame(rowRDD, schema)
+ sparkSession.createDataFrame(rowRDD, schema)
}
def dfToRowRDD(df: DataFrame): JavaRDD[Array[Byte]] = {
@@ -191,18 +225,18 @@ private[sql] object SQLUtils {
}
def loadDF(
- sqlContext: SQLContext,
+ sparkSession: SparkSession,
source: String,
options: java.util.Map[String, String]): DataFrame = {
- sqlContext.read.format(source).options(options).load()
+ sparkSession.read.format(source).options(options).load()
}
def loadDF(
- sqlContext: SQLContext,
+ sparkSession: SparkSession,
source: String,
schema: StructType,
options: java.util.Map[String, String]): DataFrame = {
- sqlContext.read.format(source).schema(schema).options(options).load()
+ sparkSession.read.format(source).schema(schema).options(options).load()
}
def readSqlObject(dis: DataInputStream, dataType: Char): Object = {
@@ -227,4 +261,22 @@ private[sql] object SQLUtils {
false
}
}
+
+ def getTables(sparkSession: SparkSession, databaseName: String): DataFrame = {
+ databaseName match {
+ case n: String if n != null && n.trim.nonEmpty =>
+ Dataset.ofRows(sparkSession, ShowTablesCommand(Some(n), None))
+ case _ =>
+ Dataset.ofRows(sparkSession, ShowTablesCommand(None, None))
+ }
+ }
+
+ def getTableNames(sparkSession: SparkSession, databaseName: String): Array[String] = {
+ databaseName match {
+ case n: String if n != null && n.trim.nonEmpty =>
+ sparkSession.catalog.listTables(n).collect().map(_.name)
+ case _ =>
+ sparkSession.catalog.listTables().collect().map(_.name)
+ }
+ }
}