aboutsummaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2015-02-05 23:43:47 -0800
committerXiangrui Meng <meng@databricks.com>2015-02-05 23:43:47 -0800
commitdc0c4490a12ecedd8ca5a1bb256c7ccbdf0be04f (patch)
tree745d33737eaddc95a0c55a814e84c7b96f9ecbcf /examples
parent6b88825a25a0a072c13bbcc57bbfdb102a3f133d (diff)
downloadspark-dc0c4490a12ecedd8ca5a1bb256c7ccbdf0be04f.tar.gz
spark-dc0c4490a12ecedd8ca5a1bb256c7ccbdf0be04f.tar.bz2
spark-dc0c4490a12ecedd8ca5a1bb256c7ccbdf0be04f.zip
[SPARK-4789] [SPARK-4942] [SPARK-5031] [mllib] Standardize ML Prediction APIs
This is part (1a) of the updates from the design doc in [https://docs.google.com/document/d/1BH9el33kBX8JiDdgUJXdLW14CA2qhTCWIG46eXZVoJs] **UPDATE**: Most of the APIs are being kept private[spark] to allow further discussion. Here is a list of changes which are public: * new output columns: rawPrediction, probabilities * The “score” column is now called “rawPrediction” * Classifiers now provide numClasses * Params.get and .set are now protected instead of private[ml]. * ParamMap now has a size method. * new classes: LinearRegression, LinearRegressionModel * LogisticRegression now has an intercept. ### Sketch of APIs (most of which are private[spark] for now) Abstract classes for learning algorithms (+ corresponding Model abstractions): * Classifier (+ ClassificationModel) * ProbabilisticClassifier (+ ProbabilisticClassificationModel) * Regressor (+ RegressionModel) * Predictor (+ PredictionModel) * *For all of these*: * There is no strongly typed training-time API. * There is a strongly typed test-time (prediction) API which helps developers implement new algorithms. Concrete classes: learning algorithms * LinearRegression * LogisticRegression (updated to use new abstract classes) * Also, removed "score" in favor of "probability" output column. Changed BinaryClassificationEvaluator to match. (SPARK-5031) Other updates: * params.scala: Changed Params.set/get to be protected instead of private[ml] * This was needed for the example of defining a class from outside of the MLlib namespace. * VectorUDT: Will later change from private[spark] to public. * This is needed for outside users to write their own validateAndTransformSchema() methods using vectors. * Also, added equals() method.f * SPARK-4942 : ML Transformers should allow output cols to be turned on,off * Update validateAndTransformSchema * Update transform * (Updated examples, test suites according to other changes) New examples: * DeveloperApiExample.scala (example of defining algorithm from outside of the MLlib namespace) * Added Java version too Test Suites: * LinearRegressionSuite * LogisticRegressionSuite * + Java versions of above suites CC: mengxr etrain shivaram Author: Joseph K. Bradley <joseph@databricks.com> Closes #3637 from jkbradley/ml-api-part1 and squashes the following commits: 405bfb8 [Joseph K. Bradley] Last edits based on code review. Small cleanups fec348a [Joseph K. Bradley] Added JavaDeveloperApiExample.java and fixed other issues: Made developer API private[spark] for now. Added constructors Java can understand to specialized Param types. 8316d5e [Joseph K. Bradley] fixes after rebasing on master fc62406 [Joseph K. Bradley] fixed test suites after last commit bcb9549 [Joseph K. Bradley] Fixed issues after rebasing from master (after move from SchemaRDD to DataFrame) 9872424 [Joseph K. Bradley] fixed JavaLinearRegressionSuite.java Java sql api f542997 [Joseph K. Bradley] Added MIMA excludes for VectorUDT (now public), and added DeveloperApi annotation to it 216d199 [Joseph K. Bradley] fixed after sql datatypes PR got merged f549e34 [Joseph K. Bradley] Updates based on code review. Major ones are: * Created weakly typed Predictor.train() method which is called by fit() so that developers do not have to call schema validation or copy parameters. * Made Predictor.featuresDataType have a default value of VectorUDT. * NOTE: This could be dangerous since the FeaturesType type parameter cannot have a default value. 343e7bd [Joseph K. Bradley] added blanket mima exclude for ml package 82f340b [Joseph K. Bradley] Fixed bug in LogisticRegression (introduced in this PR). Fixed Java suites 0a16da9 [Joseph K. Bradley] Fixed Linear/Logistic RegressionSuites c3c8da5 [Joseph K. Bradley] small cleanup 934f97b [Joseph K. Bradley] Fixed bugs from previous commit. 1c61723 [Joseph K. Bradley] * Made ProbabilisticClassificationModel into a subclass of ClassificationModel. Also introduced ProbabilisticClassifier. * This was to support output column “probabilityCol” in transform(). 4e2f711 [Joseph K. Bradley] rat fix bc654e1 [Joseph K. Bradley] Added spark.ml LinearRegressionSuite 8d13233 [Joseph K. Bradley] Added methods: * Classifier: batch predictRaw() * Predictor: train() without paramMap ProbabilisticClassificationModel.predictProbabilities() * Java versions of all above batch methods + others 1680905 [Joseph K. Bradley] Added JavaLabeledPointSuite.java for spark.ml, and added constructor to LabeledPoint which defaults weight to 1.0 adbe50a [Joseph K. Bradley] * fixed LinearRegression train() to use embedded paramMap * added Predictor.predict(RDD[Vector]) method * updated Linear/LogisticRegressionSuites 58802e3 [Joseph K. Bradley] added train() to Predictor subclasses which does not take a ParamMap. 57d54ab [Joseph K. Bradley] * Changed semantics of Predictor.train() to merge the given paramMap with the embedded paramMap. * remove threshold_internal from logreg * Added Predictor.copy() * Extended LogisticRegressionSuite e433872 [Joseph K. Bradley] Updated docs. Added LabeledPointSuite to spark.ml 54b7b31 [Joseph K. Bradley] Fixed issue with logreg threshold being set correctly 0617d61 [Joseph K. Bradley] Fixed bug from last commit (sorting paramMap by parameter names in toString). Fixed bug in persisting logreg data. Added threshold_internal to logreg for faster test-time prediction (avoiding map lookup). 601e792 [Joseph K. Bradley] Modified ParamMap to sort parameters in toString. Cleaned up classes in class hierarchy, before implementing tests and examples. d705e87 [Joseph K. Bradley] Added LinearRegression and Regressor back from ml-api branch 52f4fde [Joseph K. Bradley] removing everything except for simple class hierarchy for classification d35bb5d [Joseph K. Bradley] fixed compilation issues, but have not added tests yet bfade12 [Joseph K. Bradley] Added lots of classes for new ML API:
Diffstat (limited to 'examples')
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java6
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java217
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java10
-rw-r--r--examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java4
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala7
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala184
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala16
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala7
8 files changed, 430 insertions, 21 deletions
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java
index 0fbee6e433..5041e0b6d3 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaCrossValidatorExample.java
@@ -116,10 +116,12 @@ public class JavaCrossValidatorExample {
// Make predictions on test documents. cvModel uses the best model found (lrModel).
cvModel.transform(test).registerTempTable("prediction");
- DataFrame predictions = jsql.sql("SELECT id, text, score, prediction FROM prediction");
+ DataFrame predictions = jsql.sql("SELECT id, text, probability, prediction FROM prediction");
for (Row r: predictions.collect()) {
- System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> score=" + r.get(2)
+ System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2)
+ ", prediction=" + r.get(3));
}
+
+ jsc.stop();
}
}
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java
new file mode 100644
index 0000000000..42d4d7d0be
--- /dev/null
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaDeveloperApiExample.java
@@ -0,0 +1,217 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.ml;
+
+import java.util.List;
+
+import com.google.common.collect.Lists;
+
+import org.apache.spark.SparkConf;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import org.apache.spark.ml.classification.Classifier;
+import org.apache.spark.ml.classification.ClassificationModel;
+import org.apache.spark.ml.param.IntParam;
+import org.apache.spark.ml.param.ParamMap;
+import org.apache.spark.ml.param.Params;
+import org.apache.spark.ml.param.Params$;
+import org.apache.spark.mllib.linalg.BLAS;
+import org.apache.spark.mllib.linalg.Vector;
+import org.apache.spark.mllib.linalg.Vectors;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.SQLContext;
+
+
+/**
+ * A simple example demonstrating how to write your own learning algorithm using Estimator,
+ * Transformer, and other abstractions.
+ * This mimics {@link org.apache.spark.ml.classification.LogisticRegression}.
+ *
+ * Run with
+ * <pre>
+ * bin/run-example ml.JavaDeveloperApiExample
+ * </pre>
+ */
+public class JavaDeveloperApiExample {
+
+ public static void main(String[] args) throws Exception {
+ SparkConf conf = new SparkConf().setAppName("JavaDeveloperApiExample");
+ JavaSparkContext jsc = new JavaSparkContext(conf);
+ SQLContext jsql = new SQLContext(jsc);
+
+ // Prepare training data.
+ List<LabeledPoint> localTraining = Lists.newArrayList(
+ new LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)),
+ new LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)),
+ new LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)),
+ new LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5)));
+ DataFrame training = jsql.applySchema(jsc.parallelize(localTraining), LabeledPoint.class);
+
+ // Create a LogisticRegression instance. This instance is an Estimator.
+ MyJavaLogisticRegression lr = new MyJavaLogisticRegression();
+ // Print out the parameters, documentation, and any default values.
+ System.out.println("MyJavaLogisticRegression parameters:\n" + lr.explainParams() + "\n");
+
+ // We may set parameters using setter methods.
+ lr.setMaxIter(10);
+
+ // Learn a LogisticRegression model. This uses the parameters stored in lr.
+ MyJavaLogisticRegressionModel model = lr.fit(training);
+
+ // Prepare test data.
+ List<LabeledPoint> localTest = Lists.newArrayList(
+ new LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)),
+ new LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)),
+ new LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5)));
+ DataFrame test = jsql.applySchema(jsc.parallelize(localTest), LabeledPoint.class);
+
+ // Make predictions on test documents. cvModel uses the best model found (lrModel).
+ DataFrame results = model.transform(test);
+ double sumPredictions = 0;
+ for (Row r : results.select("features", "label", "prediction").collect()) {
+ sumPredictions += r.getDouble(2);
+ }
+ if (sumPredictions != 0.0) {
+ throw new Exception("MyJavaLogisticRegression predicted something other than 0," +
+ " even though all weights are 0!");
+ }
+
+ jsc.stop();
+ }
+}
+
+/**
+ * Example of defining a type of {@link Classifier}.
+ *
+ * NOTE: This is private since it is an example. In practice, you may not want it to be private.
+ */
+class MyJavaLogisticRegression
+ extends Classifier<Vector, MyJavaLogisticRegression, MyJavaLogisticRegressionModel>
+ implements Params {
+
+ /**
+ * Param for max number of iterations
+ * <p/>
+ * NOTE: The usual way to add a parameter to a model or algorithm is to include:
+ * - val myParamName: ParamType
+ * - def getMyParamName
+ * - def setMyParamName
+ */
+ IntParam maxIter = new IntParam(this, "maxIter", "max number of iterations");
+
+ int getMaxIter() { return (int)get(maxIter); }
+
+ public MyJavaLogisticRegression() {
+ setMaxIter(100);
+ }
+
+ // The parameter setter is in this class since it should return type MyJavaLogisticRegression.
+ MyJavaLogisticRegression setMaxIter(int value) {
+ return (MyJavaLogisticRegression)set(maxIter, value);
+ }
+
+ // This method is used by fit().
+ // In Java, we have to make it public since Java does not understand Scala's protected modifier.
+ public MyJavaLogisticRegressionModel train(DataFrame dataset, ParamMap paramMap) {
+ // Extract columns from data using helper method.
+ JavaRDD<LabeledPoint> oldDataset = extractLabeledPoints(dataset, paramMap).toJavaRDD();
+
+ // Do learning to estimate the weight vector.
+ int numFeatures = oldDataset.take(1).get(0).features().size();
+ Vector weights = Vectors.zeros(numFeatures); // Learning would happen here.
+
+ // Create a model, and return it.
+ return new MyJavaLogisticRegressionModel(this, paramMap, weights);
+ }
+}
+
+/**
+ * Example of defining a type of {@link ClassificationModel}.
+ *
+ * NOTE: This is private since it is an example. In practice, you may not want it to be private.
+ */
+class MyJavaLogisticRegressionModel
+ extends ClassificationModel<Vector, MyJavaLogisticRegressionModel> implements Params {
+
+ private MyJavaLogisticRegression parent_;
+ public MyJavaLogisticRegression parent() { return parent_; }
+
+ private ParamMap fittingParamMap_;
+ public ParamMap fittingParamMap() { return fittingParamMap_; }
+
+ private Vector weights_;
+ public Vector weights() { return weights_; }
+
+ public MyJavaLogisticRegressionModel(
+ MyJavaLogisticRegression parent_,
+ ParamMap fittingParamMap_,
+ Vector weights_) {
+ this.parent_ = parent_;
+ this.fittingParamMap_ = fittingParamMap_;
+ this.weights_ = weights_;
+ }
+
+ // This uses the default implementation of transform(), which reads column "features" and outputs
+ // columns "prediction" and "rawPrediction."
+
+ // This uses the default implementation of predict(), which chooses the label corresponding to
+ // the maximum value returned by [[predictRaw()]].
+
+ /**
+ * Raw prediction for each possible label.
+ * The meaning of a "raw" prediction may vary between algorithms, but it intuitively gives
+ * a measure of confidence in each possible label (where larger = more confident).
+ * This internal method is used to implement [[transform()]] and output [[rawPredictionCol]].
+ *
+ * @return vector where element i is the raw prediction for label i.
+ * This raw prediction may be any real number, where a larger value indicates greater
+ * confidence for that label.
+ *
+ * In Java, we have to make this method public since Java does not understand Scala's protected
+ * modifier.
+ */
+ public Vector predictRaw(Vector features) {
+ double margin = BLAS.dot(features, weights_);
+ // There are 2 classes (binary classification), so we return a length-2 vector,
+ // where index i corresponds to class i (i = 0, 1).
+ return Vectors.dense(-margin, margin);
+ }
+
+ /**
+ * Number of classes the label can take. 2 indicates binary classification.
+ */
+ public int numClasses() { return 2; }
+
+ /**
+ * Create a copy of the model.
+ * The copy is shallow, except for the embedded paramMap, which gets a deep copy.
+ * <p/>
+ * This is used for the defaul implementation of [[transform()]].
+ *
+ * In Java, we have to make this method public since Java does not understand Scala's protected
+ * modifier.
+ */
+ public MyJavaLogisticRegressionModel copy() {
+ MyJavaLogisticRegressionModel m =
+ new MyJavaLogisticRegressionModel(parent_, fittingParamMap_, weights_);
+ Params$.MODULE$.inheritValues(this.paramMap(), this, m);
+ return m;
+ }
+}
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java
index eaaa344be4..cc69e6315f 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleParamsExample.java
@@ -81,7 +81,7 @@ public class JavaSimpleParamsExample {
// One can also combine ParamMaps.
ParamMap paramMap2 = new ParamMap();
- paramMap2.put(lr.scoreCol().w("probability")); // Change output column name
+ paramMap2.put(lr.probabilityCol().w("myProbability")); // Change output column name
ParamMap paramMapCombined = paramMap.$plus$plus(paramMap2);
// Now learn a new model using the paramMapCombined parameters.
@@ -98,14 +98,16 @@ public class JavaSimpleParamsExample {
// Make predictions on test documents using the Transformer.transform() method.
// LogisticRegression.transform will only use the 'features' column.
- // Note that model2.transform() outputs a 'probability' column instead of the usual 'score'
- // column since we renamed the lr.scoreCol parameter previously.
+ // Note that model2.transform() outputs a 'myProbability' column instead of the usual
+ // 'probability' column since we renamed the lr.probabilityCol parameter previously.
model2.transform(test).registerTempTable("results");
DataFrame results =
- jsql.sql("SELECT features, label, probability, prediction FROM results");
+ jsql.sql("SELECT features, label, myProbability, prediction FROM results");
for (Row r: results.collect()) {
System.out.println("(" + r.get(0) + ", " + r.get(1) + ") -> prob=" + r.get(2)
+ ", prediction=" + r.get(3));
}
+
+ jsc.stop();
}
}
diff --git a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java
index 82d665a3e1..d929f1ad20 100644
--- a/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java
+++ b/examples/src/main/java/org/apache/spark/examples/ml/JavaSimpleTextClassificationPipeline.java
@@ -85,8 +85,10 @@ public class JavaSimpleTextClassificationPipeline {
model.transform(test).registerTempTable("prediction");
DataFrame predictions = jsql.sql("SELECT id, text, score, prediction FROM prediction");
for (Row r: predictions.collect()) {
- System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> score=" + r.get(2)
+ System.out.println("(" + r.get(0) + ", " + r.get(1) + ") --> prob=" + r.get(2)
+ ", prediction=" + r.get(3));
}
+
+ jsc.stop();
}
}
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala
index b6c30a007d..a2893f78e0 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/CrossValidatorExample.scala
@@ -23,6 +23,7 @@ import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator
import org.apache.spark.ml.feature.{HashingTF, Tokenizer}
import org.apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator}
+import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.sql.{Row, SQLContext}
/**
@@ -100,10 +101,10 @@ object CrossValidatorExample {
// Make predictions on test documents. cvModel uses the best model found (lrModel).
cvModel.transform(test)
- .select("id", "text", "score", "prediction")
+ .select("id", "text", "probability", "prediction")
.collect()
- .foreach { case Row(id: Long, text: String, score: Double, prediction: Double) =>
- println("(" + id + ", " + text + ") --> score=" + score + ", prediction=" + prediction)
+ .foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) =>
+ println(s"($id, $text) --> prob=$prob, prediction=$prediction")
}
sc.stop()
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala
new file mode 100644
index 0000000000..aed4423893
--- /dev/null
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/DeveloperApiExample.scala
@@ -0,0 +1,184 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.examples.ml
+
+import org.apache.spark.{SparkConf, SparkContext}
+import org.apache.spark.ml.classification.{Classifier, ClassifierParams, ClassificationModel}
+import org.apache.spark.ml.param.{Params, IntParam, ParamMap}
+import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors}
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+
+
+/**
+ * A simple example demonstrating how to write your own learning algorithm using Estimator,
+ * Transformer, and other abstractions.
+ * This mimics [[org.apache.spark.ml.classification.LogisticRegression]].
+ * Run with
+ * {{{
+ * bin/run-example ml.DeveloperApiExample
+ * }}}
+ */
+object DeveloperApiExample {
+
+ def main(args: Array[String]) {
+ val conf = new SparkConf().setAppName("DeveloperApiExample")
+ val sc = new SparkContext(conf)
+ val sqlContext = new SQLContext(sc)
+ import sqlContext.implicits._
+
+ // Prepare training data.
+ val training = sc.parallelize(Seq(
+ LabeledPoint(1.0, Vectors.dense(0.0, 1.1, 0.1)),
+ LabeledPoint(0.0, Vectors.dense(2.0, 1.0, -1.0)),
+ LabeledPoint(0.0, Vectors.dense(2.0, 1.3, 1.0)),
+ LabeledPoint(1.0, Vectors.dense(0.0, 1.2, -0.5))))
+
+ // Create a LogisticRegression instance. This instance is an Estimator.
+ val lr = new MyLogisticRegression()
+ // Print out the parameters, documentation, and any default values.
+ println("MyLogisticRegression parameters:\n" + lr.explainParams() + "\n")
+
+ // We may set parameters using setter methods.
+ lr.setMaxIter(10)
+
+ // Learn a LogisticRegression model. This uses the parameters stored in lr.
+ val model = lr.fit(training)
+
+ // Prepare test data.
+ val test = sc.parallelize(Seq(
+ LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)),
+ LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)),
+ LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5))))
+
+ // Make predictions on test data.
+ val sumPredictions: Double = model.transform(test)
+ .select("features", "label", "prediction")
+ .collect()
+ .map { case Row(features: Vector, label: Double, prediction: Double) =>
+ prediction
+ }.sum
+ assert(sumPredictions == 0.0,
+ "MyLogisticRegression predicted something other than 0, even though all weights are 0!")
+
+ sc.stop()
+ }
+}
+
+/**
+ * Example of defining a parameter trait for a user-defined type of [[Classifier]].
+ *
+ * NOTE: This is private since it is an example. In practice, you may not want it to be private.
+ */
+private trait MyLogisticRegressionParams extends ClassifierParams {
+
+ /**
+ * Param for max number of iterations
+ *
+ * NOTE: The usual way to add a parameter to a model or algorithm is to include:
+ * - val myParamName: ParamType
+ * - def getMyParamName
+ * - def setMyParamName
+ * Here, we have a trait to be mixed in with the Estimator and Model (MyLogisticRegression
+ * and MyLogisticRegressionModel). We place the setter (setMaxIter) method in the Estimator
+ * class since the maxIter parameter is only used during training (not in the Model).
+ */
+ val maxIter: IntParam = new IntParam(this, "maxIter", "max number of iterations")
+ def getMaxIter: Int = get(maxIter)
+}
+
+/**
+ * Example of defining a type of [[Classifier]].
+ *
+ * NOTE: This is private since it is an example. In practice, you may not want it to be private.
+ */
+private class MyLogisticRegression
+ extends Classifier[Vector, MyLogisticRegression, MyLogisticRegressionModel]
+ with MyLogisticRegressionParams {
+
+ setMaxIter(100) // Initialize
+
+ // The parameter setter is in this class since it should return type MyLogisticRegression.
+ def setMaxIter(value: Int): this.type = set(maxIter, value)
+
+ // This method is used by fit()
+ override protected def train(
+ dataset: DataFrame,
+ paramMap: ParamMap): MyLogisticRegressionModel = {
+ // Extract columns from data using helper method.
+ val oldDataset = extractLabeledPoints(dataset, paramMap)
+
+ // Do learning to estimate the weight vector.
+ val numFeatures = oldDataset.take(1)(0).features.size
+ val weights = Vectors.zeros(numFeatures) // Learning would happen here.
+
+ // Create a model, and return it.
+ new MyLogisticRegressionModel(this, paramMap, weights)
+ }
+}
+
+/**
+ * Example of defining a type of [[ClassificationModel]].
+ *
+ * NOTE: This is private since it is an example. In practice, you may not want it to be private.
+ */
+private class MyLogisticRegressionModel(
+ override val parent: MyLogisticRegression,
+ override val fittingParamMap: ParamMap,
+ val weights: Vector)
+ extends ClassificationModel[Vector, MyLogisticRegressionModel]
+ with MyLogisticRegressionParams {
+
+ // This uses the default implementation of transform(), which reads column "features" and outputs
+ // columns "prediction" and "rawPrediction."
+
+ // This uses the default implementation of predict(), which chooses the label corresponding to
+ // the maximum value returned by [[predictRaw()]].
+
+ /**
+ * Raw prediction for each possible label.
+ * The meaning of a "raw" prediction may vary between algorithms, but it intuitively gives
+ * a measure of confidence in each possible label (where larger = more confident).
+ * This internal method is used to implement [[transform()]] and output [[rawPredictionCol]].
+ *
+ * @return vector where element i is the raw prediction for label i.
+ * This raw prediction may be any real number, where a larger value indicates greater
+ * confidence for that label.
+ */
+ override protected def predictRaw(features: Vector): Vector = {
+ val margin = BLAS.dot(features, weights)
+ // There are 2 classes (binary classification), so we return a length-2 vector,
+ // where index i corresponds to class i (i = 0, 1).
+ Vectors.dense(-margin, margin)
+ }
+
+ /** Number of classes the label can take. 2 indicates binary classification. */
+ override val numClasses: Int = 2
+
+ /**
+ * Create a copy of the model.
+ * The copy is shallow, except for the embedded paramMap, which gets a deep copy.
+ *
+ * This is used for the defaul implementation of [[transform()]].
+ */
+ override protected def copy(): MyLogisticRegressionModel = {
+ val m = new MyLogisticRegressionModel(parent, fittingParamMap, weights)
+ Params.inheritValues(this.paramMap, this, m)
+ m
+ }
+}
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala
index 4d1530cd13..80c9f5ff57 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleParamsExample.scala
@@ -72,7 +72,7 @@ object SimpleParamsExample {
paramMap.put(lr.regParam -> 0.1, lr.threshold -> 0.55) // Specify multiple Params.
// One can also combine ParamMaps.
- val paramMap2 = ParamMap(lr.scoreCol -> "probability") // Change output column name
+ val paramMap2 = ParamMap(lr.probabilityCol -> "myProbability") // Change output column name
val paramMapCombined = paramMap ++ paramMap2
// Now learn a new model using the paramMapCombined parameters.
@@ -80,21 +80,21 @@ object SimpleParamsExample {
val model2 = lr.fit(training, paramMapCombined)
println("Model 2 was fit using parameters: " + model2.fittingParamMap)
- // Prepare test documents.
+ // Prepare test data.
val test = sc.parallelize(Seq(
LabeledPoint(1.0, Vectors.dense(-1.0, 1.5, 1.3)),
LabeledPoint(0.0, Vectors.dense(3.0, 2.0, -0.1)),
LabeledPoint(1.0, Vectors.dense(0.0, 2.2, -1.5))))
- // Make predictions on test documents using the Transformer.transform() method.
+ // Make predictions on test data using the Transformer.transform() method.
// LogisticRegression.transform will only use the 'features' column.
- // Note that model2.transform() outputs a 'probability' column instead of the usual 'score'
- // column since we renamed the lr.scoreCol parameter previously.
+ // Note that model2.transform() outputs a 'myProbability' column instead of the usual
+ // 'probability' column since we renamed the lr.probabilityCol parameter previously.
model2.transform(test)
- .select("features", "label", "probability", "prediction")
+ .select("features", "label", "myProbability", "prediction")
.collect()
- .foreach { case Row(features: Vector, label: Double, prob: Double, prediction: Double) =>
- println("(" + features + ", " + label + ") -> prob=" + prob + ", prediction=" + prediction)
+ .foreach { case Row(features: Vector, label: Double, prob: Vector, prediction: Double) =>
+ println("($features, $label) -> prob=$prob, prediction=$prediction")
}
sc.stop()
diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala
index dbbe01dd5c..968cb29212 100644
--- a/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/ml/SimpleTextClassificationPipeline.scala
@@ -23,6 +23,7 @@ import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.LogisticRegression
import org.apache.spark.ml.feature.{HashingTF, Tokenizer}
+import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.sql.{Row, SQLContext}
@BeanInfo
@@ -79,10 +80,10 @@ object SimpleTextClassificationPipeline {
// Make predictions on test documents.
model.transform(test)
- .select("id", "text", "score", "prediction")
+ .select("id", "text", "probability", "prediction")
.collect()
- .foreach { case Row(id: Long, text: String, score: Double, prediction: Double) =>
- println("(" + id + ", " + text + ") --> score=" + score + ", prediction=" + prediction)
+ .foreach { case Row(id: Long, text: String, prob: Vector, prediction: Double) =>
+ println("($id, $text) --> prob=$prob, prediction=$prediction")
}
sc.stop()