aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala24
1 files changed, 22 insertions, 2 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala
index a165d8a934..fe47176a4a 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LinearSVCSuite.scala
@@ -24,12 +24,13 @@ import breeze.linalg.{DenseVector => BDV}
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.classification.LinearSVCSuite._
import org.apache.spark.ml.feature.{Instance, LabeledPoint}
-import org.apache.spark.ml.linalg.{Vector, Vectors}
+import org.apache.spark.ml.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{Dataset, Row}
+import org.apache.spark.sql.functions.udf
class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
@@ -41,6 +42,9 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
@transient var smallValidationDataset: Dataset[_] = _
@transient var binaryDataset: Dataset[_] = _
+ @transient var smallSparseBinaryDataset: Dataset[_] = _
+ @transient var smallSparseValidationDataset: Dataset[_] = _
+
override def beforeAll(): Unit = {
super.beforeAll()
@@ -51,6 +55,13 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
smallBinaryDataset = generateSVMInput(A, Array[Double](B, C), nPoints, 42).toDF()
smallValidationDataset = generateSVMInput(A, Array[Double](B, C), nPoints, 17).toDF()
binaryDataset = generateSVMInput(1.0, Array[Double](1.0, 2.0, 3.0, 4.0), 10000, 42).toDF()
+
+ // Dataset for testing SparseVector
+ val toSparse: Vector => SparseVector = _.asInstanceOf[DenseVector].toSparse
+ val sparse = udf(toSparse)
+ smallSparseBinaryDataset = smallBinaryDataset.withColumn("features", sparse('features))
+ smallSparseValidationDataset = smallValidationDataset.withColumn("features", sparse('features))
+
}
/**
@@ -68,6 +79,8 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
val model = svm.fit(smallBinaryDataset)
assert(model.transform(smallValidationDataset)
.where("prediction=label").count() > nPoints * 0.8)
+ val sparseModel = svm.fit(smallSparseBinaryDataset)
+ checkModels(model, sparseModel)
}
test("Linear SVC binary classification with regularization") {
@@ -75,6 +88,8 @@ class LinearSVCSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
val model = svm.setRegParam(0.1).fit(smallBinaryDataset)
assert(model.transform(smallValidationDataset)
.where("prediction=label").count() > nPoints * 0.8)
+ val sparseModel = svm.fit(smallSparseBinaryDataset)
+ checkModels(model, sparseModel)
}
test("params") {
@@ -235,7 +250,7 @@ object LinearSVCSuite {
"aggregationDepth" -> 3
)
- // Generate noisy input of the form Y = signum(x.dot(weights) + intercept + noise)
+ // Generate noisy input of the form Y = signum(x.dot(weights) + intercept + noise)
def generateSVMInput(
intercept: Double,
weights: Array[Double],
@@ -252,5 +267,10 @@ object LinearSVCSuite {
y.zip(x).map(p => LabeledPoint(p._1, Vectors.dense(p._2)))
}
+ def checkModels(model1: LinearSVCModel, model2: LinearSVCModel): Unit = {
+ assert(model1.intercept == model2.intercept)
+ assert(model1.coefficients.equals(model2.coefficients))
+ }
+
}