package fr.univ.dblp.simulation;

import fr.univ.dblp.utils.RandomSampler;
import org.graphstream.graph.Graph;
import org.graphstream.graph.Node;

import java.util.*;

/**
 * Simulateur de propagation virale dans un réseau.
 * Implémente le modèle SIS (Susceptible-Infected-Susceptible).
 *
 * Paramètres du modèle :
 * - β (beta) : Taux de transmission par contact (1 mail/semaine = 1/7 par jour)
 * - γ (gamma) : Taux de récupération (2 mises à jour/mois = 2/30 par jour)
 */
public class ViralSimulator {

    private final Graph graph;
    private final double transmissionRate; // β - probabilité de transmission par contact par jour
    private final double recoveryRate;     // γ - probabilité de récupération par jour
    private final Set<String> immuneNodes; // Nœuds immunisés de manière permanente
    private final Random random;

    /**
     * Constantes par défaut basées sur l'énoncé du TP
     */
    public static final double DEFAULT_TRANSMISSION_RATE = 1.0 / 7.0;  // 1 mail par semaine
    public static final double DEFAULT_RECOVERY_RATE = 2.0 / 30.0;     // 2 mises à jour par mois
    public static final int DEFAULT_SIMULATION_DAYS = 90;              // 3 mois

    /**
     * Constructeur avec paramètres par défaut
     */
    public ViralSimulator(Graph graph) {
        this(graph, DEFAULT_TRANSMISSION_RATE, DEFAULT_RECOVERY_RATE, new HashSet<>());
    }

    /**
     * Constructeur avec immunisation
     */
    public ViralSimulator(Graph graph, Set<String> immuneNodes) {
        this(graph, DEFAULT_TRANSMISSION_RATE, DEFAULT_RECOVERY_RATE, immuneNodes);
    }

    /**
     * Constructeur complet
     */
    public ViralSimulator(Graph graph, double transmissionRate, double recoveryRate, Set<String> immuneNodes) {
        this.graph = graph;
        this.transmissionRate = transmissionRate;
        this.recoveryRate = recoveryRate;
        this.immuneNodes = new HashSet<>(immuneNodes);
        this.random = new Random();
    }

    /**
     * Lance une simulation de propagation virale
     *
     * @param simulationDays Durée de la simulation en jours
     * @param scenarioName Nom du scénario pour les résultats
     * @return Résultats de la simulation
     */
    public SimulationResult simulate(int simulationDays, String scenarioName) {
        // Initialiser tous les nœuds
        initializeNodes();

        // Sélectionner un patient zéro aléatoire parmi les non-immunisés
        Node patientZero = selectPatientZero();
        if (patientZero == null) {
            return createEmptyResult(simulationDays, scenarioName);
        }

        patientZero.setAttribute("state", NodeState.INFECTED);

        // Listes pour stocker l'évolution temporelle
        List<Integer> susceptibleOverTime = new ArrayList<>();
        List<Integer> infectedOverTime = new ArrayList<>();
        List<Integer> immuneOverTime = new ArrayList<>();

        // Simulation jour par jour
        for (int day = 0; day <= simulationDays; day++) {
            // Enregistrer l'état actuel
            int[] counts = countNodesByState();
            susceptibleOverTime.add(counts[0]);
            infectedOverTime.add(counts[1]);
            immuneOverTime.add(counts[2]);

            // Si plus d'infectés, arrêter la simulation
            if (counts[1] == 0 && day > 0) {
                // Remplir le reste avec des zéros
                for (int i = day + 1; i <= simulationDays; i++) {
                    susceptibleOverTime.add(counts[0]);
                    infectedOverTime.add(0);
                    immuneOverTime.add(counts[2]);
                }
                break;
            }

            // Étape de transmission
            performTransmissionStep();

            // Étape de récupération
            performRecoveryStep();
        }

        return new SimulationResult(
            susceptibleOverTime,
            infectedOverTime,
            immuneOverTime,
            graph.getNodeCount(),
            simulationDays,
            scenarioName
        );
    }

    /**
     * Initialise l'état de tous les nœuds
     */
    private void initializeNodes() {
        for (Node node : graph) {
            if (immuneNodes.contains(node.getId())) {
                node.setAttribute("state", NodeState.IMMUNE);
            } else {
                node.setAttribute("state", NodeState.SUSCEPTIBLE);
            }
        }
    }

    /**
     * Sélectionne un patient zéro aléatoire parmi les nœuds non immunisés
     */
    private Node selectPatientZero() {
        List<Node> nonImmuneNodes = new ArrayList<>();
        for (Node node : graph) {
            NodeState state = (NodeState) node.getAttribute("state");
            if (state == NodeState.SUSCEPTIBLE) {
                nonImmuneNodes.add(node);
            }
        }

        if (nonImmuneNodes.isEmpty()) {
            return null;
        }

        return nonImmuneNodes.get(random.nextInt(nonImmuneNodes.size()));
    }

