"""
Permutations.py, by KWR and RJL.  Models permutations in two forms:
1. as a 1-1 map from any set U to itself.
2. as a partition of U into cycles.
Python allows any immutable type to be the key type for a mapping (dict).
Thus we can use r-subsets of U directly as arguments for the r-fold power.
We can later convert to a Permutation object on [1 .. N] with Nat as type.

The closest standard (essentialy built-in) package for what we need is "itertools":
https://docs.python.org/3.1/library/itertools.html
One lurking issue is that this package provides /generators/ which give a use-once
stream.  You can capture such a stream to a list---see the "4-horse race" example at
http://stackoverflow.com/questions/231767/what-does-the-yield-keyword-do
Since we're not caring about memory usage (yet), I've done so here.  

The data elements, using mythical square brackets for intended types, are:

   tuple[U] domain   #fixed and immutable, hence hashable for dictionary
   dict[U,U] mp
   dict[U,U] mpinv
   list[tuple[U]] cycles
   
The domain field is technically redundant---we could make a tuple out of
keys(mp)---but it is helpful in writing code functions and the fact that
it is immutable whereas mp and mpinv are mutable is important.
Among things to get used to is that fields declared outside the "__init__(...)"
constructor are "static" in C++ terms.
"""

from itertools import *     # allows product, combinations etc. with no prefix
from copy import *          # for deepcopy if needed

def sstr(x):  # Print lists inside [] with comma separators
   if (isinstance(x, str)):   # in Python 2.x, use isinstance(x,basestring)
      return x
   elif isinstance(x, (list, tuple)):
      return ",".join( sstr(y) for y in x).join('[]')   # note recursion, Python hacks
   else:
      return str(x)

class Permutation:
   def __init__(self,dom):  #create identity permutation
      self.domain = dom
      self.mp = {arg:arg for arg in dom}
      self.mpinv = {arg:arg for arg in dom}
      self.cycles = [(arg,) for arg in dom]   # comma needed in Python singleton tuple
      #self.mp = {arg:deepcopy(arg) for arg in dom}
      #self.mpinv = {arg:deepcopy(arg) for arg in dom}
      #self.cycles = [(deepcopy(arg),) for arg in dom]

   # CLASS INV: The inverse and cycle form are maintained by every operation
   # CLASS INV: domain is always sorted natively
     
   def strmp(self, ownlines=False):    
      sep = '\n' if ownlines else ", "
      return sep.join( '('+sstr(x)+','+sstr(self.mp[x])+')' for x in self.domain).join('[]')

   def strmpinv(self, ownlines=False):
      sep = '\n' if ownlines else ", "
      return sep.join( '('+sstr(x)+','+sstr(self.mpinv[x])+')' for x in self.domain).join('[]')

   def strcycles(self, ownlines=False):
      ostr = "["
      if ownlines:
         ostr += '\n'
      for X in self.cycles:   # each X is one cycle
         ostr += ','.join(map(sstr,X)).join('()')   # does not print comma between cycles
         if ownlines:
            ostr += '\n'
      #endfor
      ostr += ']'
      return ostr
     
   # Intended as private method, hence the capital C in the name
   def makeCycles(self):   # OK to store references without deep copy
      self.cycles = []
      itemsSeen = set()
      for x in self.domain:
         if x not in itemsSeen:  # start new cycle
            itemsSeen.add(x)
            newCycle = [x]
            y = self.mp[x]
            while (y != x):
               itemsSeen.add(y)
               newCycle.append(y)
               y = self.mp[y]

            self.cycles.append(newCycle)
         #endif
      #endfor

   # The intended-public method to use in place of direct call to makeCycles
   # May later have other calls in its body, e.g. maintain "mpinv" here.  [Update: done]
   def refresh(self):     
      self.makeCycles()
      for X in self.domain:
         self.mpinv[self.mp[X]] = X
       
   def swap(self,a,b):
      ainv = self.mpinv[a]
      binv = self.mpinv[b]
      (self.mp[binv],self.mp[ainv]) = (a,b)  # parallel reference assignment
      # (self.mpinv[b],self.mpinv[a]) = (ainv,binv)
      self.refresh()
     
   def composewith(self,perm):  # Can enlarge domain, e.g. when multiplying cycles
      for x in self.domain:
         y = self.mp[x]
         if y in perm.domain:
            y = perm.mp[y]
         #endif
         self.mp[x] = y
         # self.mpinv[y] = x

      for z in perm.domain:
         if z not in self.domain:
             self.domain += (z,)
             self.mp[z] = perm.mp[z]
             # self.mpinv[z] = perm.mpinv[z]
         #endif
      #endfor
      self.domain = tuple(sorted(self.domain))
      self.refresh()
     
   def numcycles(self):
      return len(self.cycles)

   def __add__(self,other):
      return flatten(compose(cartprod(cycle(1),self),cartprod(cycle(2),other)))

   def __rmul__(self,m):
      ret = self
      for i in range(m-1):
         ret = ret + self
      return ret

