I have a basic Android TensorFlowInference
example that runs fine in a single thread.
public class InferenceExample { private static final String MODEL_FILE = "file:///android_asset/model.pb"; private static final String INPUT_NODE = "intput_node0"; private static final String OUTPUT_NODE = "output_node0"; private static final int[] INPUT_SIZE = {1, 8000, 1}; public static final int CHUNK_SIZE = 8000; public static final int STRIDE = 4; private static final int NUM_OUTPUT_STATES = 5; private static TensorFlowInferenceInterface inferenceInterface; public InferenceExample(final Context context) { inferenceInterface = new TensorFlowInferenceInterface(context.getAssets(), MODEL_FILE); } public float[] run(float[] data) { float[] res = new float[CHUNK_SIZE / STRIDE * NUM_OUTPUT_STATES]; inferenceInterface.feed(INPUT_NODE, data, INPUT_SIZE[0], INPUT_SIZE[1], INPUT_SIZE[2]); inferenceInterface.run(new String[]{OUTPUT_NODE}); inferenceInterface.fetch(OUTPUT_NODE, res); return res; } }
The example crashes with various exceptions including java.lang.ArrayIndexOutOfBoundsException
and java.lang.NullPointerException
when running in a ThreadPool
as per the below example so I guess it's not thread safe.
InferenceExample inference = new InferenceExample(context); ExecutorService executor = Executors.newFixedThreadPool(NUMBER_OF_CORES); Collection<Future<?>> futures = new LinkedList<Future<?>>(); for (int i = 1; i <= 100; i++) { Future<?> result = executor.submit(new Runnable() { public void run() { inference.call(randomData); } }); futures.add(result); } for (Future<?> future:futures) { try { future.get(); } catch(ExecutionException | InterruptedException e) { Log.e("TF", e.getMessage()); } }
Is it possible to leverage multicore Android devices with TensorFlowInferenceInterface
?
2 Answers
Answers 1
To make the InferenceExample
thread safe I changed the TensorFlowInferenceInterface
from static
and made the run
method synchronized
:
private TensorFlowInferenceInterface inferenceInterface; public InferenceExample(final Context context) { inferenceInterface = new TensorFlowInferenceInterface(assets, model); } public synchronized float[] run(float[] data) { ... }
Then I round robin a list of InterferenceExample
instance across numThreads
.
for (int i = 1; i <= 100; i++) { final int id = i % numThreads; Future<?> result = executor.submit(new Runnable() { public void run() { list.get(id).run(data); } }); futures.add(result); }
This does increase performance however on a 8 core device this peaks at numThreads
of 2 and only shows ~50% CPU usage in Android Studio Monitor.
Answers 2
The TensorFlowInferenceInterface
class is not thread-safe (as it keeps state between calls to feed
, run
, fetch
etc.
However, it is built on top of the TensorFlow Java API where objects of the Session
class are thread-safe.
So you might want to use the underlying Java API directly, TensorFlowInferenceInterface
's constructor creates a Session
and sets it up with a Graph
loaded from the AssetManager
(code).
Hope that helps.
0 comments:
Post a Comment