Question
I am trying to get tensorflow
for java to work on Scala. I am use the
tensorflow java library without any wrapper for Scala.
At sbt
I have:
If I run the HelloWord
found
here, it WORKS fine, with
the Scala adaptations:
import org.tensorflow.Graph
import org.tensorflow.Session
import org.tensorflow.Tensor
import org.tensorflow.TensorFlow
val g = new Graph()
val value = "Hello from " + TensorFlow.version()
val t = Tensor.create(value.getBytes("UTF-8"))
// The Java API doesn't yet include convenience functions for adding operations.
g.opBuilder("Const", "MyConst").setAttr("dtype", t.dataType()).setAttr("value", t).build();
val s = new Session(g)
val output = s.runner().fetch("MyConst").run().get(0)
However, if I try to use Scala reflection to compile the function from a string, it DOES NOT WORK. Here is the snippet I used to run:
import scala.reflect.runtime.{universe => ru}
import scala.tools.reflect.ToolBox
val fnStr = """
{() =>
import org.tensorflow.Graph
import org.tensorflow.Session
import org.tensorflow.Tensor
import org.tensorflow.TensorFlow
val g = new Graph()
val value = "Hello from " + TensorFlow.version()
val t = Tensor.create(value.getBytes("UTF-8"))
g.opBuilder("Const", "MyConst").setAttr("dtype", t.dataType()).setAttr("value", t).build();
val s = new Session(g)
s.runner().fetch("MyConst").run().get(0)
}
"""
val mirror = ru.runtimeMirror(getClass.getClassLoader)
val tb = mirror.mkToolBox()
var t = tb.parse(fnStr)
val fn = tb.eval(t).asInstanceOf[() => Any]
// and finally, executing the function
fn()
Here simplified build.sbt
to reproduce the error above:
lazy val commonSettings = Seq(
scalaVersion := "2.12.10",
libraryDependencies ++= {
Seq(
// To support runtime compilation
"org.scala-lang" % "scala-reflect" % scalaVersion.value,
"org.scala-lang" % "scala-compiler" % scalaVersion.value,
// for tensorflow4java
"org.tensorflow" % "tensorflow" % "1.15.0",
"org.tensorflow" % "proto" % "1.15.0",
"org.tensorflow" % "libtensorflow_jni" % "1.15.0"
)
}
)
lazy val `test-proj` = project
.in(file("."))
.settings(commonSettings)
When running the above, for example with sbt console
, I get the following
error and stack trace:
java.lang.NoSuchMethodError: org.tensorflow.Session.runner()Lorg/tensorflow/Session$$Runner;
at __wrapper$1$f093d26a3c504d4381a37ef78b6c3d54.__wrapper$1$f093d26a3c504d4381a37ef78b6c3d54$.$anonfun$wrapper$1(<no source file>:15)
Please ignore the memory-leaks that the previous code has given that no resources context (to close()) is used
Answer
The thing is in this bug appearing in combination of reflective compilation and Scala-Java interop
https://github.com/scala/bug/issues/8956
Toolbox can't typecheck a value (s.runner()
) of path-dependent type
(s.Runner
) if this type comes from Java non-static inner class. And Runner
is
exactly
such class inside org.tensorflow.Session
.
You can run the compiler manually (similarly to how Toolbox runs it)
import org.tensorflow.Tensor
import scala.reflect.internal.util.{AbstractFileClassLoader, BatchSourceFile}
import scala.reflect.io.{AbstractFile, VirtualDirectory}
import scala.reflect.runtime
import scala.reflect.runtime.universe
import scala.reflect.runtime.universe._
import scala.tools.nsc.{Global, Settings}
val code: String =
"""
|import org.tensorflow.Graph
|import org.tensorflow.Session
|import org.tensorflow.Tensor
|import org.tensorflow.TensorFlow
|
|object Main {
| def foo() = () => {
| val g = new Graph()
| val value = "Hello from " + TensorFlow.version()
| val t = Tensor.create(value.getBytes("UTF-8"))
| g.opBuilder("Const", "MyConst").setAttr("dtype", t.dataType()).setAttr("value", t).build();
|
| val s = new Session(g)
|
| s.runner().fetch("MyConst").run().get(0)
| }
|}
""".stripMargin
val directory = new VirtualDirectory("(memory)", None)
val runtimeMirror = createRuntimeMirror(directory, runtime.currentMirror)
compileCode(code, List(), directory)
val tensor = runObjectMethod("Main", runtimeMirror, "foo").asInstanceOf[() => Tensor[_]]
tensor() // STRING tensor with shape []
def compileCode(code: String, classpathDirectories: List[AbstractFile], outputDirectory: AbstractFile): Unit = {
val settings = new Settings
classpathDirectories.foreach(dir => settings.classpath.prepend(dir.toString))
settings.outputDirs.setSingleOutput(outputDirectory)
settings.usejavacp.value = true
val global = new Global(settings)
(new global.Run).compileSources(List(new BatchSourceFile("(inline)", code)))
}
def runObjectMethod(objectName: String, runtimeMirror: Mirror, methodName: String, arguments: Any*): Any = {
val objectSymbol = runtimeMirror.staticModule(objectName)
val objectModuleMirror = runtimeMirror.reflectModule(objectSymbol)
val objectInstance = objectModuleMirror.instance
val objectType = objectSymbol.typeSignature
val methodSymbol = objectType.decl(TermName(methodName)).asMethod
val objectInstanceMirror = runtimeMirror.reflect(objectInstance)
val methodMirror = objectInstanceMirror.reflectMethod(methodSymbol)
methodMirror(arguments: _*)
}
def createRuntimeMirror(directory: AbstractFile, parentMirror: Mirror): Mirror = {
val classLoader = new AbstractFileClassLoader(directory, parentMirror.classLoader)
universe.runtimeMirror(classLoader)
}
[dynamically parse json in flink map](https://stackoverflow.com/questions/64111895/dynamically-parse-json-in- flink-map)
[Dynamic compilation of multiple Scala classes at runtime](https://stackoverflow.com/questions/56922911/dynamic-compilation-of- multiple-scala-classes-at-runtime)
[How to eval code that uses InterfaceStability annotation (that fails with "illegal cyclic reference involving class InterfaceStability")?](https://stackoverflow.com/questions/53976254/how-to- eval-code-that-uses-interfacestability-annotation-that-fails-with-illeg)