view src/myVncProxy/MyRfbProto.java @ 90:462bca4c8cec

ByteBuffer
author Shinji KONO <kono@ie.u-ryukyu.ac.jp>
date Wed, 03 Aug 2011 10:30:45 +0900
parents 9b3b1e3e7db5
children 4116c19cd76e
line wrap: on
line source

package myVncProxy;

import static org.junit.Assert.*;

import java.awt.Graphics;
import java.awt.Image;
import java.awt.image.BufferedImage;
import java.io.BufferedOutputStream;
import java.io.BufferedReader;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.net.BindException;
import java.net.ServerSocket;
import java.net.Socket;
import java.nio.ByteBuffer;
import java.util.LinkedList;

import javax.imageio.ImageIO;

import org.junit.Test;

import myVncProxy.MulticastQueue.Client;

import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.zip.DataFormatException;
import java.util.zip.Deflater;
import java.util.zip.Inflater;
import java.io.OutputStream;

public
class MyRfbProto extends RfbProto {
	final static String versionMsg_3_998 = "RFB 003.998\n";
	/**
	 * CheckMillis is one of new msgType for RFB 3.998. 
	 */
	final static byte SpeedCheckMillis = 4;
	private static final int INFLATE_BUFSIZE = 1024*100;
	boolean printStatusFlag = false;
	long startCheckTime;

	private int messageType;
	private int rectangles;
	private int rectX;
	private int rectY;
	private int rectW;
	private int rectH;
	private int encoding;
	private int zLen;

	private ServerSocket servSock;
	private int acceptPort;
	private byte initData[];
	private LinkedList<Socket> cliListTmp;
	private LinkedList<Socket> cliList;
	private LinkedList<Thread> sendThreads;
	boolean createBimgFlag;

	ExecutorService executor;

	byte[] pngBytes;

	private MulticastQueue<LinkedList<ByteBuffer>> multicastqueue = new MostRecentMultiCast<LinkedList<ByteBuffer>>(10);
	private int clients = 0;
	private Inflater inflater = new Inflater();

	public
	MyRfbProto() throws IOException {
	}
	
	MyRfbProto(String h, int p, VncViewer v) throws IOException {
		super(h, p, v);
		cliList = new LinkedList<Socket>();
		cliListTmp = new LinkedList<Socket>();
		createBimgFlag = false;
		//		sendThreads = new LinkedList<Thread>();
		// executor = Executors.newCachedThreadPool();
		// executor = Executors.newSingleThreadExecutor();
	}

	MyRfbProto(String h, int p) throws IOException {
		super(h, p);
		cliList = new LinkedList<Socket>();
		cliListTmp = new LinkedList<Socket>();
		createBimgFlag = false;
		//		sendThreads = new LinkedList<Thread>();
		// executor = Executors.newCachedThreadPool();
		// executor = Executors.newSingleThreadExecutor();
	}

	// over write
	void writeVersionMsg() throws IOException {
		clientMajor = 3;
		if (serverMinor >= 9) {
			clientMinor = 9;
			os.write(versionMsg_3_998.getBytes());
		} else if (serverMajor > 3 || serverMinor >= 8) {
			clientMinor = 8;
			os.write(versionMsg_3_8.getBytes());
		} else if (serverMinor >= 9) {
			clientMinor = 9;
			os.write(versionMsg_3_998.getBytes());
		} else if (serverMinor >= 7) {
			clientMinor = 7;
			os.write(versionMsg_3_7.getBytes());
		} else {
			clientMinor = 3;
			os.write(versionMsg_3_3.getBytes());
		}
		protocolTightVNC = false;
		initCapabilities();
	}

	void initServSock(int port) throws IOException {
		servSock = new ServerSocket(port);
		acceptPort = port;
	}

	// 5550を開けるが、開いてないなら+1のポートを開ける。
	void selectPort(int p) {
		int port = p;
		while (true) {
			try {
				initServSock(port);
				break;
			} catch (BindException e) {
				port++;
				continue;
			} catch (IOException e) {

			}
		}
		System.out.println("accept port = " + port);
	}

	int getAcceptPort() {
		return acceptPort;
	}

	void setSoTimeout(int num) throws IOException {
		servSock.setSoTimeout(num);
	}

	Socket accept() throws IOException {
		return servSock.accept();
	}

