Skip to content
代码片段 群组 项目
opencl.F90 51.4 KB
Newer Older
!! Copyright (C) 2010 X. Andrade
!!
!! This program is free software; you can redistribute it and/or modify
!! it under the terms of the GNU General Public License as published by
!! the Free Software Foundation; either version 2, or (at your option)
!! any later version.
!!
!! This program is distributed in the hope that it will be useful,
!! but WITHOUT ANY WARRANTY; without even the implied warranty of
!! MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
!! GNU General Public License for more details.
!!
!! You should have received a copy of the GNU General Public License
!! along with this program; if not, write to the Free Software
!! Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA
!! 02111-1307, USA.
!!
!! $Id: opencl.F90 3587 2007-11-22 16:43:00Z xavier $
#endif
#ifdef HAVE_CLAMDFFT
  use clAmdFft
  use datasets_m
  use messages_m
  use parser_m
    opencl_init,                  &
    opencl_end,                   &
    opencl_create_buffer,         &
    opencl_write_buffer,          &
    opencl_read_buffer,           &
    opencl_release_buffer,        &
    opencl_finish,                &
    opencl_set_kernel_arg,        &
    opencl_max_workgroup_size,    &
    opencl_kernel_workgroup_size, &
    opencl_kernel_run,            &
    opencl_build_program,         &
    opencl_release_program,       &
    clblas_print_error,           &
    clfft_print_error,            &
    opencl_set_buffer_to_zero
    type(cl_platform_id)   :: platform_id
    type(cl_context)       :: context
    type(cl_command_queue) :: command_queue
    type(cl_device_id)     :: device
    integer                :: max_workgroup_size
    integer                :: local_memory_size
    logical                :: enabled
  type opencl_mem_t
    type(cl_mem)           :: mem
  end type opencl_mem_t

  type(opencl_t), public :: opencl
  type(cl_kernel), public :: kernel_vpsi
  type(cl_kernel), public :: kernel_vpsi_spinors
  type(cl_kernel), public :: kernel_daxpy
  type(cl_kernel), public :: kernel_zaxpy
  type(cl_kernel), public :: kernel_copy
  type(cl_kernel), public :: kernel_projector_bra
  type(cl_kernel), public :: kernel_projector_ket
  type(cl_kernel), public :: dpack
  type(cl_kernel), public :: zpack
  type(cl_kernel), public :: dunpack
  type(cl_kernel), public :: zunpack
  type(cl_kernel), public :: kernel_subarray_gather
  type(cl_kernel), public :: kernel_density_real
  type(cl_kernel), public :: kernel_density_complex
  type(cl_kernel), public :: kernel_phase
  type(cl_kernel), public :: dkernel_dot_matrix
  type(cl_kernel), public :: zkernel_dot_matrix
  type(cl_kernel), public :: zkernel_dot_matrix_spinors
  type(cl_kernel), public :: dkernel_dot_vector
  type(cl_kernel), public :: zkernel_dot_vector
  type(cl_kernel), public :: kernel_nrm2_vector
  type(cl_kernel), public :: dzmul
  type(cl_kernel), public :: zzmul
  ! kernels used locally
  type(cl_kernel)         :: set_zero

  interface opencl_create_buffer
    module procedure opencl_create_buffer_4
  end interface opencl_create_buffer
  interface opencl_write_buffer
    module procedure iopencl_write_buffer_1, dopencl_write_buffer_1, zopencl_write_buffer_1
    module procedure iopencl_write_buffer_2, dopencl_write_buffer_2, zopencl_write_buffer_2
    module procedure iopencl_write_buffer_3, dopencl_write_buffer_3, zopencl_write_buffer_3
  end interface opencl_write_buffer
  interface opencl_read_buffer
    module procedure iopencl_read_buffer_1, dopencl_read_buffer_1, zopencl_read_buffer_1
    module procedure iopencl_read_buffer_2, dopencl_read_buffer_2, zopencl_read_buffer_2
    module procedure iopencl_read_buffer_3, dopencl_read_buffer_3, zopencl_read_buffer_3
  end interface opencl_read_buffer
  interface opencl_set_kernel_arg
    module procedure                 &
      opencl_set_kernel_arg_buffer,  &
      iopencl_set_kernel_arg_data,   &
      dopencl_set_kernel_arg_data,   &
      zopencl_set_kernel_arg_data,   &
  end interface opencl_set_kernel_arg
  type(profile_t), save :: prof_read, prof_write, prof_kernel_run
  integer, parameter  ::      &
    OPENCL_GPU         = -1,  &
    OPENCL_CPU         = -2,  &
    OPENCL_ACCELERATOR = -3,  &
    OPENCL_DEFAULT     = -4


  integer, parameter  ::      &
    CL_PLAT_INVALID   = -1,   &
    CL_PLAT_AMD       = -2,   &
    CL_PLAT_NVIDIA    = -3,   &
    CL_PLAT_ATI       = -4,   &
    CL_PLAT_INTEL     = -5

  ! a "convenience" public variable
  integer, public :: cl_status
  integer, parameter :: OPENCL_MAX_FILE_LENGTH = 10000

  integer :: buffer_alloc_count
  integer(8) :: allocated_mem

    pure logical function opencl_is_enabled() result(enabled)
