aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala10
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Transformer.scala18
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala5
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/params.scala11
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala4
6 files changed, 34 insertions, 18 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
index e545df1e37..081a574bee 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
@@ -162,11 +162,15 @@ class PipelineModel private[ml] (
}
override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = {
- transformSchema(dataset.schema, paramMap, logging = true)
- stages.foldLeft(dataset)((cur, transformer) => transformer.transform(cur, paramMap))
+ // Precedence of ParamMaps: paramMap > this.paramMap > fittingParamMap
+ val map = (fittingParamMap ++ this.paramMap) ++ paramMap
+ transformSchema(dataset.schema, map, logging = true)
+ stages.foldLeft(dataset)((cur, transformer) => transformer.transform(cur, map))
}
private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
- stages.foldLeft(schema)((cur, transformer) => transformer.transformSchema(cur, paramMap))
+ // Precedence of ParamMaps: paramMap > this.paramMap > fittingParamMap
+ val map = (fittingParamMap ++ this.paramMap) ++ paramMap
+ stages.foldLeft(schema)((cur, transformer) => transformer.transformSchema(cur, map))
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
index 490e6609ad..23fbd228d0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
@@ -18,16 +18,14 @@
package org.apache.spark.ml
import scala.annotation.varargs
-import scala.reflect.runtime.universe.TypeTag
import org.apache.spark.Logging
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.param._
import org.apache.spark.sql.SchemaRDD
import org.apache.spark.sql.api.java.JavaSchemaRDD
-import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.analysis.Star
-import org.apache.spark.sql.catalyst.dsl._
+import org.apache.spark.sql.catalyst.expressions.ScalaUdf
import org.apache.spark.sql.catalyst.types._
/**
@@ -86,7 +84,7 @@ abstract class Transformer extends PipelineStage with Params {
* Abstract class for transformers that take one input column, apply transformation, and output the
* result as a new column.
*/
-private[ml] abstract class UnaryTransformer[IN, OUT: TypeTag, T <: UnaryTransformer[IN, OUT, T]]
+private[ml] abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]]
extends Transformer with HasInputCol with HasOutputCol with Logging {
def setInputCol(value: String): T = set(inputCol, value).asInstanceOf[T]
@@ -100,6 +98,11 @@ private[ml] abstract class UnaryTransformer[IN, OUT: TypeTag, T <: UnaryTransfor
protected def createTransformFunc(paramMap: ParamMap): IN => OUT
/**
+ * Returns the data type of the output column.
+ */
+ protected def outputDataType: DataType
+
+ /**
* Validates the input type. Throw an exception if it is invalid.
*/
protected def validateInputType(inputType: DataType): Unit = {}
@@ -111,9 +114,8 @@ private[ml] abstract class UnaryTransformer[IN, OUT: TypeTag, T <: UnaryTransfor
if (schema.fieldNames.contains(map(outputCol))) {
throw new IllegalArgumentException(s"Output column ${map(outputCol)} already exists.")
}
- val output = ScalaReflection.schemaFor[OUT]
val outputFields = schema.fields :+
- StructField(map(outputCol), output.dataType, output.nullable)
+ StructField(map(outputCol), outputDataType, !outputDataType.isPrimitive)
StructType(outputFields)
}
@@ -121,7 +123,7 @@ private[ml] abstract class UnaryTransformer[IN, OUT: TypeTag, T <: UnaryTransfor
transformSchema(dataset.schema, paramMap, logging = true)
import dataset.sqlContext._
val map = this.paramMap ++ paramMap
- val udf = this.createTransformFunc(map)
- dataset.select(Star(None), udf.call(map(inputCol).attr) as map(outputCol))
+ val udf = ScalaUdf(this.createTransformFunc(map), outputDataType, Seq(map(inputCol).attr))
+ dataset.select(Star(None), udf as map(outputCol))
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
index b98b1755a3..e0bfb1e484 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
@@ -21,7 +21,8 @@ import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.UnaryTransformer
import org.apache.spark.ml.param.{IntParam, ParamMap}
import org.apache.spark.mllib.feature
-import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.linalg.{VectorUDT, Vector}
+import org.apache.spark.sql.catalyst.types.DataType
/**
* :: AlphaComponent ::
@@ -39,4 +40,6 @@ class HashingTF extends UnaryTransformer[Iterable[_], Vector, HashingTF] {
val hashingTF = new feature.HashingTF(paramMap(numFeatures))
hashingTF.transform
}
+
+ override protected def outputDataType: DataType = new VectorUDT()
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
index 0a6599b64c..9352f40f37 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
@@ -20,7 +20,7 @@ package org.apache.spark.ml.feature
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.UnaryTransformer
import org.apache.spark.ml.param.ParamMap
-import org.apache.spark.sql.{DataType, StringType}
+import org.apache.spark.sql.{DataType, StringType, ArrayType}
/**
* :: AlphaComponent ::
@@ -36,4 +36,6 @@ class Tokenizer extends UnaryTransformer[String, Seq[String], Tokenizer] {
protected override def validateInputType(inputType: DataType): Unit = {
require(inputType == StringType, s"Input type must be string type but got $inputType.")
}
+
+ override protected def outputDataType: DataType = new ArrayType(StringType, false)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
index 8fd46aef4b..4b4340af54 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
@@ -17,13 +17,12 @@
package org.apache.spark.ml.param
-import java.lang.reflect.Modifier
-
-import org.apache.spark.annotation.AlphaComponent
-
import scala.annotation.varargs
import scala.collection.mutable
+import java.lang.reflect.Modifier
+
+import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.Identifiable
/**
@@ -221,7 +220,9 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten
/**
* Puts a list of param pairs (overwrites if the input params exists).
+ * Not usable from Java
*/
+ @varargs
def put(paramPairs: ParamPair[_]*): this.type = {
paramPairs.foreach { p =>
put(p.param.asInstanceOf[Param[Any]], p.value)
@@ -282,6 +283,7 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten
* where the latter overwrites this if there exists conflicts.
*/
def ++(other: ParamMap): ParamMap = {
+ // TODO: Provide a better method name for Java users.
new ParamMap(this.map ++ other.map)
}
@@ -290,6 +292,7 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten
* Adds all parameters from the input param map into this param map.
*/
def ++=(other: ParamMap): this.type = {
+ // TODO: Provide a better method name for Java users.
this.map ++= other.map
this
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
index 8c4c9c6cf6..9fed513bec 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
@@ -96,7 +96,9 @@ private[spark] object BLAS extends Serializable with Logging {
* dot(x, y)
*/
def dot(x: Vector, y: Vector): Double = {
- require(x.size == y.size)
+ require(x.size == y.size,
+ "BLAS.dot(x: Vector, y:Vector) was given Vectors with non-matching sizes:" +
+ " x.size = " + x.size + ", y.size = " + y.size)
(x, y) match {
case (dx: DenseVector, dy: DenseVector) =>
dot(dx, dy)