/*
 * Copyright (c) 2021, 2026 Contributors to the Eclipse Foundation
 *
 * This program and the accompanying materials are made
 * available under the terms of the Eclipse Public License 2.0
 * which is available at https://www.eclipse.org/legal/epl-2.0/
 *
 * SPDX-License-Identifier: EPL-2.0
 */
package org.eclipse.lsat.common.scheduler.algorithm;

import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Queue;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.StreamSupport;

import org.eclipse.lsat.common.scheduler.graph.Constraint;
import org.eclipse.lsat.common.scheduler.graph.GraphFactory;
import org.eclipse.lsat.common.scheduler.graph.JitConstraint;
import org.eclipse.lsat.common.scheduler.graph.Task;
import org.eclipse.lsat.common.scheduler.graph.TimeConstraint;

import lsat_graph.DispatchGroupTask;
import lsat_graph.PeripheralActionTask;
import lsat_graph.ReleaseTask;

/**
 * Implementation of the Bellman-Ford-Scheduler which schedules both ASAP and ALAP tasks, while allowing for JIT
 * (just-in-time) constraints and time-constraints with a lower and upper bound.
 * The Bellman-Ford-Scheduler gets its name from the Bellman-Ford shortest path algorithm which it uses to compute the
 * schedule. This Bellman-Ford implementation is inspired by https://www.geeksforgeeks.org/dsa/shortest-path-faster-algorithm/.
 * The interface to the implementation is via the *schedule* function. If a first pass it schedules all tasks in an
 * ASAP way, taking constraints into account. In a second pass all possible ALAP tasks are shifted where possible,
 * again taking constraints into account.
 */
public class BellmanFordImpl {
    private static final double ROUNDING_SCALE = 1.0E9;

    private enum SchedulingType {
        ASAP, ALAP
    }

    public static <T extends Task> Map<Task, BigDecimal> schedule(Iterable<T> tasks, Iterable<Constraint> constraints)
            throws SchedulerException
    {
        var schedule1 = new BellmanFordImpl(SchedulingType.ASAP);
        var asapResults = schedule1.doSchedule(tasks, constraints);

        var schedule2 = new BellmanFordImpl(SchedulingType.ALAP);
        var alapResults = schedule2.doSchedule(tasks, constraints, asapResults);

        // Combine results of ASAP and ALAP scheduling
        var scheduleDuration = computeScheduleDuration(asapResults);
        var combinedResults = new HashMap<Task, BigDecimal>();
        for (T task: tasks) {
            var startTimeAsap = asapResults.get(task);
            var startTimeAlap = scheduleDuration - alapResults.get(task) - task.getExecutionTime().doubleValue();
            var startTime = BigDecimal.valueOf(round(Math.max(startTimeAsap, startTimeAlap)));
            combinedResults.put(task, startTime);
        }
        return combinedResults;
    }

    private static final GraphFactory GRAPH_FACTORY = GraphFactory.eINSTANCE;

    private final SchedulingType type;

    private Graph graph;

    private Task startNode;

    private Task endNode;

    private BellmanFordImpl(SchedulingType t) {
        type = t;

        graph = new Graph();

        startNode = GRAPH_FACTORY.createTask();
        startNode.setExecutionTime(BigDecimal.ZERO);
        graph.addVertex(startNode);

        endNode = GRAPH_FACTORY.createTask();
        endNode.setExecutionTime(BigDecimal.ZERO);
        graph.addVertex(endNode);
    }

    private <T extends Task> Map<Task, Double> doSchedule(Iterable<T> tasks, Iterable<Constraint> constraints)
            throws SchedulerException
    {
        return doSchedule(tasks, constraints, null);
    }

    private <T extends Task> Map<Task, Double> doSchedule(Iterable<T> tasks, Iterable<Constraint> constraints,
            Map<Task, Double> asapResults) throws SchedulerException
    {
        try {
            buildGraph(tasks, constraints);
            if (asapResults != null) {
                addAsapConstraints(tasks, asapResults);
            }
            return graph.schedule((type == SchedulingType.ASAP) ? startNode : endNode);
        } catch (Graph.NegativeCycleDetectedException exception) {
            var affectedConstraints = findAffectedConstraints(exception, constraints);
            var errorMsg = "Could not compute a schedule due to invalid constraints:\n";
            errorMsg += affectedConstraints.stream().map(c -> " - " + ConstraintsUtil.constraintToString(c))
                    .collect(Collectors.joining("\n"));
            throw new SchedulerException(errorMsg, exception);
        }
    }

    private <T extends Task> void buildGraph(Iterable<T> tasks, Iterable<Constraint> constraints) throws Graph.NegativeCycleDetectedException {
        addTasks(tasks);
        addDependencies(tasks);
        addConstraintsForwardEdges(constraints);
        addConstraintsBackwardEdges(constraints);
    }

