aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-03-24 10:11:27 -0700
committerXiangrui Meng <meng@databricks.com>2015-03-24 10:11:27 -0700
commita1d1529daebee30b76b954d16a30849407f795d1 (patch)
tree121a7655df38cef5bced7d4b7c7b1e717c162182
parent08d452801195cc6cf0697a594e98cd4778f358ee (diff)
downloadspark-a1d1529daebee30b76b954d16a30849407f795d1.tar.gz
spark-a1d1529daebee30b76b954d16a30849407f795d1.tar.bz2
spark-a1d1529daebee30b76b954d16a30849407f795d1.zip
[SPARK-6475][SQL] recognize array types when infer data types from JavaBeans
Right now if there is a array field in a JavaBean, the user wold see an exception in `createDataFrame`. liancheng Author: Xiangrui Meng <meng@databricks.com> Closes #5146 from mengxr/SPARK-6475 and squashes the following commits: 51e87e5 [Xiangrui Meng] validate schemas 4f2df5e [Xiangrui Meng] recognize array types when infer data types from JavaBeans
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala80
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java41
2 files changed, 89 insertions, 32 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index dc9912b52d..e59cf9b9e0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -1210,38 +1210,56 @@ class SQLContext(@transient val sparkContext: SparkContext)
* Returns a Catalyst Schema for the given java bean class.
*/
protected def getSchema(beanClass: Class[_]): Seq[AttributeReference] = {
+ val (dataType, _) = inferDataType(beanClass)
+ dataType.asInstanceOf[StructType].fields.map { f =>
+ AttributeReference(f.name, f.dataType, f.nullable)()
+ }
+ }
+
+ /**
+ * Infers the corresponding SQL data type of a Java class.
+ * @param clazz Java class
+ * @return (SQL data type, nullable)
+ */
+ private def inferDataType(clazz: Class[_]): (DataType, Boolean) = {
// TODO: All of this could probably be moved to Catalyst as it is mostly not Spark specific.
- val beanInfo = Introspector.getBeanInfo(beanClass)
-
- // Note: The ordering of elements may differ from when the schema is inferred in Scala.
- // This is because beanInfo.getPropertyDescriptors gives no guarantees about
- // element ordering.
- val fields = beanInfo.getPropertyDescriptors.filterNot(_.getName == "class")
- fields.map { property =>
- val (dataType, nullable) = property.getPropertyType match {
- case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) =>
- (c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true)
- case c: Class[_] if c == classOf[java.lang.String] => (StringType, true)
- case c: Class[_] if c == java.lang.Short.TYPE => (ShortType, false)
- case c: Class[_] if c == java.lang.Integer.TYPE => (IntegerType, false)
- case c: Class[_] if c == java.lang.Long.TYPE => (LongType, false)
- case c: Class[_] if c == java.lang.Double.TYPE => (DoubleType, false)
- case c: Class[_] if c == java.lang.Byte.TYPE => (ByteType, false)
- case c: Class[_] if c == java.lang.Float.TYPE => (FloatType, false)
- case c: Class[_] if c == java.lang.Boolean.TYPE => (BooleanType, false)
-
- case c: Class[_] if c == classOf[java.lang.Short] => (ShortType, true)
- case c: Class[_] if c == classOf[java.lang.Integer] => (IntegerType, true)
- case c: Class[_] if c == classOf[java.lang.Long] => (LongType, true)
- case c: Class[_] if c == classOf[java.lang.Double] => (DoubleType, true)
- case c: Class[_] if c == classOf[java.lang.Byte] => (ByteType, true)
- case c: Class[_] if c == classOf[java.lang.Float] => (FloatType, true)
- case c: Class[_] if c == classOf[java.lang.Boolean] => (BooleanType, true)
- case c: Class[_] if c == classOf[java.math.BigDecimal] => (DecimalType(), true)
- case c: Class[_] if c == classOf[java.sql.Date] => (DateType, true)
- case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, true)
- }
- AttributeReference(property.getName, dataType, nullable)()
+ clazz match {
+ case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) =>
+ (c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true)
+
+ case c: Class[_] if c == classOf[java.lang.String] => (StringType, true)
+ case c: Class[_] if c == java.lang.Short.TYPE => (ShortType, false)
+ case c: Class[_] if c == java.lang.Integer.TYPE => (IntegerType, false)
+ case c: Class[_] if c == java.lang.Long.TYPE => (LongType, false)
+ case c: Class[_] if c == java.lang.Double.TYPE => (DoubleType, false)
+ case c: Class[_] if c == java.lang.Byte.TYPE => (ByteType, false)
+ case c: Class[_] if c == java.lang.Float.TYPE => (FloatType, false)
+ case c: Class[_] if c == java.lang.Boolean.TYPE => (BooleanType, false)
+
+ case c: Class[_] if c == classOf[java.lang.Short] => (ShortType, true)
+ case c: Class[_] if c == classOf[java.lang.Integer] => (IntegerType, true)
+ case c: Class[_] if c == classOf[java.lang.Long] => (LongType, true)
+ case c: Class[_] if c == classOf[java.lang.Double] => (DoubleType, true)
+ case c: Class[_] if c == classOf[java.lang.Byte] => (ByteType, true)
+ case c: Class[_] if c == classOf[java.lang.Float] => (FloatType, true)
+ case c: Class[_] if c == classOf[java.lang.Boolean] => (BooleanType, true)
+
+ case c: Class[_] if c == classOf[java.math.BigDecimal] => (DecimalType(), true)
+ case c: Class[_] if c == classOf[java.sql.Date] => (DateType, true)
+ case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, true)
+
+ case c: Class[_] if c.isArray =>
+ val (dataType, nullable) = inferDataType(c.getComponentType)
+ (ArrayType(dataType, nullable), true)
+
+ case _ =>
+ val beanInfo = Introspector.getBeanInfo(clazz)
+ val properties = beanInfo.getPropertyDescriptors.filterNot(_.getName == "class")
+ val fields = properties.map { property =>
+ val (dataType, nullable) = inferDataType(property.getPropertyType)
+ new StructField(property.getName, dataType, nullable)
+ }
+ (new StructType(fields), true)
}
}
}
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
index 2d586f784a..1ff2d5a190 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
@@ -17,29 +17,39 @@
package test.org.apache.spark.sql;
+import java.io.Serializable;
+import java.util.Arrays;
+
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Ignore;
import org.junit.Test;
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.*;
+import org.apache.spark.sql.test.TestSQLContext;
import org.apache.spark.sql.test.TestSQLContext$;
-import static org.apache.spark.sql.functions.*;
+import org.apache.spark.sql.types.*;
+import static org.apache.spark.sql.functions.*;
public class JavaDataFrameSuite {
+ private transient JavaSparkContext jsc;
private transient SQLContext context;
@Before
public void setUp() {
// Trigger static initializer of TestData
TestData$.MODULE$.testData();
+ jsc = new JavaSparkContext(TestSQLContext.sparkContext());
context = TestSQLContext$.MODULE$;
}
@After
public void tearDown() {
+ jsc = null;
context = null;
}
@@ -90,4 +100,33 @@ public class JavaDataFrameSuite {
df.show();
df.show(1000);
}
+
+ public static class Bean implements Serializable {
+ private double a = 0.0;
+ private Integer[] b = new Integer[]{0, 1};
+
+ public double getA() {
+ return a;
+ }
+
+ public Integer[] getB() {
+ return b;
+ }
+ }
+
+ @Test
+ public void testCreateDataFrameFromJavaBeans() {
+ Bean bean = new Bean();
+ JavaRDD<Bean> rdd = jsc.parallelize(Arrays.asList(bean));
+ DataFrame df = context.createDataFrame(rdd, Bean.class);
+ StructType schema = df.schema();
+ Assert.assertEquals(new StructField("a", DoubleType$.MODULE$, false, Metadata.empty()),
+ schema.apply("a"));
+ Assert.assertEquals(
+ new StructField("b", new ArrayType(IntegerType$.MODULE$, true), true, Metadata.empty()),
+ schema.apply("b"));
+ Row first = df.select("a", "b").first();
+ Assert.assertEquals(bean.getA(), first.getDouble(0), 0.0);
+ Assert.assertArrayEquals(bean.getB(), first.<Integer[]>getAs(1));
+ }
}