diff options
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala | 13 |
1 files changed, 5 insertions, 8 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 0d329d2c08..ac0124513f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -263,11 +263,10 @@ class LogisticRegression @Since("1.2.0") ( protected[spark] def train(dataset: DataFrame, handlePersistence: Boolean): LogisticRegressionModel = { val w = if ($(weightCol).isEmpty) lit(1.0) else col($(weightCol)) - val instances: RDD[Instance] = - dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map { - case Row(label: Double, weight: Double, features: Vector) => - Instance(label, weight, features) - } + val instances: RDD[Instance] = dataset.select(col($(labelCol)), w, col($(featuresCol))).map { + case Row(label: Double, weight: Double, features: Vector) => + Instance(label, weight, features) + } if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK) @@ -791,7 +790,6 @@ sealed trait LogisticRegressionSummary extends Serializable { /** * :: Experimental :: * Logistic regression training results. - * * @param predictions dataframe outputted by the model's `transform` method. * @param probabilityCol field in "predictions" which gives the calibrated probability of * each instance as a vector. @@ -815,7 +813,6 @@ class BinaryLogisticRegressionTrainingSummary private[classification] ( /** * :: Experimental :: * Binary Logistic regression results for a given model. - * * @param predictions dataframe outputted by the model's `transform` method. * @param probabilityCol field in "predictions" which gives the calibrated probability of * each instance. @@ -840,7 +837,7 @@ class BinaryLogisticRegressionSummary private[classification] ( // TODO: Allow the user to vary the number of bins using a setBins method in // BinaryClassificationMetrics. For now the default is set to 100. @transient private val binaryMetrics = new BinaryClassificationMetrics( - predictions.select(probabilityCol, labelCol).rdd.map { + predictions.select(probabilityCol, labelCol).map { case Row(score: Vector, label: Double) => (score(1), label) }, 100 ) |