From bc0a0e6392c4e729d8f0e4caffc0bd05adb0d950 Mon Sep 17 00:00:00 2001 From: titicaca Date: Sun, 12 Feb 2017 10:42:15 -0800 Subject: [SPARK-19342][SPARKR] bug fixed in collect method for collecting timestamp column ## What changes were proposed in this pull request? Fix a bug in collect method for collecting timestamp column, the bug can be reproduced as shown in the following codes and outputs: ``` library(SparkR) sparkR.session(master = "local") df <- data.frame(col1 = c(0, 1, 2), col2 = c(as.POSIXct("2017-01-01 00:00:01"), NA, as.POSIXct("2017-01-01 12:00:01"))) sdf1 <- createDataFrame(df) print(dtypes(sdf1)) df1 <- collect(sdf1) print(lapply(df1, class)) sdf2 <- filter(sdf1, "col1 > 0") print(dtypes(sdf2)) df2 <- collect(sdf2) print(lapply(df2, class)) ``` As we can see from the printed output, the column type of col2 in df2 is converted to numeric unexpectedly, when NA exists at the top of the column. This is caused by method `do.call(c, list)`, if we convert a list, i.e. `do.call(c, list(NA, as.POSIXct("2017-01-01 12:00:01"))`, the class of the result is numeric instead of POSIXct. Therefore, we need to cast the data type of the vector explicitly. ## How was this patch tested? The patch can be tested manually with the same code above. Author: titicaca Closes #16689 from titicaca/sparkr-dev. --- R/pkg/R/DataFrame.R | 3 ++- R/pkg/R/types.R | 2 +- R/pkg/inst/tests/testthat/test_sparkSQL.R | 42 +++++++++++++++++++++++++++++-- 3 files changed, 43 insertions(+), 4 deletions(-) (limited to 'R') diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R index fefe25b148..5bca4105fc 100644 --- a/R/pkg/R/DataFrame.R +++ b/R/pkg/R/DataFrame.R @@ -415,7 +415,7 @@ setMethod("coltypes", type <- PRIMITIVE_TYPES[[specialtype]] } } - type + type[[1]] }) # Find which types don't have mapping to R @@ -1136,6 +1136,7 @@ setMethod("collect", if (!is.null(PRIMITIVE_TYPES[[colType]]) && colType != "binary") { vec <- do.call(c, col) stopifnot(class(vec) != "list") + class(vec) <- PRIMITIVE_TYPES[[colType]] df[[colIndex]] <- vec } else { df[[colIndex]] <- col diff --git a/R/pkg/R/types.R b/R/pkg/R/types.R index abca703617..ade0f05c02 100644 --- a/R/pkg/R/types.R +++ b/R/pkg/R/types.R @@ -29,7 +29,7 @@ PRIMITIVE_TYPES <- as.environment(list( "string" = "character", "binary" = "raw", "boolean" = "logical", - "timestamp" = "POSIXct", + "timestamp" = c("POSIXct", "POSIXt"), "date" = "Date", # following types are not SQL types returned by dtypes(). They are listed here for usage # by checkType() in schema.R. diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R b/R/pkg/inst/tests/testthat/test_sparkSQL.R index 233a20c3d3..1494ebb3de 100644 --- a/R/pkg/inst/tests/testthat/test_sparkSQL.R +++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R @@ -1306,9 +1306,9 @@ test_that("column functions", { # Test first(), last() df <- read.json(jsonPath) - expect_equal(collect(select(df, first(df$age)))[[1]], NA) + expect_equal(collect(select(df, first(df$age)))[[1]], NA_real_) expect_equal(collect(select(df, first(df$age, TRUE)))[[1]], 30) - expect_equal(collect(select(df, first("age")))[[1]], NA) + expect_equal(collect(select(df, first("age")))[[1]], NA_real_) expect_equal(collect(select(df, first("age", TRUE)))[[1]], 30) expect_equal(collect(select(df, last(df$age)))[[1]], 19) expect_equal(collect(select(df, last(df$age, TRUE)))[[1]], 19) @@ -2777,6 +2777,44 @@ test_that("Call DataFrameWriter.load() API in Java without path and check argume "Unnamed arguments ignored: 2, 3, a.") }) +test_that("Collect on DataFrame when NAs exists at the top of a timestamp column", { + ldf <- data.frame(col1 = c(0, 1, 2), + col2 = c(as.POSIXct("2017-01-01 00:00:01"), + NA, + as.POSIXct("2017-01-01 12:00:01")), + col3 = c(as.POSIXlt("2016-01-01 00:59:59"), + NA, + as.POSIXlt("2016-01-01 12:01:01"))) + sdf1 <- createDataFrame(ldf) + ldf1 <- collect(sdf1) + expect_equal(dtypes(sdf1), list(c("col1", "double"), + c("col2", "timestamp"), + c("col3", "timestamp"))) + expect_equal(class(ldf1$col1), "numeric") + expect_equal(class(ldf1$col2), c("POSIXct", "POSIXt")) + expect_equal(class(ldf1$col3), c("POSIXct", "POSIXt")) + + # Columns with NAs at the top + sdf2 <- filter(sdf1, "col1 > 1") + ldf2 <- collect(sdf2) + expect_equal(dtypes(sdf2), list(c("col1", "double"), + c("col2", "timestamp"), + c("col3", "timestamp"))) + expect_equal(class(ldf2$col1), "numeric") + expect_equal(class(ldf2$col2), c("POSIXct", "POSIXt")) + expect_equal(class(ldf2$col3), c("POSIXct", "POSIXt")) + + # Columns with only NAs, the type will also be cast to PRIMITIVE_TYPE + sdf3 <- filter(sdf1, "col1 == 0") + ldf3 <- collect(sdf3) + expect_equal(dtypes(sdf3), list(c("col1", "double"), + c("col2", "timestamp"), + c("col3", "timestamp"))) + expect_equal(class(ldf3$col1), "numeric") + expect_equal(class(ldf3$col2), c("POSIXct", "POSIXt")) + expect_equal(class(ldf3$col3), c("POSIXct", "POSIXt")) +}) + unlink(parquetPath) unlink(orcPath) unlink(jsonPath) -- cgit v1.2.3