aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--python/pyspark/mllib/linalg.py11
1 files changed, 8 insertions, 3 deletions
diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py
index 040886f717..529bd75894 100644
--- a/python/pyspark/mllib/linalg.py
+++ b/python/pyspark/mllib/linalg.py
@@ -30,6 +30,7 @@ if sys.version >= '3':
basestring = str
xrange = range
import copyreg as copy_reg
+ long = int
else:
from itertools import izip as zip
import copy_reg
@@ -770,14 +771,18 @@ class Vectors(object):
return SparseVector(size, *args)
@staticmethod
- def dense(elements):
+ def dense(*elements):
"""
- Create a dense vector of 64-bit floats from a Python list. Always
- returns a NumPy array.
+ Create a dense vector of 64-bit floats from a Python list or numbers.
>>> Vectors.dense([1, 2, 3])
DenseVector([1.0, 2.0, 3.0])
+ >>> Vectors.dense(1.0, 2.0)
+ DenseVector([1.0, 2.0])
"""
+ if len(elements) == 1 and not isinstance(elements[0], (float, int, long)):
+ # it's list, numpy.array or other iterable object.
+ elements = elements[0]
return DenseVector(elements)
@staticmethod