aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2016-09-10 00:27:10 -0700
committerYanbo Liang <ybliang8@gmail.com>2016-09-10 00:27:10 -0700
commitbcdd259c371b1dcdb41baf227867d7e2ecb923c6 (patch)
treeed5a57cbff9d437dc59e0517b1e6ad2e94a08316 /mllib/src/main
parent1fec3ce4e19664aa9f9238d9491b0cb1511f9be1 (diff)
downloadspark-bcdd259c371b1dcdb41baf227867d7e2ecb923c6.tar.gz
spark-bcdd259c371b1dcdb41baf227867d7e2ecb923c6.tar.bz2
spark-bcdd259c371b1dcdb41baf227867d7e2ecb923c6.zip
[SPARK-15509][FOLLOW-UP][ML][SPARKR] R MLlib algorithms should support input columns "features" and "label"
## What changes were proposed in this pull request? #13584 resolved the issue of features and label columns conflict with ```RFormula``` default ones when loading libsvm data, but it still left some issues should be resolved: 1, It’s not necessary to check and rename label column. Since we have considerations on the design of ```RFormula```, it can handle the case of label column already exists(with restriction of the existing label column should be numeric/boolean type). So it’s not necessary to change the column name to avoid conflict. If the label column is not numeric/boolean type, ```RFormula``` will throw exception. 2, We should rename features column name to new one if there is conflict, but appending a random value is enough since it was used internally only. We done similar work when implementing ```SQLTransformer```. 3, We should set correct new features column for the estimators. Take ```GLM``` as example: ```GLM``` estimator should set features column with the changed one(rFormula.getFeaturesCol) rather than the default “features”. Although it’s same when training model, but it involves problems when predicting. The following is the prediction result of GLM before this PR: ![image](https://cloud.githubusercontent.com/assets/1962026/18308227/84c3c452-74a8-11e6-9caa-9d6d846cc957.png) We should drop the internal used feature column name, otherwise, it will appear on the prediction DataFrame which will confused users. And this behavior is same as other scenarios which does not exist column name conflict. After this PR: ![image](https://cloud.githubusercontent.com/assets/1962026/18308240/92082a04-74a8-11e6-9226-801f52b856d9.png) ## How was this patch tested? Existing unit tests. Author: Yanbo Liang <ybliang8@gmail.com> Closes #14993 from yanboliang/spark-15509.
Diffstat (limited to 'mllib/src/main')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/GaussianMixtureWrapper.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/IsotonicRegressionWrapper.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/RWrapperUtils.scala34
7 files changed, 10 insertions, 30 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala
index 67d037ed6e..bd965acf56 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala
@@ -99,6 +99,7 @@ private[r] object AFTSurvivalRegressionWrapper extends MLReadable[AFTSurvivalReg
val aft = new AFTSurvivalRegression()
.setCensorCol(censorCol)
.setFitIntercept(rFormula.hasIntercept)
+ .setFeaturesCol(rFormula.getFeaturesCol)
val pipeline = new Pipeline()
.setStages(Array(rFormulaModel, aft))
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GaussianMixtureWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GaussianMixtureWrapper.scala
index b654233a89..b708702959 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/GaussianMixtureWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/GaussianMixtureWrapper.scala
@@ -85,6 +85,7 @@ private[r] object GaussianMixtureWrapper extends MLReadable[GaussianMixtureWrapp
.setK(k)
.setMaxIter(maxIter)
.setTol(tol)
+ .setFeaturesCol(rFormula.getFeaturesCol)
val pipeline = new Pipeline()
.setStages(Array(rFormulaModel, gm))
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala
index 35313258f9..b1bb577e1f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala
@@ -89,6 +89,7 @@ private[r] object GeneralizedLinearRegressionWrapper
.setMaxIter(maxIter)
.setWeightCol(weightCol)
.setRegParam(regParam)
+ .setFeaturesCol(rFormula.getFeaturesCol)
val pipeline = new Pipeline()
.setStages(Array(rFormulaModel, glr))
.fit(data)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/IsotonicRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/IsotonicRegressionWrapper.scala
index 2ed7d7b770..48632316f3 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/IsotonicRegressionWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/IsotonicRegressionWrapper.scala
@@ -75,6 +75,7 @@ private[r] object IsotonicRegressionWrapper
.setIsotonic(isotonic)
.setFeatureIndex(featureIndex)
.setWeightCol(weightCol)
+ .setFeaturesCol(rFormula.getFeaturesCol)
val pipeline = new Pipeline()
.setStages(Array(rFormulaModel, isotonicRegression))
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala
index 8616a8c01e..ea9458525a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala
@@ -86,6 +86,7 @@ private[r] object KMeansWrapper extends MLReadable[KMeansWrapper] {
.setK(k)
.setMaxIter(maxIter)
.setInitMode(initMode)
+ .setFeaturesCol(rFormula.getFeaturesCol)
val pipeline = new Pipeline()
.setStages(Array(rFormulaModel, kMeans))
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala
index f2cb24b964..d1a39fea76 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala
@@ -73,6 +73,7 @@ private[r] object NaiveBayesWrapper extends MLReadable[NaiveBayesWrapper] {
val naiveBayes = new NaiveBayes()
.setSmoothing(smoothing)
.setModelType("bernoulli")
+ .setFeaturesCol(rFormula.getFeaturesCol)
.setPredictionCol(PREDICTED_LABEL_INDEX_COL)
val idxToStr = new IndexToString()
.setInputCol(PREDICTED_LABEL_INDEX_COL)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RWrapperUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RWrapperUtils.scala
index 6a435992e3..379007c4d9 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/RWrapperUtils.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/RWrapperUtils.scala
@@ -19,14 +19,15 @@ package org.apache.spark.ml.r
import org.apache.spark.internal.Logging
import org.apache.spark.ml.feature.RFormula
+import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.Dataset
object RWrapperUtils extends Logging {
/**
* DataFrame column check.
- * When loading data, default columns "features" and "label" will be added. And these two names
- * would conflict with RFormula default feature and label column names.
+ * When loading libsvm data, default columns "features" and "label" will be added.
+ * And "features" would conflict with RFormula default feature column names.
* Here is to change the column name to avoid "column already exists" error.
*
* @param rFormula RFormula instance
@@ -34,38 +35,11 @@ object RWrapperUtils extends Logging {
* @return Unit
*/
def checkDataColumns(rFormula: RFormula, data: Dataset[_]): Unit = {
- if (data.schema.fieldNames.contains(rFormula.getLabelCol)) {
- val newLabelName = convertToUniqueName(rFormula.getLabelCol, data.schema.fieldNames)
- logWarning(
- s"data containing ${rFormula.getLabelCol} column, using new name $newLabelName instead")
- rFormula.setLabelCol(newLabelName)
- }
-
if (data.schema.fieldNames.contains(rFormula.getFeaturesCol)) {
- val newFeaturesName = convertToUniqueName(rFormula.getFeaturesCol, data.schema.fieldNames)
+ val newFeaturesName = s"${Identifiable.randomUID(rFormula.getFeaturesCol)}"
logWarning(s"data containing ${rFormula.getFeaturesCol} column, " +
s"using new name $newFeaturesName instead")
rFormula.setFeaturesCol(newFeaturesName)
}
}
-
- /**
- * Convert conflicting name to be an unique name.
- * Appending a sequence number, like originalName_output1
- * and incrementing until it is not already there
- *
- * @param originalName Original name
- * @param fieldNames Array of field names in existing schema
- * @return String
- */
- def convertToUniqueName(originalName: String, fieldNames: Array[String]): String = {
- var counter = 1
- var newName = originalName + "_output"
-
- while (fieldNames.contains(newName)) {
- newName = originalName + "_output" + counter
- counter += 1
- }
- newName
- }
}