tlinesearchHZ.py - pism - [fork] customized build of PISM, the parallel ice sheet model (tillflux branch)
(HTM) git clone git://src.adamsgaard.dk/pism
(DIR) Log
(DIR) Files
(DIR) Refs
(DIR) LICENSE
---
tlinesearchHZ.py (11504B)
---
1 ############################################################################
2 #
3 # This file is a part of siple.
4 #
5 # Copyright 2010, 2014 David Maxwell
6 #
7 # siple is free software: you can redistribute it and/or modify
8 # it under the terms of the GNU General Public License as published by
9 # the Free Software Foundation, either version 2 of the License, or
10 # (at your option) any later version.
11 #
12 ############################################################################
13
14 # %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
15 # %%
16 # %% linesearchHZ
17 # %%
18 # %% Finds an approximate minimizer of F(t) for t in [0,infinity) satisfying Wolfe conditions.
19 # %%
20 # %% Algorithm from: Hager, W. and Zhang, H. CG DESCENT, a Conjugate Gradient Method with Guaranteed Descent
21 # %% Algorithm 851. ACM Transactions on Mathematical Software. 2006;32(1):113-137.
22 # %%
23 # %% Usage: status = linesearchHZ(F0, F0p, F, t, params)
24 # %%
25 # %% In:
26 # %% F0 - F(0)
27 # %% F0p - F'(0)
28 # %% F - Function to minimize. See requirements below.
29 # %% t - Initial guess for location of minimizer.
30 # %% params - Optional struct of control parameters. See below.
31 # %%
32 # %% Out:
33 # %% status - Structure with the following fields:
34 # %% code - nonnegative integer with 0 indicating no error
35 # %% msg - if code>0, a descriptive error message about why the algorithm failed
36 # %% val - if code==0, structure containing information about the minimizer
37 # %% val.t - location of minimizer
38 # %% val.F - value of F(c.t)
39 # %% val.Fp - value of F'(c.t)
40 # %% val.data - additional data at minimizer. See below.
41 # %%
42 # %% The function F must have the following signature: function [f, fp, fdata] = F(t).
43 # %% The f and fp are the values of the function and derivative at t. For some functions,
44 # %% there is expensive data that are computed along the way that might be needed by the end user.
45 # %% The data field allows this excess data to be saved and returned. In the end it will show up
46 # %% in status.c.data.
47 # %%
48 # %% The various parameters that control the algorithm are described in the above reference and
49 # %% have the same name. The principal ones you might want to change:
50 # %%
51 # %% delta: Controls sufficent decrease Wolfe condition: delta*F'(0) >= (F(t)-F(0))/t
52 # %% sigma: Controls sufficent shallowness Wolfe condition: F'(t) >= sigma*F'(0)
53 # %% rho: Expansion factor for initial bracket search (bracket expands by multiples of rho)
54 # %% nsecant: Maximum number of outermost loops
55 # %% nshrink: Maximum number of interval shrinks or expansions in a single loop
56 # %% verbose: Print lots of messages to track the algorithm
57 # %% debug: Do extra computations to verify the state is consistant as the algorithm progresses.
58 # %%
59 # %%
60 #
61
62 from siple.params import Bunch, Parameters
63 from siple.reporting import msg, pause
64 import numpy
65
66 class LinesearchHZ:
67
68 @staticmethod
69 def defaultParameters():
70 return Parameters('linesearchHZ', delta=.1, sigma=.9, epsilon=0, theta=.5, gamma=.66, rho=5,
71 nsecant=50, nshrink=50, verbose=False, debug=True);
72
73 def __init__(self,params=None):
74 self.params = self.defaultParameters()
75 if not (params is None): self.params.update(params)
76
77 def error(self):
78 return self.code > 0
79
80 def ezsearch(self,F,t0=None):
81 self.F = F
82 z = self.eval(0)
83 if t0 is None:
84 t0 = 1./(1.-z.F0p);
85 return self.search(F,z.F,z.Fp,t0)
86
87 def search(self,F,F0,F0p,t0):
88 self.code = -1
89 self.errMsg = 'no error'
90 self.F = F
91
92 params = self.params
93
94 z = Bunch(F=F0,Fp=F0p,t=0,data=None)
95 assert F0p <= 0
96
97 # % Set up constants for checking Wolfe conditions.
98 self.wolfe_lo = params.sigma*z.Fp;
99 self.wolfe_hi = params.delta*z.Fp;
100 self.awolfe_hi = (2*params.delta-1)*z.Fp;
101 self.fpert = z.F + params.epsilon;
102 self.f0 = z.F;
103
104 if params.verbose: msg('starting at z=%g (%g,%g)', z.t, z.F, z.Fp)
105
106 while True:
107 c = self.eval(t0)
108 if not numpy.isnan(c.F):
109 break
110 msg('Hit a NaN in initial evaluation at t=%g',t0)
111 t0 *= 0.5
112
113 if params.verbose: msg('initial guess c=%g (%g,%g)', c.t, c.F, c.Fp)
114
115 if self.wolfe(c):
116 if params.verbose: msg('done at init')
117 self.setDone(c)
118 return
119
120 (aj,bj) = self.bracket(z,c)
121 if params.verbose: msg('initial bracket %g %g',aj.t,bj.t)
122
123 if self.code >= 0:
124 self.doneMsg('initial bracket')
125 return
126
127 if params.debug: self.verifyBracket(aj,bj)
128
129 count = 0;
130
131 while True:
132 count += 1;
133
134 if count> params.nsecant:
135 self.setError('too many bisections in main loop')
136 return
137
138 (a,b) = self.secantsq(aj,bj);
139 if params.verbose: msg('secantsq a %g b %g', a.t, b.t)
140 if params.verbose: self.printBracket(a,b)
141 if self.code >= 0:
142 self.doneMsg('secant');
143 return
144
145 if (b.t-a.t) > params.gamma*(bj.t-aj.t):
146 (a,b) = self.update(a, b, (a.t+b.t)/2);
147 if params.verbose: msg('update to a %g b %g', aj.t, bj.t)
148 if params.verbose: self.printBracket(a,b)
149 if self.code >= 0:
150 self.doneMsg('bisect');
151 return
152 aj = a
153 bj = b
154
155 def printBracket(self,a,b):
156 msg('a %g b %g f(a) %g fp(a) %g f(b) %g fp(b) %g fpert %g', a.t, b.t, a.F, a.Fp, b.F, b.Fp, self.fpert)
157
158 def doneMsg(self,where):
159 if self.code > 0:
160 msg('done at %s with error status: %s', where, self.errMsg);
161 else:
162 if self.params.verbose: msg('done at %s with val=%g (%g, %g)', where, self.value.t, self.value.F, self.value.Fp);
163
164 def verifyBracket(self,a,b):
165 good = (a.Fp<=0) and (b.Fp >=0) and (a.F<= self.fpert);
166 if not good:
167 msg('bracket inconsistent: a %g b %g f(a) %g fp(a) %g f(b) %g fp(b) %g fpert %g', a.t, b.t, a.F, a.Fp, b.F, b.Fp, self.fpert)
168 pause()
169 if (a.t>=b.t):
170 msg('bracket not a bracket (a>=b): a %g b %g f(a) %g fp(a) %g f(b) %g fp(b) %g fpert %g', a.t, b.t, a.F, a.Fp, b.F, b.Fp, self.fpert)
171
172 def setDone(self, c):
173 self.code = 0
174 self.value = c
175
176 def setError(self,msg):
177 self.code = 1
178 self.errMsg = msg
179
180 def update(self, a, b, ct):
181 abar = a
182 bbar = b
183
184 params = self.params
185
186 if params.verbose: msg('update %g %g %g', a.t, b.t, ct);
187 if (ct<=a.t) or (ct>=b.t):
188 if params.verbose: msg('midpoint out of interval')
189 return (abar,bbar)
190
191 c = self.eval(ct)
192
193 if self.wolfe(c):
194 self.setDone(c)
195 return (abar,bbar)
196
197 if c.Fp >= 0:
198 if params.verbose: msg('midpoint with non-negative slope. Becomes b.')
199 abar = a;
200 bbar = c;
201 if params.debug: self.verifyBracket(abar,bbar)
202 return (abar,bbar)
203
204 if c.F <= self.fpert:
205 if params.verbose: msg('midpoint with negative slope, small value. Becomes a.')
206 abar = c;
207 bbar = b;
208 if params.debug: self.verifyBracket(abar,bbar)
209 return (abar,bbar)
210
211 if params.verbose: msg('midpoint with negative slope, large value. Shrinking to left.')
212 (abar,bbar) = self.ushrink(a, c);
213 if params.debug: self.verifyBracket(abar,bbar)
214
215 return (abar,bbar)
216
217 def ushrink(self,a,b):
218 abar = a;
219 bbar = b;
220
221 count = 0;
222 while True:
223 count += 1;
224
225 if self.params.verbose:
226 msg('in ushrink')
227 self.printBracket(abar,bbar)
228 if count > self.params.nshrink:
229 self.setError('too many contractions in ushrink')
230 return (abar,bbar)
231
232 d=self.eval((1-self.params.theta)*abar.t+self.params.theta*bbar.t);
233 if self.wolfe(d):
234 self.setDone(d)
235 return (abar,bbar)
236
237 if d.Fp>=0:
238 bbar = d;
239 return (abar,bbar)
240
241 if d.F <= self.fpert:
242 abar=d;
243 else:
244 bbar=d;
245
246 def plotInterval(self,a,b,N=20):
247 from matplotlib import pyplot as pp
248 import numpy as np
249 T=np.linspace(a.t,b.t,N)
250 FT=[]
251 FpT=[]
252 for t in T:
253 c=self.eval(t)
254 FT.append(c.F)
255 FpT.append(c.Fp)
256 pp.subplot(1,2,1)
257 pp.plot(T,np.array(FT))
258 pp.subplot(1,2,2)
259 pp.plot(T,np.array(FpT))
260 pp.draw()
261
262 def secant(self,a,b):
263 # % What if a'=b'? We'll generate a +/-Inf, which will subsequently test as being out
264 # % of any interval when 'update' is subsequently called. So this seems safe.
265
266 if self.params.verbose: msg('secant: a %g fp(a) %4.8g b %g fp(b) %4.8g',a.t,a.Fp, b.t, b.Fp)
267 if (a.t==b.t):
268 msg('a=b, inconcievable!')
269 if -a.Fp <= b.Fp:
270 return a.t-(a.t-b.t)*(a.Fp/(a.Fp-b.Fp));
271 else:
272 return b.t-(a.t-b.t)*((b.Fp)/(a.Fp-b.Fp));
273
274 def secantsq(self,a,b):
275 ct = self.secant(a,b)
276 if self.params.verbose: msg('first secant to %g', ct)
277 (A,B) = self.update(a,b,ct)
278 if self.code >= 0:
279 return (A,B)
280
281 if B.t == ct:
282 ct2 = self.secant(b,B);
283 if self.params.verbose: msg('second secant on left half A %g B %g with c=%g',A.t, B.t, ct2)
284 (abar,bbar) = self.update(A,B,ct2)
285 elif A.t == ct:
286 ct2 = self.secant(a,A);
287 if self.params.verbose: msg('second secant on right half A %g B %g with c=%g',A.t, B.t, ct2)
288 (abar,bbar) = self.update(A,B,ct2)
289 else:
290 if self.params.verbose: msg('first secant gave a shrink in update. Keeping A %g B %g',A.t, B.t)
291 abar = A; bbar = B
292
293 return (abar,bbar)
294
295
296 def bracket(self, z, c):
297 a = z
298 b = c
299
300 count = 0
301 while True:
302 if count > self.params.nshrink:
303 self.setError('Too many expansions in bracket')
304 return (a,b)
305 count += 1
306
307 if b.Fp >= 0:
308 if self.params.verbose: msg('initial bracket ends with expansion: b has positive slope')
309 return (a,b)
310
311 if b.F > self.fpert:
312 if self.params.verbose: msg('initial bracket contraction')
313 return self.ushrink(a,b);
314
315 if self.params.verbose: msg('initial bracket expanding')
316 a = b;
317 rho = self.params.rho
318 while True:
319 if count > self.params.nshrink:
320 self.setError('Unable to find a valid input')
321 return (a,b)
322 c = self.eval(rho*b.t)
323 if not numpy.isnan(c.F):
324 b = c
325 break
326 msg('Hit a NaN at t=%g',rho*b.t)
327 rho*=0.5
328 count += 1
329
330 if self.wolfe(b):
331 #msg('decrease %g slope %g f0 %g fpert %g', b.F-params.f0, b.t*params.wolfe_hi, params.f0, params.fpert)
332 self.setDone(b);
333 return(a,b)
334
335 def wolfe(self,c):
336 if self.params.verbose: msg('checking wolfe of c=%g (%g,%g)',c.t,c.F,c.Fp)
337
338 if c.Fp >= self.wolfe_lo:
339 if (c.F-self.f0) <= c.t*self.wolfe_hi:
340 return True
341
342 if self.params.verbose: msg('failed sufficient decrease')
343
344 # % if ((c.F <= params.fpert) && (c.Fp <= params.awolfe_hi))
345 # % msg('met awolfe')
346 # % met = true;
347 # % return;
348 # % end
349 # if params.verbose: msg('failed awolfe sufficient decrease')
350 else:
351 if self.params.verbose: msg('failed slope flatness')
352
353 return False
354
355 def eval(self,t):
356 c = Bunch(F=0,Fp=0,data=None,t=t)
357 (c.F,c.Fp,c.data) = self.F(t)
358 c.F = float(c.F)
359 c.Fp = float(c.Fp)
360 return c
361
362 if __name__ == '__main__':
363
364 lsParams = Parameters('tmp', verbose=True, debug=True)
365 ls = LinesearchHZ(params=lsParams)
366 F = lambda t: (-t*(1-t), -1+2*t,None)
367 ls.ezsearch(F,5)
368 if ls.error():
369 print(ls.errMsg)
370 else:
371 v = ls.value
372 print('minimum of %g at t=%g' % (v.F,v.t))