aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala36
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala82
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala28
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala34
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala25
6 files changed, 179 insertions, 28 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index c98a78a515..9b2340a1f1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -247,15 +247,27 @@ class LogisticRegression @Since("1.2.0") (
@Since("1.5.0")
override def getThresholds: Array[Double] = super.getThresholds
- override protected def train(dataset: DataFrame): LogisticRegressionModel = {
- // Extract columns from data. If dataset is persisted, do not persist oldDataset.
+ private var optInitialModel: Option[LogisticRegressionModel] = None
+
+ /** @group setParam */
+ private[spark] def setInitialModel(model: LogisticRegressionModel): this.type = {
+ this.optInitialModel = Some(model)
+ this
+ }
+
+ override protected[spark] def train(dataset: DataFrame): LogisticRegressionModel = {
+ val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
+ train(dataset, handlePersistence)
+ }
+
+ protected[spark] def train(dataset: DataFrame, handlePersistence: Boolean):
+ LogisticRegressionModel = {
val w = if ($(weightCol).isEmpty) lit(1.0) else col($(weightCol))
val instances: RDD[Instance] = dataset.select(col($(labelCol)), w, col($(featuresCol))).map {
case Row(label: Double, weight: Double, features: Vector) =>
Instance(label, weight, features)
}
- val handlePersistence = dataset.rdd.getStorageLevel == StorageLevel.NONE
if (handlePersistence) instances.persist(StorageLevel.MEMORY_AND_DISK)
val (summarizer, labelSummarizer) = {
@@ -343,7 +355,21 @@ class LogisticRegression @Since("1.2.0") (
val initialCoefficientsWithIntercept =
Vectors.zeros(if ($(fitIntercept)) numFeatures + 1 else numFeatures)
- if ($(fitIntercept)) {
+ if (optInitialModel.isDefined && optInitialModel.get.coefficients.size != numFeatures) {
+ val vec = optInitialModel.get.coefficients
+ logWarning(
+ s"Initial coefficients provided ${vec} did not match the expected size ${numFeatures}")
+ }
+
+ if (optInitialModel.isDefined && optInitialModel.get.coefficients.size == numFeatures) {
+ val initialCoefficientsWithInterceptArray = initialCoefficientsWithIntercept.toArray
+ optInitialModel.get.coefficients.foreachActive { case (index, value) =>
+ initialCoefficientsWithInterceptArray(index) = value
+ }
+ if ($(fitIntercept)) {
+ initialCoefficientsWithInterceptArray(numFeatures) == optInitialModel.get.intercept
+ }
+ } else if ($(fitIntercept)) {
/*
For binary logistic regression, when we initialize the coefficients as zeros,
it will converge faster if we initialize the intercept such that
@@ -434,7 +460,7 @@ object LogisticRegression extends DefaultParamsReadable[LogisticRegression] {
*/
@Since("1.4.0")
@Experimental
-class LogisticRegressionModel private[ml] (
+class LogisticRegressionModel private[spark] (
@Since("1.4.0") override val uid: String,
@Since("1.6.0") val coefficients: Vector,
@Since("1.3.0") val intercept: Double)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
index 2a7697b5a7..bf68e3edd7 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/LogisticRegression.scala
@@ -19,15 +19,18 @@ package org.apache.spark.mllib.classification
import org.apache.spark.SparkContext
import org.apache.spark.annotation.Since
+import org.apache.spark.ml.util.Identifiable
import org.apache.spark.mllib.classification.impl.GLMClassificationModel
-import org.apache.spark.mllib.linalg.{DenseVector, Vector}
+import org.apache.spark.mllib.linalg.{DenseVector, Vector, Vectors}
import org.apache.spark.mllib.linalg.BLAS.dot
import org.apache.spark.mllib.optimization._
import org.apache.spark.mllib.pmml.PMMLExportable
import org.apache.spark.mllib.regression._
import org.apache.spark.mllib.util.{DataValidators, Loader, Saveable}
+import org.apache.spark.mllib.util.MLUtils.appendBias
import org.apache.spark.rdd.RDD
-
+import org.apache.spark.sql.SQLContext
+import org.apache.spark.storage.StorageLevel
/**
* Classification model trained using Multinomial/Binary Logistic Regression.
@@ -332,6 +335,13 @@ object LogisticRegressionWithSGD {
* Limited-memory BFGS. Standard feature scaling and L2 regularization are used by default.
* NOTE: Labels used in Logistic Regression should be {0, 1, ..., k - 1}
* for k classes multi-label classification problem.
+ *
+ * Earlier implementations of LogisticRegressionWithLBFGS applies a regularization
+ * penalty to all elements including the intercept. If this is called with one of
+ * standard updaters (L1Updater, or SquaredL2Updater) this is translated
+ * into a call to ml.LogisticRegression, otherwise this will use the existing mllib
+ * GeneralizedLinearAlgorithm trainer, resulting in a regularization penalty to the
+ * intercept.
*/
@Since("1.1.0")
class LogisticRegressionWithLBFGS
@@ -374,4 +384,72 @@ class LogisticRegressionWithLBFGS
new LogisticRegressionModel(weights, intercept, numFeatures, numOfLinearPredictor + 1)
}
}
+
+ /**
+ * Run Logistic Regression with the configured parameters on an input RDD
+ * of LabeledPoint entries.
+ *
+ * If a known updater is used calls the ml implementation, to avoid
+ * applying a regularization penalty to the intercept, otherwise
+ * defaults to the mllib implementation. If more than two classes
+ * or feature scaling is disabled, always uses mllib implementation.
+ * If using ml implementation, uses ml code to generate initial weights.
+ */
+ override def run(input: RDD[LabeledPoint]): LogisticRegressionModel = {
+ run(input, generateInitialWeights(input), userSuppliedWeights = false)
+ }
+
+ /**
+ * Run Logistic Regression with the configured parameters on an input RDD
+ * of LabeledPoint entries starting from the initial weights provided.
+ *
+ * If a known updater is used calls the ml implementation, to avoid
+ * applying a regularization penalty to the intercept, otherwise
+ * defaults to the mllib implementation. If more than two classes
+ * or feature scaling is disabled, always uses mllib implementation.
+ * Uses user provided weights.
+ */
+ override def run(input: RDD[LabeledPoint], initialWeights: Vector): LogisticRegressionModel = {
+ run(input, initialWeights, userSuppliedWeights = true)
+ }
+
+ private def run(input: RDD[LabeledPoint], initialWeights: Vector, userSuppliedWeights: Boolean):
+ LogisticRegressionModel = {
+ // ml's Logisitic regression only supports binary classifcation currently.
+ if (numOfLinearPredictor == 1) {
+ def runWithMlLogisitcRegression(elasticNetParam: Double) = {
+ // Prepare the ml LogisticRegression based on our settings
+ val lr = new org.apache.spark.ml.classification.LogisticRegression()
+ lr.setRegParam(optimizer.getRegParam())
+ lr.setElasticNetParam(elasticNetParam)
+ lr.setStandardization(useFeatureScaling)
+ if (userSuppliedWeights) {
+ val uid = Identifiable.randomUID("logreg-static")
+ lr.setInitialModel(new org.apache.spark.ml.classification.LogisticRegressionModel(
+ uid, initialWeights, 1.0))
+ }
+ lr.setFitIntercept(addIntercept)
+ lr.setMaxIter(optimizer.getNumIterations())
+ lr.setTol(optimizer.getConvergenceTol())
+ // Convert our input into a DataFrame
+ val sqlContext = new SQLContext(input.context)
+ import sqlContext.implicits._
+ val df = input.toDF()
+ // Determine if we should cache the DF
+ val handlePersistence = input.getStorageLevel == StorageLevel.NONE
+ // Train our model
+ val mlLogisticRegresionModel = lr.train(df, handlePersistence)
+ // convert the model
+ val weights = Vectors.dense(mlLogisticRegresionModel.coefficients.toArray)
+ createModel(weights, mlLogisticRegresionModel.intercept)
+ }
+ optimizer.getUpdater() match {
+ case x: SquaredL2Updater => runWithMlLogisitcRegression(1.0)
+ case x: L1Updater => runWithMlLogisitcRegression(0.0)
+ case _ => super.run(input, initialWeights)
+ }
+ } else {
+ super.run(input, initialWeights)
+ }
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala
index efedc112d3..a5bd77e6be 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/optimization/LBFGS.scala
@@ -69,6 +69,13 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater)
this
}
+ /*
+ * Get the convergence tolerance of iterations.
+ */
+ private[mllib] def getConvergenceTol(): Double = {
+ this.convergenceTol
+ }
+
/**
* Set the maximal number of iterations for L-BFGS. Default 100.
* @deprecated use [[LBFGS#setNumIterations]] instead
@@ -87,6 +94,13 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater)
}
/**
+ * Get the maximum number of iterations for L-BFGS. Defaults to 100.
+ */
+ private[mllib] def getNumIterations(): Int = {
+ this.maxNumIterations
+ }
+
+ /**
* Set the regularization parameter. Default 0.0.
*/
def setRegParam(regParam: Double): this.type = {
@@ -95,6 +109,13 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater)
}
/**
+ * Get the regularization parameter.
+ */
+ private[mllib] def getRegParam(): Double = {
+ this.regParam
+ }
+
+ /**
* Set the gradient function (of the loss function of one single data example)
* to be used for L-BFGS.
*/
@@ -113,6 +134,13 @@ class LBFGS(private var gradient: Gradient, private var updater: Updater)
this
}
+ /**
+ * Returns the updater, limited to internal use.
+ */
+ private[mllib] def getUpdater(): Updater = {
+ updater
+ }
+
override def optimize(data: RDD[(Double, Vector)], initialWeights: Vector): Vector = {
val (weights, _) = LBFGS.runLBFGS(
data,
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
index e60edc675c..73da899a0e 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/regression/GeneralizedLinearAlgorithm.scala
@@ -140,7 +140,7 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
* translated back to resulting model weights, so it's transparent to users.
* Note: This technique is used in both libsvm and glmnet packages. Default false.
*/
- private var useFeatureScaling = false
+ private[mllib] var useFeatureScaling = false
/**
* The dimension of training features.
@@ -196,12 +196,9 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
}
/**
- * Run the algorithm with the configured parameters on an input
- * RDD of LabeledPoint entries.
- *
+ * Generate the initial weights when the user does not supply them
*/
- @Since("0.8.0")
- def run(input: RDD[LabeledPoint]): M = {
+ protected def generateInitialWeights(input: RDD[LabeledPoint]): Vector = {
if (numFeatures < 0) {
numFeatures = input.map(_.features.size).first()
}
@@ -217,16 +214,23 @@ abstract class GeneralizedLinearAlgorithm[M <: GeneralizedLinearModel]
* TODO: See if we can deprecate `intercept` in `GeneralizedLinearModel`, and always
* have the intercept as part of weights to have consistent design.
*/
- val initialWeights = {
- if (numOfLinearPredictor == 1) {
- Vectors.zeros(numFeatures)
- } else if (addIntercept) {
- Vectors.zeros((numFeatures + 1) * numOfLinearPredictor)
- } else {
- Vectors.zeros(numFeatures * numOfLinearPredictor)
- }
+ if (numOfLinearPredictor == 1) {
+ Vectors.zeros(numFeatures)
+ } else if (addIntercept) {
+ Vectors.zeros((numFeatures + 1) * numOfLinearPredictor)
+ } else {
+ Vectors.zeros(numFeatures * numOfLinearPredictor)
}
- run(input, initialWeights)
+ }
+
+ /**
+ * Run the algorithm with the configured parameters on an input
+ * RDD of LabeledPoint entries.
+ *
+ */
+ @Since("0.8.0")
+ def run(input: RDD[LabeledPoint]): M = {
+ run(input, generateInitialWeights(input))
}
/**
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
index d7983f92a3..445e50d867 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/OneVsRestSuite.scala
@@ -168,7 +168,7 @@ private class MockLogisticRegression(uid: String) extends LogisticRegression(uid
setMaxIter(1)
- override protected def train(dataset: DataFrame): LogisticRegressionModel = {
+ override protected[spark] def train(dataset: DataFrame): LogisticRegressionModel = {
val labelSchema = dataset.schema($(labelCol))
// check for label attribute propagation.
assert(MetadataUtils.getNumClasses(labelSchema).forall(_ == 2))
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
index 8d14bb6572..8fef1316cd 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/LogisticRegressionSuite.scala
@@ -25,6 +25,7 @@ import org.scalatest.Matchers
import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.{Vector, Vectors}
+import org.apache.spark.mllib.optimization._
import org.apache.spark.mllib.regression._
import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}
import org.apache.spark.mllib.util.TestingUtils._
@@ -215,6 +216,11 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext w
// Test if we can correctly learn A, B where Y = logistic(A + B*X)
test("logistic regression with LBFGS") {
+ val updaters: List[Updater] = List(new SquaredL2Updater(), new L1Updater())
+ updaters.foreach(testLBFGS)
+ }
+
+ private def testLBFGS(myUpdater: Updater): Unit = {
val nPoints = 10000
val A = 2.0
val B = -1.5
@@ -223,7 +229,15 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext w
val testRDD = sc.parallelize(testData, 2)
testRDD.cache()
- val lr = new LogisticRegressionWithLBFGS().setIntercept(true)
+
+ // Override the updater
+ class LogisticRegressionWithLBFGSCustomUpdater
+ extends LogisticRegressionWithLBFGS {
+ override val optimizer =
+ new LBFGS(new LogisticGradient, myUpdater)
+ }
+
+ val lr = new LogisticRegressionWithLBFGSCustomUpdater().setIntercept(true)
val model = lr.run(testRDD)
@@ -396,10 +410,11 @@ class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext w
assert(modelA1.weights(0) ~== modelA3.weights(0) * 1.0E6 absTol 0.01)
// Training data with different scales without feature standardization
- // will not yield the same result in the scaled space due to poor
- // convergence rate.
- assert(modelB1.weights(0) !~== modelB2.weights(0) * 1.0E3 absTol 0.1)
- assert(modelB1.weights(0) !~== modelB3.weights(0) * 1.0E6 absTol 0.1)
+ // should still converge quickly since the model still uses standardization but
+ // simply modifies the regularization function. See regParamL1Fun and related
+ // inside of LogisticRegression
+ assert(modelB1.weights(0) ~== modelB2.weights(0) * 1.0E3 absTol 0.1)
+ assert(modelB1.weights(0) ~== modelB3.weights(0) * 1.0E6 absTol 0.1)
}
test("multinomial logistic regression with LBFGS") {