tssa_tao.py - pism - [fork] customized build of PISM, the parallel ice sheet model (tillflux branch)
(HTM) git clone git://src.adamsgaard.dk/pism
(DIR) Log
(DIR) Files
(DIR) Refs
(DIR) LICENSE
---
tssa_tao.py (10577B)
---
1 # Copyright (C) 2012, 2014, 2015, 2016, 2018 David Maxwell and Constantine Khroulev
2 #
3 # This file is part of PISM.
4 #
5 # PISM is free software; you can redistribute it and/or modify it under the
6 # terms of the GNU General Public License as published by the Free Software
7 # Foundation; either version 3 of the License, or (at your option) any later
8 # version.
9 #
10 # PISM is distributed in the hope that it will be useful, but WITHOUT ANY
11 # WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS
12 # FOR A PARTICULAR PURPOSE. See the GNU General Public License for more
13 # details.
14 #
15 # You should have received a copy of the GNU General Public License
16 # along with PISM; if not, write to the Free Software
17 # Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA
18
19 """Inverse SSA solvers using the TAO library."""
20
21 import PISM
22 from PISM.util import Bunch
23 from PISM.logging import logError
24 from PISM.invert.ssa import InvSSASolver
25
26 import sys
27 import traceback
28
29
30 class InvSSASolver_Tikhonov(InvSSASolver):
31
32 """Inverse SSA solver based on Tikhonov iteration using TAO."""
33
34 # Dictionary converting PISM algorithm names to the corresponding
35 # TAO algorithms used to implement the Tikhonov minimization.
36 tao_types = {}
37
38 if (not PISM.imported_from_sphinx) and PISM.PETSc.Sys.getVersion() < (3, 5, 0):
39 tao_types = {'tikhonov_lmvm': 'tao_lmvm',
40 'tikhonov_cg': 'tao_cg',
41 'tikhonov_lcl': 'tao_lcl',
42 'tikhonov_blmvm': 'tao_blmvm'}
43 else:
44 tao_types = {'tikhonov_lmvm': 'lmvm',
45 'tikhonov_cg': 'cg',
46 'tikhonov_lcl': 'lcl',
47 'tikhonov_blmvm': 'blmvm'}
48
49
50 def __init__(self, ssarun, method):
51 """
52 :param ssarun: The :class:`PISM.invert.ssa.SSAForwardRun` defining the forward problem.
53 :param method: String describing the actual algorithm to use. Must be a key in :attr:`tao_types`."""
54
55 InvSSASolver.__init__(self, ssarun, method)
56 self.listeners = []
57 self.solver = None
58 self.ip = None
59 if self.tao_types.get(method) is None:
60 raise ValueError("Unknown TAO Tikhonov inversion method: %s" % method)
61
62 def addIterationListener(self, listener):
63 """Add a listener to be called after each iteration. See :ref:`Listeners`."""
64 self.listeners.append(listener)
65
66 def addDesignUpdateListener(self, listener):
67 """Add a listener to be called after each time the design variable is changed."""
68 self.listeners.append(listener)
69
70 def solveForward(self, zeta, out=None):
71 r"""Given a parameterized design variable value :math:`\zeta`, solve the SSA.
72 See :cpp:class:`IP_TaucParam` for a discussion of parameterizations.
73
74 :param zeta: :cpp:class:`IceModelVec` containing :math:`\zeta`.
75 :param out: optional :cpp:class:`IceModelVec` for storage of the computation result.
76 :returns: An :cpp:class:`IceModelVec` contianing the computation result.
77 """
78 ssa = self.ssarun.ssa
79
80 reason = ssa.linearize_at(zeta)
81 if reason.failed():
82 raise PISM.AlgorithmFailureException(reason)
83 if out is not None:
84 out.copy_from(ssa.solution())
85 else:
86 out = ssa.solution()
87 return out
88
89 def solveInverse(self, zeta0, u_obs, zeta_inv):
90 r"""Executes the inversion algorithm.
91
92 :param zeta0: The best `a-priori` guess for the value of the parameterized design variable :math:`\zeta`.
93 :param u_obs: :cpp:class:`IceModelVec2V` of observed surface velocities.
94 :param zeta_inv: :cpp:class:`zeta_inv` starting value of :math:`\zeta` for minimization of the Tikhonov functional.
95 :returns: A :cpp:class:`TerminationReason`.
96 """
97 eta = self.config.get_number("inverse.tikhonov.penalty_weight")
98
99 design_var = self.ssarun.designVariable()
100 if design_var == 'tauc':
101 if self.method == 'tikhonov_lcl':
102 problemClass = PISM.IP_SSATaucTaoTikhonovProblemLCL
103 solverClass = PISM.IP_SSATaucTaoTikhonovProblemLCLSolver
104 listenerClass = TaucLCLIterationListenerAdaptor
105 else:
106 problemClass = PISM.IP_SSATaucTaoTikhonovProblem
107 solverClass = PISM.IP_SSATaucTaoTikhonovSolver
108 listenerClass = TaucIterationListenerAdaptor
109 elif design_var == 'hardav':
110 if self.method == 'tikhonov_lcl':
111 problemClass = PISM.IP_SSAHardavTaoTikhonovProblemLCL
112 solverClass = PISM.IP_SSAHardavTaoTikhonovSolverLCL
113 listenerClass = HardavLCLIterationListenerAdaptor
114 else:
115 problemClass = PISM.IP_SSAHardavTaoTikhonovProblem
116 solverClass = PISM.IP_SSAHardavTaoTikhonovSolver
117 listenerClass = HardavIterationListenerAdaptor
118 else:
119 raise RuntimeError("Unsupported design variable '%s' for InvSSASolver_Tikhonov. Expected 'tauc' or 'hardness'" % design_var)
120
121 tao_type = self.tao_types[self.method]
122 (stateFunctional, designFunctional) = PISM.invert.ssa.createTikhonovFunctionals(self.ssarun)
123
124 self.ip = problemClass(self.ssarun.ssa, zeta0, u_obs, eta, stateFunctional, designFunctional)
125 self.solver = solverClass(self.ssarun.grid.com, tao_type, self.ip)
126
127 max_it = int(self.config.get_number("inverse.max_iterations"))
128 self.solver.setMaximumIterations(max_it)
129
130 pl = [listenerClass(self, l) for l in self.listeners]
131
132 for l in pl:
133 self.ip.addListener(l)
134
135 self.ip.setInitialGuess(zeta_inv)
136
137 vecs = self.ssarun.modeldata.vecs
138 if vecs.has('zeta_fixed_mask'):
139 self.ssarun.ssa.set_tauc_fixed_locations(vecs.zeta_fixed_mask)
140
141 return self.solver.solve()
142
143 def inverseSolution(self):
144 """Returns a tuple ``(zeta,u)`` of :cpp:class:`IceModelVec`'s corresponding to the values
145 of the design and state variables at the end of inversion."""
146 zeta = self.ip.designSolution()
147 u = self.ip.stateSolution()
148 return (zeta, u)
149
150
151 class TaucLCLIterationListenerAdaptor(PISM.IP_SSATaucTaoTikhonovProblemLCLListener):
152
153 """Adaptor converting calls to a C++ :cpp:class:`IP_SSATaucTaoTikhonovProblemListener`
154 on to a standard python-based listener. Used internally by
155 :class:`InvSSATaucSolver_Tikhonov`. I.e. don't make one of these for yourself."""
156
157 def __init__(self, owner, listener):
158 """:param owner: The :class:`InvSSATaucSolver_Tikhonov` that constructed us
159 :param listener: The python-based listener.
160 """
161 PISM.IP_SSATaucTaoTikhonovProblemLCLListener.__init__(self)
162 self.owner = owner
163 self.listener = listener
164
165 def iteration(self, problem, eta, it, objVal, penaltyVal, d, diff_d, grad_d, u, diff_u, grad_u, constraints):
166 """Called during IP_SSATaucTaoTikhonovProblemLCL iterations. Gathers together the long list of arguments
167 into a dictionary and passes it along in standard form to the python listener."""
168
169 data = Bunch(tikhonov_penalty=eta, JDesign=objVal, JState=penaltyVal,
170 zeta=d, zeta_step=diff_d, grad_JDesign=grad_d,
171 u=u, residual=diff_u, grad_JState=grad_u,
172 constraints=constraints)
173 try:
174 self.listener(self.owner, it, data)
175 except Exception:
176 logError("\nERROR: Exception occured during an inverse solver listener callback:\n\n")
177 traceback.print_exc(file=sys.stdout)
178 raise
179
180
181 class TaucIterationListenerAdaptor(PISM.IP_SSATaucTaoTikhonovProblemListener):
182
183 """Adaptor converting calls to a C++ :cpp:class:`IP_SSATaucTaoTikhonovProblemListener`
184 on to a standard python-based listener. Used internally by
185 :class:`InvSSATaucSolver_Tikhonov`. I.e. don't make one of these for yourself."""
186
187 def __init__(self, owner, listener):
188 """:param owner: The :class:`InvSSATaucSolver_Tikhonov` that constructed us
189 :param listener: The python-based listener.
190 """
191 PISM.IP_SSATaucTaoTikhonovProblemListener.__init__(self)
192 self.owner = owner
193 self.listener = listener
194
195 def iteration(self, problem, eta, it, objVal, penaltyVal, d, diff_d, grad_d, u, diff_u, grad_u, grad):
196 """Called during IP_SSATaucTaoTikhonovProblem iterations. Gathers together the long list of arguments
197 into a dictionary and passes it along in a standard form to the python listener."""
198 data = Bunch(tikhonov_penalty=eta, JDesign=objVal, JState=penaltyVal,
199 zeta=d, zeta_step=diff_d, grad_JDesign=grad_d,
200 u=u, residual=diff_u, grad_JState=grad_u, grad_JTikhonov=grad)
201 try:
202 self.listener(self.owner, it, data)
203 except Exception:
204 logError("\nERROR: Exception occured during an inverse solver listener callback:\n\n")
205 traceback.print_exc(file=sys.stdout)
206 raise
207
208
209 class HardavIterationListenerAdaptor(PISM.IP_SSAHardavTaoTikhonovProblemListener):
210
211 """Adaptor converting calls to a C++ :cpp:class:`IP_SSATaucTaoTikhonovProblemListener`
212 on to a standard python-based listener. Used internally by
213 :class:`InvSSATaucSolver_Tikhonov`. I.e. don't make one of these for yourself."""
214
215 def __init__(self, owner, listener):
216 """:param owner: The :class:`InvSSATaucSolver_Tikhonov` that constructed us
217 :param listener: The python-based listener.
218 """
219 PISM.IP_SSAHardavTaoTikhonovProblemListener.__init__(self)
220 self.owner = owner
221 self.listener = listener
222
223 def iteration(self, problem, eta, it, objVal, penaltyVal, d, diff_d, grad_d, u, diff_u, grad_u, grad):
224 """Called during IP_SSATaucTaoTikhonovProblem iterations. Gathers together the long list of arguments
225 into a dictionary and passes it along in a standard form to the python listener."""
226 data = Bunch(tikhonov_penalty=eta, JDesign=objVal, JState=penaltyVal,
227 zeta=d, zeta_step=diff_d, grad_JDesign=grad_d,
228 u=u, residual=diff_u, grad_JState=grad_u, grad_JTikhonov=grad)
229 try:
230 self.listener(self.owner, it, data)
231 except Exception:
232 logError("\nERROR: Exception occured during an inverse solver listener callback:\n\n")
233 traceback.print_exc(file=sys.stdout)
234 raise