Tuesday, February 20, 2018

Slicing tensors in tensorflow using argmax

Leave a Comment

I want to make a dynamic loss function in tensorflow. I want to calculate the energy of a signal's FFT, more specifically only a window of size 3 around the most dominant peak. I am unable to implement in TF, as it throws a lot of errors like Stride and InvalidArgumentError (see above for traceback): Expected begin, end, and strides to be 1D equal size tensors, but got shapes [1,64], [1,64], and [1] instead.

My code is this:

self.spec = tf.fft(self.signal) self.spec_mag = tf.complex_abs(self.spec[:,1:33]) self.argm = tf.cast(tf.argmax(self.spec_mag, 1), dtype=tf.int32) self.frac = tf.reduce_sum(self.spec_mag[self.argm-1:self.argm+2], 1) 

Since I am computing batchwise of 64 and dimension of data as 64 too, the shape of self.signal is (64,64). I wish to calculate only the AC components of the FFT. As the signal is real valued, only half the spectrum would do the job. Hence, the shape of self.spec_mag is (64,32).

The max in this fft is located at self.argm which has a shape (64,1).

Now I want to calculate the energy of 3 elements around the max peak via: self.spec_mag[self.argm-1:self.argm+2].

However when I run the code and try to obtain the value of self.frac, I get thrown with multiple errors.

3 Answers

Answers 1

It seems like you were missing and index when accessing argm. Here is the fixed version of the 1, 64 version.

import tensorflow as tf import numpy as np  x = np.random.rand(1, 64) xt = tf.constant(value=x, dtype=tf.complex64)  signal = xt print('signal', signal.shape) print('signal', signal.eval())  spec = tf.fft(signal) print('spec', spec.shape) print('spec', spec.eval())  spec_mag = tf.abs(spec[:,1:33]) print('spec_mag', spec_mag.shape) print('spec_mag', spec_mag.eval())  argm = tf.cast(tf.argmax(spec_mag, 1), dtype=tf.int32) print('argm', argm.shape) print('argm', argm.eval())  frac = tf.reduce_sum(spec_mag[0][(argm[0]-1):(argm[0]+2)], 0) print('frac', frac.shape) print('frac', frac.eval()) 

and here is the expanded version (batch, m, n)

import tensorflow as tf import numpy as np  x = np.random.rand(1, 1, 64) xt = tf.constant(value=x, dtype=tf.complex64)  signal = xt print('signal', signal.shape) print('signal', signal.eval())  spec = tf.fft(signal) print('spec', spec.shape) print('spec', spec.eval())  spec_mag = tf.abs(spec[:, :, 1:33]) print('spec_mag', spec_mag.shape) print('spec_mag', spec_mag.eval())  argm = tf.cast(tf.argmax(spec_mag, 2), dtype=tf.int32) print('argm', argm.shape) print('argm', argm.eval())  frac = tf.reduce_sum(spec_mag[0][0][(argm[0][0]-1):(argm[0][0]+2)], 0) print('frac', frac.shape) print('frac', frac.eval()) 

you may want to fix function names since I edit this code at a newer version of tensorflow.

Answers 2

Tensorflow indexing uses tf.Tensor.getitem:

This operation extracts the specified region from the tensor. The notation is similar to NumPy with the restriction that currently only support basic indexing. That means that using a tensor as input is not currently allowed

So using tf.slice and tf.strided_slice is out of the question as well.

Whereas in tf.gather indices defines slices into the first dimension of Tensor, in tf.gather_nd, indices defines slices into the first N dimensions of the Tensor, where N = indices.shape[-1]

Since you wanted the 3 values around the max, I manually extract the first, second and third element using a list comprehension, followed be a tf.stack

import tensorflow as tf  signal = tf.placeholder(shape=(64, 64), dtype=tf.complex64) spec = tf.fft(signal) spec_mag = tf.abs(spec[:,1:33]) argm = tf.cast(tf.argmax(spec_mag, 1), dtype=tf.int32)  frac = tf.stack([tf.gather_nd(spec,tf.transpose(tf.stack(              [tf.range(64), argm+i]))) for i in [-1, 0, 1]])  frac = tf.reduce_sum(frac, 1) 

This will fail for the corner case where argm is the first or last element in the row, but it should be easy to resolve.

Answers 3

It seems like you were missing and index when accessing argm. Here is the fixed version of the 1, 64 version.

import tensorflow as tf import numpy as np  x = np.random.rand(1, 64) xt = tf.constant(value=x, dtype=tf.complex64)  signal = xt print('signal', signal.shape) print('signal', signal.eval())  spec = tf.fft(signal) print('spec', spec.shape) print('spec', spec.eval())  spec_mag = tf.abs(spec[:,1:33]) print('spec_mag', spec_mag.shape) print('spec_mag', spec_mag.eval())  argm = tf.cast(tf.argmax(spec_mag, 1), dtype=tf.int32) print('argm', argm.shape) print('argm', argm.eval())  frac = tf.reduce_sum(spec_mag[0][(argm[0]-1):(argm[0]+2)], 0) print('frac', frac.shape) print('frac', frac.eval()) and here is the expanded version (batch, m, n)  import tensorflow as tf import numpy as np  x = np.random.rand(1, 1, 64) xt = tf.constant(value=x, dtype=tf.complex64)  signal = xt print('signal', signal.shape) print('signal', signal.eval())  spec = tf.fft(signal) print('spec', spec.shape) print('spec', spec.eval())  spec_mag = tf.abs(spec[:, :, 1:33]) print('spec_mag', spec_mag.shape) print('spec_mag', spec_mag.eval())  argm = tf.cast(tf.argmax(spec_mag, 2), dtype=tf.int32) print('argm', argm.shape) print('argm', argm.eval())  frac = tf.reduce_sum(spec_mag[0][0][(argm[0][0]-1):(argm[0][0]+2)], 0) print('frac', frac.shape) print('frac', frac.eval()) 

you may want to fix function names since I edit this code at a newer version of tensorflow.

If You Enjoyed This, Take 5 Seconds To Share It

0 comments:

Post a Comment