aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorXusen Yin <yinxusen@gmail.com>2015-04-29 14:55:32 -0700
committerXiangrui Meng <meng@databricks.com>2015-04-29 14:55:32 -0700
commitc9d530e2e5123dbd4fd13fc487c890d6076b24bf (patch)
treeed767d35090ace95b7ea77df5d58519bff653ec7 /mllib
parent15995c883aa248235fdebf0cbeeaa3ef12c97e9c (diff)
downloadspark-c9d530e2e5123dbd4fd13fc487c890d6076b24bf.tar.gz
spark-c9d530e2e5123dbd4fd13fc487c890d6076b24bf.tar.bz2
spark-c9d530e2e5123dbd4fd13fc487c890d6076b24bf.zip
[SPARK-6529] [ML] Add Word2Vec transformer
See JIRA issue [here](https://issues.apache.org/jira/browse/SPARK-6529). There are some notes: 1. I add `learningRate` in sharedParams since it is a common parameter for ML algorithms. 2. We will not support transform of finding synonyms from a `Vector`, which will support in further JIRA issues. 3. Word2Vec is different with other ML models that its training set and transformed set are different. Its training set is an `RDD[Iterable[String]]` which represents documents, but the transformed set we want is an `RDD[String]` that represents unique words. So you have to switch your `inputCol` in these two stages. Author: Xusen Yin <yinxusen@gmail.com> Closes #5596 from yinxusen/SPARK-6529 and squashes the following commits: ee2b37a [Xusen Yin] merge with former HEAD 4945462 [Xusen Yin] merge with #5626 3bc2cbd [Xusen Yin] change foldLeft to for loop and use blas 5dd4ee7 [Xusen Yin] fix scala style 743e0d5 [Xusen Yin] fix comments and code style 04c48e9 [Xusen Yin] ensure the functionality a190f2c [Xusen Yin] fix code style and refine the transform function of word2vec 02848fa [Xusen Yin] refine comments 34a55c0 [Xusen Yin] fix errors 109d124 [Xusen Yin] add test suite and pass it 04dde06 [Xusen Yin] add shared params c594095 [Xusen Yin] add word2vec transformer 23d77fa [Xusen Yin] merge with #5626 e8cfaf7 [Xusen Yin] fix conflict with master 66e7bd3 [Xusen Yin] change foldLeft to for loop and use blas 566ec20 [Xusen Yin] fix scala style b54399f [Xusen Yin] fix comments and code style 1211e86 [Xusen Yin] ensure the functionality 6b97ec8 [Xusen Yin] fix code style and refine the transform function of word2vec 7cde18f [Xusen Yin] rm sharedParams 618abd0 [Xusen Yin] refine comments e29680a [Xusen Yin] fix errors fe3afe9 [Xusen Yin] add test suite and pass it 02767fb [Xusen Yin] add shared params 6a514f1 [Xusen Yin] add word2vec transformer
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala185
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala3
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala17
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala63
4 files changed, 267 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
new file mode 100644
index 0000000000..0163fa8bd8
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
@@ -0,0 +1,185 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.apache.spark.annotation.AlphaComponent
+import org.apache.spark.ml.param._
+import org.apache.spark.ml.param.shared._
+import org.apache.spark.ml.util.SchemaUtils
+import org.apache.spark.ml.{Estimator, Model}
+import org.apache.spark.mllib.feature
+import org.apache.spark.mllib.linalg.BLAS._
+import org.apache.spark.mllib.linalg.{VectorUDT, Vectors}
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.{DataFrame, Row}
+
+/**
+ * Params for [[Word2Vec]] and [[Word2VecModel]].
+ */
+private[feature] trait Word2VecBase extends Params
+ with HasInputCol with HasOutputCol with HasMaxIter with HasStepSize with HasSeed {
+
+ /**
+ * The dimension of the code that you want to transform from words.
+ */
+ final val vectorSize = new IntParam(
+ this, "vectorSize", "the dimension of codes after transforming from words")
+ setDefault(vectorSize -> 100)
+
+ /** @group getParam */
+ def getVectorSize: Int = getOrDefault(vectorSize)
+
+ /**
+ * Number of partitions for sentences of words.
+ */
+ final val numPartitions = new IntParam(
+ this, "numPartitions", "number of partitions for sentences of words")
+ setDefault(numPartitions -> 1)
+
+ /** @group getParam */
+ def getNumPartitions: Int = getOrDefault(numPartitions)
+
+ /**
+ * The minimum number of times a token must appear to be included in the word2vec model's
+ * vocabulary.
+ */
+ final val minCount = new IntParam(this, "minCount", "the minimum number of times a token must " +
+ "appear to be included in the word2vec model's vocabulary")
+ setDefault(minCount -> 5)
+
+ /** @group getParam */
+ def getMinCount: Int = getOrDefault(minCount)
+
+ setDefault(stepSize -> 0.025)
+ setDefault(maxIter -> 1)
+ setDefault(seed -> 42L)
+
+ /**
+ * Validate and transform the input schema.
+ */
+ protected def validateAndTransformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ val map = extractParamMap(paramMap)
+ SchemaUtils.checkColumnType(schema, map(inputCol), new ArrayType(StringType, true))
+ SchemaUtils.appendColumn(schema, map(outputCol), new VectorUDT)
+ }
+}
+
+/**
+ * :: AlphaComponent ::
+ * Word2Vec trains a model of `Map(String, Vector)`, i.e. transforms a word into a code for further
+ * natural language processing or machine learning process.
+ */
+@AlphaComponent
+final class Word2Vec extends Estimator[Word2VecModel] with Word2VecBase {
+
+ /** @group setParam */
+ def setInputCol(value: String): this.type = set(inputCol, value)
+
+ /** @group setParam */
+ def setOutputCol(value: String): this.type = set(outputCol, value)
+
+ /** @group setParam */
+ def setVectorSize(value: Int): this.type = set(vectorSize, value)
+
+ /** @group setParam */
+ def setStepSize(value: Double): this.type = set(stepSize, value)
+
+ /** @group setParam */
+ def setNumPartitions(value: Int): this.type = set(numPartitions, value)
+
+ /** @group setParam */
+ def setMaxIter(value: Int): this.type = set(maxIter, value)
+
+ /** @group setParam */
+ def setSeed(value: Long): this.type = set(seed, value)
+
+ /** @group setParam */
+ def setMinCount(value: Int): this.type = set(minCount, value)
+
+ override def fit(dataset: DataFrame, paramMap: ParamMap): Word2VecModel = {
+ transformSchema(dataset.schema, paramMap, logging = true)
+ val map = extractParamMap(paramMap)
+ val input = dataset.select(map(inputCol)).map { case Row(v: Seq[String]) => v }
+ val wordVectors = new feature.Word2Vec()
+ .setLearningRate(map(stepSize))
+ .setMinCount(map(minCount))
+ .setNumIterations(map(maxIter))
+ .setNumPartitions(map(numPartitions))
+ .setSeed(map(seed))
+ .setVectorSize(map(vectorSize))
+ .fit(input)
+ val model = new Word2VecModel(this, map, wordVectors)
+ Params.inheritValues(map, this, model)
+ model
+ }
+
+ override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ validateAndTransformSchema(schema, paramMap)
+ }
+}
+
+/**
+ * :: AlphaComponent ::
+ * Model fitted by [[Word2Vec]].
+ */
+@AlphaComponent
+class Word2VecModel private[ml] (
+ override val parent: Word2Vec,
+ override val fittingParamMap: ParamMap,
+ wordVectors: feature.Word2VecModel)
+ extends Model[Word2VecModel] with Word2VecBase {
+
+ /** @group setParam */
+ def setInputCol(value: String): this.type = set(inputCol, value)
+
+ /** @group setParam */
+ def setOutputCol(value: String): this.type = set(outputCol, value)
+
+ /**
+ * Transform a sentence column to a vector column to represent the whole sentence. The transform
+ * is performed by averaging all word vectors it contains.
+ */
+ override def transform(dataset: DataFrame, paramMap: ParamMap): DataFrame = {
+ transformSchema(dataset.schema, paramMap, logging = true)
+ val map = extractParamMap(paramMap)
+ val bWordVectors = dataset.sqlContext.sparkContext.broadcast(wordVectors)
+ val word2Vec = udf { sentence: Seq[String] =>
+ if (sentence.size == 0) {
+ Vectors.sparse(map(vectorSize), Array.empty[Int], Array.empty[Double])
+ } else {
+ val cum = Vectors.zeros(map(vectorSize))
+ val model = bWordVectors.value.getVectors
+ for (word <- sentence) {
+ if (model.contains(word)) {
+ axpy(1.0, bWordVectors.value.transform(word), cum)
+ } else {
+ // pass words which not belong to model
+ }
+ }
+ scal(1.0 / sentence.size, cum)
+ cum
+ }
+ }
+ dataset.withColumn(map(outputCol), word2Vec(col(map(inputCol))))
+ }
+
+ override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
+ validateAndTransformSchema(schema, paramMap)
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
index 3f7e8f5a0b..654cd72d53 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
@@ -48,7 +48,8 @@ private[shared] object SharedParamsCodeGen {
ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true")),
ParamDesc[Long]("seed", "random seed", Some("Utils.random.nextLong()")),
ParamDesc[Double]("elasticNetParam", "the ElasticNet mixing parameter"),
- ParamDesc[Double]("tol", "the convergence tolerance for iterative algorithms"))
+ ParamDesc[Double]("tol", "the convergence tolerance for iterative algorithms"),
+ ParamDesc[Double]("stepSize", "Step size to be used for each iteration of optimization."))
val code = genSharedParams(params)
val file = "src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala"
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
index 7d2c76d6c6..96d11ed76f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
@@ -310,4 +310,21 @@ trait HasTol extends Params {
/** @group getParam */
final def getTol: Double = getOrDefault(tol)
}
+
+/**
+ * :: DeveloperApi ::
+ * Trait for shared param stepSize.
+ */
+@DeveloperApi
+trait HasStepSize extends Params {
+
+ /**
+ * Param for Step size to be used for each iteration of optimization..
+ * @group param
+ */
+ final val stepSize: DoubleParam = new DoubleParam(this, "stepSize", "Step size to be used for each iteration of optimization.")
+
+ /** @group getParam */
+ final def getStepSize: Double = getOrDefault(stepSize)
+}
// scalastyle:on
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
new file mode 100644
index 0000000000..03ba86670d
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/Word2VecSuite.scala
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.sql.{Row, SQLContext}
+
+class Word2VecSuite extends FunSuite with MLlibTestSparkContext {
+
+ test("Word2Vec") {
+ val sqlContext = new SQLContext(sc)
+ import sqlContext.implicits._
+
+ val sentence = "a b " * 100 + "a c " * 10
+ val numOfWords = sentence.split(" ").size
+ val doc = sc.parallelize(Seq(sentence, sentence)).map(line => line.split(" "))
+
+ val codes = Map(
+ "a" -> Array(-0.2811822295188904,-0.6356269121170044,-0.3020961284637451),
+ "b" -> Array(1.0309048891067505,-1.29472815990448,0.22276712954044342),
+ "c" -> Array(-0.08456747233867645,0.5137411952018738,0.11731560528278351)
+ )
+
+ val expected = doc.map { sentence =>
+ Vectors.dense(sentence.map(codes.apply).reduce((word1, word2) =>
+ word1.zip(word2).map { case (v1, v2) => v1 + v2 }
+ ).map(_ / numOfWords))
+ }
+
+ val docDF = doc.zip(expected).toDF("text", "expected")
+
+ val model = new Word2Vec()
+ .setVectorSize(3)
+ .setInputCol("text")
+ .setOutputCol("result")
+ .fit(docDF)
+
+ model.transform(docDF).select("result", "expected").collect().foreach {
+ case Row(vector1: Vector, vector2: Vector) =>
+ assert(vector1 ~== vector2 absTol 1E-5, "Transformed vector is different with expected.")
+ }
+ }
+}
+