view src/main/java/jp/ac/u_ryukyu/treevnc/MyRfbProto.java @ 118:38e461e9b9c9

remove duplicated code in MyRfbProto*
author Shinji KONO <kono@ie.u-ryukyu.ac.jp>
date Mon, 26 May 2014 18:30:18 +0900
parents bce2ef0a2e79
children c1b14cef2704
line wrap: on
line source

package jp.ac.u_ryukyu.treevnc;

import java.io.IOException;
import java.io.OutputStream;
import java.net.BindException;
import java.net.ServerSocket;
import java.net.Socket;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.util.LinkedList;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.zip.DataFormatException;
import java.util.zip.Deflater;
import java.util.zip.Inflater;

import jp.ac.u_ryukyu.treevnc.client.EchoClient;
import jp.ac.u_ryukyu.treevnc.server.RequestScreenThread;
import jp.ac.u_ryukyu.treevnc.server.VncProxyService;

import com.glavsoft.exceptions.TransportException;
import com.glavsoft.rfb.client.ClientToServerMessage;
import com.glavsoft.rfb.encoding.EncodingType;
import com.glavsoft.rfb.protocol.Protocol;
import com.glavsoft.rfb.protocol.ProtocolContext;
import com.glavsoft.transport.Reader;
import com.glavsoft.transport.Writer;
import com.glavsoft.viewer.ViewerImpl;

public class MyRfbProto {
	final static int FramebufferUpdateRequest = 3;
	final static int CheckDelay = 11;
	protected final static int FramebufferUpdate = 0;
	private ProtocolContext context;
	protected final static String versionMsg_3_856 = "RFB 003.856\n";
	private int clients;
	public MulticastQueue<LinkedList<ByteBuffer>> multicastqueue = new MulticastQueue<LinkedList<ByteBuffer>>();
	private RequestScreenThread rThread;
	private boolean proxyFlag = true;
	private EchoClient echo;
	private String proxyAddr;
	public int acceptPort = 0;
	protected boolean readyReconnect = false;
	private boolean cuiVersion;
	private long counter = 0; // packet serial number
	private VncProxyService viewer = null;
    public ServerSocket servSock;

    private static final int INFLATE_BUFSIZE = 1024 * 100;

    private Inflater inflater = new Inflater();
    private Deflater deflater = new Deflater();


	
	public MyRfbProto() {
		rThread = new RequestScreenThread(this);
	}
	
	public void setVncProxy(VncProxyService viewer) {
		this.viewer = viewer;
	}

	public boolean isRoot() {
	    return false;
	}
	
