aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala6
1 files changed, 4 insertions, 2 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala
index 1b0a9a12e8..91989c3d2f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/ElementwiseProduct.scala
@@ -19,10 +19,12 @@ package org.apache.spark.ml.feature
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.UnaryTransformer
+import org.apache.spark.ml.linalg.{Vector, VectorUDT}
import org.apache.spark.ml.param.Param
import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable}
import org.apache.spark.mllib.feature
-import org.apache.spark.mllib.linalg.{Vector, VectorUDT}
+import org.apache.spark.mllib.linalg.{Vectors => OldVectors}
+import org.apache.spark.mllib.linalg.VectorImplicits._
import org.apache.spark.sql.types.DataType
/**
@@ -52,7 +54,7 @@ class ElementwiseProduct(override val uid: String)
override protected def createTransformFunc: Vector => Vector = {
require(params.contains(scalingVec), s"transformation requires a weight vector")
val elemScaler = new feature.ElementwiseProduct($(scalingVec))
- elemScaler.transform
+ v => elemScaler.transform(v)
}
override protected def outputDataType: DataType = new VectorUDT()