	void addSock(Socket sock) {
		cliList.add(sock);
	}

	void addSockTmp(Socket sock) {
		System.out.println("connected " + sock.getInetAddress());
		cliListTmp.add(sock);
	}

	boolean markSupported() {
		return is.markSupported();
	}

	void readServerInit() throws IOException {

		is.mark(255);
		skipBytes(20);
		int nlen = readU32();
		int blen = 20 + 4 + nlen;
		initData = new byte[blen];
		is.reset();

		is.mark(blen);
		readFully(initData);
		is.reset();

		framebufferWidth = readU16();
		framebufferHeight = readU16();
		bitsPerPixel = readU8();
		depth = readU8();
		bigEndian = (readU8() != 0);
		trueColour = (readU8() != 0);
		redMax = readU16();
		greenMax = readU16();
		blueMax = readU16();
		redShift = readU8();
		greenShift = readU8();
		blueShift = readU8();
		byte[] pad = new byte[3];
		readFully(pad);
		int nameLength = readU32();
		byte[] name = new byte[nameLength];
		readFully(name);
		desktopName = new String(name);

		// Read interaction capabilities (TightVNC protocol extensions)
		if (protocolTightVNC) {
			int nServerMessageTypes = readU16();
			int nClientMessageTypes = readU16();
			int nEncodingTypes = readU16();
			readU16();
			readCapabilityList(serverMsgCaps, nServerMessageTypes);
			readCapabilityList(clientMsgCaps, nClientMessageTypes);
			readCapabilityList(encodingCaps, nEncodingTypes);
		}

		inNormalProtocol = true;
	}

	void sendRfbVersion(OutputStream os) throws IOException {
		os.write(versionMsg_3_998.getBytes());
	}

	void readVersionMsg(InputStream is) throws IOException {

		byte[] b = new byte[12];

		is.read(b);

		if ((b[0] != 'R') || (b[1] != 'F') || (b[2] != 'B') || (b[3] != ' ')
				|| (b[4] < '0') || (b[4] > '9') || (b[5] < '0') || (b[5] > '9')
				|| (b[6] < '0') || (b[6] > '9') || (b[7] != '.')
				|| (b[8] < '0') || (b[8] > '9') || (b[9] < '0') || (b[9] > '9')
				|| (b[10] < '0') || (b[10] > '9') || (b[11] != '\n')) {
			throw new IOException("Host " + host + " port " + port
					+ " is not an RFB server");
		}

		serverMajor = (b[4] - '0') * 100 + (b[5] - '0') * 10 + (b[6] - '0');
		serverMinor = (b[8] - '0') * 100 + (b[9] - '0') * 10 + (b[10] - '0');

		if (serverMajor < 3) {
			throw new IOException(
					"RFB server does not support protocol version 3");
		}

	}

	void sendSecurityType(OutputStream os) throws IOException {
		// number-of-security-types
		os.write(1);
		// security-types
		// 1:None
		os.write(1);
	}

	void readSecType(InputStream is) throws IOException {
		byte[] b = new byte[1];
		is.read(b);

	}

	void sendSecResult(OutputStream os) throws IOException {
		byte[] b = castIntByte(0);
		os.write(b);
	}

	void readClientInit(InputStream in) throws IOException {
		byte[] b = new byte[0];
		in.read(b);
	}

	void sendInitData(OutputStream os) throws IOException {
		os.write(initData);
	}


	void sendPngImage() {
		try {
			for (Socket cli : cliListTmp) {
				try {
					sendPngData(cli);
					addSock(cli);
				} catch (IOException e) {
					// if socket closed
					cliListTmp.remove(cli);
				}
			}
			// System.out.println("cliSize="+cliSize());
		} catch (Exception e) {
		}
		cliListTmp.clear();
	}

	boolean ready() throws IOException {
		BufferedReader br = new BufferedReader(new InputStreamReader(is));
		return br.ready();
	}

	int cliSize() {
		return cliList.size();
	}