	public void newClient(AcceptThread acceptThread, final Socket newCli,
			final Writer os, final Reader is) throws IOException {
		// createBimgFlag = true;
		// rfb.addSockTmp(newCli);
		// addSock(newCli);
		final int myId = clients;
		final MulticastQueue.Client<LinkedList<ByteBuffer>> c = multicastqueue.newClient();
		final AtomicInteger writerRunning = new AtomicInteger();
		writerRunning.set(1);
		/**
		 * Timeout thread. If a client is suspended, it has top of queue
		 * indefinitely, which caused memory overflow. After the timeout, we
		 * poll the queue and discard it. Start long wait if writer is running.
		 */
		final Runnable timer = new Runnable() {
			public void run() {
				int count = 0;
				for (;;) {
					long timeout = 50000 / 8;
					try {
						synchronized (this) {
							int state, flag;
							writerRunning.set(0);
							wait(timeout);
							flag = 0;
							while ((state = writerRunning.get()) == 0) {
								c.poll(); // discard, should be timeout
								count++;
								if (flag == 0) {
									System.out.println("Discarding " + myId
											+ " count=" + count);
									flag = 1;
								}
								wait(10); // if this is too short, writer cannot
											// take the poll, if this is too
											// long, memory will overflow...
							}
							if (flag == 1)
								System.out.println("Resuming " + myId
										+ " count=" + count);
							if (state != 1) {
								System.out.println("Client died " + myId);
								break;
							}
						}
					} catch (InterruptedException e) {
					}
				}
			}
		};
		new Thread(timer, "timer-discard-multicastqueue").start();
		/**
		 * send all incoming from clients to parent.
		 */
		final Runnable reader = new Runnable() {


			public void run() {
				for (;;) {
					try {
		                final byte b[] = new byte[4096];
						final int c = is.readByte(b);
						if (c <= 0)
							throw new IOException();
						if (isRoot()) {
							if (b[0] == ClientToServerMessage.SERVER_CHANGE_REQUEST) {
								if (permitChangeScreen()) {
									ByteBuffer buf = ByteBuffer.wrap(b);
									buf.order(ByteOrder.BIG_ENDIAN);
									int length = buf.getInt(4);
									if (length == 0) 
										continue;
				                	String newHostName = new String(b, 8, length);
				                	System.out.println("Root server change request :" + newHostName);
									// please remove these numbers.
				                	if (viewer != null) {
					                    viewer.changeVNCServer(newHostName, 3200, 1980);				                		
				                	}
				                } else {
				                    continue;
				                }
							}
						} else if (b[0] == ClientToServerMessage.SERVER_CHANGE_REQUEST) {
						    ClientToServerMessage sc = new ClientToServerMessage() {
                                @Override
                                public void send(Writer writer)
                                        throws TransportException {
                                    writer.write(b,0,c);
                                }
						    };
							context.sendMessage(sc);
						}
						// System.out.println("client read "+c);
					} catch (Exception e) {
						try {
							writerRunning.set(2);
							os.close();
							is.close();
							break;
						} catch (IOException e1) {
						} catch (TransportException e1) {
							e1.printStackTrace();
						}
						return;
					}
				}
			}

			private boolean permitChangeScreen() {
				return true;
			}
		};
		/**
		 * send packets to a client
		 */
		Runnable sender = new Runnable() {
			public void run() {
				writerRunning.set(1);
				try {
					requestThreadNotify();

					/**
					 * initial connection of RFB protocol
					 */
					sendRfbVersion(os);
					// readVersionMsg(is);
					readVersionMsg(is, os);
					sendSecurityType(os);
					readSecType(is);
					sendSecResult(os);
					readClientInit(is);
					sendInitData(os);
					// after this, we discard upward packet.
					new Thread(reader, "discard-upward-comm").start(); 
					// writeFramebufferUpdateRequest(0,0, framebufferWidth,
					// framebufferHeight, false );
					for (;;) {
						LinkedList<ByteBuffer> bufs = c.poll();
						int inputIndex = 0;
						ByteBuffer header = bufs.get(inputIndex);
						if (header == null)
							continue;
						else if (header.get(0) == CheckDelay) {
							writeToClient(os, bufs, inputIndex);
							continue;
						} else if (header.get(0) == FramebufferUpdate) {
							 //System.out.println("client "+ myId);
						}
						/*
						 * if(i%20==0){ sendDataCheckDelay(); } i++;
						 */
						writeToClient(os, bufs, inputIndex);
						writerRunning.set(1); // yes my client is awaking.
					}
				} catch (Exception e) {
					try {
						writerRunning.set(2);
						os.close();
					} catch (IOException e1) {
					}
					/* if socket closed cliList.remove(newCli); */
				}
			}

			public void writeToClient(final Writer os,
					LinkedList<ByteBuffer> bufs, int inputIndex)
					throws TransportException {
				while (inputIndex < bufs.size()) {
					ByteBuffer b = bufs.get(inputIndex++);
					os.write(b.array(), b.position(), b.limit());
				}
				os.flush();
				bufs = null;
				multicastqueue.heapAvailable();
			}
		};
		clients++;
		new Thread(sender, "writer-to-lower-node").start();

	}
	
	public void requestThreadNotify() {
		rThread.reStart();
	}
	
	private void sendRfbVersion(Writer writer) throws IOException, TransportException {
		writer.write(versionMsg_3_856.getBytes());
	}
	
