"""Illustrate the SVD geometrically.

Copyright (C) 2006 Stefan van der Walt

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are
met:

1. Redistributions of source code must retain the above copyright
   notice, this list of conditions and the following disclaimer.
2. Redistributions in binary form must reproduce the above copyright
   notice, this list of conditions and the following disclaimer in the
   documentation and/or other materials provided with the
   distribution.
3. The name of the author may not be used to endorse or promote
   products derived from this software without specific prior written
   permission.

THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT,
INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING
IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
POSSIBILITY OF SUCH DAMAGE.
"""

import numpy as N
import pylab as P
import copy

class CirclePoint(object):
    """
    Draggable arrow on a circle.

    Clicks within epsilon pixels of arrow head grab the arrow.
    """

    def __init__(self,epsilon=10):
        """Initialize circle in given axss."""
        axes = P.gca()
        circ = P.Circle((0,0),1., resolution=200)
        circ.set_fill(False)
        circ.set_edgecolor('b')
        axes.add_patch(circ)

        canvas = circ.figure.canvas
        canvas.mpl_connect('button_press_event',self.button_press_callback)
        canvas.mpl_connect('button_release_event',self.button_release_callback)
        canvas.mpl_connect('motion_notify_event',self.motion_notify_callback)

        self.epsilon = epsilon
        self.circ = circ
        self.arrow = None # Created by set_angle
        self.canvas = circ.figure.canvas
        self.external_hook = None
        self.axes = axes
        self.nr_pts = 0

        self.arrow_colour = {'default': 'r', 'selected': 'g'}
        self.arrow_mode = 'default'
        
        self.set_angle(N.pi/4.)

    def get_angle(self):
        """Return the angle of the arrow."""
        return self.__angle

    def set_angle(self,theta):
        """Point the arrow in the given direction."""
        self.__angle = theta

        self.update_arrow()

    angle = property(fget=get_angle, fset=set_angle,
                     doc="Angle of the arrow.")

    def update_arrow(self):
        """Redraw the canvas."""

        # Create a new arrow, and remove the previous one
        if self.arrow:
            self.axes.artists.remove(self.arrow)
        ex,ey = self.pos
        self.arrow = P.arrow(0,0,ex,ey,width=0.01,
                             length_includes_head=True)

        ac = self.arrow_colour[self.arrow_mode]
        self.arrow.set_edgecolor(ac)
        self.arrow.set_facecolor(ac)
        self.canvas.draw()

        if self.external_hook:
            self.external_hook()
        
    @property
    def pos(self):
        """Return position of arrow tip (x,y)."""
        a = self.angle
        return (N.cos(a), N.sin(a))

    def button_press_callback(self, event):
        """Called when a mouse button is pressed."""
        if event.inaxes == None: return
        if event.button != 1: return

        # translate graph coordinates to pixel coordinate
        transf = self.circ.get_transform()
        
        x,y = transf.xy_tup(self.pos)
        ex,ey = event.x,event.y
        if N.sqrt((ex-x)**2 + (ey-y)**2) > self.epsilon:
            return

        # Arrow selected
        self.arrow_mode = 'selected'
        
        cx,cy = zip(*self.circ.verts)
        cxt,cyt = transf.numerix_x_y(cx, cy)
        d = N.sqrt((cxt-ex)**2 + (cyt-ey)**2)

        self.update_arrow()

    def button_release_callback(self, event):
        """Called when a mouse button is released."""
        self.arrow_mode = 'default'
        self.update_arrow()

    def motion_notify_callback(self, event):
        """Called on mouse movement."""
        if self.arrow_mode == 'selected':
            transf = self.circ.get_transform()
            xt,yt = transf.inverse_xy_tup((event.x,event.y))
            self.angle = -N.arctan2(xt,yt) + N.pi/2
            self.update_arrow()
            

class SVD_Geometry:
    def __init__(self,M):
        fig = P.figure()
        
        cp = CirclePoint()
        cp.external_hook = self.plot_tf
        ax = cp.axes
        
        axis_max = N.array([1.5,1.5])
        axis_min = N.array([-1.5,1.5])
        
        U,S,V = N.linalg.svd(M)
        V = V.transpose()

        eig_vecs = N.vstack([(U*S).transpose(),V.transpose()])
        colours = ['m','m','c','c']
        labels = ['U', '', 'V', '']
        for i,ev in enumerate(eig_vecs):
            a = P.arrow(0,0,*ev,**{'width':0.01,'length_includes_head':True})
            a.set_edgecolor(colours[i])
            a.set_facecolor(colours[i])
            P.text(*(tuple(ev/1.5) + tuple([labels[i]])))

            idx = (ev > axis_max)
            axis_max[idx] = ev[idx]
            idx = (ev < axis_min)
            axis_min[idx] = ev[idx]

        self.M = M
        self.cp = cp

        # Maximum axis dimension is extent of bounding box
        am = max(axis_max.max(), abs(axis_min.min()))
        P.axis('equal')                
        P.axis([-am,am,-am,am])
        P.title("Geometrical Illustration of the SVD").set_weight('bold')
        P.xlabel("Click and drag the tip of the arrow.")
        P.show()

    def plot_tf(self):
        z = N.dot(self.M,N.array(self.cp.pos))
        self.cp.axes.hold(True)
        P.gca().add_patch(P.Circle(z,0.01))
    
demo = SVD_Geometry(N.array([[0.7,1.4],[1.2,0.1]]))
