aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2015-02-05 23:43:47 -0800
committerXiangrui Meng <meng@databricks.com>2015-02-05 23:43:47 -0800
commitdc0c4490a12ecedd8ca5a1bb256c7ccbdf0be04f (patch)
tree745d33737eaddc95a0c55a814e84c7b96f9ecbcf /mllib
parent6b88825a25a0a072c13bbcc57bbfdb102a3f133d (diff)
downloadspark-dc0c4490a12ecedd8ca5a1bb256c7ccbdf0be04f.tar.gz
spark-dc0c4490a12ecedd8ca5a1bb256c7ccbdf0be04f.tar.bz2
spark-dc0c4490a12ecedd8ca5a1bb256c7ccbdf0be04f.zip
[SPARK-4789] [SPARK-4942] [SPARK-5031] [mllib] Standardize ML Prediction APIs
This is part (1a) of the updates from the design doc in [https://docs.google.com/document/d/1BH9el33kBX8JiDdgUJXdLW14CA2qhTCWIG46eXZVoJs] **UPDATE**: Most of the APIs are being kept private[spark] to allow further discussion. Here is a list of changes which are public: * new output columns: rawPrediction, probabilities * The “score” column is now called “rawPrediction” * Classifiers now provide numClasses * Params.get and .set are now protected instead of private[ml]. * ParamMap now has a size method. * new classes: LinearRegression, LinearRegressionModel * LogisticRegression now has an intercept. ### Sketch of APIs (most of which are private[spark] for now) Abstract classes for learning algorithms (+ corresponding Model abstractions): * Classifier (+ ClassificationModel) * ProbabilisticClassifier (+ ProbabilisticClassificationModel) * Regressor (+ RegressionModel) * Predictor (+ PredictionModel) * *For all of these*: * There is no strongly typed training-time API. * There is a strongly typed test-time (prediction) API which helps developers implement new algorithms. Concrete classes: learning algorithms * LinearRegression * LogisticRegression (updated to use new abstract classes) * Also, removed "score" in favor of "probability" output column. Changed BinaryClassificationEvaluator to match. (SPARK-5031) Other updates: * params.scala: Changed Params.set/get to be protected instead of private[ml] * This was needed for the example of defining a class from outside of the MLlib namespace. * VectorUDT: Will later change from private[spark] to public. * This is needed for outside users to write their own validateAndTransformSchema() methods using vectors. * Also, added equals() method.f * SPARK-4942 : ML Transformers should allow output cols to be turned on,off * Update validateAndTransformSchema * Update transform * (Updated examples, test suites according to other changes) New examples: * DeveloperApiExample.scala (example of defining algorithm from outside of the MLlib namespace) * Added Java version too Test Suites: * LinearRegressionSuite * LogisticRegressionSuite * + Java versions of above suites CC: mengxr etrain shivaram Author: Joseph K. Bradley <joseph@databricks.com> Closes #3637 from jkbradley/ml-api-part1 and squashes the following commits: 405bfb8 [Joseph K. Bradley] Last edits based on code review. Small cleanups fec348a [Joseph K. Bradley] Added JavaDeveloperApiExample.java and fixed other issues: Made developer API private[spark] for now. Added constructors Java can understand to specialized Param types. 8316d5e [Joseph K. Bradley] fixes after rebasing on master fc62406 [Joseph K. Bradley] fixed test suites after last commit bcb9549 [Joseph K. Bradley] Fixed issues after rebasing from master (after move from SchemaRDD to DataFrame) 9872424 [Joseph K. Bradley] fixed JavaLinearRegressionSuite.java Java sql api f542997 [Joseph K. Bradley] Added MIMA excludes for VectorUDT (now public), and added DeveloperApi annotation to it 216d199 [Joseph K. Bradley] fixed after sql datatypes PR got merged f549e34 [Joseph K. Bradley] Updates based on code review. Major ones are: * Created weakly typed Predictor.train() method which is called by fit() so that developers do not have to call schema validation or copy parameters. * Made Predictor.featuresDataType have a default value of VectorUDT. * NOTE: This could be dangerous since the FeaturesType type parameter cannot have a default value. 343e7bd [Joseph K. Bradley] added blanket mima exclude for ml package 82f340b [Joseph K. Bradley] Fixed bug in LogisticRegression (introduced in this PR). Fixed Java suites 0a16da9 [Joseph K. Bradley] Fixed Linear/Logistic RegressionSuites c3c8da5 [Joseph K. Bradley] small cleanup 934f97b [Joseph K. Bradley] Fixed bugs from previous commit. 1c61723 [Joseph K. Bradley] * Made ProbabilisticClassificationModel into a subclass of ClassificationModel. Also introduced ProbabilisticClassifier. * This was to support output column “probabilityCol” in transform(). 4e2f711 [Joseph K. Bradley] rat fix bc654e1 [Joseph K. Bradley] Added spark.ml LinearRegressionSuite 8d13233 [Joseph K. Bradley] Added methods: * Classifier: batch predictRaw() * Predictor: train() without paramMap ProbabilisticClassificationModel.predictProbabilities() * Java versions of all above batch methods + others 1680905 [Joseph K. Bradley] Added JavaLabeledPointSuite.java for spark.ml, and added constructor to LabeledPoint which defaults weight to 1.0 adbe50a [Joseph K. Bradley] * fixed LinearRegression train() to use embedded paramMap * added Predictor.predict(RDD[Vector]) method * updated Linear/LogisticRegressionSuites 58802e3 [Joseph K. Bradley] added train() to Predictor subclasses which does not take a ParamMap. 57d54ab [Joseph K. Bradley] * Changed semantics of Predictor.train() to merge the given paramMap with the embedded paramMap. * remove threshold_internal from logreg * Added Predictor.copy() * Extended LogisticRegressionSuite e433872 [Joseph K. Bradley] Updated docs. Added LabeledPointSuite to spark.ml 54b7b31 [Joseph K. Bradley] Fixed issue with logreg threshold being set correctly 0617d61 [Joseph K. Bradley] Fixed bug from last commit (sorting paramMap by parameter names in toString). Fixed bug in persisting logreg data. Added threshold_internal to logreg for faster test-time prediction (avoiding map lookup). 601e792 [Joseph K. Bradley] Modified ParamMap to sort parameters in toString. Cleaned up classes in class hierarchy, before implementing tests and examples. d705e87 [Joseph K. Bradley] Added LinearRegression and Regressor back from ml-api branch 52f4fde [Joseph K. Bradley] removing everything except for simple class hierarchy for classification d35bb5d [Joseph K. Bradley] fixed compilation issues, but have not added tests yet bfade12 [Joseph K. Bradley] Added lots of classes for new ML API:
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Estimator.scala9
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala206
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala212
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala147
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala24
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala234
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/params.scala68
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala28
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala96
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/Regressor.scala78
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala13
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java2
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java91
-rw-r--r--mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java89
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala86
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala65
17 files changed, 1317 insertions, 135 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
index bc3defe968..eff7ef925d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Estimator.scala
@@ -34,7 +34,8 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage with Params {
* Fits a single model to the input data with optional parameters.
*
* @param dataset input dataset
- * @param paramPairs optional list of param pairs (overwrite embedded params)
+ * @param paramPairs Optional list of param pairs.
+ * These values override any specified in this Estimator's embedded ParamMap.
* @return fitted model
*/
@varargs
@@ -47,7 +48,8 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage with Params {
* Fits a single model to the input data with provided parameter map.
*
* @param dataset input dataset
- * @param paramMap parameter map
+ * @param paramMap Parameter map.
+ * These values override any specified in this Estimator's embedded ParamMap.
* @return fitted model
*/
def fit(dataset: DataFrame, paramMap: ParamMap): M
@@ -58,7 +60,8 @@ abstract class Estimator[M <: Model[M]] extends PipelineStage with Params {
* Subclasses could overwrite this to optimize multi-model training.
*
* @param dataset input dataset
- * @param paramMaps an array of parameter maps
+ * @param paramMaps An array of parameter maps.
+ * These values override any specified in this Estimator's embedded ParamMap.
* @return fitted models, matching the input parameter maps
*/
def fit(dataset: DataFrame, paramMaps: Array[ParamMap]): Seq[M] = {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
new file mode 100644
index 0000000000..1bf8eb4640
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/Classifier.scala
@@ -0,0 +1,206 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.classification
+
+import org.apache.spark.annotation.{DeveloperApi, AlphaComponent}
+import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor, PredictorParams}
+import org.apache.spark.ml.param.{Params, ParamMap, HasRawPredictionCol}
+import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
+import org.apache.spark.sql.Dsl._
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
+
+
+/**
+ * :: DeveloperApi ::
+ * Params for classification.
+ *
+ * NOTE: This is currently private[spark] but will be made public later once it is stabilized.
+ */
+@DeveloperApi
+private[spark] trait ClassifierParams extends PredictorParams
+ with HasRawPredictionCol {
+
+ override protected def validateAndTransformSchema(
+ schema: StructType,
+ paramMap: ParamMap,
+ fitting: Boolean,
+ featuresDataType: DataType): StructType = {
+ val parentSchema = super.validateAndTransformSchema(schema, paramMap, fitting, featuresDataType)
+ val map = this.paramMap ++ paramMap
+ addOutputColumn(parentSchema, map(rawPredictionCol), new VectorUDT)
+ }
+}
+
+/**
+ * :: AlphaComponent ::
+ * Single-label binary or multiclass classification.
+ * Classes are indexed {0, 1, ..., numClasses - 1}.
+ *
+ * @tparam FeaturesType Type of input features. E.g., [[Vector]]
+ * @tparam E Concrete Estimator type
+ * @tparam M Concrete Model type
+ *
+ * NOTE: This is currently private[spark] but will be made public later once it is stabilized.
+ */
+@AlphaComponent
+private[spark] abstract class Classifier[
+ FeaturesType,
+ E <: Classifier[FeaturesType, E, M],
+ M <: ClassificationModel[FeaturesType, M]]
+ extends Predictor[FeaturesType, E, M]
+ with ClassifierParams {
+
+ def setRawPredictionCol(value: String): E =
+ set(rawPredictionCol, value).asInstanceOf[E]
+
+ // TODO: defaultEvaluator (follow-up PR)
+}
+
+/**
+ * :: AlphaComponent ::
+ * Model produced by a [[Classifier]].
+ * Classes are indexed {0, 1, ..., numClasses - 1}.
+ *
+ * @tparam FeaturesType Type of input features. E.g., [[Vector]]
+ * @tparam M Concrete Model type
+ *
+ * NOTE: This is currently private[spark] but will be made public later once it is stabilized.
+ */
+@AlphaComponent
+private[spark]
+abstract class ClassificationModel[FeaturesType, M <: ClassificationModel[FeaturesType, M]]
+ extends PredictionModel[FeaturesType, M] with ClassifierParams {
+
+ def setRawPredictionCol(value: String): M = set(rawPredictionCol, value).asInstanceOf[M]
+
+ /** Number of classes (values which the label can take). */
+ def numClasses: Int
+
+ /**
+ * Transforms dataset by reading from [[featuresCol]], and appending new columns as specified by
+ * parameters:
+ * - predicted labels as [[predictionCol]] of type [[Double]]
+ * - raw predictions (confidences) as [[rawPredictionCol]] of type [[Vector]].
+ *
+ * @param dataset input dataset
+ * @param paramMap additional parameters, overwrite embedded params
+ * @return transformed dataset
+ */
+ override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
+ // This default implementation should be overridden as needed.
+
+ // Check schema
+ transformSchema(dataset.schema, paramMap, logging = true)
+ val map = this.paramMap ++ paramMap
+
+ // Prepare model
+ val tmpModel = if (paramMap.size != 0) {
+ val tmpModel = this.copy()
+ Params.inheritValues(paramMap, parent, tmpModel)
+ tmpModel
+ } else {
+ this
+ }
+
+ val (numColsOutput, outputData) =
+ ClassificationModel.transformColumnsImpl[FeaturesType](dataset, tmpModel, map)
+ if (numColsOutput == 0) {
+ logWarning(s"$uid: ClassificationModel.transform() was called as NOOP" +
+ " since no output columns were set.")
+ }
+ outputData
+ }
+
+ /**
+ * :: DeveloperApi ::
+ *
+ * Predict label for the given features.
+ * This internal method is used to implement [[transform()]] and output [[predictionCol]].
+ *
+ * This default implementation for classification predicts the index of the maximum value
+ * from [[predictRaw()]].
+ */
+ @DeveloperApi
+ override protected def predict(features: FeaturesType): Double = {
+ predictRaw(features).toArray.zipWithIndex.maxBy(_._1)._2
+ }
+
+ /**
+ * :: DeveloperApi ::
+ *
+ * Raw prediction for each possible label.
+ * The meaning of a "raw" prediction may vary between algorithms, but it intuitively gives
+ * a measure of confidence in each possible label (where larger = more confident).
+ * This internal method is used to implement [[transform()]] and output [[rawPredictionCol]].
+ *
+ * @return vector where element i is the raw prediction for label i.
+ * This raw prediction may be any real number, where a larger value indicates greater
+ * confidence for that label.
+ */
+ @DeveloperApi
+ protected def predictRaw(features: FeaturesType): Vector
+
+}
+
+private[ml] object ClassificationModel {
+
+ /**
+ * Added prediction column(s). This is separated from [[ClassificationModel.transform()]]
+ * since it is used by [[org.apache.spark.ml.classification.ProbabilisticClassificationModel]].
+ * @param dataset Input dataset
+ * @param map Parameter map. This will NOT be merged with the embedded paramMap; the merge
+ * should already be done.
+ * @return (number of columns added, transformed dataset)
+ */
+ def transformColumnsImpl[FeaturesType](
+ dataset: DataFrame,
+ model: ClassificationModel[FeaturesType, _],
+ map: ParamMap): (Int, DataFrame) = {
+
+ // Output selected columns only.
+ // This is a bit complicated since it tries to avoid repeated computation.
+ var tmpData = dataset
+ var numColsOutput = 0
+ if (map(model.rawPredictionCol) != "") {
+ // output raw prediction
+ val features2raw: FeaturesType => Vector = model.predictRaw
+ tmpData = tmpData.select($"*",
+ callUDF(features2raw, new VectorUDT,
+ col(map(model.featuresCol))).as(map(model.rawPredictionCol)))
+ numColsOutput += 1
+ if (map(model.predictionCol) != "") {
+ val raw2pred: Vector => Double = (rawPred) => {
+ rawPred.toArray.zipWithIndex.maxBy(_._1)._2
+ }
+ tmpData = tmpData.select($"*", callUDF(raw2pred, DoubleType,
+ col(map(model.rawPredictionCol))).as(map(model.predictionCol)))
+ numColsOutput += 1
+ }
+ } else if (map(model.predictionCol) != "") {
+ // output prediction
+ val features2pred: FeaturesType => Double = model.predict
+ tmpData = tmpData.select($"*",
+ callUDF(features2pred, DoubleType,
+ col(map(model.featuresCol))).as(map(model.predictionCol)))
+ numColsOutput += 1
+ }
+ (numColsOutput, tmpData)
+ }
+
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index b46a5cd8bd..c146fe244c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -18,61 +18,32 @@
package org.apache.spark.ml.classification
import org.apache.spark.annotation.AlphaComponent
-import org.apache.spark.ml._
import org.apache.spark.ml.param._
import org.apache.spark.mllib.classification.LogisticRegressionWithLBFGS
-import org.apache.spark.mllib.linalg.{BLAS, Vector, VectorUDT}
-import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.sql._
+import org.apache.spark.mllib.linalg.{VectorUDT, BLAS, Vector, Vectors}
+import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.Dsl._
-import org.apache.spark.sql.types.{DoubleType, StructField, StructType}
+import org.apache.spark.sql.types.DoubleType
import org.apache.spark.storage.StorageLevel
+
/**
- * :: AlphaComponent ::
* Params for logistic regression.
*/
-@AlphaComponent
-private[classification] trait LogisticRegressionParams extends Params
- with HasRegParam with HasMaxIter with HasLabelCol with HasThreshold with HasFeaturesCol
- with HasScoreCol with HasPredictionCol {
+private[classification] trait LogisticRegressionParams extends ProbabilisticClassifierParams
+ with HasRegParam with HasMaxIter with HasThreshold
- /**
- * Validates and transforms the input schema with the provided param map.
- * @param schema input schema
- * @param paramMap additional parameters
- * @param fitting whether this is in fitting
- * @return output schema
- */
- protected def validateAndTransformSchema(
- schema: StructType,
- paramMap: ParamMap,
- fitting: Boolean): StructType = {
- val map = this.paramMap ++ paramMap
- val featuresType = schema(map(featuresCol)).dataType
- // TODO: Support casting Array[Double] and Array[Float] to Vector.
- require(featuresType.isInstanceOf[VectorUDT],
- s"Features column ${map(featuresCol)} must be a vector column but got $featuresType.")
- if (fitting) {
- val labelType = schema(map(labelCol)).dataType
- require(labelType == DoubleType,
- s"Cannot convert label column ${map(labelCol)} of type $labelType to a double column.")
- }
- val fieldNames = schema.fieldNames
- require(!fieldNames.contains(map(scoreCol)), s"Score column ${map(scoreCol)} already exists.")
- require(!fieldNames.contains(map(predictionCol)),
- s"Prediction column ${map(predictionCol)} already exists.")
- val outputFields = schema.fields ++ Seq(
- StructField(map(scoreCol), DoubleType, false),
- StructField(map(predictionCol), DoubleType, false))
- StructType(outputFields)
- }
-}
/**
+ * :: AlphaComponent ::
+ *
* Logistic regression.
+ * Currently, this class only supports binary classification.
*/
-class LogisticRegression extends Estimator[LogisticRegressionModel] with LogisticRegressionParams {
+@AlphaComponent
+class LogisticRegression
+ extends ProbabilisticClassifier[Vector, LogisticRegression, LogisticRegressionModel]
+ with LogisticRegressionParams {
setRegParam(0.1)
setMaxIter(100)
@@ -80,68 +51,151 @@ class LogisticRegression extends Estimator[LogisticRegressionModel] with Logisti
def setRegParam(value: Double): this.type = set(regParam, value)
def setMaxIter(value: Int): this.type = set(maxIter, value)
- def setLabelCol(value: String): this.type = set(labelCol, value)
def setThreshold(value: Double): this.type = set(threshold, value)
- def setFeaturesCol(value: String): this.type = set(featuresCol, value)
- def setScoreCol(value: String): this.type = set(scoreCol, value)
- def setPredictionCol(value: String): this.type = set(predictionCol, value)
- override def fit(dataset: DataFrame, paramMap: ParamMap): LogisticRegressionModel = {
- transformSchema(dataset.schema, paramMap, logging = true)
- val map = this.paramMap ++ paramMap
- val instances = dataset.select(map(labelCol), map(featuresCol))
- .map { case Row(label: Double, features: Vector) =>
- LabeledPoint(label, features)
- }.persist(StorageLevel.MEMORY_AND_DISK)
+ override protected def train(dataset: DataFrame, paramMap: ParamMap): LogisticRegressionModel = {
+ // Extract columns from data. If dataset is persisted, do not persist oldDataset.
+ val oldDataset = extractLabeledPoints(dataset, paramMap)
+ val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
+ if (handlePersistence) {
+ oldDataset.persist(StorageLevel.MEMORY_AND_DISK)
+ }
+
+ // Train model
val lr = new LogisticRegressionWithLBFGS
lr.optimizer
- .setRegParam(map(regParam))
- .setNumIterations(map(maxIter))
- val lrm = new LogisticRegressionModel(this, map, lr.run(instances).weights)
- instances.unpersist()
- // copy model params
- Params.inheritValues(map, this, lrm)
- lrm
- }
+ .setRegParam(paramMap(regParam))
+ .setNumIterations(paramMap(maxIter))
+ val oldModel = lr.run(oldDataset)
+ val lrm = new LogisticRegressionModel(this, paramMap, oldModel.weights, oldModel.intercept)
- private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
- validateAndTransformSchema(schema, paramMap, fitting = true)
+ if (handlePersistence) {
+ oldDataset.unpersist()
+ }
+ lrm
}
}
+
/**
* :: AlphaComponent ::
+ *
* Model produced by [[LogisticRegression]].
*/
@AlphaComponent
class LogisticRegressionModel private[ml] (
override val parent: LogisticRegression,
override val fittingParamMap: ParamMap,
- weights: Vector)
- extends Model[LogisticRegressionModel] with LogisticRegressionParams {
+ val weights: Vector,
+ val intercept: Double)
+ extends ProbabilisticClassificationModel[Vector, LogisticRegressionModel]
+ with LogisticRegressionParams {
+
+ setThreshold(0.5)
def setThreshold(value: Double): this.type = set(threshold, value)
- def setFeaturesCol(value: String): this.type = set(featuresCol, value)
- def setScoreCol(value: String): this.type = set(scoreCol, value)
- def setPredictionCol(value: String): this.type = set(predictionCol, value)
- private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
- validateAndTransformSchema(schema, paramMap, fitting = false)
+ private val margin: Vector => Double = (features) => {
+ BLAS.dot(features, weights) + intercept
+ }
+
+ private val score: Vector => Double = (features) => {
+ val m = margin(features)
+ 1.0 / (1.0 + math.exp(-m))
}
override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
+ // This is overridden (a) to be more efficient (avoiding re-computing values when creating
+ // multiple output columns) and (b) to handle threshold, which the abstractions do not use.
+ // TODO: We should abstract away the steps defined by UDFs below so that the abstractions
+ // can call whichever UDFs are needed to create the output columns.
+
+ // Check schema
transformSchema(dataset.schema, paramMap, logging = true)
+
val map = this.paramMap ++ paramMap
- val scoreFunction = udf { v: Vector =>
- val margin = BLAS.dot(v, weights)
- 1.0 / (1.0 + math.exp(-margin))
+
+ // Output selected columns only.
+ // This is a bit complicated since it tries to avoid repeated computation.
+ // rawPrediction (-margin, margin)
+ // probability (1.0-score, score)
+ // prediction (max margin)
+ var tmpData = dataset
+ var numColsOutput = 0
+ if (map(rawPredictionCol) != "") {
+ val features2raw: Vector => Vector = (features) => predictRaw(features)
+ tmpData = tmpData.select($"*",
+ callUDF(features2raw, new VectorUDT, col(map(featuresCol))).as(map(rawPredictionCol)))
+ numColsOutput += 1
+ }
+ if (map(probabilityCol) != "") {
+ if (map(rawPredictionCol) != "") {
+ val raw2prob: Vector => Vector = { (rawPreds: Vector) =>
+ val prob1 = 1.0 / (1.0 + math.exp(-rawPreds(1)))
+ Vectors.dense(1.0 - prob1, prob1)
+ }
+ tmpData = tmpData.select($"*",
+ callUDF(raw2prob, new VectorUDT, col(map(rawPredictionCol))).as(map(probabilityCol)))
+ } else {
+ val features2prob: Vector => Vector = (features: Vector) => predictProbabilities(features)
+ tmpData = tmpData.select($"*",
+ callUDF(features2prob, new VectorUDT, col(map(featuresCol))).as(map(probabilityCol)))
+ }
+ numColsOutput += 1
}
- val t = map(threshold)
- val predictFunction = udf { score: Double =>
- if (score > t) 1.0 else 0.0
+ if (map(predictionCol) != "") {
+ val t = map(threshold)
+ if (map(probabilityCol) != "") {
+ val predict: Vector => Double = { probs: Vector =>
+ if (probs(1) > t) 1.0 else 0.0
+ }
+ tmpData = tmpData.select($"*",
+ callUDF(predict, DoubleType, col(map(probabilityCol))).as(map(predictionCol)))
+ } else if (map(rawPredictionCol) != "") {
+ val predict: Vector => Double = { rawPreds: Vector =>
+ val prob1 = 1.0 / (1.0 + math.exp(-rawPreds(1)))
+ if (prob1 > t) 1.0 else 0.0
+ }
+ tmpData = tmpData.select($"*",
+ callUDF(predict, DoubleType, col(map(rawPredictionCol))).as(map(predictionCol)))
+ } else {
+ val predict: Vector => Double = (features: Vector) => this.predict(features)
+ tmpData = tmpData.select($"*",
+ callUDF(predict, DoubleType, col(map(featuresCol))).as(map(predictionCol)))
+ }
+ numColsOutput += 1
}
- dataset
- .select($"*", scoreFunction(col(map(featuresCol))).as(map(scoreCol)))
- .select($"*", predictFunction(col(map(scoreCol))).as(map(predictionCol)))
+ if (numColsOutput == 0) {
+ this.logWarning(s"$uid: LogisticRegressionModel.transform() was called as NOOP" +
+ " since no output columns were set.")
+ }
+ tmpData
+ }
+
+ override val numClasses: Int = 2
+
+ /**
+ * Predict label for the given feature vector.
+ * The behavior of this can be adjusted using [[threshold]].
+ */
+ override protected def predict(features: Vector): Double = {
+ println(s"LR.predict with threshold: ${paramMap(threshold)}")
+ if (score(features) > paramMap(threshold)) 1 else 0
+ }
+
+ override protected def predictProbabilities(features: Vector): Vector = {
+ val s = score(features)
+ Vectors.dense(1.0 - s, s)
+ }
+
+ override protected def predictRaw(features: Vector): Vector = {
+ val m = margin(features)
+ Vectors.dense(0.0, m)
+ }
+
+ override protected def copy(): LogisticRegressionModel = {
+ val m = new LogisticRegressionModel(parent, fittingParamMap, weights, intercept)
+ Params.inheritValues(this.paramMap, this, m)
+ m
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
new file mode 100644
index 0000000000..1202528ca6
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/ProbabilisticClassifier.scala
@@ -0,0 +1,147 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.classification
+
+import org.apache.spark.annotation.{AlphaComponent, DeveloperApi}
+import org.apache.spark.ml.param.{HasProbabilityCol, ParamMap, Params}
+import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.Dsl._
+import org.apache.spark.sql.types.{DataType, StructType}
+
+
+/**
+ * Params for probabilistic classification.
+ */
+private[classification] trait ProbabilisticClassifierParams
+ extends ClassifierParams with HasProbabilityCol {
+
+ override protected def validateAndTransformSchema(
+ schema: StructType,
+ paramMap: ParamMap,
+ fitting: Boolean,
+ featuresDataType: DataType): StructType = {
+ val parentSchema = super.validateAndTransformSchema(schema, paramMap, fitting, featuresDataType)
+ val map = this.paramMap ++ paramMap
+ addOutputColumn(parentSchema, map(probabilityCol), new VectorUDT)
+ }
+}
+
+
+/**
+ * :: AlphaComponent ::
+ *
+ * Single-label binary or multiclass classifier which can output class conditional probabilities.
+ *
+ * @tparam FeaturesType Type of input features. E.g., [[Vector]]
+ * @tparam E Concrete Estimator type
+ * @tparam M Concrete Model type
+ *
+ * NOTE: This is currently private[spark] but will be made public later once it is stabilized.
+ */
+@AlphaComponent
+private[spark] abstract class ProbabilisticClassifier[
+ FeaturesType,
+ E <: ProbabilisticClassifier[FeaturesType, E, M],
+ M <: ProbabilisticClassificationModel[FeaturesType, M]]
+ extends Classifier[FeaturesType, E, M] with ProbabilisticClassifierParams {
+
+ def setProbabilityCol(value: String): E = set(probabilityCol, value).asInstanceOf[E]
+}
+
+
+/**
+ * :: AlphaComponent ::
+ *
+ * Model produced by a [[ProbabilisticClassifier]].
+ * Classes are indexed {0, 1, ..., numClasses - 1}.
+ *
+ * @tparam FeaturesType Type of input features. E.g., [[Vector]]
+ * @tparam M Concrete Model type
+ *
+ * NOTE: This is currently private[spark] but will be made public later once it is stabilized.
+ */
+@AlphaComponent
+private[spark] abstract class ProbabilisticClassificationModel[
+ FeaturesType,
+ M <: ProbabilisticClassificationModel[FeaturesType, M]]
+ extends ClassificationModel[FeaturesType, M] with ProbabilisticClassifierParams {
+
+ def setProbabilityCol(value: String): M = set(probabilityCol, value).asInstanceOf[M]
+
+ /**
+ * Transforms dataset by reading from [[featuresCol]], and appending new columns as specified by
+ * parameters:
+ * - predicted labels as [[predictionCol]] of type [[Double]]
+ * - raw predictions (confidences) as [[rawPredictionCol]] of type [[Vector]]
+ * - probability of each class as [[probabilityCol]] of type [[Vector]].
+ *
+ * @param dataset input dataset
+ * @param paramMap additional parameters, overwrite embedded params
+ * @return transformed dataset
+ */
+ override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
+ // This default implementation should be overridden as needed.
+
+ // Check schema
+ transformSchema(dataset.schema, paramMap, logging = true)
+ val map = this.paramMap ++ paramMap
+
+ // Prepare model
+ val tmpModel = if (paramMap.size != 0) {
+ val tmpModel = this.copy()
+ Params.inheritValues(paramMap, parent, tmpModel)
+ tmpModel
+ } else {
+ this
+ }
+
+ val (numColsOutput, outputData) =
+ ClassificationModel.transformColumnsImpl[FeaturesType](dataset, tmpModel, map)
+
+ // Output selected columns only.
+ if (map(probabilityCol) != "") {
+ // output probabilities
+ val features2probs: FeaturesType => Vector = (features) => {
+ tmpModel.predictProbabilities(features)
+ }
+ outputData.select($"*",
+ callUDF(features2probs, new VectorUDT, col(map(featuresCol))).as(map(probabilityCol)))
+ } else {
+ if (numColsOutput == 0) {
+ this.logWarning(s"$uid: ProbabilisticClassificationModel.transform() was called as NOOP" +
+ " since no output columns were set.")
+ }
+ outputData
+ }
+ }
+
+ /**
+ * :: DeveloperApi ::
+ *
+ * Predict the probability of each class given the features.
+ * These predictions are also called class conditional probabilities.
+ *
+ * WARNING: Not all models output well-calibrated probability estimates! These probabilities
+ * should be treated as confidences, not precise probabilities.
+ *
+ * This internal method is used to implement [[transform()]] and output [[probabilityCol]].
+ */
+ @DeveloperApi
+ protected def predictProbabilities(features: FeaturesType): Vector
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
index 1979ab9eb6..f21a30627e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
@@ -18,19 +18,22 @@
package org.apache.spark.ml.evaluation
import org.apache.spark.annotation.AlphaComponent
-import org.apache.spark.ml._
+import org.apache.spark.ml.Evaluator
import org.apache.spark.ml.param._
import org.apache.spark.mllib.evaluation.BinaryClassificationMetrics
+import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.types.DoubleType
+
/**
* :: AlphaComponent ::
+ *
* Evaluator for binary classification, which expects two input columns: score and label.
*/
@AlphaComponent
class BinaryClassificationEvaluator extends Evaluator with Params
- with HasScoreCol with HasLabelCol {
+ with HasRawPredictionCol with HasLabelCol {
/** param for metric name in evaluation */
val metricName: Param[String] = new Param(this, "metricName",
@@ -38,23 +41,20 @@ class BinaryClassificationEvaluator extends Evaluator with Params
def getMetricName: String = get(metricName)
def setMetricName(value: String): this.type = set(metricName, value)
- def setScoreCol(value: String): this.type = set(scoreCol, value)
+ def setScoreCol(value: String): this.type = set(rawPredictionCol, value)
def setLabelCol(value: String): this.type = set(labelCol, value)
override def evaluate(dataset: DataFrame, paramMap: ParamMap): Double = {
val map = this.paramMap ++ paramMap
val schema = dataset.schema
- val scoreType = schema(map(scoreCol)).dataType
- require(scoreType == DoubleType,
- s"Score column ${map(scoreCol)} must be double type but found $scoreType")
- val labelType = schema(map(labelCol)).dataType
- require(labelType == DoubleType,
- s"Label column ${map(labelCol)} must be double type but found $labelType")
+ checkInputColumn(schema, map(rawPredictionCol), new VectorUDT)
+ checkInputColumn(schema, map(labelCol), DoubleType)
- val scoreAndLabels = dataset.select(map(scoreCol), map(labelCol))
- .map { case Row(score: Double, label: Double) =>
- (score, label)
+ // TODO: When dataset metadata has been implemented, check rawPredictionCol vector length = 2.
+ val scoreAndLabels = dataset.select(map(rawPredictionCol), map(labelCol))
+ .map { case Row(rawPrediction: Vector, label: Double) =>
+ (rawPrediction(1), label)
}
val metrics = new BinaryClassificationMetrics(scoreAndLabels)
val metric = map(metricName) match {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
index e622a5cf9e..0b1f90daa7 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
@@ -29,11 +29,11 @@ import org.apache.spark.sql.types.{DataType, StringType, ArrayType}
@AlphaComponent
class Tokenizer extends UnaryTransformer[String, Seq[String], Tokenizer] {
- protected override def createTransformFunc(paramMap: ParamMap): String => Seq[String] = {
+ override protected def createTransformFunc(paramMap: ParamMap): String => Seq[String] = {
_.toLowerCase.split("\\s")
}
- protected override def validateInputType(inputType: DataType): Unit = {
+ override protected def validateInputType(inputType: DataType): Unit = {
require(inputType == StringType, s"Input type must be string type but got $inputType.")
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala b/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala
new file mode 100644
index 0000000000..89b53f3890
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/impl/estimator/Predictor.scala
@@ -0,0 +1,234 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.impl.estimator
+
+import org.apache.spark.annotation.{AlphaComponent, DeveloperApi}
+import org.apache.spark.ml.{Estimator, Model}
+import org.apache.spark.ml.param._
+import org.apache.spark.mllib.linalg.{VectorUDT, Vector}
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.Dsl._
+import org.apache.spark.sql.types.{DataType, DoubleType, StructType}
+
+
+/**
+ * :: DeveloperApi ::
+ *
+ * Trait for parameters for prediction (regression and classification).
+ *
+ * NOTE: This is currently private[spark] but will be made public later once it is stabilized.
+ */
+@DeveloperApi
+private[spark] trait PredictorParams extends Params
+ with HasLabelCol with HasFeaturesCol with HasPredictionCol {
+
+ /**
+ * Validates and transforms the input schema with the provided param map.
+ * @param schema input schema
+ * @param paramMap additional parameters
+ * @param fitting whether this is in fitting
+ * @param featuresDataType SQL DataType for FeaturesType.
+ * E.g., [[org.apache.spark.mllib.linalg.VectorUDT]] for vector features.
+ * @return output schema
+ */
+ protected def validateAndTransformSchema(
+ schema: StructType,
+ paramMap: ParamMap,
+ fitting: Boolean,
+ featuresDataType: DataType): StructType = {
+ val map = this.paramMap ++ paramMap
+ // TODO: Support casting Array[Double] and Array[Float] to Vector when FeaturesType = Vector
+ checkInputColumn(schema, map(featuresCol), featuresDataType)
+ if (fitting) {
+ // TODO: Allow other numeric types
+ checkInputColumn(schema, map(labelCol), DoubleType)
+ }
+ addOutputColumn(schema, map(predictionCol), DoubleType)
+ }
+}
+
+/**
+ * :: AlphaComponent ::
+ *
+ * Abstraction for prediction problems (regression and classification).
+ *
+ * @tparam FeaturesType Type of features.
+ * E.g., [[org.apache.spark.mllib.linalg.VectorUDT]] for vector features.
+ * @tparam Learner Specialization of this class. If you subclass this type, use this type
+ * parameter to specify the concrete type.
+ * @tparam M Specialization of [[PredictionModel]]. If you subclass this type, use this type
+ * parameter to specify the concrete type for the corresponding model.
+ *
+ * NOTE: This is currently private[spark] but will be made public later once it is stabilized.
+ */
+@AlphaComponent
+private[spark] abstract class Predictor[
+ FeaturesType,
+ Learner <: Predictor[FeaturesType, Learner, M],
+ M <: PredictionModel[FeaturesType, M]]
+ extends Estimator[M] with PredictorParams {
+
+ def setLabelCol(value: String): Learner = set(labelCol, value).asInstanceOf[Learner]
+ def setFeaturesCol(value: String): Learner = set(featuresCol, value).asInstanceOf[Learner]
+ def setPredictionCol(value: String): Learner = set(predictionCol, value).asInstanceOf[Learner]
+
+ override def fit(dataset: DataFrame, paramMap: ParamMap): M = {
+ // This handles a few items such as schema validation.
+ // Developers only need to implement train().
+ transformSchema(dataset.schema, paramMap, logging = true)
+ val map = this.paramMap ++ paramMap
+ val model = train(dataset, map)
+ Params.inheritValues(map, this, model) // copy params to model
+ model
+ }
+
+ /**
+ * :: DeveloperApi ::
+ *
+ * Train a model using the given dataset and parameters.
+ * Developers can implement this instead of [[fit()]] to avoid dealing with schema validation
+ * and copying parameters into the model.
+ *
+ * @param dataset Training dataset
+ * @param paramMap Parameter map. Unlike [[fit()]]'s paramMap, this paramMap has already
+ * been combined with the embedded ParamMap.
+ * @return Fitted model
+ */
+ @DeveloperApi
+ protected def train(dataset: DataFrame, paramMap: ParamMap): M
+
+ /**
+ * :: DeveloperApi ::
+ *
+ * Returns the SQL DataType corresponding to the FeaturesType type parameter.
+ *
+ * This is used by [[validateAndTransformSchema()]].
+ * This workaround is needed since SQL has different APIs for Scala and Java.
+ *
+ * The default value is VectorUDT, but it may be overridden if FeaturesType is not Vector.
+ */
+ @DeveloperApi
+ protected def featuresDataType: DataType = new VectorUDT
+
+ private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ validateAndTransformSchema(schema, paramMap, fitting = true, featuresDataType)
+ }
+
+ /**
+ * Extract [[labelCol]] and [[featuresCol]] from the given dataset,
+ * and put it in an RDD with strong types.
+ */
+ protected def extractLabeledPoints(dataset: DataFrame, paramMap: ParamMap): RDD[LabeledPoint] = {
+ val map = this.paramMap ++ paramMap
+ dataset.select(map(labelCol), map(featuresCol))
+ .map { case Row(label: Double, features: Vector) =>
+ LabeledPoint(label, features)
+ }
+ }
+}
+
+/**
+ * :: AlphaComponent ::
+ *
+ * Abstraction for a model for prediction tasks (regression and classification).
+ *
+ * @tparam FeaturesType Type of features.
+ * E.g., [[org.apache.spark.mllib.linalg.VectorUDT]] for vector features.
+ * @tparam M Specialization of [[PredictionModel]]. If you subclass this type, use this type
+ * parameter to specify the concrete type for the corresponding model.
+ *
+ * NOTE: This is currently private[spark] but will be made public later once it is stabilized.
+ */
+@AlphaComponent
+private[spark] abstract class PredictionModel[FeaturesType, M <: PredictionModel[FeaturesType, M]]
+ extends Model[M] with PredictorParams {
+
+ def setFeaturesCol(value: String): M = set(featuresCol, value).asInstanceOf[M]
+
+ def setPredictionCol(value: String): M = set(predictionCol, value).asInstanceOf[M]
+
+ /**
+ * :: DeveloperApi ::
+ *
+ * Returns the SQL DataType corresponding to the FeaturesType type parameter.
+ *
+ * This is used by [[validateAndTransformSchema()]].
+ * This workaround is needed since SQL has different APIs for Scala and Java.
+ *
+ * The default value is VectorUDT, but it may be overridden if FeaturesType is not Vector.
+ */
+ @DeveloperApi
+ protected def featuresDataType: DataType = new VectorUDT
+
+ private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ validateAndTransformSchema(schema, paramMap, fitting = false, featuresDataType)
+ }
+
+ /**
+ * Transforms dataset by reading from [[featuresCol]], calling [[predict()]], and storing
+ * the predictions as a new column [[predictionCol]].
+ *
+ * @param dataset input dataset
+ * @param paramMap additional parameters, overwrite embedded params
+ * @return transformed dataset with [[predictionCol]] of type [[Double]]
+ */
+ override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
+ // This default implementation should be overridden as needed.
+
+ // Check schema
+ transformSchema(dataset.schema, paramMap, logging = true)
+ val map = this.paramMap ++ paramMap
+
+ // Prepare model
+ val tmpModel = if (paramMap.size != 0) {
+ val tmpModel = this.copy()
+ Params.inheritValues(paramMap, parent, tmpModel)
+ tmpModel
+ } else {
+ this
+ }
+
+ if (map(predictionCol) != "") {
+ val pred: FeaturesType => Double = (features) => {
+ tmpModel.predict(features)
+ }
+ dataset.select($"*", callUDF(pred, DoubleType, col(map(featuresCol))).as(map(predictionCol)))
+ } else {
+ this.logWarning(s"$uid: Predictor.transform() was called as NOOP" +
+ " since no output columns were set.")
+ dataset
+ }
+ }
+
+ /**
+ * :: DeveloperApi ::
+ *
+ * Predict label for the given features.
+ * This internal method is used to implement [[transform()]] and output [[predictionCol]].
+ */
+ @DeveloperApi
+ protected def predict(features: FeaturesType): Double
+
+ /**
+ * Create a copy of the model.
+ * The copy is shallow, except for the embedded paramMap, which gets a deep copy.
+ */
+ protected def copy(): M
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
index 5fb4379e23..17ece897a6 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
@@ -22,8 +22,10 @@ import scala.collection.mutable
import java.lang.reflect.Modifier
-import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.annotation.{AlphaComponent, DeveloperApi}
import org.apache.spark.ml.Identifiable
+import org.apache.spark.sql.types.{DataType, StructField, StructType}
+
/**
* :: AlphaComponent ::
@@ -65,37 +67,47 @@ class Param[T] (
// specialize primitive-typed params because Java doesn't recognize scala.Double, scala.Int, ...
/** Specialized version of [[Param[Double]]] for Java. */
-class DoubleParam(parent: Params, name: String, doc: String, defaultValue: Option[Double] = None)
+class DoubleParam(parent: Params, name: String, doc: String, defaultValue: Option[Double])
extends Param[Double](parent, name, doc, defaultValue) {
+ def this(parent: Params, name: String, doc: String) = this(parent, name, doc, None)
+
override def w(value: Double): ParamPair[Double] = super.w(value)
}
/** Specialized version of [[Param[Int]]] for Java. */
-class IntParam(parent: Params, name: String, doc: String, defaultValue: Option[Int] = None)
+class IntParam(parent: Params, name: String, doc: String, defaultValue: Option[Int])
extends Param[Int](parent, name, doc, defaultValue) {
+ def this(parent: Params, name: String, doc: String) = this(parent, name, doc, None)
+
override def w(value: Int): ParamPair[Int] = super.w(value)
}
/** Specialized version of [[Param[Float]]] for Java. */
-class FloatParam(parent: Params, name: String, doc: String, defaultValue: Option[Float] = None)
+class FloatParam(parent: Params, name: String, doc: String, defaultValue: Option[Float])
extends Param[Float](parent, name, doc, defaultValue) {
+ def this(parent: Params, name: String, doc: String) = this(parent, name, doc, None)
+
override def w(value: Float): ParamPair[Float] = super.w(value)
}
/** Specialized version of [[Param[Long]]] for Java. */
-class LongParam(parent: Params, name: String, doc: String, defaultValue: Option[Long] = None)
+class LongParam(parent: Params, name: String, doc: String, defaultValue: Option[Long])
extends Param[Long](parent, name, doc, defaultValue) {
+ def this(parent: Params, name: String, doc: String) = this(parent, name, doc, None)
+
override def w(value: Long): ParamPair[Long] = super.w(value)
}
/** Specialized version of [[Param[Boolean]]] for Java. */
-class BooleanParam(parent: Params, name: String, doc: String, defaultValue: Option[Boolean] = None)
+class BooleanParam(parent: Params, name: String, doc: String, defaultValue: Option[Boolean])
extends Param[Boolean](parent, name, doc, defaultValue) {
+ def this(parent: Params, name: String, doc: String) = this(parent, name, doc, None)
+
override def w(value: Boolean): ParamPair[Boolean] = super.w(value)
}
@@ -158,7 +170,7 @@ trait Params extends Identifiable with Serializable {
/**
* Sets a parameter in the embedded param map.
*/
- private[ml] def set[T](param: Param[T], value: T): this.type = {
+ protected def set[T](param: Param[T], value: T): this.type = {
require(param.parent.eq(this))
paramMap.put(param.asInstanceOf[Param[Any]], value)
this
@@ -174,7 +186,7 @@ trait Params extends Identifiable with Serializable {
/**
* Gets the value of a parameter in the embedded param map.
*/
- private[ml] def get[T](param: Param[T]): T = {
+ protected def get[T](param: Param[T]): T = {
require(param.parent.eq(this))
paramMap(param)
}
@@ -183,9 +195,40 @@ trait Params extends Identifiable with Serializable {
* Internal param map.
*/
protected val paramMap: ParamMap = ParamMap.empty
+
+ /**
+ * Check whether the given schema contains an input column.
+ * @param colName Parameter name for the input column.
+ * @param dataType SQL DataType of the input column.
+ */
+ protected def checkInputColumn(schema: StructType, colName: String, dataType: DataType): Unit = {
+ val actualDataType = schema(colName).dataType
+ require(actualDataType.equals(dataType),
+ s"Input column $colName must be of type $dataType" +
+ s" but was actually $actualDataType. Column param description: ${getParam(colName)}")
+ }
+
+ protected def addOutputColumn(
+ schema: StructType,
+ colName: String,
+ dataType: DataType): StructType = {
+ if (colName.length == 0) return schema
+ val fieldNames = schema.fieldNames
+ require(!fieldNames.contains(colName), s"Prediction column $colName already exists.")
+ val outputFields = schema.fields ++ Seq(StructField(colName, dataType, nullable = false))
+ StructType(outputFields)
+ }
}
-private[ml] object Params {
+/**
+ * :: DeveloperApi ::
+ *
+ * Helper functionality for developers.
+ *
+ * NOTE: This is currently private[spark] but will be made public later once it is stabilized.
+ */
+@DeveloperApi
+private[spark] object Params {
/**
* Copies parameter values from the parent estimator to the child model it produced.
@@ -279,7 +322,7 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten
def copy: ParamMap = new ParamMap(map.clone())
override def toString: String = {
- map.map { case (param, value) =>
+ map.toSeq.sortBy(_._1.name).map { case (param, value) =>
s"\t${param.parent.uid}-${param.name}: $value"
}.mkString("{\n", ",\n", "\n}")
}
@@ -310,6 +353,11 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten
ParamPair(param, value)
}
}
+
+ /**
+ * Number of param pairs in this set.
+ */
+ def size: Int = map.size
}
object ParamMap {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala
index ef141d3eb2..32fc74462e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/sharedParams.scala
@@ -17,6 +17,12 @@
package org.apache.spark.ml.param
+/* NOTE TO DEVELOPERS:
+ * If you mix these parameter traits into your algorithm, please add a setter method as well
+ * so that users may use a builder pattern:
+ * val myLearner = new MyLearner().setParam1(x).setParam2(y)...
+ */
+
private[ml] trait HasRegParam extends Params {
/** param for regularization parameter */
val regParam: DoubleParam = new DoubleParam(this, "regParam", "regularization parameter")
@@ -42,12 +48,6 @@ private[ml] trait HasLabelCol extends Params {
def getLabelCol: String = get(labelCol)
}
-private[ml] trait HasScoreCol extends Params {
- /** param for score column name */
- val scoreCol: Param[String] = new Param(this, "scoreCol", "score column name", Some("score"))
- def getScoreCol: String = get(scoreCol)
-}
-
private[ml] trait HasPredictionCol extends Params {
/** param for prediction column name */
val predictionCol: Param[String] =
@@ -55,6 +55,22 @@ private[ml] trait HasPredictionCol extends Params {
def getPredictionCol: String = get(predictionCol)
}
+private[ml] trait HasRawPredictionCol extends Params {
+ /** param for raw prediction column name */
+ val rawPredictionCol: Param[String] =
+ new Param(this, "rawPredictionCol", "raw prediction (a.k.a. confidence) column name",
+ Some("rawPrediction"))
+ def getRawPredictionCol: String = get(rawPredictionCol)
+}
+
+private[ml] trait HasProbabilityCol extends Params {
+ /** param for predicted class conditional probabilities column name */
+ val probabilityCol: Param[String] =
+ new Param(this, "probabilityCol", "column name for predicted class conditional probabilities",
+ Some("probability"))
+ def getProbabilityCol: String = get(probabilityCol)
+}
+
private[ml] trait HasThreshold extends Params {
/** param for threshold in (binary) prediction */
val threshold: DoubleParam = new DoubleParam(this, "threshold", "threshold in prediction")
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
new file mode 100644
index 0000000000..d5a7bdafcb
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.regression
+
+import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.ml.param.{Params, ParamMap, HasMaxIter, HasRegParam}
+import org.apache.spark.mllib.linalg.{BLAS, Vector}
+import org.apache.spark.mllib.regression.LinearRegressionWithSGD
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.storage.StorageLevel
+
+
+/**
+ * Params for linear regression.
+ */
+private[regression] trait LinearRegressionParams extends RegressorParams
+ with HasRegParam with HasMaxIter
+
+
+/**
+ * :: AlphaComponent ::
+ *
+ * Linear regression.
+ */
+@AlphaComponent
+class LinearRegression extends Regressor[Vector, LinearRegression, LinearRegressionModel]
+ with LinearRegressionParams {
+
+ setRegParam(0.1)
+ setMaxIter(100)
+
+ def setRegParam(value: Double): this.type = set(regParam, value)
+ def setMaxIter(value: Int): this.type = set(maxIter, value)
+
+ override protected def train(dataset: DataFrame, paramMap: ParamMap): LinearRegressionModel = {
+ // Extract columns from data. If dataset is persisted, do not persist oldDataset.
+ val oldDataset = extractLabeledPoints(dataset, paramMap)
+ val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
+ if (handlePersistence) {
+ oldDataset.persist(StorageLevel.MEMORY_AND_DISK)
+ }
+
+ // Train model
+ val lr = new LinearRegressionWithSGD()
+ lr.optimizer
+ .setRegParam(paramMap(regParam))
+ .setNumIterations(paramMap(maxIter))
+ val model = lr.run(oldDataset)
+ val lrm = new LinearRegressionModel(this, paramMap, model.weights, model.intercept)
+
+ if (handlePersistence) {
+ oldDataset.unpersist()
+ }
+ lrm
+ }
+}
+
+/**
+ * :: AlphaComponent ::
+ *
+ * Model produced by [[LinearRegression]].
+ */
+@AlphaComponent
+class LinearRegressionModel private[ml] (
+ override val parent: LinearRegression,
+ override val fittingParamMap: ParamMap,
+ val weights: Vector,
+ val intercept: Double)
+ extends RegressionModel[Vector, LinearRegressionModel]
+ with LinearRegressionParams {
+
+ override protected def predict(features: Vector): Double = {
+ BLAS.dot(features, weights) + intercept
+ }
+
+ override protected def copy(): LinearRegressionModel = {
+ val m = new LinearRegressionModel(parent, fittingParamMap, weights, intercept)
+ Params.inheritValues(this.paramMap, this, m)
+ m
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/Regressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/Regressor.scala
new file mode 100644
index 0000000000..d679085eea
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/Regressor.scala
@@ -0,0 +1,78 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.regression
+
+import org.apache.spark.annotation.{DeveloperApi, AlphaComponent}
+import org.apache.spark.ml.impl.estimator.{PredictionModel, Predictor, PredictorParams}
+
+/**
+ * :: DeveloperApi ::
+ * Params for regression.
+ * Currently empty, but may add functionality later.
+ *
+ * NOTE: This is currently private[spark] but will be made public later once it is stabilized.
+ */
+@DeveloperApi
+private[spark] trait RegressorParams extends PredictorParams
+
+/**
+ * :: AlphaComponent ::
+ *
+ * Single-label regression
+ *
+ * @tparam FeaturesType Type of input features. E.g., [[org.apache.spark.mllib.linalg.Vector]]
+ * @tparam Learner Concrete Estimator type
+ * @tparam M Concrete Model type
+ *
+ * NOTE: This is currently private[spark] but will be made public later once it is stabilized.
+ */
+@AlphaComponent
+private[spark] abstract class Regressor[
+ FeaturesType,
+ Learner <: Regressor[FeaturesType, Learner, M],
+ M <: RegressionModel[FeaturesType, M]]
+ extends Predictor[FeaturesType, Learner, M]
+ with RegressorParams {
+
+ // TODO: defaultEvaluator (follow-up PR)
+}
+
+/**
+ * :: AlphaComponent ::
+ *
+ * Model produced by a [[Regressor]].
+ *
+ * @tparam FeaturesType Type of input features. E.g., [[org.apache.spark.mllib.linalg.Vector]]
+ * @tparam M Concrete Model type.
+ *
+ * NOTE: This is currently private[spark] but will be made public later once it is stabilized.
+ */
+@AlphaComponent
+private[spark] abstract class RegressionModel[FeaturesType, M <: RegressionModel[FeaturesType, M]]
+ extends PredictionModel[FeaturesType, M] with RegressorParams {
+
+ /**
+ * :: DeveloperApi ::
+ *
+ * Predict real-valued label for the given features.
+ * This internal method is used to implement [[transform()]] and output [[predictionCol]].
+ */
+ @DeveloperApi
+ protected def predict(features: FeaturesType): Double
+
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
index 77785bdbd0..480bbfb5fe 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
@@ -26,6 +26,7 @@ import scala.collection.JavaConverters._
import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV}
import org.apache.spark.SparkException
+import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.mllib.util.NumericParser
import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
@@ -110,9 +111,14 @@ sealed trait Vector extends Serializable {
}
/**
+ * :: DeveloperApi ::
+ *
* User-defined type for [[Vector]] which allows easy interaction with SQL
* via [[org.apache.spark.sql.DataFrame]].
+ *
+ * NOTE: This is currently private[spark] but will be made public later once it is stabilized.
*/
+@DeveloperApi
private[spark] class VectorUDT extends UserDefinedType[Vector] {
override def sqlType: StructType = {
@@ -169,6 +175,13 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] {
override def pyUDT: String = "pyspark.mllib.linalg.VectorUDT"
override def userClass: Class[Vector] = classOf[Vector]
+
+ override def equals(o: Any): Boolean = {
+ o match {
+ case v: VectorUDT => true
+ case _ => false
+ }
+ }
}
/**
diff --git a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java
index 56a9dbdd58..50995ffef9 100644
--- a/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/JavaPipelineSuite.java
@@ -65,7 +65,7 @@ public class JavaPipelineSuite {
.setStages(new PipelineStage[] {scaler, lr});
PipelineModel model = pipeline.fit(dataset);
model.transform(dataset).registerTempTable("prediction");
- DataFrame predictions = jsql.sql("SELECT label, score, prediction FROM prediction");
+ DataFrame predictions = jsql.sql("SELECT label, probability, prediction FROM prediction");
predictions.collectAsList();
}
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
index f4ba23c445..26284023b0 100644
--- a/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
+++ b/mllib/src/test/java/org/apache/spark/ml/classification/JavaLogisticRegressionSuite.java
@@ -18,17 +18,22 @@
package org.apache.spark.ml.classification;
import java.io.Serializable;
+import java.lang.Math;
import java.util.List;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
+import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
+import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList;
+import org.apache.spark.mllib.linalg.Vector;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.sql.DataFrame;
import org.apache.spark.sql.SQLContext;
-import static org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInputAsList;
+import org.apache.spark.sql.Row;
+
public class JavaLogisticRegressionSuite implements Serializable {
@@ -36,12 +41,17 @@ public class JavaLogisticRegressionSuite implements Serializable {
private transient SQLContext jsql;
private transient DataFrame dataset;
+ private transient JavaRDD<LabeledPoint> datasetRDD;
+ private double eps = 1e-5;
+
@Before
public void setUp() {
jsc = new JavaSparkContext("local", "JavaLogisticRegressionSuite");
jsql = new SQLContext(jsc);
List<LabeledPoint> points = generateLogisticInputAsList(1.0, 1.0, 100, 42);
- dataset = jsql.applySchema(jsc.parallelize(points, 2), LabeledPoint.class);
+ datasetRDD = jsc.parallelize(points, 2);
+ dataset = jsql.applySchema(datasetRDD, LabeledPoint.class);
+ dataset.registerTempTable("dataset");
}
@After
@@ -51,29 +61,88 @@ public class JavaLogisticRegressionSuite implements Serializable {
}
@Test
- public void logisticRegression() {
+ public void logisticRegressionDefaultParams() {
LogisticRegression lr = new LogisticRegression();
+ assert(lr.getLabelCol().equals("label"));
LogisticRegressionModel model = lr.fit(dataset);
model.transform(dataset).registerTempTable("prediction");
- DataFrame predictions = jsql.sql("SELECT label, score, prediction FROM prediction");
+ DataFrame predictions = jsql.sql("SELECT label, probability, prediction FROM prediction");
predictions.collectAsList();
+ // Check defaults
+ assert(model.getThreshold() == 0.5);
+ assert(model.getFeaturesCol().equals("features"));
+ assert(model.getPredictionCol().equals("prediction"));
+ assert(model.getProbabilityCol().equals("probability"));
}
@Test
public void logisticRegressionWithSetters() {
+ // Set params, train, and check as many params as we can.
LogisticRegression lr = new LogisticRegression()
.setMaxIter(10)
- .setRegParam(1.0);
+ .setRegParam(1.0)
+ .setThreshold(0.6)
+ .setProbabilityCol("myProbability");
LogisticRegressionModel model = lr.fit(dataset);
- model.transform(dataset, model.threshold().w(0.8)) // overwrite threshold
- .registerTempTable("prediction");
- DataFrame predictions = jsql.sql("SELECT label, score, prediction FROM prediction");
- predictions.collectAsList();
+ assert(model.fittingParamMap().apply(lr.maxIter()) == 10);
+ assert(model.fittingParamMap().apply(lr.regParam()).equals(1.0));
+ assert(model.fittingParamMap().apply(lr.threshold()).equals(0.6));
+ assert(model.getThreshold() == 0.6);
+
+ // Modify model params, and check that the params worked.
+ model.setThreshold(1.0);
+ model.transform(dataset).registerTempTable("predAllZero");
+ DataFrame predAllZero = jsql.sql("SELECT prediction, myProbability FROM predAllZero");
+ for (Row r: predAllZero.collectAsList()) {
+ assert(r.getDouble(0) == 0.0);
+ }
+ // Call transform with params, and check that the params worked.
+ model.transform(dataset, model.threshold().w(0.0), model.probabilityCol().w("myProb"))
+ .registerTempTable("predNotAllZero");
+ DataFrame predNotAllZero = jsql.sql("SELECT prediction, myProb FROM predNotAllZero");
+ boolean foundNonZero = false;
+ for (Row r: predNotAllZero.collectAsList()) {
+ if (r.getDouble(0) != 0.0) foundNonZero = true;
+ }
+ assert(foundNonZero);
+
+ // Call fit() with new params, and check as many params as we can.
+ LogisticRegressionModel model2 = lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1),
+ lr.threshold().w(0.4), lr.probabilityCol().w("theProb"));
+ assert(model2.fittingParamMap().apply(lr.maxIter()) == 5);
+ assert(model2.fittingParamMap().apply(lr.regParam()).equals(0.1));
+ assert(model2.fittingParamMap().apply(lr.threshold()).equals(0.4));
+ assert(model2.getThreshold() == 0.4);
+ assert(model2.getProbabilityCol().equals("theProb"));
}
+ @SuppressWarnings("unchecked")
@Test
- public void logisticRegressionFitWithVarargs() {
+ public void logisticRegressionPredictorClassifierMethods() {
LogisticRegression lr = new LogisticRegression();
- lr.fit(dataset, lr.maxIter().w(10), lr.regParam().w(1.0));
+ LogisticRegressionModel model = lr.fit(dataset);
+ assert(model.numClasses() == 2);
+
+ model.transform(dataset).registerTempTable("transformed");
+ DataFrame trans1 = jsql.sql("SELECT rawPrediction, probability FROM transformed");
+ for (Row row: trans1.collect()) {
+ Vector raw = (Vector)row.get(0);
+ Vector prob = (Vector)row.get(1);
+ assert(raw.size() == 2);
+ assert(prob.size() == 2);
+ double probFromRaw1 = 1.0 / (1.0 + Math.exp(-raw.apply(1)));
+ assert(Math.abs(prob.apply(1) - probFromRaw1) < eps);
+ assert(Math.abs(prob.apply(0) - (1.0 - probFromRaw1)) < eps);
+ }
+
+ DataFrame trans2 = jsql.sql("SELECT prediction, probability FROM transformed");
+ for (Row row: trans2.collect()) {
+ double pred = row.getDouble(0);
+ Vector prob = (Vector)row.get(1);
+ double probOfPred = prob.apply((int)pred);
+ for (int i = 0; i < prob.size(); ++i) {
+ assert(probOfPred >= prob.apply(i));
+ }
+ }
}
}
diff --git a/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
new file mode 100644
index 0000000000..5bd616e74d
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/ml/regression/JavaLinearRegressionSuite.java
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.regression;
+
+import java.io.Serializable;
+import java.util.List;
+
+import org.junit.After;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+import static org.apache.spark.mllib.classification.LogisticRegressionSuite
+ .generateLogisticInputAsList;
+import org.apache.spark.mllib.regression.LabeledPoint;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.SQLContext;
+
+
+public class JavaLinearRegressionSuite implements Serializable {
+
+ private transient JavaSparkContext jsc;
+ private transient SQLContext jsql;
+ private transient DataFrame dataset;
+ private transient JavaRDD<LabeledPoint> datasetRDD;
+
+ @Before
+ public void setUp() {
+ jsc = new JavaSparkContext("local", "JavaLinearRegressionSuite");
+ jsql = new SQLContext(jsc);
+ List<LabeledPoint> points = generateLogisticInputAsList(1.0, 1.0, 100, 42);
+ datasetRDD = jsc.parallelize(points, 2);
+ dataset = jsql.applySchema(datasetRDD, LabeledPoint.class);
+ dataset.registerTempTable("dataset");
+ }
+
+ @After
+ public void tearDown() {
+ jsc.stop();
+ jsc = null;
+ }
+
+ @Test
+ public void linearRegressionDefaultParams() {
+ LinearRegression lr = new LinearRegression();
+ assert(lr.getLabelCol().equals("label"));
+ LinearRegressionModel model = lr.fit(dataset);
+ model.transform(dataset).registerTempTable("prediction");
+ DataFrame predictions = jsql.sql("SELECT label, prediction FROM prediction");
+ predictions.collect();
+ // Check defaults
+ assert(model.getFeaturesCol().equals("features"));
+ assert(model.getPredictionCol().equals("prediction"));
+ }
+
+ @Test
+ public void linearRegressionWithSetters() {
+ // Set params, train, and check as many params as we can.
+ LinearRegression lr = new LinearRegression()
+ .setMaxIter(10)
+ .setRegParam(1.0);
+ LinearRegressionModel model = lr.fit(dataset);
+ assert(model.fittingParamMap().apply(lr.maxIter()) == 10);
+ assert(model.fittingParamMap().apply(lr.regParam()).equals(1.0));
+
+ // Call fit() with new params, and check as many params as we can.
+ LinearRegressionModel model2 =
+ lr.fit(dataset, lr.maxIter().w(5), lr.regParam().w(0.1), lr.predictionCol().w("thePred"));
+ assert(model2.fittingParamMap().apply(lr.maxIter()) == 5);
+ assert(model2.fittingParamMap().apply(lr.regParam()).equals(0.1));
+ assert(model2.getPredictionCol().equals("thePred"));
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index 33e40dc741..b3d1bfcfbe 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -20,44 +20,108 @@ package org.apache.spark.ml.classification
import org.scalatest.FunSuite
import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
+import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.{SQLContext, DataFrame}
+import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+
class LogisticRegressionSuite extends FunSuite with MLlibTestSparkContext {
@transient var sqlContext: SQLContext = _
@transient var dataset: DataFrame = _
+ private val eps: Double = 1e-5
override def beforeAll(): Unit = {
super.beforeAll()
sqlContext = new SQLContext(sc)
dataset = sqlContext.createDataFrame(
- sc.parallelize(generateLogisticInput(1.0, 1.0, 100, 42), 2))
+ sc.parallelize(generateLogisticInput(1.0, 1.0, nPoints = 100, seed = 42), 2))
}
- test("logistic regression") {
+ test("logistic regression: default params") {
val lr = new LogisticRegression
+ assert(lr.getLabelCol == "label")
+ assert(lr.getFeaturesCol == "features")
+ assert(lr.getPredictionCol == "prediction")
+ assert(lr.getRawPredictionCol == "rawPrediction")
+ assert(lr.getProbabilityCol == "probability")
val model = lr.fit(dataset)
model.transform(dataset)
- .select("label", "prediction")
+ .select("label", "probability", "prediction", "rawPrediction")
.collect()
+ assert(model.getThreshold === 0.5)
+ assert(model.getFeaturesCol == "features")
+ assert(model.getPredictionCol == "prediction")
+ assert(model.getRawPredictionCol == "rawPrediction")
+ assert(model.getProbabilityCol == "probability")
}
test("logistic regression with setters") {
+ // Set params, train, and check as many params as we can.
val lr = new LogisticRegression()
.setMaxIter(10)
.setRegParam(1.0)
+ .setThreshold(0.6)
+ .setProbabilityCol("myProbability")
val model = lr.fit(dataset)
- model.transform(dataset, model.threshold -> 0.8) // overwrite threshold
- .select("label", "score", "prediction")
+ assert(model.fittingParamMap.get(lr.maxIter) === Some(10))
+ assert(model.fittingParamMap.get(lr.regParam) === Some(1.0))
+ assert(model.fittingParamMap.get(lr.threshold) === Some(0.6))
+ assert(model.getThreshold === 0.6)
+
+ // Modify model params, and check that the params worked.
+ model.setThreshold(1.0)
+ val predAllZero = model.transform(dataset)
+ .select("prediction", "myProbability")
.collect()
+ .map { case Row(pred: Double, prob: Vector) => pred }
+ assert(predAllZero.forall(_ === 0),
+ s"With threshold=1.0, expected predictions to be all 0, but only" +
+ s" ${predAllZero.count(_ === 0)} of ${dataset.count()} were 0.")
+ // Call transform with params, and check that the params worked.
+ val predNotAllZero =
+ model.transform(dataset, model.threshold -> 0.0, model.probabilityCol -> "myProb")
+ .select("prediction", "myProb")
+ .collect()
+ .map { case Row(pred: Double, prob: Vector) => pred }
+ assert(predNotAllZero.exists(_ !== 0.0))
+
+ // Call fit() with new params, and check as many params as we can.
+ val model2 = lr.fit(dataset, lr.maxIter -> 5, lr.regParam -> 0.1, lr.threshold -> 0.4,
+ lr.probabilityCol -> "theProb")
+ assert(model2.fittingParamMap.get(lr.maxIter).get === 5)
+ assert(model2.fittingParamMap.get(lr.regParam).get === 0.1)
+ assert(model2.fittingParamMap.get(lr.threshold).get === 0.4)
+ assert(model2.getThreshold === 0.4)
+ assert(model2.getProbabilityCol == "theProb")
}
- test("logistic regression fit and transform with varargs") {
+ test("logistic regression: Predictor, Classifier methods") {
+ val sqlContext = this.sqlContext
val lr = new LogisticRegression
- val model = lr.fit(dataset, lr.maxIter -> 10, lr.regParam -> 1.0)
- model.transform(dataset, model.threshold -> 0.8, model.scoreCol -> "probability")
- .select("label", "probability", "prediction")
- .collect()
+
+ val model = lr.fit(dataset)
+ assert(model.numClasses === 2)
+
+ val threshold = model.getThreshold
+ val results = model.transform(dataset)
+
+ // Compare rawPrediction with probability
+ results.select("rawPrediction", "probability").collect().map {
+ case Row(raw: Vector, prob: Vector) =>
+ assert(raw.size === 2)
+ assert(prob.size === 2)
+ val probFromRaw1 = 1.0 / (1.0 + math.exp(-raw(1)))
+ assert(prob(1) ~== probFromRaw1 relTol eps)
+ assert(prob(0) ~== 1.0 - probFromRaw1 relTol eps)
+ }
+
+ // Compare prediction with probability
+ results.select("prediction", "probability").collect().map {
+ case Row(pred: Double, prob: Vector) =>
+ val predFromProb = prob.toArray.zipWithIndex.maxBy(_._1)._2
+ assert(pred == predFromProb)
+ }
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
new file mode 100644
index 0000000000..bbb44c3e2d
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.regression
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.{DataFrame, SQLContext}
+
+class LinearRegressionSuite extends FunSuite with MLlibTestSparkContext {
+
+ @transient var sqlContext: SQLContext = _
+ @transient var dataset: DataFrame = _
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ sqlContext = new SQLContext(sc)
+ dataset = sqlContext.createDataFrame(
+ sc.parallelize(generateLogisticInput(1.0, 1.0, nPoints = 100, seed = 42), 2))
+ }
+
+ test("linear regression: default params") {
+ val lr = new LinearRegression
+ assert(lr.getLabelCol == "label")
+ val model = lr.fit(dataset)
+ model.transform(dataset)
+ .select("label", "prediction")
+ .collect()
+ // Check defaults
+ assert(model.getFeaturesCol == "features")
+ assert(model.getPredictionCol == "prediction")
+ }
+
+ test("linear regression with setters") {
+ // Set params, train, and check as many as we can.
+ val lr = new LinearRegression()
+ .setMaxIter(10)
+ .setRegParam(1.0)
+ val model = lr.fit(dataset)
+ assert(model.fittingParamMap.get(lr.maxIter).get === 10)
+ assert(model.fittingParamMap.get(lr.regParam).get === 1.0)
+
+ // Call fit() with new params, and check as many as we can.
+ val model2 = lr.fit(dataset, lr.maxIter -> 5, lr.regParam -> 0.1, lr.predictionCol -> "thePred")
+ assert(model2.fittingParamMap.get(lr.maxIter).get === 5)
+ assert(model2.fittingParamMap.get(lr.regParam).get === 0.1)
+ assert(model2.getPredictionCol == "thePred")
+ }
+}