aboutsummaryrefslogtreecommitdiff
path: root/examples/src/main
diff options
context:
space:
mode:
Diffstat (limited to 'examples/src/main')
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java43
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala11
2 files changed, 39 insertions, 15 deletions
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java
index eac4f898a4..ec533d174e 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java
@@ -28,6 +28,7 @@ import org.apache.spark.ml.classification.Classifier;
import org.apache.spark.ml.classification.ClassificationModel;
import org.apache.spark.ml.param.IntParam;
import org.apache.spark.ml.param.ParamMap;
+import org.apache.spark.ml.util.Identifiable$;
import org.apache.spark.mllib.linalg.BLAS;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
@@ -103,7 +104,23 @@ public class JavaDeveloperApiExample {
* However, this should still compile and run successfully.
*/
class MyJavaLogisticRegression
- extends Classifier<Vector, MyJavaLogisticRegression, MyJavaLogisticRegressionModel> {
+ extends Classifier<Vector, MyJavaLogisticRegression, MyJavaLogisticRegressionModel> {
+
+ public MyJavaLogisticRegression() {
+ init();
+ }
+
+ public MyJavaLogisticRegression(String uid) {
+ this.uid_ = uid;
+ init();
+ }
+
+ private String uid_ = Identifiable$.MODULE$.randomUID("myJavaLogReg");
+
+ @Override
+ public String uid() {
+ return uid_;
+ }
/**
* Param for max number of iterations
@@ -117,7 +134,7 @@ class MyJavaLogisticRegression
int getMaxIter() { return (Integer) getOrDefault(maxIter); }
- public MyJavaLogisticRegression() {
+ private void init() {
setMaxIter(100);
}
@@ -137,7 +154,7 @@ class MyJavaLogisticRegression
Vector weights = Vectors.zeros(numFeatures); // Learning would happen here.
// Create a model, and return it.
- return new MyJavaLogisticRegressionModel(this, weights);
+ return new MyJavaLogisticRegressionModel(uid(), weights).setParent(this);
}
}
@@ -149,17 +166,21 @@ class MyJavaLogisticRegression
* However, this should still compile and run successfully.
*/
class MyJavaLogisticRegressionModel
- extends ClassificationModel<Vector, MyJavaLogisticRegressionModel> {
-
- private MyJavaLogisticRegression parent_;
- public MyJavaLogisticRegression parent() { return parent_; }
+ extends ClassificationModel<Vector, MyJavaLogisticRegressionModel> {
private Vector weights_;
public Vector weights() { return weights_; }
- public MyJavaLogisticRegressionModel(MyJavaLogisticRegression parent_, Vector weights_) {
- this.parent_ = parent_;
- this.weights_ = weights_;
+ public MyJavaLogisticRegressionModel(String uid, Vector weights) {
+ this.uid_ = uid;
+ this.weights_ = weights;
+ }
+
+ private String uid_ = Identifiable$.MODULE$.randomUID("myJavaLogReg");
+
+ @Override
+ public String uid() {
+ return uid_;
}
// This uses the default implementation of transform(), which reads column "features" and outputs
@@ -204,6 +225,6 @@ class MyJavaLogisticRegressionModel
*/
@Override
public MyJavaLogisticRegressionModel copy(ParamMap extra) {
- return copyValues(new MyJavaLogisticRegressionModel(parent_, weights_), extra);
+ return copyValues(new MyJavaLogisticRegressionModel(uid(), weights_), extra);
}
}
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala
index 2a2d067727..3ee456edbe 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala
@@ -20,6 +20,7 @@ package org.apache.spark.examples.ml
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.ml.classification.{ClassificationModel, Classifier, ClassifierParams}
import org.apache.spark.ml.param.{IntParam, ParamMap}
+import org.apache.spark.ml.util.Identifiable
import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
@@ -106,10 +107,12 @@ private trait MyLogisticRegressionParams extends ClassifierParams {
*
* NOTE: This is private since it is an example. In practice, you may not want it to be private.
*/
-private class MyLogisticRegression
+private class MyLogisticRegression(override val uid: String)
extends Classifier[Vector, MyLogisticRegression, MyLogisticRegressionModel]
with MyLogisticRegressionParams {
+ def this() = this(Identifiable.randomUID("myLogReg"))
+
setMaxIter(100) // Initialize
// The parameter setter is in this class since it should return type MyLogisticRegression.
@@ -125,7 +128,7 @@ private class MyLogisticRegression
val weights = Vectors.zeros(numFeatures) // Learning would happen here.
// Create a model, and return it.
- new MyLogisticRegressionModel(this, weights)
+ new MyLogisticRegressionModel(uid, weights).setParent(this)
}
}
@@ -135,7 +138,7 @@ private class MyLogisticRegression
* NOTE: This is private since it is an example. In practice, you may not want it to be private.
*/
private class MyLogisticRegressionModel(
- override val parent: MyLogisticRegression,
+ override val uid: String,
val weights: Vector)
extends ClassificationModel[Vector, MyLogisticRegressionModel]
with MyLogisticRegressionParams {
@@ -173,6 +176,6 @@ private class MyLogisticRegressionModel(
* This is used for the default implementation of [[transform()]].
*/
override def copy(extra: ParamMap): MyLogisticRegressionModel = {
- copyValues(new MyLogisticRegressionModel(parent, weights), extra)
+ copyValues(new MyLogisticRegressionModel(uid, weights), extra)
}
}