diff options
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala | 8 |
1 files changed, 5 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala index 95bae1c8a3..a72692960f 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Word2Vec.scala @@ -27,7 +27,7 @@ import org.apache.spark.ml.param.shared._ import org.apache.spark.ml.util._ import org.apache.spark.mllib.feature import org.apache.spark.mllib.linalg.{BLAS, Vector, Vectors, VectorUDT} -import org.apache.spark.sql.{DataFrame, SQLContext} +import org.apache.spark.sql.{DataFrame, Dataset, SQLContext} import org.apache.spark.sql.functions._ import org.apache.spark.sql.types._ @@ -135,7 +135,8 @@ final class Word2Vec(override val uid: String) extends Estimator[Word2VecModel] /** @group setParam */ def setMinCount(value: Int): this.type = set(minCount, value) - override def fit(dataset: DataFrame): Word2VecModel = { + @Since("2.0.0") + override def fit(dataset: Dataset[_]): Word2VecModel = { transformSchema(dataset.schema, logging = true) val input = dataset.select($(inputCol)).rdd.map(_.getAs[Seq[String]](0)) val wordVectors = new feature.Word2Vec() @@ -219,7 +220,8 @@ class Word2VecModel private[ml] ( * Transform a sentence column to a vector column to represent the whole sentence. The transform * is performed by averaging all word vectors it contains. */ - override def transform(dataset: DataFrame): DataFrame = { + @Since("2.0.0") + override def transform(dataset: Dataset[_]): DataFrame = { transformSchema(dataset.schema, logging = true) val vectors = wordVectors.getVectors .mapValues(vv => Vectors.dense(vv.map(_.toDouble))) |