    private <T extends Task> void addAsapConstraints(Iterable<T> tasks, Map<Task, Double> asapResults) {
        // Set maximum duration of schedule by linking the start and end node with a time-constraint
        var duration = computeScheduleDuration(asapResults);
        addEdge(endNode, startNode, -duration); // weights of forward edges are negated
        addEdge(startNode, endNode, duration); // weights of backward edges are not negated

        // Fixate every task which cannot be moved as an ALAP task using time-constraints
        for (T task: tasks) {
            if (!supportsALAP(task)) {
                var weight = asapResults.get(task) + task.getExecutionTime().doubleValue();
                addEdge(task, startNode, -weight); // weights of forward edges are negated
                addEdge(startNode, task, weight); // weights of backward edges are not negated
            }
        }
    }

    private <T extends Task> void addTasks(Iterable<T> tasks) {
        for (T task: tasks) {
            graph.addVertex(task);
        }
    }

    private <T extends Task> void addDependencies(Iterable<T> tasks) {
        for (T task: tasks) {
            // Link all tasks without incoming edges (i.e. start tasks) to the start node
            if (task.getIncomingEdges().isEmpty()) {
                if (type == SchedulingType.ASAP) {
                    addEdge(startNode, task, 0.0);
                } else {
                    addEdge(task, startNode, -task.getExecutionTime().doubleValue());
                }
            }

            // Link all tasks without outgoing edges (i.e. end tasks) to the end node
            if (task.getOutgoingEdges().isEmpty()) {
                if (type == SchedulingType.ASAP) {
                    addEdge(task, endNode, -task.getExecutionTime().doubleValue());
                } else {
                    addEdge(endNode, task, 0.0);
                }
            }

            // Link a task to all tasks connected this task using the task's execution time as weight
            if (type == SchedulingType.ASAP) {
                for (var edge: task.getOutgoingEdges()) {
                    addEdge(task, (Task)edge.getTargetNode(), -task.getExecutionTime().doubleValue());
                }
            } else {
                for (var edge: task.getIncomingEdges()) {
                    addEdge(task, (Task)edge.getSourceNode(), -task.getExecutionTime().doubleValue());
                }
            }
        }
    }

    private void addConstraintsForwardEdges(Iterable<Constraint> constraints) {
        for (var constraint: constraints) {
            var begin = (type == SchedulingType.ASAP) ? constraint.getSource() : constraint.getTarget();
            var end = (type == SchedulingType.ASAP) ? constraint.getTarget() : constraint.getSource();

            var weight = begin.getExecutionTime();
            if (constraint instanceof TimeConstraint timeConstraint) {
                weight = weight.add(timeConstraint.getLowerBound());
            }

            addEdge(begin, end, -weight.doubleValue()); // weights of forward edges are also negated
        }
    }

    private void addConstraintsBackwardEdges(Iterable<Constraint> constraints) {
        // compute weight for each constraint before adding edges to the graph
        var weights = new HashMap<Constraint, BigDecimal>();
        for (var constraint: constraints) {
            var begin = (type == SchedulingType.ASAP) ? constraint.getSource() : constraint.getTarget();
            var end = (type == SchedulingType.ASAP) ? constraint.getTarget() : constraint.getSource();

            var weight = BigDecimal.ZERO;
            if (constraint instanceof JitConstraint) {
                var path = graph.shortestPath(begin).get(end);
                weight = new BigDecimal(-path); // negate the weight of the forward edge
            } else if (constraint instanceof TimeConstraint timeConstraint) {
                weight = begin.getExecutionTime().add(timeConstraint.getUpperBound());
            }
            weights.put(constraint, weight);
        }

        for (var constraint: constraints) {
            var begin = (type == SchedulingType.ASAP) ? constraint.getSource() : constraint.getTarget();
            var end = (type == SchedulingType.ASAP) ? constraint.getTarget() : constraint.getSource();
            addEdge(end, begin, weights.get(constraint).doubleValue()); // weights of backward edges are not negated
        }
    }

    private void addEdge(Task begin, Task end, double weight) {
        graph.addEdge(begin, end, weight);
    }

    private Set<Constraint> findAffectedConstraints(Graph.NegativeCycleDetectedException exception,
            Iterable<Constraint> constraints)
    {
        var nodes = exception.getCycle();
        var matchingConstraints = StreamSupport.stream(constraints.spliterator(), false)
                .filter(c -> nodes.contains(c.getSource()) || nodes.contains(c.getTarget()))
                .collect(Collectors.toCollection(LinkedHashSet::new));
        return matchingConstraints;
    }

    private static Double computeScheduleDuration(Map<Task, Double> results) {
        return results.values().stream().max(Double::compare).orElse(0.0);
    }

    private static <T extends Task> boolean supportsALAP(T task) {
        return switch (task) {
            case DispatchGroupTask ignored -> false;
            case PeripheralActionTask actionTask -> actionTask.getAction().scheduleAlap();
            case ReleaseTask ignored -> false;
            default -> true;
        };
    }

    private static double round(double value) {
        return Math.round(value * ROUNDING_SCALE) / ROUNDING_SCALE;
    }

