aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorSandeep <sandeep@techaddict.me>2014-04-15 00:19:43 -0700
committerMatei Zaharia <matei@databricks.com>2014-04-15 00:19:43 -0700
commitdf360917990ad95dde3c8e016ec42507d1566355 (patch)
tree621930a449382e043e58628a08db6e00d0843261 /python
parentc99bcb7feaa761c5826f2e1d844d0502a3b79538 (diff)
downloadspark-df360917990ad95dde3c8e016ec42507d1566355.tar.gz
spark-df360917990ad95dde3c8e016ec42507d1566355.tar.bz2
spark-df360917990ad95dde3c8e016ec42507d1566355.zip
SPARK-1426: Make MLlib work with NumPy versions older than 1.7
Currently it requires NumPy 1.7 due to using the copyto method (http://docs.scipy.org/doc/numpy/reference/generated/numpy.copyto.html) for extracting data out of an array. Replace it with a fallback Author: Sandeep <sandeep@techaddict.me> Closes #391 from techaddict/1426 and squashes the following commits: d365962 [Sandeep] SPARK-1426: Make MLlib work with NumPy versions older than 1.7 Currently it requires NumPy 1.7 due to using the copyto method (http://docs.scipy.org/doc/numpy/reference/generated/numpy.copyto.html) for extracting data out of an array. Replace it with a fallback
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/mllib/__init__.py6
-rw-r--r--python/pyspark/mllib/_common.py11
2 files changed, 9 insertions, 8 deletions
diff --git a/python/pyspark/mllib/__init__.py b/python/pyspark/mllib/__init__.py
index 538ff26ce7..4149f54931 100644
--- a/python/pyspark/mllib/__init__.py
+++ b/python/pyspark/mllib/__init__.py
@@ -19,8 +19,8 @@
Python bindings for MLlib.
"""
-# MLlib currently needs and NumPy 1.7+, so complain if lower
+# MLlib currently needs and NumPy 1.4+, so complain if lower
import numpy
-if numpy.version.version < '1.7':
- raise Exception("MLlib requires NumPy 1.7+")
+if numpy.version.version < '1.4':
+ raise Exception("MLlib requires NumPy 1.4+")
diff --git a/python/pyspark/mllib/_common.py b/python/pyspark/mllib/_common.py
index 7ef251d24c..e19f5d2aaa 100644
--- a/python/pyspark/mllib/_common.py
+++ b/python/pyspark/mllib/_common.py
@@ -15,7 +15,7 @@
# limitations under the License.
#
-from numpy import ndarray, copyto, float64, int64, int32, ones, array_equal, array, dot, shape, complex, issubdtype
+from numpy import ndarray, float64, int64, int32, ones, array_equal, array, dot, shape, complex, issubdtype
from pyspark import SparkContext, RDD
import numpy as np
@@ -72,8 +72,8 @@ def _serialize_double_vector(v):
header = ndarray(shape=[2], buffer=ba, dtype="int64")
header[0] = 1
header[1] = length
- copyto(ndarray(shape=[length], buffer=ba, offset=16,
- dtype="float64"), v)
+ arr_mid = ndarray(shape=[length], buffer=ba, offset=16, dtype="float64")
+ arr_mid[...] = v
return ba
def _deserialize_double_vector(ba):
@@ -112,8 +112,9 @@ def _serialize_double_matrix(m):
header[0] = 2
header[1] = rows
header[2] = cols
- copyto(ndarray(shape=[rows, cols], buffer=ba, offset=24,
- dtype="float64", order='C'), m)
+ arr_mid = ndarray(shape=[rows, cols], buffer=ba, offset=24,
+ dtype="float64", order='C')
+ arr_mid[...] = m
return ba
else:
raise TypeError("_serialize_double_matrix called on a "