formatting improved

This commit is contained in:
Fabian Joswig 2021-09-24 16:49:00 +01:00
parent 63037fd98f
commit caa1330ab8
2 changed files with 24 additions and 24 deletions

View file

@ -20,14 +20,14 @@ class Corr:
""" """
def __init__(self, data_input,padding_front=0,padding_back=0): def __init__(self, data_input, padding_front=0, padding_back=0):
#All data_input should be a list of things at different timeslices. This needs to be verified #All data_input should be a list of things at different timeslices. This needs to be verified
if not (isinstance(data_input,list)): if not (isinstance(data_input, list)):
raise TypeError('Corr__init__ expects a list of timeslices.') raise TypeError('Corr__init__ expects a list of timeslices.')
# data_input can have multiple shapes. The simplest one is a list of Obs. # data_input can have multiple shapes. The simplest one is a list of Obs.
#We check, if this is the case #We check, if this is the case
if all([isinstance(item,Obs) for item in data_input]): if all([isinstance(item, Obs) for item in data_input]):
self.content=[np.asarray([item]) for item in data_input] self.content=[np.asarray([item]) for item in data_input]
#Wrapping the Obs in an array ensures that the data structure is consistent with smearing matrices. #Wrapping the Obs in an array ensures that the data structure is consistent with smearing matrices.
self.N = 1 # number of smearings self.N = 1 # number of smearings
@ -70,7 +70,7 @@ class Corr:
#The method can use one or two vectors. #The method can use one or two vectors.
#If two are specified it returns v1@G@v2 (the order might be very important.) #If two are specified it returns v1@G@v2 (the order might be very important.)
#By default it will return the lowest source, which usually means unsmeared-unsmeared (0,0), but it does not have to #By default it will return the lowest source, which usually means unsmeared-unsmeared (0,0), but it does not have to
def projected(self,vector_l=None,vector_r=None): def projected(self, vector_l=None, vector_r=None):
if self.N == 1: if self.N == 1:
raise Exception("Trying to project a Corr, that already has N=1.") raise Exception("Trying to project a Corr, that already has N=1.")
#This Exception is in no way necessary. One could just return self #This Exception is in no way necessary. One could just return self
@ -224,7 +224,7 @@ class Corr:
#We want to apply a pe.standard_fit directly to the Corr using an arbitrary function and range. #We want to apply a pe.standard_fit directly to the Corr using an arbitrary function and range.
def fit(self, function, fitrange=None): def fit(self, function, fitrange=None, silent=False):
if self.N != 1: if self.N != 1:
raise Exception("Correlator must be projected before fitting") raise Exception("Correlator must be projected before fitting")
@ -233,7 +233,7 @@ class Corr:
xs = [x for x in range(fitrange[0], fitrange[1]) if not self.content[x] is None] xs = [x for x in range(fitrange[0], fitrange[1]) if not self.content[x] is None]
ys = [self.content[x][0] for x in range(fitrange[0], fitrange[1]) if not self.content[x] is None] ys = [self.content[x][0] for x in range(fitrange[0], fitrange[1]) if not self.content[x] is None]
result = standard_fit(xs, ys, function, silent=True) result = standard_fit(xs, ys, function, silent=silent)
[item.gamma_method() for item in result if isinstance(item,Obs)] [item.gamma_method() for item in result if isinstance(item,Obs)]
return result return result
@ -245,10 +245,10 @@ class Corr:
raise Exception("plateau is undefined at all timeslices in plateaurange.") raise Exception("plateau is undefined at all timeslices in plateaurange.")
if method == "fit": if method == "fit":
def const_func(a, t): def const_func(a, t):
return a[0] + a[1] * 0 # At some point pe.standard fit had an issue with single parameter fits. Being careful does not hurt return a[0] # At some point pe.standard fit had an issue with single parameter fits. Being careful does not hurt
return self.fit(const_func,plateau_range)[0] return self.fit(const_func,plateau_range)[0]
elif method in ["avg","average","mean"]: elif method in ["avg","average","mean"]:
returnvalue= np.mean([item[0] for item in self.content if not item is None]) returnvalue= np.mean([item[0] for item in self.content[plateau_range[0]:plateau_range[1]+1] if not item is None])
returnvalue.gamma_method() returnvalue.gamma_method()
return returnvalue return returnvalue
@ -258,28 +258,28 @@ class Corr:
#quick and dirty plotting function to view Correlator inside Jupyter #quick and dirty plotting function to view Correlator inside Jupyter
#If one would not want to import pyplot, this could easily be replaced by a call to pe.plot_corrs #If one would not want to import pyplot, this could easily be replaced by a call to pe.plot_corrs
#This might be a bit more flexible later #This might be a bit more flexible later
def show(self,xrange=None,logscale=False): def show(self, x_range=None, logscale=False):
if self.N!=1: if self.N!=1:
raise Exception("Correlator must be projected before plotting") raise Exception("Correlator must be projected before plotting")
if xrange is None: if x_range is None:
xrange=[0,self.T] x_range=[0, self.T]
x,y,y_err=self.plottable() x,y,y_err=self.plottable()
plt.errorbar(x,y,y_err) plt.errorbar(x,y,y_err)
if logscale: if logscale:
plt.yscale("log") plt.yscale('log')
else: else:
# we generate ylim instead of using autoscaling. # we generate ylim instead of using autoscaling.
y_min=min([ (x[0].value-x[0].dvalue) for x in self.content[xrange[0]:xrange[1]] if(not x is None)]) y_min=min([(x[0].value - x[0].dvalue) for x in self.content[x_range[0]:x_range[1]] if(not x is None)])
y_max=max([ (x[0].value+x[0].dvalue) for x in self.content[xrange[0]:xrange[1]] if(not x is None)]) y_max=max([(x[0].value + x[0].dvalue) for x in self.content[x_range[0]:x_range[1]] if(not x is None)])
plt.ylim([y_min-0.1*(y_max-y_min),y_max+0.1*(y_max-y_min)]) plt.ylim([y_min - 0.1 * (y_max - y_min), y_max + 0.1 * (y_max - y_min)])
plt.xlabel(r"$an_t$") plt.xlabel(r'$x_0 / a$')
plt.xlim([xrange[0] - 0.5, xrange[1] + 0.5]) plt.xlim([x_range[0] - 0.5, x_range[1] + 0.5])
plt.title("Quickplot") #plt.title("Quickplot")
plt.show() plt.show()
plt.clf() #plt.clf()
return return
def dump(self,filename): def dump(self,filename):

