aboutsummaryrefslogtreecommitdiff
path: root/R/pkg
diff options
context:
space:
mode:
authorNarine Kokhlikyan <narine.kokhlikyan@gmail.com>2015-10-26 15:12:25 -0700
committerShivaram Venkataraman <shivaram@cs.berkeley.edu>2015-10-26 15:12:25 -0700
commit3689beb98b6a6db61e35049fdb57b0cd6aad8019 (patch)
treeb2002c5fd2b1a7de90890657387c226e4daae98d /R/pkg
parent616be29c7f2ebc184bd5ec97210da36a2174d80c (diff)
downloadspark-3689beb98b6a6db61e35049fdb57b0cd6aad8019.tar.gz
spark-3689beb98b6a6db61e35049fdb57b0cd6aad8019.tar.bz2
spark-3689beb98b6a6db61e35049fdb57b0cd6aad8019.zip
[SPARK-10979][SPARKR] Sparkrmerge: Add merge to DataFrame with R signature
Add merge function to DataFrame, which supports R signature. https://stat.ethz.ch/R-manual/R-devel/library/base/html/merge.html Author: Narine Kokhlikyan <narine.kokhlikyan@gmail.com> Closes #9012 from NarineK/sparkrmerge.
Diffstat (limited to 'R/pkg')
-rw-r--r--R/pkg/R/DataFrame.R140
-rw-r--r--R/pkg/inst/tests/test_sparkSQL.R37
2 files changed, 169 insertions, 8 deletions
diff --git a/R/pkg/R/DataFrame.R b/R/pkg/R/DataFrame.R
index 2acbd081cd..c894445954 100644
--- a/R/pkg/R/DataFrame.R
+++ b/R/pkg/R/DataFrame.R
@@ -1457,15 +1457,147 @@ setMethod("join",
dataFrame(sdf)
})
-#' @rdname merge
+#'
#' @name merge
#' @aliases join
+#' @title Merges two data frames
+#' @param x the first data frame to be joined
+#' @param y the second data frame to be joined
+#' @param by a character vector specifying the join columns. If by is not
+#' specified, the common column names in \code{x} and \code{y} will be used.
+#' @param by.x a character vector specifying the joining columns for x.
+#' @param by.y a character vector specifying the joining columns for y.
+#' @param all.x a boolean value indicating whether all the rows in x should
+#' be including in the join
+#' @param all.y a boolean value indicating whether all the rows in y should
+#' be including in the join
+#' @param sort a logical argument indicating whether the resulting columns should be sorted
+#' @details If all.x and all.y are set to FALSE, a natural join will be returned. If
+#' all.x is set to TRUE and all.y is set to FALSE, a left outer join will
+#' be returned. If all.x is set to FALSE and all.y is set to TRUE, a right
+#' outer join will be returned. If all.x and all.y are set to TRUE, a full
+#' outer join will be returned.
+#' @rdname merge
+#' @export
+#' @examples
+#'\dontrun{
+#' sc <- sparkR.init()
+#' sqlContext <- sparkRSQL.init(sc)
+#' df1 <- jsonFile(sqlContext, path)
+#' df2 <- jsonFile(sqlContext, path2)
+#' merge(df1, df2) # Performs a Cartesian
+#' merge(df1, df2, by = "col1") # Performs an inner join based on expression
+#' merge(df1, df2, by.x = "col1", by.y = "col2", all.y = TRUE)
+#' merge(df1, df2, by.x = "col1", by.y = "col2", all.x = TRUE)
+#' merge(df1, df2, by.x = "col1", by.y = "col2", all.x = TRUE, all.y = TRUE)
+#' merge(df1, df2, by.x = "col1", by.y = "col2", all = TRUE, sort = FALSE)
+#' merge(df1, df2, by = "col1", all = TRUE, suffixes = c("-X", "-Y"))
+#' }
setMethod("merge",
signature(x = "DataFrame", y = "DataFrame"),
- function(x, y, joinExpr = NULL, joinType = NULL, ...) {
- join(x, y, joinExpr, joinType)
- })
+ function(x, y, by = intersect(names(x), names(y)), by.x = by, by.y = by,
+ all = FALSE, all.x = all, all.y = all,
+ sort = TRUE, suffixes = c("_x","_y"), ... ) {
+
+ if (length(suffixes) != 2) {
+ stop("suffixes must have length 2")
+ }
+
+ # join type is identified based on the values of all, all.x and all.y
+ # default join type is inner, according to R it should be natural but since it
+ # is not supported in spark inner join is used
+ joinType <- "inner"
+ if (all || (all.x && all.y)) {
+ joinType <- "outer"
+ } else if (all.x) {
+ joinType <- "left_outer"
+ } else if (all.y) {
+ joinType <- "right_outer"
+ }
+ # join expression is based on by.x, by.y if both by.x and by.y are not missing
+ # or on by, if by.x or by.y are missing or have different lengths
+ if (length(by.x) > 0 && length(by.x) == length(by.y)) {
+ joinX <- by.x
+ joinY <- by.y
+ } else if (length(by) > 0) {
+ # if join columns have the same name for both dataframes,
+ # they are used in join expression
+ joinX <- by
+ joinY <- by
+ } else {
+ # if by or both by.x and by.y have length 0, use Cartesian Product
+ joinRes <- join(x, y)
+ return (joinRes)
+ }
+
+ # sets alias for making colnames unique in dataframes 'x' and 'y'
+ colsX <- generateAliasesForIntersectedCols(x, by, suffixes[1])
+ colsY <- generateAliasesForIntersectedCols(y, by, suffixes[2])
+
+ # selects columns with their aliases from dataframes
+ # in case same column names are present in both data frames
+ xsel <- select(x, colsX)
+ ysel <- select(y, colsY)
+
+ # generates join conditions and adds them into a list
+ # it also considers alias names of the columns while generating join conditions
+ joinColumns <- lapply(seq_len(length(joinX)), function(i) {
+ colX <- joinX[[i]]
+ colY <- joinY[[i]]
+
+ if (colX %in% by) {
+ colX <- paste(colX, suffixes[1], sep = "")
+ }
+ if (colY %in% by) {
+ colY <- paste(colY, suffixes[2], sep = "")
+ }
+
+ colX <- getColumn(xsel, colX)
+ colY <- getColumn(ysel, colY)
+
+ colX == colY
+ })
+
+ # concatenates join columns with '&' and executes join
+ joinExpr <- Reduce("&", joinColumns)
+ joinRes <- join(xsel, ysel, joinExpr, joinType)
+
+ # sorts the result by 'by' columns if sort = TRUE
+ if (sort && length(by) > 0) {
+ colNameWithSuffix <- paste(by, suffixes[2], sep = "")
+ joinRes <- do.call("arrange", c(joinRes, colNameWithSuffix, decreasing = FALSE))
+ }
+
+ joinRes
+ })
+
+#'
+#' Creates a list of columns by replacing the intersected ones with aliases.
+#' The name of the alias column is formed by concatanating the original column name and a suffix.
+#'
+#' @param x a DataFrame on which the
+#' @param intersectedColNames a list of intersected column names
+#' @param suffix a suffix for the column name
+#' @return list of columns
+#'
+generateAliasesForIntersectedCols <- function (x, intersectedColNames, suffix) {
+ allColNames <- names(x)
+ # sets alias for making colnames unique in dataframe 'x'
+ cols <- lapply(allColNames, function(colName) {
+ col <- getColumn(x, colName)
+ if (colName %in% intersectedColNames) {
+ newJoin <- paste(colName, suffix, sep = "")
+ if (newJoin %in% allColNames){
+ stop ("The following column name: ", newJoin, " occurs more than once in the 'DataFrame'.",
+ "Please use different suffixes for the intersected columns.")
+ }
+ col <- alias(col, newJoin)
+ }
+ col
+ })
+ cols
+}
#' UnionAll
#'
diff --git a/R/pkg/inst/tests/test_sparkSQL.R b/R/pkg/inst/tests/test_sparkSQL.R
index 67d8b23cd7..540854d114 100644
--- a/R/pkg/inst/tests/test_sparkSQL.R
+++ b/R/pkg/inst/tests/test_sparkSQL.R
@@ -1105,11 +1105,40 @@ test_that("join() and merge() on a DataFrame", {
expect_equal(count(joined9), 4)
expect_true(is.na(collect(orderBy(joined9, joined9$age))$age[2]))
- merged <- select(merge(df, df2, df$name == df2$name, "outer"),
- alias(df$age + 5, "newAge"), df$name, df2$test)
- expect_equal(names(merged), c("newAge", "name", "test"))
+ merged <- merge(df, df2, by.x = "name", by.y = "name", all.x = TRUE, all.y = TRUE)
expect_equal(count(merged), 4)
- expect_equal(collect(orderBy(merged, merged$name))$newAge[3], 24)
+ expect_equal(names(merged), c("age", "name_x", "name_y", "test"))
+ expect_equal(collect(orderBy(merged, merged$name_x))$age[3], 19)
+
+ merged <- merge(df, df2, suffixes = c("-X","-Y"))
+ expect_equal(count(merged), 3)
+ expect_equal(names(merged), c("age", "name-X", "name-Y", "test"))
+ expect_equal(collect(orderBy(merged, merged$"name-X"))$age[1], 30)
+
+ merged <- merge(df, df2, by = "name", suffixes = c("-X","-Y"), sort = FALSE)
+ expect_equal(count(merged), 3)
+ expect_equal(names(merged), c("age", "name-X", "name-Y", "test"))
+ expect_equal(collect(orderBy(merged, merged$"name-Y"))$"name-X"[3], "Michael")
+
+ merged <- merge(df, df2, by = "name", all = T, sort = T)
+ expect_equal(count(merged), 4)
+ expect_equal(names(merged), c("age", "name_x", "name_y", "test"))
+ expect_equal(collect(orderBy(merged, merged$"name_y"))$"name_x"[1], "Andy")
+
+ merged <- merge(df, df2, by = NULL)
+ expect_equal(count(merged), 12)
+ expect_equal(names(merged), c("age", "name", "name", "test"))
+
+ mockLines3 <- c("{\"name\":\"Michael\", \"name_y\":\"Michael\", \"test\": \"yes\"}",
+ "{\"name\":\"Andy\", \"name_y\":\"Andy\", \"test\": \"no\"}",
+ "{\"name\":\"Justin\", \"name_y\":\"Justin\", \"test\": \"yes\"}",
+ "{\"name\":\"Bob\", \"name_y\":\"Bob\", \"test\": \"yes\"}")
+ jsonPath3 <- tempfile(pattern="sparkr-test", fileext=".tmp")
+ writeLines(mockLines3, jsonPath3)
+ df3 <- jsonFile(sqlContext, jsonPath3)
+ expect_error(merge(df, df3),
+ paste("The following column name: name_y occurs more than once in the 'DataFrame'.",
+ "Please use different suffixes for the intersected columns.", sep = ""))
})
test_that("toJSON() returns an RDD of the correct values", {