//
// T    wo dimensional Hopfield net demonstration
import java.awt.*;
import java.io.*;
import java.applet.*;
import java.util.*;

public class hop2d extends Applet
{ static final int MAXPATTERNS = 4;
  static final int ROWS = 8;                       // Dimensions of the images
  static final int COLS = 8;
  static final int MAINSCREEN = 0;        // Constants for drawing the screens
  static final int TRAINPATTS = 1;

  Button viewTrainPattern;   // Click on this button to view the training patterns
  Button singleStep;            // Click on this button to run a test one (more) step
  int trainPatterns[][][] = new int [MAXPATTERNS][COLS][ROWS];
  int testPattern[][] = new int [COLS][ROWS];
  int t[][][][] = new int [COLS][ROWS][COLS][ROWS];   // connection weights
  int screenMode = MAINSCREEN;

  public void init ()
  { setBackground(Color.cyan);
    singleStep = new Button("Single Step");
    add(singleStep);
    viewTrainPattern = new Button("View Training Patterns");
    add(viewTrainPattern);
    setupTrainingPatterns();
    // Initialise test pattern to blank
    for (int i = 0; i < COLS; i++)
      for (int j = 0; j < ROWS; j++)
        testPattern[i][j] = -1;
  }

  public void paint (Graphics g)
  { switch (screenMode)
     { case MAINSCREEN : drawMainScreen(g);  break;
       case TRAINPATTS : drawTrainingPatterns(g); break;
     }
  }

  public void drawMainScreen (Graphics g)
  { g.drawString("Test pattern",50,40);
    for (int j = 0; j < ROWS; j++)
      for (int i = 0; i < COLS; i++)
       drawCell(g,testPattern[i][j],50 + 12 * i, 50 + 12 * j);
  }

  public void drawTrainingPatterns (Graphics g)
  { for (int pattern = 0; pattern < MAXPATTERNS; pattern++)
      { g.drawString("Pattern " + pattern, 5 + 110 * pattern, 45);
        for (int i = 0; i < COLS; i++)
          for (int j = 0; j < ROWS; j++)
            drawCell(g, trainPatterns[pattern][i][j],5 + 110 * pattern + 12 * i, 50 + 12 * j);
      }
  }

  // Draw a particular value of 1 or -1
  public void drawCell (Graphics g, int value, int x, int y)
  { if (value == 1)
      g.setColor(Color.black);
    else
      g.setColor(Color.white);
    g.fillOval(x,y,12,12);         // Draws a white circle or a black one
    g.setColor(Color.black);
    g.drawOval(x,y,12,12);
  }

  // Process the button clicks
  public boolean action (Event e, Object arg)
  { if (e.target == viewTrainPattern)
      { if (viewTrainPattern.getLabel() == "View Training Patterns")
         { viewTrainPattern.setLabel("View Main Screen");
           screenMode = TRAINPATTS;
           repaint();
         }
        else
         { viewTrainPattern.setLabel("View Training Patterns");
           screenMode = MAINSCREEN;
           repaint();
         }
      }
    if (e.target == singleStep)
      { runSingleStep();
        repaint();
      }
    return true;
  }

  // Handle general mouse clicks depending on the drawing mode
  public boolean mouseDown (Event event, int x, int y)
  { switch (screenMode)
    { case MAINSCREEN : mouseMainScreen(x,y); break;
      case TRAINPATTS : mouseTrain(x,y); break;
    }
    return true;
  }

  public void mouseMainScreen (int x, int y)
  { boolean redrawNeeded = false;
    for (int i = 0; i < COLS; i++)
      for (int j = 0; j< ROWS; j++)
        if (x > 50 + i * 12 && x < 62 + i * 12 && y > 50 + j * 12 && y < 62 + j * 12)
          { testPattern[i][j] = - testPattern[i][j];
            redrawNeeded = true;
          }
    if (redrawNeeded == true)
      repaint();
  }

  public void mouseTrain (int x, int y)
  { boolean redrawNeeded = false;
    for (int pattern = 0; pattern < MAXPATTERNS; pattern++)
      for (int i = 0; i < COLS; i++)
        for (int j = 0; j < ROWS; j++)
          if (x > 5 + 110 * pattern + 12 * i && x < 17 + 110 * pattern + 12 * i &&
              y > 50 + 12 * j && y < 62 + 12 * j)
           { trainPatterns[pattern][i][j] = - trainPatterns[pattern][i][j];  // Swap -1 to 1 and vice-versa
             redrawNeeded = true;
           }
    if (redrawNeeded == true)
     { retrain();
       repaint();
     }
  }

  public void runSingleStep ()
  { int new_mu[][] = new int[COLS][ROWS];   // Holds new values of mu before copying back
    int i,j,k,l;                                                           // FOR loop counters
    for (k = 0; k < COLS; k++)
      for (l = 0; l < ROWS; l++)
       { int sum = 0;
         for (i = 0; i < COLS; i++)
           for (j = 0; j < ROWS; j++)
             sum += t[i][j][k][l] * testPattern[i][j];
         if (sum > 0)    // Pass through hard-limiting non-linearity
           new_mu[k][l] = 1;
         else
           new_mu[k][l] = -1;
       }
    // Now copy the values present in new_mu back into the test pattern ready to be displayed
    for (i = 0; i < COLS; i++)
      for (j = 0; j < ROWS; j++)
        testPattern[i][j] = new_mu[i][j];
    repaint();
  }
  
  public void setupTrainingPatterns ()
  { initialisePattern(0,"00000000","01111110","01000010","01000010","01000010","01000010",
                                    "01111110","00000000");
    initialisePattern(1,"00000000","00010000","00010000","00010000","11111110","00010000",
                                   "00010000","00010000");
    initialisePattern(2,"10011110","10010010","10010010","10010010","10010010","10010010",
                                   "10010010","11110011");
    initialisePattern(3,"00000000","00010000","00011000","00100100","00100100","01000010",
                                   "01000010","11111111");
    retrain();
  }
  
  public void initialisePattern (int patNum, String line0, String line1, String line2, String line3,
                                                                        String line4, String line5, String line6, String line7)
  { initialiseline(patNum,0,line0);
    initialiseline(patNum,1,line1);
    initialiseline(patNum,2,line2);
    initialiseline(patNum,3,line3);
    initialiseline(patNum,4,line4);
    initialiseline(patNum,5,line5);
    initialiseline(patNum,6,line6);
    initialiseline(patNum,7,line7);
  }
  
  public void initialiseline (int patNum, int lineNum, String line)
  { for (int i = 0; i < line.length(); i++)
      if (line.charAt(i) == '1')
        trainPatterns[patNum][i][lineNum] = +1;
      else
        trainPatterns[patNum][i][lineNum] = -1;
  }
  
  public void retrain ()
  { for (int i = 0; i < COLS; i++)
      for (int j = 0; j < ROWS; j++)
        for (int k = 0; k < COLS; k++)
          for (int l = 0; l < ROWS; l++)
            if (i == k && j == l)
              t[i][j][k][l] = 0;
            else
              { int sum = 0;
                for (int pattern = 0; pattern < MAXPATTERNS; pattern++)
                  sum += trainPatterns[pattern][i][j] * trainPatterns[pattern][k][l];
                t[i][j][k][l] = sum;
              }
  }
}