aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-07-17 12:43:58 -0700
committerXiangrui Meng <meng@databricks.com>2015-07-17 12:43:58 -0700
commitf9a82a884e7cb2a466a33ab64912924ce7ee30c1 (patch)
tree43fbbe83838eeb7d28c03a454f101f98079529c7 /python
parent587c315b204f1439f696620543c38166d95f8a3d (diff)
downloadspark-f9a82a884e7cb2a466a33ab64912924ce7ee30c1.tar.gz
spark-f9a82a884e7cb2a466a33ab64912924ce7ee30c1.tar.bz2
spark-f9a82a884e7cb2a466a33ab64912924ce7ee30c1.zip
[SPARK-9138] [MLLIB] fix Vectors.dense
Vectors.dense() should accept numbers directly, like the one in Scala. We already use it in doctests, it worked by luck. cc mengxr jkbradley Author: Davies Liu <davies@databricks.com> Closes #7476 from davies/fix_vectors_dense and squashes the following commits: e0fd292 [Davies Liu] fix Vectors.dense
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/mllib/linalg.py11
1 files changed, 8 insertions, 3 deletions
diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py
index 040886f717..529bd75894 100644
--- a/python/pyspark/mllib/linalg.py
+++ b/python/pyspark/mllib/linalg.py
@@ -30,6 +30,7 @@ if sys.version >= '3':
basestring = str
xrange = range
import copyreg as copy_reg
+ long = int
else:
from itertools import izip as zip
import copy_reg
@@ -770,14 +771,18 @@ class Vectors(object):
return SparseVector(size, *args)
@staticmethod
- def dense(elements):
+ def dense(*elements):
"""
- Create a dense vector of 64-bit floats from a Python list. Always
- returns a NumPy array.
+ Create a dense vector of 64-bit floats from a Python list or numbers.
>>> Vectors.dense([1, 2, 3])
DenseVector([1.0, 2.0, 3.0])
+ >>> Vectors.dense(1.0, 2.0)
+ DenseVector([1.0, 2.0])
"""
+ if len(elements) == 1 and not isinstance(elements[0], (float, int, long)):
+ # it's list, numpy.array or other iterable object.
+ elements = elements[0]
return DenseVector(elements)
@staticmethod