diff options
author | sethah <seth.hendrickson16@gmail.com> | 2015-09-23 15:00:52 -0700 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2015-09-23 15:00:52 -0700 |
commit | 098be27ad53c485ee2fc7f5871c47f899020e87b (patch) | |
tree | 1e6fe63cc0bb8bd6088b4117bc1951fdd6c42507 /examples/src | |
parent | a18208047f06a4244703c17023bb20cbe1f59d73 (diff) | |
download | spark-098be27ad53c485ee2fc7f5871c47f899020e87b.tar.gz spark-098be27ad53c485ee2fc7f5871c47f899020e87b.tar.bz2 spark-098be27ad53c485ee2fc7f5871c47f899020e87b.zip |
[SPARK-9715] [ML] Store numFeatures in all ML PredictionModel types
All prediction models should store `numFeatures` indicating the number of features the model was trained on. Default value of -1 added for backwards compatibility.
Author: sethah <seth.hendrickson16@gmail.com>
Closes #8675 from sethah/SPARK-9715.
Diffstat (limited to 'examples/src')
-rw-r--r-- | examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java | 5 | ||||
-rw-r--r-- | examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala | 3 |
2 files changed, 8 insertions, 0 deletions
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java index a377694507..0b4c0d9ba9 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java @@ -220,6 +220,11 @@ class MyJavaLogisticRegressionModel public int numClasses() { return 2; } /** + * Number of features the model was trained on. + */ + public int numFeatures() { return weights_.size(); } + + /** * Create a copy of the model. * The copy is shallow, except for the embedded paramMap, which gets a deep copy. * <p> diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala index 340c3559b1..3758edc561 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala @@ -172,6 +172,9 @@ private class MyLogisticRegressionModel( /** Number of classes the label can take. 2 indicates binary classification. */ override val numClasses: Int = 2 + /** Number of features the model was trained on. */ + override val numFeatures: Int = weights.size + /** * Create a copy of the model. * The copy is shallow, except for the embedded paramMap, which gets a deep copy. |