Tuesday, January 23, 2018

How to efficiently pass function through?

Leave a Comment

Motivation

Take a look at the following picture.

enter image description here

Given are the red, blue, and green curve. I would like to find at each point on the x axis the dominating curve. This is shown as the black graph in the picture. From the properties of the red, green, and blue curve (increasing and constant after a while) this boils down to find the dominating curve on the very right hand side and then move towards the left hand side finding all intersection points and update the dominating curve.

This outlined problem should be solved T times. There is one final twist in the problem. The blue, green, and red curve of the next iteration are constructed via the dominating solution from the previous iteration plus some varying parameters. As an example in the picture above: The solution is the black function. This function is used to generate a new blue, green, and red curve. Then the problem start again to find the dominating one for these new curves etc.

Question in a nutshell
In each iteration I start at the fixed very right hand side and evaluate all three functions to see which is the dominating one. This evaluations are taking longer and longer over iteration. My feeling is that I don't pass optimally the old dominating function to construct the new blue, green, and red curve. Reason: I got in an earlier version a maximum recursion depth error. Other parts of the code where the value of the current dominating function (which is essential either the green, red, or blue curve) is required are also taking longer and longer with iteration.

For 5 iterations just evaluating the functions on one point on the very right hand side grows:

The results were produced via

test = A(5, 120000, 100000)  

And then running

test.find_all_intersections()  >>> test.find_all_intersections() iteration 4 to compute function values it took 0.0102479457855 iteration 3 to compute function values it took 0.0134601593018 iteration 2 to compute function values it took 0.0294270515442 iteration 1 to compute function values it took 0.109843969345 iteration 0 to compute function values it took 0.823768854141 

I would like to know why is this the case and if one can program it more efficiently.

Detailed Code explanation

I quickly summarize the most important functions. The complete code can be found further below. If there are any other questions regarding the code I'm more than happy to elaborate / clarify.

  1. Method u: For the recurring task of generating a new batch of the green, red, and blue curve above we need the old dominating one. u is the initialization to be used in the very first iteration.

  2. Method _function_template: The function generates versions of the green, blue, and red curve by using different parameters. It returns a function of a single input.

  3. Method eval: This is the core function to generate the blue, green, and red versions every time. It takes three varying parameters each iteration: vfunction which is the dominating function from the previous step, m, and s which are two parameters (flaots) affecting the shape of the resulting curve. The other parameters are the same in each iteration. In the code there are sample values for m and s for each iteration. For the more geeky ones: It's to approximate an integral where m and s are the expected mean and standard deviation of the underlying normal distribution. The approximation is done via Gauss-Hermite nodes / weights.

  4. Method find_all_intersections: This is the core method finding in each iteration the dominating one. It constructs a dominating function via a piece wise concatenation of the blue, green, and red curve. This is achieved via the function piecewise.

Here is the complete code

