tfill_missing_petsc.py - pism - [fork] customized build of PISM, the parallel ice sheet model (tillflux branch)
(HTM) git clone git://src.adamsgaard.dk/pism
(DIR) Log
(DIR) Files
(DIR) Refs
(DIR) LICENSE
---
tfill_missing_petsc.py (12203B)
---
1 #!/usr/bin/env python3
2
3 # @package fill_missing
4 # @brief This script solves the Laplace equation as a method of filling holes in map-plane data.
5 #
6 # Uses an approximation to Laplace's equation
7 # @f[ \nabla^2 u = 0 @f]
8 # to smoothly replace missing values in two-dimensional NetCDF
9 # variables with the average of the ``nearby'' non-missing values.
10 #
11 # Here is hypothetical example, filling the missing values in the variables
12 # `topg` and `usurf` in `data.nc` :
13 # \code
14 # fill_missing.py -v topg,usurf data.nc data_smoothed.nc
15 # \endcode
16 # Generally variables should be filled one at a time.
17 #
18 # Each of the requested variables must have missing values specified
19 # using the _FillValue attribute.
20
21 import petsc4py
22 import sys
23 petsc4py.init(sys.argv)
24
25 from petsc4py import PETSc
26 import numpy as np
27
28
29 def assemble_matrix(mask):
30 """Assemble the matrix corresponding to the standard 5-point stencil
31 approximation of the Laplace operator on the domain defined by
32 mask == True, where mask is a 2D NumPy array. The stencil wraps
33 around the grid, i.e. this is an approximation of the Laplacian
34 on a torus.
35
36 The grid spacing is ignored, which is equivalent to assuming equal
37 spacing in x and y directions.
38 """
39 PETSc.Sys.Print("Assembling the matrix...")
40 # grid size
41 nrow, ncol = mask.shape
42
43 # create sparse matrix
44 A = PETSc.Mat()
45 A.create(PETSc.COMM_WORLD)
46 A.setSizes([nrow * ncol, nrow * ncol])
47 A.setType('aij') # sparse
48 A.setPreallocationNNZ(5)
49
50 # precompute values for setting
51 # diagonal and non-diagonal entries
52 diagonal = 4.0
53 offdx = - 1.0
54 offdy = - 1.0
55
56 def R(i, j):
57 "Map from the (row,column) pair to the linear row number."
58 return i * ncol + j
59
60 # loop over owned block of rows on this
61 # processor and insert entry values
62 row_start, row_end = A.getOwnershipRange()
63 for row in range(row_start, row_end):
64 A[row, row] = diagonal
65
66 i = row // ncol # map row number to
67 j = row - i * ncol # grid coordinates
68
69 if mask[i, j] == False:
70 continue
71
72 # i
73 if i == 0: # top row
74 col = R(nrow - 1, j)
75 A[row, col] = offdx
76
77 if i > 0: # interior
78 col = R(i - 1, j)
79 A[row, col] = offdx
80
81 if i < nrow - 1: # interior
82 col = R(i + 1, j)
83 A[row, col] = offdx
84
85 if i == nrow - 1: # bottom row
86 col = R(0, j)
87 A[row, col] = offdx
88
89 # j
90 if j == 0: # left-most column
91 col = R(i, ncol - 1)
92 A[row, col] = offdy
93
94 if j > 0: # interior
95 col = R(i, j - 1)
96 A[row, col] = offdy
97
98 if j < ncol - 1: # interior
99 col = R(i, j + 1)
100 A[row, col] = offdy
101
102 if j == ncol - 1: # right-most column
103 col = R(i, 0)
104 A[row, col] = offdy
105
106 # communicate off-processor values
107 # and setup internal data structures
108 # for performing parallel operations
109 A.assemblyBegin()
110 A.assemblyEnd()
111
112 PETSc.Sys.Print("done.")
113 return A
114
115
116 def assemble_rhs(rhs, X):
117 """Assemble the right-hand side of the system approximating the
118 Laplace equation.
119
120 Modifies rhs in place; sets Dirichlet BC using X where X.mask ==
121 False.
122 """
123 # PETSc.Sys.Print("Setting Dirichlet BC...")
124 nrow, ncol = X.shape
125 row_start, row_end = rhs.getOwnershipRange()
126
127 # The right-hand side is zero everywhere except for Dirichlet
128 # nodes.
129 rhs.set(0.0)
130
131 for row in range(row_start, row_end):
132 i = row // ncol # map row number to
133 j = row - i * ncol # grid coordinates
134
135 if X.mask[i, j] == False:
136 rhs[row] = 4.0 * X[i, j]
137
138 rhs.assemble()
139 # PETSc.Sys.Print("done.")
140
141
142 def create_solver():
143 "Create the KSP solver"
144 # create linear solver
145 ksp = PETSc.KSP()
146 ksp.create(PETSc.COMM_WORLD)
147
148 # Use algebraic multigrid:
149 pc = ksp.getPC()
150 pc.setType(PETSc.PC.Type.GAMG)
151 ksp.setFromOptions()
152
153 ksp.setInitialGuessNonzero(True)
154
155 return ksp
156
157
158 def fill_missing(field, matrix=None):
159 """Fill missing values in a NumPy array 'field' using the matrix
160 'matrix' approximating the Laplace operator."""
161
162 ksp = create_solver()
163
164 if matrix is None:
165 A = assemble_matrix(field.mask)
166 else:
167 # PETSc.Sys.Print("Reusing the matrix...")
168 A = matrix
169
170 # obtain solution & RHS vectors
171 x, b = A.getVecs()
172
173 assemble_rhs(b, field)
174
175 initial_guess = np.mean(field)
176
177 # set the initial guess
178 x.set(initial_guess)
179
180 ksp.setOperators(A)
181
182 # Solve Ax = b
183 # PETSc.Sys.Print("Solving...")
184 ksp.solve(b, x)
185 # PETSc.Sys.Print("done.")
186
187 # transfer solution to processor 0
188 vec0, scatter = create_scatter(x)
189 scatter_to_0(x, vec0, scatter)
190
191 return vec0, A
192
193
194 def create_scatter(vector):
195 "Create the scatter to processor 0."
196 comm = vector.getComm()
197 rank = comm.getRank()
198 scatter, V0 = PETSc.Scatter.toZero(vector)
199 scatter.scatter(vector, V0, False, PETSc.Scatter.Mode.FORWARD)
200 comm.barrier()
201
202 return V0, scatter
203
204
205 def scatter_to_0(vector, vector_0, scatter):
206 "Scatter a distributed 'vector' to 'vector_0' on processor 0 using 'scatter'."
207 comm = vector.getComm()
208 scatter.scatter(vector, vector_0, False, PETSc.Scatter.Mode.FORWARD)
209 comm.barrier()
210
211
212 def scatter_from_0(vector_0, vector, scatter):
213 "Scatter 'vector_0' on processor 0 to a distributed 'vector' using 'scatter'."
214 comm = vector.getComm()
215 scatter.scatter(vector, vector_0, False, PETSc.Scatter.Mode.REVERSE)
216 comm.barrier()
217
218
219 def fill_2d_record(data, matrix=None):
220 "Fill missing values in a 2D record."
221
222 if getattr(data, "mask", None) is None:
223 return data, None
224 filled_data, A = fill_missing(data, matrix)
225 if PETSc.COMM_WORLD.getRank() == 0:
226 filled_data = filled_data[:].reshape(data.shape)
227
228 return filled_data, A
229
230
231 def test():
232 "Test fill_missing() using synthetic data."
233 N = 201
234 M = N * 1.5
235 x = np.linspace(-1, 1, N)
236 y = np.linspace(-1, 1, M)
237 xx, yy = np.meshgrid(x, y)
238 zz = np.sin(2.5 * np.pi * xx) * np.cos(2.0 * np.pi * yy)
239
240 K = 10
241 mask = np.random.randint(0, K, zz.size).reshape(zz.shape) / float(K)
242
243 mask[(xx - 1.0) ** 2 + (yy - 1.0) ** 2 < 1.0] = 1
244 mask[(xx + 1.0) ** 2 + (yy + 1.0) ** 2 < 1.0] = 1
245
246 field = np.ma.array(zz, mask=mask)
247
248 zzz, _ = fill_missing(field)
249
250 rank = PETSc.COMM_WORLD.getRank()
251 if rank == 0:
252 zzz0_np = zzz[:].reshape(field.shape)
253
254 import pylab as plt
255
256 plt.figure(1)
257 plt.imshow(zz, interpolation='nearest')
258
259 plt.figure(2)
260 plt.imshow(field, interpolation='nearest')
261
262 plt.figure(3)
263 plt.imshow(zzz0_np, interpolation='nearest')
264
265 plt.show()
266
267
268 def fill_variable(nc, name):
269 "Fill missing values in one variable."
270 PETSc.Sys.Print("Processing %s..." % name)
271 t0 = time()
272
273 var = nc.variables[name]
274
275 comm = PETSc.COMM_WORLD
276 rank = comm.getRank()
277
278 if var.ndim == 3:
279 A = None
280 n_records = var.shape[0]
281 for t in range(n_records):
282 PETSc.Sys.Print("Processing record %d/%d..." % (t + 1, n_records))
283 data = var[t, :, :]
284
285 filled_data, A = fill_2d_record(data, A)
286 if rank == 0:
287 var[t, :, :] = filled_data
288
289 comm.barrier()
290 PETSc.Sys.Print("Time elapsed: %5f seconds." % (time() - t0))
291 elif var.ndim == 2:
292 data = var[:, :]
293
294 filled_data, _ = fill_2d_record(data)
295 if rank == 0:
296 var[:, :] = filled_data
297
298 comm.barrier()
299 PETSc.Sys.Print("Time elapsed: %5f seconds." % (time() - t0))
300 else:
301 PETSc.Sys.Print("Skipping the %dD variable %s." % (var.ndim, name))
302 return
303
304 # Remove the _FillValue attribute:
305 try:
306 delattr(var, '_FillValue')
307 except:
308 pass
309
310 # Remove the missing_value attribute:
311 try:
312 delattr(var, 'missing_value')
313 except:
314 pass
315
316
317 def add_history(nc):
318 "Update the history attribute in a NetCDF file nc."
319 comm = PETSc.COMM_WORLD
320 rank = comm.getRank()
321
322 if rank != 0:
323 return
324
325 # add history global attribute (after checking if present)
326 historysep = ' '
327 historystr = asctime() + ': ' + historysep.join(sys.argv) + '\n'
328 if 'history' in nc.ncattrs():
329 nc.history = historystr + nc.history # prepend to history string
330 else:
331 nc.history = historystr
332
333
334 if __name__ == "__main__":
335 from argparse import ArgumentParser
336 import os
337 import os.path
338 import tempfile
339 import shutil
340 from time import time, asctime
341
342 try:
343 from netCDF4 import Dataset as NC
344 except:
345 PETSc.Sys.Print("netCDF4 is not installed!")
346 sys.exit(1)
347
348 parser = ArgumentParser()
349 parser.description = "Fill missing values by solving the Laplace equation in on the missing values and using present values as Dirichlet B.C."
350
351 parser.add_argument("INPUT", nargs=1, help="Input file name.")
352 parser.add_argument("OUTPUT", nargs=1, help="Output file name.")
353 parser.add_argument("-a", "--all", dest="all", action="store_true",
354 help="Process all variables.")
355 parser.add_argument("-v", "--vars", dest="variables",
356 help="comma-separated list of variables to process")
357
358 options, _ = parser.parse_known_args()
359
360 input_filename = options.INPUT[0]
361 output_filename = options.OUTPUT[0]
362
363 if options.all:
364 nc = NC(input_filename)
365 variables = list(nc.variables.keys())
366 nc.close()
367 else:
368 try:
369 variables = (options.variables).split(',')
370 except:
371 PETSc.Sys.Print("Please specify variables using the -v option.")
372 sys.exit(-1)
373
374 # Done processing command-line options.
375
376 comm = PETSc.COMM_WORLD
377 rank = comm.getRank()
378
379 t0 = time()
380
381 PETSc.Sys.Print("Filling missing values in %s and saving results to %s..." % (input_filename,
382 output_filename))
383 if rank == 0:
384 try:
385 PETSc.Sys.Print("Creating a temporary file...")
386 # find the name of the directory with the output file:
387 dirname = os.path.dirname(os.path.abspath(output_filename))
388 (handle, tmp_filename) = tempfile.mkstemp(prefix="fill_missing_",
389 suffix=".nc",
390 dir=dirname)
391
392 os.close(handle) # mkstemp returns a file handle (which we don't need)
393 except IOError:
394 PETSc.Sys.Print("ERROR: Can't create %s, Exiting..." % tmp_filename)
395
396 try:
397 PETSc.Sys.Print("Copying input file %s to %s..." % (input_filename,
398 tmp_filename))
399 shutil.copy(input_filename, tmp_filename)
400 except IOError:
401 PETSc.Sys.Print("ERROR: Can't copy %s, Exiting..." % input_filename)
402
403 try:
404 if rank == 0:
405 nc = NC(tmp_filename, 'a')
406 else:
407 nc = NC(input_filename, 'r')
408 except Exception as message:
409 PETSc.Sys.Print(message)
410 PETSc.Sys.Print("Note: %s was not modified." % output_filename)
411 sys.exit(-1)
412
413 add_history(nc)
414
415 for name in variables:
416 try:
417 fill_variable(nc, name)
418 except Exception as message:
419 PETSc.Sys.Print("ERROR:", message)
420 PETSc.Sys.Print("Note: %s was not modified." % output_filename)
421 sys.exit(-1)
422 nc.close()
423
424 try:
425 if rank == 0:
426 shutil.move(tmp_filename, output_filename)
427 except:
428 PETSc.Sys.Print("Error moving %s to %s. Exiting..." % (tmp_filename,
429 output_filename))
430 sys.exit(-1)
431
432 PETSc.Sys.Print("Total time elapsed: %5f seconds." % (time() - t0))