aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/ml/feature.py28
-rw-r--r--python/pyspark/ml/tests.py5
-rw-r--r--python/pyspark/mllib/feature.py11
-rw-r--r--python/pyspark/mllib/tests.py4
4 files changed, 41 insertions, 7 deletions
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index 776906eaab..49a78ede37 100644
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -2219,28 +2219,31 @@ class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, Has
minCount = Param(Params._dummy(), "minCount",
"the minimum number of times a token must appear to be included in the " +
"word2vec model's vocabulary", typeConverter=TypeConverters.toInt)
+ windowSize = Param(Params._dummy(), "windowSize",
+ "the window size (context words from [-window, window]). Default value is 5",
+ typeConverter=TypeConverters.toInt)
@keyword_only
def __init__(self, vectorSize=100, minCount=5, numPartitions=1, stepSize=0.025, maxIter=1,
- seed=None, inputCol=None, outputCol=None):
+ seed=None, inputCol=None, outputCol=None, windowSize=5):
"""
__init__(self, vectorSize=100, minCount=5, numPartitions=1, stepSize=0.025, maxIter=1, \
- seed=None, inputCol=None, outputCol=None)
+ seed=None, inputCol=None, outputCol=None, windowSize=5)
"""
super(Word2Vec, self).__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Word2Vec", self.uid)
self._setDefault(vectorSize=100, minCount=5, numPartitions=1, stepSize=0.025, maxIter=1,
- seed=None)
+ seed=None, windowSize=5)
kwargs = self.__init__._input_kwargs
self.setParams(**kwargs)
@keyword_only
@since("1.4.0")
def setParams(self, vectorSize=100, minCount=5, numPartitions=1, stepSize=0.025, maxIter=1,
- seed=None, inputCol=None, outputCol=None):
+ seed=None, inputCol=None, outputCol=None, windowSize=5):
"""
setParams(self, minCount=5, numPartitions=1, stepSize=0.025, maxIter=1, seed=None, \
- inputCol=None, outputCol=None)
+ inputCol=None, outputCol=None, windowSize=5)
Sets params for this Word2Vec.
"""
kwargs = self.setParams._input_kwargs
@@ -2291,6 +2294,21 @@ class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, Has
"""
return self.getOrDefault(self.minCount)
+ @since("2.0.0")
+ def setWindowSize(self, value):
+ """
+ Sets the value of :py:attr:`windowSize`.
+ """
+ self._set(windowSize=value)
+ return self
+
+ @since("2.0.0")
+ def getWindowSize(self):
+ """
+ Gets the value of windowSize or its default value.
+ """
+ return self.getOrDefault(self.windowSize)
+
def _create_model(self, java_model):
return Word2VecModel(java_model)
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 9d6ff47b54..f1bca6ebe0 100644
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -341,6 +341,11 @@ class ParamTests(PySparkTestCase):
params = param_store.params # should not invoke the property 'test_property'
self.assertEqual(len(params), 1)
+ def test_word2vec_param(self):
+ model = Word2Vec().setWindowSize(6)
+ # Check windowSize is set properly
+ self.assertEqual(model.getWindowSize(), 6)
+
class FeatureTests(PySparkTestCase):
diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py
index b3dd2f63a5..90559f6cfb 100644
--- a/python/pyspark/mllib/feature.py
+++ b/python/pyspark/mllib/feature.py
@@ -617,6 +617,7 @@ class Word2Vec(object):
self.numIterations = 1
self.seed = random.randint(0, sys.maxsize)
self.minCount = 5
+ self.windowSize = 5
@since('1.2.0')
def setVectorSize(self, vectorSize):
@@ -669,6 +670,14 @@ class Word2Vec(object):
self.minCount = minCount
return self
+ @since('2.0.0')
+ def setWindowSize(self, windowSize):
+ """
+ Sets window size (default: 5).
+ """
+ self.windowSize = windowSize
+ return self
+
@since('1.2.0')
def fit(self, data):
"""
@@ -682,7 +691,7 @@ class Word2Vec(object):
jmodel = callMLlibFunc("trainWord2VecModel", data, int(self.vectorSize),
float(self.learningRate), int(self.numPartitions),
int(self.numIterations), int(self.seed),
- int(self.minCount))
+ int(self.minCount), int(self.windowSize))
return Word2VecModel(jmodel)
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index ac55fbf798..f272da56d1 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -1027,13 +1027,15 @@ class Word2VecTests(MLlibTestCase):
.setNumPartitions(2) \
.setNumIterations(10) \
.setSeed(1024) \
- .setMinCount(3)
+ .setMinCount(3) \
+ .setWindowSize(6)
self.assertEqual(model.vectorSize, 2)
self.assertTrue(model.learningRate < 0.02)
self.assertEqual(model.numPartitions, 2)
self.assertEqual(model.numIterations, 10)
self.assertEqual(model.seed, 1024)
self.assertEqual(model.minCount, 3)
+ self.assertEqual(model.windowSize, 6)
def test_word2vec_get_vectors(self):
data = [