module TourEval where

import Data.Array
import Data.Maybe
import List
import Point 

import Debug.Trace
import Control.Exception


toureval pts = 	let
		 tour = fi pts
		in assert ((length tour) == (length pts)) tour

fi pts = let 
           allnn = allnnEval pts
           start = startTour allnn
	   nontour = nonTour pts start
         in fiIter (assert (startTourOK start pts) start) 
                   allnn 
                   nontour 


fiIter tour _ [] = tour
fiIter tour allnn ((dtour,tourpt,ntourpt):nontourpq) = 
	let nearTN = nearestTourNeighbor ntourpt tour allnn 
        in if nearTN == tourpt
           then fiIter (tourInsert ntourpt tourpt tour) allnn nontourpq
	   else fiIter tour allnn (nonTourpqUpdate (nearTN,ntourpt) nontourpq)


-- evaluate all nearest neighbors
--allnnEval pts = map (\pt -> (pt,sortBy (distsrc pt) [np | np <- pts, np/=pt])) 
--                    pts
allnnEval pts = map (\pt -> (pt,tail (sortBy (distsrc pt) [np | np <- pts])))
                    pts
  

allnnOK allnn = all (\ (pt,nn) -> (length allnn) ==  ((length nn)+1)) allnn

distsrc pt pt0 pt1 = distcmp (pt,pt0) (pt,pt1)
 
-- compare pairs of points wrt squared distance
distcmp pp0 pp1 =
	let 
 	  dpp0 = sqrdist pp0
          dpp1 = sqrdist pp1
	in if dpp0 < dpp1 then LT
	   else if dpp0 == dpp1 then EQ
	   else GT


---------------- start tour  -----------------
startTour allnn = 
	let (pt0,pt1) = maximumBy distcmp (map (\(pt,nn) -> (pt,last nn)) allnn)
	in [pt0,pt1]

startTourOK [pt0,pt1] pts = all (\ pt -> sqrdist (pt1,pt0) >= sqrdist (pt,pt0)
                                         && sqrdist(pt1,pt0) >= sqrdist(pt,pt1))
                                pts

-- nearest tour neighbor finding

nearestTourNeighbor ntp tour allnn = 
	let 
	  (Just nn) = lookup ntp allnn
	in head (filter (\pt -> elem pt tour) nn)


--  tour insertion

tourInsert nTpt pt tour = 
	let tb = tourInsertBefore nTpt pt tour
            ta = tourInsertAfter nTpt pt tour
	in if (tourWeight tb) <= (tourWeight ta)
	   then assert ((length tb) == (length tour)+1) tb
	   else assert ((length ta) == (length tour)+1) ta 


tourInsertBefore nTpt pt tour = (takeWhile (/=pt) tour)++
			 	(nTpt:(dropWhile (/=pt) tour))

tourInsertAfter nTpt pt tour = 
	let frpt = dropWhile (/= pt) tour
	in (takeWhile (/=pt) tour)++((head frpt):nTpt:(tail frpt))


------ tour weight

tourWeight :: [Point Int]->Int
tourWeight tour = foldl (+) 0  (map sqrdist (zip (tour++[(head tour)]) 
                                               (tail (tour++[(head tour)]))))

-- nontour priority queue
nonTour pts [pt0,pt1] = sort (map (tourNeighbor pt0 pt1) 
			              (filter (\pt -> pt/= pt0 && pt/= pt1)
						pts))

tourNeighbor pt0 pt1 pt = 
	let dpt0 = sqrdist (pt0,pt)
	    dpt1 = sqrdist (pt1,pt)
	in if dpt0 <= dpt1
	   then (dpt0,pt0,pt)
	   else (dpt1,pt1,pt)
 

nonTourpqUpdate (nearTN,ntourpt) nontourpq =
	let d = sqrdist (nearTN,ntourpt)
	    newnontourpq = (takeWhile (gtdtour d) nontourpq)++
	     ((d,nearTN,ntourpt):(dropWhile (gtdtour d) nontourpq))
	in assert ((length newnontourpq)==(length nontourpq)+1) newnontourpq
	   where gtdtour d (d1,_,_) = d1 > d


