Path: blob/main/vendor/github.com/chzyer/readline/remote.go
2875 views
package readline12import (3"bufio"4"bytes"5"encoding/binary"6"fmt"7"io"8"net"9"os"10"sync"11"sync/atomic"12)1314type MsgType int161516const (17T_DATA = MsgType(iota)18T_WIDTH19T_WIDTH_REPORT20T_ISTTY_REPORT21T_RAW22T_ERAW // exit raw23T_EOF24)2526type RemoteSvr struct {27eof int3228closed int3229width int3230reciveChan chan struct{}31writeChan chan *writeCtx32conn net.Conn33isTerminal bool34funcWidthChan func()35stopChan chan struct{}3637dataBufM sync.Mutex38dataBuf bytes.Buffer39}4041type writeReply struct {42n int43err error44}4546type writeCtx struct {47msg *Message48reply chan *writeReply49}5051func newWriteCtx(msg *Message) *writeCtx {52return &writeCtx{53msg: msg,54reply: make(chan *writeReply),55}56}5758func NewRemoteSvr(conn net.Conn) (*RemoteSvr, error) {59rs := &RemoteSvr{60width: -1,61conn: conn,62writeChan: make(chan *writeCtx),63reciveChan: make(chan struct{}),64stopChan: make(chan struct{}),65}66buf := bufio.NewReader(rs.conn)6768if err := rs.init(buf); err != nil {69return nil, err70}7172go rs.readLoop(buf)73go rs.writeLoop()74return rs, nil75}7677func (r *RemoteSvr) init(buf *bufio.Reader) error {78m, err := ReadMessage(buf)79if err != nil {80return err81}82// receive isTerminal83if m.Type != T_ISTTY_REPORT {84return fmt.Errorf("unexpected init message")85}86r.GotIsTerminal(m.Data)8788// receive width89m, err = ReadMessage(buf)90if err != nil {91return err92}93if m.Type != T_WIDTH_REPORT {94return fmt.Errorf("unexpected init message")95}96r.GotReportWidth(m.Data)9798return nil99}100101func (r *RemoteSvr) HandleConfig(cfg *Config) {102cfg.Stderr = r103cfg.Stdout = r104cfg.Stdin = r105cfg.FuncExitRaw = r.ExitRawMode106cfg.FuncIsTerminal = r.IsTerminal107cfg.FuncMakeRaw = r.EnterRawMode108cfg.FuncExitRaw = r.ExitRawMode109cfg.FuncGetWidth = r.GetWidth110cfg.FuncOnWidthChanged = func(f func()) {111r.funcWidthChan = f112}113}114115func (r *RemoteSvr) IsTerminal() bool {116return r.isTerminal117}118119func (r *RemoteSvr) checkEOF() error {120if atomic.LoadInt32(&r.eof) == 1 {121return io.EOF122}123return nil124}125126func (r *RemoteSvr) Read(b []byte) (int, error) {127r.dataBufM.Lock()128n, err := r.dataBuf.Read(b)129r.dataBufM.Unlock()130if n == 0 {131if err := r.checkEOF(); err != nil {132return 0, err133}134}135136if n == 0 && err == io.EOF {137<-r.reciveChan138r.dataBufM.Lock()139n, err = r.dataBuf.Read(b)140r.dataBufM.Unlock()141}142if n == 0 {143if err := r.checkEOF(); err != nil {144return 0, err145}146}147148return n, err149}150151func (r *RemoteSvr) writeMsg(m *Message) error {152ctx := newWriteCtx(m)153r.writeChan <- ctx154reply := <-ctx.reply155return reply.err156}157158func (r *RemoteSvr) Write(b []byte) (int, error) {159ctx := newWriteCtx(NewMessage(T_DATA, b))160r.writeChan <- ctx161reply := <-ctx.reply162return reply.n, reply.err163}164165func (r *RemoteSvr) EnterRawMode() error {166return r.writeMsg(NewMessage(T_RAW, nil))167}168169func (r *RemoteSvr) ExitRawMode() error {170return r.writeMsg(NewMessage(T_ERAW, nil))171}172173func (r *RemoteSvr) writeLoop() {174defer r.Close()175176loop:177for {178select {179case ctx, ok := <-r.writeChan:180if !ok {181break182}183n, err := ctx.msg.WriteTo(r.conn)184ctx.reply <- &writeReply{n, err}185case <-r.stopChan:186break loop187}188}189}190191func (r *RemoteSvr) Close() error {192if atomic.CompareAndSwapInt32(&r.closed, 0, 1) {193close(r.stopChan)194r.conn.Close()195}196return nil197}198199func (r *RemoteSvr) readLoop(buf *bufio.Reader) {200defer r.Close()201for {202m, err := ReadMessage(buf)203if err != nil {204break205}206switch m.Type {207case T_EOF:208atomic.StoreInt32(&r.eof, 1)209select {210case r.reciveChan <- struct{}{}:211default:212}213case T_DATA:214r.dataBufM.Lock()215r.dataBuf.Write(m.Data)216r.dataBufM.Unlock()217select {218case r.reciveChan <- struct{}{}:219default:220}221case T_WIDTH_REPORT:222r.GotReportWidth(m.Data)223case T_ISTTY_REPORT:224r.GotIsTerminal(m.Data)225}226}227}228229func (r *RemoteSvr) GotIsTerminal(data []byte) {230if binary.BigEndian.Uint16(data) == 0 {231r.isTerminal = false232} else {233r.isTerminal = true234}235}236237func (r *RemoteSvr) GotReportWidth(data []byte) {238atomic.StoreInt32(&r.width, int32(binary.BigEndian.Uint16(data)))239if r.funcWidthChan != nil {240r.funcWidthChan()241}242}243244func (r *RemoteSvr) GetWidth() int {245return int(atomic.LoadInt32(&r.width))246}247248// -----------------------------------------------------------------------------249250type Message struct {251Type MsgType252Data []byte253}254255func ReadMessage(r io.Reader) (*Message, error) {256m := new(Message)257var length int32258if err := binary.Read(r, binary.BigEndian, &length); err != nil {259return nil, err260}261if err := binary.Read(r, binary.BigEndian, &m.Type); err != nil {262return nil, err263}264m.Data = make([]byte, int(length)-2)265if _, err := io.ReadFull(r, m.Data); err != nil {266return nil, err267}268return m, nil269}270271func NewMessage(t MsgType, data []byte) *Message {272return &Message{t, data}273}274275func (m *Message) WriteTo(w io.Writer) (int, error) {276buf := bytes.NewBuffer(make([]byte, 0, len(m.Data)+2+4))277binary.Write(buf, binary.BigEndian, int32(len(m.Data)+2))278binary.Write(buf, binary.BigEndian, m.Type)279buf.Write(m.Data)280n, err := buf.WriteTo(w)281return int(n), err282}283284// -----------------------------------------------------------------------------285286type RemoteCli struct {287conn net.Conn288raw RawMode289receiveChan chan struct{}290inited int32291isTerminal *bool292293data bytes.Buffer294dataM sync.Mutex295}296297func NewRemoteCli(conn net.Conn) (*RemoteCli, error) {298r := &RemoteCli{299conn: conn,300receiveChan: make(chan struct{}),301}302return r, nil303}304305func (r *RemoteCli) MarkIsTerminal(is bool) {306r.isTerminal = &is307}308309func (r *RemoteCli) init() error {310if !atomic.CompareAndSwapInt32(&r.inited, 0, 1) {311return nil312}313314if err := r.reportIsTerminal(); err != nil {315return err316}317318if err := r.reportWidth(); err != nil {319return err320}321322// register sig for width changed323DefaultOnWidthChanged(func() {324r.reportWidth()325})326return nil327}328329func (r *RemoteCli) writeMsg(m *Message) error {330r.dataM.Lock()331_, err := m.WriteTo(r.conn)332r.dataM.Unlock()333return err334}335336func (r *RemoteCli) Write(b []byte) (int, error) {337m := NewMessage(T_DATA, b)338r.dataM.Lock()339_, err := m.WriteTo(r.conn)340r.dataM.Unlock()341return len(b), err342}343344func (r *RemoteCli) reportWidth() error {345screenWidth := GetScreenWidth()346data := make([]byte, 2)347binary.BigEndian.PutUint16(data, uint16(screenWidth))348msg := NewMessage(T_WIDTH_REPORT, data)349350if err := r.writeMsg(msg); err != nil {351return err352}353return nil354}355356func (r *RemoteCli) reportIsTerminal() error {357var isTerminal bool358if r.isTerminal != nil {359isTerminal = *r.isTerminal360} else {361isTerminal = DefaultIsTerminal()362}363data := make([]byte, 2)364if isTerminal {365binary.BigEndian.PutUint16(data, 1)366} else {367binary.BigEndian.PutUint16(data, 0)368}369msg := NewMessage(T_ISTTY_REPORT, data)370if err := r.writeMsg(msg); err != nil {371return err372}373return nil374}375376func (r *RemoteCli) readLoop() {377buf := bufio.NewReader(r.conn)378for {379msg, err := ReadMessage(buf)380if err != nil {381break382}383switch msg.Type {384case T_ERAW:385r.raw.Exit()386case T_RAW:387r.raw.Enter()388case T_DATA:389os.Stdout.Write(msg.Data)390}391}392}393394func (r *RemoteCli) ServeBy(source io.Reader) error {395if err := r.init(); err != nil {396return err397}398399go func() {400defer r.Close()401for {402n, _ := io.Copy(r, source)403if n == 0 {404break405}406}407}()408defer r.raw.Exit()409r.readLoop()410return nil411}412413func (r *RemoteCli) Close() {414r.writeMsg(NewMessage(T_EOF, nil))415}416417func (r *RemoteCli) Serve() error {418return r.ServeBy(os.Stdin)419}420421func ListenRemote(n, addr string, cfg *Config, h func(*Instance), onListen ...func(net.Listener) error) error {422ln, err := net.Listen(n, addr)423if err != nil {424return err425}426if len(onListen) > 0 {427if err := onListen[0](ln); err != nil {428return err429}430}431for {432conn, err := ln.Accept()433if err != nil {434break435}436go func() {437defer conn.Close()438rl, err := HandleConn(*cfg, conn)439if err != nil {440return441}442h(rl)443}()444}445return nil446}447448func HandleConn(cfg Config, conn net.Conn) (*Instance, error) {449r, err := NewRemoteSvr(conn)450if err != nil {451return nil, err452}453r.HandleConfig(&cfg)454455rl, err := NewEx(&cfg)456if err != nil {457return nil, err458}459return rl, nil460}461462func DialRemote(n, addr string) error {463conn, err := net.Dial(n, addr)464if err != nil {465return err466}467defer conn.Close()468469cli, err := NewRemoteCli(conn)470if err != nil {471return err472}473return cli.Serve()474}475476477