aboutsummaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
Diffstat (limited to 'examples')
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java24
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java4
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala6
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala22
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala4
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala6
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala4
7 files changed, 25 insertions, 45 deletions
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java
index 46377a99c4..eac4f898a4 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java
@@ -28,7 +28,6 @@ import org.apache.spark.ml.classification.Classifier;
import org.apache.spark.ml.classification.ClassificationModel;
import org.apache.spark.ml.param.IntParam;
import org.apache.spark.ml.param.ParamMap;
-import org.apache.spark.ml.param.Params$;
import org.apache.spark.mllib.linalg.BLAS;
import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.linalg.Vectors;
@@ -129,16 +128,16 @@ class MyJavaLogisticRegression
// This method is used by fit().
// In Java, we have to make it public since Java does not understand Scala's protected modifier.
- public MyJavaLogisticRegressionModel train(DataFrame dataset, ParamMap paramMap) {
+ public MyJavaLogisticRegressionModel train(DataFrame dataset) {
// Extract columns from data using helper method.
- JavaRDD<LabeledPoint> oldDataset = extractLabeledPoints(dataset, paramMap).toJavaRDD();
+ JavaRDD<LabeledPoint> oldDataset = extractLabeledPoints(dataset).toJavaRDD();
// Do learning to estimate the weight vector.
int numFeatures = oldDataset.take(1).get(0).features().size();
Vector weights = Vectors.zeros(numFeatures); // Learning would happen here.
// Create a model, and return it.
- return new MyJavaLogisticRegressionModel(this, paramMap, weights);
+ return new MyJavaLogisticRegressionModel(this, weights);
}
}
@@ -155,18 +154,11 @@ class MyJavaLogisticRegressionModel
private MyJavaLogisticRegression parent_;
public MyJavaLogisticRegression parent() { return parent_; }
- private ParamMap fittingParamMap_;
- public ParamMap fittingParamMap() { return fittingParamMap_; }
-
private Vector weights_;
public Vector weights() { return weights_; }
- public MyJavaLogisticRegressionModel(
- MyJavaLogisticRegression parent_,
- ParamMap fittingParamMap_,
- Vector weights_) {
+ public MyJavaLogisticRegressionModel(MyJavaLogisticRegression parent_, Vector weights_) {
this.parent_ = parent_;
- this.fittingParamMap_ = fittingParamMap_;
this.weights_ = weights_;
}
@@ -210,10 +202,8 @@ class MyJavaLogisticRegressionModel
* In Java, we have to make this method public since Java does not understand Scala's protected
* modifier.
*/
- public MyJavaLogisticRegressionModel copy() {
- MyJavaLogisticRegressionModel m =
- new MyJavaLogisticRegressionModel(parent_, fittingParamMap_, weights_);
- Params$.MODULE$.inheritValues(this.extractParamMap(), this, m);
- return m;
+ @Override
+ public MyJavaLogisticRegressionModel copy(ParamMap extra) {
+ return copyValues(new MyJavaLogisticRegressionModel(parent_, weights_), extra);
}
}
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java
index 4e02acce69..29158d5c85 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java
@@ -71,7 +71,7 @@ public class JavaSimpleParamsExample {
// we can view the parameters it used during fit().
// This prints the parameter (name: value) pairs, where names are unique IDs for this
// LogisticRegression instance.
- System.out.println("Model 1 was fit using parameters: " + model1.fittingParamMap());
+ System.out.println("Model 1 was fit using parameters: " + model1.parent().extractParamMap());
// We may alternatively specify parameters using a ParamMap.
ParamMap paramMap = new ParamMap();
@@ -87,7 +87,7 @@ public class JavaSimpleParamsExample {
// Now learn a new model using the paramMapCombined parameters.
// paramMapCombined overrides all parameters set earlier via lr.set* methods.
LogisticRegressionModel model2 = lr.fit(training, paramMapCombined);
- System.out.println("Model 2 was fit using parameters: " + model2.fittingParamMap());
+ System.out.println("Model 2 was fit using parameters: " + model2.parent().extractParamMap());
// Prepare test documents.
List<LabeledPoint> localTest = Lists.newArrayList(
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala
index 9002e99d82..8340d91101 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/DecisionTreeExample.scala
@@ -276,16 +276,14 @@ object DecisionTreeExample {
// Get the trained Decision Tree from the fitted PipelineModel
algo match {
case "classification" =>
- val treeModel = pipelineModel.getModel[DecisionTreeClassificationModel](
- dt.asInstanceOf[DecisionTreeClassifier])
+ val treeModel = pipelineModel.stages.last.asInstanceOf[DecisionTreeClassificationModel]
if (treeModel.numNodes < 20) {
println(treeModel.toDebugString) // Print full model.
} else {
println(treeModel) // Print model summary.
}
case "regression" =>
- val treeModel = pipelineModel.getModel[DecisionTreeRegressionModel](
- dt.asInstanceOf[DecisionTreeRegressor])
+ val treeModel = pipelineModel.stages.last.asInstanceOf[DecisionTreeRegressionModel]
if (treeModel.numNodes < 20) {
println(treeModel.toDebugString) // Print full model.
} else {
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala
index 2245fa429f..2a2d067727 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala
@@ -18,13 +18,12 @@
package org.apache.spark.examples.ml
import org.apache.spark.{SparkConf, SparkContext}
-import org.apache.spark.ml.classification.{Classifier, ClassifierParams, ClassificationModel}
-import org.apache.spark.ml.param.{Params, IntParam, ParamMap}
+import org.apache.spark.ml.classification.{ClassificationModel, Classifier, ClassifierParams}
+import org.apache.spark.ml.param.{IntParam, ParamMap}
import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.sql.{DataFrame, Row, SQLContext}
-
/**
* A simple example demonstrating how to write your own learning algorithm using Estimator,
* Transformer, and other abstractions.
@@ -99,7 +98,7 @@ private trait MyLogisticRegressionParams extends ClassifierParams {
* class since the maxIter parameter is only used during training (not in the Model).
*/
val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations")
- def getMaxIter: Int = getOrDefault(maxIter)
+ def getMaxIter: Int = $(maxIter)
}
/**
@@ -117,18 +116,16 @@ private class MyLogisticRegression
def setMaxIter(value: Int): this.type = set(maxIter, value)
// This method is used by fit()
- override protected def train(
- dataset: DataFrame,
- paramMap: ParamMap): MyLogisticRegressionModel = {
+ override protected def train(dataset: DataFrame): MyLogisticRegressionModel = {
// Extract columns from data using helper method.
- val oldDataset = extractLabeledPoints(dataset, paramMap)
+ val oldDataset = extractLabeledPoints(dataset)
// Do learning to estimate the weight vector.
val numFeatures = oldDataset.take(1)(0).features.size
val weights = Vectors.zeros(numFeatures) // Learning would happen here.
// Create a model, and return it.
- new MyLogisticRegressionModel(this, paramMap, weights)
+ new MyLogisticRegressionModel(this, weights)
}
}
@@ -139,7 +136,6 @@ private class MyLogisticRegression
*/
private class MyLogisticRegressionModel(
override val parent: MyLogisticRegression,
- override val fittingParamMap: ParamMap,
val weights: Vector)
extends ClassificationModel[Vector, MyLogisticRegressionModel]
with MyLogisticRegressionParams {
@@ -176,9 +172,7 @@ private class MyLogisticRegressionModel(
*
* This is used for the default implementation of [[transform()]].
*/
- override protected def copy(): MyLogisticRegressionModel = {
- val m = new MyLogisticRegressionModel(parent, fittingParamMap, weights)
- Params.inheritValues(extractParamMap(), this, m)
- m
+ override def copy(extra: ParamMap): MyLogisticRegressionModel = {
+ copyValues(new MyLogisticRegressionModel(parent, weights), extra)
}
}
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala
index 5fccb142d4..c5899b6683 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/GBTExample.scala
@@ -201,14 +201,14 @@ object GBTExample {
// Get the trained GBT from the fitted PipelineModel
algo match {
case "classification" =>
- val rfModel = pipelineModel.getModel[GBTClassificationModel](dt.asInstanceOf[GBTClassifier])
+ val rfModel = pipelineModel.stages.last.asInstanceOf[GBTClassificationModel]
if (rfModel.totalNumNodes < 30) {
println(rfModel.toDebugString) // Print full model.
} else {
println(rfModel) // Print model summary.
}
case "regression" =>
- val rfModel = pipelineModel.getModel[GBTRegressionModel](dt.asInstanceOf[GBTRegressor])
+ val rfModel = pipelineModel.stages.last.asInstanceOf[GBTRegressionModel]
if (rfModel.totalNumNodes < 30) {
println(rfModel.toDebugString) // Print full model.
} else {
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala
index 9b909324ec..7f88d2681b 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/RandomForestExample.scala
@@ -209,16 +209,14 @@ object RandomForestExample {
// Get the trained Random Forest from the fitted PipelineModel
algo match {
case "classification" =>
- val rfModel = pipelineModel.getModel[RandomForestClassificationModel](
- dt.asInstanceOf[RandomForestClassifier])
+ val rfModel = pipelineModel.stages.last.asInstanceOf[RandomForestClassificationModel]
if (rfModel.totalNumNodes < 30) {
println(rfModel.toDebugString) // Print full model.
} else {
println(rfModel) // Print model summary.
}
case "regression" =>
- val rfModel = pipelineModel.getModel[RandomForestRegressionModel](
- dt.asInstanceOf[RandomForestRegressor])
+ val rfModel = pipelineModel.stages.last.asInstanceOf[RandomForestRegressionModel]
if (rfModel.totalNumNodes < 30) {
println(rfModel.toDebugString) // Print full model.
} else {
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala
index bf805149d0..e8a991f50e 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala
@@ -63,7 +63,7 @@ object SimpleParamsExample {
// we can view the parameters it used during fit().
// This prints the parameter (name: value) pairs, where names are unique IDs for this
// LogisticRegression instance.
- println("Model 1 was fit using parameters: " + model1.fittingParamMap)
+ println("Model 1 was fit using parameters: " + model1.parent.extractParamMap())
// We may alternatively specify parameters using a ParamMap,
// which supports several methods for specifying parameters.
@@ -78,7 +78,7 @@ object SimpleParamsExample {
// Now learn a new model using the paramMapCombined parameters.
// paramMapCombined overrides all parameters set earlier via lr.set* methods.
val model2 = lr.fit(training.toDF(), paramMapCombined)
- println("Model 2 was fit using parameters: " + model2.fittingParamMap)
+ println("Model 2 was fit using parameters: " + model2.parent.extractParamMap())
// Prepare test data.
val test = sc.parallelize(Seq(