356 lines
7.4 KiB
Go
356 lines
7.4 KiB
Go
|
package main
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"fmt"
|
||
|
"github.com/valyala/fasthttp"
|
||
|
"github.com/valyala/fasthttp/fasthttpproxy"
|
||
|
"go.uber.org/automaxprocs/maxprocs"
|
||
|
"io/ioutil"
|
||
|
"net"
|
||
|
"net/http"
|
||
|
_ "net/http/pprof"
|
||
|
url2 "net/url"
|
||
|
"os"
|
||
|
"os/signal"
|
||
|
"strconv"
|
||
|
"strings"
|
||
|
"sync"
|
||
|
"sync/atomic"
|
||
|
"syscall"
|
||
|
"time"
|
||
|
)
|
||
|
|
||
|
var (
|
||
|
startTime = time.Now()
|
||
|
)
|
||
|
|
||
|
type ReportRecord struct {
|
||
|
cost time.Duration
|
||
|
code string
|
||
|
error string
|
||
|
readBytes int64
|
||
|
writeBytes int64
|
||
|
}
|
||
|
|
||
|
var recordPool = sync.Pool{
|
||
|
New: func() interface{} { return new(ReportRecord) },
|
||
|
}
|
||
|
|
||
|
func init() {
|
||
|
go func() {
|
||
|
http.ListenAndServe("0.0.0.0:6060", nil)
|
||
|
}()
|
||
|
_, _ = maxprocs.Set()
|
||
|
}
|
||
|
|
||
|
type MyConn struct {
|
||
|
net.Conn
|
||
|
r, w *int64
|
||
|
}
|
||
|
|
||
|
func NewMyConn(conn net.Conn, r, w *int64) (*MyConn, error) {
|
||
|
myConn := &MyConn{Conn: conn, r: r, w: w}
|
||
|
return myConn, nil
|
||
|
}
|
||
|
|
||
|
func (c *MyConn) Read(b []byte) (n int, err error) {
|
||
|
sz, err := c.Conn.Read(b)
|
||
|
|
||
|
if err == nil {
|
||
|
atomic.AddInt64(c.r, int64(sz))
|
||
|
}
|
||
|
return sz, err
|
||
|
}
|
||
|
|
||
|
func (c *MyConn) Write(b []byte) (n int, err error) {
|
||
|
sz, err := c.Conn.Write(b)
|
||
|
|
||
|
if err == nil {
|
||
|
atomic.AddInt64(c.w, int64(sz))
|
||
|
}
|
||
|
return sz, err
|
||
|
}
|
||
|
|
||
|
func ThroughputInterceptorDial(dial fasthttp.DialFunc, r *int64, w *int64) fasthttp.DialFunc {
|
||
|
return func(addr string) (net.Conn, error) {
|
||
|
conn, err := dial(addr)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
return NewMyConn(conn, r, w)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
type Requester struct {
|
||
|
concurrency int
|
||
|
requests int64
|
||
|
duration time.Duration
|
||
|
clientOpt *ClientOpt
|
||
|
httpClient *fasthttp.HostClient
|
||
|
httpHeader *fasthttp.RequestHeader
|
||
|
|
||
|
recordChan chan *ReportRecord
|
||
|
report *StreamReport
|
||
|
errCount int64
|
||
|
wg sync.WaitGroup
|
||
|
|
||
|
readBytes int64
|
||
|
writeBytes int64
|
||
|
|
||
|
cancel func()
|
||
|
}
|
||
|
|
||
|
type ClientOpt struct {
|
||
|
url string
|
||
|
method string
|
||
|
headers []string
|
||
|
bodyBytes []byte
|
||
|
bodyFile string
|
||
|
|
||
|
maxConns int
|
||
|
doTimeout time.Duration
|
||
|
readTimeout time.Duration
|
||
|
writeTimeout time.Duration
|
||
|
dialTimeout time.Duration
|
||
|
|
||
|
socks5Proxy string
|
||
|
contentType string
|
||
|
host string
|
||
|
}
|
||
|
|
||
|
func NewRequester(concurrency int, requests int64, duration time.Duration, clientOpt *ClientOpt) (*Requester, error) {
|
||
|
maxResult := concurrency * 100
|
||
|
if maxResult > 8192 {
|
||
|
maxResult = 8192
|
||
|
}
|
||
|
r := &Requester{
|
||
|
concurrency: concurrency,
|
||
|
requests: requests,
|
||
|
duration: duration,
|
||
|
clientOpt: clientOpt,
|
||
|
recordChan: make(chan *ReportRecord, maxResult),
|
||
|
}
|
||
|
client, header, err := buildRequestClient(clientOpt, &r.readBytes, &r.writeBytes)
|
||
|
if err != nil {
|
||
|
return nil, err
|
||
|
}
|
||
|
r.httpClient = client
|
||
|
r.httpHeader = header
|
||
|
return r, nil
|
||
|
}
|
||
|
|
||
|
func addMissingPort(addr string, isTLS bool) string {
|
||
|
n := strings.Index(addr, ":")
|
||
|
if n >= 0 {
|
||
|
return addr
|
||
|
}
|
||
|
port := 80
|
||
|
if isTLS {
|
||
|
port = 443
|
||
|
}
|
||
|
return net.JoinHostPort(addr, strconv.Itoa(port))
|
||
|
}
|
||
|
|
||
|
func buildRequestClient(opt *ClientOpt, r *int64, w *int64) (*fasthttp.HostClient, *fasthttp.RequestHeader, error) {
|
||
|
u, err := url2.Parse(opt.url)
|
||
|
if err != nil {
|
||
|
return nil, nil, err
|
||
|
}
|
||
|
httpClient := &fasthttp.HostClient{
|
||
|
Addr: addMissingPort(u.Host, u.Scheme == "https"),
|
||
|
IsTLS: u.Scheme == "https",
|
||
|
Name: "plow",
|
||
|
MaxConns: opt.maxConns,
|
||
|
ReadTimeout: opt.readTimeout,
|
||
|
WriteTimeout: opt.writeTimeout,
|
||
|
DisableHeaderNamesNormalizing: true,
|
||
|
}
|
||
|
if opt.socks5Proxy != "" {
|
||
|
if strings.Index(opt.socks5Proxy, "://") == -1 {
|
||
|
opt.socks5Proxy = "socks5://" + opt.socks5Proxy
|
||
|
}
|
||
|
httpClient.Dial = fasthttpproxy.FasthttpSocksDialer(opt.socks5Proxy)
|
||
|
} else {
|
||
|
httpClient.Dial = fasthttpproxy.FasthttpProxyHTTPDialerTimeout(opt.dialTimeout)
|
||
|
}
|
||
|
httpClient.Dial = ThroughputInterceptorDial(httpClient.Dial, r, w)
|
||
|
|
||
|
var requestHeader fasthttp.RequestHeader
|
||
|
if opt.contentType != "" {
|
||
|
requestHeader.SetContentType(opt.contentType)
|
||
|
}
|
||
|
if opt.host != "" {
|
||
|
requestHeader.SetHost(opt.host)
|
||
|
} else {
|
||
|
requestHeader.SetHost(u.Host)
|
||
|
}
|
||
|
requestHeader.SetMethod(opt.method)
|
||
|
requestHeader.SetRequestURI(u.RequestURI())
|
||
|
for _, h := range opt.headers {
|
||
|
n := strings.SplitN(h, ":", 2)
|
||
|
if len(n) != 2 {
|
||
|
return nil, nil, fmt.Errorf("invalid header: %s", h)
|
||
|
}
|
||
|
requestHeader.Set(n[0], n[1])
|
||
|
}
|
||
|
|
||
|
return httpClient, &requestHeader, nil
|
||
|
}
|
||
|
|
||
|
func (r *Requester) Cancel() {
|
||
|
r.cancel()
|
||
|
}
|
||
|
|
||
|
func (r *Requester) RecordChan() <-chan *ReportRecord {
|
||
|
return r.recordChan
|
||
|
}
|
||
|
|
||
|
func getErrorType(err error) string {
|
||
|
switch err {
|
||
|
case fasthttp.ErrTimeout:
|
||
|
return "Timeout"
|
||
|
case fasthttp.ErrNoFreeConns:
|
||
|
return "NoFreeConns"
|
||
|
case fasthttp.ErrConnectionClosed:
|
||
|
return "ConnClosed"
|
||
|
case fasthttp.ErrDialTimeout:
|
||
|
return "DialTimeout"
|
||
|
default:
|
||
|
if opErr, ok := err.(*net.OpError); ok {
|
||
|
err = opErr.Err
|
||
|
}
|
||
|
switch t := err.(type) {
|
||
|
case *net.DNSError:
|
||
|
return "DNS"
|
||
|
case *os.SyscallError:
|
||
|
if errno, ok := t.Err.(syscall.Errno); ok {
|
||
|
switch errno {
|
||
|
case syscall.ECONNREFUSED:
|
||
|
return "ConnRefused"
|
||
|
case syscall.ETIMEDOUT:
|
||
|
return "Timeout"
|
||
|
case syscall.EADDRNOTAVAIL:
|
||
|
return "AddrNotAvail"
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
return "Unknown"
|
||
|
}
|
||
|
|
||
|
func (r *Requester) DoRequest(req *fasthttp.Request, resp *fasthttp.Response, rr *ReportRecord) {
|
||
|
t1 := time.Since(startTime)
|
||
|
var err error
|
||
|
if r.clientOpt.doTimeout > 0 {
|
||
|
err = r.httpClient.DoTimeout(req, resp, r.clientOpt.doTimeout)
|
||
|
} else {
|
||
|
err = r.httpClient.Do(req, resp)
|
||
|
}
|
||
|
var code string
|
||
|
|
||
|
if err != nil {
|
||
|
rr.cost = time.Since(startTime) - t1
|
||
|
rr.code = ""
|
||
|
rr.error = err.Error()
|
||
|
return
|
||
|
} else {
|
||
|
switch resp.StatusCode() / 100 {
|
||
|
case 1:
|
||
|
code = "1xx"
|
||
|
case 2:
|
||
|
code = "2xx"
|
||
|
case 3:
|
||
|
code = "3xx"
|
||
|
case 4:
|
||
|
code = "4xx"
|
||
|
case 5:
|
||
|
code = "5xx"
|
||
|
}
|
||
|
err = resp.BodyWriteTo(ioutil.Discard)
|
||
|
if err != nil {
|
||
|
rr.cost = time.Since(startTime) - t1
|
||
|
rr.code = ""
|
||
|
rr.error = err.Error()
|
||
|
return
|
||
|
}
|
||
|
}
|
||
|
|
||
|
rr.cost = time.Since(startTime) - t1
|
||
|
rr.code = code
|
||
|
rr.error = ""
|
||
|
}
|
||
|
|
||
|
func (r *Requester) Run() {
|
||
|
// handle ctrl-c
|
||
|
sigs := make(chan os.Signal, 1)
|
||
|
signal.Notify(sigs, os.Interrupt, syscall.SIGTERM)
|
||
|
defer signal.Stop(sigs)
|
||
|
|
||
|
ctx, cancelFunc := context.WithCancel(context.Background())
|
||
|
r.cancel = cancelFunc
|
||
|
go func() {
|
||
|
<-sigs
|
||
|
cancelFunc()
|
||
|
}()
|
||
|
if r.duration > 0 {
|
||
|
time.AfterFunc(r.duration, func() {
|
||
|
cancelFunc()
|
||
|
})
|
||
|
}
|
||
|
|
||
|
startTime = time.Now()
|
||
|
semaphore := r.requests
|
||
|
for i := 0; i < r.concurrency; i++ {
|
||
|
r.wg.Add(1)
|
||
|
go func() {
|
||
|
defer r.wg.Done()
|
||
|
req := &fasthttp.Request{}
|
||
|
resp := &fasthttp.Response{}
|
||
|
r.httpHeader.CopyTo(&req.Header)
|
||
|
if r.httpClient.IsTLS {
|
||
|
req.URI().SetScheme("https")
|
||
|
req.URI().SetHostBytes(req.Header.Host())
|
||
|
}
|
||
|
|
||
|
for {
|
||
|
select {
|
||
|
case <-ctx.Done():
|
||
|
return
|
||
|
default:
|
||
|
}
|
||
|
|
||
|
if r.requests > 0 && atomic.AddInt64(&semaphore, -1) < 0 {
|
||
|
cancelFunc()
|
||
|
return
|
||
|
}
|
||
|
|
||
|
if r.clientOpt.bodyFile != "" {
|
||
|
file, err := os.Open(r.clientOpt.bodyFile)
|
||
|
if err != nil {
|
||
|
rr := recordPool.Get().(*ReportRecord)
|
||
|
rr.cost = 0
|
||
|
rr.error = err.Error()
|
||
|
rr.readBytes = atomic.LoadInt64(&r.readBytes)
|
||
|
rr.writeBytes = atomic.LoadInt64(&r.writeBytes)
|
||
|
r.recordChan <- rr
|
||
|
continue
|
||
|
}
|
||
|
req.SetBodyStream(file, -1)
|
||
|
} else {
|
||
|
req.SetBodyRaw(r.clientOpt.bodyBytes)
|
||
|
}
|
||
|
resp.Reset()
|
||
|
rr := recordPool.Get().(*ReportRecord)
|
||
|
r.DoRequest(req, resp, rr)
|
||
|
rr.readBytes = atomic.LoadInt64(&r.readBytes)
|
||
|
rr.writeBytes = atomic.LoadInt64(&r.writeBytes)
|
||
|
r.recordChan <- rr
|
||
|
}
|
||
|
}()
|
||
|
}
|
||
|
|
||
|
r.wg.Wait()
|
||
|
close(r.recordChan)
|
||
|
}
|