Learning functions#

try:
    # in colab
    import google.colab
    print('In colab, downloading LOTlib3')
    !git clone https://github.com/thelogicalgrammar/LOTlib3
except:
    # not in colab
    print('Not in colab!')
Not in colab!

Imports#

First we need to import a bunch of stuff:

import numpy as np
from pprint import pprint
import matplotlib.pyplot as plt
from LOTlib3.Miscellaneous import q, random
from LOTlib3.Grammar import Grammar
from LOTlib3.DataAndObjects import FunctionData, Obj
from LOTlib3.Hypotheses.LOTHypothesis import LOTHypothesis
from LOTlib3.Hypotheses.Likelihoods.BinaryLikelihood import BinaryLikelihood
from LOTlib3.Eval import primitive
from LOTlib3.Miscellaneous import qq
from LOTlib3.TopN import TopN
from LOTlib3.Samplers.MetropolisHastings import MetropolisHastingsSampler

The model#

The aim here is to define a language that recovers the explicit form for a mathematical function most likely to have generated a certain set of input-output pairs.

For instance, to start with we might want: addition, product, power, log, exp.

If we want to make it fancier, we can add recursion up to a level (which allows us to do e.g. series)

AIMS

  • Pick a set of mathematical primitives that seem likely

  • Define a grammar

  • Define likelihood, e.g., sum of squares minimization

  • Run inference algo on the example datasets below!

Some hidden data to recover (Don’t look at it yet!)

Hide code cell source
x = np.linspace(1, 20, 30)

y = 1 - np.log(x) + np.random.normal(scale=0.1, size=len(x))
f1 = np.array(list(zip(x, y)))

y = 1 + 3 * x**2 + np.random.normal(scale=0.1, size=len(x))
f2 = np.array(list(zip(x, y)))

y = 4 * x - 0.5 * x**2 + np.random.normal(scale=0.1, size=len(x))
f3 = np.array(list(zip(x, y)))
for f in [f1, f2, f3]:
    print('Data: ')
    pprint(f)
    plt.scatter(*f.T)
    plt.show()
    print('\n')
Data: 
array([[ 1.        ,  0.96693628],
       [ 1.65517241,  0.44488931],
       [ 2.31034483,  0.11111997],
       [ 2.96551724, -0.05388751],
       [ 3.62068966, -0.08770159],
       [ 4.27586207, -0.34701368],
       [ 4.93103448, -0.5312597 ],
       [ 5.5862069 , -0.56654964],
       [ 6.24137931, -0.98309806],
       [ 6.89655172, -0.87024759],
       [ 7.55172414, -1.02468311],
       [ 8.20689655, -1.07690071],
       [ 8.86206897, -1.3229659 ],
       [ 9.51724138, -1.26628302],
       [10.17241379, -1.26076366],
       [10.82758621, -1.29232939],
       [11.48275862, -1.47545188],
       [12.13793103, -1.53730544],
       [12.79310345, -1.72121669],
       [13.44827586, -1.59430107],
       [14.10344828, -1.6440229 ],
       [14.75862069, -1.74346587],
       [15.4137931 , -1.89063419],
       [16.06896552, -1.87860771],
       [16.72413793, -1.73337155],
       [17.37931034, -1.74285824],
       [18.03448276, -1.82462328],
       [18.68965517, -2.13664592],
       [19.34482759, -1.87158473],
       [20.        , -1.96221768]])
_images/d333a5d73e9a061fbae05e48c3fa4c12d7e4f5a37867991a070c2a63c0e5c3f5.png
Data: 
array([[1.00000000e+00, 3.71191585e+00],
       [1.65517241e+00, 9.24602588e+00],
       [2.31034483e+00, 1.71111810e+01],
       [2.96551724e+00, 2.74461376e+01],
       [3.62068966e+00, 4.04248405e+01],
       [4.27586207e+00, 5.58883272e+01],
       [4.93103448e+00, 7.39602616e+01],
       [5.58620690e+00, 9.45716098e+01],
       [6.24137931e+00, 1.17768463e+02],
       [6.89655172e+00, 1.43745608e+02],
       [7.55172414e+00, 1.72014815e+02],
       [8.20689655e+00, 2.03113066e+02],
       [8.86206897e+00, 2.36586296e+02],
       [9.51724138e+00, 2.72592478e+02],
       [1.01724138e+01, 3.11500578e+02],
       [1.08275862e+01, 3.52680831e+02],
       [1.14827586e+01, 3.96347343e+02],
       [1.21379310e+01, 4.42937224e+02],
       [1.27931034e+01, 4.91952665e+02],
       [1.34482759e+01, 5.43415913e+02],
       [1.41034483e+01, 5.97790718e+02],
       [1.47586207e+01, 6.54449584e+02],
       [1.54137931e+01, 7.13693393e+02],
       [1.60689655e+01, 7.75636386e+02],
       [1.67241379e+01, 8.39996211e+02],
       [1.73793103e+01, 9.07024122e+02],
       [1.80344828e+01, 9.76727694e+02],
       [1.86896552e+01, 1.04889184e+03],
       [1.93448276e+01, 1.12381860e+03],
       [2.00000000e+01, 1.20102847e+03]])
_images/554f6c6c2e1916a1991c4509d2d9b272d995e765eb8acd5583af4101fb5d3362.png
Data: 
array([[   1.        ,    3.4938151 ],
       [   1.65517241,    5.44628788],
       [   2.31034483,    6.53998577],
       [   2.96551724,    7.67421996],
       [   3.62068966,    7.89516258],
       [   4.27586207,    8.02252067],
       [   4.93103448,    7.55957521],
       [   5.5862069 ,    6.68554562],
       [   6.24137931,    5.49661335],
       [   6.89655172,    3.78650889],
       [   7.55172414,    1.76357834],
       [   8.20689655,   -0.95288503],
       [   8.86206897,   -3.56660826],
       [   9.51724138,   -7.30943747],
       [  10.17241379,  -11.0152518 ],
       [  10.82758621,  -15.41306886],
       [  11.48275862,  -20.02607231],
       [  12.13793103,  -25.05703701],
       [  12.79310345,  -30.66109506],
       [  13.44827586,  -36.58291196],
       [  14.10344828,  -43.09300853],
       [  14.75862069,  -49.80759178],
       [  15.4137931 ,  -57.20869912],
       [  16.06896552,  -64.76773707],
       [  16.72413793,  -72.90556438],
       [  17.37931034,  -81.42918673],
       [  18.03448276,  -90.50159589],
       [  18.68965517,  -99.98277275],
       [  19.34482759, -109.61216676],
       [  20.        , -120.07684039]])
_images/4a8632b04c05e127a2c6d0b80a47fbcbd9ffdd7d99d5adf9affebe5487cad169.png