aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala34
1 files changed, 16 insertions, 18 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
index 9edab3af91..5c64cb09d5 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
@@ -17,18 +17,16 @@
package org.apache.spark.ml.feature
-import org.apache.spark.annotation.Experimental
import org.apache.spark.SparkContext
+import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
import org.apache.spark.mllib.feature
-import org.apache.spark.mllib.linalg.{VectorUDT, Vector, Vectors}
-import org.apache.spark.mllib.linalg.BLAS._
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.mllib.linalg.{BLAS, Vector, VectorUDT, Vectors}
+import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.sql.functions._
-import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.types._
/**
@@ -148,10 +146,9 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel]
@Experimental
class Word2VecModel private[ml] (
override val uid: String,
- wordVectors: feature.Word2VecModel)
+ @transient wordVectors: feature.Word2VecModel)
extends Model[Word2VecModel] with Word2VecBase {
-
/**
* Returns a dataframe with two fields, "word" and "vector", with "word" being a String and
* and the vector the DenseVector that it is mapped to.
@@ -197,22 +194,23 @@ class Word2VecModel private[ml] (
*/
override def transform(dataset: DataFrame): DataFrame = {
transformSchema(dataset.schema, logging = true)
- val bWordVectors = dataset.sqlContext.sparkContext.broadcast(wordVectors)
+ val vectors = wordVectors.getVectors
+ .mapValues(vv => Vectors.dense(vv.map(_.toDouble)))
+ .map(identity) // mapValues doesn't return a serializable map (SI-7005)
+ val bVectors = dataset.sqlContext.sparkContext.broadcast(vectors)
+ val d = $(vectorSize)
val word2Vec = udf { sentence: Seq[String] =>
if (sentence.size == 0) {
- Vectors.sparse($(vectorSize), Array.empty[Int], Array.empty[Double])
+ Vectors.sparse(d, Array.empty[Int], Array.empty[Double])
} else {
- val cum = Vectors.zeros($(vectorSize))
- val model = bWordVectors.value.getVectors
- for (word <- sentence) {
- if (model.contains(word)) {
- axpy(1.0, bWordVectors.value.transform(word), cum)
- } else {
- // pass words which not belong to model
+ val sum = Vectors.zeros(d)
+ sentence.foreach { word =>
+ bVectors.value.get(word).foreach { v =>
+ BLAS.axpy(1.0, v, sum)
}
}
- scal(1.0 / sentence.size, cum)
- cum
+ BLAS.scal(1.0 / sentence.size, sum)
+ sum
}
}
dataset.withColumn($(outputCol), word2Vec(col($(inputCol))))