diff options
author | George Dittmar <georgedittmar@gmail.com> | 2015-07-27 11:16:33 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-07-27 11:16:33 -0700 |
commit | 1f7b3d9dc7c2ed9d31f9083284cf900fd4c21e42 (patch) | |
tree | 512c47906dd738d6cce5db5973b227b82b4cd202 | |
parent | e2f38167f8b5678ac45794eacb9c7bb9b951af82 (diff) | |
download | spark-1f7b3d9dc7c2ed9d31f9083284cf900fd4c21e42.tar.gz spark-1f7b3d9dc7c2ed9d31f9083284cf900fd4c21e42.tar.bz2 spark-1f7b3d9dc7c2ed9d31f9083284cf900fd4c21e42.zip |
[SPARK-7423] [MLLIB] Modify ClassificationModel and Probabalistic model to use Vector.argmax
Use Vector.argmax call instead of converting to dense vector before calculating predictions.
Author: George Dittmar <georgedittmar@gmail.com>
Closes #7670 from GeorgeDittmar/sprk-7423 and squashes the following commits:
e796747 [George Dittmar] Changing ClassificationModel and ProbabilisticClassificationModel to use Vector.argmax instead of converting to DenseVector
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala | 2 | ||||
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala | 2 |
2 files changed, 2 insertions, 2 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala index 85c097bc64..581d8fa774 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala @@ -156,5 +156,5 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur * This may be overridden to support thresholds which favor particular labels. * @return predicted label */ - protected def raw2prediction(rawPrediction: Vector): Double = rawPrediction.toDense.argmax + protected def raw2prediction(rawPrediction: Vector): Double = rawPrediction.argmax } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala index 38e8323726..dad4511086 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala @@ -173,5 +173,5 @@ private[spark] abstract class ProbabilisticClassificationModel[ * This may be overridden to support thresholds which favor particular labels. * @return predicted label */ - protected def probability2prediction(probability: Vector): Double = probability.toDense.argmax + protected def probability2prediction(probability: Vector): Double = probability.argmax } |