#ifdef HAVE_OPENCL
      enabled = opencl%enabled
#else
      enabled = .false.
#endif
    end function opencl_is_enabled

    ! ------------------------------------------
    subroutine opencl_init(base_grp)
      type(mpi_grp_t),  intent(inout) :: base_grp

      logical  :: disable, default, run_benchmark
      integer  :: device_type
      integer  :: idevice, iplatform, ndevices, idev, cl_status, ret_devices, nplatforms, iplat
      character(len=256) :: device_name
#ifdef HAVE_OPENCL
      type(cl_program) :: prog
      type(cl_platform_id), allocatable :: allplatforms(:)
      type(cl_device_id), allocatable :: alldevices(:)
      type(profile_t), save :: prof_init
      !%Variable DisableOpenCL
      !%Type logical
      !%Default yes
      !%Section Execution::OpenCL
      !%Description
      !% If Octopus was compiled with OpenCL support, it will try to
      !% initialize and use an OpenCL device. By setting this variable
      !% to <tt>yes</tt> you tell Octopus not to use OpenCL.
      !%End

#ifndef HAVE_OPENCL
      default = .true.
#else
      default = .false.
#endif
      call parse_logical(datasets_check('DisableOpenCL'), default, disable)
      opencl%enabled = .not. disable
#ifndef HAVE_OPENCL
      if(opencl%enabled) then
        message(1) = 'Octopus was compiled without OpenCL support.'
        call messages_fatal(1)
      end if