	void printNumBytesRead() {
		System.out.println("numBytesRead=" + numBytesRead);
	}


	
	void regiFramebufferUpdate() throws IOException {
		is.mark(20);
		messageType = readU8();   // 0
		skipBytes(1);                    // 1
		rectangles = readU16();     //  2
		rectX = readU16();            //  4  
		rectY = readU16();            //  6
		rectW = readU16();           //  8
		rectH = readU16();           //  10
		encoding = readU32();      //   12
		System.out.println("encoding = "+encoding);
		if (encoding == EncodingZRLE)
			zLen = readU32();
		else
			zLen = 0;
		is.reset();
/*
		int dataLen;
		switch (encoding) {
		case RfbProto.EncodingRaw:
			dataLen = rectW * rectH * 4 + 16;
			mark(dataLen);
			break;
		case RfbProto.EncodingCopyRect:
			dataLen = 16 + 4;
			mark(dataLen);
			break;
		case RfbProto.EncodingRRE:
		case RfbProto.EncodingCoRRE:
		case RfbProto.EncodingHextile:
			
		case RfbProto.EncodingZlib:
		case RfbProto.EncodingTight:
		case RfbProto.EncodingZRLE:
			dataLen = zLen + 20;
			mark(dataLen);
			break;
		default:
			dataLen = 1000000;
			mark(dataLen);
		}
	
*/	
	
	}

	int checkAndMark() throws IOException {
		int dataLen;
		switch (encoding) {
		case RfbProto.EncodingRaw:
			dataLen = rectW * rectH * 4 + 16;
			is.mark(dataLen);
			break;
		case RfbProto.EncodingCopyRect:
			dataLen = 16 + 4;
			is.mark(dataLen);
			break;
		case RfbProto.EncodingRRE:
		case RfbProto.EncodingCoRRE:
		case RfbProto.EncodingHextile:
		case RfbProto.EncodingTight:
			dataLen = zLen + 20;
			is.mark(dataLen);
			break;
		case RfbProto.EncodingZlib:
		case RfbProto.EncodingZRLE:
			dataLen = zLen + 20;
			is.mark(dataLen);
			break;
		case RfbProto.EncodingXCursor:
		case RfbProto.EncodingRichCursor:
			int pixArray = rectW * rectH * 4;
			int u8Array = (int)Math.floor((rectW + 7)/8) * rectH; 
			dataLen = pixArray + u8Array;
			printFramebufferUpdate();
			is.mark(dataLen);
			break;
		default:
			dataLen = 1000000;
			is.mark(dataLen);
		}
		return dataLen;
	}
	

	void sendDataToClient() throws Exception {
		regiFramebufferUpdate();
		int dataLen = checkAndMark();
		readSendData(dataLen);		
	}

	BufferedImage createBufferedImage(Image img) {
		BufferedImage bimg = new BufferedImage(img.getWidth(null),
				img.getHeight(null), BufferedImage.TYPE_INT_RGB);

		Graphics g = bimg.getGraphics();
		g.drawImage(img, 0, 0, null);
		g.dispose();
		return bimg;
	}

	void createPngBytes(BufferedImage bimg) throws IOException {
		pngBytes = getImageBytes(bimg, "png");
	}

	byte[] getBytes(BufferedImage img) throws IOException {
		byte[] b = getImageBytes(img, "png");
		return b;
	}

	byte[] getImageBytes(BufferedImage image, String imageFormat)
			throws IOException {
		ByteArrayOutputStream bos = new ByteArrayOutputStream();
		BufferedOutputStream os = new BufferedOutputStream(bos);
		image.flush();
		ImageIO.write(image, imageFormat, os);
		os.flush();
		os.close();
		return bos.toByteArray();
	}

	void sendPngData(Socket sock) throws IOException {
		byte[] dataLength = castIntByte(pngBytes.length);
		sock.getOutputStream().write(dataLength);
		sock.getOutputStream().write(pngBytes);
	}

	byte[] castIntByte(int len) {
		byte[] b = new byte[4];
		b[0] = (byte) ((len >>> 24) & 0xFF);
		b[1] = (byte) ((len >>> 16) & 0xFF);
		b[2] = (byte) ((len >>> 8) & 0xFF);
		b[3] = (byte) ((len >>> 0) & 0xFF);
		return b;
	}