View file

@ -356,7 +356,7 @@ class Obs:
plt.ylabel('tauint') plt.ylabel('tauint')
length = int(len(self.e_n_tauint[e_name])) length = int(len(self.e_n_tauint[e_name]))
plt.errorbar(np.arange(length), self.e_n_tauint[e_name][:], yerr=self.e_n_dtauint[e_name][:], linewidth=1, capsize=2) plt.errorbar(np.arange(length), self.e_n_tauint[e_name][:], yerr=self.e_n_dtauint[e_name][:], linewidth=1, capsize=2)
plt.axvline(x=self.e_windowsize[e_name], color='r', alpha=0.25) plt.axvline(x=self.e_windowsize[e_name], color='r', alpha=0.25, marker=',')
if self.tau_exp[e_name] > 0: if self.tau_exp[e_name] > 0:
base = self.e_n_tauint[e_name][self.e_windowsize[e_name]] base = self.e_n_tauint[e_name][self.e_windowsize[e_name]]
x_help = np.arange(2 * self.tau_exp[e_name]) x_help = np.arange(2 * self.tau_exp[e_name])
@ -1174,17 +1174,17 @@ def plot_corrs(observables, **kwargs):
if 'prange' in kwargs: if 'prange' in kwargs:
prange = kwargs.get('prange') prange = kwargs.get('prange')
plt.axvline(x=prange[0] - 0.5, ls='--', c='k', lw=1, alpha=0.5) plt.axvline(x=prange[0] - 0.5, ls='--', c='k', lw=1, alpha=0.5, marker=',')
plt.axvline(x=prange[1] + 0.5, ls='--', c='k', lw=1, alpha=0.5) plt.axvline(x=prange[1] + 0.5, ls='--', c='k', lw=1, alpha=0.5, marker=',')
if 'plateau' in kwargs: if 'plateau' in kwargs:
plateau = kwargs.get('plateau') plateau = kwargs.get('plateau')
if isinstance(plateau, Obs): if isinstance(plateau, Obs):
plt.axhline(y=plateau.value, linewidth=2, color='k', alpha=0.6, label='Plateau') plt.axhline(y=plateau.value, linewidth=2, color='k', alpha=0.6, label='Plateau', marker=',')
plt.axhspan(plateau.value - plateau.dvalue, plateau.value + plateau.dvalue, alpha=0.25, color='k') plt.axhspan(plateau.value - plateau.dvalue, plateau.value + plateau.dvalue, alpha=0.25, color='k')
elif isinstance(plateau, list): elif isinstance(plateau, list):
for i in range(len(plateau)): for i in range(len(plateau)):
plt.axhline(y=plateau[i].value, linewidth=2, color='C' + str(i), alpha=0.6, label='Plateau' + str(i + 1)) plt.axhline(y=plateau[i].value, linewidth=2, color='C' + str(i), alpha=0.6, label='Plateau' + str(i + 1), marker=',')
plt.axhspan(plateau[i].value - plateau[i].dvalue, plateau[i].value + plateau[i].dvalue, plt.axhspan(plateau[i].value - plateau[i].dvalue, plateau[i].value + plateau[i].dvalue,
color='C' + str(i), alpha=0.25) color='C' + str(i), alpha=0.25)
else: else: