aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/mllib
diff options
context:
space:
mode:
authorTor Myklebust <tmyklebu@gmail.com>2013-12-25 00:53:48 -0500
committerTor Myklebust <tmyklebu@gmail.com>2013-12-25 00:53:48 -0500
commit02208a175c76be111eeb66dc19c7499a6656a067 (patch)
tree73345f6072aeb94072c1986045fdb70f6d4d5bb2 /python/pyspark/mllib
parent4e821390bca0d1f40b6f2f011bdc71476a1d3aa4 (diff)
downloadspark-02208a175c76be111eeb66dc19c7499a6656a067.tar.gz
spark-02208a175c76be111eeb66dc19c7499a6656a067.tar.bz2
spark-02208a175c76be111eeb66dc19c7499a6656a067.zip
Initial weights in Scala are ones; do that too. Also fix some errors.
Diffstat (limited to 'python/pyspark/mllib')
-rw-r--r--python/pyspark/mllib/_common.py12
1 files changed, 6 insertions, 6 deletions
diff --git a/python/pyspark/mllib/_common.py b/python/pyspark/mllib/_common.py
index e68bd8a9db..e74ba0fabc 100644
--- a/python/pyspark/mllib/_common.py
+++ b/python/pyspark/mllib/_common.py
@@ -15,7 +15,7 @@
# limitations under the License.
#
-from numpy import ndarray, copyto, float64, int64, int32, zeros, array_equal, array, dot, shape
+from numpy import ndarray, copyto, float64, int64, int32, ones, array_equal, array, dot, shape
from pyspark import SparkContext
# Double vector format:
@@ -143,7 +143,7 @@ def _linear_predictor_typecheck(x, coeffs):
elif (type(x) == RDD):
raise RuntimeError("Bulk predict not yet supported.")
else:
- raise TypeError("Argument of type " + type(x) + " unsupported")
+ raise TypeError("Argument of type " + type(x).__name__ + " unsupported")
def _get_unmangled_rdd(data, serializer):
dataBytes = data.map(serializer)
@@ -182,11 +182,11 @@ def _get_initial_weights(initial_weights, data):
initial_weights = data.first()
if type(initial_weights) != ndarray:
raise TypeError("At least one data element has type "
- + type(initial_weights) + " which is not ndarray")
+ + type(initial_weights).__name__ + " which is not ndarray")
if initial_weights.ndim != 1:
raise TypeError("At least one data element has "
+ initial_weights.ndim + " dimensions, which is not 1")
- initial_weights = zeros([initial_weights.shape[0] - 1])
+ initial_weights = ones([initial_weights.shape[0] - 1])
return initial_weights
# train_func should take two parameters, namely data and initial_weights, and
@@ -200,10 +200,10 @@ def _regression_train_wrapper(sc, train_func, klass, data, initial_weights):
raise RuntimeError("JVM call result had unexpected length")
elif type(ans[0]) != bytearray:
raise RuntimeError("JVM call result had first element of type "
- + type(ans[0]) + " which is not bytearray")
+ + type(ans[0]).__name__ + " which is not bytearray")
elif type(ans[1]) != float:
raise RuntimeError("JVM call result had second element of type "
- + type(ans[0]) + " which is not float")
+ + type(ans[0]).__name__ + " which is not float")
return klass(_deserialize_double_vector(ans[0]), ans[1])
def _serialize_rating(r):