CoCalc Logo Icon
StoreFeaturesDocsShareSupportNewsAboutSign UpSign In

Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.

| Download
Views: 26
Image: ubuntu2204
1
# libs/fitting.py
2
3
import numpy as np
4
import matplotlib.pyplot as plt
5
import math
6
from scipy.optimize import least_squares
7
from numpy.random import default_rng
8
from scipy.stats import chi2 as stats_chi2
9
import sys
10
11
12
def straight_line(params, xdata):
13
"""
14
Function for a straight line
15
- params array of function parameters
16
- xdata array of x data points
17
18
"""
19
20
f = params[0] + params[1] * xdata
21
return f
22
23
24
def straight_line_diff(params, xdata):
25
"""
26
Differential of function for a straight line
27
- params array of function parameters
28
- xdata array of x data points
29
30
"""
31
32
df = params[1]
33
return df
34
35
36
def polynomial(params, xdata):
37
"""
38
Function for a ploynomial of order `len(nparams)`
39
- params array of function parameters
40
- xdata array of x data points
41
42
"""
43
44
f = 0
45
46
npol = len(params)
47
for i in range(npol):
48
f += params[i] * xdata**i
49
50
return f
51
52
53
# <!-- Polynomial Fit -- START -->
54
55
56
def polynomial_diff(params, xdata):
57
"""
58
Differential of function for a polynomial of order `len(nparams)`
59
- params array of function parameters
60
- xdata array of x data points
61
"""
62
63
df = 0
64
65
npol = len(params)
66
for i in range(1, npol):
67
df += params[i] * i * xdata**(i - 1)
68
69
return df
70
71
72
# <!-- Polynomial Fit -- END -->
73
74
75
def minimise(params, xdata, ydata, xerr, yerr):
76
"""
77
Calculation to minimise to find the best fit using chi2: difference between each y point and its
78
prediction by the function, divided by the sum in quadrature of the error on y, both from the
79
y error and from the related error in x.
80
- params array of function parameters
81
- xdata array of x data points
82
- ydata array of y data points
83
- xerr array of x data errors
84
- yerr array of y data errors
85
- func function describing data
86
- diff differential of function
87
"""
88
89
residuals = (ydata - straight_line(params, xdata)) / (
90
np.sqrt(yerr**2 + straight_line_diff(params, xdata)**2 * xerr**2))
91
92
return residuals
93
94
95
# Minimise Poly -- START
96
def minimise_poly(params, xdata, ydata, xerr, yerr):
97
"""
98
Calculation to minimise to find the best fit using chi2: difference between each y point and its
99
prediction by the function, divided by the sum in quadrature of the error on y, both from the
100
y error and from the related error in x.
101
- params array of function parameters
102
- xdata array of x data points
103
- ydata array of y data points
104
- xerr array of x data errors
105
- yerr array of y data errors
106
- func function describing data
107
- diff differential of function
108
"""
109
110
residuals = (ydata - polynomial(params, xdata)) / (
111
np.sqrt(yerr**2 + polynomial_diff(params, xdata)**2 * xerr**2))
112
113
return residuals
114
115
116
def fit(xdata,
117
ydata,
118
xerror,
119
yerror,
120
init_params=None,
121
xlabel="X",
122
ylabel="Y",
123
title=None,
124
fname=None):
125
"""
126
Function to perform least-squares fit for a straight line.
127
- xdata array of x data points
128
- ydata array of y data points
129
- xerr array of x data errors
130
- yerr array of y data errors
131
- init_params array of function parameters
132
- xlabel string of the x axis
133
- ylabel string of the y axis
134
- title string of the title of graph
135
"""
136
137
# Run fit
138
if init_params is None:
139
init_params = [0.0, 10.0]
140
141
if title is None:
142
title = ""
143
else:
144
title = title
145
146
if fname is None:
147
fname = "unknown"
148
else:
149
fname = fname
150
151
result = least_squares(minimise,
152
init_params,
153
args=(xdata, ydata, xerror, yerror))
154
155
# Check fit succeeds
156
if not result.success or result.status < 1:
157
print("ERROR: Fit failed with message {}".format(result.message))
158
print("Please check the data and inital parameter estimates")
159
return 0, 0
160
else:
161
print("Fit succeeded")
162
163
# Get fitted parameters
164
final_params = result.x
165
c = final_params[0]
166
m = final_params[1]
167
nparams = len(final_params)
168
169
# Calculate chi2
170
chi2_array = result.fun**2
171
chi2 = sum(chi2_array)
172
npoints = len(xdata)
173
reduced_chi2 = chi2 / (npoints - nparams)
174
chi2_prob = stats_chi2.sf(chi2, (npoints - nparams))
175
176
# Print chi2
177
np.set_printoptions(precision=3)
178
print("\n=== Fit quality ===")
179
print("chisq per point = \n", chi2_array)
180
print(
181
"chisq = {:7.5g}, ndf = {}, chisq/NDF = {:7.5g}, chisq prob = {:7.5g}\n"
182
.format(chi2, npoints - nparams, reduced_chi2, chi2_prob))
183
184
if reduced_chi2 < 0.25 or reduced_chi2 > 4:
185
print(
186
"WARNING: chi2/ndf suspiciously small or large. Please check the data and initial parameter estimates"
187
)
188
189
if chi2_prob < 0.05:
190
print(
191
"WARNING: chi2 probability for given degrees of freedom less than 0.05 . Please check the data and "
192
"initial parameter estimates")
193
194
# Calculate errors
195
jacobian = result.jac
196
jacobian2 = np.dot(jacobian.T, jacobian)
197
determinant = np.linalg.det(jacobian2)
198
199
if determinant < 1E-32:
200
print(
201
f"Matrix singular (determinant = {determinant}, error calculation failed."
202
)
203
param_errors = np.zeros(nparams)
204
else:
205
covariance = np.linalg.inv(jacobian2)
206
param_errors = np.sqrt(covariance.diagonal())
207
208
print("=== Fitted parameters ===")
209
print("c = {:7.5g} +- {:7.5g}".format(final_params[0], param_errors[0]))
210
print("m = {:7.5g} +- {:7.5g}".format(final_params[1], param_errors[1]))
211
212
# Calculate fitted function values
213
yfit = straight_line(final_params, xdata)
214
215
# Visualise result
216
fig = plt.figure(figsize=(5, 4))
217
plt.title(title)
218
plt.xlabel(xlabel, fontsize=13)
219
plt.ylabel(ylabel, fontsize=13)
220
plt.grid(color='grey', linestyle="--")
221
222
plt.errorbar(xdata,
223
ydata,
224
xerr=xerror,
225
yerr=yerror,
226
fmt='k',
227
marker=".",
228
linestyle='',
229
label="Data")
230
plt.plot(xdata, yfit, color='r', linestyle='-', label="Fit")
231
232
plt.legend(loc=0, fontsize=16)
233
234
text = "c: {:7.5g} +- {:7.5g}\n".format(final_params[0], param_errors[0])
235
text += "m: {:7.5g} +- {:7.5g}\n".format(final_params[1], param_errors[1])
236
plt.text(0.95,
237
0,
238
text,
239
transform=fig.axes[0].transAxes,
240
ha="right",
241
va="bottom",
242
fontsize=12)
243
244
plt.show()
245
# plt.savefig(fname)
246
247
# Return arrays of parameters and associated errors
248
return final_params, param_errors
249
250
251
# <!-- Fit Data as a linear graph - END -->
252
# Fit Poly
253
254
255
def fit_poly(xdata,
256
ydata,
257
xerror,
258
yerror,
259
init_params=None,
260
xlabel="X",
261
ylabel="Y",
262
title=None,
263
fname=None):
264
"""
265
Function to perform least-squares fit for a straight line.
266
- xdata array of x data points
267
- ydata array of y data points
268
- xerr array of x data errors
269
- yerr array of y data errors
270
- init_params array of function parameters
271
- xlabel string of the x axis
272
- ylabel string of the y axis
273
- title string of the title of graph
274
"""
275
276
# Run fit
277
if init_params is None:
278
init_params = [0.0, 10.0]
279
280
if title is None:
281
title = ""
282
else:
283
title = title
284
285
if fname is None:
286
fname = "unknown"
287
else:
288
fname = fname
289
290
result = least_squares(minimise_poly,
291
init_params,
292
args=(xdata, ydata, xerror, yerror))
293
294
# Check fit succeeds
295
if not result.success or result.status < 1:
296
print("ERROR: Fit failed with message {}".format(result.message))
297
print("Please check the data and inital parameter estimates")
298
return 0, 0
299
else:
300
print("Fit succeeded")
301
302
# Get fitted parameters
303
final_params = result.x
304
c = final_params[0]
305
m = final_params[1]
306
nparams = len(final_params)
307
308
# Calculate chi2
309
chi2_array = result.fun**2
310
chi2 = sum(chi2_array)
311
npoints = len(xdata)
312
reduced_chi2 = chi2 / (npoints - nparams)
313
chi2_prob = stats_chi2.sf(chi2, (npoints - nparams))
314
315
# Print chi2
316
np.set_printoptions(precision=3)
317
print("\n=== Fit quality ===")
318
print("chisq per point = \n", chi2_array)
319
print(
320
"chisq = {:7.5g}, ndf = {}, chisq/NDF = {:7.5g}, chisq prob = {:7.5g}\n"
321
.format(chi2, npoints - nparams, reduced_chi2, chi2_prob))
322
323
if reduced_chi2 < 0.25 or reduced_chi2 > 4:
324
print(
325
"WARNING: chi2/ndf suspiciously small or large. Please check the data and initial parameter estimates"
326
)
327
328
if chi2_prob < 0.05:
329
print(
330
"WARNING: chi2 probability for given degrees of freedom less than 0.05 . Please check the data and "
331
"initial parameter estimates")
332
333
# Calculate errors
334
jacobian = result.jac
335
jacobian2 = np.dot(jacobian.T, jacobian)
336
determinant = np.linalg.det(jacobian2)
337
338
if determinant < 1E-32:
339
print(
340
f"Matrix singular (determinant = {determinant}, error calculation failed."
341
)
342
param_errors = np.zeros(nparams)
343
else:
344
covariance = np.linalg.inv(jacobian2)
345
param_errors = np.sqrt(covariance.diagonal())
346
finalText = ""
347
348
if nparams == 2:
349
print("=== Fitted parameters ===")
350
print("Intercept (c) = {:7.5g} +- {:7.5g}".format(
351
final_params[0], param_errors[0]))
352
finalText += "c: {:7.5g} +- {:7.5g}\n".format(final_params[0],
353
param_errors[0])
354
print("gradient (m) = {:7.5g} +- {:7.5g}".format(
355
final_params[1], param_errors[1]))
356
finalText += "m: {:7.5g} +- {:7.5g}\n".format(final_params[1],
357
param_errors[1])
358
else:
359
print("(Parameters in order of increasing power of x)")
360
for x in range(nparams):
361
print("param {} = {:7.5g} +- {:7.5g}".format(
362
x, final_params[x], param_errors[x]))
363
finalText += "p{}: {:7.5g} +- {:7.5g}\n".format(
364
x, final_params[x], param_errors[x])
365
366
# Calculate fitted function values
367
yfit = straight_line(final_params, xdata)
368
369
# Visualise result
370
fig = plt.figure(figsize=(5, 4))
371
plt.title(title)
372
plt.xlabel(xlabel, fontsize=13)
373
plt.ylabel(ylabel, fontsize=13)
374
plt.grid(color='grey', linestyle="--")
375
376
plt.errorbar(xdata,
377
ydata,
378
xerr=xerror,
379
yerr=yerror,
380
fmt='k',
381
marker=".",
382
linestyle='',
383
label="Data")
384
plt.plot(xdata, yfit, color='r', linestyle='-', label="Fit")
385
386
plt.legend(loc=0, fontsize=16)
387
388
text = "c: {:7.5g} +- {:7.5g}\n".format(final_params[0], param_errors[0])
389
text += "m: {:7.5g} +- {:7.5g}\n".format(final_params[1], param_errors[1])
390
plt.text(0.95,
391
0,
392
finalText,
393
transform=fig.axes[0].transAxes,
394
ha="right",
395
va="bottom",
396
fontsize=12)
397
398
plt.show()
399
# plt.savefig(fname)
400
401
# Return arrays of parameters and associated errors
402
return final_params, param_errors
403
404
405
# <!-- Fit data as a histogram - START -->
406
407
408
def plotHistogram(array,
409
title,
410
xLabel,
411
yLabel,
412
binBot=0,
413
binTop=1000,
414
binNumber=10):
415
if binBot == 0 and binTop == 1000:
416
binBot = min(array)
417
binTop = max(array)
418
else:
419
binBot = binBot
420
binTop = binTop
421
422
bins, binWidth = np.linspace(binBot, binTop, binNumber + 1, retstep=True)
423
424
# Calculate mean and standard deviation
425
nEvents = len(array)
426
mu = np.mean(array) # calculate arithmetic mean of numbers in array
427
sigma = np.std(
428
array) # calculate standard deviation (error on single value)
429
muError = sigma / np.sqrt(
430
nEvents) # calculate error of mean from sigma/sqrt(n)
431
432
# Construct a figure with a title and axis labels
433
plt.figure(figsize=(7, 5))
434
plt.title(title, fontsize=14)
435
plt.xlabel(xLabel)
436
plt.ylabel(yLabel)
437
438
# Make a histogram and display it
439
plt.hist(array, bins=bins, color='b')
440
plt.grid(color='grey')
441
plt.show()
442
443
print("Histogram bins start at", binBot, "finish at", binTop)
444
print("Number of bins is", binNumber, "and width of bins is", binWidth)
445
446
return mu, sigma, muError
447
448
449
# <!-- Fit data as a histogram - END -->
450
# <!-- Consistency Check Function - START -->
451
452
453
def consistencycheck(m1, m2, mErr1, mErr2):
454
leftSide = np.abs(m1 - m2)
455
rightSide = 3 * ((math.sqrt((mErr1)**2 + (mErr2)**2)))
456
457
if leftSide <= rightSide:
458
print("The results are consistent ",
459
f"{round(leftSide, 4)} <= {round(rightSide, 4)}")
460
else:
461
print("The results are inconsistent",
462
f"{round(leftSide, 4)} > {round(rightSide, 4)}")
463
464
465
# <!-- Consistency Check Function - END -->
466
467
# <!-- Polynomial Fit -- END -->
468
469
470
# <!-- Weighted Mean Function -- START -->
471
def weightedMean(data, errors):
472
for i in errors:
473
weights = 1 / i**2
474
475
numer = np.sum(np.array(data) * np.array(weights))
476
denom = np.sum(np.array(weights))
477
478
newdata = np.average(data)
479
newerror = 1 / math.sqrt(denom)
480
481
return newdata, newerror
482
483
484
# <!-- Weighted Mean Function -- END -->
485
486
487
def generate_toy_data(xarray, slope, intercept):
488
"""
489
Generate toy data, randomly Poisson fluctuated around y = slope * xarray + intercept.
490
Take a simple Gaussian error i.e. sqrt(N) on each y point.
491
"""
492
493
model = intercept + slope * xarray
494
495
rng = default_rng(3)
496
yarray = rng.poisson(model)
497
yerr = np.sqrt(yarray)
498
499
return yarray, yerr
500
501