aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/linalg/VectorUDTSuite.scala16
1 files changed, 16 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/linalg/VectorUDTSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/linalg/VectorUDTSuite.scala
index 6d01d8f282..7b50876d33 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/linalg/VectorUDTSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/linalg/VectorUDTSuite.scala
@@ -17,9 +17,19 @@
package org.apache.spark.ml.linalg
+import scala.beans.BeanInfo
+
import org.apache.spark.SparkFunSuite
+import org.apache.spark.sql.catalyst.JavaTypeInference
import org.apache.spark.sql.types._
+@BeanInfo
+case class LabeledPoint(label: Double, features: Vector) {
+ override def toString: String = {
+ s"($label,$features)"
+ }
+}
+
class VectorUDTSuite extends SparkFunSuite {
test("preloaded VectorUDT") {
@@ -36,4 +46,10 @@ class VectorUDTSuite extends SparkFunSuite {
assert(udt.simpleString == "vector")
}
}
+
+ test("JavaTypeInference with VectorUDT") {
+ val (dataType, _) = JavaTypeInference.inferDataType(classOf[LabeledPoint])
+ assert(dataType.asInstanceOf[StructType].fields.map(_.dataType)
+ === Seq(new VectorUDT, DoubleType))
+ }
}