aboutsummaryrefslogtreecommitdiff
path: root/R
diff options
context:
space:
mode:
authorwm624@hotmail.com <wm624@hotmail.com>2016-09-02 01:47:17 -0700
committerFelix Cheung <felixcheung@apache.org>2016-09-02 01:47:17 -0700
commit0f30cdedbdb0d38e8c479efab6bb1c6c376206ff (patch)
treee5b05ff7584f77d551da0dfc2269a58d7fd1da43 /R
parent2ab8dbddaa31e4491b52eb0e495660ebbebfdb9e (diff)
downloadspark-0f30cdedbdb0d38e8c479efab6bb1c6c376206ff.tar.gz
spark-0f30cdedbdb0d38e8c479efab6bb1c6c376206ff.tar.bz2
spark-0f30cdedbdb0d38e8c479efab6bb1c6c376206ff.zip
[SPARK-16883][SPARKR] SQL decimal type is not properly cast to number when collecting SparkDataFrame
## What changes were proposed in this pull request? (Please fill in changes proposed in this fix) registerTempTable(createDataFrame(iris), "iris") str(collect(sql("select cast('1' as double) as x, cast('2' as decimal) as y from iris limit 5"))) 'data.frame': 5 obs. of 2 variables: $ x: num 1 1 1 1 1 $ y:List of 5 ..$ : num 2 ..$ : num 2 ..$ : num 2 ..$ : num 2 ..$ : num 2 The problem is that spark returns `decimal(10, 0)` col type, instead of `decimal`. Thus, `decimal(10, 0)` is not handled correctly. It should be handled as "double". As discussed in JIRA thread, we can have two potential fixes: 1). Scala side fix to add a new case when writing the object back; However, I can't use spark.sql.types._ in Spark core due to dependency issues. I don't find a way of doing type case match; 2). SparkR side fix: Add a helper function to check special type like `"decimal(10, 0)"` and replace it with `double`, which is PRIMITIVE type. This special helper is generic for adding new types handling in the future. I open this PR to discuss pros and cons of both approaches. If we want to do Scala side fix, we need to find a way to match the case of DecimalType and StructType in Spark Core. ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) Manual test: > str(collect(sql("select cast('1' as double) as x, cast('2' as decimal) as y from iris limit 5"))) 'data.frame': 5 obs. of 2 variables: $ x: num 1 1 1 1 1 $ y: num 2 2 2 2 2 R Unit tests Author: wm624@hotmail.com <wm624@hotmail.com> Closes #14613 from wangmiao1981/type.
Diffstat (limited to 'R')
-rw-r--r--R/pkg/R/DataFrame.R13
-rw-r--r--R/pkg/R/types.R16
-rw-r--r--R/pkg/inst/tests/testthat/test_sparkSQL.R22
3 files changed, 50 insertions, 1 deletions
diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R
index e12b58e2ee..a92450274e 100644
--- a/R/pkg/R/DataFrame.R
+++ b/R/pkg/R/DataFrame.R
@@ -397,7 +397,11 @@ setMethod("coltypes",
}
if (is.null(type)) {
- stop(paste("Unsupported data type: ", x))
+ specialtype <- specialtypeshandle(x)
+ if (is.null(specialtype)) {
+ stop(paste("Unsupported data type: ", x))
+ }
+ type <- PRIMITIVE_TYPES[[specialtype]]
}
}
type
@@ -1063,6 +1067,13 @@ setMethod("collect",
df[[colIndex]] <- col
} else {
colType <- dtypes[[colIndex]][[2]]
+ if (is.null(PRIMITIVE_TYPES[[colType]])) {
+ specialtype <- specialtypeshandle(colType)
+ if (!is.null(specialtype)) {
+ colType <- specialtype
+ }
+ }
+
# Note that "binary" columns behave like complex types.
if (!is.null(PRIMITIVE_TYPES[[colType]]) && colType != "binary") {
vec <- do.call(c, col)
diff --git a/R/pkg/R/types.R b/R/pkg/R/types.R
index ad048b1cd1..abca703617 100644
--- a/R/pkg/R/types.R
+++ b/R/pkg/R/types.R
@@ -67,3 +67,19 @@ rToSQLTypes <- as.environment(list(
"double" = "double",
"character" = "string",
"logical" = "boolean"))
+
+# Helper function of coverting decimal type. When backend returns column type in the
+# format of decimal(,) (e.g., decimal(10, 0)), this function coverts the column type
+# as double type. This function converts backend returned types that are not the key
+# of PRIMITIVE_TYPES, but should be treated as PRIMITIVE_TYPES.
+# @param A type returned from the JVM backend.
+# @return A type is the key of the PRIMITIVE_TYPES.
+specialtypeshandle <- function(type) {
+ returntype <- NULL
+ m <- regexec("^decimal(.+)$", type)
+ matchedStrings <- regmatches(type, m)
+ if (length(matchedStrings[[1]]) >= 2) {
+ returntype <- "double"
+ }
+ returntype
+}
diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R
index 8ff56eba1f..683a15cb4f 100644
--- a/R/pkg/inst/tests/testthat/test_sparkSQL.R
+++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R
@@ -526,6 +526,17 @@ test_that(
expect_is(newdf, "SparkDataFrame")
expect_equal(count(newdf), 1)
dropTempView("table1")
+
+ createOrReplaceTempView(df, "dfView")
+ sqlCast <- collect(sql("select cast('2' as decimal) as x from dfView limit 1"))
+ out <- capture.output(sqlCast)
+ expect_true(is.data.frame(sqlCast))
+ expect_equal(names(sqlCast)[1], "x")
+ expect_equal(nrow(sqlCast), 1)
+ expect_equal(ncol(sqlCast), 1)
+ expect_equal(out[1], " x")
+ expect_equal(out[2], "1 2")
+ dropTempView("dfView")
})
test_that("test cache, uncache and clearCache", {
@@ -2089,6 +2100,9 @@ test_that("Method coltypes() to get and set R's data types of a DataFrame", {
# Test primitive types
DF <- createDataFrame(data, schema)
expect_equal(coltypes(DF), c("integer", "logical", "POSIXct"))
+ createOrReplaceTempView(DF, "DFView")
+ sqlCast <- sql("select cast('2' as decimal) as x from DFView limit 1")
+ expect_equal(coltypes(sqlCast), "numeric")
# Test complex types
x <- createDataFrame(list(list(as.environment(
@@ -2132,6 +2146,14 @@ test_that("Method str()", {
"setosa\" \"setosa\" \"setosa\" \"setosa\""))
expect_equal(out[7], " $ col : logi TRUE TRUE TRUE TRUE TRUE TRUE")
+ createOrReplaceTempView(irisDF2, "irisView")
+
+ sqlCast <- sql("select cast('2' as decimal) as x from irisView limit 1")
+ castStr <- capture.output(str(sqlCast))
+ expect_equal(length(castStr), 2)
+ expect_equal(castStr[1], "'SparkDataFrame': 1 variables:")
+ expect_equal(castStr[2], " $ x: num 2")
+
# A random dataset with many columns. This test is to check str limits
# the number of columns. Therefore, it will suffice to check for the
# number of returned rows