图灵手写tomcat

异步BIO方式获取并处理socket连接

package org.malred;

import java.io.IOException;
import java.io.InputStream;
import java.net.ServerSocket;
import java.net.Socket;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

public class Tomcat {
    public void start() {
        // socket连接 TCP
        try {
            // 20个线程的线程池
            ExecutorService executorService = Executors.newFixedThreadPool(20);
            // 创建socket
            ServerSocket serverSocket = new ServerSocket(8080);
            // 可以处理多次连接
            while (true) {
                // accept(会阻塞)等待并处理连接
                Socket socket = serverSocket.accept();
                // 多线程处理,防止阻塞
                executorService.execute(new SocketProcessor(socket));
            }
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    public void start(int port) {
        // socket连接 TCP
        try {
            // 20个线程的线程池
            ExecutorService executorService = Executors.newFixedThreadPool(20);
            // 创建socket
            ServerSocket serverSocket = new ServerSocket(port);
            // 可以处理多次连接
            while (true) {
                // accept(会阻塞)等待并处理连接
                Socket socket = serverSocket.accept();
                // 多线程处理,防止阻塞
                executorService.execute(new SocketProcessor(socket));
            }
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }
}
package org.malred;

import java.net.Socket;

/**
 * 处理socket的类
 */
public class SocketProcessor implements Runnable {
    private Socket socket;

    public SocketProcessor(Socket socket) {
        this.socket = socket;
    }

    @Override
    public void run() {
processSocket(socket);
    }

    // 处理socket
    private void processSocket(Socket socket) {
        System.out.println("process");
    }
}
import org.junit.Test;
import org.malred.Tomcat;

public class tomcat {
    @Test
    public void start() {
        Tomcat tomcat = new Tomcat();
        tomcat.start(8888);
    }
}
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>

    <groupId>org.malred.itomcat</groupId>
    <artifactId>itomcat</artifactId>
    <version>1.0-SNAPSHOT</version>

    <properties>
        <maven.compiler.source>11</maven.compiler.source>
        <maven.compiler.target>11</maven.compiler.target>
        <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
    </properties>
    <dependencies>
        <dependency>
            <groupId>javax.servlet</groupId>
            <artifactId>javax.servlet-api</artifactId>
            <version>3.1.0</version>
        </dependency>
        <dependency>
            <groupId>junit</groupId>
            <artifactId>junit</artifactId>
            <version>4.13.2</version>
            <scope>test</scope>
        </dependency>
    </dependencies>
</project>

按http协议解析字节流

package org.malred;

import java.io.IOException;
import java.io.InputStream;
import java.net.Socket;

/**
 * 处理socket的类
 */
public class SocketProcessor implements Runnable {
    private Socket socket;

    public SocketProcessor(Socket socket) {
        this.socket = socket;
    }

    @Override
    public void run() {
        processSocket(socket);
    }

    // 处理socket
    private void processSocket(Socket socket) {
        System.out.println("接收到请求,开始处理......");
        try {
            InputStream inputStream = socket.getInputStream();

            byte[] bytes = new byte[1024];
            // 得到http请求字节流(GET / http 1.1 ...)
            inputStream.read(bytes);

            // 解析字节流
            int pos = 0;
            int begin = 0, end = 0;
            for (; pos < bytes.length; pos++, end++) {
                // 得到第一个空格位置 -> end
                if (bytes[pos] == ' ') break;
            }

            // 组合了空格之前的字节流,转换成字符串就是请求方法
            StringBuilder method = new StringBuilder();
            for (; begin < end; begin++) {
                // 从 0~end 第一个空格前是请求方法
                method.append((char) bytes[begin]);
            }

            System.out.println("该请求的方法是:" + method);

            pos++;
            begin++;
            end++;
            for (; pos < bytes.length; pos++, end++) {
                // 得到第二个空格位置 -> end
                if (bytes[pos] == ' ') break;
            }

            StringBuilder url = new StringBuilder();
            for (; begin < end; begin++) {
                // 从 第一个空格 到 第2个空格 前是url
                url.append((char) bytes[begin]);
            }

            System.out.println("该请求的url是:" + url);

            pos++;
            begin++;
            end++;
            for (; pos < bytes.length; pos++, end++) {
                if (bytes[pos] == '\r' && bytes[pos + 1] == '\n') break;
            }

            StringBuilder protoc = new StringBuilder();
            for (; begin < end; begin++) {
                protoc.append((char) bytes[begin]);
            }

            System.out.println("该请求的http协议版本是:" + protoc);

            Request request = new Request(method.toString(), url.toString(), protoc.toString());
            // 匹配servlet
            
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }
}

按servlet规范实现Request和Response

package org.malred;

public class Request {
    private String method;
    private String url;
    private String protoc;

    public Request(String method, String url, String protoc) {
        this.method = method;
        this.url = url;
        this.protoc = protoc;
    }

    public String getMethod() {
        return method;
    }

    public String getUrl() {
        return url;
    }

    public String getProtoc() {
        return protoc;
    }
}

package org.malred;

import javax.servlet.ServletOutputStream;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;

public class IResponse extends AbstractHttpServletResponse {

    private int status = 200;
    private String message = "OK";
    private Map<String, String> headers = new HashMap<>();

    @Override
    public int getStatus() {
        return super.getStatus();
    }

    @Override
    public void setStatus(int sc) {
        this.status = sc;
    }

    @Override
    public void setStatus(int sc, String s) {
        this.status = sc;
        this.message = s;
    }

    @Override
    public void addHeader(String name, String value) {
        headers.put(name, value);
    } 
}
package org.malred;

import javax.servlet.ServletException;
import java.io.IOException;
import java.io.InputStream;
import java.net.Socket;

/**
 * 处理socket的类
 */
public class SocketProcessor implements Runnable {
    private Socket socket;

    public SocketProcessor(Socket socket) {
        this.socket = socket;
    }

    @Override
    public void run() {
        processSocket(socket);
    }

    // 处理socket
    private void processSocket(Socket socket) {
        System.out.println("接收到请求,开始处理......");
        try {
            // ...

            IRequest request = new IRequest(method.toString(), url.toString(), protoc.toString());
            IResponse response = new IResponse();

            // 匹配servlet
            IServlet servlet = new IServlet();
            // 让servlet的http来判断请求方法,然后进入对应的到doxxx方法
            servlet.service(request, response);
        } catch (IOException e) {
            throw new RuntimeException(e);
        } catch (ServletException e) {
            throw new RuntimeException(e);
        }
    }
}
package org.malred;

import javax.servlet.ServletOutputStream;
import javax.servlet.http.Cookie;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.Collection;
import java.util.Locale;

public class AbstractHttpServletResponse implements HttpServletResponse {
    @Override
    public void addCookie(Cookie cookie) {

    }

    @Override
    public boolean containsHeader(String name) {
        return false;
    }

    @Override
    public String encodeURL(String url) {
        return null;
    }

    @Override
    public String encodeRedirectURL(String url) {
        return null;
    }

    @Override
    public String encodeUrl(String url) {
        return null;
    }

    @Override
    public String encodeRedirectUrl(String url) {
        return null;
    }

    @Override
    public void sendError(int sc, String msg) throws IOException {

    }

    @Override
    public void sendError(int sc) throws IOException {

    }

    @Override
    public void sendRedirect(String location) throws IOException {

    }

    @Override
    public void setDateHeader(String name, long date) {

    }

    @Override
    public void addDateHeader(String name, long date) {

    }

    @Override
    public void setHeader(String name, String value) {

    }

    @Override
    public void addHeader(String name, String value) {

    }

    @Override
    public void setIntHeader(String name, int value) {

    }

    @Override
    public void addIntHeader(String name, int value) {

    }

    @Override
    public void setStatus(int sc) {

    }

    @Override
    public void setStatus(int sc, String sm) {

    }

    @Override
    public int getStatus() {
        return 0;
    }

    @Override
    public String getHeader(String name) {
        return null;
    }

    @Override
    public Collection<String> getHeaders(String name) {
        return null;
    }

    @Override
    public Collection<String> getHeaderNames() {
        return null;
    }

    @Override
    public String getCharacterEncoding() {
        return null;
    }

    @Override
    public String getContentType() {
        return null;
    }

    @Override
    public ServletOutputStream getOutputStream() throws IOException {
        return null;
    }

    @Override
    public PrintWriter getWriter() throws IOException {
        return null;
    }

    @Override
    public void setCharacterEncoding(String charset) {

    }

    @Override
    public void setContentLength(int len) {

    }

    @Override
    public void setContentLengthLong(long len) {

    }

    @Override
    public void setContentType(String type) {

    }

    @Override
    public void setBufferSize(int size) {

    }

    @Override
    public int getBufferSize() {
        return 0;
    }

    @Override
    public void flushBuffer() throws IOException {

    }

    @Override
    public void resetBuffer() {

    }

    @Override
    public boolean isCommitted() {
        return false;
    }

    @Override
    public void reset() {

    }

    @Override
    public void setLocale(Locale loc) {

    }

    @Override
    public Locale getLocale() {
        return null;
    }
}
package org.malred;

public class IRequest extends AbstractHttpServletRequest {
    private String method;
    private String url;
    private String protoc;

    public IRequest(String method, String url, String protoc) {
        this.method = method;
        this.url = url;
        this.protoc = protoc;
    }

    @Override
    public String getMethod() {
        return method;
    }

    @Override
    public StringBuffer getRequestURL() {
        return new StringBuffer(url);
    }

    @Override
    public String getProtocol() {
        return protoc;
    }
}
package org.malred;

import javax.servlet.*;
import javax.servlet.http.*;
import java.io.BufferedReader;
import java.io.IOException;
import java.io.UnsupportedEncodingException;
import java.security.Principal;
import java.util.Collection;
import java.util.Enumeration;
import java.util.Locale;
import java.util.Map;

// 该抽象类提供空实现,子类选择性实现方法
public class AbstractHttpServletRequest implements HttpServletRequest {
    @Override
    public String getAuthType() {
        return null;
    }

    @Override
    public Cookie[] getCookies() {
        return new Cookie[0];
    }

    @Override
    public long getDateHeader(String name) {
        return 0;
    }

    @Override
    public String getHeader(String name) {
        return null;
    }

    @Override
    public Enumeration<String> getHeaders(String name) {
        return null;
    }

    @Override
    public Enumeration<String> getHeaderNames() {
        return null;
    }

    @Override
    public int getIntHeader(String name) {
        return 0;
    }

    @Override
    public String getMethod() {
        return null;
    }

    @Override
    public String getPathInfo() {
        return null;
    }

    @Override
    public String getPathTranslated() {
        return null;
    }

    @Override
    public String getContextPath() {
        return null;
    }

    @Override
    public String getQueryString() {
        return null;
    }

    @Override
    public String getRemoteUser() {
        return null;
    }

    @Override
    public boolean isUserInRole(String role) {
        return false;
    }

    @Override
    public Principal getUserPrincipal() {
        return null;
    }

    @Override
    public String getRequestedSessionId() {
        return null;
    }

    @Override
    public String getRequestURI() {
        return null;
    }

    @Override
    public StringBuffer getRequestURL() {
        return null;
    }

    @Override
    public String getServletPath() {
        return null;
    }

    @Override
    public HttpSession getSession(boolean create) {
        return null;
    }

    @Override
    public HttpSession getSession() {
        return null;
    }

    @Override
    public String changeSessionId() {
        return null;
    }

    @Override
    public boolean isRequestedSessionIdValid() {
        return false;
    }

    @Override
    public boolean isRequestedSessionIdFromCookie() {
        return false;
    }

    @Override
    public boolean isRequestedSessionIdFromURL() {
        return false;
    }

    @Override
    public boolean isRequestedSessionIdFromUrl() {
        return false;
    }

    @Override
    public boolean authenticate(HttpServletResponse response) throws IOException, ServletException {
        return false;
    }

    @Override
    public void login(String username, String password) throws ServletException {

    }

    @Override
    public void logout() throws ServletException {

    }

    @Override
    public Collection<Part> getParts() throws IOException, ServletException {
        return null;
    }

    @Override
    public Part getPart(String name) throws IOException, ServletException {
        return null;
    }

    @Override
    public <T extends HttpUpgradeHandler> T upgrade(Class<T> handlerClass) throws IOException, ServletException {
        return null;
    }

    @Override
    public Object getAttribute(String name) {
        return null;
    }

    @Override
    public Enumeration<String> getAttributeNames() {
        return null;
    }

    @Override
    public String getCharacterEncoding() {
        return null;
    }

    @Override
    public void setCharacterEncoding(String env) throws UnsupportedEncodingException {

    }

    @Override
    public int getContentLength() {
        return 0;
    }

    @Override
    public long getContentLengthLong() {
        return 0;
    }

    @Override
    public String getContentType() {
        return null;
    }

    @Override
    public ServletInputStream getInputStream() throws IOException {
        return null;
    }

    @Override
    public String getParameter(String name) {
        return null;
    }

    @Override
    public Enumeration<String> getParameterNames() {
        return null;
    }

    @Override
    public String[] getParameterValues(String name) {
        return new String[0];
    }

    @Override
    public Map<String, String[]> getParameterMap() {
        return null;
    }

    @Override
    public String getProtocol() {
        return null;
    }

    @Override
    public String getScheme() {
        return null;
    }

    @Override
    public String getServerName() {
        return null;
    }

    @Override
    public int getServerPort() {
        return 0;
    }

    @Override
    public BufferedReader getReader() throws IOException {
        return null;
    }

    @Override
    public String getRemoteAddr() {
        return null;
    }

    @Override
    public String getRemoteHost() {
        return null;
    }

    @Override
    public void setAttribute(String name, Object o) {

    }

    @Override
    public void removeAttribute(String name) {

    }

    @Override
    public Locale getLocale() {
        return null;
    }

    @Override
    public Enumeration<Locale> getLocales() {
        return null;
    }

    @Override
    public boolean isSecure() {
        return false;
    }

    @Override
    public RequestDispatcher getRequestDispatcher(String path) {
        return null;
    }

    @Override
    public String getRealPath(String path) {
        return null;
    }

    @Override
    public int getRemotePort() {
        return 0;
    }

    @Override
    public String getLocalName() {
        return null;
    }

    @Override
    public String getLocalAddr() {
        return null;
    }

    @Override
    public int getLocalPort() {
        return 0;
    }

    @Override
    public ServletContext getServletContext() {
        return null;
    }

    @Override
    public AsyncContext startAsync() throws IllegalStateException {
        return null;
    }

    @Override
    public AsyncContext startAsync(ServletRequest servletRequest, ServletResponse servletResponse) throws IllegalStateException {
        return null;
    }

    @Override
    public boolean isAsyncStarted() {
        return false;
    }

    @Override
    public boolean isAsyncSupported() {
        return false;
    }

    @Override
    public AsyncContext getAsyncContext() {
        return null;
    }

    @Override
    public DispatcherType getDispatcherType() {
        return null;
    }
}
package org.malred;

import javax.servlet.ServletException;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;

public class IServlet extends HttpServlet {
    @Override
    protected void doGet(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
        System.out.println("收到GET请求,doGet进行处理");
        resp.getOutputStream().write("hello world".getBytes());
    }

    @Override
    protected void doPost(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
        System.out.println("收到POST请求,doPOST进行处理");
        super.doPost(req, resp);
    }

    @Override
    protected void doDelete(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
        super.doDelete(req, resp);
    }

    @Override
    protected void doPut(HttpServletRequest req, HttpServletResponse resp) throws ServletException, IOException {
        super.doPut(req, resp);
    }
}

暂存响应体

当前IResponse还没有实现getOutputStream方法,所以运行会报空指针异常

package org.malred;

import javax.servlet.ServletOutputStream;
import javax.servlet.WriteListener;
import java.io.IOException; 
import java.util.Arrays;

public class ResponseServletOutputStream extends ServletOutputStream {
    private byte[] bytes = new byte[1024];
    private int pos = 0;

    @Override
    public boolean isReady() {
        return false;
    }

    @Override
    public void setWriteListener(WriteListener writeListener) {

    }

    public byte[] getBytes() {
        return bytes;
    }

    public int getPos() {
        return pos;
    }

    @Override
    public void write(int b) throws IOException {
        // 因为也可能多次调用write,所以不应该write直接回写数据
        // 而是先暂存响应体/响应数据,最后判断要不要发送
        if (bytes.length - 1 == pos) {
            // 扩容
            bytes = Arrays.copyOf(bytes, bytes.length * 2);
        }
        bytes[pos] = (byte) b;
        pos++;
    }
}
package org.malred;

import javax.servlet.ServletOutputStream;
import java.io.IOException;
import java.util.HashMap;
import java.util.Map;

public class IResponse extends AbstractHttpServletResponse {

    private int status = 200;
    private String message = "OK";
    private Map<String, String> headers = new HashMap<>();

    // ...

    // 响应完成
    public void complete(){

    }
}
package org.malred;

import javax.servlet.ServletException;
import java.io.IOException;
import java.io.InputStream;
import java.net.Socket;

/**
 * 处理socket的类
 */
public class SocketProcessor implements Runnable {
    private Socket socket;

    public SocketProcessor(Socket socket) {
        this.socket = socket;
    }

    @Override
    public void run() {
        processSocket(socket);
    }

    // 处理socket
    private void processSocket(Socket socket) {
        System.out.println("接收到请求,开始处理......");
        try {
            InputStream inputStream = socket.getInputStream();

            byte[] bytes = new byte[1024];
            // 得到http请求字节流(GET / http 1.1 ...)
            inputStream.read(bytes);

            // 解析字节流
            int pos = 0;
            int begin = 0, end = 0;
            for (; pos < bytes.length; pos++, end++) {
                // 得到第一个空格位置 -> end
                if (bytes[pos] == ' ') break;
            }

            // 组合了空格之前的字节流,转换成字符串就是请求方法
            StringBuilder method = new StringBuilder();
            for (; begin < end; begin++) {
                // 从 0~end 第一个空格前是请求方法
                method.append((char) bytes[begin]);
            }

            System.out.println("该请求的方法是:" + method);

            pos++;
            begin++;
            end++;
            for (; pos < bytes.length; pos++, end++) {
                // 得到第二个空格位置 -> end
                if (bytes[pos] == ' ') break;
            }

            StringBuilder url = new StringBuilder();
            for (; begin < end; begin++) {
                // 从 第一个空格 到 第2个空格 前是url
                url.append((char) bytes[begin]);
            }

            System.out.println("该请求的url是:" + url);

            pos++;
            begin++;
            end++;
            for (; pos < bytes.length; pos++, end++) {
                if (bytes[pos] == '\r' && bytes[pos + 1] == '\n') break;
            }

            StringBuilder protoc = new StringBuilder();
            for (; begin < end; begin++) {
                protoc.append((char) bytes[begin]);
            }

            System.out.println("该请求的http协议版本是:" + protoc);

            IRequest request = new IRequest(method.toString(), url.toString(), protoc.toString());
            IResponse response = new IResponse();

            // 匹配servlet
            IServlet servlet = new IServlet();
            // 让servlet的http来判断请求方法,然后进入对应的到doxxx方法
            servlet.service(request, response);

            // 发送响应
            response.complete();
        } catch (IOException e) {
            throw new RuntimeException(e);
        } catch (ServletException e) {
            throw new RuntimeException(e);
        }
    }
}

按http协议发送响应数据


  转载请注明: malred-blog 图灵手写tomcat

  目录