! Copyright (C) 2022  Light and Molecules Group

! This program is free software: you can redistribute it and/or modify
! it under the terms of the GNU General Public License as published by
! the Free Software Foundation, either version 3 of the License, or
! (at your option) any later version.

! This program is distributed in the hope that it will be useful,
! but WITHOUT ANY WARRANTY; without even the implied warranty of
! MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
! GNU General Public License for more details.

! You should have received a copy of the GNU General Public License
! along with this program.  If not, see <https://www.gnu.org/licenses/>.

program analysis
  !! Statistical analysis program for Newton-X.
  !!
  !! The program is intended at being launched from the directory
  !! where the ./TRAJECTORIES folder can be found. It will recursively
  !! walk in each TRAJn trajectory, and compute the following
  !! averages:
  !!
  !! - Total energies, Kinetic energies, Potential energies
  !! - Electronic populations
  !!
  !! It will also output the fraction of trajectories in each state at
  !! each step of the dynamics. To do so, we start by setting an array
  !! of size nstate [0, 0, ..., 0, 0, ntraj, 0, 0, ..., 0] where ntraj
  !! is at position nstatdyn_init. Then at each step, when a hopping
  !! occurs, we adjust the numbers in the array. When all trajectories
  !! are done, the final array contains, at each step, how many
  !! trajectories are in the given state. We just have to divide the
  !! array by the number of trajectories to get the final result.
  use mod_kinds, only: dp
  use mod_constants
#ifdef USE_HDF5
  use hdf5
  use mod_h5md
#endif

  use mod_statistics
  implicit none

  character(len=MAX_STR_SIZE) :: path_to_traj
  !! Path to individual trajectories folders.
  character(len=MAX_STR_SIZE) :: filename
  !! File to read

  integer :: nstat
  !! Number of states included in the simulation.
  integer :: nsteps
  !! Number of steps in each dynamics.
  integer :: ntraj
  !! Total number of trajectories.

  integer :: nstatdyn_init

  integer :: offset
  !! Start to read the files at line ``offset``.
  integer, dimension(:), allocatable :: usecols
  !! Array determining which columns to use from filename.
  integer :: ncols
  !! Number of columns that will be read from the file.

  logical :: do_energies, do_fractraj, do_populations
  logical :: print_std ! If true, prints standard deviation (else,
  ! variance)

  integer, parameter :: MAX_COLS=1024

#ifdef USE_HDF5
  type(nx_h5md_t) :: h5md_f
  integer :: hdferr
  integer, dimension(:), allocatable :: step
#endif

  real(dp), dimension(:), allocatable :: time
  real(dp), dimension(:, :), allocatable :: en_avg
  real(dp), dimension(:, :), allocatable :: en_var
  real(dp), dimension(:, :), allocatable :: pop_avg
  real(dp), dimension(:, :), allocatable :: pop_var
  integer, dimension(:, :), allocatable :: fractraj
  !! Fraction of trajectories in each states


  integer :: u, i

  namelist /nxstats/ path_to_traj, offset, nsteps, &
       & ntraj, ncols, usecols, nstat, &
       & nstatdyn_init, nstat, &
       & do_energies, do_populations, do_fractraj, &
       & print_std

  ! Let's put some defaults values here for convenience
  path_to_traj = './TRAJECTORIES/TRAJ'
  offset = 2
  nsteps = -1
  ntraj = -1
  nstat = 2
  ncols = 5
  nstatdyn_init = 2
  do_energies = .true.
  do_populations = .false.
  do_fractraj = .false.
  print_std = .false.

  ! Read input parameters
  allocate(usecols(MAX_COLS))
  open(newunit=u, file='analysis.inp', action='read')
  read(u, nml=nxstats)
  deallocate(usecols)
  allocate(usecols(ncols))
  rewind(u)
  read(u, nml=nxstats)
  close(u)

  ! Initialize some arrays
  if (do_energies) then
     allocate(en_avg(nstat+3, nsteps))
     allocate(en_var(nstat+3, nsteps))
  end if

  if (do_populations) then
     allocate(pop_avg(nstat+1, nsteps))
     allocate(pop_var(nstat+1, nsteps))
  end if

  if (do_fractraj) then
     allocate(fractraj(nsteps, nstat))
     fractraj(:, :) = 0
     fractraj(:, nstatdyn_init) = ntraj
  end if

  ! Process files
