diff options
author | Liang-Chi Hsieh <simonh@tw.ibm.com> | 2016-05-11 09:31:22 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2016-05-11 09:31:22 -0700 |
commit | a5f9fdbba3bbefb56ca9ab33271301a2ff0834b5 (patch) | |
tree | aa16e3718d2af592411c18fd67d28a477a3d0e38 /mllib | |
parent | 427c20dd6d84cb9de1aac322183bc6e7b72ca25d (diff) | |
download | spark-a5f9fdbba3bbefb56ca9ab33271301a2ff0834b5.tar.gz spark-a5f9fdbba3bbefb56ca9ab33271301a2ff0834b5.tar.bz2 spark-a5f9fdbba3bbefb56ca9ab33271301a2ff0834b5.zip |
[SPARK-15268][SQL] Make JavaTypeInference work with UDTRegistration
## What changes were proposed in this pull request?
We have a private `UDTRegistration` API to register user defined type. Currently `JavaTypeInference` can't work with it. So `SparkSession.createDataFrame` from a bean class will not correctly infer the schema of the bean class.
## How was this patch tested?
`VectorUDTSuite`.
Author: Liang-Chi Hsieh <simonh@tw.ibm.com>
Closes #13046 from viirya/fix-udt-registry-javatypeinference.
Diffstat (limited to 'mllib')
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/linalg/VectorUDTSuite.scala | 16 |
1 files changed, 16 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/linalg/VectorUDTSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/linalg/VectorUDTSuite.scala index 6d01d8f282..7b50876d33 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/linalg/VectorUDTSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/linalg/VectorUDTSuite.scala @@ -17,9 +17,19 @@ package org.apache.spark.ml.linalg +import scala.beans.BeanInfo + import org.apache.spark.SparkFunSuite +import org.apache.spark.sql.catalyst.JavaTypeInference import org.apache.spark.sql.types._ +@BeanInfo +case class LabeledPoint(label: Double, features: Vector) { + override def toString: String = { + s"($label,$features)" + } +} + class VectorUDTSuite extends SparkFunSuite { test("preloaded VectorUDT") { @@ -36,4 +46,10 @@ class VectorUDTSuite extends SparkFunSuite { assert(udt.simpleString == "vector") } } + + test("JavaTypeInference with VectorUDT") { + val (dataType, _) = JavaTypeInference.inferDataType(classOf[LabeledPoint]) + assert(dataType.asInstanceOf[StructType].fields.map(_.dataType) + === Seq(new VectorUDT, DoubleType)) + } } |