aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala5
1 files changed, 4 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala
index 995b1ef03b..add4d49110 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/GeneralizedLinearRegressionWrapper.scala
@@ -29,6 +29,7 @@ import org.apache.spark.ml.regression._
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.param.shared._
+import org.apache.spark.ml.r.RWrapperUtils._
import org.apache.spark.ml.util._
import org.apache.spark.sql._
import org.apache.spark.sql.functions._
@@ -64,6 +65,7 @@ private[r] class GeneralizedLinearRegressionWrapper private (
.drop(PREDICTED_LABEL_PROB_COL)
.drop(PREDICTED_LABEL_INDEX_COL)
.drop(glm.getFeaturesCol)
+ .drop(glm.getLabelCol)
} else {
pipeline.transform(dataset)
.drop(glm.getFeaturesCol)
@@ -92,7 +94,7 @@ private[r] object GeneralizedLinearRegressionWrapper
regParam: Double): GeneralizedLinearRegressionWrapper = {
val rFormula = new RFormula().setFormula(formula)
if (family == "binomial") rFormula.setForceIndexLabel(true)
- RWrapperUtils.checkDataColumns(rFormula, data)
+ checkDataColumns(rFormula, data)
val rFormulaModel = rFormula.fit(data)
// get labels and feature names from output schema
val schema = rFormulaModel.transform(data).schema
@@ -109,6 +111,7 @@ private[r] object GeneralizedLinearRegressionWrapper
.setWeightCol(weightCol)
.setRegParam(regParam)
.setFeaturesCol(rFormula.getFeaturesCol)
+ .setLabelCol(rFormula.getLabelCol)
val pipeline = if (family == "binomial") {
// Convert prediction from probability to label index.
val probToPred = new ProbabilityToPrediction()