/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.instructions.gpu.context;

import jcuda.CudaException;
import jcuda.Pointer;
import jcuda.runtime.JCuda;
import jcuda.runtime.cudaError;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.runtime.instructions.gpu.context.GPUMemoryAllocator;

public class CudaMemoryAllocator
implements GPUMemoryAllocator {
    private static long unusableFreeMem = 0L;

    @Override
    public void allocate(Pointer devPtr, long size) {
        try {
            int n = JCuda.cudaMalloc((Pointer)devPtr, (long)size);
        }
        catch (CudaException e) {
            if (e.getMessage().equals("cudaErrorMemoryAllocation")) {
                unusableFreeMem = this.getAvailableMemory();
            }
            throw new CudaException("cudaMalloc failed: " + e.getMessage());
        }
    }

    @Override
    public void free(Pointer devPtr) throws CudaException {
        int status = JCuda.cudaFree((Pointer)devPtr);
        if (status != 0) {
            throw new CudaException("cudaFree failed:" + cudaError.stringFor((int)status));
        }
    }

    @Override
    public boolean canAllocate(long size) {
        return size <= this.getAvailableMemory() - unusableFreeMem;
    }

    @Override
    public long getAvailableMemory() {
        long[] free = new long[]{0L};
        long[] total = new long[]{0L};
        JCuda.cudaMemGetInfo((long[])free, (long[])total);
        return (long)((double)free[0] * DMLScript.GPU_MEMORY_UTILIZATION_FACTOR);
    }

    public static void resetUnusableFreeMemory() {
        unusableFreeMem = 0L;
    }
}

