diff options
Diffstat (limited to 'R/pkg/R/SQLContext.R')
-rw-r--r-- | R/pkg/R/SQLContext.R | 109 |
1 files changed, 57 insertions, 52 deletions
diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 914b02a47a..3232241f8a 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -53,7 +53,8 @@ dispatchFunc <- function(newFuncSig, x, ...) { # Strip sqlContext from list of parameters and then pass the rest along. contextNames <- c("org.apache.spark.sql.SQLContext", "org.apache.spark.sql.hive.HiveContext", - "org.apache.spark.sql.hive.test.TestHiveContext") + "org.apache.spark.sql.hive.test.TestHiveContext", + "org.apache.spark.sql.SparkSession") if (missing(x) && length(list(...)) == 0) { f() } else if (class(x) == "jobj" && @@ -65,14 +66,12 @@ dispatchFunc <- function(newFuncSig, x, ...) { } } -#' return the SQL Context -getSqlContext <- function() { - if (exists(".sparkRHivesc", envir = .sparkREnv)) { - get(".sparkRHivesc", envir = .sparkREnv) - } else if (exists(".sparkRSQLsc", envir = .sparkREnv)) { - get(".sparkRSQLsc", envir = .sparkREnv) +#' return the SparkSession +getSparkSession <- function() { + if (exists(".sparkRsession", envir = .sparkREnv)) { + get(".sparkRsession", envir = .sparkREnv) } else { - stop("SQL context not initialized") + stop("SparkSession not initialized") } } @@ -109,6 +108,13 @@ infer_type <- function(x) { } } +getDefaultSqlSource <- function() { + sparkSession <- getSparkSession() + conf <- callJMethod(sparkSession, "conf") + source <- callJMethod(conf, "get", "spark.sql.sources.default", "org.apache.spark.sql.parquet") + source +} + #' Create a SparkDataFrame #' #' Converts R data.frame or list into SparkDataFrame. @@ -131,7 +137,7 @@ infer_type <- function(x) { # TODO(davies): support sampling and infer type from NA createDataFrame.default <- function(data, schema = NULL, samplingRatio = 1.0) { - sqlContext <- getSqlContext() + sparkSession <- getSparkSession() if (is.data.frame(data)) { # get the names of columns, they will be put into RDD if (is.null(schema)) { @@ -158,7 +164,7 @@ createDataFrame.default <- function(data, schema = NULL, samplingRatio = 1.0) { data <- do.call(mapply, append(args, data)) } if (is.list(data)) { - sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sqlContext) + sc <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getJavaSparkContext", sparkSession) rdd <- parallelize(sc, data) } else if (inherits(data, "RDD")) { rdd <- data @@ -201,7 +207,7 @@ createDataFrame.default <- function(data, schema = NULL, samplingRatio = 1.0) { jrdd <- getJRDD(lapply(rdd, function(x) x), "row") srdd <- callJMethod(jrdd, "rdd") sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "createDF", - srdd, schema$jobj, sqlContext) + srdd, schema$jobj, sparkSession) dataFrame(sdf) } @@ -265,10 +271,10 @@ setMethod("toDF", signature(x = "RDD"), #' @method read.json default read.json.default <- function(path) { - sqlContext <- getSqlContext() + sparkSession <- getSparkSession() # Allow the user to have a more flexible definiton of the text file path paths <- as.list(suppressWarnings(normalizePath(path))) - read <- callJMethod(sqlContext, "read") + read <- callJMethod(sparkSession, "read") sdf <- callJMethod(read, "json", paths) dataFrame(sdf) } @@ -336,10 +342,10 @@ jsonRDD <- function(sqlContext, rdd, schema = NULL, samplingRatio = 1.0) { #' @method read.parquet default read.parquet.default <- function(path) { - sqlContext <- getSqlContext() + sparkSession <- getSparkSession() # Allow the user to have a more flexible definiton of the text file path paths <- as.list(suppressWarnings(normalizePath(path))) - read <- callJMethod(sqlContext, "read") + read <- callJMethod(sparkSession, "read") sdf <- callJMethod(read, "parquet", paths) dataFrame(sdf) } @@ -385,10 +391,10 @@ parquetFile <- function(x, ...) { #' @method read.text default read.text.default <- function(path) { - sqlContext <- getSqlContext() + sparkSession <- getSparkSession() # Allow the user to have a more flexible definiton of the text file path paths <- as.list(suppressWarnings(normalizePath(path))) - read <- callJMethod(sqlContext, "read") + read <- callJMethod(sparkSession, "read") sdf <- callJMethod(read, "text", paths) dataFrame(sdf) } @@ -418,8 +424,8 @@ read.text <- function(x, ...) { #' @method sql default sql.default <- function(sqlQuery) { - sqlContext <- getSqlContext() - sdf <- callJMethod(sqlContext, "sql", sqlQuery) + sparkSession <- getSparkSession() + sdf <- callJMethod(sparkSession, "sql", sqlQuery) dataFrame(sdf) } @@ -449,8 +455,8 @@ sql <- function(x, ...) { #' @note since 2.0.0 tableToDF <- function(tableName) { - sqlContext <- getSqlContext() - sdf <- callJMethod(sqlContext, "table", tableName) + sparkSession <- getSparkSession() + sdf <- callJMethod(sparkSession, "table", tableName) dataFrame(sdf) } @@ -472,12 +478,8 @@ tableToDF <- function(tableName) { #' @method tables default tables.default <- function(databaseName = NULL) { - sqlContext <- getSqlContext() - jdf <- if (is.null(databaseName)) { - callJMethod(sqlContext, "tables") - } else { - callJMethod(sqlContext, "tables", databaseName) - } + sparkSession <- getSparkSession() + jdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "getTables", sparkSession, databaseName) dataFrame(jdf) } @@ -503,12 +505,11 @@ tables <- function(x, ...) { #' @method tableNames default tableNames.default <- function(databaseName = NULL) { - sqlContext <- getSqlContext() - if (is.null(databaseName)) { - callJMethod(sqlContext, "tableNames") - } else { - callJMethod(sqlContext, "tableNames", databaseName) - } + sparkSession <- getSparkSession() + callJStatic("org.apache.spark.sql.api.r.SQLUtils", + "getTableNames", + sparkSession, + databaseName) } tableNames <- function(x, ...) { @@ -536,8 +537,9 @@ tableNames <- function(x, ...) { #' @method cacheTable default cacheTable.default <- function(tableName) { - sqlContext <- getSqlContext() - callJMethod(sqlContext, "cacheTable", tableName) + sparkSession <- getSparkSession() + catalog <- callJMethod(sparkSession, "catalog") + callJMethod(catalog, "cacheTable", tableName) } cacheTable <- function(x, ...) { @@ -565,8 +567,9 @@ cacheTable <- function(x, ...) { #' @method uncacheTable default uncacheTable.default <- function(tableName) { - sqlContext <- getSqlContext() - callJMethod(sqlContext, "uncacheTable", tableName) + sparkSession <- getSparkSession() + catalog <- callJMethod(sparkSession, "catalog") + callJMethod(catalog, "uncacheTable", tableName) } uncacheTable <- function(x, ...) { @@ -587,8 +590,9 @@ uncacheTable <- function(x, ...) { #' @method clearCache default clearCache.default <- function() { - sqlContext <- getSqlContext() - callJMethod(sqlContext, "clearCache") + sparkSession <- getSparkSession() + catalog <- callJMethod(sparkSession, "catalog") + callJMethod(catalog, "clearCache") } clearCache <- function() { @@ -615,11 +619,12 @@ clearCache <- function() { #' @method dropTempTable default dropTempTable.default <- function(tableName) { - sqlContext <- getSqlContext() + sparkSession <- getSparkSession() if (class(tableName) != "character") { stop("tableName must be a string.") } - callJMethod(sqlContext, "dropTempTable", tableName) + catalog <- callJMethod(sparkSession, "catalog") + callJMethod(catalog, "dropTempView", tableName) } dropTempTable <- function(x, ...) { @@ -655,21 +660,21 @@ dropTempTable <- function(x, ...) { #' @method read.df default read.df.default <- function(path = NULL, source = NULL, schema = NULL, ...) { - sqlContext <- getSqlContext() + sparkSession <- getSparkSession() options <- varargsToEnv(...) if (!is.null(path)) { options[["path"]] <- path } if (is.null(source)) { - source <- callJMethod(sqlContext, "getConf", "spark.sql.sources.default", - "org.apache.spark.sql.parquet") + source <- getDefaultSqlSource() } if (!is.null(schema)) { stopifnot(class(schema) == "structType") - sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "loadDF", sqlContext, source, + sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "loadDF", sparkSession, source, schema$jobj, options) } else { - sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "loadDF", sqlContext, source, options) + sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", + "loadDF", sparkSession, source, options) } dataFrame(sdf) } @@ -715,12 +720,13 @@ loadDF <- function(x, ...) { #' @method createExternalTable default createExternalTable.default <- function(tableName, path = NULL, source = NULL, ...) { - sqlContext <- getSqlContext() + sparkSession <- getSparkSession() options <- varargsToEnv(...) if (!is.null(path)) { options[["path"]] <- path } - sdf <- callJMethod(sqlContext, "createExternalTable", tableName, source, options) + catalog <- callJMethod(sparkSession, "catalog") + sdf <- callJMethod(catalog, "createExternalTable", tableName, source, options) dataFrame(sdf) } @@ -767,12 +773,11 @@ read.jdbc <- function(url, tableName, partitionColumn = NULL, lowerBound = NULL, upperBound = NULL, numPartitions = 0L, predicates = list(), ...) { jprops <- varargsToJProperties(...) - - read <- callJMethod(sqlContext, "read") + sparkSession <- getSparkSession() + read <- callJMethod(sparkSession, "read") if (!is.null(partitionColumn)) { if (is.null(numPartitions) || numPartitions == 0) { - sqlContext <- getSqlContext() - sc <- callJMethod(sqlContext, "sparkContext") + sc <- callJMethod(sparkSession, "sparkContext") numPartitions <- callJMethod(sc, "defaultParallelism") } else { numPartitions <- numToInt(numPartitions) |