package com.ora.rmibook.chapter22.tunneler;
import java.io.*;
import java.net.*;
import java.util.*;
import java.rmi.registry.*;
import java.rmi.*;
import java.rmi.server.*;
import javax.servlet.http.*;
import javax.servlet.*;
import sun.rmi.server.MarshalInputStream;
public class ServletForwardCommand {
public static void execute(HttpServletRequest request, HttpServletResponse response, String stringifiedPort)
throws ServletClientException, ServletServerException, IOException {
int port = convertStringToPort(stringifiedPort);
Socket connectionToLocalServer = null;
try {
connectionToLocalServer = connectToLocalServer(port);
forwardRequest(request, connectionToLocalServer);
forwardResponse(response, connectionToLocalServer);
} finally {
if (null != connectionToLocalServer) {
connectionToLocalServer.close();
}
}
}
private static int convertStringToPort(String stringfiedPort) throws ServletClientException {
int returnValue;
try {
returnValue = Integer.parseInt(stringfiedPort);
} catch (NumberFormatException e) {
throw new ServletClientException("invalid port number: " + stringfiedPort);
}
if (returnValue <= 0 || returnValue > 0xFFFF) {
throw new ServletClientException("invalid port: " + returnValue);
}
if (returnValue < 1024) {
throw new ServletClientException("permission denied for port: " + returnValue);
}
return returnValue;
}
private static Socket connectToLocalServer(int port) throws ServletServerException {
Socket returnValue;
try {
returnValue = new Socket(InetAddress.getLocalHost(), port);
} catch (IOException e) {
throw new ServletServerException("could not connect to " + "local port");
}
return returnValue;
}
private static void forwardRequest(HttpServletRequest request, Socket connectionToLocalServer)
throws IOException, ServletClientException, ServletServerException {
byte buffer[];
DataInputStream clientIn = new DataInputStream(request.getInputStream());
buffer = new byte[request.getContentLength()];
try {
clientIn.readFully(buffer);
} catch (EOFException e) {
throw new ServletClientException("unexpected EOF " + "reading request body");
} catch (IOException e) {
throw new ServletClientException("error reading request" + " body");
}
DataOutputStream socketOut = null;
// send to local server in HTTP
try {
socketOut = new DataOutputStream(connectionToLocalServer.getOutputStream());
socketOut.writeBytes("POST / HTTP/1.0\r\n");
socketOut.writeBytes("Content-length: " + request.getContentLength() + "\r\n\r\n");
socketOut.write(buffer);
socketOut.flush();
} catch (IOException e) {
throw new ServletServerException("error writing to server");
}
}
private static void forwardResponse(HttpServletResponse response, Socket connectionToLocalServer)
throws IOException, ServletClientException, ServletServerException {
byte[] buffer;
DataInputStream socketIn;
try {
socketIn = new DataInputStream(connectionToLocalServer.getInputStream());
} catch (IOException e) {
throw new ServletServerException("error reading from " + "server");
}
String key = "Content-length:".toLowerCase();
boolean contentLengthFound = false;
String line;
int responseContentLength = -1;
do {
try {
line = socketIn.readLine();
} catch (IOException e) {
throw new ServletServerException("error reading from server");
}
if (line == null) {
throw new ServletServerException("unexpected EOF reading server response");
}
if (line.toLowerCase().startsWith(key)) {
responseContentLength = Integer.parseInt(line.substring(key.length()).trim());
contentLengthFound = true;
}
}
while ((line.length() != 0) &&
(line.charAt(0) != '\r') && (line.charAt(0) != '\n'));
if (!contentLengthFound || responseContentLength < 0)
throw new ServletServerException("missing or invalid content length in server response");
buffer = new byte[responseContentLength];
try {
socketIn.readFully(buffer);
} catch (EOFException e) {
throw new ServletServerException("unexpected EOF reading server response");
} catch (IOException e) {
throw new ServletServerException("error reading from server");
}
response.setStatus(HttpServletResponse.SC_OK);
response.setContentType("application/octet-stream");
response.setContentLength(buffer.length);
try {
OutputStream out = response.getOutputStream();
out.write(buffer);
out.flush();
} catch (IOException e) {
throw new ServletServerException("error writing response");
}
}
}
|