Monday, October 30, 2017

Running TensorFlow on multicore devices

Leave a Comment

I have a basic Android TensorFlowInference example that runs fine in a single thread.

public class InferenceExample {      private static final String MODEL_FILE = "file:///android_asset/model.pb";     private static final String INPUT_NODE = "intput_node0";     private static final String OUTPUT_NODE = "output_node0";      private static final int[] INPUT_SIZE = {1, 8000, 1};     public static final int CHUNK_SIZE = 8000;     public static final int STRIDE = 4;     private static final int NUM_OUTPUT_STATES = 5;      private static TensorFlowInferenceInterface inferenceInterface;      public InferenceExample(final Context context) {         inferenceInterface = new TensorFlowInferenceInterface(context.getAssets(), MODEL_FILE);     }      public float[] run(float[] data) {          float[] res = new float[CHUNK_SIZE / STRIDE * NUM_OUTPUT_STATES];          inferenceInterface.feed(INPUT_NODE, data, INPUT_SIZE[0], INPUT_SIZE[1], INPUT_SIZE[2]);         inferenceInterface.run(new String[]{OUTPUT_NODE});         inferenceInterface.fetch(OUTPUT_NODE, res);          return res;     } } 

The example crashes with various exceptions including java.lang.ArrayIndexOutOfBoundsException and java.lang.NullPointerException when running in a ThreadPool as per the below example so I guess it's not thread safe.

InferenceExample inference = new InferenceExample(context);  ExecutorService executor = Executors.newFixedThreadPool(NUMBER_OF_CORES);     Collection<Future<?>> futures = new LinkedList<Future<?>>();  for (int i = 1; i <= 100; i++) {     Future<?> result = executor.submit(new Runnable() {         public void run() {            inference.call(randomData);         }     });     futures.add(result); }  for (Future<?> future:futures) {     try { future.get(); }     catch(ExecutionException | InterruptedException e) {         Log.e("TF", e.getMessage());     } } 

Is it possible to leverage multicore Android devices with TensorFlowInferenceInterface?

2 Answers

Answers 1

To make the InferenceExample thread safe I changed the TensorFlowInferenceInterface from static and made the run method synchronized:

private TensorFlowInferenceInterface inferenceInterface;  public InferenceExample(final Context context) {     inferenceInterface = new TensorFlowInferenceInterface(assets, model); }  public synchronized float[] run(float[] data) { ... } 

Then I round robin a list of InterferenceExample instance across numThreads.

for (int i = 1; i <= 100; i++) {     final int id = i % numThreads;     Future<?> result = executor.submit(new Runnable() {         public void run() {             list.get(id).run(data);         }     });     futures.add(result); } 

This does increase performance however on a 8 core device this peaks at numThreads of 2 and only shows ~50% CPU usage in Android Studio Monitor.

Answers 2

The TensorFlowInferenceInterface class is not thread-safe (as it keeps state between calls to feed, run, fetch etc.

However, it is built on top of the TensorFlow Java API where objects of the Session class are thread-safe.

So you might want to use the underlying Java API directly, TensorFlowInferenceInterface's constructor creates a Session and sets it up with a Graph loaded from the AssetManager (code).

Hope that helps.

If You Enjoyed This, Take 5 Seconds To Share It

0 comments:

Post a Comment