import numpy as np import pandas as pd from scipy.optimize import brentq import multiprocessing as mp import pathos as pt import timeit import math class A(object):     def u(self, w):         _w = np.asarray(w).copy()         _w[_w >= 120000] = 120000         _p = np.maximum(0, 100000 - _w)         return _w - 1000*_p**2      def __init__(self, T, upper_bound, lower_bound):         self.T = T         self.upper_bound = upper_bound         self.lower_bound = lower_bound      def _function_template(self, *args):         def _f(x):             return self.evalv(x, *args)         return _f      def evalv(self, w, c, vfunction, g, m, s, gauss_weights, gauss_nodes):         _A = np.tile(1 + m + math.sqrt(2) * s * gauss_nodes, (np.size(w), 1))         _W = (_A.T * w).T         _W = gauss_weights * vfunction(np.ravel(_W)).reshape(np.size(w),                                                              len(gauss_nodes))         evalue = g*1/math.sqrt(math.pi)*np.sum(_W, axis=1)         return c + evalue      def find_all_intersections(self):          # the hermite gauss weights and nodes for integration         # and additional paramters used for function generation          gauss = np.polynomial.hermite.hermgauss(10)         gauss_nodes = gauss[0]         gauss_weights = gauss[1]         r = np.asarray([1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,                         1., 1., 1., 1., 1., 1., 1., 1., 1.])         m = [[0.038063407778193614, 0.08475713587463352, 0.15420895520972322],              [0.038212567720998125, 0.08509661835487026, 0.15484578903763624],              [0.03836174909668277, 0.08543620707856969, 0.15548297423808233],              [0.038212567720998125, 0.08509661835487026, 0.15484578903763624],              [0.038063407778193614, 0.08475713587463352, 0.15420895520972322],              [0.038063407778193614, 0.08475713587463352, 0.15420895520972322],              [0.03836174909668277, 0.08543620707856969, 0.15548297423808233],              [0.038212567720998125, 0.08509661835487026, 0.15484578903763624],              [0.038212567720998125, 0.08509661835487026, 0.15484578903763624],              [0.038212567720998125, 0.08509661835487026, 0.15484578903763624],              [0.038063407778193614, 0.08475713587463352, 0.15420895520972322],              [0.038212567720998125, 0.08509661835487026, 0.15484578903763624],              [0.038212567720998125, 0.08509661835487026, 0.15484578903763624],              [0.038212567720998125, 0.08509661835487026, 0.15484578903763624],              [0.03836174909668277, 0.08543620707856969, 0.15548297423808233],              [0.038063407778193614, 0.08475713587463352, 0.15420895520972322],              [0.038063407778193614, 0.08475713587463352, 0.15420895520972322],              [0.038212567720998125, 0.08509661835487026, 0.15484578903763624],              [0.03836174909668277, 0.08543620707856969, 0.15548297423808233],              [0.038212567720998125, 0.08509661835487026, 0.15484578903763624],              [0.038212567720998125, 0.08509661835487026, 0.15484578903763624]]          s = [[0.01945441966324046, 0.04690600929081242, 0.200125178687699],              [0.019491796104351332, 0.04699612658674578, 0.20050966545654142],              [0.019529101011406914, 0.04708607140891122, 0.20089341636351565],              [0.019491796104351332, 0.04699612658674578, 0.20050966545654142],              [0.01945441966324046, 0.04690600929081242, 0.200125178687699],              [0.01945441966324046, 0.04690600929081242, 0.200125178687699],              [0.019529101011406914, 0.04708607140891122, 0.20089341636351565],              [0.019491796104351332, 0.04699612658674578, 0.20050966545654142],              [0.019491796104351332, 0.04699612658674578, 0.20050966545654142],              [0.019491796104351332, 0.04699612658674578, 0.20050966545654142],              [0.01945441966324046, 0.04690600929081242, 0.200125178687699],              [0.019491796104351332, 0.04699612658674578, 0.20050966545654142],              [0.019491796104351332, 0.04699612658674578, 0.20050966545654142],              [0.019491796104351332, 0.04699612658674578, 0.20050966545654142],              [0.019529101011406914, 0.04708607140891122, 0.20089341636351565],              [0.01945441966324046, 0.04690600929081242, 0.200125178687699],              [0.01945441966324046, 0.04690600929081242, 0.200125178687699],              [0.019491796104351332, 0.04699612658674578, 0.20050966545654142],              [0.019529101011406914, 0.04708607140891122, 0.20089341636351565],              [0.019491796104351332, 0.04699612658674578, 0.20050966545654142],              [0.019491796104351332, 0.04699612658674578, 0.20050966545654142]]          self.solution = []          n_cpu = mp.cpu_count()         pool = pt.multiprocessing.ProcessPool(n_cpu)          # this function is used for multiprocessing         def call_f(f, x):             return f(x)          # this function takes differences for getting cross points         def _diff(f_dom, f_other):             def h(x):                 return f_dom(x) - f_other(x)             return h          # finds the root of two function         def find_roots(F, u_bound, l_bound):                 try:                     sol = brentq(F, a=l_bound,                                  b=u_bound)                     if np.absolute(sol - u_bound) > 1:                         return sol                     else:                         return l_bound                 except ValueError:                     return l_bound          # piecewise function         def piecewise(l_comp, l_f):             def f(x):                 _ind_f = np.digitize(x, l_comp) - 1                 if np.isscalar(x):                     return l_f[_ind_f](x)                 else:                     return np.asarray([l_f[_ind_f[i]](x[i])                                        for i in range(0, len(x))]).ravel()             return f          _u = self.u          for t in range(self.T-1, -1, -1):             print('iteration' + ' ' + str(t))              l_bound, u_bound = 0.5*self.lower_bound, self.upper_bound             l_ordered_functions = []             l_roots = []             l_solution = []              # build all function variations              l_functions = [self._function_template(0, _u, r[t], m[t][i], s[t][i],                                                    gauss_weights, gauss_nodes)                            for i in range(0, len(m[t]))]              # get the best solution for the upper bound on the very             # right hand side of wealth interval              array_functions = np.asarray(l_functions)             start_time = timeit.default_timer()             functions_values = pool.map(call_f, array_functions.tolist(),                                         len(m[t]) * [u_bound])             elapsed = timeit.default_timer() - start_time             print('to compute function values it took')             print(elapsed)              ind = np.argmax(functions_values)             cross_points = len(m[t]) * [u_bound]             l_roots.insert(0, u_bound)             max_m = m[t][ind]             l_solution.insert(0, max_m)              # move from the upper bound twoards the lower bound             # and find the dominating solution by exploring all cross             # points.              test = True              while test:                 l_ordered_functions.insert(0, array_functions[ind])                 current_max = l_ordered_functions[0]                  l_c_max = len(m[t]) * [current_max]                 l_u_cross = len(m[t]) * [cross_points[ind]]                  # Find new cross points on the smaller interval                  diff = pool.map(_diff, l_c_max, array_functions.tolist())                 cross_points = pool.map(find_roots, diff,                                         l_u_cross, len(m[t]) * [l_bound])                  # update the solution, cross points and current                 # dominating function.                  ind = np.argmax(cross_points)                 l_roots.insert(0, cross_points[ind])                 max_m = m[t][ind]                 l_solution.insert(0, max_m)                  if cross_points[ind] <= l_bound:                     test = False              l_ordered_functions.insert(0, l_functions[0])             l_roots.insert(0, 0)             l_roots[-1] = np.inf              l_comp = l_roots[:]             l_f = l_ordered_functions[:]              # build piecewise function which is used for next             # iteration.              _u = piecewise(l_comp, l_f)             _sol = pd.DataFrame(data=l_solution,                                 index=np.asarray(l_roots)[0:-1])             self.solution.insert(0, _sol)         return self.solution 

