aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorDB Tsai <dbt@netflix.com>2015-10-07 15:56:57 -0700
committerXiangrui Meng <meng@databricks.com>2015-10-07 15:56:57 -0700
commitdd36ec6bc5844aaa045a4bd9ba49113528e1740c (patch)
tree42a0aa25a707ad9a30ad9dd0fd77f3015e0e4043 /mllib
parent7e2e268289828ae664622c59b90d82938d957ff3 (diff)
downloadspark-dd36ec6bc5844aaa045a4bd9ba49113528e1740c.tar.gz
spark-dd36ec6bc5844aaa045a4bd9ba49113528e1740c.tar.bz2
spark-dd36ec6bc5844aaa045a4bd9ba49113528e1740c.zip
[SPARK-10738] [ML] Refactoring `Instance` out from LOR and LIR, and also cleaning up some code
Refactoring `Instance` case class out from LOR and LIR, and also cleaning up some code. Author: DB Tsai <dbt@netflix.com> Closes #8853 from dbtsai/refactoring.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala116
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Instance.scala29
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala82
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala1
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala1
5 files changed, 125 insertions, 104 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index c17a7b0c36..6f839ff4d7 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -24,6 +24,7 @@ import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS,
import org.apache.spark.{Logging, SparkException}
import org.apache.spark.annotation.Experimental
+import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.Identifiable
@@ -147,17 +148,6 @@ private[classification] trait LogisticRegressionParams extends ProbabilisticClas
}
/**
- * Class that represents an instance of weighted data point with label and features.
- *
- * TODO: Refactor this class to proper place.
- *
- * @param label Label for this data point.
- * @param weight The weight of this instance.
- * @param features The vector of features for this data point.
- */
-private[classification] case class Instance(label: Double, weight: Double, features: Vector)
-
-/**
* :: Experimental ::
* Logistic regression.
* Currently, this class only supports binary classification. It will support multiclass
@@ -322,7 +312,7 @@ class LogisticRegression(override val uid: String)
if ($(fitIntercept)) {
/*
- For binary logistic regression, when we initialize the weights as zeros,
+ For binary logistic regression, when we initialize the coefficients as zeros,
it will converge faster if we initialize the intercept such that
it follows the distribution of the labels.
@@ -757,62 +747,63 @@ private class LogisticAggregator(
private val gradientSumArray = Array.ofDim[Double](coefficientsArray.length)
/**
- * Add a new training data to this LogisticAggregator, and update the loss and gradient
+ * Add a new training instance to this LogisticAggregator, and update the loss and gradient
* of the objective function.
*
- * @param label The label for this data point.
- * @param data The features for one data point in dense/sparse vector format to be added
- * into this aggregator.
- * @param weight The weight for over-/undersamples each of training instance. Default is one.
+ * @param instance The instance of data point to be added.
* @return This LogisticAggregator object.
*/
- def add(label: Double, data: Vector, weight: Double = 1.0): this.type = {
- require(dim == data.size, s"Dimensions mismatch when adding new instance." +
- s" Expecting $dim but got ${data.size}.")
- require(weight >= 0.0, s"instance weight, ${weight} has to be >= 0.0")
-
- if (weight == 0.0) return this
+ def add(instance: Instance): this.type = {
+ instance match { case Instance(label, weight, features) =>
+ require(dim == features.size, s"Dimensions mismatch when adding new instance." +
+ s" Expecting $dim but got ${features.size}.")
+ require(weight >= 0.0, s"instance weight, ${weight} has to be >= 0.0")
+
+ if (weight == 0.0) return this
+
+ val localCoefficientsArray = coefficientsArray
+ val localGradientSumArray = gradientSumArray
+
+ numClasses match {
+ case 2 =>
+ // For Binary Logistic Regression.
+ val margin = - {
+ var sum = 0.0
+ features.foreachActive { (index, value) =>
+ if (featuresStd(index) != 0.0 && value != 0.0) {
+ sum += localCoefficientsArray(index) * (value / featuresStd(index))
+ }
+ }
+ sum + {
+ if (fitIntercept) localCoefficientsArray(dim) else 0.0
+ }
+ }
- val localCoefficientsArray = coefficientsArray
- val localGradientSumArray = gradientSumArray
+ val multiplier = weight * (1.0 / (1.0 + math.exp(margin)) - label)
- numClasses match {
- case 2 =>
- // For Binary Logistic Regression.
- val margin = - {
- var sum = 0.0
- data.foreachActive { (index, value) =>
+ features.foreachActive { (index, value) =>
if (featuresStd(index) != 0.0 && value != 0.0) {
- sum += localCoefficientsArray(index) * (value / featuresStd(index))
+ localGradientSumArray(index) += multiplier * (value / featuresStd(index))
}
}
- sum + { if (fitIntercept) localCoefficientsArray(dim) else 0.0 }
- }
-
- val multiplier = weight * (1.0 / (1.0 + math.exp(margin)) - label)
- data.foreachActive { (index, value) =>
- if (featuresStd(index) != 0.0 && value != 0.0) {
- localGradientSumArray(index) += multiplier * (value / featuresStd(index))
+ if (fitIntercept) {
+ localGradientSumArray(dim) += multiplier
}
- }
- if (fitIntercept) {
- localGradientSumArray(dim) += multiplier
- }
-
- if (label > 0) {
- // The following is equivalent to log(1 + exp(margin)) but more numerically stable.
- lossSum += weight * MLUtils.log1pExp(margin)
- } else {
- lossSum += weight * (MLUtils.log1pExp(margin) - margin)
- }
- case _ =>
- new NotImplementedError("LogisticRegression with ElasticNet in ML package only supports " +
- "binary classification for now.")
+ if (label > 0) {
+ // The following is equivalent to log(1 + exp(margin)) but more numerically stable.
+ lossSum += weight * MLUtils.log1pExp(margin)
+ } else {
+ lossSum += weight * (MLUtils.log1pExp(margin) - margin)
+ }
+ case _ =>
+ new NotImplementedError("LogisticRegression with ElasticNet in ML package " +
+ "only supports binary classification for now.")
+ }
+ weightSum += weight
+ this
}
- weightSum += weight
- this
}
/**
@@ -861,11 +852,11 @@ private class LogisticAggregator(
/**
* LogisticCostFun implements Breeze's DiffFunction[T] for a multinomial logistic loss function,
* as used in multi-class classification (it is also used in binary logistic regression).
- * It returns the loss and gradient with L2 regularization at a particular point (weights).
+ * It returns the loss and gradient with L2 regularization at a particular point (coefficients).
* It's used in Breeze's convex optimization routines.
*/
private class LogisticCostFun(
- data: RDD[Instance],
+ instances: RDD[Instance],
numClasses: Int,
fitIntercept: Boolean,
standardization: Boolean,
@@ -875,15 +866,14 @@ private class LogisticCostFun(
override def calculate(coefficients: BDV[Double]): (Double, BDV[Double]) = {
val numFeatures = featuresStd.length
- val w = Vectors.fromBreeze(coefficients)
+ val coeffs = Vectors.fromBreeze(coefficients)
val logisticAggregator = {
- val seqOp = (c: LogisticAggregator, instance: Instance) =>
- c.add(instance.label, instance.features, instance.weight)
+ val seqOp = (c: LogisticAggregator, instance: Instance) => c.add(instance)
val combOp = (c1: LogisticAggregator, c2: LogisticAggregator) => c1.merge(c2)
- data.treeAggregate(
- new LogisticAggregator(w, numClasses, fitIntercept, featuresStd, featuresMean)
+ instances.treeAggregate(
+ new LogisticAggregator(coeffs, numClasses, fitIntercept, featuresStd, featuresMean)
)(seqOp, combOp)
}
@@ -894,7 +884,7 @@ private class LogisticCostFun(
0.0
} else {
var sum = 0.0
- w.foreachActive { (index, value) =>
+ coeffs.foreachActive { (index, value) =>
// If `fitIntercept` is true, the last term which is intercept doesn't
// contribute to the regularization.
if (index != numFeatures) {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Instance.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Instance.scala
new file mode 100644
index 0000000000..12176757ae
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Instance.scala
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.apache.spark.mllib.linalg.Vector
+
+/**
+ * Class that represents an instance of weighted data point with label and features.
+ *
+ * @param label Label for this data point.
+ * @param weight The weight of this instance.
+ * @param features The vector of features for this data point.
+ */
+private[ml] case class Instance(label: Double, weight: Double, features: Vector)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
index a77e702141..0dc084fdd1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/LinearRegression.scala
@@ -19,11 +19,12 @@ package org.apache.spark.ml.regression
import scala.collection.mutable
-import breeze.linalg.{DenseVector => BDV, norm => brzNorm}
+import breeze.linalg.{DenseVector => BDV}
import breeze.optimize.{CachedDiffFunction, DiffFunction, LBFGS => BreezeLBFGS, OWLQN => BreezeOWLQN}
import org.apache.spark.{Logging, SparkException}
import org.apache.spark.annotation.Experimental
+import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.PredictorParams
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.param.shared._
@@ -45,23 +46,12 @@ private[regression] trait LinearRegressionParams extends PredictorParams
with HasFitIntercept with HasStandardization with HasWeightCol
/**
- * Class that represents an instance of weighted data point with label and features.
- *
- * TODO: Refactor this class to proper place.
- *
- * @param label Label for this data point.
- * @param weight The weight of this instance.
- * @param features The vector of features for this data point.
- */
-private[regression] case class Instance(label: Double, weight: Double, features: Vector)
-
-/**
* :: Experimental ::
* Linear regression.
*
* The learning objective is to minimize the squared error, with regularization.
* The specific squared error loss function used is:
- * L = 1/2n ||A weights - y||^2^
+ * L = 1/2n ||A coefficients - y||^2^
*
* This support multiple types of regularization:
* - none (a.k.a. ordinary least squares)
@@ -172,13 +162,14 @@ class LinearRegression(override val uid: String)
// If the yStd is zero, then the intercept is yMean with zero weights;
// as a result, training is not needed.
if (yStd == 0.0) {
- logWarning(s"The standard deviation of the label is zero, so the weights will be zeros " +
- s"and the intercept will be the mean of the label; as a result, training is not needed.")
+ logWarning(s"The standard deviation of the label is zero, so the coefficients will be " +
+ s"zeros and the intercept will be the mean of the label; as a result, " +
+ s"training is not needed.")
if (handlePersistence) instances.unpersist()
- val weights = Vectors.sparse(numFeatures, Seq())
+ val coefficients = Vectors.sparse(numFeatures, Seq())
val intercept = yMean
- val model = new LinearRegressionModel(uid, weights, intercept)
+ val model = new LinearRegressionModel(uid, coefficients, intercept)
val trainingSummary = new LinearRegressionTrainingSummary(
model.transform(dataset),
$(predictionCol),
@@ -218,11 +209,11 @@ class LinearRegression(override val uid: String)
new BreezeOWLQN[Int, BDV[Double]]($(maxIter), 10, effectiveL1RegFun, $(tol))
}
- val initialWeights = Vectors.zeros(numFeatures)
+ val initialCoefficients = Vectors.zeros(numFeatures)
val states = optimizer.iterations(new CachedDiffFunction(costFun),
- initialWeights.toBreeze.toDenseVector)
+ initialCoefficients.toBreeze.toDenseVector)
- val (weights, objectiveHistory) = {
+ val (coefficients, objectiveHistory) = {
/*
Note that in Linear Regression, the objective history (loss + regularization) returned
from optimizer is computed in the scaled space given by the following formula.
@@ -243,18 +234,18 @@ class LinearRegression(override val uid: String)
}
/*
- The weights are trained in the scaled space; we're converting them back to
+ The coefficients are trained in the scaled space; we're converting them back to
the original space.
*/
- val rawWeights = state.x.toArray.clone()
+ val rawCoefficients = state.x.toArray.clone()
var i = 0
- val len = rawWeights.length
+ val len = rawCoefficients.length
while (i < len) {
- rawWeights(i) *= { if (featuresStd(i) != 0.0) yStd / featuresStd(i) else 0.0 }
+ rawCoefficients(i) *= { if (featuresStd(i) != 0.0) yStd / featuresStd(i) else 0.0 }
i += 1
}
- (Vectors.dense(rawWeights).compressed, arrayBuilder.result())
+ (Vectors.dense(rawCoefficients).compressed, arrayBuilder.result())
}
/*
@@ -262,11 +253,15 @@ class LinearRegression(override val uid: String)
converged. See the following discussion for detail.
http://stats.stackexchange.com/questions/13617/how-is-the-intercept-computed-in-glmnet
*/
- val intercept = if ($(fitIntercept)) yMean - dot(weights, Vectors.dense(featuresMean)) else 0.0
+ val intercept = if ($(fitIntercept)) {
+ yMean - dot(coefficients, Vectors.dense(featuresMean))
+ } else {
+ 0.0
+ }
if (handlePersistence) instances.unpersist()
- val model = copyValues(new LinearRegressionModel(uid, weights, intercept))
+ val model = copyValues(new LinearRegressionModel(uid, coefficients, intercept))
val trainingSummary = new LinearRegressionTrainingSummary(
model.transform(dataset),
$(predictionCol),
@@ -425,7 +420,7 @@ class LinearRegressionSummary private[regression] (
* For improving the convergence rate during the optimization process, and also preventing against
* features with very large variances exerting an overly large influence during model training,
* package like R's GLMNET performs the scaling to unit variance and removing the mean to reduce
- * the condition number, and then trains the model in scaled space but returns the weights in
+ * the condition number, and then trains the model in scaled space but returns the coefficients in
* the original scale. See page 9 in http://cran.r-project.org/web/packages/glmnet/glmnet.pdf
*
* However, we don't want to apply the `StandardScaler` on the training dataset, and then cache
@@ -456,7 +451,7 @@ class LinearRegressionSummary private[regression] (
* + \bar{y} / \hat{y}||^2
* = 1/2n ||\sum_i w_i^\prime x_i - y / \hat{y} + offset||^2 = 1/2n diff^2
* }}}
- * where w_i^\prime^ is the effective weights defined by w_i/\hat{x_i}, offset is
+ * where w_i^\prime^ is the effective coefficients defined by w_i/\hat{x_i}, offset is
* {{{
* - \sum_i (w_i/\hat{x_i})\bar{x_i} + \bar{y} / \hat{y}.
* }}}, and diff is
@@ -465,7 +460,7 @@ class LinearRegressionSummary private[regression] (
* }}}
*
*
- * Note that the effective weights and offset don't depend on training dataset,
+ * Note that the effective coefficients and offset don't depend on training dataset,
* so they can be precomputed.
*
* Now, the first derivative of the objective function in scaled space is
@@ -543,13 +538,13 @@ private class LeastSquaresAggregator(
private val gradientSumArray = Array.ofDim[Double](dim)
/**
- * Add a new training data to this LeastSquaresAggregator, and update the loss and gradient
+ * Add a new training instance to this LeastSquaresAggregator, and update the loss and gradient
* of the objective function.
*
- * @param instance The data point instance to be added.
+ * @param instance The instance of data point to be added.
* @return This LeastSquaresAggregator object.
*/
- def add(instance: Instance): this.type =
+ def add(instance: Instance): this.type = {
instance match { case Instance(label, weight, features) =>
require(dim == features.size, s"Dimensions mismatch when adding new sample." +
s" Expecting $dim but got ${features.size}.")
@@ -573,6 +568,7 @@ private class LeastSquaresAggregator(
weightSum += weight
this
}
+ }
/**
* Merge another LeastSquaresAggregator, and update the loss and gradient
@@ -621,11 +617,11 @@ private class LeastSquaresAggregator(
/**
* LeastSquaresCostFun implements Breeze's DiffFunction[T] for Least Squares cost.
- * It returns the loss and gradient with L2 regularization at a particular point (weights).
+ * It returns the loss and gradient with L2 regularization at a particular point (coefficients).
* It's used in Breeze's convex optimization routines.
*/
private class LeastSquaresCostFun(
- data: RDD[Instance],
+ instances: RDD[Instance],
labelStd: Double,
labelMean: Double,
fitIntercept: Boolean,
@@ -635,12 +631,16 @@ private class LeastSquaresCostFun(
effectiveL2regParam: Double) extends DiffFunction[BDV[Double]] {
override def calculate(coefficients: BDV[Double]): (Double, BDV[Double]) = {
- val coeff = Vectors.fromBreeze(coefficients)
+ val coeffs = Vectors.fromBreeze(coefficients)
- val leastSquaresAggregator = data.treeAggregate(new LeastSquaresAggregator(coeff, labelStd,
- labelMean, fitIntercept, featuresStd, featuresMean))(
- seqOp = (aggregator, instance) => aggregator.add(instance),
- combOp = (aggregator1, aggregator2) => aggregator1.merge(aggregator2))
+ val leastSquaresAggregator = {
+ val seqOp = (c: LeastSquaresAggregator, instance: Instance) => c.add(instance)
+ val combOp = (c1: LeastSquaresAggregator, c2: LeastSquaresAggregator) => c1.merge(c2)
+
+ instances.treeAggregate(
+ new LeastSquaresAggregator(coeffs, labelStd, labelMean, fitIntercept, featuresStd,
+ featuresMean))(seqOp, combOp)
+ }
val totalGradientArray = leastSquaresAggregator.gradient.toArray
@@ -648,7 +648,7 @@ private class LeastSquaresCostFun(
0.0
} else {
var sum = 0.0
- coeff.foreachActive { (index, value) =>
+ coeffs.foreachActive { (index, value) =>
// The following code will compute the loss of the regularization; also
// the gradient of the regularization, and add back to totalGradientArray.
sum += {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index ec01998601..5186c4e2be 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.ml.classification
import scala.util.Random
import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.MLTestingUtils
import org.apache.spark.mllib.classification.LogisticRegressionSuite._
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
index 7cb9471e69..32729470d5 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/LinearRegressionSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.ml.regression
import scala.util.Random
import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.feature.Instance
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.MLTestingUtils
import org.apache.spark.mllib.regression.LabeledPoint