aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark
diff options
context:
space:
mode:
authorSandeep <sandeep@techaddict.me>2014-04-10 11:17:41 -0700
committerMatei Zaharia <matei@databricks.com>2014-04-10 11:17:41 -0700
commit3bd312940e2f5250edaf3e88d6c23de25bb1d0a9 (patch)
tree06d7fc1c38541f641f4ce835ee9ad8bd498a047d /python/pyspark
parent79820fe825ed7c09d55f50503b7ab53d4585e5f7 (diff)
downloadspark-3bd312940e2f5250edaf3e88d6c23de25bb1d0a9.tar.gz
spark-3bd312940e2f5250edaf3e88d6c23de25bb1d0a9.tar.bz2
spark-3bd312940e2f5250edaf3e88d6c23de25bb1d0a9.zip
SPARK-1428: MLlib should convert non-float64 NumPy arrays to float64 instead of complaining
Author: Sandeep <sandeep@techaddict.me> Closes #356 from techaddict/1428 and squashes the following commits: 3bdf5f6 [Sandeep] SPARK-1428: MLlib should convert non-float64 NumPy arrays to float64 instead of complaining
Diffstat (limited to 'python/pyspark')
-rw-r--r--python/pyspark/mllib/_common.py18
1 files changed, 14 insertions, 4 deletions
diff --git a/python/pyspark/mllib/_common.py b/python/pyspark/mllib/_common.py
index 20a0e309d1..7ef251d24c 100644
--- a/python/pyspark/mllib/_common.py
+++ b/python/pyspark/mllib/_common.py
@@ -15,8 +15,9 @@
# limitations under the License.
#
-from numpy import ndarray, copyto, float64, int64, int32, ones, array_equal, array, dot, shape
+from numpy import ndarray, copyto, float64, int64, int32, ones, array_equal, array, dot, shape, complex, issubdtype
from pyspark import SparkContext, RDD
+import numpy as np
from pyspark.serializers import Serializer
import struct
@@ -47,13 +48,22 @@ def _deserialize_byte_array(shape, ba, offset):
return ar.copy()
def _serialize_double_vector(v):
- """Serialize a double vector into a mutually understood format."""
+ """Serialize a double vector into a mutually understood format.
+
+ >>> x = array([1,2,3])
+ >>> y = _deserialize_double_vector(_serialize_double_vector(x))
+ >>> array_equal(y, array([1.0, 2.0, 3.0]))
+ True
+ """
if type(v) != ndarray:
raise TypeError("_serialize_double_vector called on a %s; "
"wanted ndarray" % type(v))
+ """complex is only datatype that can't be converted to float64"""
+ if issubdtype(v.dtype, complex):
+ raise TypeError("_serialize_double_vector called on a %s; "
+ "wanted ndarray" % type(v))
if v.dtype != float64:
- raise TypeError("_serialize_double_vector called on an ndarray of %s; "
- "wanted ndarray of float64" % v.dtype)
+ v = v.astype(float64)
if v.ndim != 1:
raise TypeError("_serialize_double_vector called on a %ddarray; "
"wanted a 1darray" % v.ndim)