aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2014-12-04 17:00:06 +0800
committerXiangrui Meng <meng@databricks.com>2014-12-04 17:00:06 +0800
commit469a6e5f3bdd5593b3254bc916be8236e7c6cb74 (patch)
treefd9756fcaf83aca60724616dd9abaa55b7e5c6dd /mllib
parent529439bd506949f272a2b6f099ea549b097428f3 (diff)
downloadspark-469a6e5f3bdd5593b3254bc916be8236e7c6cb74.tar.gz
spark-469a6e5f3bdd5593b3254bc916be8236e7c6cb74.tar.bz2
spark-469a6e5f3bdd5593b3254bc916be8236e7c6cb74.zip
[SPARK-4575] [mllib] [docs] spark.ml pipelines doc + bug fixes
Documentation: * Added ml-guide.md, linked from mllib-guide.md * Updated mllib-guide.md with small section pointing to ml-guide.md Examples: * CrossValidatorExample * SimpleParamsExample * (I copied these + the SimpleTextClassificationPipeline example into the ml-guide.md) Bug fixes: * PipelineModel: did not use ParamMaps correctly * UnaryTransformer: issues with TypeTag serialization (Thanks to mengxr for that fix!) CC: mengxr shivaram etrain Documentation for Pipelines: I know the docs are not complete, but the goal is to have enough to let interested people get started using spark.ml and to add more docs once the package is more established/complete. Author: Joseph K. Bradley <joseph@databricks.com> Author: jkbradley <joseph.kurata.bradley@gmail.com> Author: Xiangrui Meng <meng@databricks.com> Closes #3588 from jkbradley/ml-package-docs and squashes the following commits: d393b5c [Joseph K. Bradley] fixed bug in Pipeline (typo from last commit). updated examples for CV and Params for spark.ml c38469c [Joseph K. Bradley] Updated ml-guide with CV examples 99f88c2 [Joseph K. Bradley] Fixed bug in PipelineModel.transform* with usage of params. Updated CrossValidatorExample to use more training examples so it is less likely to get a 0-size fold. ea34dc6 [jkbradley] Merge pull request #4 from mengxr/ml-package-docs 3b83ec0 [Xiangrui Meng] replace TypeTag with explicit datatype 41ad9b1 [Joseph K. Bradley] Added examples for spark.ml: SimpleParamsExample + Java version, CrossValidatorExample + Java version. CrossValidatorExample not working yet. Added programming guide for spark.ml, but need to add CrossValidatorExample to it once CrossValidatorExample works.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala10
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/Transformer.scala18
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala5
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/params.scala11
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala4
6 files changed, 34 insertions, 18 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
index e545df1e37..081a574bee 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Pipeline.scala
@@ -162,11 +162,15 @@ class PipelineModel private[ml] (
}
override def transform(dataset: SchemaRDD, paramMap: ParamMap): SchemaRDD = {
- transformSchema(dataset.schema, paramMap, logging = true)
- stages.foldLeft(dataset)((cur, transformer) => transformer.transform(cur, paramMap))
+ // Precedence of ParamMaps: paramMap > this.paramMap > fittingParamMap
+ val map = (fittingParamMap ++ this.paramMap) ++ paramMap
+ transformSchema(dataset.schema, map, logging = true)
+ stages.foldLeft(dataset)((cur, transformer) => transformer.transform(cur, map))
}
private[ml] override def transformSchema(schema: StructType, paramMap: ParamMap): StructType = {
- stages.foldLeft(schema)((cur, transformer) => transformer.transformSchema(cur, paramMap))
+ // Precedence of ParamMaps: paramMap > this.paramMap > fittingParamMap
+ val map = (fittingParamMap ++ this.paramMap) ++ paramMap
+ stages.foldLeft(schema)((cur, transformer) => transformer.transformSchema(cur, map))
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
index 490e6609ad..23fbd228d0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala
@@ -18,16 +18,14 @@
package org.apache.spark.ml
import scala.annotation.varargs
-import scala.reflect.runtime.universe.TypeTag
import org.apache.spark.Logging
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.param._
import org.apache.spark.sql.SchemaRDD
import org.apache.spark.sql.api.java.JavaSchemaRDD
-import org.apache.spark.sql.catalyst.ScalaReflection
import org.apache.spark.sql.catalyst.analysis.Star
-import org.apache.spark.sql.catalyst.dsl._
+import org.apache.spark.sql.catalyst.expressions.ScalaUdf
import org.apache.spark.sql.catalyst.types._
/**
@@ -86,7 +84,7 @@ abstract class Transformer extends PipelineStage with Params {
* Abstract class for transformers that take one input column, apply transformation, and output the
* result as a new column.
*/
-private[ml] abstract class UnaryTransformer[IN, OUT: TypeTag, T <: UnaryTransformer[IN, OUT, T]]
+private[ml] abstract class UnaryTransformer[IN, OUT, T <: UnaryTransformer[IN, OUT, T]]
extends Transformer with HasInputCol with HasOutputCol with Logging {
def setInputCol(value: String): T = set(inputCol, value).asInstanceOf[T]
@@ -100,6 +98,11 @@ private[ml] abstract class UnaryTransformer[IN, OUT: TypeTag, T <: UnaryTransfor
protected def createTransformFunc(paramMap: ParamMap): IN => OUT
/**
+ * Returns the data type of the output column.
+ */
+ protected def outputDataType: DataType
+
+ /**
* Validates the input type. Throw an exception if it is invalid.
*/
protected def validateInputType(inputType: DataType): Unit = {}
@@ -111,9 +114,8 @@ private[ml] abstract class UnaryTransformer[IN, OUT: TypeTag, T <: UnaryTransfor
if (schema.fieldNames.contains(map(outputCol))) {
throw new IllegalArgumentException(s"Output column ${map(outputCol)} already exists.")
}
- val output = ScalaReflection.schemaFor[OUT]
val outputFields = schema.fields :+
- StructField(map(outputCol), output.dataType, output.nullable)
+ StructField(map(outputCol), outputDataType, !outputDataType.isPrimitive)
StructType(outputFields)
}
@@ -121,7 +123,7 @@ private[ml] abstract class UnaryTransformer[IN, OUT: TypeTag, T <: UnaryTransfor
transformSchema(dataset.schema, paramMap, logging = true)
import dataset.sqlContext._
val map = this.paramMap ++ paramMap
- val udf = this.createTransformFunc(map)
- dataset.select(Star(None), udf.call(map(inputCol).attr) as map(outputCol))
+ val udf = ScalaUdf(this.createTransformFunc(map), outputDataType, Seq(map(inputCol).attr))
+ dataset.select(Star(None), udf as map(outputCol))
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
index b98b1755a3..e0bfb1e484 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
@@ -21,7 +21,8 @@ import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.UnaryTransformer
import org.apache.spark.ml.param.{IntParam, ParamMap}
import org.apache.spark.mllib.feature
-import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.linalg.{VectorUDT, Vector}
+import org.apache.spark.sql.catalyst.types.DataType
/**
* :: AlphaComponent ::
@@ -39,4 +40,6 @@ class HashingTF extends UnaryTransformer[Iterable[_], Vector, HashingTF] {
val hashingTF = new feature.HashingTF(paramMap(numFeatures))
hashingTF.transform
}
+
+ override protected def outputDataType: DataType = new VectorUDT()
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
index 0a6599b64c..9352f40f37 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
@@ -20,7 +20,7 @@ package org.apache.spark.ml.feature
import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.UnaryTransformer
import org.apache.spark.ml.param.ParamMap
-import org.apache.spark.sql.{DataType, StringType}
+import org.apache.spark.sql.{DataType, StringType, ArrayType}
/**
* :: AlphaComponent ::
@@ -36,4 +36,6 @@ class Tokenizer extends UnaryTransformer[String, Seq[String], Tokenizer] {
protected override def validateInputType(inputType: DataType): Unit = {
require(inputType == StringType, s"Input type must be string type but got $inputType.")
}
+
+ override protected def outputDataType: DataType = new ArrayType(StringType, false)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
index 8fd46aef4b..4b4340af54 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
@@ -17,13 +17,12 @@
package org.apache.spark.ml.param
-import java.lang.reflect.Modifier
-
-import org.apache.spark.annotation.AlphaComponent
-
import scala.annotation.varargs
import scala.collection.mutable
+import java.lang.reflect.Modifier
+
+import org.apache.spark.annotation.AlphaComponent
import org.apache.spark.ml.Identifiable
/**
@@ -221,7 +220,9 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten
/**
* Puts a list of param pairs (overwrites if the input params exists).
+ * Not usable from Java
*/
+ @varargs
def put(paramPairs: ParamPair[_]*): this.type = {
paramPairs.foreach { p =>
put(p.param.asInstanceOf[Param[Any]], p.value)
@@ -282,6 +283,7 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten
* where the latter overwrites this if there exists conflicts.
*/
def ++(other: ParamMap): ParamMap = {
+ // TODO: Provide a better method name for Java users.
new ParamMap(this.map ++ other.map)
}
@@ -290,6 +292,7 @@ class ParamMap private[ml] (private val map: mutable.Map[Param[Any], Any]) exten
* Adds all parameters from the input param map into this param map.
*/
def ++=(other: ParamMap): this.type = {
+ // TODO: Provide a better method name for Java users.
this.map ++= other.map
this
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
index 8c4c9c6cf6..9fed513bec 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/linalg/BLAS.scala
@@ -96,7 +96,9 @@ private[spark] object BLAS extends Serializable with Logging {
* dot(x, y)
*/
def dot(x: Vector, y: Vector): Double = {
- require(x.size == y.size)
+ require(x.size == y.size,
+ "BLAS.dot(x: Vector, y:Vector) was given Vectors with non-matching sizes:" +
+ " x.size = " + x.size + ", y.size = " + y.size)
(x, y) match {
case (dx: DenseVector, dy: DenseVector) =>
dot(dx, dy)