	private int readVersionMsg(Reader reader, Writer writer) throws IOException, TransportException {

		byte[] b = new byte[12];

		reader.readBytes(b);

		if ((b[0] != 'R') || (b[1] != 'F') || (b[2] != 'B') || (b[3] != ' ')
				|| (b[4] < '0') || (b[4] > '9') || (b[5] < '0') || (b[5] > '9')
				|| (b[6] < '0') || (b[6] > '9') || (b[7] != '.')
				|| (b[8] < '0') || (b[8] > '9') || (b[9] < '0') || (b[9] > '9')
				|| (b[10] < '0') || (b[10] > '9') || (b[11] != '\n')) {
			throw new IOException("this is not an RFB server");
		}

		int rfbMajor = (b[4] - '0') * 100 + (b[5] - '0') * 10 + (b[6] - '0');
		int rfbMinor = (b[8] - '0') * 100 + (b[9] - '0') * 10 + (b[10] - '0');

		if (rfbMajor < 3) {
			throw new IOException(
					"RFB server does not support protocol version 3");
		}

		if (rfbMinor == 855) {
			sendProxyFlag(writer);
			if (proxyFlag)
				sendPortNumber(writer);
		}
		return rfbMinor;
	}
		
	
	private void sendProxyFlag(Writer writer) throws TransportException {
		if (proxyFlag)
			writer.writeInt(1);
		else
			writer.writeInt(0);
	}

	private void sendPortNumber(Writer writer) throws TransportException {
	    ByteBuffer b = ByteBuffer.allocate(4);
	    b.order(ByteOrder.BIG_ENDIAN);
	    b.putInt(9999);
		writer.write(b.array());
	}
	
	
	private void readSecType(Reader reader) throws TransportException {
		byte[] b = new byte[1];
		reader.read(b);
	}
	
	private void sendSecurityType(Writer os) throws TransportException {
		// number-of-security-types
		os.writeInt(1);
		// security-types
		// 1:None
		os.writeInt(1);

		/*
		 * os.write(4); os.write(30); os.write(31); os.write(32); os.write(35);
		 * os.flush();
		 */
	}
	
	private void sendSecResult(Writer os) throws TransportException {
	       ByteBuffer b = ByteBuffer.allocate(4);
	        b.order(ByteOrder.BIG_ENDIAN);
	        b.putInt(0);
	        os.write(b.array());
	}

	private void readClientInit(Reader in) throws TransportException {
		byte[] b = new byte[0];
		in.readBytes(b);
	}
	
	byte initData[] = {7, -128, 4, 56, 32, 24, 0, 1, 0, -1, 0, -1, 0, -1, 16, 8, 0, 0, 0, 0, 0, 0, 0, 7, 102, 105, 114, 101, 102, 108, 121};
	private void sendInitData(Writer os) throws TransportException {
		// In case of "-d" we have no context 
		if (context != null){
			os.write(context.getInitData());			
		} else {
			// Send dummy data
			os.write(initData);
		}
	}
	
    public void setProtocolContext(Protocol workingProtocol) {
        context = workingProtocol;
    }


	public Socket accept() throws IOException {
		return null;
	}

    public void initServSock(int port) throws IOException {
        servSock = new ServerSocket(port);
        acceptPort = port;
    }
	
    public int selectPort(int p) {
        int port = p;
        while (true) {
            try {
                initServSock(port);
                break;
            } catch (BindException e) {
                port++;
                continue;
            } catch (IOException e) {

            }
        }
        System.out.println("accept port = " + port);
        return port;
    }



	public void writeFramebufferUpdateRequest(int x, int y, int w, int h,
			boolean incremental) throws TransportException {
		byte[] b = new byte[10];

		b[0] = (byte) FramebufferUpdateRequest; // 3 is FrameBufferUpdateRequest
		b[1] = (byte) (incremental ? 1 : 0);
		b[2] = (byte) ((x >> 8) & 0xff);
		b[3] = (byte) (x & 0xff);
		b[4] = (byte) ((y >> 8) & 0xff);
		b[5] = (byte) (y & 0xff);
		b[6] = (byte) ((w >> 8) & 0xff);
		b[7] = (byte) (w & 0xff);
		b[8] = (byte) ((h >> 8) & 0xff);
		b[9] = (byte) (h & 0xff);

//		os.write(b);
	}
	
	public void notProxy() {
		proxyFlag = false;
	}

	public void setEcho(EchoClient _echo) {
		echo = _echo;
	}
	
	public void setViewer(ViewerImpl v) {
		echo.setViewer(v);
	}
	
