diff options
-rw-r--r-- | python/pyspark/mllib/linalg.py | 11 |
1 files changed, 8 insertions, 3 deletions
diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py index 040886f717..529bd75894 100644 --- a/python/pyspark/mllib/linalg.py +++ b/python/pyspark/mllib/linalg.py @@ -30,6 +30,7 @@ if sys.version >= '3': basestring = str xrange = range import copyreg as copy_reg + long = int else: from itertools import izip as zip import copy_reg @@ -770,14 +771,18 @@ class Vectors(object): return SparseVector(size, *args) @staticmethod - def dense(elements): + def dense(*elements): """ - Create a dense vector of 64-bit floats from a Python list. Always - returns a NumPy array. + Create a dense vector of 64-bit floats from a Python list or numbers. >>> Vectors.dense([1, 2, 3]) DenseVector([1.0, 2.0, 3.0]) + >>> Vectors.dense(1.0, 2.0) + DenseVector([1.0, 2.0]) """ + if len(elements) == 1 and not isinstance(elements[0], (float, int, long)): + # it's list, numpy.array or other iterable object. + elements = elements[0] return DenseVector(elements) @staticmethod |