aboutsummaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
Diffstat (limited to 'examples')
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java5
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala3
2 files changed, 8 insertions, 0 deletions
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java
index a377694507..0b4c0d9ba9 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java
@@ -220,6 +220,11 @@ class MyJavaLogisticRegressionModel
public int numClasses() { return 2; }
/**
+ * Number of features the model was trained on.
+ */
+ public int numFeatures() { return weights_.size(); }
+
+ /**
* Create a copy of the model.
* The copy is shallow, except for the embedded paramMap, which gets a deep copy.
* <p>
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala
index 340c3559b1..3758edc561 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala
@@ -172,6 +172,9 @@ private class MyLogisticRegressionModel(
/** Number of classes the label can take. 2 indicates binary classification. */
override val numClasses: Int = 2
+ /** Number of features the model was trained on. */
+ override val numFeatures: Int = weights.size
+
/**
* Create a copy of the model.
* The copy is shallow, except for the embedded paramMap, which gets a deep copy.