Package ivs :: Package sed :: Module decorators
[hide private]
[frames] | no frames]

Source Code for Module ivs.sed.decorators

  1  # -*- coding: utf-8 -*- 
  2  """ 
  3  Decorators specifically for SED package 
  4  """ 
  5  import functools 
  6  import logging 
  7  import numpy as np 
  8  import pylab as pl 
  9  from multiprocessing import Manager,Process,cpu_count 
 10  import model 
 11  from ivs.units import conversions 
 12  from ivs.units import constants 
 13   
 14  logger = logging.getLogger('SED.DEC') 
15 16 -def parallel_gridsearch(fctn):
17 """ 18 Decorator to run SED grid fitting in parallel. 19 20 This splits up the effective temperature range between teffrange[0] and 21 teffrange[1] in 'threads' parts. 22 23 This must decorate a 'make_parallel' decorator. 24 """ 25 @functools.wraps(fctn) 26 def globpar(*args,**kwargs): 27 #-- construct a manager to collect all calculations 28 manager = Manager() 29 arr = manager.list([]) 30 all_processes = [] 31 #-- get information on threading 32 threads = kwargs.pop('threads',1) 33 if threads=='max': 34 threads = cpu_count() 35 elif threads=='half': 36 threads = cpu_count()/2 37 elif threads=='safe': 38 threads = cpu_count()-1 39 threads = int(threads) 40 index = np.arange(len(args[-1])) 41 42 #-- distribute the periodogram calcs over different threads, and wait 43 for i in range(threads): 44 #-- extend the arguments to include the parallel array, and split 45 # up the first four input arrays 46 #myargs = tuple([args[0][i::threads],args[1][i::threads],args[2][i::threads],args[3][i::threads]] + list(args[4:]) + [arr] ) 47 myargs = tuple(list(args[:3]) + [args[j][i::threads] for j in range(3,len(args))] + [arr] ) 48 kwargs['index'] = index[i::threads] 49 logger.debug("parallel: starting process %s"%(i)) 50 p = Process(target=fctn, args=myargs, kwargs=kwargs) 51 p.start() 52 all_processes.append(p) 53 54 for p in all_processes: p.join() 55 56 logger.debug("parallel: all processes ended") 57 58 #-- join all periodogram pieces 59 chisqs = np.hstack([output[0] for output in arr]) 60 scales = np.hstack([output[1] for output in arr]) 61 e_scales = np.hstack([output[2] for output in arr]) 62 lumis = np.hstack([output[3] for output in arr]) 63 index = np.hstack([output[4] for output in arr]) 64 sa = np.argsort(index) 65 return chisqs[sa],scales[sa],e_scales[sa],lumis[sa]#,index[sa]
66 67 return globpar 68
69 -def iterate_gridsearch(fctn):
70 """ 71 Decorator to run SED iteratively and zooming in on the minimum. 72 73 iterations: number of iterative zoom-ins 74 increase: increase in number of grid points in each search (1 means no increase) 75 size: speed of zoomin: the higher, the slower 76 """ 77 @functools.wraps(fctn) 78 def globpar(*args,**kwargs): 79 iterations = kwargs.pop('iterations',1) 80 increase = kwargs.pop('increase',1) 81 speed = kwargs.pop('speed',2) 82 83 N = 0 84 for nr_iter in range(iterations): 85 data_ = fctn(*args,**kwargs) 86 #-- append results to record array 87 if nr_iter == 0: 88 data = data_ 89 startN = len(data) 90 else: 91 data = np.core.records.fromrecords(data.tolist()+data_.tolist(),dtype=data.dtype) 92 93 #-- select next stage 94 best = np.argmin(data['chisq']) 95 limit = data['chisq'][best]+speed*0.5**nr_iter*data['chisq'][best] 96 97 kwargs['teffrange'] = (data['teff'][data['chisq']<=limit]).min(),(data['teff'][data['chisq']<=limit]).max() 98 kwargs['loggrange'] = (data['logg'][data['chisq']<=limit]).min(),(data['logg'][data['chisq']<=limit]).max() 99 kwargs['ebvrange'] = (data['ebv'][data['chisq']<=limit]).min(),(data['ebv'][data['chisq']<=limit]).max() 100 kwargs['zrange'] = (data['z'][data['chisq']<=limit]).min(),(data['z'][data['chisq']<=limit]).max() 101 kwargs['points'] = increase**(nr_iter+1)*startN 102 103 logger.info('Best parameters (stage %d/%d): teff=%.0f logg=%.3f E(B-V)=%.3f Z=%.2f (CHI2=%g, cutoff=%g)'\ 104 %(nr_iter+1,iterations,data['teff'][best],data['logg'][best],\ 105 data['ebv'][best],data['z'][best],data['chisq'][best],limit)) 106 107 return data
108 109 return globpar 110
111 112 113 -def standalone_figure(fctn):
114 """ 115 Accept 'savefig' as an extra keyword. If it is given, start a new figure and 116 save it to the filename given, and close it! 117 """ 118 @functools.wraps(fctn) 119 def dofig(*args,**kwargs): 120 savefig = kwargs.pop('savefig',None) 121 colors = kwargs.get('colors',False) 122 #-- if savefig is not a string but a boolean, make name: it consists 123 # of the ID, the model used and the function name 124 if savefig is True: 125 savefig = args[0].ID + '__%s_'%(args[0].label) + model.defaults2str() 126 for char in ['/','*']: 127 savefig = savefig.replace(char,'_') 128 savefig = savefig.replace('.','p') 129 savefig = savefig + '_' + fctn.__name__ 130 if colors: 131 savefig = savefig + '_colors' 132 #-- start figure 133 if savefig: 134 pl.figure() 135 out = fctn(*args,**kwargs) 136 #-- end figure 137 if savefig: 138 pl.savefig(savefig) 139 pl.close() 140 return out
141 142 return dofig 143
144 145 146 -def blackbody_input(fctn):
147 """ 148 Prepare input and output for blackbody-like functions. 149 150 If the user gives wavelength units and Flambda units, we only need to convert 151 everything to SI (and back to the desired units in the end). 152 153 If the user gives frequency units and Fnu units, we only need to convert 154 everything to SI ( and back to the desired units in the end). 155 156 If the user gives wavelength units and Fnu units, we need to convert 157 the wavelengths first to frequency. 158 """ 159 @functools.wraps(fctn) 160 def dobb(x,T,**kwargs): 161 wave_units = kwargs.get('wave_units','AA') 162 flux_units = kwargs.get('flux_units','erg/s/cm2/AA') 163 #-- prepare input 164 #-- what kind of units did we receive? 165 curr_conv = constants._current_convention 166 # X: wavelength/frequency 167 x_unit_type = conversions.get_type(wave_units) 168 x = conversions.convert(wave_units,curr_conv,x) 169 # T: temperature 170 if isinstance(T,tuple): 171 T = conversions.convert(T[1],'K',T[0]) 172 # Y: flux 173 y_unit_type = conversions.change_convention('SI',flux_units) 174 #-- if you give Jy vs micron, we need to first convert wavelength to frequency 175 if y_unit_type=='kg1 rad-1 s-2' and x_unit_type=='length': 176 x = conversions.convert(conversions._conventions[curr_conv]['length'],'rad/s',x) 177 x_unit_type = 'frequency' 178 elif y_unit_type=='kg1 m-1 s-3' and x_unit_type=='frequency': 179 x = conversions.convert('rad/s',conversions._conventions[curr_conv]['length'],x) 180 x_unit_type = 'length' 181 #-- correct for rad 182 if x_unit_type=='frequency': 183 x /= (2*np.pi) 184 print y_unit_type 185 #-- run function 186 I = fctn((x,x_unit_type),T) 187 188 #-- prepare output 189 disc_integrated = kwargs.get('disc_integrated',True) 190 ang_diam = kwargs.get('ang_diam',None) 191 if disc_integrated: 192 I *= np.sqrt(2*np.pi) 193 if ang_diam is not None: 194 scale = conversions.convert(ang_diam[1],'sr',ang_diam[0]/2.) 195 I *= scale 196 I = conversions.convert(curr_conv,flux_units,I) 197 return I
198 199 return dobb 200