tfill_missing_petsc.py - pism - [fork] customized build of PISM, the parallel ice sheet model (tillflux branch)
 (HTM) git clone git://src.adamsgaard.dk/pism
 (DIR) Log
 (DIR) Files
 (DIR) Refs
 (DIR) LICENSE
       ---
       tfill_missing_petsc.py (12203B)
       ---
            1 #!/usr/bin/env python3
            2 
            3 # @package fill_missing
            4 # @brief This script solves the Laplace equation as a method of filling holes in map-plane data.
            5 #
            6 # Uses an approximation to Laplace's equation
            7 #         @f[ \nabla^2 u = 0 @f]
            8 # to smoothly replace missing values in two-dimensional NetCDF
            9 # variables with the average of the ``nearby'' non-missing values.
           10 #
           11 # Here is hypothetical example, filling the missing values in the variables
           12 # `topg` and `usurf` in `data.nc` :
           13 # \code
           14 # fill_missing.py -v topg,usurf data.nc data_smoothed.nc
           15 # \endcode
           16 # Generally variables should be filled one at a time.
           17 #
           18 # Each of the requested variables must have missing values specified
           19 # using the _FillValue attribute.
           20 
           21 import petsc4py
           22 import sys
           23 petsc4py.init(sys.argv)
           24 
           25 from petsc4py import PETSc
           26 import numpy as np
           27 
           28 
           29 def assemble_matrix(mask):
           30     """Assemble the matrix corresponding to the standard 5-point stencil
           31     approximation of the Laplace operator on the domain defined by
           32     mask == True, where mask is a 2D NumPy array. The stencil wraps
           33     around the grid, i.e. this is an approximation of the Laplacian
           34     on a torus.
           35 
           36     The grid spacing is ignored, which is equivalent to assuming equal
           37     spacing in x and y directions.
           38     """
           39     PETSc.Sys.Print("Assembling the matrix...")
           40     # grid size
           41     nrow, ncol = mask.shape
           42 
           43     # create sparse matrix
           44     A = PETSc.Mat()
           45     A.create(PETSc.COMM_WORLD)
           46     A.setSizes([nrow * ncol, nrow * ncol])
           47     A.setType('aij')  # sparse
           48     A.setPreallocationNNZ(5)
           49 
           50     # precompute values for setting
           51     # diagonal and non-diagonal entries
           52     diagonal = 4.0
           53     offdx = - 1.0
           54     offdy = - 1.0
           55 
           56     def R(i, j):
           57         "Map from the (row,column) pair to the linear row number."
           58         return i * ncol + j
           59 
           60     # loop over owned block of rows on this
           61     # processor and insert entry values
           62     row_start, row_end = A.getOwnershipRange()
           63     for row in range(row_start, row_end):
           64         A[row, row] = diagonal
           65 
           66         i = row // ncol    # map row number to
           67         j = row - i * ncol  # grid coordinates
           68 
           69         if mask[i, j] == False:
           70             continue
           71 
           72         # i
           73         if i == 0:              # top row
           74             col = R(nrow - 1, j)
           75             A[row, col] = offdx
           76 
           77         if i > 0:               # interior
           78             col = R(i - 1, j)
           79             A[row, col] = offdx
           80 
           81         if i < nrow - 1:        # interior
           82             col = R(i + 1, j)
           83             A[row, col] = offdx
           84 
           85         if i == nrow - 1:       # bottom row
           86             col = R(0, j)
           87             A[row, col] = offdx
           88 
           89         # j
           90         if j == 0:              # left-most column
           91             col = R(i, ncol - 1)
           92             A[row, col] = offdy
           93 
           94         if j > 0:               # interior
           95             col = R(i, j - 1)
           96             A[row, col] = offdy
           97 
           98         if j < ncol - 1:        # interior
           99             col = R(i, j + 1)
          100             A[row, col] = offdy
          101 
          102         if j == ncol - 1:       # right-most column
          103             col = R(i, 0)
          104             A[row, col] = offdy
          105 
          106     # communicate off-processor values
          107     # and setup internal data structures
          108     # for performing parallel operations
          109     A.assemblyBegin()
          110     A.assemblyEnd()
          111 
          112     PETSc.Sys.Print("done.")
          113     return A
          114 
          115 
          116 def assemble_rhs(rhs, X):
          117     """Assemble the right-hand side of the system approximating the
          118     Laplace equation.
          119 
          120     Modifies rhs in place; sets Dirichlet BC using X where X.mask ==
          121     False.
          122     """
          123     # PETSc.Sys.Print("Setting Dirichlet BC...")
          124     nrow, ncol = X.shape
          125     row_start, row_end = rhs.getOwnershipRange()
          126 
          127     # The right-hand side is zero everywhere except for Dirichlet
          128     # nodes.
          129     rhs.set(0.0)
          130 
          131     for row in range(row_start, row_end):
          132         i = row // ncol    # map row number to
          133         j = row - i * ncol  # grid coordinates
          134 
          135         if X.mask[i, j] == False:
          136             rhs[row] = 4.0 * X[i, j]
          137 
          138     rhs.assemble()
          139     # PETSc.Sys.Print("done.")
          140 
          141 
          142 def create_solver():
          143     "Create the KSP solver"
          144     # create linear solver
          145     ksp = PETSc.KSP()
          146     ksp.create(PETSc.COMM_WORLD)
          147 
          148     # Use algebraic multigrid:
          149     pc = ksp.getPC()
          150     pc.setType(PETSc.PC.Type.GAMG)
          151     ksp.setFromOptions()
          152 
          153     ksp.setInitialGuessNonzero(True)
          154 
          155     return ksp
          156 
          157 
          158 def fill_missing(field, matrix=None):
          159     """Fill missing values in a NumPy array 'field' using the matrix
          160     'matrix' approximating the Laplace operator."""
          161 
          162     ksp = create_solver()
          163 
          164     if matrix is None:
          165         A = assemble_matrix(field.mask)
          166     else:
          167         # PETSc.Sys.Print("Reusing the matrix...")
          168         A = matrix
          169 
          170     # obtain solution & RHS vectors
          171     x, b = A.getVecs()
          172 
          173     assemble_rhs(b, field)
          174 
          175     initial_guess = np.mean(field)
          176 
          177     # set the initial guess
          178     x.set(initial_guess)
          179 
          180     ksp.setOperators(A)
          181 
          182     # Solve Ax = b
          183     # PETSc.Sys.Print("Solving...")
          184     ksp.solve(b, x)
          185     # PETSc.Sys.Print("done.")
          186 
          187     # transfer solution to processor 0
          188     vec0, scatter = create_scatter(x)
          189     scatter_to_0(x, vec0, scatter)
          190 
          191     return vec0, A
          192 
          193 
          194 def create_scatter(vector):
          195     "Create the scatter to processor 0."
          196     comm = vector.getComm()
          197     rank = comm.getRank()
          198     scatter, V0 = PETSc.Scatter.toZero(vector)
          199     scatter.scatter(vector, V0, False, PETSc.Scatter.Mode.FORWARD)
          200     comm.barrier()
          201 
          202     return V0, scatter
          203 
          204 
          205 def scatter_to_0(vector, vector_0, scatter):
          206     "Scatter a distributed 'vector' to 'vector_0' on processor 0 using 'scatter'."
          207     comm = vector.getComm()
          208     scatter.scatter(vector, vector_0, False, PETSc.Scatter.Mode.FORWARD)
          209     comm.barrier()
          210 
          211 
          212 def scatter_from_0(vector_0, vector, scatter):
          213     "Scatter 'vector_0' on processor 0 to a distributed 'vector' using 'scatter'."
          214     comm = vector.getComm()
          215     scatter.scatter(vector, vector_0, False, PETSc.Scatter.Mode.REVERSE)
          216     comm.barrier()
          217 
          218 
          219 def fill_2d_record(data, matrix=None):
          220     "Fill missing values in a 2D record."
          221 
          222     if getattr(data, "mask", None) is None:
          223         return data, None
          224     filled_data, A = fill_missing(data, matrix)
          225     if PETSc.COMM_WORLD.getRank() == 0:
          226         filled_data = filled_data[:].reshape(data.shape)
          227 
          228     return filled_data, A
          229 
          230 
          231 def test():
          232     "Test fill_missing() using synthetic data."
          233     N = 201
          234     M = N * 1.5
          235     x = np.linspace(-1, 1, N)
          236     y = np.linspace(-1, 1, M)
          237     xx, yy = np.meshgrid(x, y)
          238     zz = np.sin(2.5 * np.pi * xx) * np.cos(2.0 * np.pi * yy)
          239 
          240     K = 10
          241     mask = np.random.randint(0, K, zz.size).reshape(zz.shape) / float(K)
          242 
          243     mask[(xx - 1.0) ** 2 + (yy - 1.0) ** 2 < 1.0] = 1
          244     mask[(xx + 1.0) ** 2 + (yy + 1.0) ** 2 < 1.0] = 1
          245 
          246     field = np.ma.array(zz, mask=mask)
          247 
          248     zzz, _ = fill_missing(field)
          249 
          250     rank = PETSc.COMM_WORLD.getRank()
          251     if rank == 0:
          252         zzz0_np = zzz[:].reshape(field.shape)
          253 
          254         import pylab as plt
          255 
          256         plt.figure(1)
          257         plt.imshow(zz, interpolation='nearest')
          258 
          259         plt.figure(2)
          260         plt.imshow(field, interpolation='nearest')
          261 
          262         plt.figure(3)
          263         plt.imshow(zzz0_np, interpolation='nearest')
          264 
          265         plt.show()
          266 
          267 
          268 def fill_variable(nc, name):
          269     "Fill missing values in one variable."
          270     PETSc.Sys.Print("Processing %s..." % name)
          271     t0 = time()
          272 
          273     var = nc.variables[name]
          274 
          275     comm = PETSc.COMM_WORLD
          276     rank = comm.getRank()
          277 
          278     if var.ndim == 3:
          279         A = None
          280         n_records = var.shape[0]
          281         for t in range(n_records):
          282             PETSc.Sys.Print("Processing record %d/%d..." % (t + 1, n_records))
          283             data = var[t, :, :]
          284 
          285             filled_data, A = fill_2d_record(data, A)
          286             if rank == 0:
          287                 var[t, :, :] = filled_data
          288 
          289             comm.barrier()
          290         PETSc.Sys.Print("Time elapsed: %5f seconds." % (time() - t0))
          291     elif var.ndim == 2:
          292         data = var[:, :]
          293 
          294         filled_data, _ = fill_2d_record(data)
          295         if rank == 0:
          296             var[:, :] = filled_data
          297 
          298         comm.barrier()
          299         PETSc.Sys.Print("Time elapsed: %5f seconds." % (time() - t0))
          300     else:
          301         PETSc.Sys.Print("Skipping the %dD variable %s." % (var.ndim, name))
          302         return
          303 
          304     # Remove the _FillValue attribute:
          305     try:
          306         delattr(var, '_FillValue')
          307     except:
          308         pass
          309 
          310     # Remove the missing_value attribute:
          311     try:
          312         delattr(var, 'missing_value')
          313     except:
          314         pass
          315 
          316 
          317 def add_history(nc):
          318     "Update the history attribute in a NetCDF file nc."
          319     comm = PETSc.COMM_WORLD
          320     rank = comm.getRank()
          321 
          322     if rank != 0:
          323         return
          324 
          325     # add history global attribute (after checking if present)
          326     historysep = ' '
          327     historystr = asctime() + ': ' + historysep.join(sys.argv) + '\n'
          328     if 'history' in nc.ncattrs():
          329         nc.history = historystr + nc.history  # prepend to history string
          330     else:
          331         nc.history = historystr
          332 
          333 
          334 if __name__ == "__main__":
          335     from argparse import ArgumentParser
          336     import os
          337     import os.path
          338     import tempfile
          339     import shutil
          340     from time import time, asctime
          341 
          342     try:
          343         from netCDF4 import Dataset as NC
          344     except:
          345         PETSc.Sys.Print("netCDF4 is not installed!")
          346         sys.exit(1)
          347 
          348     parser = ArgumentParser()
          349     parser.description = "Fill missing values by solving the Laplace equation in on the missing values and using present values as Dirichlet B.C."
          350 
          351     parser.add_argument("INPUT", nargs=1, help="Input file name.")
          352     parser.add_argument("OUTPUT", nargs=1, help="Output file name.")
          353     parser.add_argument("-a", "--all", dest="all", action="store_true",
          354                         help="Process all variables.")
          355     parser.add_argument("-v", "--vars", dest="variables",
          356                         help="comma-separated list of variables to process")
          357 
          358     options, _ = parser.parse_known_args()
          359 
          360     input_filename = options.INPUT[0]
          361     output_filename = options.OUTPUT[0]
          362 
          363     if options.all:
          364         nc = NC(input_filename)
          365         variables = list(nc.variables.keys())
          366         nc.close()
          367     else:
          368         try:
          369             variables = (options.variables).split(',')
          370         except:
          371             PETSc.Sys.Print("Please specify variables using the -v option.")
          372             sys.exit(-1)
          373 
          374     # Done processing command-line options.
          375 
          376     comm = PETSc.COMM_WORLD
          377     rank = comm.getRank()
          378 
          379     t0 = time()
          380 
          381     PETSc.Sys.Print("Filling missing values in %s and saving results to %s..." % (input_filename,
          382                                                                                   output_filename))
          383     if rank == 0:
          384         try:
          385             PETSc.Sys.Print("Creating a temporary file...")
          386             # find the name of the directory with the output file:
          387             dirname = os.path.dirname(os.path.abspath(output_filename))
          388             (handle, tmp_filename) = tempfile.mkstemp(prefix="fill_missing_",
          389                                                       suffix=".nc",
          390                                                       dir=dirname)
          391 
          392             os.close(handle)  # mkstemp returns a file handle (which we don't need)
          393         except IOError:
          394             PETSc.Sys.Print("ERROR: Can't create %s, Exiting..." % tmp_filename)
          395 
          396         try:
          397             PETSc.Sys.Print("Copying input file %s to %s..." % (input_filename,
          398                                                                 tmp_filename))
          399             shutil.copy(input_filename, tmp_filename)
          400         except IOError:
          401             PETSc.Sys.Print("ERROR: Can't copy %s, Exiting..." % input_filename)
          402 
          403     try:
          404         if rank == 0:
          405             nc = NC(tmp_filename, 'a')
          406         else:
          407             nc = NC(input_filename, 'r')
          408     except Exception as message:
          409         PETSc.Sys.Print(message)
          410         PETSc.Sys.Print("Note: %s was not modified." % output_filename)
          411         sys.exit(-1)
          412 
          413     add_history(nc)
          414 
          415     for name in variables:
          416         try:
          417             fill_variable(nc, name)
          418         except Exception as message:
          419             PETSc.Sys.Print("ERROR:", message)
          420             PETSc.Sys.Print("Note: %s was not modified." % output_filename)
          421             sys.exit(-1)
          422     nc.close()
          423 
          424     try:
          425         if rank == 0:
          426             shutil.move(tmp_filename, output_filename)
          427     except:
          428         PETSc.Sys.Print("Error moving %s to %s. Exiting..." % (tmp_filename,
          429                                                                output_filename))
          430         sys.exit(-1)
          431 
          432     PETSc.Sys.Print("Total time elapsed: %5f seconds." % (time() - t0))