diff options
author | Eric Liang <ekl@databricks.com> | 2015-07-20 20:49:38 -0700 |
---|---|---|
committer | Shivaram Venkataraman <shivaram@cs.berkeley.edu> | 2015-07-20 20:49:38 -0700 |
commit | 1cbdd8991898912a8471a7070c472a0edb92487c (patch) | |
tree | 2ce542693eadb80bad9644be4a9d5a389b4466c9 /mllib | |
parent | 2bdf9914ab709bf9c1cdd17fc5dd7a69f6d46f29 (diff) | |
download | spark-1cbdd8991898912a8471a7070c472a0edb92487c.tar.gz spark-1cbdd8991898912a8471a7070c472a0edb92487c.tar.bz2 spark-1cbdd8991898912a8471a7070c472a0edb92487c.zip |
[SPARK-9201] [ML] Initial integration of MLlib + SparkR using RFormula
This exposes the SparkR:::glm() and SparkR:::predict() APIs. It was necessary to change RFormula to silently drop the label column if it was missing from the input dataset, which is kind of a hack but necessary to integrate with the Pipeline API.
The umbrella design doc for MLlib + SparkR integration can be viewed here: https://docs.google.com/document/d/10NZNSEurN2EdWM31uFYsgayIPfCFHiuIu3pCWrUmP_c/edit
mengxr
Author: Eric Liang <ekl@databricks.com>
Closes #7483 from ericl/spark-8774 and squashes the following commits:
3dfac0c [Eric Liang] update
17ef516 [Eric Liang] more comments
1753a0f [Eric Liang] make glm generic
b0f50f8 [Eric Liang] equivalence test
550d56d [Eric Liang] export methods
c015697 [Eric Liang] second pass
117949a [Eric Liang] comments
5afbc67 [Eric Liang] test label columns
6b7f15f [Eric Liang] Fri Jul 17 14:20:22 PDT 2015
3a63ae5 [Eric Liang] Fri Jul 17 13:41:52 PDT 2015
ce61367 [Eric Liang] Fri Jul 17 13:41:17 PDT 2015
0299c59 [Eric Liang] Fri Jul 17 13:40:32 PDT 2015
e37603f [Eric Liang] Fri Jul 17 12:15:03 PDT 2015
d417d0c [Eric Liang] Merge remote-tracking branch 'upstream/master' into spark-8774
29a2ce7 [Eric Liang] Merge branch 'spark-8774-1' into spark-8774
d1959d2 [Eric Liang] clarify comment
2db68aa [Eric Liang] second round of comments
dc3c943 [Eric Liang] address comments
5765ec6 [Eric Liang] fix style checks
1f361b0 [Eric Liang] doc
d33211b [Eric Liang] r support
fb0826b [Eric Liang] [SPARK-8774] Add R model formula with basic support as a transformer
Diffstat (limited to 'mllib')
3 files changed, 61 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala index 56169f2a01..f7b46efa10 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/RFormula.scala @@ -73,12 +73,16 @@ class RFormula(override val uid: String) val withFeatures = transformFeatures.transformSchema(schema) if (hasLabelCol(schema)) { withFeatures - } else { + } else if (schema.exists(_.name == parsedFormula.get.label)) { val nullable = schema(parsedFormula.get.label).dataType match { case _: NumericType | BooleanType => false case _ => true } StructType(withFeatures.fields :+ StructField($(labelCol), DoubleType, nullable)) + } else { + // Ignore the label field. This is a hack so that this transformer can also work on test + // datasets in a Pipeline. + withFeatures } } @@ -92,10 +96,10 @@ class RFormula(override val uid: String) override def toString: String = s"RFormula(${get(formula)})" private def transformLabel(dataset: DataFrame): DataFrame = { + val labelName = parsedFormula.get.label if (hasLabelCol(dataset.schema)) { dataset - } else { - val labelName = parsedFormula.get.label + } else if (dataset.schema.exists(_.name == labelName)) { dataset.schema(labelName).dataType match { case _: NumericType | BooleanType => dataset.withColumn($(labelCol), dataset(labelName).cast(DoubleType)) @@ -103,6 +107,10 @@ class RFormula(override val uid: String) case other => throw new IllegalArgumentException("Unsupported type for label: " + other) } + } else { + // Ignore the label field. This is a hack so that this transformer can also work on test + // datasets in a Pipeline. + dataset } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala new file mode 100644 index 0000000000..1ee080641e --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/r/SparkRWrappers.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.api.r + +import org.apache.spark.ml.feature.RFormula +import org.apache.spark.ml.classification.LogisticRegression +import org.apache.spark.ml.regression.LinearRegression +import org.apache.spark.ml.{Pipeline, PipelineModel} +import org.apache.spark.sql.DataFrame + +private[r] object SparkRWrappers { + def fitRModelFormula( + value: String, + df: DataFrame, + family: String, + lambda: Double, + alpha: Double): PipelineModel = { + val formula = new RFormula().setFormula(value) + val estimator = family match { + case "gaussian" => new LinearRegression().setRegParam(lambda).setElasticNetParam(alpha) + case "binomial" => new LogisticRegression().setRegParam(lambda).setElasticNetParam(alpha) + } + val pipeline = new Pipeline().setStages(Array(formula, estimator)) + pipeline.fit(df) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala index fa8611b243..79c4ccf02d 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala @@ -74,6 +74,15 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext { } } + test("allow missing label column for test datasets") { + val formula = new RFormula().setFormula("y ~ x").setLabelCol("label") + val original = sqlContext.createDataFrame(Seq((0, 1.0), (2, 2.0))).toDF("x", "_not_y") + val resultSchema = formula.transformSchema(original.schema) + assert(resultSchema.length == 3) + assert(!resultSchema.exists(_.name == "label")) + assert(resultSchema.toString == formula.transform(original).schema.toString) + } + // TODO(ekl) enable after we implement string label support // test("transform string label") { // val formula = new RFormula().setFormula("name ~ id") |