aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorHolden Karau <holden@pigscanfly.ca>2015-09-27 21:16:15 +0100
committerSean Owen <sowen@cloudera.com>2015-09-27 21:16:15 +0100
commit8ecba3e86e53834413da8b4299f5791545cae12e (patch)
tree3abdaf5fb6cdc748c70039960e87c9acbbf7ea0e /sql
parent418e5e4cbdaab87addb91ac0bb2245ff0213ac81 (diff)
downloadspark-8ecba3e86e53834413da8b4299f5791545cae12e.tar.gz
spark-8ecba3e86e53834413da8b4299f5791545cae12e.tar.bz2
spark-8ecba3e86e53834413da8b4299f5791545cae12e.zip
[SPARK-10720] [SQL] [JAVA] Add a java wrapper to create a dataframe from a local list of java beans
Similar to SPARK-10630 it would be nice if Java users didn't have to parallelize there data explicitly (as Scala users already can skip). Issue came up in http://stackoverflow.com/questions/32613413/apache-spark-machine-learning-cant-get-estimator-example-to-work Author: Holden Karau <holden@pigscanfly.ca> Closes #8879 from holdenk/SPARK-10720-add-a-java-wrapper-to-create-a-dataframe-from-a-local-list-of-java-beans.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala51
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java22
2 files changed, 56 insertions, 17 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 1bd4e26fb3..cb0a3e361c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql
-import java.beans.Introspector
+import java.beans.{BeanInfo, Introspector}
import java.util.Properties
import java.util.concurrent.atomic.AtomicReference
@@ -499,21 +499,12 @@ class SQLContext(@transient val sparkContext: SparkContext)
* @since 1.3.0
*/
def createDataFrame(rdd: RDD[_], beanClass: Class[_]): DataFrame = {
- val attributeSeq = getSchema(beanClass)
+ val attributeSeq: Seq[AttributeReference] = getSchema(beanClass)
val className = beanClass.getName
val rowRdd = rdd.mapPartitions { iter =>
// BeanInfo is not serializable so we must rediscover it remotely for each partition.
val localBeanInfo = Introspector.getBeanInfo(Utils.classForName(className))
- val extractors =
- localBeanInfo.getPropertyDescriptors.filterNot(_.getName == "class").map(_.getReadMethod)
- val methodsToConverts = extractors.zip(attributeSeq).map { case (e, attr) =>
- (e, CatalystTypeConverters.createToCatalystConverter(attr.dataType))
- }
- iter.map { row =>
- new GenericInternalRow(
- methodsToConverts.map { case (e, convert) => convert(e.invoke(row)) }.toArray[Any]
- ): InternalRow
- }
+ SQLContext.beansToRows(iter, localBeanInfo, attributeSeq)
}
DataFrame(this, LogicalRDD(attributeSeq, rowRdd)(this))
}
@@ -531,6 +522,23 @@ class SQLContext(@transient val sparkContext: SparkContext)
}
/**
+ * Applies a schema to an List of Java Beans.
+ *
+ * WARNING: Since there is no guaranteed ordering for fields in a Java Bean,
+ * SELECT * queries will return the columns in an undefined order.
+ * @group dataframes
+ * @since 1.6.0
+ */
+ def createDataFrame(data: java.util.List[_], beanClass: Class[_]): DataFrame = {
+ val attrSeq = getSchema(beanClass)
+ val className = beanClass.getName
+ val beanInfo = Introspector.getBeanInfo(beanClass)
+ val rows = SQLContext.beansToRows(data.asScala.iterator, beanInfo, attrSeq)
+ DataFrame(self, LocalRelation(attrSeq, rows.toSeq))
+ }
+
+
+ /**
* :: Experimental ::
* Returns a [[DataFrameReader]] that can be used to read data in as a [[DataFrame]].
* {{{
@@ -1229,4 +1237,23 @@ object SQLContext {
lastInstantiatedContext.set(sqlContext)
}
}
+
+ /**
+ * Converts an iterator of Java Beans to InternalRow using the provided
+ * bean info & schema. This is not related to the singleton, but is a static
+ * method for internal use.
+ */
+ private def beansToRows(data: Iterator[_], beanInfo: BeanInfo, attrs: Seq[AttributeReference]):
+ Iterator[InternalRow] = {
+ val extractors =
+ beanInfo.getPropertyDescriptors.filterNot(_.getName == "class").map(_.getReadMethod)
+ val methodsToConverts = extractors.zip(attrs).map { case (e, attr) =>
+ (e, CatalystTypeConverters.createToCatalystConverter(attr.dataType))
+ }
+ data.map{ element =>
+ new GenericInternalRow(
+ methodsToConverts.map { case (e, convert) => convert(e.invoke(element)) }.toArray[Any]
+ ): InternalRow
+ }
+ }
}
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
index 250ac2e109..a1a3fdbb48 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
@@ -142,11 +142,7 @@ public class JavaDataFrameSuite {
}
}
- @Test
- public void testCreateDataFrameFromJavaBeans() {
- Bean bean = new Bean();
- JavaRDD<Bean> rdd = jsc.parallelize(Arrays.asList(bean));
- DataFrame df = context.createDataFrame(rdd, Bean.class);
+ void validateDataFrameWithBeans(Bean bean, DataFrame df) {
StructType schema = df.schema();
Assert.assertEquals(new StructField("a", DoubleType$.MODULE$, false, Metadata.empty()),
schema.apply("a"));
@@ -183,6 +179,22 @@ public class JavaDataFrameSuite {
}
@Test
+ public void testCreateDataFrameFromLocalJavaBeans() {
+ Bean bean = new Bean();
+ List<Bean> data = Arrays.asList(bean);
+ DataFrame df = context.createDataFrame(data, Bean.class);
+ validateDataFrameWithBeans(bean, df);
+ }
+
+ @Test
+ public void testCreateDataFrameFromJavaBeans() {
+ Bean bean = new Bean();
+ JavaRDD<Bean> rdd = jsc.parallelize(Arrays.asList(bean));
+ DataFrame df = context.createDataFrame(rdd, Bean.class);
+ validateDataFrameWithBeans(bean, df);
+ }
+
+ @Test
public void testCreateDataFromFromList() {
StructType schema = createStructType(Arrays.asList(createStructField("i", IntegerType, true)));
List<Row> rows = Arrays.asList(RowFactory.create(0));