diff options
Diffstat (limited to 'mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala')
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala | 10 |
1 files changed, 9 insertions, 1 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala index 57d0278e03..0db27607bc 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/VectorAssemblerSuite.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.feature import org.scalatest.FunSuite import org.apache.spark.SparkException -import org.apache.spark.mllib.linalg.{Vector, Vectors} +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.sql.{Row, SQLContext} @@ -48,6 +48,14 @@ class VectorAssemblerSuite extends FunSuite with MLlibTestSparkContext { } } + test("assemble should compress vectors") { + import org.apache.spark.ml.feature.VectorAssembler.assemble + val v1 = assemble(0.0, 0.0, 0.0, Vectors.dense(4.0)) + assert(v1.isInstanceOf[SparseVector]) + val v2 = assemble(1.0, 2.0, 3.0, Vectors.sparse(1, Array(0), Array(4.0))) + assert(v2.isInstanceOf[DenseVector]) + } + test("VectorAssembler") { val df = sqlContext.createDataFrame(Seq( (0, 0.0, Vectors.dense(1.0, 2.0), "a", Vectors.sparse(2, Array(1), Array(3.0)), 10L) |