aboutsummaryrefslogtreecommitdiff
path: root/src/scala/spark/repl/ExecutorClassLoader.scala
blob: 13d81ec1cf096befaff8b7e9901b783faa17212e (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
package spark.repl

import java.io.{ByteArrayOutputStream, InputStream}
import java.net.{URI, URL, URLClassLoader, URLEncoder}
import java.util.concurrent.{Executors, ExecutorService}

import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path}

import org.objectweb.asm._
import org.objectweb.asm.commons.EmptyVisitor
import org.objectweb.asm.Opcodes._


/**
 * A ClassLoader that reads classes from a Hadoop FileSystem or HTTP URI,
 * used to load classes defined by the interpreter when the REPL is used
 */
class ExecutorClassLoader(classUri: String, parent: ClassLoader)
extends ClassLoader(parent) {
  val uri = new URI(classUri)
  val directory = uri.getPath

  // Hadoop FileSystem object for our URI, if it isn't using HTTP
  var fileSystem: FileSystem = {
    if (uri.getScheme() == "http")
      null
    else
      FileSystem.get(uri, new Configuration())
  }
  
  override def findClass(name: String): Class[_] = {
    try {
      val pathInDirectory = name.replace('.', '/') + ".class"
      val inputStream = {
        if (fileSystem != null)
          fileSystem.open(new Path(directory, pathInDirectory))
        else
          new URL(classUri + "/" + urlEncode(pathInDirectory)).openStream()
      }
      val bytes = readAndTransformClass(name, inputStream)
      inputStream.close()
      return defineClass(name, bytes, 0, bytes.length)
    } catch {
      case e: Exception => throw new ClassNotFoundException(name, e)
    }
  }
  
  def readAndTransformClass(name: String, in: InputStream): Array[Byte] = {
    if (name.startsWith("line") && name.endsWith("$iw$")) {
      // Class seems to be an interpreter "wrapper" object storing a val or var.
      // Replace its constructor with a dummy one that does not run the
      // initialization code placed there by the REPL. The val or var will
      // be initialized later through reflection when it is used in a task.
      val cr = new ClassReader(in)
      val cw = new ClassWriter(
        ClassWriter.COMPUTE_FRAMES + ClassWriter.COMPUTE_MAXS)
      val cleaner = new ConstructorCleaner(name, cw)
      cr.accept(cleaner, 0)
      return cw.toByteArray
    } else {
      // Pass the class through unmodified
      val bos = new ByteArrayOutputStream
      val bytes = new Array[Byte](4096)
      var done = false
      while (!done) {
        val num = in.read(bytes)
        if (num >= 0)
          bos.write(bytes, 0, num)
        else
          done = true
      }
      return bos.toByteArray
    }
  }

  /**
   * URL-encode a string, preserving only slashes
   */
  def urlEncode(str: String): String = {
    str.split('/').map(part => URLEncoder.encode(part, "UTF-8")).mkString("/")
  }
}

class ConstructorCleaner(className: String, cv: ClassVisitor)
extends ClassAdapter(cv) {
  override def visitMethod(access: Int, name: String, desc: String,
      sig: String, exceptions: Array[String]): MethodVisitor = {
    val mv = cv.visitMethod(access, name, desc, sig, exceptions)
    if (name == "<init>" && (access & ACC_STATIC) == 0) {
      // This is the constructor, time to clean it; just output some new
      // instructions to mv that create the object and set the static MODULE$
      // field in the class to point to it, but do nothing otherwise.
      mv.visitCode()
      mv.visitVarInsn(ALOAD, 0) // load this
      mv.visitMethodInsn(INVOKESPECIAL, "java/lang/Object", "<init>", "()V")
      mv.visitVarInsn(ALOAD, 0) // load this
      //val classType = className.replace('.', '/')
      //mv.visitFieldInsn(PUTSTATIC, classType, "MODULE$", "L" + classType + ";")
      mv.visitInsn(RETURN)
      mv.visitMaxs(-1, -1) // stack size and local vars will be auto-computed
      mv.visitEnd()
      return null
    } else {
      return mv
    }
  }
}