aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorPunya Biswal <pbiswal@palantir.com>2015-04-21 14:50:02 -0700
committerMichael Armbrust <michael@databricks.com>2015-04-21 14:50:02 -0700
commit2a24bf92e6d36e876bad6a8b4e0ff12c407ebb8a (patch)
treec49d92745450245227daa05c2b86f4cf51352bb9 /sql
parent6265cba00f6141575b4be825735d77d4cea500ab (diff)
downloadspark-2a24bf92e6d36e876bad6a8b4e0ff12c407ebb8a.tar.gz
spark-2a24bf92e6d36e876bad6a8b4e0ff12c407ebb8a.tar.bz2
spark-2a24bf92e6d36e876bad6a8b4e0ff12c407ebb8a.zip
[SPARK-6996][SQL] Support map types in java beans
liancheng mengxr this is similar to #5146. Author: Punya Biswal <pbiswal@palantir.com> Closes #5578 from punya/feature/SPARK-6996 and squashes the following commits: d56c3e0 [Punya Biswal] Fix imports c7e308b [Punya Biswal] Support java iterable types in POJOs 5e00685 [Punya Biswal] Support map types in java beans
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala20
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/JavaTypeInference.scala110
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala52
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java57
4 files changed, 180 insertions, 59 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
index d4f9fdacda..a13e2f36a1 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/CatalystTypeConverters.scala
@@ -17,6 +17,7 @@
package org.apache.spark.sql.catalyst
+import java.lang.{Iterable => JavaIterable}
import java.util.{Map => JavaMap}
import scala.collection.mutable.HashMap
@@ -49,6 +50,16 @@ object CatalystTypeConverters {
case (s: Seq[_], arrayType: ArrayType) =>
s.map(convertToCatalyst(_, arrayType.elementType))
+ case (jit: JavaIterable[_], arrayType: ArrayType) => {
+ val iter = jit.iterator
+ var listOfItems: List[Any] = List()
+ while (iter.hasNext) {
+ val item = iter.next()
+ listOfItems :+= convertToCatalyst(item, arrayType.elementType)
+ }
+ listOfItems
+ }
+
case (s: Array[_], arrayType: ArrayType) =>
s.toSeq.map(convertToCatalyst(_, arrayType.elementType))
@@ -124,6 +135,15 @@ object CatalystTypeConverters {
extractOption(item) match {
case a: Array[_] => a.toSeq.map(elementConverter)
case s: Seq[_] => s.map(elementConverter)
+ case i: JavaIterable[_] => {
+ val iter = i.iterator
+ var convertedIterable: List[Any] = List()
+ while (iter.hasNext) {
+ val item = iter.next()
+ convertedIterable :+= elementConverter(item)
+ }
+ convertedIterable
+ }
case null => null
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/JavaTypeInference.scala b/sql/core/src/main/scala/org/apache/spark/sql/JavaTypeInference.scala
new file mode 100644
index 0000000000..db484c5f50
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/JavaTypeInference.scala
@@ -0,0 +1,110 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql
+
+import java.beans.Introspector
+import java.lang.{Iterable => JIterable}
+import java.util.{Iterator => JIterator, Map => JMap}
+
+import com.google.common.reflect.TypeToken
+
+import org.apache.spark.sql.types._
+
+import scala.language.existentials
+
+/**
+ * Type-inference utilities for POJOs and Java collections.
+ */
+private [sql] object JavaTypeInference {
+
+ private val iterableType = TypeToken.of(classOf[JIterable[_]])
+ private val mapType = TypeToken.of(classOf[JMap[_, _]])
+ private val iteratorReturnType = classOf[JIterable[_]].getMethod("iterator").getGenericReturnType
+ private val nextReturnType = classOf[JIterator[_]].getMethod("next").getGenericReturnType
+ private val keySetReturnType = classOf[JMap[_, _]].getMethod("keySet").getGenericReturnType
+ private val valuesReturnType = classOf[JMap[_, _]].getMethod("values").getGenericReturnType
+
+ /**
+ * Infers the corresponding SQL data type of a Java type.
+ * @param typeToken Java type
+ * @return (SQL data type, nullable)
+ */
+ private [sql] def inferDataType(typeToken: TypeToken[_]): (DataType, Boolean) = {
+ // TODO: All of this could probably be moved to Catalyst as it is mostly not Spark specific.
+ typeToken.getRawType match {
+ case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) =>
+ (c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true)
+
+ case c: Class[_] if c == classOf[java.lang.String] => (StringType, true)
+ case c: Class[_] if c == java.lang.Short.TYPE => (ShortType, false)
+ case c: Class[_] if c == java.lang.Integer.TYPE => (IntegerType, false)
+ case c: Class[_] if c == java.lang.Long.TYPE => (LongType, false)
+ case c: Class[_] if c == java.lang.Double.TYPE => (DoubleType, false)
+ case c: Class[_] if c == java.lang.Byte.TYPE => (ByteType, false)
+ case c: Class[_] if c == java.lang.Float.TYPE => (FloatType, false)
+ case c: Class[_] if c == java.lang.Boolean.TYPE => (BooleanType, false)
+
+ case c: Class[_] if c == classOf[java.lang.Short] => (ShortType, true)
+ case c: Class[_] if c == classOf[java.lang.Integer] => (IntegerType, true)
+ case c: Class[_] if c == classOf[java.lang.Long] => (LongType, true)
+ case c: Class[_] if c == classOf[java.lang.Double] => (DoubleType, true)
+ case c: Class[_] if c == classOf[java.lang.Byte] => (ByteType, true)
+ case c: Class[_] if c == classOf[java.lang.Float] => (FloatType, true)
+ case c: Class[_] if c == classOf[java.lang.Boolean] => (BooleanType, true)
+
+ case c: Class[_] if c == classOf[java.math.BigDecimal] => (DecimalType(), true)
+ case c: Class[_] if c == classOf[java.sql.Date] => (DateType, true)
+ case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, true)
+
+ case _ if typeToken.isArray =>
+ val (dataType, nullable) = inferDataType(typeToken.getComponentType)
+ (ArrayType(dataType, nullable), true)
+
+ case _ if iterableType.isAssignableFrom(typeToken) =>
+ val (dataType, nullable) = inferDataType(elementType(typeToken))
+ (ArrayType(dataType, nullable), true)
+
+ case _ if mapType.isAssignableFrom(typeToken) =>
+ val typeToken2 = typeToken.asInstanceOf[TypeToken[_ <: JMap[_, _]]]
+ val mapSupertype = typeToken2.getSupertype(classOf[JMap[_, _]])
+ val keyType = elementType(mapSupertype.resolveType(keySetReturnType))
+ val valueType = elementType(mapSupertype.resolveType(valuesReturnType))
+ val (keyDataType, _) = inferDataType(keyType)
+ val (valueDataType, nullable) = inferDataType(valueType)
+ (MapType(keyDataType, valueDataType, nullable), true)
+
+ case _ =>
+ val beanInfo = Introspector.getBeanInfo(typeToken.getRawType)
+ val properties = beanInfo.getPropertyDescriptors.filterNot(_.getName == "class")
+ val fields = properties.map { property =>
+ val returnType = typeToken.method(property.getReadMethod).getReturnType
+ val (dataType, nullable) = inferDataType(returnType)
+ new StructField(property.getName, dataType, nullable)
+ }
+ (new StructType(fields), true)
+ }
+ }
+
+ private def elementType(typeToken: TypeToken[_]): TypeToken[_] = {
+ val typeToken2 = typeToken.asInstanceOf[TypeToken[_ <: JIterable[_]]]
+ val iterableSupertype = typeToken2.getSupertype(classOf[JIterable[_]])
+ val iteratorType = iterableSupertype.resolveType(iteratorReturnType)
+ val itemType = iteratorType.resolveType(nextReturnType)
+ itemType
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index f9f3eb2e03..bcd20c06c6 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -25,6 +25,8 @@ import scala.collection.immutable
import scala.language.implicitConversions
import scala.reflect.runtime.universe.TypeTag
+import com.google.common.reflect.TypeToken
+
import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.rdd.RDD
@@ -1222,56 +1224,12 @@ class SQLContext(@transient val sparkContext: SparkContext)
* Returns a Catalyst Schema for the given java bean class.
*/
protected def getSchema(beanClass: Class[_]): Seq[AttributeReference] = {
- val (dataType, _) = inferDataType(beanClass)
+ val (dataType, _) = JavaTypeInference.inferDataType(TypeToken.of(beanClass))
dataType.asInstanceOf[StructType].fields.map { f =>
AttributeReference(f.name, f.dataType, f.nullable)()
}
}
- /**
- * Infers the corresponding SQL data type of a Java class.
- * @param clazz Java class
- * @return (SQL data type, nullable)
- */
- private def inferDataType(clazz: Class[_]): (DataType, Boolean) = {
- // TODO: All of this could probably be moved to Catalyst as it is mostly not Spark specific.
- clazz match {
- case c: Class[_] if c.isAnnotationPresent(classOf[SQLUserDefinedType]) =>
- (c.getAnnotation(classOf[SQLUserDefinedType]).udt().newInstance(), true)
-
- case c: Class[_] if c == classOf[java.lang.String] => (StringType, true)
- case c: Class[_] if c == java.lang.Short.TYPE => (ShortType, false)
- case c: Class[_] if c == java.lang.Integer.TYPE => (IntegerType, false)
- case c: Class[_] if c == java.lang.Long.TYPE => (LongType, false)
- case c: Class[_] if c == java.lang.Double.TYPE => (DoubleType, false)
- case c: Class[_] if c == java.lang.Byte.TYPE => (ByteType, false)
- case c: Class[_] if c == java.lang.Float.TYPE => (FloatType, false)
- case c: Class[_] if c == java.lang.Boolean.TYPE => (BooleanType, false)
-
- case c: Class[_] if c == classOf[java.lang.Short] => (ShortType, true)
- case c: Class[_] if c == classOf[java.lang.Integer] => (IntegerType, true)
- case c: Class[_] if c == classOf[java.lang.Long] => (LongType, true)
- case c: Class[_] if c == classOf[java.lang.Double] => (DoubleType, true)
- case c: Class[_] if c == classOf[java.lang.Byte] => (ByteType, true)
- case c: Class[_] if c == classOf[java.lang.Float] => (FloatType, true)
- case c: Class[_] if c == classOf[java.lang.Boolean] => (BooleanType, true)
-
- case c: Class[_] if c == classOf[java.math.BigDecimal] => (DecimalType(), true)
- case c: Class[_] if c == classOf[java.sql.Date] => (DateType, true)
- case c: Class[_] if c == classOf[java.sql.Timestamp] => (TimestampType, true)
-
- case c: Class[_] if c.isArray =>
- val (dataType, nullable) = inferDataType(c.getComponentType)
- (ArrayType(dataType, nullable), true)
-
- case _ =>
- val beanInfo = Introspector.getBeanInfo(clazz)
- val properties = beanInfo.getPropertyDescriptors.filterNot(_.getName == "class")
- val fields = properties.map { property =>
- val (dataType, nullable) = inferDataType(property.getPropertyType)
- new StructField(property.getName, dataType, nullable)
- }
- (new StructType(fields), true)
- }
- }
}
+
+
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
index 6d0fbe83c2..fc3ed4a708 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDataFrameSuite.java
@@ -17,23 +17,28 @@
package test.org.apache.spark.sql;
-import java.io.Serializable;
-import java.util.Arrays;
-
-import scala.collection.Seq;
-
-import org.junit.After;
-import org.junit.Assert;
-import org.junit.Before;
-import org.junit.Ignore;
-import org.junit.Test;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.primitives.Ints;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
-import org.apache.spark.sql.*;
+import org.apache.spark.sql.DataFrame;
+import org.apache.spark.sql.Row;
+import org.apache.spark.sql.SQLContext;
+import org.apache.spark.sql.TestData$;
import org.apache.spark.sql.test.TestSQLContext;
import org.apache.spark.sql.test.TestSQLContext$;
import org.apache.spark.sql.types.*;
+import org.junit.*;
+
+import scala.collection.JavaConversions;
+import scala.collection.Seq;
+import scala.collection.mutable.Buffer;
+
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.List;
+import java.util.Map;
import static org.apache.spark.sql.functions.*;
@@ -106,6 +111,8 @@ public class JavaDataFrameSuite {
public static class Bean implements Serializable {
private double a = 0.0;
private Integer[] b = new Integer[]{0, 1};
+ private Map<String, int[]> c = ImmutableMap.of("hello", new int[] { 1, 2 });
+ private List<String> d = Arrays.asList("floppy", "disk");
public double getA() {
return a;
@@ -114,6 +121,14 @@ public class JavaDataFrameSuite {
public Integer[] getB() {
return b;
}
+
+ public Map<String, int[]> getC() {
+ return c;
+ }
+
+ public List<String> getD() {
+ return d;
+ }
}
@Test
@@ -127,7 +142,15 @@ public class JavaDataFrameSuite {
Assert.assertEquals(
new StructField("b", new ArrayType(IntegerType$.MODULE$, true), true, Metadata.empty()),
schema.apply("b"));
- Row first = df.select("a", "b").first();
+ ArrayType valueType = new ArrayType(DataTypes.IntegerType, false);
+ MapType mapType = new MapType(DataTypes.StringType, valueType, true);
+ Assert.assertEquals(
+ new StructField("c", mapType, true, Metadata.empty()),
+ schema.apply("c"));
+ Assert.assertEquals(
+ new StructField("d", new ArrayType(DataTypes.StringType, true), true, Metadata.empty()),
+ schema.apply("d"));
+ Row first = df.select("a", "b", "c", "d").first();
Assert.assertEquals(bean.getA(), first.getDouble(0), 0.0);
// Now Java lists and maps are converetd to Scala Seq's and Map's. Once we get a Seq below,
// verify that it has the expected length, and contains expected elements.
@@ -136,5 +159,15 @@ public class JavaDataFrameSuite {
for (int i = 0; i < result.length(); i++) {
Assert.assertEquals(bean.getB()[i], result.apply(i));
}
+ Buffer<Integer> outputBuffer = (Buffer<Integer>) first.getJavaMap(2).get("hello");
+ Assert.assertArrayEquals(
+ bean.getC().get("hello"),
+ Ints.toArray(JavaConversions.asJavaList(outputBuffer)));
+ Seq<String> d = first.getAs(3);
+ Assert.assertEquals(bean.getD().size(), d.length());
+ for (int i = 0; i < d.length(); i++) {
+ Assert.assertEquals(bean.getD().get(i), d.apply(i));
+ }
}
+
}