aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-07-08 18:22:53 -0700
committerDavies Liu <davies.liu@gmail.com>2015-07-08 18:22:53 -0700
commit74d8d3d928cc9a7386b68588ac89ae042847d146 (patch)
tree0248cc711322eb4a7a6966e9cfbf3a90ca886733
parent2a4f88b6c16f2991e63b17c0e103bcd79f04dbbc (diff)
downloadspark-74d8d3d928cc9a7386b68588ac89ae042847d146.tar.gz
spark-74d8d3d928cc9a7386b68588ac89ae042847d146.tar.bz2
spark-74d8d3d928cc9a7386b68588ac89ae042847d146.zip
[SPARK-8450] [SQL] [PYSARK] cleanup type converter for Python DataFrame
This PR fixes the converter for Python DataFrame, especially for DecimalType Closes #7106 Author: Davies Liu <davies@databricks.com> Closes #7131 from davies/decimal_python and squashes the following commits: 4d3c234 [Davies Liu] Merge branch 'master' of github.com:apache/spark into decimal_python 20531d6 [Davies Liu] Merge branch 'master' of github.com:apache/spark into decimal_python 7d73168 [Davies Liu] fix conflit 6cdd86a [Davies Liu] Merge branch 'master' of github.com:apache/spark into decimal_python 7104e97 [Davies Liu] improve type infer 9cd5a21 [Davies Liu] run python tests with SPARK_PREPEND_CLASSES 829a05b [Davies Liu] fix UDT in python c99e8c5 [Davies Liu] fix mima c46814a [Davies Liu] convert decimal for Python DataFrames
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala10
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala16
-rw-r--r--project/MimaExcludes.scala5
-rw-r--r--python/pyspark/sql/tests.py13
-rw-r--r--python/pyspark/sql/types.py4
-rwxr-xr-xpython/run-tests.py3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala28
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala95
9 files changed, 84 insertions, 94 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
index 75e7004464..0df0766340 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Matrices.scala
@@ -24,9 +24,9 @@ import scala.collection.mutable.{ArrayBuilder => MArrayBuilder, HashSet => MHash
import breeze.linalg.{CSCMatrix => BSM, DenseMatrix => BDM, Matrix => BM}
import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.sql.Row
-import org.apache.spark.sql.types._
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.types._
/**
* Trait for a local matrix.
@@ -147,7 +147,7 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] {
))
}
- override def serialize(obj: Any): Row = {
+ override def serialize(obj: Any): InternalRow = {
val row = new GenericMutableRow(7)
obj match {
case sm: SparseMatrix =>
@@ -173,9 +173,7 @@ private[spark] class MatrixUDT extends UserDefinedType[Matrix] {
override def deserialize(datum: Any): Matrix = {
datum match {
- // TODO: something wrong with UDT serialization, should never happen.
- case m: Matrix => m
- case row: Row =>
+ case row: InternalRow =>
require(row.length == 7,
s"MatrixUDT.deserialize given row with length ${row.length} but requires length == 7")
val tpe = row.getByte(0)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
index c9c27425d2..e048b01d92 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/Vectors.scala
@@ -28,7 +28,7 @@ import breeze.linalg.{DenseVector => BDV, SparseVector => BSV, Vector => BV}
import org.apache.spark.SparkException
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.mllib.util.NumericParser
-import org.apache.spark.sql.Row
+import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.GenericMutableRow
import org.apache.spark.sql.types._
@@ -175,7 +175,7 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] {
StructField("values", ArrayType(DoubleType, containsNull = false), nullable = true)))
}
- override def serialize(obj: Any): Row = {
+ override def serialize(obj: Any): InternalRow = {
obj match {
case SparseVector(size, indices, values) =>
val row = new GenericMutableRow(4)
@@ -191,17 +191,12 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] {
row.setNullAt(2)
row.update(3, values.toSeq)
row
- // TODO: There are bugs in UDT serialization because we don't have a clear separation between
- // TODO: internal SQL types and language specific types (including UDT). UDT serialize and
- // TODO: deserialize may get called twice. See SPARK-7186.
- case row: Row =>
- row
}
}
override def deserialize(datum: Any): Vector = {
datum match {
- case row: Row =>
+ case row: InternalRow =>
require(row.length == 4,
s"VectorUDT.deserialize given row with length ${row.length} but requires length == 4")
val tpe = row.getByte(0)
@@ -215,11 +210,6 @@ private[spark] class VectorUDT extends UserDefinedType[Vector] {
val values = row.getAs[Iterable[Double]](3).toArray
new DenseVector(values)
}
- // TODO: There are bugs in UDT serialization because we don't have a clear separation between
- // TODO: internal SQL types and language specific types (including UDT). UDT serialize and
- // TODO: deserialize may get called twice. See SPARK-7186.
- case v: Vector =>
- v
}
}
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 57a86bf8de..821aadd477 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -63,7 +63,10 @@ object MimaExcludes {
// SQL execution is considered private.
excludePackage("org.apache.spark.sql.execution"),
// Parquet support is considered private.
- excludePackage("org.apache.spark.sql.parquet")
+ excludePackage("org.apache.spark.sql.parquet"),
+ // local function inside a method
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.sql.SQLContext.org$apache$spark$sql$SQLContext$$needsConversion$1")
) ++ Seq(
// SPARK-8479 Add numNonzeros and numActives to Matrix.
ProblemFilters.exclude[MissingMethodProblem](
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 333378c7f1..66827d4885 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -700,6 +700,19 @@ class SQLTests(ReusedPySparkTestCase):
self.assertTrue(now - now1 < datetime.timedelta(0.001))
self.assertTrue(now - utcnow1 < datetime.timedelta(0.001))
+ def test_decimal(self):
+ from decimal import Decimal
+ schema = StructType([StructField("decimal", DecimalType(10, 5))])
+ df = self.sqlCtx.createDataFrame([(Decimal("3.14159"),)], schema)
+ row = df.select(df.decimal + 1).first()
+ self.assertEqual(row[0], Decimal("4.14159"))
+ tmpPath = tempfile.mkdtemp()
+ shutil.rmtree(tmpPath)
+ df.write.parquet(tmpPath)
+ df2 = self.sqlCtx.read.parquet(tmpPath)
+ row = df2.first()
+ self.assertEqual(row[0], Decimal("3.14159"))
+
def test_dropna(self):
schema = StructType([
StructField("name", StringType(), True),
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index 160df40d65..7e64cb0b54 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -1069,6 +1069,10 @@ def _verify_type(obj, dataType):
if obj is None:
return
+ # StringType can work with any types
+ if isinstance(dataType, StringType):
+ return
+
if isinstance(dataType, UserDefinedType):
if not (hasattr(obj, '__UDT__') and obj.__UDT__ == dataType):
raise ValueError("%r is not an instance of type %r" % (obj, dataType))
diff --git a/python/run-tests.py b/python/run-tests.py
index 7638854def..cc56077937 100755
--- a/python/run-tests.py
+++ b/python/run-tests.py
@@ -72,7 +72,8 @@ LOGGER = logging.getLogger()
def run_individual_python_test(test_name, pyspark_python):
- env = {'SPARK_TESTING': '1', 'PYSPARK_PYTHON': which(pyspark_python)}
+ env = dict(os.environ)
+ env.update({'SPARK_TESTING': '1', 'PYSPARK_PYTHON': which(pyspark_python)})
LOGGER.debug("Starting test(%s): %s", pyspark_python, test_name)
start_time = time.time()
try:
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index eeefc85255..d9f987ae02 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -1549,8 +1549,8 @@ class DataFrame private[sql](
* Converts a JavaRDD to a PythonRDD.
*/
protected[sql] def javaToPython: JavaRDD[Array[Byte]] = {
- val fieldTypes = schema.fields.map(_.dataType)
- val jrdd = rdd.map(EvaluatePython.rowToArray(_, fieldTypes)).toJavaRDD()
+ val structType = schema // capture it for closure
+ val jrdd = queryExecution.toRdd.map(EvaluatePython.toJava(_, structType)).toJavaRDD()
SerDeUtil.javaToPython(jrdd)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 079f31ab8f..477dea9164 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -1044,33 +1044,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
rdd: RDD[Array[Any]],
schema: StructType): DataFrame = {
- def needsConversion(dataType: DataType): Boolean = dataType match {
- case ByteType => true
- case ShortType => true
- case LongType => true
- case FloatType => true
- case DateType => true
- case TimestampType => true
- case StringType => true
- case ArrayType(_, _) => true
- case MapType(_, _, _) => true
- case StructType(_) => true
- case udt: UserDefinedType[_] => needsConversion(udt.sqlType)
- case other => false
- }
-
- val convertedRdd = if (schema.fields.exists(f => needsConversion(f.dataType))) {
- rdd.map(m => m.zip(schema.fields).map {
- case (value, field) => EvaluatePython.fromJava(value, field.dataType)
- })
- } else {
- rdd
- }
-
- val rowRdd = convertedRdd.mapPartitions { iter =>
- iter.map { m => new GenericInternalRow(m): InternalRow}
- }
-
+ val rowRdd = rdd.map(r => EvaluatePython.fromJava(r, schema).asInstanceOf[InternalRow])
DataFrame(this, LogicalRDD(schema.toAttributes, rowRdd)(self))
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala
index 6946e798b7..1c8130b07c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUDFs.scala
@@ -24,20 +24,19 @@ import scala.collection.JavaConverters._
import net.razorvine.pickle.{Pickler, Unpickler}
-import org.apache.spark.{Accumulator, Logging => SparkLogging}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.api.python.{PythonBroadcast, PythonRDD}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.Row
import org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.catalyst.util.DateTimeUtils
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
+import org.apache.spark.{Accumulator, Logging => SparkLogging}
/**
* A serialized version of a Python lambda function. Suitable for use in a [[PythonRDD]].
@@ -125,59 +124,86 @@ object EvaluatePython {
new EvaluatePython(udf, child, AttributeReference("pythonUDF", udf.dataType)())
/**
- * Helper for converting a Scala object to a java suitable for pyspark serialization.
+ * Helper for converting from Catalyst type to java type suitable for Pyrolite.
*/
def toJava(obj: Any, dataType: DataType): Any = (obj, dataType) match {
case (null, _) => null
- case (row: Row, struct: StructType) =>
+ case (row: InternalRow, struct: StructType) =>
val fields = struct.fields.map(field => field.dataType)
- row.toSeq.zip(fields).map {
- case (obj, dataType) => toJava(obj, dataType)
- }.toArray
+ rowToArray(row, fields)
case (seq: Seq[Any], array: ArrayType) =>
seq.map(x => toJava(x, array.elementType)).asJava
- case (list: JList[_], array: ArrayType) =>
- list.map(x => toJava(x, array.elementType)).asJava
- case (arr, array: ArrayType) if arr.getClass.isArray =>
- arr.asInstanceOf[Array[Any]].map(x => toJava(x, array.elementType))
case (obj: Map[_, _], mt: MapType) => obj.map {
case (k, v) => (toJava(k, mt.keyType), toJava(v, mt.valueType))
}.asJava
- case (ud, udt: UserDefinedType[_]) => toJava(udt.serialize(ud), udt.sqlType)
+ case (ud, udt: UserDefinedType[_]) => toJava(ud, udt.sqlType)
case (date: Int, DateType) => DateTimeUtils.toJavaDate(date)
case (t: Long, TimestampType) => DateTimeUtils.toJavaTimestamp(t)
+
+ case (d: Decimal, _) => d.toJavaBigDecimal
+
case (s: UTF8String, StringType) => s.toString
- // Pyrolite can handle Timestamp and Decimal
case (other, _) => other
}
/**
* Convert Row into Java Array (for pickled into Python)
*/
- def rowToArray(row: Row, fields: Seq[DataType]): Array[Any] = {
+ def rowToArray(row: InternalRow, fields: Seq[DataType]): Array[Any] = {
// TODO: this is slow!
row.toSeq.zip(fields).map {case (obj, dt) => toJava(obj, dt)}.toArray
}
- // Converts value to the type specified by the data type.
- // Because Python does not have data types for TimestampType, FloatType, ShortType, and
- // ByteType, we need to explicitly convert values in columns of these data types to the desired
- // JVM data types.
+ /**
+ * Converts `obj` to the type specified by the data type, or returns null if the type of obj is
+ * unexpected. Because Python doesn't enforce the type.
+ */
def fromJava(obj: Any, dataType: DataType): Any = (obj, dataType) match {
- // TODO: We should check nullable
case (null, _) => null
+ case (c: Boolean, BooleanType) => c
+
+ case (c: Int, ByteType) => c.toByte
+ case (c: Long, ByteType) => c.toByte
+
+ case (c: Int, ShortType) => c.toShort
+ case (c: Long, ShortType) => c.toShort
+
+ case (c: Int, IntegerType) => c
+ case (c: Long, IntegerType) => c.toInt
+
+ case (c: Int, LongType) => c.toLong
+ case (c: Long, LongType) => c
+
+ case (c: Double, FloatType) => c.toFloat
+
+ case (c: Double, DoubleType) => c
+
+ case (c: java.math.BigDecimal, dt: DecimalType) => Decimal(c)
+
+ case (c: Int, DateType) => c
+
+ case (c: Long, TimestampType) => c
+
+ case (c: String, StringType) => UTF8String.fromString(c)
+ case (c, StringType) =>
+ // If we get here, c is not a string. Call toString on it.
+ UTF8String.fromString(c.toString)
+
+ case (c: String, BinaryType) => c.getBytes("utf-8")
+ case (c, BinaryType) if c.getClass.isArray && c.getClass.getComponentType.getName == "byte" => c
+
case (c: java.util.List[_], ArrayType(elementType, _)) =>
- c.map { e => fromJava(e, elementType)}: Seq[Any]
+ c.map { e => fromJava(e, elementType)}.toSeq
case (c, ArrayType(elementType, _)) if c.getClass.isArray =>
- c.asInstanceOf[Array[_]].map(e => fromJava(e, elementType)): Seq[Any]
+ c.asInstanceOf[Array[_]].map(e => fromJava(e, elementType)).toSeq
case (c: java.util.Map[_, _], MapType(keyType, valueType, _)) => c.map {
case (key, value) => (fromJava(key, keyType), fromJava(value, valueType))
@@ -188,30 +214,11 @@ object EvaluatePython {
case (e, f) => fromJava(e, f.dataType)
})
- case (c: java.util.Calendar, DateType) =>
- DateTimeUtils.fromJavaDate(new java.sql.Date(c.getTimeInMillis))
-
- case (c: java.util.Calendar, TimestampType) =>
- c.getTimeInMillis * 10000L
- case (t: java.sql.Timestamp, TimestampType) =>
- DateTimeUtils.fromJavaTimestamp(t)
-
- case (_, udt: UserDefinedType[_]) =>
- fromJava(obj, udt.sqlType)
-
- case (c: Int, ByteType) => c.toByte
- case (c: Long, ByteType) => c.toByte
- case (c: Int, ShortType) => c.toShort
- case (c: Long, ShortType) => c.toShort
- case (c: Long, IntegerType) => c.toInt
- case (c: Int, LongType) => c.toLong
- case (c: Double, FloatType) => c.toFloat
- case (c: String, StringType) => UTF8String.fromString(c)
- case (c, StringType) =>
- // If we get here, c is not a string. Call toString on it.
- UTF8String.fromString(c.toString)
+ case (_, udt: UserDefinedType[_]) => fromJava(obj, udt.sqlType)
- case (c, _) => c
+ // all other unexpected type should be null, or we will have runtime exception
+ // TODO(davies): we could improve this by try to cast the object to expected type
+ case (c, _) => null
}
}