#include <iostream>
#include <fstream>
#include <stdlib.h>

#include <ida/ida.h>                 /* prototypes for IDA fcts., consts.    */
#include <nvector/nvector_serial.h>    /* access to serial N_Vector            */
#include <sunmatrix/sunmatrix_dense.h> /* access to dense SUNMatrix            */
#include <sunlinsol/sunlinsol_dense.h> /* access to dense SUNLinearSolver      */
#include <sundials/sundials_types.h>   /* defs. of realtype, sunindextype      */
#include <sundials/sundials_math.h>    /* defs. of SUNRabs, SUNRexp, etc.      */


struct ModelData {
    double p[8];
};

#define NEQ 10

void output_result(std::ofstream &outputfile, double t, N_Vector &y, int N);
static int res(realtype t, N_Vector y, N_Vector yp, N_Vector resval, void *params);

int
main(int argc, char** argv)
{
    N_Vector y, yp, abstol, id;
    double T0 = 0.0; // Start time

    ModelData params;
    params.p[0] = 21.893;
    params.p[1] = 2.14e9;
    params.p[2] = 32.318;
    params.p[3] = 21.893;
    params.p[4] = 1.07e9;
    params.p[5] = 7.65e-18;
    params.p[6] = 4.03e-11;
    params.p[7] = 5.32e-18;

    y = N_VNew_Serial(NEQ);
    yp = N_VNew_Serial(NEQ);
    if(y == NULL || yp == NULL) {
        std::cerr << "Error allocating y or yp vector, exiting" << std::endl;
        return -1;
    }

    // Set initial values
    NV_Ith_S(y, 0) = 1.5776;
    NV_Ith_S(y, 1) = 8.32;
    NV_Ith_S(y, 2) = 0;
    NV_Ith_S(y, 3) = 0;
    NV_Ith_S(y, 4) = 0;
    NV_Ith_S(y, 5) = 0.0131;
    NV_Ith_S(y, 6) = 0.5*(-params.p[6] + sqrt((pow(params.p[6], 2) + 4*params.p[6]*1.5776)));
    NV_Ith_S(y, 7) = 0.5*(-params.p[6] + sqrt((pow(params.p[6], 2) + 4*params.p[6]*1.5776)));
    NV_Ith_S(y, 8) = 0;
    NV_Ith_S(y, 9) = 0;
    
    for(int k=0; k < NEQ; k++) {
        NV_Ith_S(yp, k) = 0.0;
    }
    
    void *ida_mem = IDACreate();
    if(ida_mem == NULL) {
        std::cerr << "Error allocating ida object, exiting" << std::endl;
        return -1;
    }

    if(IDAInit(ida_mem, res, T0, y, yp) < 0) {
        std::cerr << "Error allocating space for solver, exiting" << std::endl;
        return -1;
    }
    
    realtype reltol = RCONST(1.0e-7);
    abstol = N_VNew_Serial(NEQ);
    for(int k=0; k < NEQ; k++) {
        NV_Ith_S(abstol, k) = 1.0e-7;
    }
    if(IDASVtolerances(ida_mem, reltol, abstol) < 0) {
        std::cerr << "Error setting tolerances for solver, exiting" << std::endl;
        return -1;
    }

    /* Set ID vector */
    id = N_VNew_Serial(NEQ);
    NV_Ith_S(id, 0) = 1.0;
    NV_Ith_S(id, 1) = 1.0;
    NV_Ith_S(id, 2) = 1.0;
    NV_Ith_S(id, 3) = 1.0;
    NV_Ith_S(id, 4) = 1.0;
    NV_Ith_S(id, 5) = 1.0;
    NV_Ith_S(id, 6) = 1.0;
    NV_Ith_S(id, 7) = 0.0;
    NV_Ith_S(id, 8) = 0.0;
    NV_Ith_S(id, 9) = 0.0;
    if(IDASetId(ida_mem, id) < 0) {
        std::cerr << "Error setting ID, exiting" << std::endl;
        return -1;
    }

    if(IDASetUserData(ida_mem, &params) < 0) {
        std::cerr << "Error setting user data, exiting" << std::endl;
        return -1;
    }

    /* Create dense SUNMatrix for use in linear solves */
    SUNMatrix A = SUNDenseMatrix(NEQ, NEQ);
    if(A == NULL) {
        std::cerr << "Error allocating SUNMatrix for linear solves, exiting" << std::endl;
        return -1;
    }

    /* Create dense SUNLinearSolver object */
    SUNLinearSolver LS = SUNLinSol_Dense(y, A);
    if(LS == NULL) {
        std::cerr << "Error allocating SUNLinSol_Dense, exiting" << std::endl;
        return -1;
    }

     if(IDASetLinearSolver(ida_mem, LS, A) < 0) {
         std::cerr << "Error attaching SUNMatrix and SUNLinSol_dense matrices to solver, exiting" << std::endl;
         return -1;
     }
     
    N_Vector res_vec = N_VNew_Serial(NEQ);
    res(0.0, y, yp, res_vec, (void *)&params);

    if(IDACalcIC(ida_mem, IDA_YA_YDP_INIT, 0.1) < 0) {
        std::cerr << "Error computing consistent initial conditions, exiting" << std::endl;
        for(int k=0; k < NEQ; k++) {
            std::cout << NV_Ith_S(y, k) << " ";
        }
        std::cout << std::endl;
        return -1;
    }
     
    if(IDAGetConsistentIC(ida_mem, y, yp) < 0) {
         std::cerr << "Error retrieving consistent initial conditions, exiting" << std::endl;
         return -1;
    }

    std::cout << "Consistent Initial conditions" << std::endl;
    for(int k=0; k < NEQ; k++) {
        std::cout << k+1 << ": " << NV_Ith_S(y, k) << ", " << NV_Ith_S(yp, k) << std::endl;
    }
    
    N_Vector res_val = N_VNew_Serial(NEQ);
    res(0.0, y, yp, res_val, &params);
    std::cout << "Calling residual equation with consistent initial conditions" << std::endl;
    double max_res = -1.0;
    for(int k=0; k < NEQ; k++) {
        if(abs(NV_Ith_S(res_val, k)) > max_res) {
            max_res = abs(NV_Ith_S(res_val, k));
        }
    }
    std::cout << "  max residual value: " << max_res << std::endl;
    N_VDestroy(res_val);

    IDASetMaxNumSteps(ida_mem, 500);
    
    
    std::cout << "Start solving" << std::endl;
    double Tend = 10.0;
    int num_steps = 100;
    double dt = Tend/num_steps;
    double tout = dt;
    realtype t = T0;
    
    std::string filename = "solution.csv";
    std::ofstream outputfile;
    std::cout << "Output to file " << filename << std::endl;
    outputfile.open(filename);

    for (int iout=1; iout <= num_steps; iout++, tout += dt) {
        output_result(outputfile, t, y, NEQ);
        if(IDASolve(ida_mem, tout, &t, y, yp, IDA_NORMAL) < 0) {
            std::cerr << "Error solving, exiting" << std::endl;
            break;
        }
    }
    output_result(outputfile, t, y, NEQ);
    outputfile.close();
    
    N_VDestroy(y);
    N_VDestroy(yp);
    N_VDestroy(abstol);
    N_VDestroy(id);

    IDAFree(&ida_mem);
    SUNLinSolFree(LS);
    SUNMatDestroy(A);

    return 0;
}

