aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorGeorge Dittmar <georgedittmar@gmail.com>2015-07-27 11:16:33 -0700
committerXiangrui Meng <meng@databricks.com>2015-07-27 11:16:33 -0700
commit1f7b3d9dc7c2ed9d31f9083284cf900fd4c21e42 (patch)
tree512c47906dd738d6cce5db5973b227b82b4cd202
parente2f38167f8b5678ac45794eacb9c7bb9b951af82 (diff)
downloadspark-1f7b3d9dc7c2ed9d31f9083284cf900fd4c21e42.tar.gz
spark-1f7b3d9dc7c2ed9d31f9083284cf900fd4c21e42.tar.bz2
spark-1f7b3d9dc7c2ed9d31f9083284cf900fd4c21e42.zip
[SPARK-7423] [MLLIB] Modify ClassificationModel and Probabalistic model to use Vector.argmax
Use Vector.argmax call instead of converting to dense vector before calculating predictions. Author: George Dittmar <georgedittmar@gmail.com> Closes #7670 from GeorgeDittmar/sprk-7423 and squashes the following commits: e796747 [George Dittmar] Changing ClassificationModel and ProbabilisticClassificationModel to use Vector.argmax instead of converting to DenseVector
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala2
2 files changed, 2 insertions, 2 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
index 85c097bc64..581d8fa774 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
@@ -156,5 +156,5 @@ abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[Featur
* This may be overridden to support thresholds which favor particular labels.
* @return predicted label
*/
- protected def raw2prediction(rawPrediction: Vector): Double = rawPrediction.toDense.argmax
+ protected def raw2prediction(rawPrediction: Vector): Double = rawPrediction.argmax
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
index 38e8323726..dad4511086 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
@@ -173,5 +173,5 @@ private[spark] abstract class ProbabilisticClassificationModel[
* This may be overridden to support thresholds which favor particular labels.
* @return predicted label
*/
- protected def probability2prediction(probability: Vector): Double = probability.toDense.argmax
+ protected def probability2prediction(probability: Vector): Double = probability.argmax
}