aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/GaussianMixtureWrapper.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/IsotonicRegressionWrapper.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/RWrapperUtils.scala34
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/r/RWrapperUtilsSuite.scala16
8 files changed, 14 insertions, 42 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala
index 67d037ed6e..bd965acf56 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/AFTSurvivalRegressionWrapper.scala
@@ -99,6 +99,7 @@ private[r] object AFTSurvivalRegressionWrapper extends MLReadable[AFTSurvivalReg
val aft = new AFTSurvivalRegression()
.setCensorCol(censorCol)
.setFitIntercept(rFormula.hasIntercept)
+ .setFeaturesCol(rFormula.getFeaturesCol)
val pipeline = new Pipeline()
.setStages(Array(rFormulaModel, aft))
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GaussianMixtureWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GaussianMixtureWrapper.scala
index b654233a89..b708702959 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/GaussianMixtureWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/GaussianMixtureWrapper.scala
@@ -85,6 +85,7 @@ private[r] object GaussianMixtureWrapper extends MLReadable[GaussianMixtureWrapp
.setK(k)
.setMaxIter(maxIter)
.setTol(tol)
+ .setFeaturesCol(rFormula.getFeaturesCol)
val pipeline = new Pipeline()
.setStages(Array(rFormulaModel, gm))
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala
index 35313258f9..b1bb577e1f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala
@@ -89,6 +89,7 @@ private[r] object GeneralizedLinearRegressionWrapper
.setMaxIter(maxIter)
.setWeightCol(weightCol)
.setRegParam(regParam)
+ .setFeaturesCol(rFormula.getFeaturesCol)
val pipeline = new Pipeline()
.setStages(Array(rFormulaModel, glr))
.fit(data)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/IsotonicRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/IsotonicRegressionWrapper.scala
index 2ed7d7b770..48632316f3 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/IsotonicRegressionWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/IsotonicRegressionWrapper.scala
@@ -75,6 +75,7 @@ private[r] object IsotonicRegressionWrapper
.setIsotonic(isotonic)
.setFeatureIndex(featureIndex)
.setWeightCol(weightCol)
+ .setFeaturesCol(rFormula.getFeaturesCol)
val pipeline = new Pipeline()
.setStages(Array(rFormulaModel, isotonicRegression))
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala
index 8616a8c01e..ea9458525a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/KMeansWrapper.scala
@@ -86,6 +86,7 @@ private[r] object KMeansWrapper extends MLReadable[KMeansWrapper] {
.setK(k)
.setMaxIter(maxIter)
.setInitMode(initMode)
+ .setFeaturesCol(rFormula.getFeaturesCol)
val pipeline = new Pipeline()
.setStages(Array(rFormulaModel, kMeans))
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala
index f2cb24b964..d1a39fea76 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/NaiveBayesWrapper.scala
@@ -73,6 +73,7 @@ private[r] object NaiveBayesWrapper extends MLReadable[NaiveBayesWrapper] {
val naiveBayes = new NaiveBayes()
.setSmoothing(smoothing)
.setModelType("bernoulli")
+ .setFeaturesCol(rFormula.getFeaturesCol)
.setPredictionCol(PREDICTED_LABEL_INDEX_COL)
val idxToStr = new IndexToString()
.setInputCol(PREDICTED_LABEL_INDEX_COL)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RWrapperUtils.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RWrapperUtils.scala
index 6a435992e3..379007c4d9 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/RWrapperUtils.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/RWrapperUtils.scala
@@ -19,14 +19,15 @@ package org.apache.spark.ml.r
import org.apache.spark.internal.Logging
import org.apache.spark.ml.feature.RFormula
+import org.apache.spark.ml.util.Identifiable
import org.apache.spark.sql.Dataset
object RWrapperUtils extends Logging {
/**
* DataFrame column check.
- * When loading data, default columns "features" and "label" will be added. And these two names
- * would conflict with RFormula default feature and label column names.
+ * When loading libsvm data, default columns "features" and "label" will be added.
+ * And "features" would conflict with RFormula default feature column names.
* Here is to change the column name to avoid "column already exists" error.
*
* @param rFormula RFormula instance
@@ -34,38 +35,11 @@ object RWrapperUtils extends Logging {
* @return Unit
*/
def checkDataColumns(rFormula: RFormula, data: Dataset[_]): Unit = {
- if (data.schema.fieldNames.contains(rFormula.getLabelCol)) {
- val newLabelName = convertToUniqueName(rFormula.getLabelCol, data.schema.fieldNames)
- logWarning(
- s"data containing ${rFormula.getLabelCol} column, using new name $newLabelName instead")
- rFormula.setLabelCol(newLabelName)
- }
-
if (data.schema.fieldNames.contains(rFormula.getFeaturesCol)) {
- val newFeaturesName = convertToUniqueName(rFormula.getFeaturesCol, data.schema.fieldNames)
+ val newFeaturesName = s"${Identifiable.randomUID(rFormula.getFeaturesCol)}"
logWarning(s"data containing ${rFormula.getFeaturesCol} column, " +
s"using new name $newFeaturesName instead")
rFormula.setFeaturesCol(newFeaturesName)
}
}
-
- /**
- * Convert conflicting name to be an unique name.
- * Appending a sequence number, like originalName_output1
- * and incrementing until it is not already there
- *
- * @param originalName Original name
- * @param fieldNames Array of field names in existing schema
- * @return String
- */
- def convertToUniqueName(originalName: String, fieldNames: Array[String]): String = {
- var counter = 1
- var newName = originalName + "_output"
-
- while (fieldNames.contains(newName)) {
- newName = originalName + "_output" + counter
- counter += 1
- }
- newName
- }
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/r/RWrapperUtilsSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/r/RWrapperUtilsSuite.scala
index ddc24cb3a6..27b03918d9 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/r/RWrapperUtilsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/r/RWrapperUtilsSuite.scala
@@ -35,22 +35,14 @@ class RWrapperUtilsSuite extends SparkFunSuite with MLlibTestSparkContext {
// after checking, model build is ok
RWrapperUtils.checkDataColumns(rFormula, data)
- assert(rFormula.getLabelCol == "label_output")
- assert(rFormula.getFeaturesCol == "features_output")
+ assert(rFormula.getLabelCol == "label")
+ assert(rFormula.getFeaturesCol.startsWith("features_"))
val model = rFormula.fit(data)
assert(model.isInstanceOf[RFormulaModel])
- assert(model.getLabelCol == "label_output")
- assert(model.getFeaturesCol == "features_output")
- }
-
- test("generate unique name by appending a sequence number") {
- val originalName = "label"
- val fieldNames = Array("label_output", "label_output1", "label_output2")
- val newName = RWrapperUtils.convertToUniqueName(originalName, fieldNames)
-
- assert(newName === "label_output3")
+ assert(model.getLabelCol == "label")
+ assert(model.getFeaturesCol.startsWith("features_"))
}
}