From 02208a175c76be111eeb66dc19c7499a6656a067 Mon Sep 17 00:00:00 2001 From: Tor Myklebust Date: Wed, 25 Dec 2013 00:53:48 -0500 Subject: Initial weights in Scala are ones; do that too. Also fix some errors. --- python/pyspark/mllib/_common.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/python/pyspark/mllib/_common.py b/python/pyspark/mllib/_common.py index e68bd8a9db..e74ba0fabc 100644 --- a/python/pyspark/mllib/_common.py +++ b/python/pyspark/mllib/_common.py @@ -15,7 +15,7 @@ # limitations under the License. # -from numpy import ndarray, copyto, float64, int64, int32, zeros, array_equal, array, dot, shape +from numpy import ndarray, copyto, float64, int64, int32, ones, array_equal, array, dot, shape from pyspark import SparkContext # Double vector format: @@ -143,7 +143,7 @@ def _linear_predictor_typecheck(x, coeffs): elif (type(x) == RDD): raise RuntimeError("Bulk predict not yet supported.") else: - raise TypeError("Argument of type " + type(x) + " unsupported") + raise TypeError("Argument of type " + type(x).__name__ + " unsupported") def _get_unmangled_rdd(data, serializer): dataBytes = data.map(serializer) @@ -182,11 +182,11 @@ def _get_initial_weights(initial_weights, data): initial_weights = data.first() if type(initial_weights) != ndarray: raise TypeError("At least one data element has type " - + type(initial_weights) + " which is not ndarray") + + type(initial_weights).__name__ + " which is not ndarray") if initial_weights.ndim != 1: raise TypeError("At least one data element has " + initial_weights.ndim + " dimensions, which is not 1") - initial_weights = zeros([initial_weights.shape[0] - 1]) + initial_weights = ones([initial_weights.shape[0] - 1]) return initial_weights # train_func should take two parameters, namely data and initial_weights, and @@ -200,10 +200,10 @@ def _regression_train_wrapper(sc, train_func, klass, data, initial_weights): raise RuntimeError("JVM call result had unexpected length") elif type(ans[0]) != bytearray: raise RuntimeError("JVM call result had first element of type " - + type(ans[0]) + " which is not bytearray") + + type(ans[0]).__name__ + " which is not bytearray") elif type(ans[1]) != float: raise RuntimeError("JVM call result had second element of type " - + type(ans[0]) + " which is not float") + + type(ans[0]).__name__ + " which is not float") return klass(_deserialize_double_vector(ans[0]), ans[1]) def _serialize_rating(r): -- cgit v1.2.3