	public ViewerImpl getViewer() {
		return echo.getViewer();
	}
	
	public EchoClient getEcho() {
		return echo;
	}

	public void setTerminationType(boolean setType) {
		/*nop*/
	}

	public boolean getTerminationType() {
		/*nop*/
		return true;
	}

	public void setProxyAddr(String proxyAddr) {
		this.proxyAddr = proxyAddr;
	}

    void sendProxyFlag(OutputStream os) throws IOException {
        if (proxyFlag)
            os.write(1);
        else
            os.write(0);
    }



	public void close() {
	    // none
	}
	
	public int getAcceptPort() {
		return acceptPort;
	}
	
	public boolean getReadyReconnect() {
		return readyReconnect;
	}


	public boolean getCuiVersion() {
		return cuiVersion;
	} 
	
	public void  setCuiVersion(boolean flag) {
		cuiVersion = flag;
	}

	public void readCheckDelay(Reader reader) throws TransportException {
		
	}
	
	public String getProxyAddr() {
		return proxyAddr;
	}
	
	public synchronized void setReadyReconnect(boolean ready) {
		readyReconnect = ready;
		if (ready) {
			notifyAll();
		}
	}	

	public synchronized void waitForReady(VncProxyService vncProxyService) throws InterruptedException {
		while (!readyReconnect) {
			wait();
		}
	}


	public void sendDesktopSizeChange() {
		LinkedList<ByteBuffer> desktopSize = new LinkedList<ByteBuffer>();
		int width = context.getFbWidth();
		int height = context.getFbHeight();
		desktopSize.add(new UpdateRectangleMessage(0,0, width, height, EncodingType.DESKTOP_SIZE).getMessage());
		addSerialNumber(desktopSize);
		multicastqueue.put(desktopSize);
	}


	public void addSerialNumber(LinkedList<ByteBuffer> bufs) {
		ByteBuffer serialNum = multicastqueue.allocate(8);
		serialNum.putLong(counter++);
		serialNum.flip();
		bufs.addFirst(serialNum);
	}


    public void resetDecoder() {
        context.resetDecoder();
    }

    public void stopReceiverTask() {
        if (context!=null)
            context.cleanUpSession(null);
        // cleanup zlib decoder for new VNCServer
        if (isRoot())
            inflater = new Inflater();
    }

    public String getMyAddress() {
        return echo.getMyAddress();
    }

    /**
     * gzip byte arrays
     * 
     * @param deflater
     * @param inputs
     *            byte data[]
     * @param inputIndex
     * @param outputs
     *            byte data[]
     * @return byte length in last byte array
     * @throws IOException
     */
    public int zip(Deflater deflater, LinkedList<ByteBuffer> inputs,
            int inputIndex, LinkedList<ByteBuffer> outputs) throws IOException {
        int len = 0;
        ByteBuffer c1 = multicastqueue.allocate(INFLATE_BUFSIZE);
        while (inputIndex < inputs.size()) {
            ByteBuffer b1 = inputs.get(inputIndex++);
            deflater.setInput(b1.array(), b1.position(), b1.remaining());
            /**
             * If we finish() stream and reset() it, Deflater start new gzip
             * stream, this makes continuous zlib reader unhappy. if we remove
             * finish(), Deflater.deflate() never flushes its output. The
             * original zlib deflate has flush flag. I'm pretty sure this a kind
             * of bug of Java library.
             */
            if (inputIndex == inputs.size())
                deflater.finish();
            int len1 = 0;
            do {
                len1 = deflater.deflate(c1.array(), c1.position(),
                        c1.remaining());
                if (len1 > 0) {
                    len += len1;
                    c1.position(c1.position() + len1);
                    if (c1.remaining() == 0) {
                        c1.flip();
                        outputs.addLast(c1);
                        c1 = multicastqueue.allocate(INFLATE_BUFSIZE);
                    }
                }
            } while (len1 > 0 || !deflater.needsInput()); // &&!deflater.finished());
        }
        if (c1.position() != 0) {
            c1.flip();
            outputs.addLast(c1);
        }
        deflater.reset();
        return len;
    }

