aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala45
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala7
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java17
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java11
5 files changed, 54 insertions, 32 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
index 8b53d988cb..e9d9508e5a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/JavaTypeInference.scala
@@ -117,11 +117,10 @@ object JavaTypeInference {
val (valueDataType, nullable) = inferDataType(valueType)
(MapType(keyDataType, valueDataType, nullable), true)
- case _ =>
+ case other =>
// TODO: we should only collect properties that have getter and setter. However, some tests
// pass in scala case class as java bean class which doesn't have getter and setter.
- val beanInfo = Introspector.getBeanInfo(typeToken.getRawType)
- val properties = beanInfo.getPropertyDescriptors.filterNot(_.getName == "class")
+ val properties = getJavaBeanReadableProperties(other)
val fields = properties.map { property =>
val returnType = typeToken.method(property.getReadMethod).getReturnType
val (dataType, nullable) = inferDataType(returnType)
@@ -131,10 +130,15 @@ object JavaTypeInference {
}
}
- private def getJavaBeanProperties(beanClass: Class[_]): Array[PropertyDescriptor] = {
+ def getJavaBeanReadableProperties(beanClass: Class[_]): Array[PropertyDescriptor] = {
val beanInfo = Introspector.getBeanInfo(beanClass)
- beanInfo.getPropertyDescriptors
- .filter(p => p.getReadMethod != null && p.getWriteMethod != null)
+ beanInfo.getPropertyDescriptors.filterNot(_.getName == "class")
+ .filter(_.getReadMethod != null)
+ }
+
+ private def getJavaBeanReadableAndWritableProperties(
+ beanClass: Class[_]): Array[PropertyDescriptor] = {
+ getJavaBeanReadableProperties(beanClass).filter(_.getWriteMethod != null)
}
private def elementType(typeToken: TypeToken[_]): TypeToken[_] = {
@@ -298,9 +302,7 @@ object JavaTypeInference {
keyData :: valueData :: Nil)
case other =>
- val properties = getJavaBeanProperties(other)
- assert(properties.length > 0)
-
+ val properties = getJavaBeanReadableAndWritableProperties(other)
val setters = properties.map { p =>
val fieldName = p.getName
val fieldType = typeToken.method(p.getReadMethod).getReturnType
@@ -417,21 +419,16 @@ object JavaTypeInference {
)
case other =>
- val properties = getJavaBeanProperties(other)
- if (properties.length > 0) {
- CreateNamedStruct(properties.flatMap { p =>
- val fieldName = p.getName
- val fieldType = typeToken.method(p.getReadMethod).getReturnType
- val fieldValue = Invoke(
- inputObject,
- p.getReadMethod.getName,
- inferExternalType(fieldType.getRawType))
- expressions.Literal(fieldName) :: serializerFor(fieldValue, fieldType) :: Nil
- })
- } else {
- throw new UnsupportedOperationException(
- s"Cannot infer type for class ${other.getName} because it is not bean-compliant")
- }
+ val properties = getJavaBeanReadableAndWritableProperties(other)
+ CreateNamedStruct(properties.flatMap { p =>
+ val fieldName = p.getName
+ val fieldType = typeToken.method(p.getReadMethod).getReturnType
+ val fieldValue = Invoke(
+ inputObject,
+ p.getReadMethod.getName,
+ inferExternalType(fieldType.getRawType))
+ expressions.Literal(fieldName) :: serializerFor(fieldValue, fieldType) :: Nil
+ })
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index dbe55090ea..234ef2dffc 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -1090,14 +1090,14 @@ object SQLContext {
*/
private[sql] def beansToRows(
data: Iterator[_],
- beanInfo: BeanInfo,
+ beanClass: Class[_],
attrs: Seq[AttributeReference]): Iterator[InternalRow] = {
val extractors =
- beanInfo.getPropertyDescriptors.filterNot(_.getName == "class").map(_.getReadMethod)
+ JavaTypeInference.getJavaBeanReadableProperties(beanClass).map(_.getReadMethod)
val methodsToConverts = extractors.zip(attrs).map { case (e, attr) =>
(e, CatalystTypeConverters.createToCatalystConverter(attr.dataType))
}
- data.map{ element =>
+ data.map { element =>
new GenericInternalRow(
methodsToConverts.map { case (e, convert) => convert(e.invoke(element)) }
): InternalRow
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
index 72af55c1fa..afc1827e7e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SparkSession.scala
@@ -17,7 +17,6 @@
package org.apache.spark.sql
-import java.beans.Introspector
import java.io.Closeable
import java.util.concurrent.atomic.AtomicReference
@@ -347,8 +346,7 @@ class SparkSession private(
val className = beanClass.getName
val rowRdd = rdd.mapPartitions { iter =>
// BeanInfo is not serializable so we must rediscover it remotely for each partition.
- val localBeanInfo = Introspector.getBeanInfo(Utils.classForName(className))
- SQLContext.beansToRows(iter, localBeanInfo, attributeSeq)
+ SQLContext.beansToRows(iter, Utils.classForName(className), attributeSeq)
}
Dataset.ofRows(self, LogicalRDD(attributeSeq, rowRdd)(self))
}
@@ -374,8 +372,7 @@ class SparkSession private(
*/
def createDataFrame(data: java.util.List[_], beanClass: Class[_]): DataFrame = {
val attrSeq = getSchema(beanClass)
- val beanInfo = Introspector.getBeanInfo(beanClass)
- val rows = SQLContext.beansToRows(data.asScala.iterator, beanInfo, attrSeq)
+ val rows = SQLContext.beansToRows(data.asScala.iterator, beanClass, attrSeq)
Dataset.ofRows(self, LocalRelation(attrSeq, rows.toSeq))
}
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
index c3b94a44c2..a8f814bfae 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
@@ -397,4 +397,21 @@ public class JavaDataFrameSuite {
Assert.assertTrue(filter4.mightContain(i * 3));
}
}
+
+ public static class BeanWithoutGetter implements Serializable {
+ private String a;
+
+ public void setA(String a) {
+ this.a = a;
+ }
+ }
+
+ @Test
+ public void testBeanWithoutGetter() {
+ BeanWithoutGetter bean = new BeanWithoutGetter();
+ List<BeanWithoutGetter> data = Arrays.asList(bean);
+ Dataset<Row> df = spark.createDataFrame(data, BeanWithoutGetter.class);
+ Assert.assertEquals(df.schema().length(), 0);
+ Assert.assertEquals(df.collectAsList().size(), 1);
+ }
}
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index 577672ca8e..4581c6ebe9 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -1276,4 +1276,15 @@ public class JavaDatasetSuite implements Serializable {
spark.createDataset(data, Encoders.bean(NestedComplicatedJavaBean.class));
ds.collectAsList();
}
+
+ public static class EmptyBean implements Serializable {}
+
+ @Test
+ public void testEmptyBean() {
+ EmptyBean bean = new EmptyBean();
+ List<EmptyBean> data = Arrays.asList(bean);
+ Dataset<EmptyBean> df = spark.createDataset(data, Encoders.bean(EmptyBean.class));
+ Assert.assertEquals(df.schema().length(), 0);
+ Assert.assertEquals(df.collectAsList().size(), 1);
+ }
}