aboutsummaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorsethah <seth.hendrickson16@gmail.com>2015-09-23 15:00:52 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-09-23 15:00:52 -0700
commit098be27ad53c485ee2fc7f5871c47f899020e87b (patch)
tree1e6fe63cc0bb8bd6088b4117bc1951fdd6c42507 /examples
parenta18208047f06a4244703c17023bb20cbe1f59d73 (diff)
downloadspark-098be27ad53c485ee2fc7f5871c47f899020e87b.tar.gz
spark-098be27ad53c485ee2fc7f5871c47f899020e87b.tar.bz2
spark-098be27ad53c485ee2fc7f5871c47f899020e87b.zip
[SPARK-9715] [ML] Store numFeatures in all ML PredictionModel types
All prediction models should store `numFeatures` indicating the number of features the model was trained on. Default value of -1 added for backwards compatibility. Author: sethah <seth.hendrickson16@gmail.com> Closes #8675 from sethah/SPARK-9715.
Diffstat (limited to 'examples')
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java5
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala3
2 files changed, 8 insertions, 0 deletions
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java
index a377694507..0b4c0d9ba9 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java
@@ -220,6 +220,11 @@ class MyJavaLogisticRegressionModel
public int numClasses() { return 2; }
/**
+ * Number of features the model was trained on.
+ */
+ public int numFeatures() { return weights_.size(); }
+
+ /**
* Create a copy of the model.
* The copy is shallow, except for the embedded paramMap, which gets a deep copy.
* <p>
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala
index 340c3559b1..3758edc561 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala
@@ -172,6 +172,9 @@ private class MyLogisticRegressionModel(
/** Number of classes the label can take. 2 indicates binary classification. */
override val numClasses: Int = 2
+ /** Number of features the model was trained on. */
+ override val numFeatures: Int = weights.size
+
/**
* Create a copy of the model.
* The copy is shallow, except for the embedded paramMap, which gets a deep copy.