From 1cbdd8991898912a8471a7070c472a0edb92487c Mon Sep 17 00:00:00 2001 From: Eric Liang Date: Mon, 20 Jul 2015 20:49:38 -0700 Subject: [SPARK-9201] [ML] Initial integration of MLlib + SparkR using RFormula This exposes the SparkR:::glm() and SparkR:::predict() APIs. It was necessary to change RFormula to silently drop the label column if it was missing from the input dataset, which is kind of a hack but necessary to integrate with the Pipeline API. The umbrella design doc for MLlib + SparkR integration can be viewed here: https://docs.google.com/document/d/10NZNSEurN2EdWM31uFYsgayIPfCFHiuIu3pCWrUmP_c/edit mengxr Author: Eric Liang Closes #7483 from ericl/spark-8774 and squashes the following commits: 3dfac0c [Eric Liang] update 17ef516 [Eric Liang] more comments 1753a0f [Eric Liang] make glm generic b0f50f8 [Eric Liang] equivalence test 550d56d [Eric Liang] export methods c015697 [Eric Liang] second pass 117949a [Eric Liang] comments 5afbc67 [Eric Liang] test label columns 6b7f15f [Eric Liang] Fri Jul 17 14:20:22 PDT 2015 3a63ae5 [Eric Liang] Fri Jul 17 13:41:52 PDT 2015 ce61367 [Eric Liang] Fri Jul 17 13:41:17 PDT 2015 0299c59 [Eric Liang] Fri Jul 17 13:40:32 PDT 2015 e37603f [Eric Liang] Fri Jul 17 12:15:03 PDT 2015 d417d0c [Eric Liang] Merge remote-tracking branch 'upstream/master' into spark-8774 29a2ce7 [Eric Liang] Merge branch 'spark-8774-1' into spark-8774 d1959d2 [Eric Liang] clarify comment 2db68aa [Eric Liang] second round of comments dc3c943 [Eric Liang] address comments 5765ec6 [Eric Liang] fix style checks 1f361b0 [Eric Liang] doc d33211b [Eric Liang] r support fb0826b [Eric Liang] [SPARK-8774] Add R model formula with basic support as a transformer --- R/pkg/DESCRIPTION | 1 + R/pkg/NAMESPACE | 4 +++ R/pkg/R/generics.R | 4 +++ R/pkg/R/mllib.R | 73 +++++++++++++++++++++++++++++++++++++++++++ R/pkg/inst/tests/test_mllib.R | 42 +++++++++++++++++++++++++ 5 files changed, 124 insertions(+) create mode 100644 R/pkg/R/mllib.R create mode 100644 R/pkg/inst/tests/test_mllib.R (limited to 'R/pkg') diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index d028821534..4949d86d20 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -29,6 +29,7 @@ Collate: 'client.R' 'context.R' 'deserialize.R' + 'mllib.R' 'serialize.R' 'sparkR.R' 'utils.R' diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE index 331307c207..5834813319 100644 --- a/R/pkg/NAMESPACE +++ b/R/pkg/NAMESPACE @@ -10,6 +10,10 @@ export("sparkR.init") export("sparkR.stop") export("print.jobj") +# MLlib integration +exportMethods("glm", + "predict") + # Job group lifecycle management methods export("setJobGroup", "clearJobGroup", diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R index ebe6fbd97c..39b5586f7c 100644 --- a/R/pkg/R/generics.R +++ b/R/pkg/R/generics.R @@ -661,3 +661,7 @@ setGeneric("toRadians", function(x) { standardGeneric("toRadians") }) #' @rdname column #' @export setGeneric("upper", function(x) { standardGeneric("upper") }) + +#' @rdname glm +#' @export +setGeneric("glm") diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R new file mode 100644 index 0000000000..258e354081 --- /dev/null +++ b/R/pkg/R/mllib.R @@ -0,0 +1,73 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# mllib.R: Provides methods for MLlib integration + +#' @title S4 class that represents a PipelineModel +#' @param model A Java object reference to the backing Scala PipelineModel +#' @export +setClass("PipelineModel", representation(model = "jobj")) + +#' Fits a generalized linear model +#' +#' Fits a generalized linear model, similarly to R's glm(). Also see the glmnet package. +#' +#' @param formula A symbolic description of the model to be fitted. Currently only a few formula +#' operators are supported, including '~' and '+'. +#' @param data DataFrame for training +#' @param family Error distribution. "gaussian" -> linear regression, "binomial" -> logistic reg. +#' @param lambda Regularization parameter +#' @param alpha Elastic-net mixing parameter (see glmnet's documentation for details) +#' @return a fitted MLlib model +#' @rdname glm +#' @export +#' @examples +#'\dontrun{ +#' sc <- sparkR.init() +#' sqlContext <- sparkRSQL.init(sc) +#' data(iris) +#' df <- createDataFrame(sqlContext, iris) +#' model <- glm(Sepal_Length ~ Sepal_Width, df) +#'} +setMethod("glm", signature(formula = "formula", family = "ANY", data = "DataFrame"), + function(formula, family = c("gaussian", "binomial"), data, lambda = 0, alpha = 0) { + family <- match.arg(family) + model <- callJStatic("org.apache.spark.ml.api.r.SparkRWrappers", + "fitRModelFormula", deparse(formula), data@sdf, family, lambda, + alpha) + return(new("PipelineModel", model = model)) + }) + +#' Make predictions from a model +#' +#' Makes predictions from a model produced by glm(), similarly to R's predict(). +#' +#' @param model A fitted MLlib model +#' @param newData DataFrame for testing +#' @return DataFrame containing predicted values +#' @rdname glm +#' @export +#' @examples +#'\dontrun{ +#' model <- glm(y ~ x, trainingData) +#' predicted <- predict(model, testData) +#' showDF(predicted) +#'} +setMethod("predict", signature(object = "PipelineModel"), + function(object, newData) { + return(dataFrame(callJMethod(object@model, "transform", newData@sdf))) + }) diff --git a/R/pkg/inst/tests/test_mllib.R b/R/pkg/inst/tests/test_mllib.R new file mode 100644 index 0000000000..a492763344 --- /dev/null +++ b/R/pkg/inst/tests/test_mllib.R @@ -0,0 +1,42 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +library(testthat) + +context("MLlib functions") + +# Tests for MLlib functions in SparkR + +sc <- sparkR.init() + +sqlContext <- sparkRSQL.init(sc) + +test_that("glm and predict", { + training <- createDataFrame(sqlContext, iris) + test <- select(training, "Sepal_Length") + model <- glm(Sepal_Width ~ Sepal_Length, training, family = "gaussian") + prediction <- predict(model, test) + expect_equal(typeof(take(select(prediction, "prediction"), 1)$prediction), "double") +}) + +test_that("predictions match with native glm", { + training <- createDataFrame(sqlContext, iris) + model <- glm(Sepal_Width ~ Sepal_Length, data = training) + vals <- collect(select(predict(model, training), "prediction")) + rVals <- predict(glm(Sepal.Width ~ Sepal.Length, data = iris), iris) + expect_true(all(abs(rVals - vals) < 1e-9), rVals - vals) +}) -- cgit v1.2.3