/*******************************************************************************
* Copyright 2014-2020 Intel Corporation.
*
* This software and the related documents are Intel copyrighted  materials,  and
* your use of  them is  governed by the  express license  under which  they were
* provided to you (License).  Unless the License provides otherwise, you may not
* use, modify, copy, publish, distribute,  disclose or transmit this software or
* the related documents without Intel's prior written permission.
*
* This software and the related documents  are provided as  is,  with no express
* or implied  warranties,  other  than those  that are  expressly stated  in the
* License.
*******************************************************************************/

//@HEADER
// ***************************************************
//
// HPCG: High Performance Conjugate Gradient Benchmark
//
// Contact:
// Michael A. Heroux ( maherou@sandia.gov)
// Jack Dongarra     (dongarra@eecs.utk.edu)
// Piotr Luszczek    (luszczek@eecs.utk.edu)
//
// ***************************************************
//@HEADER

/*!
 @file ComputeSPMV.cpp

 HPCG routine
 */

#include "ComputeSPMV.hpp"
#include "ComputeSPMV_ref.hpp"
#include "ComputeDotProduct.hpp"
#include "UsmUtil.hpp"
#ifndef HPCG_NO_MPI
#include "ExchangeHalo.hpp"

#include <mpi.h>
#include "Geometry.hpp"
#include <cstdlib>
#endif

#include "CustomKernels.hpp"
#include "VeryBasicProfiler.hpp"

#ifdef BASIC_PROFILING
#define BEGIN_PROFILE(n) optData->profiler->begin(n);
#define END_PROFILE(n) optData->profiler->end(n);
#define END_PROFILE_WAIT(n, event) event.wait(); optData->profiler->end(n);
#else
#define BEGIN_PROFILE(n)
#define END_PROFILE(n)
#define END_PROFILE_WAIT(n, event)
#endif

using namespace sycl;

/*!
  Routine to compute sparse matrix vector product y = Ax where:
  Precondition: First call exchange_externals to get off-processor values of x

  This routine calls the reference SpMV implementation by default, but
  can be replaced by a custom, optimized routine suited for
  the target system.

  @param[in]  A the known system matrix
  @param[in]  x the known vector
  @param[out] y the On exit contains the result: Ax.

  @return returns 0 upon success and non-zero otherwise

  @see ComputeSPMV_ref
*/

sycl::event ComputeSPMV( const SparseMatrix & A, Vector & x, Vector & y, sycl::queue & main_queue,
                         int& ierr, const std::vector<sycl::event> & deps)
{
#ifdef HPCG_LOCAL_LONG_LONG
    ComputeSPMV_ref(A,x,y);
#else
    struct optData *optData = (struct optData *)A.optimizationData;

    custom::sparseMatrix *sparseM = (custom::sparseMatrix *)optData->esbM;

    BEGIN_PROFILE("SPMV:halo");
#ifndef HPCG_NO_MPI
    sycl::event halo_ev = ExchangeHalo(A, x, main_queue, deps);
    std::vector<sycl::event> halo_deps({halo_ev});
    END_PROFILE_WAIT("SPMV:halo", halo_ev);
#else
    const std::vector<sycl::event>& halo_deps = deps;
    END_PROFILE("SPMV:halo");
#endif

    sycl::event last_ev;
    try {
        BEGIN_PROFILE("SPMV:gemv");
        last_ev = custom::SpGEMV(main_queue, sparseM, x.values, y.values, halo_deps);
        END_PROFILE_WAIT("SPMV:gemv", last_ev);
    }
    catch (sycl::exception const &e) {
        std::cout << "\t\tCaught synchronous SYCL exception in SPMV:\n" << e.what() << std::endl;
        ierr += 1;
        return last_ev;
    }
    catch (std::exception const &e) {
        std::cout << "\t\tCaught std exception in SPMV:\n" << e.what() << std::endl;
        ierr += 1;
        return last_ev;
    }

#endif
    return last_ev;
}

sycl::event ComputeSPMV_DOT( const SparseMatrix & A, Vector & x, Vector & y, double * pAp,
                             sycl::queue & main_queue, int& ierr, const std::vector<sycl::event> & deps)
{
    struct optData *optData = (struct optData *)A.optimizationData;

    custom::sparseMatrix *sparseM = (custom::sparseMatrix *)optData->esbM;

    BEGIN_PROFILE("SPMV_DOT:halo");
#ifndef HPCG_NO_MPI
    sycl::event halo_ev = ExchangeHalo(A,x, main_queue, deps);
    std::vector<sycl::event> halo_deps({halo_ev});
    END_PROFILE_WAIT("SPMV_DOT:halo", halo_ev);
#else
    const std::vector<sycl::event>& halo_deps = deps;
    END_PROFILE("SPMV_DOT:halo");
#endif

    //gemv + dot product
    sycl::event last_ev;
    try {
        BEGIN_PROFILE("SPMV_DOT:gemv_dot");
        last_ev = custom::SpGEMV_DOT(main_queue, sparseM, x.values, y.values, pAp, halo_deps);
        END_PROFILE_WAIT("SPMV_DOT:gemv_dot", last_ev);
    }
    catch (sycl::exception const &e) {
        std::cout << "\t\tCaught synchronous SYCL exception in SPMV_DOT:\n" << e.what() << std::endl;
        ierr += 1;
        return sycl::event();
    }
    catch (std::exception const &e) {
        std::cout << "\t\tCaught std exception in SPMV_DOT:\n" << e.what() << std::endl;
        ierr += 1;
        return sycl::event();
    }

    return last_ev;
}
