aboutsummaryrefslogtreecommitdiff
path: root/R/pkg/R/SQLContext.R
diff options
context:
space:
mode:
Diffstat (limited to 'R/pkg/R/SQLContext.R')
-rw-r--r--R/pkg/R/SQLContext.R109
1 files changed, 57 insertions, 52 deletions
diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R
index 914b02a47a..3232241f8a 100644
--- a/R/pkg/R/SQLContext.R
+++ b/R/pkg/R/SQLContext.R
@@ -53,7 +53,8 @@ dispatchFunc <- function(newFuncSig, x, ...) {
# Strip sqlContext from list of parameters and then pass the rest along.
contextNames <- c("org.apache.spark.sql.SQLContext",
"org.apache.spark.sql.hive.HiveContext",
- "org.apache.spark.sql.hive.test.TestHiveContext")
+ "org.apache.spark.sql.hive.test.TestHiveContext",
+ "org.apache.spark.sql.SparkSession")
if (missing(x) && length(list(...)) == 0) {
f()
} else if (class(x) == "jobj" &&
@@ -65,14 +66,12 @@ dispatchFunc <- function(newFuncSig, x, ...) {
}
}
-#' return the SQL Context
-getSqlContext <- function() {
- if (exists(".sparkRHivesc", envir = .sparkREnv)) {
- get(".sparkRHivesc", envir = .sparkREnv)
- } else if (exists(".sparkRSQLsc", envir = .sparkREnv)) {
- get(".sparkRSQLsc", envir = .sparkREnv)
+#' return the SparkSession
+getSparkSession <- function() {
+ if (exists(".sparkRsession", envir = .sparkREnv)) {
+ get(".sparkRsession", envir = .sparkREnv)
} else {
- stop("SQL context not initialized")
+ stop("SparkSession not initialized")
}
}
@@ -109,6 +108,13 @@ infer_type <- function(x) {
}
}
+getDefaultSqlSource <- function() {
+ sparkSession <- getSparkSession()
+ conf <- callJMethod(sparkSession, "conf")
+ source <- callJMethod(conf, "get", "spark.sql.sources.default", "org.apache.spark.sql.parquet")
+ source
+}
+
#' Create a SparkDataFrame
#'
#' Converts R data.frame or list into SparkDataFrame.
@@ -131,7 +137,7 @@ infer_type <- function(x) {
# TODO(davies): support sampling and infer type from NA
createDataFrame.default <- function(data, schema = NULL, samplingRatio = 1.0) {
- sqlContext <- getSqlContext()
+ sparkSession <- getSparkSession()
if (is.data.frame(data)) {
# get the names of columns, they will be put into RDD
if (is.null(schema)) {
@@ -158,7 +164,7 @@ createDataFrame.default <- function(data, schema = NULL, samplingRatio = 1.0) {
data <- do.call(mapply, append(args, data))
}
if (is.list(data)) {
- sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sqlContext)
+ sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession)
rdd <- parallelize(sc, data)
} else if (inherits(data, "RDD")) {
rdd <- data
@@ -201,7 +207,7 @@ createDataFrame.default <- function(data, schema = NULL, samplingRatio = 1.0) {
jrdd <- getJRDD(lapply(rdd, function(x) x), "row")
srdd <- callJMethod(jrdd, "rdd")
sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "createDF",
- srdd, schema$jobj, sqlContext)
+ srdd, schema$jobj, sparkSession)
dataFrame(sdf)
}
@@ -265,10 +271,10 @@ setMethod("toDF", signature(x = "RDD"),
#' @method read.json default
read.json.default <- function(path) {
- sqlContext <- getSqlContext()
+ sparkSession <- getSparkSession()
# Allow the user to have a more flexible definiton of the text file path
paths <- as.list(suppressWarnings(normalizePath(path)))
- read <- callJMethod(sqlContext, "read")
+ read <- callJMethod(sparkSession, "read")
sdf <- callJMethod(read, "json", paths)
dataFrame(sdf)
}
@@ -336,10 +342,10 @@ jsonRDD <- function(sqlContext, rdd, schema = NULL, samplingRatio = 1.0) {
#' @method read.parquet default
read.parquet.default <- function(path) {
- sqlContext <- getSqlContext()
+ sparkSession <- getSparkSession()
# Allow the user to have a more flexible definiton of the text file path
paths <- as.list(suppressWarnings(normalizePath(path)))
- read <- callJMethod(sqlContext, "read")
+ read <- callJMethod(sparkSession, "read")
sdf <- callJMethod(read, "parquet", paths)
dataFrame(sdf)
}
@@ -385,10 +391,10 @@ parquetFile <- function(x, ...) {
#' @method read.text default
read.text.default <- function(path) {
- sqlContext <- getSqlContext()
+ sparkSession <- getSparkSession()
# Allow the user to have a more flexible definiton of the text file path
paths <- as.list(suppressWarnings(normalizePath(path)))
- read <- callJMethod(sqlContext, "read")
+ read <- callJMethod(sparkSession, "read")
sdf <- callJMethod(read, "text", paths)
dataFrame(sdf)
}
@@ -418,8 +424,8 @@ read.text <- function(x, ...) {
#' @method sql default
sql.default <- function(sqlQuery) {
- sqlContext <- getSqlContext()
- sdf <- callJMethod(sqlContext, "sql", sqlQuery)
+ sparkSession <- getSparkSession()
+ sdf <- callJMethod(sparkSession, "sql", sqlQuery)
dataFrame(sdf)
}
@@ -449,8 +455,8 @@ sql <- function(x, ...) {
#' @note since 2.0.0
tableToDF <- function(tableName) {
- sqlContext <- getSqlContext()
- sdf <- callJMethod(sqlContext, "table", tableName)
+ sparkSession <- getSparkSession()
+ sdf <- callJMethod(sparkSession, "table", tableName)
dataFrame(sdf)
}
@@ -472,12 +478,8 @@ tableToDF <- function(tableName) {
#' @method tables default
tables.default <- function(databaseName = NULL) {
- sqlContext <- getSqlContext()
- jdf <- if (is.null(databaseName)) {
- callJMethod(sqlContext, "tables")
- } else {
- callJMethod(sqlContext, "tables", databaseName)
- }
+ sparkSession <- getSparkSession()
+ jdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getTables", sparkSession, databaseName)
dataFrame(jdf)
}
@@ -503,12 +505,11 @@ tables <- function(x, ...) {
#' @method tableNames default
tableNames.default <- function(databaseName = NULL) {
- sqlContext <- getSqlContext()
- if (is.null(databaseName)) {
- callJMethod(sqlContext, "tableNames")
- } else {
- callJMethod(sqlContext, "tableNames", databaseName)
- }
+ sparkSession <- getSparkSession()
+ callJStatic("org.apache.spark.sql.api.r.SQLUtils",
+ "getTableNames",
+ sparkSession,
+ databaseName)
}
tableNames <- function(x, ...) {
@@ -536,8 +537,9 @@ tableNames <- function(x, ...) {
#' @method cacheTable default
cacheTable.default <- function(tableName) {
- sqlContext <- getSqlContext()
- callJMethod(sqlContext, "cacheTable", tableName)
+ sparkSession <- getSparkSession()
+ catalog <- callJMethod(sparkSession, "catalog")
+ callJMethod(catalog, "cacheTable", tableName)
}
cacheTable <- function(x, ...) {
@@ -565,8 +567,9 @@ cacheTable <- function(x, ...) {
#' @method uncacheTable default
uncacheTable.default <- function(tableName) {
- sqlContext <- getSqlContext()
- callJMethod(sqlContext, "uncacheTable", tableName)
+ sparkSession <- getSparkSession()
+ catalog <- callJMethod(sparkSession, "catalog")
+ callJMethod(catalog, "uncacheTable", tableName)
}
uncacheTable <- function(x, ...) {
@@ -587,8 +590,9 @@ uncacheTable <- function(x, ...) {
#' @method clearCache default
clearCache.default <- function() {
- sqlContext <- getSqlContext()
- callJMethod(sqlContext, "clearCache")
+ sparkSession <- getSparkSession()
+ catalog <- callJMethod(sparkSession, "catalog")
+ callJMethod(catalog, "clearCache")
}
clearCache <- function() {
@@ -615,11 +619,12 @@ clearCache <- function() {
#' @method dropTempTable default
dropTempTable.default <- function(tableName) {
- sqlContext <- getSqlContext()
+ sparkSession <- getSparkSession()
if (class(tableName) != "character") {
stop("tableName must be a string.")
}
- callJMethod(sqlContext, "dropTempTable", tableName)
+ catalog <- callJMethod(sparkSession, "catalog")
+ callJMethod(catalog, "dropTempView", tableName)
}
dropTempTable <- function(x, ...) {
@@ -655,21 +660,21 @@ dropTempTable <- function(x, ...) {
#' @method read.df default
read.df.default <- function(path = NULL, source = NULL, schema = NULL, ...) {
- sqlContext <- getSqlContext()
+ sparkSession <- getSparkSession()
options <- varargsToEnv(...)
if (!is.null(path)) {
options[["path"]] <- path
}
if (is.null(source)) {
- source <- callJMethod(sqlContext, "getConf", "spark.sql.sources.default",
- "org.apache.spark.sql.parquet")
+ source <- getDefaultSqlSource()
}
if (!is.null(schema)) {
stopifnot(class(schema) == "structType")
- sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "loadDF", sqlContext, source,
+ sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "loadDF", sparkSession, source,
schema$jobj, options)
} else {
- sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "loadDF", sqlContext, source, options)
+ sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils",
+ "loadDF", sparkSession, source, options)
}
dataFrame(sdf)
}
@@ -715,12 +720,13 @@ loadDF <- function(x, ...) {
#' @method createExternalTable default
createExternalTable.default <- function(tableName, path = NULL, source = NULL, ...) {
- sqlContext <- getSqlContext()
+ sparkSession <- getSparkSession()
options <- varargsToEnv(...)
if (!is.null(path)) {
options[["path"]] <- path
}
- sdf <- callJMethod(sqlContext, "createExternalTable", tableName, source, options)
+ catalog <- callJMethod(sparkSession, "catalog")
+ sdf <- callJMethod(catalog, "createExternalTable", tableName, source, options)
dataFrame(sdf)
}
@@ -767,12 +773,11 @@ read.jdbc <- function(url, tableName,
partitionColumn = NULL, lowerBound = NULL, upperBound = NULL,
numPartitions = 0L, predicates = list(), ...) {
jprops <- varargsToJProperties(...)
-
- read <- callJMethod(sqlContext, "read")
+ sparkSession <- getSparkSession()
+ read <- callJMethod(sparkSession, "read")
if (!is.null(partitionColumn)) {
if (is.null(numPartitions) || numPartitions == 0) {
- sqlContext <- getSqlContext()
- sc <- callJMethod(sqlContext, "sparkContext")
+ sc <- callJMethod(sparkSession, "sparkContext")
numPartitions <- callJMethod(sc, "defaultParallelism")
} else {
numPartitions <- numToInt(numPartitions)