package proxy;

import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.InetAddress;
import java.net.ServerSocket;
import java.net.Socket;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

import net.zortrium.util.Options;

/**
 * A simple proxy server for forwarding to an arbitrary remote service. Handles
 * an arbitrary number of concurrent clients and displays client connection and
 * disconnection information. This code may be used, modified, and redistributed
 * provided that the author tag below remains intact.
 * @author Sean Barker (zortrium@gmail.com)
 */
public class ProxyServer {

  /**
   * Read command line arguments, verify that all needed information is
   * provided, then start the proxy server.
   * @param args Command line arguments.
   */
  public static void main(String args[]) {
    Options.setUsageString("java -jar proxy.jar");
    Options.add(Integer.class, "proxyPort").setShortFlag('p').setDescription(
        "Local port to listen to").setRequired(true);
    Options.add(Integer.class, "servicePort").setShortFlag('s').setDescription(
        "Remote server port to connect to").setRequired(true);
    Options.add(String.class, "serviceDomain").setShortFlag('d').setDescription(
        "Domain name of remote server").setRequired(true);
    Options.readArguments(args);
    try {
      new ProxyServer(Options.get(Integer.class, "proxyPort"), Options.get(Integer.class,
          "servicePort"), Options.get(String.class, "serviceDomain"));
    } catch (IOException e) {
      System.err.println("Unable to create proxy server: " + e);
    }
  }

  /**
   * Address of the remote server clients are connected to.
   */
  private final InetAddress serviceAddress;

  /**
   * Setup the proxy server and wait indefinitely for clients, sending their
   * connections off to worker threads to handle.
   * @param proxyPort Local listener port.
   * @param servicePort Remote server port.
   * @param serviceDomainName Remote server domain name.
   * @throws IOException If the proxy server could not be started.
   */
  private ProxyServer(int proxyPort, int servicePort, String serviceDomainName) throws IOException {
    this.serviceAddress = InetAddress.getByName(serviceDomainName);
    final ServerSocket listener = new ServerSocket(proxyPort);
    final ExecutorService threadPool = Executors.newCachedThreadPool();
    System.out.println("Proxy server successfully started.");
    while (true) {
      try {
        Socket toClient = listener.accept();
        System.out.println(toClient.getInetAddress().getCanonicalHostName() + " connected");
        Socket toServer = new Socket(serviceAddress, servicePort);
        threadPool.execute(new ProxyConnection(toClient, toServer));
        threadPool.execute(new ProxyConnection(toServer, toClient));
      } catch (IOException e) {
        System.err.println("An unexpected error occurred: " + e);
      }
    }
  }

  /**
   * A one-way proxied connection. All incoming data is forwarded on through the
   * forwarding socket. Once one socket closes, the other socket is also closed.
   */
  private class ProxyConnection implements Runnable {

    /**
     * The incoming socket connection.
     */
    private final Socket incoming;

    /**
     * The outgoing socket connection.
     */
    private final Socket forwarding;

    /**
     * Store the two sockets involved in the proxied connection.
     * @param incoming The socket on which data arrives.
     * @param forwarding The socket on which data is forwarded.
     */
    public ProxyConnection(Socket incoming, Socket forwarding) {
      this.incoming = incoming;
      this.forwarding = forwarding;
    }

    /**
     * Forward data until no more data remains, then close both sockets.
     */
    @Override
    public void run() {
      try {
        InputStream input = incoming.getInputStream();
        OutputStream output = forwarding.getOutputStream();
        for (int nextByte = input.read(); nextByte != -1; nextByte = input.read()) {
          output.write(nextByte);
        }
      } catch (IOException e) {
      }
      close(incoming);
      close(forwarding);
    }

    /**
     * Close the given socket if it is not already closed. Synchronized on the
     * socket to avoid closing the same socket on multiple threads.
     * @param sock The socket to close.
     */
    private void close(Socket sock) {
      synchronized (sock) {
        if (!sock.isClosed()) {
          try {
            sock.close();
          } catch (IOException e) {
          }
          InetAddress sockAddress = sock.getInetAddress();
          if (!sockAddress.equals(serviceAddress)) {
            System.out.println(sockAddress.getCanonicalHostName() + " disconnected");
          }
        }
      }
    }
  }

}


