From c98408adf2d96928fe227a740631a8efd8e0c339 Mon Sep 17 00:00:00 2001 From: Li Haoyi Date: Wed, 21 Feb 2018 21:05:37 -0800 Subject: Clean up the provisional client-server code with unit tests and proper file-sockets Seems to work well enough for interactive scala consoles, though still not Ammonite Also Added ScalaModule#launcher and re-worked our build.sc file to use it --- clientserver/src/mill/clientserver/Client.scala | 57 +++++++++ .../src/mill/clientserver/ClientServer.scala | 139 +++++++++++++++++++++ clientserver/src/mill/clientserver/Locks.scala | 103 +++++++++++++++ clientserver/src/mill/clientserver/Server.scala | 104 +++++++++++++++ .../src/mill/clientserver/ClientServerTests.scala | 118 +++++++++++++++++ 5 files changed, 521 insertions(+) create mode 100644 clientserver/src/mill/clientserver/Client.scala create mode 100644 clientserver/src/mill/clientserver/ClientServer.scala create mode 100644 clientserver/src/mill/clientserver/Locks.scala create mode 100644 clientserver/src/mill/clientserver/Server.scala create mode 100644 clientserver/test/src/mill/clientserver/ClientServerTests.scala (limited to 'clientserver') diff --git a/clientserver/src/mill/clientserver/Client.scala b/clientserver/src/mill/clientserver/Client.scala new file mode 100644 index 00000000..dcf65271 --- /dev/null +++ b/clientserver/src/mill/clientserver/Client.scala @@ -0,0 +1,57 @@ +package mill.clientserver + +import java.io._ + +import org.scalasbt.ipcsocket.UnixDomainSocket + +object Client{ + def WithLock[T](index: Int)(f: String => T): T = { + val lockBase = "out/mill-worker-" + index + new java.io.File(lockBase).mkdirs() + val lockFile = new RandomAccessFile(lockBase+ "/clientLock", "rw") + val channel = lockFile.getChannel + channel.tryLock() match{ + case null => + lockFile.close() + channel.close() + if (index < 5) WithLock(index + 1)(f) + else throw new Exception("Reached max process limit: " + 5) + case locked => + try f(lockBase) + finally{ + locked.release() + lockFile.close() + channel.close() + } + } + } +} + +class Client(lockBase: String, + initServer: () => Unit, + locks: Locks, + stdin: InputStream, + stdout: OutputStream, + stderr: OutputStream) extends ClientServer(lockBase){ + def run(args: Array[String]) = { + + val f = new FileOutputStream(runFile) + ClientServer.writeArgs(System.console() != null, args, f) + f.close() + if (locks.processLock.probe()) initServer() + while(locks.processLock.probe()) Thread.sleep(3) + + val ioSocket = ClientServer.retry(1000, new UnixDomainSocket(ioPath)) + val outErr = ioSocket.getInputStream + val in = ioSocket.getOutputStream + val outPump = new ClientOutputPumper(outErr, stdout, stderr) + val inPump = new ClientInputPumper(stdin, in) + val outThread = new Thread(outPump) + outThread.setDaemon(true) + val inThread = new Thread(inPump) + inThread.setDaemon(true) + outThread.start() + inThread.start() + locks.serverLock.await() + } +} diff --git a/clientserver/src/mill/clientserver/ClientServer.scala b/clientserver/src/mill/clientserver/ClientServer.scala new file mode 100644 index 00000000..2cc38859 --- /dev/null +++ b/clientserver/src/mill/clientserver/ClientServer.scala @@ -0,0 +1,139 @@ +package mill.clientserver + +import java.io.{FileInputStream, InputStream, OutputStream, RandomAccessFile} +import java.nio.channels.FileChannel + +import scala.annotation.tailrec + +class ClientServer(lockBase: String){ + val ioPath = lockBase + "/io" + val logFile = new java.io.File(lockBase + "/log") + val runFile = new java.io.File(lockBase + "/run") +} + +object ClientServer{ + def parseArgs(argStream: InputStream) = { + val interactive = argStream.read() != 0 + val argsLength = argStream.read() + val args = Array.fill(argsLength){ + val n = argStream.read() + val arr = new Array[Byte](n) + argStream.read(arr) + new String(arr) + } + (interactive, args) + } + def writeArgs(interactive: Boolean, args: Array[String], argStream: OutputStream) = { + argStream.write(if (interactive) 1 else 0) + argStream.write(args.length) + var i = 0 + while (i < args.length){ + argStream.write(args(i).length) + argStream.write(args(i).getBytes) + i += 1 + } + } + @tailrec def retry[T](millis: Long, t: => T): T = { + val current = System.currentTimeMillis() + val res = + try Some(t) + catch{case e: Throwable if System.currentTimeMillis() < current + millis => + None + } + res match{ + case Some(t) => t + case None => + Thread.sleep(1) + retry(millis - (System.currentTimeMillis() - current), t) + } + } + + def interruptWith[T](millis: Int, close: => Unit)(t: => T): T = { + var int = true + new Thread(() => { + Thread.sleep(millis) + if (int) close + }).start() + + try t + finally { + + int = false + } + } + + def polling[T](probe: => Boolean, cb: () => Unit)(t: => T): T = { + var probing = true + val probeThread = new Thread(() => while(probing){ + if (probe){ + probing = false + cb() + } + Thread.sleep(1000) + }) + probeThread.start() + try t + finally probing = false + } +} +object ProxyOutputStream{ + val lock = new Object +} +class ProxyOutputStream(x: => java.io.OutputStream, + key: Int) extends java.io.OutputStream { + override def write(b: Int) = ProxyOutputStream.lock.synchronized{ + x.write(key) + x.write(b) + } +} +class ProxyInputStream(x: => java.io.InputStream) extends java.io.InputStream{ + def read() = x.read() + override def read(b: Array[Byte], off: Int, len: Int) = x.read(b, off, len) + override def read(b: Array[Byte]) = x.read(b) +} + +class ClientInputPumper(src: InputStream, dest: OutputStream) extends Runnable{ + var running = true + def run() = { + val buffer = new Array[Byte](1024) + while(running){ + val n = src.read(buffer) + if (n == -1) running = false + else { + dest.write(buffer, 0, n) + dest.flush() + } + } + } + +} +class ClientOutputPumper(src: InputStream, dest1: OutputStream, dest2: OutputStream) extends Runnable{ + var running = true + def run() = { + val buffer = new Array[Byte](1024) + var state = 0 + while(running){ + val n = src.read(buffer) + if (n == -1) running = false + else { + var i = 0 + while (i < n){ + state match{ + case 0 => state = buffer(i) + 1 + case 1 => + dest1.write(buffer(i)) + state = 0 + case 2 => + dest2.write(buffer(i)) + state = 0 + } + + i += 1 + } + dest1.flush() + dest2.flush() + } + } + } + +} \ No newline at end of file diff --git a/clientserver/src/mill/clientserver/Locks.scala b/clientserver/src/mill/clientserver/Locks.scala new file mode 100644 index 00000000..d1644719 --- /dev/null +++ b/clientserver/src/mill/clientserver/Locks.scala @@ -0,0 +1,103 @@ +package mill.clientserver + +import java.io.RandomAccessFile +import java.nio.channels.FileChannel +import java.util.concurrent.locks.{ReadWriteLock, ReentrantLock} + + +trait Lock{ + def lock(): Locked + def lockBlock[T](t: => T): T = { + val l = lock() + try t + finally l.release() + } + def tryLockBlock[T](t: => T): Option[T] = { + tryLock() match{ + case None => + None + case Some(l) => + try Some(t) + finally l.release() + } + + } + def tryLock(): Option[Locked] + def await(): Unit = { + val l = lock() + l.release() + } + + /** + * Returns `true` if the lock is *available for taking* + */ + def probe(): Boolean +} +trait Locked{ + def release(): Unit +} +trait Locks{ + val processLock: Lock + val serverLock: Lock + val clientLock: Lock +} +class FileLocked(lock: java.nio.channels.FileLock) extends Locked{ + def release() = { + lock.release() + } +} + +class FileLock(path: String) extends Lock{ + + val raf = new RandomAccessFile(path, "rw") + val chan = raf.getChannel + def lock() = { + val lock = chan.lock() + new FileLocked(lock) + } + def tryLock() = { + chan.tryLock() match{ + case null => None + case lock => Some(new FileLocked(lock)) + } + } + def probe(): Boolean = tryLock() match{ + case None => false + case Some(locked) => + locked.release() + true + } +} +class FileLocks(lockBase: String) extends Locks{ + val processLock = new FileLock(lockBase + "/pid") + + val serverLock = new FileLock(lockBase + "/serverLock") + + val clientLock = new FileLock(lockBase + "/clientLock") +} +class MemoryLocked(l: java.util.concurrent.locks.Lock) extends Locked{ + def release() = l.unlock() +} + +class MemoryLock() extends Lock{ + val innerLock = new ReentrantLock(true) + + def probe() = !innerLock.isLocked + def lock() = { + innerLock.lock() + new MemoryLocked(innerLock) + } + def tryLock() = { + innerLock.tryLock() match{ + case false => None + case true => Some(new MemoryLocked(innerLock)) + } + } +} +class MemoryLocks() extends Locks{ + val processLock = new MemoryLock() + + val serverLock = new MemoryLock() + + val clientLock = new MemoryLock() +} \ No newline at end of file diff --git a/clientserver/src/mill/clientserver/Server.scala b/clientserver/src/mill/clientserver/Server.scala new file mode 100644 index 00000000..ad2e35e4 --- /dev/null +++ b/clientserver/src/mill/clientserver/Server.scala @@ -0,0 +1,104 @@ +package mill.clientserver + +import java.io._ +import java.net.Socket + +import org.scalasbt.ipcsocket.UnixDomainServerSocket + +trait ServerMain[T]{ + def main(args0: Array[String]): Unit = { + new Server( + args0(0), + this, + () => System.exit(0), + () => System.currentTimeMillis(), + new FileLocks(args0(0)) + ).run() + } + var stateCache = Option.empty[T] + def main0(args: Array[String], + stateCache: Option[T], + mainInteractive: Boolean, + watchInterrupted: () => Boolean, + stdin: InputStream, + stdout: PrintStream, + stderr: PrintStream): (Boolean, Option[T]) +} + + +class Server[T](lockBase: String, + sm: ServerMain[T], + interruptServer: () => Unit, + currentTimeMillis: () => Long, + locks: Locks) extends ClientServer(lockBase){ + + val originalStdout = System.out + def run() = { + locks.processLock.tryLockBlock{ + var lastRun = currentTimeMillis() + while (currentTimeMillis() - lastRun < 60000) locks.serverLock.lockBlock{ + new File(ioPath).delete() + val ioSocket = new UnixDomainServerSocket(ioPath) + val sockOpt = ClientServer.interruptWith( + 1000, + ioSocket.close() + ){ + try Some(ioSocket.accept()) + catch{ case e: IOException => None} + } + + sockOpt.foreach{sock => + try handleRun(sock) + catch{case e: Throwable => e.printStackTrace(originalStdout) } + finally lastRun = currentTimeMillis() + } + } + }.getOrElse(throw new Exception("PID already present")) + } + + def handleRun(clientSocket: Socket) = { + + val currentOutErr = clientSocket.getOutputStream + val socketIn = clientSocket.getInputStream + val argStream = new FileInputStream(runFile) + val (interactive, args) = ClientServer.parseArgs(argStream) + argStream.close() + + var done = false + val t = new Thread(() => + + try { + val stdout = new PrintStream(new ProxyOutputStream(currentOutErr, 0), true) + val stderr = new PrintStream(new ProxyOutputStream(currentOutErr, 1), true) + val (_, newStateCache) = sm.main0( + args, + sm.stateCache, + interactive, + () => !locks.clientLock.probe(), + socketIn, + stdout, stderr + ) + + sm.stateCache = newStateCache + } catch{case WatchInterrupted(sc: Option[T]) => + sm.stateCache = sc + } finally{ + done = true + } + ) + + t.start() + + // We cannot simply use Lock#await here, because the filesystem doesn't + // realize the clientLock/serverLock are held by different threads in the + // two processes and gives a spurious deadlock error + while(!done && !locks.clientLock.probe()) { + Thread.sleep(3) + } + + t.interrupt() + t.stop() + clientSocket.close() + } +} +case class WatchInterrupted[T](stateCache: Option[T]) extends Exception \ No newline at end of file diff --git a/clientserver/test/src/mill/clientserver/ClientServerTests.scala b/clientserver/test/src/mill/clientserver/ClientServerTests.scala new file mode 100644 index 00000000..ecf09ab3 --- /dev/null +++ b/clientserver/test/src/mill/clientserver/ClientServerTests.scala @@ -0,0 +1,118 @@ +package mill.clientserver +import java.io._ +import java.nio.file.Path + +import utest._ +class EchoServer extends ServerMain[Int]{ + def main0(args: Array[String], + stateCache: Option[Int], + mainInteractive: Boolean, + watchInterrupted: () => Boolean, + stdin: InputStream, + stdout: PrintStream, + stderr: PrintStream) = { + + val reader = new BufferedReader(new InputStreamReader(stdin)) + val str = reader.readLine() + stdout.println(str + args(0)) + stdout.flush() + stderr.println(str.toUpperCase + args(0)) + stderr.flush() + (true, None) + } +} + +object ClientServerTests extends TestSuite{ + def initStreams() = { + val in = new ByteArrayInputStream("hello\n".getBytes()) + val out = new ByteArrayOutputStream() + val err = new ByteArrayOutputStream() + (in, out, err) + } + def init() = { + val tmpDir = java.nio.file.Files.createTempDirectory("") + val locks = new MemoryLocks() + + (tmpDir, locks) + } + + def tests = Tests{ + 'hello - { + var currentTimeMillis = 0 + val (tmpDir, locks) = init() + + def spawnEchoServer() = { + new Thread(() => new Server( + tmpDir.toString, + new EchoServer(), + () => (), + () => currentTimeMillis, + locks + ).run()).start() + } + + + def runClient(arg: String) = { + val (in, out, err) = initStreams() + locks.clientLock.lockBlock{ + val c = new Client( + tmpDir.toString, + () => spawnEchoServer(), + locks, + in, + out, + err + ) + c.run(Array(arg)) + (new String(out.toByteArray), new String(err.toByteArray)) + } + } + + // Make sure the simple "have the client start a server and + // exchange one message" workflow works from end to end. + + assert( + locks.clientLock.probe(), + locks.serverLock.probe(), + locks.processLock.probe() + ) + + val (out1, err1) = runClient("world") + + assert( + out1 == "helloworld\n", + err1 == "HELLOworld\n" + ) + + assert( + locks.clientLock.probe(), + !locks.serverLock.probe(), + !locks.processLock.probe() + ) + + // A seecond client in sequence connect to the same server + val (out2, err2) = runClient(" WORLD") + + assert( + out2 == "hello WORLD\n", + err2 == "HELLO WORLD\n" + ) + + // Make sure the server times out of not used for a while + currentTimeMillis += 60001 + Thread.sleep(2000) + assert( + locks.clientLock.probe(), + locks.serverLock.probe(), + locks.processLock.probe() + ) + + // Have a third client spawn/connect-to a new server at the same path + val (out3, err3) = runClient(" World") + assert( + out3 == "hello World\n", + err3 == "HELLO World\n" + ) + } + } +} -- cgit v1.2.3