#endif

      if(.not. opencl_is_enabled()) then
        POP_SUB(opencl_init)
        return
      end if

      !%Variable OpenCLPlatform
      !%Type integer
      !%Default 0
      !%Section Execution::OpenCL
      !%Description
      !% This variable selects the OpenCL platform that Octopus will
      !% use. You can give an explicit platform number or use one of
      !% the options that select a particular vendor
      !% implementation. Platform 0 is used by default.
      !%Option amd -2
      !% Use the AMD OpenCL platform.
      !%Option nvidia -3
      !% Use the Nvidia OpenCL platform.
      !%Option ati -4
      !% Use the ATI (old AMD) OpenCL platform.
      !%Option intel -5
      !% Use the Intel OpenCL platform.
      !%End
      call parse_integer(datasets_check('OpenCLPlatform'), 0, iplatform)

      !%Variable OpenCLDevice
      !%Type integer
      !%Default 0
      !%Section Execution::OpenCL
      !%Description
      !% This variable selects the OpenCL device that Octopus will
      !% use. You can specify one of the options below or a numerical
      !% id to select a specific device.
      !%Option gpu -1
      !% If available, Octopus will use a GPU for OpenCL. This is the default.
      !%Option cpu -2
      !% If available, Octopus will use a GPU for OpenCL.
      !%Option accelerator -3
      !% If available, Octopus will use an accelerator for OpenCL.
      !%Option cl_default -4
      !% Octopus will use the default device specified by the OpenCL
      !% implementation.
      call parse_integer(datasets_check('OpenCLDevice'), OPENCL_GPU, idevice)
      if(idevice < OPENCL_DEFAULT) then
        message(1) = 'Invalid OpenCLDevice.'
        call messages_fatal(1)
      end if
      call messages_print_stress(stdout, "OpenCL")

      call profiling_in(prof_init, 'CL_INIT')
      
      call clGetPlatformIDs(nplatforms, cl_status)
      if(cl_status /= CL_SUCCESS) call opencl_print_error(cl_status, "GetPlatformIDs")
      call clGetPlatformIDs(allplatforms, iplat, cl_status)
      if(cl_status /= CL_SUCCESS) call opencl_print_error(cl_status, "GetPlatformIDs")

      call messages_write('Info: Available CL platforms: ')
      call messages_write(nplatforms)
      call messages_info()

      do iplat = 1, nplatforms
        call clGetPlatformInfo(allplatforms(iplat), CL_PLATFORM_NAME, device_name, cl_status)

        if(iplatform < 0) then
          if(iplatform == get_platform_id(device_name)) iplatform = iplat - 1
        end if

        if(iplatform == iplat - 1) then
          call messages_write('    * Platform ')
        else
          call messages_write('      Platform ')
        end if

        call messages_write(iplat - 1)
        call messages_write(' : '//device_name)
        call clGetPlatformInfo(allplatforms(iplat), CL_PLATFORM_VERSION, device_name, cl_status)
        call messages_write(' ('//trim(device_name)//')')
        call messages_info()
      end do

      if(iplatform >= nplatforms .or. iplatform < 0) then
        call messages_write('Requested CL platform does not exist')
        if(iplatform > 0) then 
          call messages_write('(platform = ')
          call messages_write(iplatform)
          call messages_write(').')
        end if
      opencl%platform_id = allplatforms(iplatform + 1)

      SAFE_DEALLOCATE_A(allplatforms)
      call clGetDeviceIDs(opencl%platform_id, CL_DEVICE_TYPE_ALL, ndevices, cl_status)

      call messages_write('Info: Available CL devices: ')
      call messages_write(ndevices)
      call messages_info()

      SAFE_ALLOCATE(alldevices(1:ndevices))

      ! list all devices

      call clGetDeviceIDs(opencl%platform_id, CL_DEVICE_TYPE_ALL, alldevices, ret_devices, cl_status)
        call messages_write('      Device ')
        call clGetDeviceInfo(alldevices(idev), CL_DEVICE_NAME, device_name, cl_status)
        call messages_write(' : '//device_name)
        call messages_info()
      end do

          device_type = CL_DEVICE_TYPE_GPU
          device_type = CL_DEVICE_TYPE_CPU
        case(OPENCL_ACCELERATOR)
          device_type = CL_DEVICE_TYPE_ACCELERATOR
          device_type = CL_DEVICE_TYPE_DEFAULT
        case default
          device_type = CL_DEVICE_TYPE_ALL
      end select

      ! now get a list of the selected type
      call clGetDeviceIDs(opencl%platform_id, device_type, alldevices, ret_devices, cl_status)
      
      if(ret_devices < 1) then
        ! we didnt find a device of the selected type, we ask for the default device
        call clGetDeviceIDs(opencl%platform_id, CL_DEVICE_TYPE_DEFAULT, alldevices, ret_devices, cl_status)

        if(ret_devices < 1) then
          ! if this does not work, we ask for all devices
          call clGetDeviceIDs(opencl%platform_id, CL_DEVICE_TYPE_ALL, alldevices, ret_devices, cl_status)
        end if
        
        if(ret_devices < 1) then
          call messages_write('Cannot find an OpenCL device')
          call messages_fatal()
        end if
      end if

      ! the number of devices can be smaller
      ndevices = ret_devices

        if(base_grp%size > 1) then
          ! with MPI we have to select the device so multiple GPUs in one
          ! node are correctly distributed
          call select_device(idevice)
        else
          idevice = 0
        end if
      if(idevice >= ndevices) then
        call messages_write('Requested CL device does not exist (device = ')
        call messages_write(idevice)
        call messages_write(', platform = ')
        call messages_write(iplatform)
        call messages_write(').')
        call messages_fatal()
      end if

      opencl%device = alldevices(idevice + 1)

      if(mpi_grp_is_root(base_grp)) call device_info()
      opencl%context = clCreateContext(opencl%platform_id, opencl%device, cl_status)
      if(cl_status /= CL_SUCCESS) call opencl_print_error(cl_status, "CreateContext")

      SAFE_DEALLOCATE_A(alldevices)

      opencl%command_queue = clCreateCommandQueue(opencl%context, opencl%device, CL_QUEUE_PROFILING_ENABLE, cl_status)
      if(cl_status /= CL_SUCCESS) call opencl_print_error(cl_status, "CreateCommandQueue")
      call clGetDeviceInfo(opencl%device, CL_DEVICE_MAX_WORK_GROUP_SIZE, opencl%max_workgroup_size, cl_status)
      call clGetDeviceInfo(opencl%device, CL_DEVICE_LOCAL_MEM_SIZE, opencl%local_memory_size, cl_status)
      ! now initialize the kernels
      call opencl_build_program(prog, trim(conf%share)//'/opencl/set_zero.cl')
      call opencl_create_kernel(set_zero, prog, "set_zero")
      call opencl_release_program(prog)
      call opencl_build_program(prog, trim(conf%share)//'/opencl/vpsi.cl')
      call opencl_create_kernel(kernel_vpsi, prog, "vpsi")
      call opencl_create_kernel(kernel_vpsi_spinors, prog, "vpsi_spinors")
      call opencl_release_program(prog)
      
      call opencl_build_program(prog, trim(conf%share)//'/opencl/axpy.cl', flags = '-DRTYPE_DOUBLE')
      call opencl_create_kernel(kernel_daxpy, prog, "daxpy")
      call opencl_release_program(prog)

      call opencl_build_program(prog, trim(conf%share)//'/opencl/axpy.cl', flags = '-DRTYPE_COMPLEX')
      call opencl_create_kernel(kernel_zaxpy, prog, "zaxpy")
      call opencl_release_program(prog)

      call opencl_build_program(prog, trim(conf%share)//'/opencl/projector.cl')
      call opencl_create_kernel(kernel_projector_bra, prog, "projector_bra")
      call opencl_create_kernel(kernel_projector_ket, prog, "projector_ket")
      call opencl_release_program(prog)

      call opencl_build_program(prog, trim(conf%share)//'/opencl/pack.cl')
      call opencl_create_kernel(dpack, prog, "dpack")
      call opencl_create_kernel(zpack, prog, "zpack")
      call opencl_create_kernel(dunpack, prog, "dunpack")
      call opencl_create_kernel(zunpack, prog, "zunpack")
      call opencl_release_program(prog)
      call opencl_build_program(prog, trim(conf%share)//'/opencl/copy.cl')
      call opencl_create_kernel(kernel_copy, prog, "copy")
      call opencl_release_program(prog)

      call opencl_build_program(prog, trim(conf%share)//'/opencl/subarray.cl')
      call opencl_create_kernel(kernel_subarray_gather, prog, "subarray_gather")
      call opencl_release_program(prog)

      call opencl_build_program(prog, trim(conf%share)//'/opencl/density.cl')
      call opencl_create_kernel(kernel_density_real, prog, "density_real")
      call opencl_create_kernel(kernel_density_complex, prog, "density_complex")
      call opencl_release_program(prog)

      call opencl_build_program(prog, trim(conf%share)//'/opencl/phase.cl')
      call opencl_create_kernel(kernel_phase, prog, "phase")
      call opencl_release_program(prog)

      call opencl_build_program(prog, trim(conf%share)//'/opencl/mesh_batch.cl')
      call opencl_create_kernel(dkernel_dot_vector, prog, "ddot_vector")
      call opencl_create_kernel(zkernel_dot_vector, prog, "zdot_vector")
      call opencl_create_kernel(dkernel_dot_matrix, prog, "ddot_matrix")
      call opencl_create_kernel(zkernel_dot_matrix, prog, "zdot_matrix")
      call opencl_create_kernel(zkernel_dot_matrix_spinors, prog, "zdot_matrix_spinors")
      call opencl_create_kernel(kernel_nrm2_vector, prog, "nrm2_vector")
      call opencl_release_program(prog)
      call opencl_build_program(prog, trim(conf%share)//'/opencl/mul.cl', flags = '-DRTYPE_DOUBLE')
      call opencl_create_kernel(dzmul, prog, "dzmul")
      call opencl_release_program(prog)

      call opencl_build_program(prog, trim(conf%share)//'/opencl/mul.cl', flags = '-DRTYPE_COMPLEX')
      call opencl_create_kernel(zzmul, prog, "zzmul")
      call opencl_release_program(prog)
#ifdef HAVE_CLAMDBLAS
      call clAmdBlasSetup(cl_status)
      if(cl_status /= clAmdBlasSuccess) call clblas_print_error(cl_status, 'clAmdBlasSetup')
#ifdef HAVE_CLAMDFFT
      call clAmdFftSetup(cl_status)
      if(cl_status /= CLFFT_SUCCESS) call clfft_print_error(cl_status, 'clAmdFftSetup')
      call profiling_out(prof_init)
#endif
      !%Variable OpenCLBenchmark
      !%Type logical
      !%Default no
      !%Section Execution::OpenCL
      !%Description
      !% If this variable is set to yes, Octopus will run some
      !% routines to benchmark the performance of the OpenCL device.
      !%End

      call parse_logical(datasets_check('OpenCLBenchmark'), .false., run_benchmark)

      if(run_benchmark) then
        call opencl_check_bandwidth()
      end if

      call messages_print_stress(stdout)


    contains
      
      subroutine select_device(idevice)
        integer, intent(inout) :: idevice
#if defined(HAVE_MPI) && defined(HAVE_OPENCL)
        integer :: irank
        character(len=256) :: device_name
David Strubbe's avatar
David Strubbe 已提交
        PUSH_SUB(opencl_init.select_device)

        idevice = mod(base_grp%rank, ndevices)

        call MPI_Barrier(base_grp%comm, mpi_err)
        call messages_write('Info: CL device distribution:')
        call messages_info()
        do irank = 0, base_grp%size - 1
          if(irank == base_grp%rank) then
            call clGetDeviceInfo(alldevices(idevice + 1), CL_DEVICE_NAME, device_name, cl_status)
            call messages_write('      MPI node ')
            call messages_write(base_grp%rank)
            call messages_write(' -> CL device ')
            call messages_write(idevice)
            call messages_write(' : '//device_name)
            call messages_info(all_nodes = .true.)
          end if
          call MPI_Barrier(base_grp%comm, mpi_err)
        end do
#endif

David Strubbe's avatar
David Strubbe 已提交
        POP_SUB(opencl_init.select_device)
      subroutine device_info()

#ifdef HAVE_OPENCL
        integer(8) :: val 
        character(len=256) :: val_str

David Strubbe's avatar
David Strubbe 已提交
        PUSH_SUB(opencl_init.device_info)

        call messages_new_line()
        call messages_write('Selected CL device:')
        call messages_new_line()

        call clGetDeviceInfo(opencl%device, CL_DEVICE_VENDOR, val_str, cl_status)
        call messages_write('      Device vendor          : '//trim(val_str))
        call messages_new_line()

        call clGetDeviceInfo(opencl%device, CL_DEVICE_NAME, val_str, cl_status)
        call messages_write('      Device name            : '//trim(val_str))
        call messages_new_line()

        call clGetDeviceInfo(opencl%device, CL_DRIVER_VERSION, val_str, cl_status)
        call messages_write('      Driver version         : '//trim(val_str))
        call messages_new_line()

        call clGetDeviceInfo(opencl%device, CL_DEVICE_MAX_COMPUTE_UNITS, val, cl_status)
        call messages_write('      Compute units          :')
        call messages_write(val)
        call messages_new_line()

        call clGetDeviceInfo(opencl%device, CL_DEVICE_MAX_CLOCK_FREQUENCY, val, cl_status)
        call messages_write('      Clock frequency        :')
        call messages_write(val)
        call messages_write(' GHz')
        call messages_new_line()

        call clGetDeviceInfo(opencl%device, CL_DEVICE_GLOBAL_MEM_SIZE, val, cl_status)
        call messages_write('      Device memory          :')
        call messages_write(val, units = unit_megabytes)
        call messages_new_line()

        call clGetDeviceInfo(opencl%device, CL_DEVICE_MAX_MEM_ALLOC_SIZE, val, cl_status)
        call messages_write('      Max alloc size         :')
        call messages_write(val, units = unit_megabytes)
        call messages_new_line()

        call clGetDeviceInfo(opencl%device, CL_DEVICE_GLOBAL_MEM_CACHE_SIZE, val, cl_status)
        call messages_write('      Device cache           :')
        call messages_write(val, units = unit_kilobytes)
        call messages_new_line()

        call clGetDeviceInfo(opencl%device, CL_DEVICE_LOCAL_MEM_SIZE, val, cl_status)
        call messages_write('      Local memory           :')
        call messages_write(val, units = unit_kilobytes)
        call messages_new_line()

        call clGetDeviceInfo(opencl%device, CL_DEVICE_MAX_CONSTANT_BUFFER_SIZE, val, cl_status)
        call messages_write('      Constant memory        :')
        call messages_write(val, units = unit_kilobytes)
        call messages_new_line()

        call clGetDeviceInfo(opencl%device, CL_DEVICE_MAX_WORK_GROUP_SIZE, val, cl_status)
        call messages_write('      Max. workgroup size    :')
        call messages_write(val)
        call messages_new_line()

        call messages_write('      Extension cl_khr_fp64  :')
        call messages_write(f90_cl_device_has_extension(opencl%device, "cl_khr_fp64"))
        call messages_new_line()

        call messages_write('      Extension cl_amd_fp64  :')
        call messages_write(f90_cl_device_has_extension(opencl%device, "cl_amd_fp64"))
        call messages_new_line()

David Strubbe's avatar
David Strubbe 已提交
        POP_SUB(opencl_init.device_info)
      end subroutine device_info

    end subroutine opencl_init

    ! ------------------------------------------
    integer function get_platform_id(platform_name) result(platform_id)
      character(len=*), intent(in) :: platform_name

      platform_id = CL_PLAT_INVALID
      if(index(platform_name, 'AMD') > 0)    platform_id = CL_PLAT_AMD
      if(index(platform_name, 'ATI') > 0)    platform_id = CL_PLAT_ATI
      if(index(platform_name, 'NVIDIA') > 0) platform_id = CL_PLAT_NVIDIA
      if(index(platform_name, 'Intel') > 0)  platform_id = CL_PLAT_INTEL
    end function get_platform_id

    ! ------------------------------------------

#ifdef HAVE_OPENCL
#ifdef HAVE_CLAMDBLAS
      call clAmdBlasTearDown()
#endif

#ifdef HAVE_CLAMDFFT
      call clAmdFftTearDown()
#endif

      if(opencl_is_enabled()) then
        call opencl_release_kernel(kernel_vpsi)
        call opencl_release_kernel(kernel_vpsi_spinors)
        call opencl_release_kernel(set_zero)
        call opencl_release_kernel(kernel_daxpy)
        call opencl_release_kernel(kernel_zaxpy)
        call opencl_release_kernel(kernel_copy)
        call opencl_release_kernel(kernel_projector_bra)
        call opencl_release_kernel(kernel_projector_ket)
        call opencl_release_kernel(dpack)
        call opencl_release_kernel(zpack)
        call opencl_release_kernel(dunpack)
        call opencl_release_kernel(zunpack)
        call opencl_release_kernel(kernel_subarray_gather)
        call opencl_release_kernel(kernel_density_real)
        call opencl_release_kernel(kernel_density_complex)
        call opencl_release_kernel(kernel_phase)
        call opencl_release_kernel(dkernel_dot_matrix)
        call opencl_release_kernel(zkernel_dot_matrix)
        call opencl_release_kernel(dkernel_dot_vector)
        call opencl_release_kernel(zkernel_dot_vector)
        call opencl_release_kernel(zkernel_dot_matrix_spinors)
        call clReleaseCommandQueue(opencl%command_queue, ierr)
        if(ierr /= CL_SUCCESS) call opencl_print_error(ierr, "ReleaseCommandQueue")
        call clReleaseContext(opencl%context, cl_status)

        if(buffer_alloc_count /= 0) then
          call messages_write('OpenCL:')
          call messages_write(real(allocated_mem, REAL_PRECISION), fmt = 'f12.1', units = unit_megabytes, align_left = .true.)
          call messages_write(' in ')
          call messages_write(buffer_alloc_count)
          call messages_write(' buffers were not deallocated.')
          call messages_warning()
        end if
    end subroutine opencl_end

    ! ------------------------------------------

    integer function opencl_padded_size(nn) result(psize)
      integer,        intent(in) :: nn

#ifdef HAVE_OPENCL
      integer :: modnn, bsize

      bsize = opencl_max_workgroup_size()

      psize = nn
      modnn = mod(nn, bsize)
      if(modnn /= 0) psize = psize + bsize - modnn
#else
      psize = nn
#endif
    end function opencl_padded_size

#ifdef HAVE_OPENCL
    ! ------------------------------------------

    subroutine opencl_create_buffer_4(this, flags, type, size)
      type(opencl_mem_t), intent(inout) :: this
      integer,            intent(in)    :: flags
      fsize = int(size, 8)*types_get_size(type)

      ASSERT(fsize >= 0)
      this%mem = clCreateBuffer(opencl%context, flags, fsize, ierr)
      if(ierr /= CL_SUCCESS) call opencl_print_error(ierr, "clCreateBuffer")
    end subroutine opencl_create_buffer_4

    ! ------------------------------------------

    subroutine opencl_release_buffer(this)
      type(opencl_mem_t), intent(inout) :: this

      call clReleaseMemObject(this%mem, ierr)
      if(ierr /= CL_SUCCESS) call opencl_print_error(ierr, "clReleaseMemObject")
      INCR(buffer_alloc_count, -1)
      INCR(allocated_mem, -int(this%size, 8)*types_get_size(this%type))

David Strubbe's avatar
David Strubbe 已提交
      POP_SUB(opencl_release_buffer)
    end subroutine opencl_release_buffer

    ! ------------------------------------------

    integer(SIZEOF_SIZE_T) pure function opencl_get_buffer_size(this) result(size)
      type(opencl_mem_t), intent(in) :: this

      size = this%size
    end function opencl_get_buffer_size

    ! -----------------------------------------

    type(type_t) pure function opencl_get_buffer_type(this) result(type)
      type(opencl_mem_t), intent(in) :: this

      type = this%type
    end function opencl_get_buffer_type

    ! -----------------------------------------

    subroutine opencl_finish()
David Strubbe's avatar
David Strubbe 已提交
      PUSH_SUB(opencl_finish)

      call profiling_in(prof_kernel_run, "CL_KERNEL_RUN")

      call clFinish(opencl%command_queue, ierr)
      if(ierr /= CL_SUCCESS) call opencl_print_error(ierr, 'clFinish') 
David Strubbe's avatar
David Strubbe 已提交
      POP_SUB(opencl_finish)
    end subroutine opencl_finish

    ! ------------------------------------------

    subroutine opencl_set_kernel_arg_buffer(kernel, narg, buffer)
      type(cl_kernel),    intent(inout) :: kernel
      integer,            intent(in)    :: narg
      type(opencl_mem_t), intent(in)    :: buffer
      
      call clSetKernelArg(kernel, narg, buffer%mem, ierr)
      if(ierr /= CL_SUCCESS) call opencl_print_error(ierr, "clSetKernelArg_buf")

    end subroutine opencl_set_kernel_arg_buffer

    ! ------------------------------------------

    subroutine opencl_set_kernel_arg_local(kernel, narg, type, size)
      type(cl_kernel),    intent(inout) :: kernel
      size_in_bytes = int(size, 8)*types_get_size(type)
      
      if(size_in_bytes > opencl%local_memory_size) then
        write(message(1), '(a,f12.6,a)') "CL Error: requested local memory: ", dble(size_in_bytes)/1024.0, " Kb"
        write(message(2), '(a,f12.6,a)') "          available local memory: ", dble(opencl%local_memory_size)/1024.0, " Kb"
      else if(size_in_bytes <= 0) then
        write(message(1), '(a,i10)') "CL Error: invalid local memory size: ", size_in_bytes
        call messages_fatal(1)
      call clSetKernelArgLocal(kernel, narg, size_in_bytes, ierr)
      if(ierr /= CL_SUCCESS) call opencl_print_error(ierr, "set_kernel_arg_local")

    end subroutine opencl_set_kernel_arg_local

    ! ------------------------------------------

    subroutine opencl_kernel_run(kernel, globalsizes, localsizes)
      type(cl_kernel),    intent(inout) :: kernel
      integer,            intent(in)    :: globalsizes(:)
      integer,            intent(in)    :: localsizes(:)
      
      integer(8) :: gsizes(1:3)
      integer(8) :: lsizes(1:3)
      call profiling_in(prof_kernel_run, "CL_KERNEL_RUN")

      dim = ubound(globalsizes, dim = 1)
      ASSERT(dim == ubound(localsizes, dim = 1))
      ASSERT(all(localsizes <= opencl_max_workgroup_size()))
      ASSERT(all(mod(globalsizes, localsizes) == 0))
     
      gsizes(1:dim) = int(globalsizes(1:dim), 8)
      lsizes(1:dim) = int(localsizes(1:dim), 8)
      call clEnqueueNDRangeKernel(opencl%command_queue, kernel, gsizes(1:dim), lsizes(1:dim), ierr)
      if(ierr /= CL_SUCCESS) call opencl_print_error(ierr, "EnqueueNDRangeKernel")

      call profiling_out(prof_kernel_run)
    ! -----------------------------------------------

    integer pure function opencl_max_workgroup_size() result(max_workgroup_size)
      max_workgroup_size = opencl%max_workgroup_size
    end function opencl_max_workgroup_size

    ! -----------------------------------------------
    integer function opencl_kernel_workgroup_size(kernel) result(workgroup_size)
      type(cl_kernel), intent(inout) :: kernel
      integer(8) :: workgroup_size8
      integer    :: ierr

      call clGetKernelWorkGroupInfo(kernel, opencl%device, CL_KERNEL_WORK_GROUP_SIZE, workgroup_size8, ierr)
      if(ierr /= CL_SUCCESS) call opencl_print_error(ierr, "EnqueueNDRangeKernel")
      workgroup_size = workgroup_size8
    end function opencl_kernel_workgroup_size

    ! -----------------------------------------------

    subroutine opencl_build_program(prog, filename, flags)
      type(cl_program),           intent(inout) :: prog
      character(len=*),           intent(in)    :: filename
      character(len=*), optional, intent(in)    :: flags
      
      character(len = OPENCL_MAX_FILE_LENGTH) :: string
      integer :: ierr, ierrlog, iunit, irec
      type(profile_t), save :: prof
David Strubbe's avatar
David Strubbe 已提交

      PUSH_SUB(opencl_build_program)
      call profiling_in(prof, "CL_COMPILE", exclude = .true.)
      string = ''

      call io_assign(iunit)
      open(unit = iunit, file = trim(filename), access='direct', status = 'old', action = 'read', iostat = ierr, recl = 1)
      irec = 1
      do
        read(unit = iunit, rec = irec, iostat = ierr) string(irec:irec) 
        if (ierr /= 0) exit
        if(irec == OPENCL_MAX_FILE_LENGTH) then
          call messages_write('CL source file is too big: '//trim(filename)//'.')
          call messages_new_line()
          call messages_write("       Increase 'OPENCL_MAX_FILE_LENGTH'.")
          call messages_fatal()
        end if
        irec = irec + 1
      end do

      close(unit = iunit)
      call io_free(iunit)

      call messages_write("Building CL program '"//trim(filename)//"'.")
      call messages_info()

      prog = clCreateProgramWithSource(opencl%context, string, ierr)
      if(ierr /= CL_SUCCESS) call opencl_print_error(ierr, "clCreateProgramWithSource")

      ! build the compilation flags
      string='-w'
      ! full optimization
      string=trim(string)//' -cl-denorms-are-zero'
      string=trim(string)//' -cl-strict-aliasing'
      string=trim(string)//' -cl-mad-enable'
      string=trim(string)//' -cl-unsafe-math-optimizations'
      string=trim(string)//' -cl-finite-math-only'
      string=trim(string)//' -cl-fast-relaxed-math'

      string=trim(string)//' -I'//trim(conf%share)//'/opencl/'
      if (f90_cl_device_has_extension(opencl%device, "cl_amd_fp64")) then
        string = trim(string)//' -DEXT_AMD_FP64'
      else if(f90_cl_device_has_extension(opencl%device, "cl_khr_fp64")) then
        string = trim(string)//' -DEXT_KHR_FP64'
        call messages_write('Octopus requires an OpenCL device with double-precision support.')
        call messages_fatal()
        string = trim(string)//' '//trim(flags)
        message(1) = "Debug info: compilation flags '"//trim(string)//"'. "
      call clBuildProgram(prog, trim(string), ierr)
      call clGetProgramBuildInfo(prog, opencl%device, CL_PROGRAM_BUILD_LOG, string, ierrlog)
      if(ierrlog /= CL_SUCCESS) call opencl_print_error(ierrlog, "clGetProgramBuildInfo")
      
      if(len(trim(string)) > 0) write(stderr, '(a)') trim(string)

      if(ierr /= CL_SUCCESS) call opencl_print_error(ierr, "clBuildProgram")

David Strubbe's avatar
David Strubbe 已提交
      POP_SUB(opencl_build_program)
    end subroutine opencl_build_program

    ! -----------------------------------------------

    subroutine opencl_release_program(prog)
      type(cl_program),    intent(inout) :: prog
David Strubbe's avatar
David Strubbe 已提交
      PUSH_SUB(opencl_release_program)

      call clReleaseProgram(prog, ierr)
      if(ierr /= CL_SUCCESS) call opencl_print_error(ierr, "clReleaseProgram")