잡다한 코드

파이썬으로 curve fitting 및 R^2 값 구하기

아끌 2024. 10. 6. 16:53

여러가지 실험보고서(특히 물리)를 쓸 때 curve fitting을 해야하는 경우가 많다. 보통 엑셀의 추세선 기능을 이용하지만, 추세선의 종류는 직선, 지수함수 등으로 한정되어 있다. 데이터가 이론적으로 간단한 함수를 따른다면 x축과 y축의 물리량 간 관계가 선형이 되도록 y축에 해당하는 물리량을 변형한 후 fitting하면 되지만, 만약 데이터가 e^(-bt)sin(wt+d) 꼴의 복잡한 함수를 따른다면 엑셀에서는 그 추세선을 도저히 그려낼 방법이 없다. 보통 curve fitting 프로그램을 다운받아 그리지만(ex: PASCO Capstone, Origin), 프로그램을 다운받고 그 사용방법까지 익히는데는 적잖은 노력과 시간이 필요하다.

이러한 경우 파이썬을 이용하여 원하는 함수 꼴로 fitting을 진행할 수 있다. 파이썬을 아는 사람이라면, 이 코드에서 사용되는 라이브러리와 그 원리를 자세히 모르더라도 눈치껏 따라할 수 있을 것이다.

편리한 코드 사용을 위해 google colab에서 진행하기를 추천한다.

 

# fit a line to the economic data
from numpy import *
import pandas
from pandas import read_excel
from scipy.optimize import curve_fit
from matplotlib import pyplot

# define the true objective function
# 여기에 fitting을 원하는 함수 꼴을 입력하면 된다.

def objective(x, 미정계수들):
	return 함수식

# 예시 
# def objective(x, a, b, c, d, T):
#	return a + b*sin(2*pi/T*x + d) * exp(-x/c)

# load the dataset
# Google colab에서 진행하는 경우 drive mount 이후 파일 업로드 후 진행.
url = '데이터 파일 경로(파일 이름까지 포함)'

# 만약 엑셀 파일이 아니라 csv 등 기타 파일을 사용할 경우
# pandas의 dataframe 형태로 읽어주는 다른 알맞은 함수 사용할 것.
# Sheet1 부분에 데이터가 포함된 엑셀 시트 이름을 적으면 된다
dataframe = read_excel(url, sheet_name = 'Sheet1')


# choose the input and output variables
# 데이터 형태는 다음과 같이 생긴 한 시트여야 한다.
#      A           B             C       D     E
# 1   Time(s)   x-Position(cm)
# 2   0            1
# 3   0.02         1.09
# 4   0.04         1.27
# ...  ...         ...
#
# 기본적으로 A, B열 두 개의 열에만 데이터가 존재하여야 한다. 그 외의 열은 무시가 된다.
# 반드시 1행에는 각 축의 이름이, 2행부터 데이터가 와야한다. 
# A열이 x축, B열이 y축으로 쓰인다.

x, y= dataframe['1행 1열 내용'], dataframe['1행 2열 내용']
# 예시
# x, y= dataframe['Time(s)'], dataframe['x-Position(cm)']

# curve fit
# 추가로 p0, bounds 등의 parameter를 넣어 계산 속도를 향상시키거나 뻗는 걸 방지할 수 있다.
# 특히 주기함수로 피팅하는 경우, p0나 bounds를 넣어주지 않으면 fitting이 아예 잘못되는 경우도 있으니
# 미정계수의 예상되는 초기값이나 구간을 꼭 넣어줘야 한다.
# 참고로, popt에는 sum of the squared residuals of f(xdata, *popt) - ydata가 최소가 되는
# 미정계수들이 array 형태로 담긴다.
popt, _ = curve_fit(objective, x, y)

# plot input vs output
pyplot.scatter(x, y)
# define a sequence of inputs between the smallest and largest known inputs
x_line = arange(min(x), max(x), (max(x)-min(x))/len(x))
# calculate the output for the range
y_line = objective(x_line, *popt)
# create a line plot for the mapping function
pyplot.plot(x_line, y_line, '--', color='red')
pyplot.xlabel('x축 제목', fontsize=14) # 참고로, 제목을 한글로 넣을 거면 추가로 한글을 python에 설치해줘야 한다. 
pyplot.ylabel('y축 제목', fontsize=14)
pyplot.title('제목', fontsize = 16)
pyplot.show()

residuals = y- objective(x, *popt)
ss_res = sum(residuals**2)
ss_tot = sum((y-mean(y))**2)
r_squared = 1 - (ss_res / ss_tot)
print(r_squared) # r^2 값 출력
print(popt) # 찾아낸 파라메터 값 출력

 

'잡다한 코드' 카테고리의 다른 글

그래프 그리는 코드  (0) 2024.10.14
wolfram alpha의 행렬 지수 연산의 문제  (0) 2024.10.06