diff options
Diffstat (limited to 'examples')
-rw-r--r-- | examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java | 5 | ||||
-rw-r--r-- | examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala | 3 |
2 files changed, 8 insertions, 0 deletions
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java index a377694507..0b4c0d9ba9 100644 --- a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java +++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java @@ -220,6 +220,11 @@ class MyJavaLogisticRegressionModel public int numClasses() { return 2; } /** + * Number of features the model was trained on. + */ + public int numFeatures() { return weights_.size(); } + + /** * Create a copy of the model. * The copy is shallow, except for the embedded paramMap, which gets a deep copy. * <p> diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala index 340c3559b1..3758edc561 100644 --- a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala +++ b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala @@ -172,6 +172,9 @@ private class MyLogisticRegressionModel( /** Number of classes the label can take. 2 indicates binary classification. */ override val numClasses: Int = 2 + /** Number of features the model was trained on. */ + override val numFeatures: Int = weights.size + /** * Create a copy of the model. * The copy is shallow, except for the embedded paramMap, which gets a deep copy. |