void
output_result(std::ofstream &outputfile, double t, N_Vector &y, int N)
{
    outputfile << t << ", ";
    for(int k=0; k < N-1; k++) {
        outputfile << NV_Ith_S(y, k) << ", ";
    }
    outputfile << NV_Ith_S(y, N-1) << std::endl;
}


static int
res(realtype t, N_Vector y_vec, N_Vector yp_vec, N_Vector resval, void *user_data)
{
    double* p = ((ModelData *)user_data)->p;

    double p1 = p[0];
    double p2 = p[1];
    double p3 = p[2];
    double p4 = p[3];
    double p5 = p[4];
    double p6 = p[5];
    double p7 = p[6];
    double p8 = p[7];

    double u1 = NV_Ith_S(y_vec, 0);
    double u2 = NV_Ith_S(y_vec, 1);
    double u3 = NV_Ith_S(y_vec, 2);
    double u4 = NV_Ith_S(y_vec, 3);
    double u5 = NV_Ith_S(y_vec, 4);
    double u6 = NV_Ith_S(y_vec, 5);
    double u7 = NV_Ith_S(y_vec, 6);
    double u8 = NV_Ith_S(y_vec, 7);
    double u9 = NV_Ith_S(y_vec, 8);
    double u10 = NV_Ith_S(y_vec, 9);

    double up1 = NV_Ith_S(yp_vec, 0);
    double up2 = NV_Ith_S(yp_vec, 1);
    double up3 = NV_Ith_S(yp_vec, 2);
    double up4 = NV_Ith_S(yp_vec, 3);
    double up5 = NV_Ith_S(yp_vec, 4);
    double up6 = NV_Ith_S(yp_vec, 5);
    double up7 = NV_Ith_S(yp_vec, 6);

    NV_Ith_S(resval, 0) = -up1 - p3*u2*u8;
    NV_Ith_S(resval, 1) = -up2 - p1*u2*u6 + p2*u10 - p3*u2*u8;
    NV_Ith_S(resval, 2) = -up3 + p3*u2*u8 + p4*u4*u6 - p5*u9;
    NV_Ith_S(resval, 3) = -up4 - p4*u4*u6 + p5*u9;
    NV_Ith_S(resval, 4) = -up5 + p1*u2*u6 - p2*u10;
    NV_Ith_S(resval, 5) = -up6 - p1*u2*u6 - p4*u4*u6 + p2*u10 + p5*u9;
    NV_Ith_S(resval, 6) = -up7 - 0.0131 + u6 + u8 + u9 + u10;
    NV_Ith_S(resval, 7) = u8 - p7*u1/(p7 + u7);
    NV_Ith_S(resval, 8) = u9 - p8*u3/(p8 + u7);
    NV_Ith_S(resval, 9) = u10 - p6*u5/(p6 + u7);
    return 0;
}