From 12f5eaeee1235850a076ce5716d069bd2f1205a5 Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Fri, 5 Jun 2015 10:19:03 -0700 Subject: [SPARK-8085] [SPARKR] Support user-specified schema in read.df cc davies sun-rui Author: Shivaram Venkataraman Closes #6620 from shivaram/sparkr-read-schema and squashes the following commits: 16a6726 [Shivaram Venkataraman] Fix loadDF to pass schema Also add a unit test a229877 [Shivaram Venkataraman] Use wrapper function to DataFrameReader ee70ba8 [Shivaram Venkataraman] Support user-specified schema in read.df --- R/pkg/R/SQLContext.R | 14 ++++++++++---- R/pkg/inst/tests/test_sparkSQL.R | 13 +++++++++++++ .../main/scala/org/apache/spark/sql/api/r/SQLUtils.scala | 15 +++++++++++++++ 3 files changed, 38 insertions(+), 4 deletions(-) diff --git a/R/pkg/R/SQLContext.R b/R/pkg/R/SQLContext.R index 88e1a508f3..22a4b5bf86 100644 --- a/R/pkg/R/SQLContext.R +++ b/R/pkg/R/SQLContext.R @@ -452,7 +452,7 @@ dropTempTable <- function(sqlContext, tableName) { #' df <- read.df(sqlContext, "path/to/file.json", source = "json") #' } -read.df <- function(sqlContext, path = NULL, source = NULL, ...) { +read.df <- function(sqlContext, path = NULL, source = NULL, schema = NULL, ...) { options <- varargsToEnv(...) if (!is.null(path)) { options[['path']] <- path @@ -462,15 +462,21 @@ read.df <- function(sqlContext, path = NULL, source = NULL, ...) { source <- callJMethod(sqlContext, "getConf", "spark.sql.sources.default", "org.apache.spark.sql.parquet") } - sdf <- callJMethod(sqlContext, "load", source, options) + if (!is.null(schema)) { + stopifnot(class(schema) == "structType") + sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "loadDF", sqlContext, source, + schema$jobj, options) + } else { + sdf <- callJStatic("org.apache.spark.sql.api.r.SQLUtils", "loadDF", sqlContext, source, options) + } dataFrame(sdf) } #' @aliases loadDF #' @export -loadDF <- function(sqlContext, path = NULL, source = NULL, ...) { - read.df(sqlContext, path, source, ...) +loadDF <- function(sqlContext, path = NULL, source = NULL, schema = NULL, ...) { + read.df(sqlContext, path, source, schema, ...) } #' Create an external table diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R index d2d82e791e..30edfc8a7b 100644 --- a/R/pkg/inst/tests/test_sparkSQL.R +++ b/R/pkg/inst/tests/test_sparkSQL.R @@ -504,6 +504,19 @@ test_that("read.df() from json file", { df <- read.df(sqlContext, jsonPath, "json") expect_true(inherits(df, "DataFrame")) expect_true(count(df) == 3) + + # Check if we can apply a user defined schema + schema <- structType(structField("name", type = "string"), + structField("age", type = "double")) + + df1 <- read.df(sqlContext, jsonPath, "json", schema) + expect_true(inherits(df1, "DataFrame")) + expect_equal(dtypes(df1), list(c("name", "string"), c("age", "double"))) + + # Run the same with loadDF + df2 <- loadDF(sqlContext, jsonPath, "json", schema) + expect_true(inherits(df2, "DataFrame")) + expect_equal(dtypes(df2), list(c("name", "string"), c("age", "double"))) }) test_that("write.df() as parquet file", { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala index 604f3124e2..43b62f0e82 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/api/r/SQLUtils.scala @@ -139,4 +139,19 @@ private[r] object SQLUtils { case "ignore" => SaveMode.Ignore } } + + def loadDF( + sqlContext: SQLContext, + source: String, + options: java.util.Map[String, String]): DataFrame = { + sqlContext.read.format(source).options(options).load() + } + + def loadDF( + sqlContext: SQLContext, + source: String, + schema: StructType, + options: java.util.Map[String, String]): DataFrame = { + sqlContext.read.format(source).schema(schema).options(options).load() + } } -- cgit v1.2.3