import matplotlib.pyplot as plt
import numpy as np
import math
# Data example
v_samples = np.array([-1.71498585, -1.68068613, -1.64638642, -1.6120867 , -1.57778698, -1.54348727,
-1.50918755, -1.47488783, -1.44058812, -1.4062884 , -1.37198868, -1.33768896,
-1.30338925, -1.26908953, -1.23478981, -1.2004901 , -1.16619038, -1.13189066,
-1.09759094, -1.06329123, -1.02899151, -0.99469179, -0.96039208, -0.92609236,
-0.89179264, -0.85749293, -0.82319321, -0.78889349, -0.75459377, -0.72029406,
-0.68599434, -0.65169462, -0.61739491, -0.58309519, -0.54879547, -0.51449576,
-0.48019604, -0.44589632, -0.4115966 , -0.37729689, -0.34299717, -0.30869745,
-0.27439774, -0.24009802, -0.2057983 , -0.17149859, -0.13719887, -0.10289915,
-0.06859943, -0.03429972, 0. , 0.03429972, 0.06859943, 0.10289915,
0.13719887, 0.17149859, 0.2057983 , 0.24009802, 0.27439774, 0.30869745,
0.34299717, 0.37729689, 0.4115966 , 0.44589632, 0.48019604, 0.51449576,
0.54879547, 0.58309519, 0.61739491, 0.65169462, 0.68599434 , 0.72029406,
0.75459377, 0.78889349, 0.82319321, 0.85749293, 0.89179264, 0.92609236,
0.96039208, 0.99469179, 1.02899151, 1.06329123, 1.09759094, 1.13189066,
1.16619038, 1.2004901, 1.23478981, 1.26908953, 1.30338925, 1.33768896,
1.37198868, 1.4062884, 1.44058812, 1.47488783, 1.50918755, 1.54348727,
1.57778698, 1.6120867, 1.64638642, 1.68068613, 1.71498585, 1.71498585,
1.68068613, 1.64638642, 1.6120867 , 1.57778698, 1.54348727, 1.50918755,
1.47488783, 1.44058812, 1.4062884 , 1.37198868, 1.33768896, 1.30338925,
1.26908953, 1.23478981, 1.2004901 , 1.16619038, 1.13189066, 1.09759094,
1.06329123, 1.02899151, 0.99469179, 0.96039208, 0.92609236, 0.89179264,
0.85749293, 0.82319321, 0.78889349, 0.75459377, 0.72029406, 0.68599434,
0.65169462, 0.61739491, 0.58309519, 0.54879547, 0.51449576, 0.48019604,
0.44589632, 0.4115966, 0.37729689, 0.34299717, 0.30869745, 0.27439774,
0.24009802, 0.2057983, 0.17149859, 0.13719887, 0.10289915, 0.06859943,
0.03429972, 0., -0.03429972, -0.06859943, -0.10289915, -0.13719887,
-0.17149859, -0.2057983, -0.24009802, -0.27439774, -0.30869745, -0.34299717,
-0.37729689, -0.4115966, -0.44589632, -0.48019604, -0.51449576, -0.54879547,
-0.58309519, -0.61739491, -0.65169462, -0.68599434, -0.72029406, -0.75459377,
-0.78889349, -0.82319321, -0.85749293, -0.89179264, -0.92609236, -0.96039208,
-0.99469179, -1.02899151, -1.06329123, -1.09759094, -1.13189066, -1.16619038,
-1.2004901 , -1.23478981, -1.26908953, -1.30338925, -1.33768896, -1.37198868,
-1.4062884 , -1.44058812, -1.47488783, -1.50918755, -1.54348727, -1.57778698,
-1.6120867 , -1.64638642, -1.68068613, -1.71498585])
x_samples = np.array([ 1.43253309e+00, 1.42697026e+00, 1.42446699e+00, 1.41895978e+00,
1.40939171e+00, 1.40494145e+00, 1.39320388e+00, 1.38758542e+00,
1.38135505e+00, 1.37100818e+00, 1.36349836e+00, 1.35137139e+00,
1.34024572e+00, 1.32978760e+00, 1.31871757e+00, 1.30447672e+00,
1.29068090e+00, 1.27933272e+00, 1.26542565e+00, 1.24918218e+00,
1.23143675e+00, 1.21747404e+00, 1.19878293e+00, 1.18142690e+00,
1.16301393e+00, 1.14131889e+00, 1.11723183e+00, 1.09865197e+00,
1.07601125e+00, 1.04908715e+00, 1.02494446e+00, 9.92123757e-01,
9.67925442e-01, 9.35994792e-01, 9.04119770e-01, 8.73079173e-01,
8.40091585e-01, 8.00372971e-01, 7.67997294e-01, 7.29613760e-01,
6.89060722e-01, 6.46282551e-01, 6.02336186e-01, 5.58000422e-01,
5.11439526e-01, 4.60984649e-01, 4.12532390e-01, 3.61409973e-01,
3.06782972e-01, 2.53546679e-01, 1.98919678e-01, 1.41121864e-01,
8.21558545e-02, 1.92958637e-02, -3.90026057e-02, -1.03809587e-01,
-1.71397985e-01, -2.31476560e-01, -2.86214817e-01, -3.42677553e-01,
-3.98472748e-01, -4.60609571e-01, -5.21633827e-01, -5.85550756e-01,
-6.50802764e-01, -7.19225586e-01, -7.84032568e-01, -8.49618346e-01,
-9.13257133e-01, -9.75783354e-01, -1.03753078e+00, -1.09744247e+00,
-1.15479526e+00, -1.21303810e+00,-1.26633002e+00 ,-1.31628424e+00,
-1.36100940e+00, -1.39822474e+00, -1.42653955e+00, -1.44784520e+00,
-1.45574442e+00, -1.45168355e+00, -1.43566260e+00, -1.40512265e+00,
-1.36529278e+00, -1.31239026e+00 ,-1.25470370e+00, -1.18594711e+00,
-1.11496539e+00, -1.04070159e+00 ,-9.64157037e-01, -8.89448216e-01,
-8.14016227e-01, -7.41087512e-01, -6.70662071e-01, -6.00904170e-01,
-5.32982003e-01, -4.68119393e-01, -4.06705738e-01, -3.50521144e-01,
-2.94837206e-01, -2.68525015e-01, -2.69192554e-01, -2.74199102e-01,
-2.87939295e-01, -3.08799912e-01, -3.27435396e-01, -3.58809763e-01,
-3.91352324e-01, -4.29568974e-01 ,-4.70511410e-01, -5.20020606e-01,
-5.72311218e-01, -6.29775263e-01, -6.91522688e-01, -7.59778625e-01,
-8.31650402e-01, -9.11254514e-01, -9.96643971e-01, -1.08859757e+00,
-1.18533520e+00, -1.28724626e+00, -1.39271754e+00, -1.49813319e+00,
-1.59915420e+00 ,-1.69127468e+00, -1.76537159e+00, -1.82161181e+00,
-1.84692270e+00, -1.84369625e+00, -1.81415762e+00, -1.76537159e+00,
-1.70373542e+00, -1.63197490e+00, -1.55331647e+00, -1.46881707e+00,
-1.38187002e+00 ,-1.29514548e+00, -1.20319189e+00, -1.11579981e+00,
-1.02801834e+00 ,-9.47468544e-01, -8.68754485e-01, -7.91097363e-01,
-7.17056082e-01 ,-6.44182995e-01, -5.73813183e-01, -5.07003582e-01,
-4.40861521e-01 ,-3.79503494e-01, -3.21093768e-01, -2.58511919e-01,
-2.04385573e-01, -1.47477811e-01, -9.95262075e-02, -4.80700203e-02,
8.82892888e-04, 5.94595038e-02, 1.02793958e-01, 1.56419649e-01,
1.96193891e-01, 2.46704397e-01 , 2.82751542e-01, 3.35598437e-01,
3.72480007e-01, 4.19430301e-01, 4.55699959e-01, 4.98700643e-01,
5.28684303e-01, 5.73631978e-01, 6.02503071e-01, 6.46115666e-01,
6.73762936e-01, 7.15039143e-01, 7.43520838e-01, 7.72503187e-01,
8.10219182e-01, 8.34250612e-01, 8.62954820e-01, 8.96109293e-01,
9.18972529e-01, 9.41446366e-01, 9.75157123e-01, 9.92846925e-01,
1.01737901e+00, 1.04324617e+00, 1.06126975e+00, 1.08836073e+00,
1.10610617e+00, 1.13219584e+00, 1.15327897e+00, 1.16707479e+00,
1.18765727e+00, 1.20851788e+00, 1.22581829e+00, 1.24667891e+00,
1.25886151e+00, 1.27849830e+00, 1.29335106e+00, 1.30859322e+00,
1.32060893e+00, 1.32778498e+00])
# B-Spline fitting
n = 25 # number of knots
s = list(range(len(x_samples)+1))
d = 3
t_knots = []
index = list(range(n+d+2)) # according to the manual I was following it should be n+d+1
# but when I do that the code produces an error of list index out of range
# print(index)
for i in index:
if 0 <= i <= d:
o = 0
elif d+1 <= i <= n:
o = (i - d) / (n + 1 - d)
elif n+1 <= i <= n+d+1:
o = 1
t_knots.append(o)
# print(t_knots)
# Defining functions
def N_f(d,i,j,t,t_knots): # defines the N_ij coefficients of the spline equation
if j == 0:
if t_knots[i] <= t < t_knots[i+1]:
return 1
else:
return 0
#print('j=',j)
a = (t - t_knots[i]) / (t_knots[i+j] - t_knots[i]) * N_f(d,i,j-1,t,t_knots) if t_knots[i+j] != t_knots[i] else 0
b = (t_knots[i+j+1] -t) / (t_knots[i+j+1] - t_knots[i+1]) * N_f(d,i+1,j-1,t,t_knots) if t_knots[i+j+1] != t_knots[i+1] else 0
return a + b
def splineint(t, t_knots, Q_points, n, d): # interpolates for a given value of t
s = 0
for i in range(n):
s = s + N_f(d,i,d,t,t_knots) * Q_points[i,:]
return s
def spline_loss(t_map, t_knots, d, Q_points, P_points): # calculates the loss of the given guess
x_spline = np.zeros((1,2))
for t in t_map:
x_sp = splineint(t, t_knots, Q_points, n, d)
x_spline = np.append(x_spline, x_sp.reshape(1,2), axis=0)
x_spline = x_spline[1:,:]
z = x_spline - P_points
z2 = np.zeros((1,z.shape[0]))
for i in range(len(z2)):
z2[i] = z[i,:] @ z[i,:].T
J = 0.5 * np.sum(z2)
return J
def spline_fit(t_map, t_knots, d, Q_points, P_points): # Calculates the gradient for the Q_points
#for the optimization algorithm
A = np.zeros((len(t_map),len(Q_points[:,0])))
for r in range(A.shape[0]):
for c in range(A.shape[1]):
A[r,c] = N_f(d,c,d,t_map[r],t_knots)
gradQ = A.T @ A @ Q_points - A.T @ P_points
return gradQ
def spline_optimize(iterations, alpha, t_map, t_knots, d, Q1, P_points):
J_history = []
for it in range(iterations):
J = spline_loss(t_map, t_knots, d, Q1, P_points)
J_history.append(J)
if (it + 1) % 20 == 0:
print(J)
gradQ = spline_fit(t_map, t_knots, d, Q1, P_points)
Q1 -= alpha * gradQ
return Q1, J_history
# Mapping the sample points into t_map
t_map = []
for i in range(0,len(x_samples)):
o = (s[i] - s[0]) / (s[-1] - s[0])
t_map.append(o)
# Creating the first spline
Q1_1 = np.linspace(-2, 2, num=n+1)
Q1_2 = np.zeros((n+1, 1))
Q1 = np.append(Q1_1.reshape(n+1,1), Q1_2, axis=1)
# Visualizing the first spline
plt.plot(v_samples[0], x_samples[0])
plt.scatter(Q1[:,0], Q1[:,1], c=t_knots[0:n+1])
plt.colorbar()
x_spline = np.zeros((1,2))
for t in t_map:
x_sp = splineint(t, t_knots, Q1, n, d)
x_spline = np.append(x_spline, x_sp.reshape(1,2), axis=0)
x_spline = x_spline[1:,:]
# plt.plot(Q1[:,0],Q1[:,1])
plt.plot(x_spline[:,0],x_spline[:,1])
plt.show
print(x_spline.shape)
print(x_samples.shape)
s = np.array(s[:-1]).reshape(len(s[:-1]),1)
P_points = np.append(np.array(v_samples).reshape(len(s),1), np.array(x_samples).reshape(len(s),1), axis=1)
# Optimization Algorithm
iterations = 300
alpha = 0.01
Q1, J_history = spline_optimize(iterations, alpha, t_map, t_knots, d, Q1, P_points)
plt.plot(P_points[:,0],P_points[:,1])
plt.scatter(Q1[:,0], Q1[:,1], c=t_knots[0:n+1])
plt.colorbar()
x_spline = np.zeros((1,2))
for t in t_map:
x_sp = splineint(t, t_knots, Q1, n, d)
x_spline = np.append(x_spline, x_sp.reshape(1,2), axis=0)
x_spline = x_spline[1:,:]
# plt.plot(Q1[:,0],Q1[:,1])
plt.plot(x_spline[:,0],x_spline[:,1])
x_sp = splineint(0, t_knots, Q1, n, d)
print(x_sp)
x_sp = splineint(1, t_knots, Q1, n, d)
print(x_sp) # always gives [0., 0.]
# print(Q1)
plt.plot(J_history)
plt.yscale("log")
plt.show Click Run or press shift + ENTER to run code