From 4a11d029dc6abeb98fef5725d3d446a3eb5deddf Mon Sep 17 00:00:00 2001 From: Yuhao Date: Mon, 23 Jan 2017 12:18:06 -0800 Subject: [SPARK-14709][ML] spark.ml API for linear SVM ## What changes were proposed in this pull request? jira: https://issues.apache.org/jira/browse/SPARK-14709 Provide API for SVM algorithm for DataFrames. As discussed in jira, the initial implementation uses OWL-QN with Hinge loss function. The API should mimic existing spark.ml.classification APIs. Currently only Binary Classification is supported. Multinomial support can be added in this or following release. ## How was this patch tested? new unit tests and simple manual test Author: Yuhao Author: Yuhao Yang Closes #15211 from hhbyyh/mlsvm. --- .../apache/spark/ml/classification/LinearSVC.scala | 546 +++++++++++++++++++++ .../ml/classification/LogisticRegression.scala | 4 +- .../spark/ml/classification/LinearSVCSuite.scala | 241 +++++++++ 3 files changed, 789 insertions(+), 2 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala (limited to 'mllib') diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala new file mode 100644 index 0000000000..c4e93bf5e8 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LinearSVC.scala @@ -0,0 +1,546 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +import scala.collection.mutable + +import breeze.linalg.{DenseVector => BDV} +import breeze.optimize.{CachedDiffFunction, DiffFunction, OWLQN => BreezeOWLQN} +import org.apache.hadoop.fs.Path + +import org.apache.spark.SparkException +import org.apache.spark.annotation.{Experimental, Since} +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.internal.Logging +import org.apache.spark.ml.feature.Instance +import org.apache.spark.ml.linalg._ +import org.apache.spark.ml.linalg.BLAS._ +import org.apache.spark.ml.param._ +import org.apache.spark.ml.param.shared._ +import org.apache.spark.ml.util._ +import org.apache.spark.mllib.linalg.VectorImplicits._ +import org.apache.spark.mllib.stat.MultivariateOnlineSummarizer +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.{Dataset, Row} +import org.apache.spark.sql.functions.{col, lit} + +/** Params for linear SVM Classifier. */ +private[classification] trait LinearSVCParams extends ClassifierParams with HasRegParam + with HasMaxIter with HasFitIntercept with HasTol with HasStandardization with HasWeightCol + with HasThreshold with HasAggregationDepth + +/** + * :: Experimental :: + * + * Linear SVM Classifier (https://en.wikipedia.org/wiki/Support_vector_machine#Linear_SVM) + * + * This binary classifier optimizes the Hinge Loss using the OWLQN optimizer. + * + */ +@Since("2.2.0") +@Experimental +class LinearSVC @Since("2.2.0") ( + @Since("2.2.0") override val uid: String) + extends Classifier[Vector, LinearSVC, LinearSVCModel] + with LinearSVCParams with DefaultParamsWritable { + + @Since("2.2.0") + def this() = this(Identifiable.randomUID("linearsvc")) + + /** + * Set the regularization parameter. + * Default is 0.0. + * + * @group setParam + */ + @Since("2.2.0") + def setRegParam(value: Double): this.type = set(regParam, value) + setDefault(regParam -> 0.0) + + /** + * Set the maximum number of iterations. + * Default is 100. + * + * @group setParam + */ + @Since("2.2.0") + def setMaxIter(value: Int): this.type = set(maxIter, value) + setDefault(maxIter -> 100) + + /** + * Whether to fit an intercept term. + * Default is true. + * + * @group setParam + */ + @Since("2.2.0") + def setFitIntercept(value: Boolean): this.type = set(fitIntercept, value) + setDefault(fitIntercept -> true) + + /** + * Set the convergence tolerance of iterations. + * Smaller values will lead to higher accuracy at the cost of more iterations. + * Default is 1E-6. + * + * @group setParam + */ + @Since("2.2.0") + def setTol(value: Double): this.type = set(tol, value) + setDefault(tol -> 1E-6) + + /** + * Whether to standardize the training features before fitting the model. + * Default is true. + * + * @group setParam + */ + @Since("2.2.0") + def setStandardization(value: Boolean): this.type = set(standardization, value) + setDefault(standardization -> true) + + /** + * Sets the value of param [[weightCol]]. + * If this is not set or empty, we treat all instance weights as 1.0. + * Default is not set, so all instances have weight one. + * + * @group setParam + */ + @Since("2.2.0") + def setWeightCol(value: String): this.type = set(weightCol, value) + + /** + * Set threshold in binary classification, in range [0, 1]. + * + * @group setParam + */ + @Since("2.2.0") + def setThreshold(value: Double): this.type = set(threshold, value) + setDefault(threshold -> 0.0) + + /** + * Suggested depth for treeAggregate (greater than or equal to 2). + * If the dimensions of features or the number of partitions are large, + * this param could be adjusted to a larger size. + * Default is 2. + * + * @group expertSetParam + */ + @Since("2.2.0") + def setAggregationDepth(value: Int): this.type = set(aggregationDepth, value) + setDefault(aggregationDepth -> 2) + + @Since("2.2.0") + override def copy(extra: ParamMap): LinearSVC = defaultCopy(extra) + + override protected[classification] def train(dataset: Dataset[_]): LinearSVCModel = { + val w = if (!isDefined(weightCol) || $(weightCol).isEmpty) lit(1.0) else col($(weightCol)) + val instances: RDD[Instance] = + dataset.select(col($(labelCol)), w, col($(featuresCol))).rdd.map { + case Row(label: Double, weight: Double, features: Vector) => + Instance(label, weight, features) + } + + val instr = Instrumentation.create(this, instances) + instr.logParams(regParam, maxIter, fitIntercept, tol, standardization, threshold, + aggregationDepth) + + val (summarizer, labelSummarizer) = { + val seqOp = (c: (MultivariateOnlineSummarizer, MultiClassSummarizer), + instance: Instance) => + (c._1.add(instance.features, instance.weight), c._2.add(instance.label, instance.weight)) + + val combOp = (c1: (MultivariateOnlineSummarizer, MultiClassSummarizer), + c2: (MultivariateOnlineSummarizer, MultiClassSummarizer)) => + (c1._1.merge(c2._1), c1._2.merge(c2._2)) + + instances.treeAggregate( + new MultivariateOnlineSummarizer, new MultiClassSummarizer + )(seqOp, combOp, $(aggregationDepth)) + } + + val histogram = labelSummarizer.histogram + val numInvalid = labelSummarizer.countInvalid + val numFeatures = summarizer.mean.size + val numFeaturesPlusIntercept = if (getFitIntercept) numFeatures + 1 else numFeatures + + val numClasses = MetadataUtils.getNumClasses(dataset.schema($(labelCol))) match { + case Some(n: Int) => + require(n >= histogram.length, s"Specified number of classes $n was " + + s"less than the number of unique labels ${histogram.length}.") + n + case None => histogram.length + } + require(numClasses == 2, s"LinearSVC only supports binary classification." + + s" $numClasses classes detected in $labelCol") + instr.logNumClasses(numClasses) + instr.logNumFeatures(numFeatures) + + val (coefficientVector, interceptVector, objectiveHistory) = { + if (numInvalid != 0) { + val msg = s"Classification labels should be in [0 to ${numClasses - 1}]. " + + s"Found $numInvalid invalid labels." + logError(msg) + throw new SparkException(msg) + } + + val featuresStd = summarizer.variance.toArray.map(math.sqrt) + val regParamL2 = $(regParam) + val bcFeaturesStd = instances.context.broadcast(featuresStd) + val costFun = new LinearSVCCostFun(instances, $(fitIntercept), + $(standardization), bcFeaturesStd, regParamL2, $(aggregationDepth)) + + def regParamL1Fun = (index: Int) => 0D + val optimizer = new BreezeOWLQN[Int, BDV[Double]]($(maxIter), 10, regParamL1Fun, $(tol)) + val initialCoefWithIntercept = Vectors.zeros(numFeaturesPlusIntercept) + + val states = optimizer.iterations(new CachedDiffFunction(costFun), + initialCoefWithIntercept.asBreeze.toDenseVector) + + val scaledObjectiveHistory = mutable.ArrayBuilder.make[Double] + var state: optimizer.State = null + while (states.hasNext) { + state = states.next() + scaledObjectiveHistory += state.adjustedValue + } + + bcFeaturesStd.destroy(blocking = false) + if (state == null) { + val msg = s"${optimizer.getClass.getName} failed." + logError(msg) + throw new SparkException(msg) + } + + /* + The coefficients are trained in the scaled space; we're converting them back to + the original space. + Note that the intercept in scaled space and original space is the same; + as a result, no scaling is needed. + */ + val rawCoefficients = state.x.toArray + val coefficientArray = Array.tabulate(numFeatures) { i => + if (featuresStd(i) != 0.0) { + rawCoefficients(i) / featuresStd(i) + } else { + 0.0 + } + } + + val intercept = if ($(fitIntercept)) { + rawCoefficients(numFeaturesPlusIntercept - 1) + } else { + 0.0 + } + (Vectors.dense(coefficientArray), intercept, scaledObjectiveHistory.result()) + } + + val model = copyValues(new LinearSVCModel(uid, coefficientVector, interceptVector)) + instr.logSuccess(model) + model + } +} + +@Since("2.2.0") +object LinearSVC extends DefaultParamsReadable[LinearSVC] { + + @Since("2.2.0") + override def load(path: String): LinearSVC = super.load(path) +} + +/** + * :: Experimental :: + * SVM Model trained by [[LinearSVC]] + */ +@Since("2.2.0") +@Experimental +class LinearSVCModel private[classification] ( + @Since("2.2.0") override val uid: String, + @Since("2.2.0") val coefficients: Vector, + @Since("2.2.0") val intercept: Double) + extends ClassificationModel[Vector, LinearSVCModel] + with LinearSVCParams with MLWritable { + + @Since("2.2.0") + override val numClasses: Int = 2 + + @Since("2.2.0") + override val numFeatures: Int = coefficients.size + + @Since("2.2.0") + def setThreshold(value: Double): this.type = set(threshold, value) + + @Since("2.2.0") + def setWeightCol(value: Double): this.type = set(threshold, value) + + private val margin: Vector => Double = (features) => { + BLAS.dot(features, coefficients) + intercept + } + + override protected def predict(features: Vector): Double = { + if (margin(features) > $(threshold)) 1.0 else 0.0 + } + + override protected def predictRaw(features: Vector): Vector = { + val m = margin(features) + Vectors.dense(-m, m) + } + + @Since("2.2.0") + override def copy(extra: ParamMap): LinearSVCModel = { + copyValues(new LinearSVCModel(uid, coefficients, intercept), extra).setParent(parent) + } + + @Since("2.2.0") + override def write: MLWriter = new LinearSVCModel.LinearSVCWriter(this) + +} + + +@Since("2.2.0") +object LinearSVCModel extends MLReadable[LinearSVCModel] { + + @Since("2.2.0") + override def read: MLReader[LinearSVCModel] = new LinearSVCReader + + @Since("2.2.0") + override def load(path: String): LinearSVCModel = super.load(path) + + /** [[MLWriter]] instance for [[LinearSVCModel]] */ + private[LinearSVCModel] + class LinearSVCWriter(instance: LinearSVCModel) + extends MLWriter with Logging { + + private case class Data(coefficients: Vector, intercept: Double) + + override protected def saveImpl(path: String): Unit = { + // Save metadata and Params + DefaultParamsWriter.saveMetadata(instance, path, sc) + val data = Data(instance.coefficients, instance.intercept) + val dataPath = new Path(path, "data").toString + sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath) + } + } + + private class LinearSVCReader extends MLReader[LinearSVCModel] { + + /** Checked against metadata when loading model */ + private val className = classOf[LinearSVCModel].getName + + override def load(path: String): LinearSVCModel = { + val metadata = DefaultParamsReader.loadMetadata(path, sc, className) + val dataPath = new Path(path, "data").toString + val data = sparkSession.read.format("parquet").load(dataPath) + val Row(coefficients: Vector, intercept: Double) = + data.select("coefficients", "intercept").head() + val model = new LinearSVCModel(metadata.uid, coefficients, intercept) + DefaultParamsReader.getAndSetParams(model, metadata) + model + } + } +} + +/** + * LinearSVCCostFun implements Breeze's DiffFunction[T] for hinge loss function + */ +private class LinearSVCCostFun( + instances: RDD[Instance], + fitIntercept: Boolean, + standardization: Boolean, + bcFeaturesStd: Broadcast[Array[Double]], + regParamL2: Double, + aggregationDepth: Int) extends DiffFunction[BDV[Double]] { + + override def calculate(coefficients: BDV[Double]): (Double, BDV[Double]) = { + val coeffs = Vectors.fromBreeze(coefficients) + val bcCoeffs = instances.context.broadcast(coeffs) + val featuresStd = bcFeaturesStd.value + val numFeatures = featuresStd.length + + val svmAggregator = { + val seqOp = (c: LinearSVCAggregator, instance: Instance) => c.add(instance) + val combOp = (c1: LinearSVCAggregator, c2: LinearSVCAggregator) => c1.merge(c2) + + instances.treeAggregate( + new LinearSVCAggregator(bcCoeffs, bcFeaturesStd, fitIntercept) + )(seqOp, combOp, aggregationDepth) + } + + val totalGradientArray = svmAggregator.gradient.toArray + // regVal is the sum of coefficients squares excluding intercept for L2 regularization. + val regVal = if (regParamL2 == 0.0) { + 0.0 + } else { + var sum = 0.0 + coeffs.foreachActive { case (index, value) => + // We do not apply regularization to the intercepts + if (index != numFeatures) { + // The following code will compute the loss of the regularization; also + // the gradient of the regularization, and add back to totalGradientArray. + sum += { + if (standardization) { + totalGradientArray(index) += regParamL2 * value + value * value + } else { + if (featuresStd(index) != 0.0) { + // If `standardization` is false, we still standardize the data + // to improve the rate of convergence; as a result, we have to + // perform this reverse standardization by penalizing each component + // differently to get effectively the same objective function when + // the training dataset is not standardized. + val temp = value / (featuresStd(index) * featuresStd(index)) + totalGradientArray(index) += regParamL2 * temp + value * temp + } else { + 0.0 + } + } + } + } + } + 0.5 * regParamL2 * sum + } + bcCoeffs.destroy(blocking = false) + + (svmAggregator.loss + regVal, new BDV(totalGradientArray)) + } +} + +/** + * LinearSVCAggregator computes the gradient and loss for hinge loss function, as used + * in binary classification for instances in sparse or dense vector in a online fashion. + * + * Two LinearSVCAggregator can be merged together to have a summary of loss and gradient of + * the corresponding joint dataset. + * + * This class standardizes feature values during computation using bcFeaturesStd. + * + * @param bcCoefficients The coefficients corresponding to the features. + * @param fitIntercept Whether to fit an intercept term. + * @param bcFeaturesStd The standard deviation values of the features. + */ +private class LinearSVCAggregator( + bcCoefficients: Broadcast[Vector], + bcFeaturesStd: Broadcast[Array[Double]], + fitIntercept: Boolean) extends Serializable { + + private val numFeatures: Int = bcFeaturesStd.value.length + private val numFeaturesPlusIntercept: Int = if (fitIntercept) numFeatures + 1 else numFeatures + private val coefficients: Vector = bcCoefficients.value + private var weightSum: Double = 0.0 + private var lossSum: Double = 0.0 + require(numFeaturesPlusIntercept == coefficients.size, s"Dimension mismatch. Coefficients " + + s"length ${coefficients.size}, FeaturesStd length ${numFeatures}, fitIntercept: $fitIntercept") + + private val coefficientsArray = coefficients match { + case dv: DenseVector => dv.values + case _ => + throw new IllegalArgumentException( + s"coefficients only supports dense vector but got type ${coefficients.getClass}.") + } + private val gradientSumArray = Array.fill[Double](coefficientsArray.length)(0) + + /** + * Add a new training instance to this LinearSVCAggregator, and update the loss and gradient + * of the objective function. + * + * @param instance The instance of data point to be added. + * @return This LinearSVCAggregator object. + */ + def add(instance: Instance): this.type = { + instance match { case Instance(label, weight, features) => + if (weight == 0.0) return this + val localFeaturesStd = bcFeaturesStd.value + val localCoefficients = coefficientsArray + val localGradientSumArray = gradientSumArray + + val dotProduct = { + var sum = 0.0 + features.foreachActive { (index, value) => + if (localFeaturesStd(index) != 0.0 && value != 0.0) { + sum += localCoefficients(index) * value / localFeaturesStd(index) + } + } + if (fitIntercept) sum += localCoefficients(numFeaturesPlusIntercept - 1) + sum + } + // Our loss function with {0, 1} labels is max(0, 1 - (2y - 1) (f_w(x))) + // Therefore the gradient is -(2y - 1)*x + val labelScaled = 2 * label - 1.0 + val loss = if (1.0 > labelScaled * dotProduct) { + weight * (1.0 - labelScaled * dotProduct) + } else { + 0.0 + } + + if (1.0 > labelScaled * dotProduct) { + val gradientScale = -labelScaled * weight + features.foreachActive { (index, value) => + if (localFeaturesStd(index) != 0.0 && value != 0.0) { + localGradientSumArray(index) += value * gradientScale / localFeaturesStd(index) + } + } + if (fitIntercept) { + localGradientSumArray(localGradientSumArray.length - 1) += gradientScale + } + } + + lossSum += loss + weightSum += weight + this + } + } + + /** + * Merge another LinearSVCAggregator, and update the loss and gradient + * of the objective function. + * (Note that it's in place merging; as a result, `this` object will be modified.) + * + * @param other The other LinearSVCAggregator to be merged. + * @return This LinearSVCAggregator object. + */ + def merge(other: LinearSVCAggregator): this.type = { + if (other.weightSum != 0.0) { + weightSum += other.weightSum + lossSum += other.lossSum + + var i = 0 + val localThisGradientSumArray = this.gradientSumArray + val localOtherGradientSumArray = other.gradientSumArray + val len = localThisGradientSumArray.length + while (i < len) { + localThisGradientSumArray(i) += localOtherGradientSumArray(i) + i += 1 + } + } + this + } + + def loss: Double = { + if (weightSum != 0) { + lossSum / weightSum + } else 0.0 + } + + def gradient: Vector = { + if (weightSum != 0) { + val result = Vectors.dense(gradientSumArray.clone()) + scal(1.0 / weightSum, result) + result + } else Vectors.dense(Array.fill[Double](coefficientsArray.length)(0)) + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala index 5e1d6eec96..d2b0f2a9e1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala @@ -233,7 +233,7 @@ class LogisticRegression @Since("1.2.0") ( /** * Set the convergence tolerance of iterations. - * Smaller value will lead to higher accuracy with the cost of more iterations. + * Smaller value will lead to higher accuracy at the cost of more iterations. * Default is 1E-6. * * @group setParam @@ -1431,7 +1431,7 @@ private class LogisticAggregator( private var weightSum = 0.0 private var lossSum = 0.0 - private val gradientSumArray = Array.ofDim[Double](coefficientSize) + private val gradientSumArray = Array.fill[Double](coefficientSize)(0.0D) if (multinomial && numClasses <= 2) { logInfo(s"Multinomial logistic regression for binary classification yields separate " + diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala new file mode 100644 index 0000000000..ee2aefee7a --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala @@ -0,0 +1,241 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.classification + +import scala.util.Random + +import breeze.linalg.{DenseVector => BDV} + +import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.classification.LinearSVCSuite._ +import org.apache.spark.ml.feature.LabeledPoint +import org.apache.spark.ml.linalg.{Vector, Vectors} +import org.apache.spark.ml.param.ParamsSuite +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} +import org.apache.spark.ml.util.TestingUtils._ +import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.sql.{Dataset, Row} + + +class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { + + import testImplicits._ + + private val nPoints = 50 + @transient var smallBinaryDataset: Dataset[_] = _ + @transient var smallValidationDataset: Dataset[_] = _ + @transient var binaryDataset: Dataset[_] = _ + + override def beforeAll(): Unit = { + super.beforeAll() + + // NOTE: Intercept should be small for generating equal 0s and 1s + val A = 0.01 + val B = -1.5 + val C = 1.0 + smallBinaryDataset = generateSVMInput(A, Array[Double](B, C), nPoints, 42).toDF() + smallValidationDataset = generateSVMInput(A, Array[Double](B, C), nPoints, 17).toDF() + binaryDataset = generateSVMInput(1.0, Array[Double](1.0, 2.0, 3.0, 4.0), 10000, 42).toDF() + } + + /** + * Enable the ignored test to export the dataset into CSV format, + * so we can validate the training accuracy compared with R's e1071 package. + */ + ignore("export test data into CSV format") { + binaryDataset.rdd.map { case Row(label: Double, features: Vector) => + label + "," + features.toArray.mkString(",") + }.repartition(1).saveAsTextFile("target/tmp/LinearSVC/binaryDataset") + } + + test("Linear SVC binary classification") { + val svm = new LinearSVC() + val model = svm.fit(smallBinaryDataset) + assert(model.transform(smallValidationDataset) + .where("prediction=label").count() > nPoints * 0.8) + } + + test("Linear SVC binary classification with regularization") { + val svm = new LinearSVC() + val model = svm.setRegParam(0.1).fit(smallBinaryDataset) + assert(model.transform(smallValidationDataset) + .where("prediction=label").count() > nPoints * 0.8) + } + + test("params") { + ParamsSuite.checkParams(new LinearSVC) + val model = new LinearSVCModel("linearSVC", Vectors.dense(0.0), 0.0) + ParamsSuite.checkParams(model) + } + + test("linear svc: default params") { + val lsvc = new LinearSVC() + assert(lsvc.getRegParam === 0.0) + assert(lsvc.getMaxIter === 100) + assert(lsvc.getFitIntercept) + assert(lsvc.getTol === 1E-6) + assert(lsvc.getStandardization) + assert(!lsvc.isDefined(lsvc.weightCol)) + assert(lsvc.getThreshold === 0.0) + assert(lsvc.getAggregationDepth === 2) + assert(lsvc.getLabelCol === "label") + assert(lsvc.getFeaturesCol === "features") + assert(lsvc.getPredictionCol === "prediction") + assert(lsvc.getRawPredictionCol === "rawPrediction") + val model = lsvc.setMaxIter(5).fit(smallBinaryDataset) + model.transform(smallBinaryDataset) + .select("label", "prediction", "rawPrediction") + .collect() + assert(model.getThreshold === 0.0) + assert(model.getFeaturesCol === "features") + assert(model.getPredictionCol === "prediction") + assert(model.getRawPredictionCol === "rawPrediction") + assert(model.intercept !== 0.0) + assert(model.hasParent) + assert(model.numFeatures === 2) + + // copied model must have the same parent. + MLTestingUtils.checkCopy(model) + } + + test("linear svc doesn't fit intercept when fitIntercept is off") { + val lsvc = new LinearSVC().setFitIntercept(false).setMaxIter(5) + val model = lsvc.fit(smallBinaryDataset) + assert(model.intercept === 0.0) + + val lsvc2 = new LinearSVC().setFitIntercept(true).setMaxIter(5) + val model2 = lsvc2.fit(smallBinaryDataset) + assert(model2.intercept !== 0.0) + } + + test("linearSVC with sample weights") { + def modelEquals(m1: LinearSVCModel, m2: LinearSVCModel): Unit = { + assert(m1.coefficients ~== m2.coefficients absTol 0.05) + assert(m1.intercept ~== m2.intercept absTol 0.05) + } + + val estimator = new LinearSVC().setRegParam(0.01).setTol(0.01) + val dataset = smallBinaryDataset + MLTestingUtils.testArbitrarilyScaledWeights[LinearSVCModel, LinearSVC]( + dataset.as[LabeledPoint], estimator, modelEquals) + MLTestingUtils.testOutliersWithSmallWeights[LinearSVCModel, LinearSVC]( + dataset.as[LabeledPoint], estimator, 2, modelEquals) + MLTestingUtils.testOversamplingVsWeighting[LinearSVCModel, LinearSVC]( + dataset.as[LabeledPoint], estimator, modelEquals, 42L) + } + + test("linearSVC comparison with R e1071 and scikit-learn") { + val trainer1 = new LinearSVC() + .setRegParam(0.00002) // set regParam = 2.0 / datasize / c + .setMaxIter(200) + .setTol(1e-4) + val model1 = trainer1.fit(binaryDataset) + + /* + Use the following R code to load the data and train the model using glmnet package. + + library(e1071) + data <- read.csv("path/target/tmp/LinearSVC/binaryDataset/part-00000", header=FALSE) + label <- factor(data$V1) + features <- as.matrix(data.frame(data$V2, data$V3, data$V4, data$V5)) + svm_model <- svm(features, label, type='C', kernel='linear', cost=10, scale=F, tolerance=1e-4) + w <- -t(svm_model$coefs) %*% svm_model$SV + w + svm_model$rho + + > w + data.V2 data.V3 data.V4 data.V5 + [1,] 7.310338 14.89741 22.21005 29.83508 + > svm_model$rho + [1] 7.440177 + + */ + val coefficientsR = Vectors.dense(7.310338, 14.89741, 22.21005, 29.83508) + val interceptR = 7.440177 + assert(model1.intercept ~== interceptR relTol 1E-2) + assert(model1.coefficients ~== coefficientsR relTol 1E-2) + + /* + Use the following python code to load the data and train the model using scikit-learn package. + + import numpy as np + from sklearn import svm + f = open("path/target/tmp/LinearSVC/binaryDataset/part-00000") + data = np.loadtxt(f, delimiter=",") + X = data[:, 1:] # select columns 1 through end + y = data[:, 0] # select column 0 as label + clf = svm.LinearSVC(fit_intercept=True, C=10, loss='hinge', tol=1e-4, random_state=42) + m = clf.fit(X, y) + print m.coef_ + print m.intercept_ + + [[ 7.24690165 14.77029087 21.99924004 29.5575729 ]] + [ 7.36947518] + */ + + val coefficientsSK = Vectors.dense(7.24690165, 14.77029087, 21.99924004, 29.5575729) + val interceptSK = 7.36947518 + assert(model1.intercept ~== interceptSK relTol 1E-3) + assert(model1.coefficients ~== coefficientsSK relTol 4E-3) + } + + test("read/write: SVM") { + def checkModelData(model: LinearSVCModel, model2: LinearSVCModel): Unit = { + assert(model.intercept === model2.intercept) + assert(model.coefficients === model2.coefficients) + assert(model.numFeatures === model2.numFeatures) + } + val svm = new LinearSVC() + testEstimatorAndModelReadWrite(svm, smallBinaryDataset, LinearSVCSuite.allParamSettings, + checkModelData) + } +} + +object LinearSVCSuite { + + val allParamSettings: Map[String, Any] = Map( + "regParam" -> 0.01, + "maxIter" -> 2, // intentionally small + "fitIntercept" -> true, + "tol" -> 0.8, + "standardization" -> false, + "threshold" -> 0.6, + "predictionCol" -> "myPredict", + "rawPredictionCol" -> "myRawPredict", + "aggregationDepth" -> 3 + ) + + // Generate noisy input of the form Y = signum(x.dot(weights) + intercept + noise) + def generateSVMInput( + intercept: Double, + weights: Array[Double], + nPoints: Int, + seed: Int): Seq[LabeledPoint] = { + val rnd = new Random(seed) + val weightsMat = new BDV(weights) + val x = Array.fill[Array[Double]](nPoints)( + Array.fill[Double](weights.length)(rnd.nextDouble() * 2.0 - 1.0)) + val y = x.map { xi => + val yD = new BDV(xi).dot(weightsMat) + intercept + 0.01 * rnd.nextGaussian() + if (yD > 0) 1.0 else 0.0 + } + y.zip(x).map(p => LabeledPoint(p._1, Vectors.dense(p._2))) + } + +} + -- cgit v1.2.3