aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorCheng Lian <lian@databricks.com>2016-01-24 19:40:34 -0800
committerYin Huai <yhuai@databricks.com>2016-01-24 19:40:34 -0800
commit3327fd28170b549516fee1972dc6f4c32541591b (patch)
tree9d5a2d7a0fd49bd77907f3199d46b360b9fc42c7
parente789b1d2c1eab6187f54424ed92697ca200c3101 (diff)
downloadspark-3327fd28170b549516fee1972dc6f4c32541591b.tar.gz
spark-3327fd28170b549516fee1972dc6f4c32541591b.tar.bz2
spark-3327fd28170b549516fee1972dc6f4c32541591b.zip
[SPARK-12624][PYSPARK] Checks row length when converting Java arrays to Python rows
When actual row length doesn't conform to specified schema field length, we should give a better error message instead of throwing an unintuitive `ArrayOutOfBoundsException`. Author: Cheng Lian <lian@databricks.com> Closes #10886 from liancheng/spark-12624.
-rw-r--r--python/pyspark/sql/tests.py9
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala9
2 files changed, 17 insertions, 1 deletions
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index ae8620274d..7593b991a7 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -364,6 +364,15 @@ class SQLTests(ReusedPySparkTestCase):
df3 = self.sqlCtx.createDataFrame(rdd, df.schema)
self.assertEqual(10, df3.count())
+ def test_create_dataframe_schema_mismatch(self):
+ input = [Row(a=1)]
+ rdd = self.sc.parallelize(range(3)).map(lambda i: Row(a=i))
+ schema = StructType([StructField("a", IntegerType()), StructField("b", StringType())])
+ df = self.sqlCtx.createDataFrame(rdd, schema)
+ message = ".*Input row doesn't have expected number of values required by the schema.*"
+ with self.assertRaisesRegexp(Exception, message):
+ df.show()
+
def test_serialize_nested_array_and_map(self):
d = [Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})]
rdd = self.sc.parallelize(d)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala
index 41e35fd724..e3a016e18d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python.scala
@@ -220,7 +220,14 @@ object EvaluatePython {
ArrayBasedMapData(keys, values)
case (c, StructType(fields)) if c.getClass.isArray =>
- new GenericInternalRow(c.asInstanceOf[Array[_]].zip(fields).map {
+ val array = c.asInstanceOf[Array[_]]
+ if (array.length != fields.length) {
+ throw new IllegalStateException(
+ s"Input row doesn't have expected number of values required by the schema. " +
+ s"${fields.length} fields are required while ${array.length} values are provided."
+ )
+ }
+ new GenericInternalRow(array.zip(fields).map {
case (e, f) => fromJava(e, f.dataType)
})