Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
Real-time collaboration for Jupyter Notebooks, Linux Terminals, LaTeX, VS Code, R IDE, and more,
all in one place. Commercial Alternative to JupyterHub.
| Download
Project: Joseph Spencer - Physics_2023_24/PHYS105 Introduction to Computational Physics/PHYS105_2023
Path: PHYS105 Introduction to Computational Physics/ComputerClassesStudent/Phys105-Week11/ libs/fitting.py
Views: 26Image: ubuntu2204
# libs/fitting.py12import numpy as np3import matplotlib.pyplot as plt4import math5from scipy.optimize import least_squares6from numpy.random import default_rng7from scipy.stats import chi2 as stats_chi28import sys91011def straight_line(params, xdata):12"""13Function for a straight line14- params array of function parameters15- xdata array of x data points1617"""1819f = params[0] + params[1] * xdata20return f212223def straight_line_diff(params, xdata):24"""25Differential of function for a straight line26- params array of function parameters27- xdata array of x data points2829"""3031df = params[1]32return df333435def polynomial(params, xdata):36"""37Function for a ploynomial of order `len(nparams)`38- params array of function parameters39- xdata array of x data points4041"""4243f = 04445npol = len(params)46for i in range(npol):47f += params[i] * xdata**i4849return f505152# <!-- Polynomial Fit -- START -->535455def polynomial_diff(params, xdata):56"""57Differential of function for a polynomial of order `len(nparams)`58- params array of function parameters59- xdata array of x data points60"""6162df = 06364npol = len(params)65for i in range(1, npol):66df += params[i] * i * xdata**(i - 1)6768return df697071# <!-- Polynomial Fit -- END -->727374def minimise(params, xdata, ydata, xerr, yerr):75"""76Calculation to minimise to find the best fit using chi2: difference between each y point and its77prediction by the function, divided by the sum in quadrature of the error on y, both from the78y error and from the related error in x.79- params array of function parameters80- xdata array of x data points81- ydata array of y data points82- xerr array of x data errors83- yerr array of y data errors84- func function describing data85- diff differential of function86"""8788residuals = (ydata - straight_line(params, xdata)) / (89np.sqrt(yerr**2 + straight_line_diff(params, xdata)**2 * xerr**2))9091return residuals929394# Minimise Poly -- START95def minimise_poly(params, xdata, ydata, xerr, yerr):96"""97Calculation to minimise to find the best fit using chi2: difference between each y point and its98prediction by the function, divided by the sum in quadrature of the error on y, both from the99y error and from the related error in x.100- params array of function parameters101- xdata array of x data points102- ydata array of y data points103- xerr array of x data errors104- yerr array of y data errors105- func function describing data106- diff differential of function107"""108109residuals = (ydata - polynomial(params, xdata)) / (110np.sqrt(yerr**2 + polynomial_diff(params, xdata)**2 * xerr**2))111112return residuals113114115def fit(xdata,116ydata,117xerror,118yerror,119init_params=None,120xlabel="X",121ylabel="Y",122title=None,123fname=None):124"""125Function to perform least-squares fit for a straight line.126- xdata array of x data points127- ydata array of y data points128- xerr array of x data errors129- yerr array of y data errors130- init_params array of function parameters131- xlabel string of the x axis132- ylabel string of the y axis133- title string of the title of graph134"""135136# Run fit137if init_params is None:138init_params = [0.0, 10.0]139140if title is None:141title = ""142else:143title = title144145if fname is None:146fname = "unknown"147else:148fname = fname149150result = least_squares(minimise,151init_params,152args=(xdata, ydata, xerror, yerror))153154# Check fit succeeds155if not result.success or result.status < 1:156print("ERROR: Fit failed with message {}".format(result.message))157print("Please check the data and inital parameter estimates")158return 0, 0159else:160print("Fit succeeded")161162# Get fitted parameters163final_params = result.x164c = final_params[0]165m = final_params[1]166nparams = len(final_params)167168# Calculate chi2169chi2_array = result.fun**2170chi2 = sum(chi2_array)171npoints = len(xdata)172reduced_chi2 = chi2 / (npoints - nparams)173chi2_prob = stats_chi2.sf(chi2, (npoints - nparams))174175# Print chi2176np.set_printoptions(precision=3)177print("\n=== Fit quality ===")178print("chisq per point = \n", chi2_array)179print(180"chisq = {:7.5g}, ndf = {}, chisq/NDF = {:7.5g}, chisq prob = {:7.5g}\n"181.format(chi2, npoints - nparams, reduced_chi2, chi2_prob))182183if reduced_chi2 < 0.25 or reduced_chi2 > 4:184print(185"WARNING: chi2/ndf suspiciously small or large. Please check the data and initial parameter estimates"186)187188if chi2_prob < 0.05:189print(190"WARNING: chi2 probability for given degrees of freedom less than 0.05 . Please check the data and "191"initial parameter estimates")192193# Calculate errors194jacobian = result.jac195jacobian2 = np.dot(jacobian.T, jacobian)196determinant = np.linalg.det(jacobian2)197198if determinant < 1E-32:199print(200f"Matrix singular (determinant = {determinant}, error calculation failed."201)202param_errors = np.zeros(nparams)203else:204covariance = np.linalg.inv(jacobian2)205param_errors = np.sqrt(covariance.diagonal())206207print("=== Fitted parameters ===")208print("c = {:7.5g} +- {:7.5g}".format(final_params[0], param_errors[0]))209print("m = {:7.5g} +- {:7.5g}".format(final_params[1], param_errors[1]))210211# Calculate fitted function values212yfit = straight_line(final_params, xdata)213214# Visualise result215fig = plt.figure(figsize=(5, 4))216plt.title(title)217plt.xlabel(xlabel, fontsize=13)218plt.ylabel(ylabel, fontsize=13)219plt.grid(color='grey', linestyle="--")220221plt.errorbar(xdata,222ydata,223xerr=xerror,224yerr=yerror,225fmt='k',226marker=".",227linestyle='',228label="Data")229plt.plot(xdata, yfit, color='r', linestyle='-', label="Fit")230231plt.legend(loc=0, fontsize=16)232233text = "c: {:7.5g} +- {:7.5g}\n".format(final_params[0], param_errors[0])234text += "m: {:7.5g} +- {:7.5g}\n".format(final_params[1], param_errors[1])235plt.text(0.95,2360,237text,238transform=fig.axes[0].transAxes,239ha="right",240va="bottom",241fontsize=12)242243plt.show()244# plt.savefig(fname)245246# Return arrays of parameters and associated errors247return final_params, param_errors248249250# <!-- Fit Data as a linear graph - END -->251# Fit Poly252253254def fit_poly(xdata,255ydata,256xerror,257yerror,258init_params=None,259xlabel="X",260ylabel="Y",261title=None,262fname=None):263"""264Function to perform least-squares fit for a straight line.265- xdata array of x data points266- ydata array of y data points267- xerr array of x data errors268- yerr array of y data errors269- init_params array of function parameters270- xlabel string of the x axis271- ylabel string of the y axis272- title string of the title of graph273"""274275# Run fit276if init_params is None:277init_params = [0.0, 10.0]278279if title is None:280title = ""281else:282title = title283284if fname is None:285fname = "unknown"286else:287fname = fname288289result = least_squares(minimise_poly,290init_params,291args=(xdata, ydata, xerror, yerror))292293# Check fit succeeds294if not result.success or result.status < 1:295print("ERROR: Fit failed with message {}".format(result.message))296print("Please check the data and inital parameter estimates")297return 0, 0298else:299print("Fit succeeded")300301# Get fitted parameters302final_params = result.x303c = final_params[0]304m = final_params[1]305nparams = len(final_params)306307# Calculate chi2308chi2_array = result.fun**2309chi2 = sum(chi2_array)310npoints = len(xdata)311reduced_chi2 = chi2 / (npoints - nparams)312chi2_prob = stats_chi2.sf(chi2, (npoints - nparams))313314# Print chi2315np.set_printoptions(precision=3)316print("\n=== Fit quality ===")317print("chisq per point = \n", chi2_array)318print(319"chisq = {:7.5g}, ndf = {}, chisq/NDF = {:7.5g}, chisq prob = {:7.5g}\n"320.format(chi2, npoints - nparams, reduced_chi2, chi2_prob))321322if reduced_chi2 < 0.25 or reduced_chi2 > 4:323print(324"WARNING: chi2/ndf suspiciously small or large. Please check the data and initial parameter estimates"325)326327if chi2_prob < 0.05:328print(329"WARNING: chi2 probability for given degrees of freedom less than 0.05 . Please check the data and "330"initial parameter estimates")331332# Calculate errors333jacobian = result.jac334jacobian2 = np.dot(jacobian.T, jacobian)335determinant = np.linalg.det(jacobian2)336337if determinant < 1E-32:338print(339f"Matrix singular (determinant = {determinant}, error calculation failed."340)341param_errors = np.zeros(nparams)342else:343covariance = np.linalg.inv(jacobian2)344param_errors = np.sqrt(covariance.diagonal())345finalText = ""346347if nparams == 2:348print("=== Fitted parameters ===")349print("Intercept (c) = {:7.5g} +- {:7.5g}".format(350final_params[0], param_errors[0]))351finalText += "c: {:7.5g} +- {:7.5g}\n".format(final_params[0],352param_errors[0])353print("gradient (m) = {:7.5g} +- {:7.5g}".format(354final_params[1], param_errors[1]))355finalText += "m: {:7.5g} +- {:7.5g}\n".format(final_params[1],356param_errors[1])357else:358print("(Parameters in order of increasing power of x)")359for x in range(nparams):360print("param {} = {:7.5g} +- {:7.5g}".format(361x, final_params[x], param_errors[x]))362finalText += "p{}: {:7.5g} +- {:7.5g}\n".format(363x, final_params[x], param_errors[x])364365# Calculate fitted function values366yfit = straight_line(final_params, xdata)367368# Visualise result369fig = plt.figure(figsize=(5, 4))370plt.title(title)371plt.xlabel(xlabel, fontsize=13)372plt.ylabel(ylabel, fontsize=13)373plt.grid(color='grey', linestyle="--")374375plt.errorbar(xdata,376ydata,377xerr=xerror,378yerr=yerror,379fmt='k',380marker=".",381linestyle='',382label="Data")383plt.plot(xdata, yfit, color='r', linestyle='-', label="Fit")384385plt.legend(loc=0, fontsize=16)386387text = "c: {:7.5g} +- {:7.5g}\n".format(final_params[0], param_errors[0])388text += "m: {:7.5g} +- {:7.5g}\n".format(final_params[1], param_errors[1])389plt.text(0.95,3900,391finalText,392transform=fig.axes[0].transAxes,393ha="right",394va="bottom",395fontsize=12)396397plt.show()398# plt.savefig(fname)399400# Return arrays of parameters and associated errors401return final_params, param_errors402403404# <!-- Fit data as a histogram - START -->405406407def plotHistogram(array,408title,409xLabel,410yLabel,411binBot=0,412binTop=1000,413binNumber=10):414if binBot == 0 and binTop == 1000:415binBot = min(array)416binTop = max(array)417else:418binBot = binBot419binTop = binTop420421bins, binWidth = np.linspace(binBot, binTop, binNumber + 1, retstep=True)422423# Calculate mean and standard deviation424nEvents = len(array)425mu = np.mean(array) # calculate arithmetic mean of numbers in array426sigma = np.std(427array) # calculate standard deviation (error on single value)428muError = sigma / np.sqrt(429nEvents) # calculate error of mean from sigma/sqrt(n)430431# Construct a figure with a title and axis labels432plt.figure(figsize=(7, 5))433plt.title(title, fontsize=14)434plt.xlabel(xLabel)435plt.ylabel(yLabel)436437# Make a histogram and display it438plt.hist(array, bins=bins, color='b')439plt.grid(color='grey')440plt.show()441442print("Histogram bins start at", binBot, "finish at", binTop)443print("Number of bins is", binNumber, "and width of bins is", binWidth)444445return mu, sigma, muError446447448# <!-- Fit data as a histogram - END -->449# <!-- Consistency Check Function - START -->450451452def consistencycheck(m1, m2, mErr1, mErr2):453leftSide = np.abs(m1 - m2)454rightSide = 3 * ((math.sqrt((mErr1)**2 + (mErr2)**2)))455456if leftSide <= rightSide:457print("The results are consistent ",458f"{round(leftSide, 4)} <= {round(rightSide, 4)}")459else:460print("The results are inconsistent",461f"{round(leftSide, 4)} > {round(rightSide, 4)}")462463464# <!-- Consistency Check Function - END -->465466# <!-- Polynomial Fit -- END -->467468469# <!-- Weighted Mean Function -- START -->470def weightedMean(data, errors):471for i in errors:472weights = 1 / i**2473474numer = np.sum(np.array(data) * np.array(weights))475denom = np.sum(np.array(weights))476477newdata = np.average(data)478newerror = 1 / math.sqrt(denom)479480return newdata, newerror481482483# <!-- Weighted Mean Function -- END -->484485486def generate_toy_data(xarray, slope, intercept):487"""488Generate toy data, randomly Poisson fluctuated around y = slope * xarray + intercept.489Take a simple Gaussian error i.e. sqrt(N) on each y point.490"""491492model = intercept + slope * xarray493494rng = default_rng(3)495yarray = rng.poisson(model)496yerr = np.sqrt(yarray)497498return yarray, yerr499500501