# astar.py
# Written by Erez Sh, 22/07/07

import bisect, math
class PriorityQueue(object):
    """A queue in which x is returned before y if f(x) > f(y).
    add(item)                   -- O(log2 n)
    remove(item)                -- O(log2 n)
    __contains__(item) -> bool  -- O(log2 n)
    pop() --> item              -- O(1)

    Erez Sh, 22/07/07
    """
    def __init__(self, key=lambda x: x):
        self.l = []
        self.key = key
    def add(self, item):
        bisect.insort(self.l, (self.key(item), item))
    def remove(self, item):
        if len(self.l)==0:
            raise ValueError()
        prioritized_item = self.key(item), item
        pos = bisect.bisect_left(self.l, prioritized_item)
        if self.l[pos] == prioritized_item:
            del self.l[pos]
        else:
            raise ValueError()
    def __contains__(self, item):
        if len(self.l)==0:
            return False
        prioritized_item = self.key(item), item
        pos = bisect.bisect_left(self.l, prioritized_item)
        return self.l[pos] == prioritized_item
    def __len__(self):
        return len(self.l)
    def pop(self):
        return self.l.pop()[1]

class AStar(object):
    """A* implementation. 
    On __init__ accepts node_map, starting_node and target_node,
    where node_map must be an object that provides these methods:
        * get_priority(node) -> int
        * arrival_cost(node) -> int
        * parent(node) -> node
        * neighbors(node) -> node iterable
        * movement_cost(from_node, to_node) -> int
        * set_node(node, parent, cost)

    To solve call step() repeatedly while it returns None.
    Eventually it will return the path as a reversed list of nodes.

    Erez Sh, 22/07/07
    """

    INFINITE_COST = 1.0e400    # Taken from AIMA

    def __init__(self, node_map, starting_node, target_node):
        # 'open' queue is used to determine next node. We negate the priority so that smaller values will pop first
        self.open = PriorityQueue( lambda node: -node_map.get_priority(node) )

        node_map.set_node(starting_node, None, 0)
        self.open.add(starting_node)

        self.closed = set() # Used to determine where we've already visited

        self.node_map = node_map
        self.target_node = target_node

    def step(self):
        node_map = self.node_map

        current = self.open.pop()

        # If reached target, return found path
        if current == self.target_node:
            return self.get_path_to_node(current)

        self.closed.add(current)

        cur_cost = node_map.arrival_cost(current)
        for neighbor in node_map.neighbors(current):
            if neighbor in self.closed:  # This may cause problems with bad heuristics
                continue

            new_cost = cur_cost + node_map.movement_cost(current, neighbor)

            if neighbor not in self.open or new_cost < node_map.arrival_cost(neighbor):
                # Remove works because we haven't made any changes to neighbor yet
                try:
                    self.open.remove(neighbor)
                except ValueError:
                    pass   # Well, if it's not there, then no biggie

                node_map.set_node(neighbor, current, new_cost)

                # Add to queue for later inspection, node_map determines the priority for us
                self.open.add(neighbor)

    def get_path_to_node(self, node):
        l = []
        while self.node_map.parent(node) is not None: # Only starting-node has no parent
            l.append(node)
            node = self.node_map.parent(node)
        return l