/*******************************************************************************
* Copyright 2014-2020 Intel Corporation.
*
* This software and the related documents are Intel copyrighted  materials,  and
* your use of  them is  governed by the  express license  under which  they were
* provided to you (License).  Unless the License provides otherwise, you may not
* use, modify, copy, publish, distribute,  disclose or transmit this software or
* the related documents without Intel's prior written permission.
*
* This software and the related documents  are provided as  is,  with no express
* or implied  warranties,  other  than those  that are  expressly stated  in the
* License.
*******************************************************************************/

//@HEADER
// ***************************************************
//
// HPCG: High Performance Conjugate Gradient Benchmark
//
// Contact:
// Michael A. Heroux ( maherou@sandia.gov)
// Jack Dongarra     (dongarra@eecs.utk.edu)
// Piotr Luszczek    (luszczek@eecs.utk.edu)
//
// ***************************************************
//@HEADER

/*!
 @file ComputeDotProduct.cpp

 HPCG routine
 */
#ifndef HPCG_NO_MPI
#include <mpi.h>
#include "mytimer.hpp"
#endif

#include "ComputeDotProduct.hpp"
#include "ComputeDotProduct_ref.hpp"

#include "Helpers.hpp"
#include "EsimdHelpers.hpp"

using LSCAtomicOp = sycl::ext::intel::esimd::native::lsc::atomic_op;

using namespace sycl;
/*!
  Routine to compute the dot product of two vectors.

  This routine calls the reference dot-product implementation by default, but
  can be replaced by a custom routine that is optimized and better suited for
  the target system.

  @param[in]  n the number of vector elements (on this processor)
  @param[in]  x, y the input vectors (data should be device-accessible)
  @param[out] result a pointer to scalar value, on exit will contain the result
              (in serial, should be device-accessible, in parallel should be host-accessible,
              which is weird and should be made consistent TODO)
  @param[out] time_allreduce the time it took to perform the communication between processes

  @return sycl::event to wait on

*/

// Template unroll for loading x/y and accumulating into res_vec (simd vector)
template <local_int_t block_size, local_int_t uroll, local_int_t s = 0>
static inline void dot_impl(esimd::simd<double, block_size> &res_vec,
              const double *x, const double* y) {
    if constexpr (s < uroll) {
        auto x_vec = esimd_lsc_block_load<double, local_int_t, block_size, st, uc>(x, s * block_size);
        auto y_vec = esimd_lsc_block_load<double, local_int_t, block_size, st, uc>(y, s * block_size);
        res_vec += x_vec * y_vec;
        dot_impl<block_size, uroll, s + 1>(res_vec, x, y);
    }
}

sycl::event ComputeDotProductLocal(
    const local_int_t n, const Vector & x, const Vector & y, double * result, sycl::queue & main_queue,
    const std::vector<sycl::event> &deps) {

    //DPCPP ComputeDotProductKernel
    const double * xv = x.values;
    const double * yv = y.values;

    constexpr local_int_t block_size = HPCG_BLOCK_SIZE;
    constexpr local_int_t uroll = 16; 
    const local_int_t nWG = 8;
    local_int_t nBlocks = ceil_div(n, uroll * block_size);

    // Assume n is a multiple of block_size, since no remainder handling
    assert(n % block_size == 0);

    auto evt = main_queue.memset(result, 0, sizeof(double), deps);

    auto last = main_queue.submit([&](sycl::handler &cgh) {
        cgh.depends_on(evt);

        auto ddot_esimd_kernel = [=](sycl::nd_item<1> item) SYCL_ESIMD_KERNEL
            {
                local_int_t block = item.get_global_id(0);
                local_int_t offset = block * uroll * block_size;

                if (block >= nBlocks) return;

                esimd::simd<double, block_size> res_vec(0);

                if (block < nBlocks - 1) {
                    dot_impl<block_size, uroll>(res_vec, xv + offset, yv + offset);
                }
                else { // Last WG handles remainder w.r.t unroll if needed
                    if (offset + uroll * block_size == n)
                        dot_impl<block_size, uroll>(res_vec, xv + offset, yv + offset);
                    else {
                        for (; offset < n; offset += block_size)
                            dot_impl<block_size, 1>(res_vec, xv + offset, yv + offset);
                    }
                }

                auto res = esimd::reduce<double>(res_vec, std::plus<>());
                sycl::ext::intel::esimd::atomic_update<LSCAtomicOp::fadd, double, 1>(result, 0, res);
            };

        cgh.parallel_for<class ddot_esimd_kernel>(
            sycl::nd_range<1>(ceil_div(nBlocks, nWG) * nWG, nWG), ddot_esimd_kernel);
    });

    return last;
}


sycl::event ComputeDotProduct(
    const local_int_t n, const Vector & x, const Vector & y,
    double * result, double & time_allreduce, sycl::queue & main_queue,
    const std::vector<sycl::event> &deps) {

#ifndef HPCG_NO_MPI
    double * local_result = sycl::malloc_device<double>(1, main_queue);
    double * host_local_result = sycl::malloc_host<double>(1, main_queue);
#else
    double * local_result = result;
#endif

    auto local_dot_ev = ComputeDotProductLocal(n, x, y, local_result, main_queue, deps);

#ifndef HPCG_NO_MPI
    double t0 = mytimer();
    double global_result = 0.0;
    main_queue.memcpy(host_local_result, local_result, sizeof(double),
                      {local_dot_ev}).wait();
    MPI_Allreduce(host_local_result, &global_result, 1, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD);
    *result = global_result;
    sycl::free(local_result, main_queue);
    sycl::free(host_local_result, main_queue);
    time_allreduce += mytimer() - t0;
    return sycl::event();
#else
    return local_dot_ev;
#endif
}