	BufferedImage createBimg() throws IOException {
		BufferedImage bimg = ImageIO.read(new ByteArrayInputStream(pngBytes));
		return bimg;
	}
/*
	void readPngData() throws IOException {
		pngBytes = new byte[is.available()];
		readFully(pngBytes);
	}
*/
	void printFramebufferUpdate() {

		System.out.println("messageType=" + messageType);
		System.out.println("rectangles=" + rectangles);
		System.out.println("encoding=" + encoding);
		System.out.println("rectX = "+rectX+": rectY = "+rectY);
		System.out.println("rectW = "+rectW+": rectH = "+rectH);
		switch (encoding) {
		case RfbProto.EncodingRaw:
			System.out.println("rectW * rectH * 4 + 16 =" + rectW * rectH * 4
					+ 16);
			break;
		default:
		}
	}
	
	void readSpeedCheck() throws IOException {
		byte[] b = new byte[1];
		readFully(b);
	}
	
	void startSpeedCheck() {
		ByteBuffer b = ByteBuffer.allocate(10);
		b.put((byte)SpeedCheckMillis);
		b.flip();
		startCheckTime = System.currentTimeMillis();
		System.out.println("startChckTime = "+ startCheckTime);
		LinkedList<ByteBuffer>bufs = new LinkedList<ByteBuffer>();
		bufs.add(b);
		multicastqueue.put(bufs);
	}

	void endSpeedCheck() {
		long accTime = System.currentTimeMillis();
		long time = accTime - startCheckTime;
		System.out.println("checkMillis: " + time);
	}

	void printStatus() {
		System.out.println();
	}

	synchronized void changeStatusFlag() {
		printStatusFlag = true;
	}

	void printMills() {
		if(printStatusFlag) {

			changeStatusFlag();
		} else {
			changeStatusFlag();
		}
	}
	
	void speedCheckMillis() {
			Runnable stdin = new Runnable() {
			public void run() {
				int c;
				try {
					while( (c = System.in.read()) != -1 ) {
						switch(c) {
							case 's':
								break;
							default:
								startSpeedCheck();
								break;
						}
					}
				}catch(IOException e){
					System.out.println(e);
				}
			}
		};
		
		new Thread(stdin).start();
	}

	/**
	 * gzip byte arrays
	 * @param deflater
	 * @param inputs
	 *            byte len[4] total byte length
	 *            byte data[]
	 * @param outputs
	 * 	 		  byte len[4] total byte length
	 *            byte data[]
	 * @return  byte length in last byte array
	 * @throws IOException
	 */
	public int zip(Deflater deflater,LinkedList<ByteBuffer> inputs, LinkedList<ByteBuffer> outputs) throws IOException {
		int len1=0,len = 0;
		deflater.reset();
		do {
			ByteBuffer b1 = inputs.poll();
			deflater.setInput(b1.array(),b1.position(),b1.limit());
			if (inputs.size()==0) {
				deflater.finish();
			} 
			do {
				ByteBuffer c1 = ByteBuffer.allocate(INFLATE_BUFSIZE);
				len1 = deflater.deflate(c1.array(),c1.position(),c1.capacity());
				c1.limit(len1);
				if (len1>0) {
					outputs.addLast(c1);
					len += len1;
				}
			} while (len1 > 0);
		} while(inputs.size()>0);
		ByteBuffer blen = ByteBuffer.wrap(castIntByte(len));
		outputs.addFirst(blen);
		return len;
	}
	
	/**
	 * gunzip byte arrays
	 * @param inflater
	 * @param inputs
	 *            byte len[4] total byte length
	 *            byte data[]
	 * @param outputs
	 * 	 		  byte len[4]   total byte length
	 *            byte data[]
	 * @throws IOException
	 */
	public int unzip(Inflater inflater, LinkedList<ByteBuffer> inputs, LinkedList<ByteBuffer> outputs)
																	throws DataFormatException {
		int len=0,len0;
		do {
			ByteBuffer input = inputs.poll();
			inflater.setInput(input.array(),0,input.limit());
			do {
				ByteBuffer buf = ByteBuffer.allocate(INFLATE_BUFSIZE);
				len0 = inflater.inflate(buf.array(),0,buf.capacity());
				buf.limit(len0);
				len += len0;
				outputs.addLast(buf);
			} while (len0 ==INFLATE_BUFSIZE);
		} while (!inputs.isEmpty());
		return len;
	}
	
