aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala13
1 files changed, 5 insertions, 8 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 0d329d2c08..ac0124513f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -263,11 +263,10 @@ class LogisticRegression @Since("1.2.0") (
protected[spark] def train(dataset: DataFrame, handlePersistence: Boolean):
LogisticRegressionModel = {
val w = if ($(weightCol).isEmpty) lit(1.0) else col($(weightCol))
- val instances: RDD[Instance] =
- dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map {
- case Row(label: Double, weight: Double, features: Vector) =>
- Instance(label, weight, features)
- }
+ val instances: RDD[Instance] = dataset.select(col($(labelCol)), w, col($(featuresCol))).map {
+ case Row(label: Double, weight: Double, features: Vector) =>
+ Instance(label, weight, features)
+ }
if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)
@@ -791,7 +790,6 @@ sealed trait LogisticRegressionSummary extends Serializable {
/**
* :: Experimental ::
* Logistic regression training results.
- *
* @param predictions dataframe outputted by the model's `transform` method.
* @param probabilityCol field in "predictions" which gives the calibrated probability of
* each instance as a vector.
@@ -815,7 +813,6 @@ class BinaryLogisticRegressionTrainingSummary private[classification] (
/**
* :: Experimental ::
* Binary Logistic regression results for a given model.
- *
* @param predictions dataframe outputted by the model's `transform` method.
* @param probabilityCol field in "predictions" which gives the calibrated probability of
* each instance.
@@ -840,7 +837,7 @@ class BinaryLogisticRegressionSummary private[classification] (
// TODO: Allow the user to vary the number of bins using a setBins method in
// BinaryClassificationMetrics. For now the default is set to 100.
@transient private val binaryMetrics = new BinaryClassificationMetrics(
- predictions.select(probabilityCol, labelCol).rdd.map {
+ predictions.select(probabilityCol, labelCol).map {
case Row(score: Vector, label: Double) => (score(1), label)
}, 100
)