aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorHolden Karau <holden@us.ibm.com>2015-12-09 16:45:13 +0000
committerSean Owen <sowen@cloudera.com>2015-12-09 16:45:13 +0000
commit22b9a8740d51289434553d19b6b1ac34aecdc09a (patch)
tree2c589b7c9434f37a2e22b18868b47b0705c59db6 /mllib
parent6e1c55eac4849669e119ce0d51f6d051830deb9f (diff)
downloadspark-22b9a8740d51289434553d19b6b1ac34aecdc09a.tar.gz
spark-22b9a8740d51289434553d19b6b1ac34aecdc09a.tar.bz2
spark-22b9a8740d51289434553d19b6b1ac34aecdc09a.zip
[SPARK-10299][ML] word2vec should allow users to specify the window size
Currently word2vec has the window hard coded at 5, some users may want different sizes (for example if using on n-gram input or similar). User request comes from http://stackoverflow.com/questions/32231975/spark-word2vec-window-size . Author: Holden Karau <holden@us.ibm.com> Author: Holden Karau <holden@pigscanfly.ca> Closes #8513 from holdenk/SPARK-10299-word2vec-should-allow-users-to-specify-the-window-size.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala15
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala11
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala43
3 files changed, 65 insertions, 4 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
index a8d61b6dea..f105a983a3 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
@@ -50,6 +50,17 @@ private[feature] trait Word2VecBase extends Params
def getVectorSize: Int = $(vectorSize)
/**
+ * The window size (context words from [-window, window]) default 5.
+ * @group expertParam
+ */
+ final val windowSize = new IntParam(
+ this, "windowSize", "the window size (context words from [-window, window])")
+ setDefault(windowSize -> 5)
+
+ /** @group expertGetParam */
+ def getWindowSize: Int = $(windowSize)
+
+ /**
* Number of partitions for sentences of words.
* Default: 1
* @group param
@@ -106,6 +117,9 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel]
/** @group setParam */
def setVectorSize(value: Int): this.type = set(vectorSize, value)
+ /** @group expertSetParam */
+ def setWindowSize(value: Int): this.type = set(windowSize, value)
+
/** @group setParam */
def setStepSize(value: Double): this.type = set(stepSize, value)
@@ -131,6 +145,7 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel]
.setNumPartitions($(numPartitions))
.setSeed($(seed))
.setVectorSize($(vectorSize))
+ .setWindowSize($(windowSize))
.fit(input)
copyValues(new Word2VecModel(uid, wordVectors).setParent(this))
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
index 23b1514e30..1f400e1430 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/Word2Vec.scala
@@ -126,6 +126,15 @@ class Word2Vec extends Serializable with Logging {
}
/**
+ * Sets the window of words (default: 5)
+ */
+ @Since("1.6.0")
+ def setWindowSize(window: Int): this.type = {
+ this.window = window
+ this
+ }
+
+ /**
* Sets minCount, the minimum number of times a token must appear to be included in the word2vec
* model's vocabulary (default: 5).
*/
@@ -141,7 +150,7 @@ class Word2Vec extends Serializable with Logging {
private val MAX_SENTENCE_LENGTH = 1000
/** context words from [-window, window] */
- private val window = 5
+ private var window = 5
private var trainWordsCount = 0
private var vocabSize = 0
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
index a773244cd7..d561bbbb25 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
@@ -35,7 +35,8 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
}
test("Word2Vec") {
- val sqlContext = new SQLContext(sc)
+
+ val sqlContext = this.sqlContext
import sqlContext.implicits._
val sentence = "a b " * 100 + "a c " * 10
@@ -77,7 +78,7 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
test("getVectors") {
- val sqlContext = new SQLContext(sc)
+ val sqlContext = this.sqlContext
import sqlContext.implicits._
val sentence = "a b " * 100 + "a c " * 10
@@ -118,7 +119,7 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
test("findSynonyms") {
- val sqlContext = new SQLContext(sc)
+ val sqlContext = this.sqlContext
import sqlContext.implicits._
val sentence = "a b " * 100 + "a c " * 10
@@ -141,7 +142,43 @@ class Word2VecSuite extends SparkFunSuite with MLlibTestSparkContext with Defaul
expectedSimilarity.zip(similarity).map {
case (expected, actual) => assert(math.abs((expected - actual) / expected) < 1E-5)
}
+ }
+
+ test("window size") {
+
+ val sqlContext = this.sqlContext
+ import sqlContext.implicits._
+
+ val sentence = "a q s t q s t b b b s t m s t m q " * 100 + "a c " * 10
+ val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" "))
+ val docDF = doc.zip(doc).toDF("text", "alsotext")
+
+ val model = new Word2Vec()
+ .setVectorSize(3)
+ .setWindowSize(2)
+ .setInputCol("text")
+ .setOutputCol("result")
+ .setSeed(42L)
+ .fit(docDF)
+ val (synonyms, similarity) = model.findSynonyms("a", 6).map {
+ case Row(w: String, sim: Double) => (w, sim)
+ }.collect().unzip
+
+ // Increase the window size
+ val biggerModel = new Word2Vec()
+ .setVectorSize(3)
+ .setInputCol("text")
+ .setOutputCol("result")
+ .setSeed(42L)
+ .setWindowSize(10)
+ .fit(docDF)
+
+ val (synonymsLarger, similarityLarger) = model.findSynonyms("a", 6).map {
+ case Row(w: String, sim: Double) => (w, sim)
+ }.collect().unzip
+ // The similarity score should be very different with the larger window
+ assert(math.abs(similarity(5) - similarityLarger(5) / similarity(5)) > 1E-5)
}
test("Word2Vec read/write") {