2 Answers

Answers 1

Let's start by changing the code to output the current iteration:

_u = self.u for t in range(0, self.T):     print(t)     lparams = np.random.randint(self.a, self.b, 6).reshape(3, 2).tolist()     functions = [self._function_template(_u, *lparams[i])                  for i in range(0, 3)]     # evaluate functions     pairs = list(itertools.combinations(functions, 2))     fval = [F(diff(*pairs[i]), self.a, self.b) for i in range(0, 3)]     ind = np.sort(np.unique(np.random.randint(self.a, self.b, 10)))     _u = _temp(ind, np.asarray(functions)[ind % 3]) 

Looking into the line causing the behaviour,

fval = [F(diff(*pairs[i]), self.a, self.b) for i in range(0, 3)] 

functions of interest would be F and diff. The latter being straightforward, the former:

def F(f, a, b):     try:         brentq(f, a=a, b=b)     except ValueError:         pass 

Hmm, swallowing exceptions, let's see what happens if we:

def F(f, a, b):     brentq(f, a=a, b=b) 

Immediately, for the first function and on the first iteration, an error is thrown:

ValueError: f(a) and f(b) must have different signs

Looking at the docs this is a prerequisite of the root finding function brentq. Let's change the definition once more to monitor this condition on each iteration.

def F(f, a, b):     try:         brentq(f, a=a, b=b)     except ValueError as e:         print(e) 

The output is

i f(a) and f(b) must have different signs f(a) and f(b) must have different signs f(a) and f(b) must have different signs 

for i ranging from 0 to 57. Meaning, the first time the function F ever does any real work is for i=58. And it keeps doing so for higher values of i.

Conclusion: it takes longer for these higher values, because:

  1. the root is never calculated for the lower values
  2. the number of calculations grows linear for i>58

Answers 2

Your code is really far too complex to explain your problem - strive for something simpler. Sometimes you have to write code just to demonstrate the problem.

I'm taking a stab, based purely on your description rather than your code (although I ran the code and verified) . Here's your problem:

method eval: This is the core function to generate the blue, green and red versions every time. It takes three varying parameters each iteration: vfunction which is the dominating function from the previous step, m and s which are two parameters (flaots) affecting the shape of the resulting curve.

Your vfunction parameter is more complex on each iteration. You are passing a nested function built up over previous iterations, which causes a recursive execution. Each iteration increases the depth of the recursive call.

How can you avoid this? There's no easy or built in way. The simplest answer is - assuming the inputs to these functions are consistent - to store the functional result (i.e. the numbers) rather than the function themselves. You can do this as long as you have a finite number of known inputs.

If the inputs to the underlying functions aren't consistent then there's no shortcut. You need to repeatedly evaluate those underlying functions. I see that you're doing some piecewise splicing of the underlying functions - you can test whether the cost of doing so exceeds the cost of simply taking the max of each of the underlying functions.

The test that I ran (10 iterations) took a few seconds. I don't see that as a problem.

If You Enjoyed This, Take 5 Seconds To Share It

0 comments:

Post a Comment