From 8ecba3e86e53834413da8b4299f5791545cae12e Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Sun, 27 Sep 2015 21:16:15 +0100 Subject: [SPARK-10720] [SQL] [JAVA] Add a java wrapper to create a dataframe from a local list of java beans Similar to SPARK-10630 it would be nice if Java users didn't have to parallelize there data explicitly (as Scala users already can skip). Issue came up in http://stackoverflow.com/questions/32613413/apache-spark-machine-learning-cant-get-estimator-example-to-work Author: Holden Karau Closes #8879 from holdenk/SPARK-10720-add-a-java-wrapper-to-create-a-dataframe-from-a-local-list-of-java-beans. --- .../scala/org/apache/spark/sql/SQLContext.scala | 51 +++++++++++++++++----- .../org/apache/spark/sql/JavaDataFrameSuite.java | 22 +++++++--- 2 files changed, 56 insertions(+), 17 deletions(-) (limited to 'sql/core') diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala index 1bd4e26fb3..cb0a3e361c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql -import java.beans.Introspector +import java.beans.{BeanInfo, Introspector} import java.util.Properties import java.util.concurrent.atomic.AtomicReference @@ -499,21 +499,12 @@ class SQLContext(@transient val sparkContext: SparkContext) * @since 1.3.0 */ def createDataFrame(rdd: RDD[_], beanClass: Class[_]): DataFrame = { - val attributeSeq = getSchema(beanClass) + val attributeSeq: Seq[AttributeReference] = getSchema(beanClass) val className = beanClass.getName val rowRdd = rdd.mapPartitions { iter => // BeanInfo is not serializable so we must rediscover it remotely for each partition. val localBeanInfo = Introspector.getBeanInfo(Utils.classForName(className)) - val extractors = - localBeanInfo.getPropertyDescriptors.filterNot(_.getName == "class").map(_.getReadMethod) - val methodsToConverts = extractors.zip(attributeSeq).map { case (e, attr) => - (e, CatalystTypeConverters.createToCatalystConverter(attr.dataType)) - } - iter.map { row => - new GenericInternalRow( - methodsToConverts.map { case (e, convert) => convert(e.invoke(row)) }.toArray[Any] - ): InternalRow - } + SQLContext.beansToRows(iter, localBeanInfo, attributeSeq) } DataFrame(this, LogicalRDD(attributeSeq, rowRdd)(this)) } @@ -530,6 +521,23 @@ class SQLContext(@transient val sparkContext: SparkContext) createDataFrame(rdd.rdd, beanClass) } + /** + * Applies a schema to an List of Java Beans. + * + * WARNING: Since there is no guaranteed ordering for fields in a Java Bean, + * SELECT * queries will return the columns in an undefined order. + * @group dataframes + * @since 1.6.0 + */ + def createDataFrame(data: java.util.List[_], beanClass: Class[_]): DataFrame = { + val attrSeq = getSchema(beanClass) + val className = beanClass.getName + val beanInfo = Introspector.getBeanInfo(beanClass) + val rows = SQLContext.beansToRows(data.asScala.iterator, beanInfo, attrSeq) + DataFrame(self, LocalRelation(attrSeq, rows.toSeq)) + } + + /** * :: Experimental :: * Returns a [[DataFrameReader]] that can be used to read data in as a [[DataFrame]]. @@ -1229,4 +1237,23 @@ object SQLContext { lastInstantiatedContext.set(sqlContext) } } + + /** + * Converts an iterator of Java Beans to InternalRow using the provided + * bean info & schema. This is not related to the singleton, but is a static + * method for internal use. + */ + private def beansToRows(data: Iterator[_], beanInfo: BeanInfo, attrs: Seq[AttributeReference]): + Iterator[InternalRow] = { + val extractors = + beanInfo.getPropertyDescriptors.filterNot(_.getName == "class").map(_.getReadMethod) + val methodsToConverts = extractors.zip(attrs).map { case (e, attr) => + (e, CatalystTypeConverters.createToCatalystConverter(attr.dataType)) + } + data.map{ element => + new GenericInternalRow( + methodsToConverts.map { case (e, convert) => convert(e.invoke(element)) }.toArray[Any] + ): InternalRow + } + } } diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java index 250ac2e109..a1a3fdbb48 100644 --- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java +++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java @@ -142,11 +142,7 @@ public class JavaDataFrameSuite { } } - @Test - public void testCreateDataFrameFromJavaBeans() { - Bean bean = new Bean(); - JavaRDD rdd = jsc.parallelize(Arrays.asList(bean)); - DataFrame df = context.createDataFrame(rdd, Bean.class); + void validateDataFrameWithBeans(Bean bean, DataFrame df) { StructType schema = df.schema(); Assert.assertEquals(new StructField("a", DoubleType$.MODULE$, false, Metadata.empty()), schema.apply("a")); @@ -182,6 +178,22 @@ public class JavaDataFrameSuite { } } + @Test + public void testCreateDataFrameFromLocalJavaBeans() { + Bean bean = new Bean(); + List data = Arrays.asList(bean); + DataFrame df = context.createDataFrame(data, Bean.class); + validateDataFrameWithBeans(bean, df); + } + + @Test + public void testCreateDataFrameFromJavaBeans() { + Bean bean = new Bean(); + JavaRDD rdd = jsc.parallelize(Arrays.asList(bean)); + DataFrame df = context.createDataFrame(rdd, Bean.class); + validateDataFrameWithBeans(bean, df); + } + @Test public void testCreateDataFromFromList() { StructType schema = createStructType(Arrays.asList(createStructField("i", IntegerType, true))); -- cgit v1.2.3