    class Graph {
        private static final double EPSILON = 1.0E-9;

        private static final String GRAPH_CONTAINS_NEGATIVE_WEIGHT_CYCLE = "Graph contains a negative-weight cycle";

        // Map of a complex type vertex to an integer index
        private Map<Task, Integer> vertexToIndex = new HashMap<>();
        // List from integer index back to complex type vertex
        private List<Task> indexToVertex = new ArrayList<>();
        // For each vertex: a list of connected vertices with their weights
        private List<List<Pair>> edges = new ArrayList<>();

        public void addVertex(Task task) {
            vertexToIndex.put(task, vertexToIndex.size());
            indexToVertex.add(task);
            edges.add(new ArrayList<Pair>());
        }

        public void addEdge(Task frm, Task to, double weight) {
            int iFrm = vertexToIndex.get(frm);
            int iTo = vertexToIndex.get(to);
            edges.get(iFrm).add(new Pair(iTo, weight));
        }

        public Map<Task, Double> schedule(Task source) throws Graph.NegativeCycleDetectedException {
            return shortestPath(source, -1.0);
        }

        public Map<Task, Double> shortestPath(Task source) throws Graph.NegativeCycleDetectedException {
            return shortestPath(source, 1.0);
        }

        // This shortest path implementation was inspired by
        // https://www.geeksforgeeks.org/dsa/shortest-path-faster-algorithm/
        private Map<Task, Double> shortestPath(Task source, double scale) throws Graph.NegativeCycleDetectedException {
            final int sourceIndex = vertexToIndex.get(source);
            final int numVertices = vertexToIndex.size();

            // Create arrays for shortest distance and predecessor
            double[] d = new double[numVertices];
            Arrays.fill(d, Double.MAX_VALUE);
            int[] pred = new int[numVertices];
            Arrays.fill(pred, -1);

            // Boolean array to check if vertex is present in the queue or not
            boolean[] inQueue = new boolean[numVertices];

            // Negative cycle detection by keeping track how often each vertex is visited
            int[] visitCnt = new int[numVertices];
            boolean negativeWeightCycleDetected = false;

            // Start shortest part computations from the source
            d[sourceIndex] = 0;
            Queue<Integer> q = new LinkedList<>();
            q.add(sourceIndex);
            inQueue[sourceIndex] = true;

            while (!q.isEmpty()) {
                // Take the front vertex from queue
                int u = q.poll();
                inQueue[u] = false;

                if (visitCnt[u] > numVertices) {
                    negativeWeightCycleDetected = true;
                    break;
                }

                // Relaxing all the adjacent edges of vertex taken from the queue
                for (Pair edge: edges.get(u)) {
                    int v = edge.target;
                    double weight = edge.weight;

                    if (d[v] > d[u] + weight + EPSILON) {
                        d[v] = d[u] + weight;
                        pred[v] = u;

                        // Check if vertex v is in queue or not
                        // if not, then push it into the queue
                        if (!inQueue[v]) {
                            q.add(v);
                            inQueue[v] = true;
                            visitCnt[v]++;
                        }
                    }
                }
            }

            // Check for negative cycles
            if (negativeWeightCycleDetected) {
                while (!q.isEmpty()) {
                    int u = q.poll();
                    inQueue[u] = false;
                    for (Pair edge: edges.get(u)) {
                        int v = edge.target;
                        double weight = edge.weight;
                        if (d[v] > d[u] + weight + EPSILON) {
                            pred[v] = u;
                            throw new NegativeCycleDetectedException(GRAPH_CONTAINS_NEGATIVE_WEIGHT_CYCLE,
                                    computeNegativeCycle(pred, v));
                        }
                    }
                }
            }

            var results = new HashMap<Task, Double>();
            for (var entry: vertexToIndex.entrySet()) {
                results.put(entry.getKey(), scale * d[entry.getValue()]);
            }
            return results;
        }

        private List<Task> computeNegativeCycle(int[] pred, int start) {
            var visited = new HashSet<Integer>();
            visited.add(start);
            int cur = pred[start];
            while (!visited.contains(cur)) {
                visited.add(cur);
                cur = pred[cur];
            }
            // build the actual cycle
            var cycle = new ArrayList<Task>();
            start = cur;
            do {
                cycle.add(indexToVertex.get(cur));
                cur = pred[cur];
            } while (cur != start);
            Collections.reverse(cycle);
            return cycle;
        }

        class Pair {
            int target;
            double weight;

            public Pair(int target, double weight) {
                this.target = target;
                this.weight = weight;
            }
        }

        public class NegativeCycleDetectedException extends RuntimeException {
            private static final long serialVersionUID = -5497793361920910984L;

            private final List<Task> cycle;

            public NegativeCycleDetectedException(String message, List<Task> cycle) {
                super(message);
                this.cycle = cycle;
            }

            public List<Task> getCycle() {
                return cycle;
            }
        }
    }
}