    /**
     * gunzip byte arrays
     * 
     * @param inflater
     * @param inputs
     *            byte data[]
     * @param outputs
     *            byte data[]
     * @return number of total bytes
     * @throws IOException
     */
    public int unzip(Inflater inflater, LinkedList<ByteBuffer> inputs,
            int inputIndex, LinkedList<ByteBuffer> outputs, int bufSize)
            throws DataFormatException {
        int len = 0;
        ByteBuffer buf = multicastqueue.allocate(bufSize);
        while (inputIndex < inputs.size()) {
            ByteBuffer input = inputs.get(inputIndex++);
            inflater.setInput(input.array(), input.position(), input.limit());
            // if (inputIndex==inputs.size()) if inflater/deflater has symmetry,
            // we need this
            // inflater.end(); but this won't work
            do {
                int len0 = inflater.inflate(buf.array(), buf.position(),
                        buf.remaining());
                if (len0 > 0) {
                    buf.position(buf.position() + len0);
                    len += len0;
                    if (buf.remaining() == 0) {
                        buf.flip();
                        outputs.addLast(buf);
                        buf = multicastqueue.allocate(bufSize);
                    }
                }
            } while (!inflater.needsInput());
        }
        if (buf.position() != 0) {
            buf.flip();
            outputs.addLast(buf);
        }
        return len;
    }

    public void readSendData(int dataLen, Reader reader)
            throws TransportException {
        LinkedList<ByteBuffer> bufs = new LinkedList<ByteBuffer>();
        ByteBuffer header = multicastqueue.allocate(16);
        ByteBuffer serial = multicastqueue.allocate(8);
        if (!isRoot()) {
            reader.mark(dataLen+8); // +8 is serialnum    
            reader.readBytes(serial.array(),0,8);
            serial.limit(8);
        }
        reader.readBytes(header.array(), 0, 16);
        header.limit(16);
        if (header.get(0) == FramebufferUpdate) {
            int encoding = header.getInt(12);
            if (encoding == EncodingType.ZRLE.getId()
                    || encoding == EncodingType.ZLIB.getId()) { // ZRLEE is
                                                                // already
                                                                // recompressed
                ByteBuffer len = multicastqueue.allocate(4);
                reader.readBytes(len.array(), 0, 4);
                len.limit(4);
                ByteBuffer inputData = multicastqueue.allocate(dataLen - 20);
                reader.readBytes(inputData.array(), 0, inputData.capacity());
                inputData.limit(dataLen - 20);
                LinkedList<ByteBuffer> inputs = new LinkedList<ByteBuffer>();
                inputs.add(inputData);

                header.putInt(12, EncodingType.ZRLEE.getId()); // means
                                                                // recompress
                                                                // every time
                // using new Deflecter every time is incompatible with the
                // protocol, clients have to be modified.
                Deflater nDeflater = deflater; // new Deflater();
                LinkedList<ByteBuffer> out = new LinkedList<ByteBuffer>();
                try {
                    unzip(inflater, inputs, 0, out, INFLATE_BUFSIZE);
                    // dump32(inputs);
                    int len2 = zip(nDeflater, out, 0, bufs);
                    ByteBuffer blen = multicastqueue.allocate(4);
                    blen.putInt(len2);
                    blen.flip();
                    bufs.addFirst(blen);
                    bufs.addFirst(header);
                    multicastqueue.put(bufs);
                    if (!isRoot()) reader.reset();
                } catch (DataFormatException e) {
                    throw new TransportException(e);
                } catch (IOException e) {
                    throw new TransportException(e);
                }
                return;
            }
            if (!isRoot())
                bufs.add(serial);
            bufs.add(header);
            if (dataLen > 16) {
                ByteBuffer b = multicastqueue.allocate(dataLen - 16);
                reader.readBytes(b.array(), 0, dataLen - 16);
                b.limit(dataLen - 16);
                bufs.add(b);
            }
            multicastqueue.put(bufs);
            return;
        }
        if (isRoot())
            reader.reset();
        // It may be compressed. We can inflate here to avoid repeating clients
        // decompressing here,
        // but it may generate too many large data. It is better to do it in
        // each client.
        // But we have do inflation for all input data, so we have to do it
        // here.
    }
}