aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorYuming Wang <q79969786@gmail.com>2015-11-11 09:43:26 -0800
committerXiangrui Meng <meng@databricks.com>2015-11-11 09:43:26 -0800
commit27524a3a9ccee6fbe56149180ebfb3f74e0957e7 (patch)
treee7ea7e631c3375c2cb1848234cc2ae0791a78d65 /mllib
parent1510c527b4f5ee0953ae42313ef9e16d2f5864c4 (diff)
downloadspark-27524a3a9ccee6fbe56149180ebfb3f74e0957e7.tar.gz
spark-27524a3a9ccee6fbe56149180ebfb3f74e0957e7.tar.bz2
spark-27524a3a9ccee6fbe56149180ebfb3f74e0957e7.zip
[SPARK-11626][ML] ml.feature.Word2Vec.transform() function very slow
org.apache.spark.ml.feature.Word2Vec.transform() very slow. we should not read broadcast every sentence. Author: Yuming Wang <q79969786@gmail.com> Author: yuming.wang <q79969786@gmail.com> Author: Xiangrui Meng <meng@databricks.com> Closes #9592 from 979969786/master.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala34
1 files changed, 16 insertions, 18 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
index 9edab3af91..5c64cb09d5 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala
@@ -17,18 +17,16 @@
package org.apache.spark.ml.feature
-import org.apache.spark.annotation.Experimental
import org.apache.spark.SparkContext
+import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
import org.apache.spark.mllib.feature
-import org.apache.spark.mllib.linalg.{VectorUDT, Vector, Vectors}
-import org.apache.spark.mllib.linalg.BLAS._
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.mllib.linalg.{BLAS, Vector, VectorUDT, Vectors}
+import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.sql.functions._
-import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.types._
/**
@@ -148,10 +146,9 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel]
@Experimental
class Word2VecModel private[ml] (
override val uid: String,
- wordVectors: feature.Word2VecModel)
+ @transient wordVectors: feature.Word2VecModel)
extends Model[Word2VecModel] with Word2VecBase {
-
/**
* Returns a dataframe with two fields, "word" and "vector", with "word" being a String and
* and the vector the DenseVector that it is mapped to.
@@ -197,22 +194,23 @@ class Word2VecModel private[ml] (
*/
override def transform(dataset: DataFrame): DataFrame = {
transformSchema(dataset.schema, logging = true)
- val bWordVectors = dataset.sqlContext.sparkContext.broadcast(wordVectors)
+ val vectors = wordVectors.getVectors
+ .mapValues(vv => Vectors.dense(vv.map(_.toDouble)))
+ .map(identity) // mapValues doesn't return a serializable map (SI-7005)
+ val bVectors = dataset.sqlContext.sparkContext.broadcast(vectors)
+ val d = $(vectorSize)
val word2Vec = udf { sentence: Seq[String] =>
if (sentence.size == 0) {
- Vectors.sparse($(vectorSize), Array.empty[Int], Array.empty[Double])
+ Vectors.sparse(d, Array.empty[Int], Array.empty[Double])
} else {
- val cum = Vectors.zeros($(vectorSize))
- val model = bWordVectors.value.getVectors
- for (word <- sentence) {
- if (model.contains(word)) {
- axpy(1.0, bWordVectors.value.transform(word), cum)
- } else {
- // pass words which not belong to model
+ val sum = Vectors.zeros(d)
+ sentence.foreach { word =>
+ bVectors.value.get(word).foreach { v =>
+ BLAS.axpy(1.0, v, sum)
}
}
- scal(1.0 / sentence.size, cum)
- cum
+ BLAS.scal(1.0 / sentence.size, sum)
+ sum
}
}
dataset.withColumn($(outputCol), word2Vec(col($(inputCol))))