	void readSendData(int dataLen) throws IOException, DataFormatException {
		LinkedList<ByteBuffer>bufs = new LinkedList<ByteBuffer>();
		ByteBuffer header = ByteBuffer.allocate(16);
		readFully(header.array(),0,16); 
		header.limit(16);
		if (header.get(0)==RfbProto.FramebufferUpdate) {
			int encoding = header.getInt(12);
			if (encoding==RfbProto.EncodingZlib||encoding==RfbProto.EncodingZRLE) {
				ByteBuffer len = ByteBuffer.allocate(4);
				readFully(len.array(),0,4); len.limit(4);
				ByteBuffer inputData = ByteBuffer.allocate(dataLen-20);
				readFully(inputData.array(),0,inputData.capacity()); inputData.limit(dataLen-20);
				LinkedList<ByteBuffer>inputs = new LinkedList<ByteBuffer>();
				inputs.add(inputData);
				unzip(inflater, inputs, bufs);
				bufs.addFirst(header);
				multicastqueue.put(bufs);
				is.reset();
				return ;
			}
		} 
		bufs.add(header);
		if (dataLen>16) {
			ByteBuffer b = ByteBuffer.allocate(dataLen-16);
			readFully(b.array(),0,dataLen-16); b.limit(dataLen-16);
			bufs.add(b);
		}
		multicastqueue.put(bufs);
		is.reset();

		// It may be compressed. We can inflate here to avoid repeating clients decompressing here,
		// but it may generate too many large data. It is better to do it in each client.
		// But we have do inflation for all input data, so we have to do it here.
	}

	void newClient(AcceptThread acceptThread, final Socket newCli,
			final OutputStream os, final InputStream is) throws IOException {
		// createBimgFlag = true;
		// rfb.addSockTmp(newCli);
		//		addSock(newCli);
		final Client <LinkedList<ByteBuffer>> c = multicastqueue.newClient();
		Runnable sender = new Runnable() {
			public void run() {

			    Deflater deflater = new Deflater();
				try {
					/**
					 *  initial connection of RFB protocol
					 */
					sendRfbVersion(os);
					readVersionMsg(is);
					sendSecurityType(os);
					readSecType(is);
					sendSecResult(os);
					readClientInit(is);
					sendInitData(os);

					for (;;) {
						LinkedList<ByteBuffer> bufs = c.poll();
						ByteBuffer header = bufs.poll();
						if (header.get(0)==RfbProto.FramebufferUpdate) {
							int encoding = header.getInt(12);
							if (encoding==RfbProto.EncodingZlib||encoding==RfbProto.EncodingZRLE) {
								LinkedList<ByteBuffer> outs = new LinkedList<ByteBuffer>();
								int len2 = zip(deflater, bufs, outs);
								ByteBuffer blen = ByteBuffer.allocate(4); blen.putInt(len2); blen.flip();
								outs.addFirst(blen);
								outs.addFirst(header);
								while(!outs.isEmpty()) {
								   ByteBuffer out=  outs.poll();
								   os.write(out.array(),out.position(),out.limit());
								}
							}
							os.flush();
							return;
						}
						os.write(header.array(),header.position(),header.limit());
						for(ByteBuffer b : bufs) {
							os.write(b.array(), b.position(), b.limit());
						}
						os.flush();
					}
				} catch (IOException e) {
					/* if socket closed 	cliList.remove(newCli); */
				}
			}
		};
		clients++;
		new Thread(sender).start();

	}



	@Test
	public void test1() {
		try {
			LinkedList<ByteBuffer> in = new LinkedList<ByteBuffer>();
			LinkedList<ByteBuffer> out = new LinkedList<ByteBuffer>();
			LinkedList<ByteBuffer> out2 = new LinkedList<ByteBuffer>();
			for(int i=0;i<10;i++) {
				in.add(ByteBuffer.wrap("test1".getBytes()));
				in.add(ByteBuffer.wrap("test2".getBytes()));
				in.add(ByteBuffer.wrap("test3".getBytes()));
				in.add(ByteBuffer.wrap("test4".getBytes()));
			}

			Deflater deflater = new Deflater();
			zip(deflater, in,out);
			unzip(inflater, out, out2);
			for(ByteBuffer b:out) {
				ByteBuffer c = out2.poll();
				assertEquals(b,c);
			}
			System.out.println("Test Ok.");
		} catch (Exception e) {
			assertEquals(0,1);
		}
	}

}