#ifdef USE_HDF5
  call h5open_f(hdferr)

  filename = trim(path_to_traj)//'1/dyn.h5'
  call h5md_f%open(filename)
  call h5_get_time_step(h5md_f%obs_id, 'total_energy', time, step)
  call h5md_f%close()

  ! Fill the first line with time data
  if (do_energies) then
     en_avg(1, :) = time(:)
     en_var(1, :) = time(:)
  end if
  if (do_populations) then
     pop_avg(1, :) = time(:)
     pop_var(1, :) = time(:)
  end if

  do i=1, ntraj
     write(filename, '(a,i0,a)')&
          & trim(path_to_traj), i, '/dyn.h5'
     write(*, *) 'Reading file ', trim(filename)
     call h5md_f%open(filename)

     if (do_fractraj) then
        write(*, *) '   Getting nstatdyn info ... '
        call h5_compare_nstatdyn(h5md_f%obs_id, fractraj, nstatdyn_init)
     end if

     if (do_energies) then
        write(*, *) '   Getting energies ...'
        call h5_stats_energies(h5md_f%obs_id, en_avg, en_var, i)
     end if

     if (do_populations) then
        write(*, *) '   Getting populations ...'
        call h5_update_avg_var(h5md_f%obs_id, 'populations', &
             & pop_avg(2:nstat+1, :), pop_var(2:nstat+1, :), i)
     end if

     call h5md_f%close()
  end do

  call h5close_f(hdferr)
#else

  ! Fraction of trajectories
  if (do_fractraj) then
     allocate(time(nsteps))
     do i=1, ntraj
        write(filename, '(a,i0,a,a,a)') &
             & trim(path_to_traj), i, '/energies.dat'
        call txt_compare_nstatdyn(filename, fractraj, offset, nsteps,&
             & nstatdyn_init, time)
     end do
  end if

  do i=1, ntraj
     if (do_energies) then
        write(filename, '(a,i0,a)') &
             & trim(path_to_traj), i, '/energies.dat'
        call txt_update_avg_var(filename, en_avg, en_var, offset, nsteps,&
             & i, usecols=usecols)
     end if

     if (do_populations) then
        write(filename, '(a,i0,a)') &
             & trim(path_to_traj), i, '/populations.dat'
        call txt_update_avg_var(filename, pop_avg, pop_var, offset, nsteps,&
             & i)
     end if
  end do
#endif

  ! Print the results
  if (do_fractraj) then
     filename = './fractraj.dat'
     call pretty_print_fractraj(fractraj, time, './fraction_of_traj.dat')
  end if

  if (do_energies) then
     ! Print average
     write(filename, '(a)') './energies_avg.dat'
     call pretty_print_stats(en_avg, filename, 'energies')

     ! Print standard deviation (or variance)
     write(filename, '(a)') './energies_var.dat'
     if (print_std) then
        write(filename, '(a)') './energies_std.dat'
        en_var(2:size(en_var, 1), :) = dsqrt(en_var(2:size(en_var, 1),&
             & :))
     end if
     call pretty_print_stats(pop_var, filename, 'energies')
  end if

  if (do_populations) then
     write(filename, '(a)') './populations_avg.dat'
     call pretty_print_stats(pop_avg, filename, 'populations')

     write(filename, '(a)') './populations_var.dat'
     if (print_std) then
        write(filename, '(a)') './populations_std.dat'
        pop_var(2:size(pop_var, 1), :) = dsqrt(pop_var(2:size(pop_var, 1),&
             & :))
     end if
     call pretty_print_stats(pop_var, filename, 'populations')
  end if
end program analysis