################## end of class ####################

def prp(perm, ownlines=False):
   print(perm.strmp(ownlines))

def pri(perm, ownlines=False):
   print(perm.strmpinv(ownlines))

def prc(perm, ownlines=True):
   print(perm.strcycles(ownlines))

def spec(perm):
   seq = []
   for cycle in perm.cycles:
      seq += (len(cycle),)
   print(len(seq), ":", sstr(seq))
   #return seq

def idperm(n):
   return Permutation(tuple(range(1, 1+n)))   # numbers from 1 to n

def cycle(head, *args):  # Need not be consecutive integers
   perm = Permutation((head,) + args)
   prev = head
   for arg in args:
      perm.mp[prev] = arg
      # perm.mpinv[arg] = prev
      prev = arg

   perm.mp[prev] = head
   # perm.mpinv[head] = prev
   perm.refresh()
   return perm

def C(n):
   return cycle(1,*range(2,n+1))
   
def compose(perm1,perm2):   # left-to-right, i.e. giving x |--> perm2(perm1(x))
   perm3 = Permutation(perm1.domain)
   perm3.composewith(perm1)
   perm3.composewith(perm2)
   return perm3
   
def cartprod(perm1,perm2):  # Keys become ordered tuples not sets 
   perm3 = Permutation(tuple(product(perm1.domain, perm2.domain)))
   for (a,b) in perm3.domain:
      perm3.mp[(a,b)] = (perm1.mp[a],perm2.mp[b])
      # perm3.mpinv[(a,b)] = (perm1.mpinv[a],perm2.mpinv[b])

   perm3.refresh()
   return perm3


# If the domains of perm1 and perm2 are already tuple types, then
# instead of a tuple of tuples (T, U), make a tuple of T+U (union). 
# This method resets the domain afterward.
def flatprod(perm1,perm2): 
   perm3 = Permutation(tuple(product(perm1.domain, perm2.domain)))
   perm3.mp = {}
   for (a,b) in perm3.domain:
      if isinstance(a,(tuple,list)) and isinstance(b,(tuple,list)):
         newarg = tuple(sorted(a + b))
         perm3.mp[newarg] = tuple(sorted(perm1.mp[a] + perm2.mp[b]))

   #endfor
   perm3.domain = tuple(sorted(perm3.mp.keys()))
   perm3.refresh()
   return perm3
   
# Keys become sets, which we maintain as always-sorted tuples
def setpower(perm,k):
   permk = Permutation(tuple(combinations(perm.domain, k)))  # unordered k-sets
   for X in permk.domain:
      permk.mp[X] = tuple(sorted(map(lambda y: perm.mp[y], X)))      # the magic line
      # permk.mpinv[X] = tuple(sorted(map(lambda y: perm.mpinv[y], X)))

   permk.refresh()
   return permk

# REQ: A.domain and B.domain are disjoint
def comppower(A,B,k):
   #dom = A.domain + B.domain
   #C0 = Permutation(dom)
   #C = setpower(C0,k) 
   C = flatprod(setpower(A,0),setpower(B,k))
   for j in (tuple(range(1, k+1))):
      C = compose(C, flatprod(setpower(A,j),setpower(B,k-j)))

   C.domain = tuple(sorted(C.mp.keys()))
   C.refresh()
   return C


def flatten(perm):  # Re-labels every element by integers 1..N
   n = len(perm.domain)
   iperm = idperm(n)
   # domlist = list(perm.domain)
   i = 0
   hh = {}
   for X in perm.domain:
      i += 1
      hh[X] = i
   
   for X in perm.domain:
      iperm.mp[hh[X]] = hh[perm.mp[X]]
      # iperm.mpinv[hh[X]] = hh[perm.mpinv[X]]

   iperm.refresh()
   return iperm
      
   
# Some sample code
#u = cycle(1,2,3,4)
u = C(4)
#v = cycle(5,6,7)
v = C(3)
#w = compose(u,v)
w = u + 2*v
x = setpower(w,2)
N = x.numcycles()
print("Number of cycles in x = setpower(C_4 + 2*C_3, 2) is ", N)
f = flatten(x)
print("\nThe cycle structure of x is:")
#print(x.strcycles(True))
prc(x,True)
print("\nFlattened into a permutation of ints:")
prp(f)
print("=")
prc(f,True)
