From 27524a3a9ccee6fbe56149180ebfb3f74e0957e7 Mon Sep 17 00:00:00 2001 From: Yuming Wang Date: Wed, 11 Nov 2015 09:43:26 -0800 Subject: [SPARK-11626][ML] ml.feature.Word2Vec.transform() function very slow org.apache.spark.ml.feature.Word2Vec.transform() very slow. we should not read broadcast every sentence. Author: Yuming Wang Author: yuming.wang Author: Xiangrui Meng Closes #9592 from 979969786/master. --- .../org/apache/spark/ml/feature/Word2Vec.scala | 34 ++++++++++------------ 1 file changed, 16 insertions(+), 18 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index 9edab3af91..5c64cb09d5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -17,18 +17,16 @@ package org.apache.spark.ml.feature -import org.apache.spark.annotation.Experimental import org.apache.spark.SparkContext +import org.apache.spark.annotation.Experimental import org.apache.spark.ml.{Estimator, Model} import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util.{Identifiable, SchemaUtils} import org.apache.spark.mllib.feature -import org.apache.spark.mllib.linalg.{VectorUDT, Vector, Vectors} -import org.apache.spark.mllib.linalg.BLAS._ -import org.apache.spark.sql.DataFrame +import org.apache.spark.mllib.linalg.{BLAS, Vector, VectorUDT, Vectors} +import org.apache.spark.sql.{DataFrame, SQLContext} import org.apache.spark.sql.functions._ -import org.apache.spark.sql.SQLContext import org.apache.spark.sql.types._ /** @@ -148,10 +146,9 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel] @Experimental class Word2VecModel private[ml] ( override val uid: String, - wordVectors: feature.Word2VecModel) + @transient wordVectors: feature.Word2VecModel) extends Model[Word2VecModel] with Word2VecBase { - /** * Returns a dataframe with two fields, "word" and "vector", with "word" being a String and * and the vector the DenseVector that it is mapped to. @@ -197,22 +194,23 @@ class Word2VecModel private[ml] ( */ override def transform(dataset: DataFrame): DataFrame = { transformSchema(dataset.schema, logging = true) - val bWordVectors = dataset.sqlContext.sparkContext.broadcast(wordVectors) + val vectors = wordVectors.getVectors + .mapValues(vv => Vectors.dense(vv.map(_.toDouble))) + .map(identity) // mapValues doesn't return a serializable map (SI-7005) + val bVectors = dataset.sqlContext.sparkContext.broadcast(vectors) + val d = $(vectorSize) val word2Vec = udf { sentence: Seq[String] => if (sentence.size == 0) { - Vectors.sparse($(vectorSize), Array.empty[Int], Array.empty[Double]) + Vectors.sparse(d, Array.empty[Int], Array.empty[Double]) } else { - val cum = Vectors.zeros($(vectorSize)) - val model = bWordVectors.value.getVectors - for (word <- sentence) { - if (model.contains(word)) { - axpy(1.0, bWordVectors.value.transform(word), cum) - } else { - // pass words which not belong to model + val sum = Vectors.zeros(d) + sentence.foreach { word => + bVectors.value.get(word).foreach { v => + BLAS.axpy(1.0, v, sum) } } - scal(1.0 / sentence.size, cum) - cum + BLAS.scal(1.0 / sentence.size, sum) + sum } } dataset.withColumn($(outputCol), word2Vec(col($(inputCol)))) -- cgit v1.2.3