diff options
author | Sandeep <sandeep@techaddict.me> | 2014-04-10 11:17:41 -0700 |
---|---|---|
committer | Matei Zaharia <matei@databricks.com> | 2014-04-10 11:17:41 -0700 |
commit | 3bd312940e2f5250edaf3e88d6c23de25bb1d0a9 (patch) | |
tree | 06d7fc1c38541f641f4ce835ee9ad8bd498a047d /python | |
parent | 79820fe825ed7c09d55f50503b7ab53d4585e5f7 (diff) | |
download | spark-3bd312940e2f5250edaf3e88d6c23de25bb1d0a9.tar.gz spark-3bd312940e2f5250edaf3e88d6c23de25bb1d0a9.tar.bz2 spark-3bd312940e2f5250edaf3e88d6c23de25bb1d0a9.zip |
SPARK-1428: MLlib should convert non-float64 NumPy arrays to float64 instead of complaining
Author: Sandeep <sandeep@techaddict.me>
Closes #356 from techaddict/1428 and squashes the following commits:
3bdf5f6 [Sandeep] SPARK-1428: MLlib should convert non-float64 NumPy arrays to float64 instead of complaining
Diffstat (limited to 'python')
-rw-r--r-- | python/pyspark/mllib/_common.py | 18 |
1 files changed, 14 insertions, 4 deletions
diff --git a/python/pyspark/mllib/_common.py b/python/pyspark/mllib/_common.py index 20a0e309d1..7ef251d24c 100644 --- a/python/pyspark/mllib/_common.py +++ b/python/pyspark/mllib/_common.py @@ -15,8 +15,9 @@ # limitations under the License. # -from numpy import ndarray, copyto, float64, int64, int32, ones, array_equal, array, dot, shape +from numpy import ndarray, copyto, float64, int64, int32, ones, array_equal, array, dot, shape, complex, issubdtype from pyspark import SparkContext, RDD +import numpy as np from pyspark.serializers import Serializer import struct @@ -47,13 +48,22 @@ def _deserialize_byte_array(shape, ba, offset): return ar.copy() def _serialize_double_vector(v): - """Serialize a double vector into a mutually understood format.""" + """Serialize a double vector into a mutually understood format. + + >>> x = array([1,2,3]) + >>> y = _deserialize_double_vector(_serialize_double_vector(x)) + >>> array_equal(y, array([1.0, 2.0, 3.0])) + True + """ if type(v) != ndarray: raise TypeError("_serialize_double_vector called on a %s; " "wanted ndarray" % type(v)) + """complex is only datatype that can't be converted to float64""" + if issubdtype(v.dtype, complex): + raise TypeError("_serialize_double_vector called on a %s; " + "wanted ndarray" % type(v)) if v.dtype != float64: - raise TypeError("_serialize_double_vector called on an ndarray of %s; " - "wanted ndarray of float64" % v.dtype) + v = v.astype(float64) if v.ndim != 1: raise TypeError("_serialize_double_vector called on a %ddarray; " "wanted a 1darray" % v.ndim) |