aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorXin Ren <iamshrek@126.com>2016-09-02 01:54:28 -0700
committerFelix Cheung <felixcheung@apache.org>2016-09-02 01:54:28 -0700
commit6969dcc79a33d715250958b24361f2d43552d840 (patch)
tree93ec5d331962f8b9d9f381e8813bf6b4b7bd392e /mllib
parent0f30cdedbdb0d38e8c479efab6bb1c6c376206ff (diff)
downloadspark-6969dcc79a33d715250958b24361f2d43552d840.tar.gz
spark-6969dcc79a33d715250958b24361f2d43552d840.tar.bz2
spark-6969dcc79a33d715250958b24361f2d43552d840.zip
[SPARK-15509][ML][SPARKR] R MLlib algorithms should support input columns "features" and "label"
https://issues.apache.org/jira/browse/SPARK-15509 ## What changes were proposed in this pull request? Currently in SparkR, when you load a LibSVM dataset using the sqlContext and then pass it to an MLlib algorithm, the ML wrappers will fail since they will try to create a "features" column, which conflicts with the existing "features" column from the LibSVM loader. E.g., using the "mnist" dataset from LibSVM: `training <- loadDF(sqlContext, ".../mnist", "libsvm")` `model <- naiveBayes(label ~ features, training)` This fails with: ``` 16/05/24 11:52:41 ERROR RBackendHandler: fit on org.apache.spark.ml.r.NaiveBayesWrapper failed Error in invokeJava(isStatic = TRUE, className, methodName, ...) : java.lang.IllegalArgumentException: Output column features already exists. at org.apache.spark.ml.feature.VectorAssembler.transformSchema(VectorAssembler.scala:120) at org.apache.spark.ml.Pipeline$$anonfun$transformSchema$4.apply(Pipeline.scala:179) at org.apache.spark.ml.Pipeline$$anonfun$transformSchema$4.apply(Pipeline.scala:179) at scala.collection.IndexedSeqOptimized$class.foldl(IndexedSeqOptimized.scala:57) at scala.collection.IndexedSeqOptimized$class.foldLeft(IndexedSeqOptimized.scala:66) at scala.collection.mutable.ArrayOps$ofRef.foldLeft(ArrayOps.scala:186) at org.apache.spark.ml.Pipeline.transformSchema(Pipeline.scala:179) at org.apache.spark.ml.PipelineStage.transformSchema(Pipeline.scala:67) at org.apache.spark.ml.Pipeline.fit(Pipeline.scala:131) at org.apache.spark.ml.feature.RFormula.fit(RFormula.scala:169) at org.apache.spark.ml.r.NaiveBayesWrapper$.fit(NaiveBayesWrapper.scala:62) at org.apache.spark.ml.r.NaiveBayesWrapper.fit(NaiveBayesWrapper.sca The same issue appears for the "label" column once you rename the "features" column. ``` The cause is, when using `loadDF()` to generate dataframes, sometimes it’s with default column name `“label”` and `“features”`, and these two name will conflict with default column names `setDefault(labelCol, "label")` and ` setDefault(featuresCol, "features")` of `SharedParams.scala` ## How was this patch tested? Test on my local machine. Author: Xin Ren <iamshrek@126.com> Closes #13584 from keypointt/SPARK-15509.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/GaussianMixtureWrapper.scala5
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/IsotonicRegressionWrapper.scala5
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala5
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala11
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/RWrapperUtils.scala71
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala3
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/r/RWrapperUtilsSuite.scala56
9 files changed, 144 insertions, 14 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala
index 5462f80d69..67d037ed6e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala
@@ -87,6 +87,7 @@ private[r] object AFTSurvivalRegressionWrapper extends MLReadable[AFTSurvivalReg
val (rewritedFormula, censorCol) = formulaRewrite(formula)
val rFormula = new RFormula().setFormula(rewritedFormula)
+ RWrapperUtils.checkDataColumns(rFormula, data)
val rFormulaModel = rFormula.fit(data)
// get feature names from output schema
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GaussianMixtureWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GaussianMixtureWrapper.scala
index 1e8b3bbab6..b654233a89 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/GaussianMixtureWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/GaussianMixtureWrapper.scala
@@ -68,10 +68,11 @@ private[r] object GaussianMixtureWrapper extends MLReadable[GaussianMixtureWrapp
maxIter: Int,
tol: Double): GaussianMixtureWrapper = {
- val rFormulaModel = new RFormula()
+ val rFormula = new RFormula()
.setFormula(formula)
.setFeaturesCol("features")
- .fit(data)
+ RWrapperUtils.checkDataColumns(rFormula, data)
+ val rFormulaModel = rFormula.fit(data)
// get feature names from output schema
val schema = rFormulaModel.transform(data).schema
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala
index 7a6ab618a1..35313258f9 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala
@@ -73,6 +73,7 @@ private[r] object GeneralizedLinearRegressionWrapper
regParam: Double): GeneralizedLinearRegressionWrapper = {
val rFormula = new RFormula()
.setFormula(formula)
+ RWrapperUtils.checkDataColumns(rFormula, data)
val rFormulaModel = rFormula.fit(data)
// get labels and feature names from output schema
val schema = rFormulaModel.transform(data).schema
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/IsotonicRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/IsotonicRegressionWrapper.scala
index a7992debe6..2ed7d7b770 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/IsotonicRegressionWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/IsotonicRegressionWrapper.scala
@@ -57,10 +57,11 @@ private[r] object IsotonicRegressionWrapper
featureIndex: Int,
weightCol: String): IsotonicRegressionWrapper = {
- val rFormulaModel = new RFormula()
+ val rFormula = new RFormula()
.setFormula(formula)
.setFeaturesCol("features")
- .fit(data)
+ RWrapperUtils.checkDataColumns(rFormula, data)
+ val rFormulaModel = rFormula.fit(data)
// get feature names from output schema
val schema = rFormulaModel.transform(data).schema
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala
index 4d4c303fc8..8616a8c01e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala
@@ -70,10 +70,11 @@ private[r] object KMeansWrapper extends MLReadable[KMeansWrapper] {
maxIter: Int,
initMode: String): KMeansWrapper = {
- val rFormulaModel = new RFormula()
+ val rFormula = new RFormula()
.setFormula(formula)
.setFeaturesCol("features")
- .fit(data)
+ RWrapperUtils.checkDataColumns(rFormula, data)
+ val rFormulaModel = rFormula.fit(data)
// get feature names from output schema
val schema = rFormulaModel.transform(data).schema
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala
index 1dac246b03..f2cb24b964 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala
@@ -59,13 +59,14 @@ private[r] object NaiveBayesWrapper extends MLReadable[NaiveBayesWrapper] {
def fit(formula: String, data: DataFrame, smoothing: Double): NaiveBayesWrapper = {
val rFormula = new RFormula()
.setFormula(formula)
- .fit(data)
+ RWrapperUtils.checkDataColumns(rFormula, data)
+ val rFormulaModel = rFormula.fit(data)
// get labels and feature names from output schema
- val schema = rFormula.transform(data).schema
- val labelAttr = Attribute.fromStructField(schema(rFormula.getLabelCol))
+ val schema = rFormulaModel.transform(data).schema
+ val labelAttr = Attribute.fromStructField(schema(rFormulaModel.getLabelCol))
.asInstanceOf[NominalAttribute]
val labels = labelAttr.values.get
- val featureAttrs = AttributeGroup.fromStructField(schema(rFormula.getFeaturesCol))
+ val featureAttrs = AttributeGroup.fromStructField(schema(rFormulaModel.getFeaturesCol))
.attributes.get
val features = featureAttrs.map(_.name.get)
// assemble and fit the pipeline
@@ -78,7 +79,7 @@ private[r] object NaiveBayesWrapper extends MLReadable[NaiveBayesWrapper] {
.setOutputCol(PREDICTED_LABEL_COL)
.setLabels(labels)
val pipeline = new Pipeline()
- .setStages(Array(rFormula, naiveBayes, idxToStr))
+ .setStages(Array(rFormulaModel, naiveBayes, idxToStr))
.fit(data)
new NaiveBayesWrapper(pipeline, labels, features)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RWrapperUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RWrapperUtils.scala
new file mode 100644
index 0000000000..6a435992e3
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/RWrapperUtils.scala
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.r
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.ml.feature.RFormula
+import org.apache.spark.sql.Dataset
+
+object RWrapperUtils extends Logging {
+
+ /**
+ * DataFrame column check.
+ * When loading data, default columns "features" and "label" will be added. And these two names
+ * would conflict with RFormula default feature and label column names.
+ * Here is to change the column name to avoid "column already exists" error.
+ *
+ * @param rFormula RFormula instance
+ * @param data Input dataset
+ * @return Unit
+ */
+ def checkDataColumns(rFormula: RFormula, data: Dataset[_]): Unit = {
+ if (data.schema.fieldNames.contains(rFormula.getLabelCol)) {
+ val newLabelName = convertToUniqueName(rFormula.getLabelCol, data.schema.fieldNames)
+ logWarning(
+ s"data containing ${rFormula.getLabelCol} column, using new name $newLabelName instead")
+ rFormula.setLabelCol(newLabelName)
+ }
+
+ if (data.schema.fieldNames.contains(rFormula.getFeaturesCol)) {
+ val newFeaturesName = convertToUniqueName(rFormula.getFeaturesCol, data.schema.fieldNames)
+ logWarning(s"data containing ${rFormula.getFeaturesCol} column, " +
+ s"using new name $newFeaturesName instead")
+ rFormula.setFeaturesCol(newFeaturesName)
+ }
+ }
+
+ /**
+ * Convert conflicting name to be an unique name.
+ * Appending a sequence number, like originalName_output1
+ * and incrementing until it is not already there
+ *
+ * @param originalName Original name
+ * @param fieldNames Array of field names in existing schema
+ * @return String
+ */
+ def convertToUniqueName(originalName: String, fieldNames: Array[String]): String = {
+ var counter = 1
+ var newName = originalName + "_output"
+
+ while (fieldNames.contains(newName)) {
+ newName = originalName + "_output" + counter
+ counter += 1
+ }
+ newName
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
index c12ab8fe9e..0794a049d9 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/RFormulaSuite.scala
@@ -54,9 +54,6 @@ class RFormulaSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
intercept[IllegalArgumentException] {
formula.fit(original)
}
- intercept[IllegalArgumentException] {
- formula.fit(original)
- }
}
test("label column already exists") {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/r/RWrapperUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/r/RWrapperUtilsSuite.scala
new file mode 100644
index 0000000000..ddc24cb3a6
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/r/RWrapperUtilsSuite.scala
@@ -0,0 +1,56 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.r
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.feature.{RFormula, RFormulaModel}
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+
+class RWrapperUtilsSuite extends SparkFunSuite with MLlibTestSparkContext {
+
+ test("avoid libsvm data column name conflicting") {
+ val rFormula = new RFormula().setFormula("label ~ features")
+ val data = spark.read.format("libsvm").load("../data/mllib/sample_libsvm_data.txt")
+
+ // if not checking column name, then IllegalArgumentException
+ intercept[IllegalArgumentException] {
+ rFormula.fit(data)
+ }
+
+ // after checking, model build is ok
+ RWrapperUtils.checkDataColumns(rFormula, data)
+
+ assert(rFormula.getLabelCol == "label_output")
+ assert(rFormula.getFeaturesCol == "features_output")
+
+ val model = rFormula.fit(data)
+ assert(model.isInstanceOf[RFormulaModel])
+
+ assert(model.getLabelCol == "label_output")
+ assert(model.getFeaturesCol == "features_output")
+ }
+
+ test("generate unique name by appending a sequence number") {
+ val originalName = "label"
+ val fieldNames = Array("label_output", "label_output1", "label_output2")
+ val newName = RWrapperUtils.convertToUniqueName(originalName, fieldNames)
+
+ assert(newName === "label_output3")
+ }
+
+}