diff options
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala | 24 |
1 files changed, 22 insertions, 2 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala index a165d8a934..fe47176a4a 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala @@ -24,12 +24,13 @@ import breeze.linalg.{DenseVector => BDV} import org.apache.spark.SparkFunSuite import org.apache.spark.ml.classification.LinearSVCSuite._ import org.apache.spark.ml.feature.{Instance, LabeledPoint} -import org.apache.spark.ml.linalg.{Vector, Vectors} +import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.ml.util.TestingUtils._ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{Dataset, Row} +import org.apache.spark.sql.functions.udf class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { @@ -41,6 +42,9 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau @transient var smallValidationDataset: Dataset[_] = _ @transient var binaryDataset: Dataset[_] = _ + @transient var smallSparseBinaryDataset: Dataset[_] = _ + @transient var smallSparseValidationDataset: Dataset[_] = _ + override def beforeAll(): Unit = { super.beforeAll() @@ -51,6 +55,13 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau smallBinaryDataset = generateSVMInput(A, Array[Double](B, C), nPoints, 42).toDF() smallValidationDataset = generateSVMInput(A, Array[Double](B, C), nPoints, 17).toDF() binaryDataset = generateSVMInput(1.0, Array[Double](1.0, 2.0, 3.0, 4.0), 10000, 42).toDF() + + // Dataset for testing SparseVector + val toSparse: Vector => SparseVector = _.asInstanceOf[DenseVector].toSparse + val sparse = udf(toSparse) + smallSparseBinaryDataset = smallBinaryDataset.withColumn("features", sparse('features)) + smallSparseValidationDataset = smallValidationDataset.withColumn("features", sparse('features)) + } /** @@ -68,6 +79,8 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau val model = svm.fit(smallBinaryDataset) assert(model.transform(smallValidationDataset) .where("prediction=label").count() > nPoints * 0.8) + val sparseModel = svm.fit(smallSparseBinaryDataset) + checkModels(model, sparseModel) } test("Linear SVC binary classification with regularization") { @@ -75,6 +88,8 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau val model = svm.setRegParam(0.1).fit(smallBinaryDataset) assert(model.transform(smallValidationDataset) .where("prediction=label").count() > nPoints * 0.8) + val sparseModel = svm.fit(smallSparseBinaryDataset) + checkModels(model, sparseModel) } test("params") { @@ -235,7 +250,7 @@ object LinearSVCSuite { "aggregationDepth" -> 3 ) - // Generate noisy input of the form Y = signum(x.dot(weights) + intercept + noise) + // Generate noisy input of the form Y = signum(x.dot(weights) + intercept + noise) def generateSVMInput( intercept: Double, weights: Array[Double], @@ -252,5 +267,10 @@ object LinearSVCSuite { y.zip(x).map(p => LabeledPoint(p._1, Vectors.dense(p._2))) } + def checkModels(model1: LinearSVCModel, model2: LinearSVCModel): Unit = { + assert(model1.intercept == model2.intercept) + assert(model1.coefficients.equals(model2.coefficients)) + } + } |