    /**
     * Étape de transmission : les nœuds infectés tentent d'infecter leurs voisins
     */
    private void performTransmissionStep() {
        // Collecter tous les nœuds infectés
        List<Node> infectedNodes = new ArrayList<>();
        for (Node node : graph) {
            NodeState state = (NodeState) node.getAttribute("state");
            if (state == NodeState.INFECTED) {
                infectedNodes.add(node);
            }
        }

        // Liste des nouvelles infections (à appliquer après pour éviter les modifications concurrentes)
        List<Node> newlyInfected = new ArrayList<>();

        // Pour chaque nœud infecté
        for (Node infected : infectedNodes) {
            // Examiner chaque voisin
            infected.edges().forEach(edge -> {
                Node neighbor = edge.getOpposite(infected);
                NodeState neighborState = (NodeState) neighbor.getAttribute("state");

                // Si le voisin est susceptible, tenter l'infection
                if (neighborState == NodeState.SUSCEPTIBLE) {
                    if (random.nextDouble() < transmissionRate) {
                        newlyInfected.add(neighbor);
                    }
                }
            });
        }

        // Appliquer les nouvelles infections
        for (Node node : newlyInfected) {
            node.setAttribute("state", NodeState.INFECTED);
        }
    }

    /**
     * Étape de récupération : les nœuds infectés peuvent se rétablir
     */
    private void performRecoveryStep() {
        List<Node> toRecover = new ArrayList<>();

        for (Node node : graph) {
            NodeState state = (NodeState) node.getAttribute("state");
            if (state == NodeState.INFECTED) {
                if (random.nextDouble() < recoveryRate) {
                    toRecover.add(node);
                }
            }
        }

        // Appliquer les récupérations (retour à SUSCEPTIBLE dans le modèle SIS)
        for (Node node : toRecover) {
            node.setAttribute("state", NodeState.SUSCEPTIBLE);
        }
    }

    /**
     * Compte le nombre de nœuds dans chaque état
     *
     * @return Tableau [susceptibles, infectés, immunisés]
     */
    private int[] countNodesByState() {
        int susceptible = 0;
        int infected = 0;
        int immune = 0;

        for (Node node : graph) {
            NodeState state = (NodeState) node.getAttribute("state");
            switch (state) {
                case SUSCEPTIBLE:
                    susceptible++;
                    break;
                case INFECTED:
                    infected++;
                    break;
                case IMMUNE:
                    immune++;
                    break;
            }
        }

        return new int[]{susceptible, infected, immune};
    }

    /**
     * Crée un résultat vide (aucune épidémie possible)
     */
    private SimulationResult createEmptyResult(int simulationDays, String scenarioName) {
        List<Integer> zeros = new ArrayList<>();
        List<Integer> immune = new ArrayList<>();
        for (int i = 0; i <= simulationDays; i++) {
            zeros.add(0);
            immune.add(immuneNodes.size());
        }

        return new SimulationResult(
            zeros,
            zeros,
            immune,
            graph.getNodeCount(),
            simulationDays,
            scenarioName
        );
    }

    /**
     * Lance plusieurs simulations et retourne les résultats moyennés
     *
     * @param numRuns Nombre de simulations à effectuer
     * @param simulationDays Durée de chaque simulation
     * @param scenarioName Nom du scénario
     * @return Résultats moyennés
     */
    public SimulationResult simulateMultipleRuns(int numRuns, int simulationDays, String scenarioName) {
        List<SimulationResult> allResults = new ArrayList<>();

        for (int run = 0; run < numRuns; run++) {
            SimulationResult result = simulate(simulationDays, scenarioName + "_run" + run);
            allResults.add(result);
        }

        return averageResults(allResults, scenarioName);
    }

    /**
     * Moyenne les résultats de plusieurs simulations
     */
    private SimulationResult averageResults(List<SimulationResult> results, String scenarioName) {
        if (results.isEmpty()) {
            return createEmptyResult(DEFAULT_SIMULATION_DAYS, scenarioName);
        }

        int simulationDays = results.get(0).getSimulationDays();
        int totalNodes = results.get(0).getTotalNodes();

        List<Integer> avgSusceptible = new ArrayList<>();
        List<Integer> avgInfected = new ArrayList<>();
        List<Integer> avgImmune = new ArrayList<>();

        for (int day = 0; day <= simulationDays; day++) {
            int sumSusceptible = 0;
            int sumInfected = 0;
            int sumImmune = 0;

            for (SimulationResult result : results) {
                sumSusceptible += result.getSusceptibleOverTime().get(day);
                sumInfected += result.getInfectedOverTime().get(day);
                sumImmune += result.getImmuneCount().get(day);
            }

            avgSusceptible.add(sumSusceptible / results.size());
            avgInfected.add(sumInfected / results.size());
            avgImmune.add(sumImmune / results.size());
        }

        return new SimulationResult(
            avgSusceptible,
            avgInfected,
            avgImmune,
            totalNodes,
            simulationDays,
            scenarioName
        );
    }

    public double getTransmissionRate() {
        return transmissionRate;
    }

    public double getRecoveryRate() {
        return recoveryRate;
    }

    public Set<String> getImmuneNodes() {
        return new HashSet<>(immuneNodes);
    }
}
