diff --git a/terminal/terminal.go b/terminal/terminal.go new file mode 100644 index 0000000..c127e1f --- /dev/null +++ b/terminal/terminal.go @@ -0,0 +1,113 @@ +package terminal + +import ( + "encoding/json" + "fmt" + + "github.com/gorilla/websocket" +) + +// Terminal 接口定义平台特定的终端操作 +type Terminal interface { + Close() error + Read(p []byte) (int, error) + Write(p []byte) (int, error) + Resize(cols, rows int) error + Wait() error +} + +// terminalImpl 封装终端和平台特定逻辑 +type terminalImpl struct { + shell string + workingDir string + term Terminal +} + +// StartTerminal 启动终端并处理 WebSocket 通信 +func StartTerminal(conn *websocket.Conn) { + impl, err := newTerminalImpl() + if err != nil { + conn.WriteMessage(websocket.TextMessage, []byte(fmt.Sprintf("Error: %v\r\n", err))) + return + } + + errChan := make(chan error, 1) + defer impl.term.Close() + // 从 WebSocket 读取消息并写入终端 + go handleWebSocketInput(conn, impl.term, errChan) + + // 从终端读取输出并写入 WebSocket + go handleTerminalOutput(conn, impl.term, errChan) + + // 错误处理和清理 + go func() { + err := <-errChan + if err != nil && conn != nil { + conn.WriteMessage(websocket.TextMessage, []byte(fmt.Sprintf("Error: %v\r\n", err))) + conn.Close() + } + impl.term.Close() + }() + // 等待终端进程结束 + if err := impl.term.Wait(); err != nil { + select { + case errChan <- err: + // 错误已发送 + default: + // 错误通道已满或已关闭 + conn.WriteMessage(websocket.TextMessage, []byte(fmt.Sprintf("Terminal exited with error: %v\r\n", err))) + } + } +} + +// handleWebSocketInput 处理 WebSocket 输入 +func handleWebSocketInput(conn *websocket.Conn, term Terminal, errChan chan<- error) { + for { + t, p, err := conn.ReadMessage() + if err != nil { + errChan <- err + return + } + if t == websocket.TextMessage { + var cmd struct { + Type string `json:"type"` + Cols int `json:"cols,omitempty"` + Rows int `json:"rows,omitempty"` + Input string `json:"input,omitempty"` + } + if err := json.Unmarshal(p, &cmd); err == nil { + switch cmd.Type { + case "resize": + if cmd.Cols > 0 && cmd.Rows > 0 { + term.Resize(cmd.Cols, cmd.Rows) + } + case "input": + if cmd.Input != "" { + term.Write([]byte(cmd.Input)) + } + } + } else { + term.Write(p) + } + } + if t == websocket.BinaryMessage { + term.Write(p) + } + } +} + +// handleTerminalOutput 处理终端输出 +func handleTerminalOutput(conn *websocket.Conn, term Terminal, errChan chan<- error) { + buf := make([]byte, 4096) + for { + n, err := term.Read(buf) + if err != nil { + errChan <- err + return + } + if err := conn.WriteMessage(websocket.BinaryMessage, buf[:n]); err != nil { + errChan <- err + return + } + } +} diff --git a/terminal/terminal_unix.go b/terminal/terminal_unix.go index 32a3d12..3b20a6d 100644 --- a/terminal/terminal_unix.go +++ b/terminal/terminal_unix.go @@ -3,117 +3,73 @@ package terminal import ( - "encoding/json" "fmt" + "os" "os/exec" "syscall" "github.com/creack/pty" - "github.com/gorilla/websocket" ) -// StartTerminal 在Unix/Linux系统上启动终端 -func StartTerminal(conn *websocket.Conn) { - // 获取shell - defalut_shell := []string{"zsh", "bash", "sh"} +func newTerminalImpl() (*terminalImpl, error) { + // 查找可用 shell + defaultShells := []string{"zsh", "bash", "sh"} shell := "" - for _, s := range defalut_shell { + for _, s := range defaultShells { if _, err := exec.LookPath(s); err == nil { shell = s break } } if shell == "" { - conn.WriteMessage(websocket.TextMessage, []byte("No supported shell found.")) - return + return nil, fmt.Errorf("no supported shell found") } + // 创建进程 cmd := exec.Command(shell) cmd.Env = append(cmd.Env, "TERM=xterm-256color") tty, err := pty.Start(cmd) if err != nil { - conn.WriteMessage(websocket.TextMessage, []byte(fmt.Sprintf("Error: %v\r\n", err))) - return + return nil, fmt.Errorf("failed to start pty: %v", err) } - defer tty.Close() - // 设置终端大小 - pty.Setsize(tty, &pty.Winsize{ - Rows: 24, - Cols: 80, - X: 0, - Y: 0, - }) - terminateConn := func() { - pgid, err := syscall.Getpgid(cmd.Process.Pid) - if err != nil { - cmd.Process.Kill() - } - syscall.Kill(-pgid, syscall.SIGKILL) - if conn != nil { - conn.Close() - } - } - err_chan := make(chan error, 1) - // 从WebSocket读取数据并写入pty - go func() { - for { - t, p, err := conn.ReadMessage() - if err != nil { - err_chan <- err - return - } - if t == websocket.TextMessage { - var cmd struct { - Type string `json:"type"` - Cols int `json:"cols,omitempty"` - Rows int `json:"rows,omitempty"` - Input string `json:"input,omitempty"` - } - if err := json.Unmarshal(p, &cmd); err == nil { - switch cmd.Type { - case "resize": - if cmd.Cols > 0 && cmd.Rows > 0 { - pty.Setsize(tty, &pty.Winsize{ - Rows: uint16(cmd.Rows), - Cols: uint16(cmd.Cols), - }) - } - case "input": - if cmd.Input != "" { - tty.Write([]byte(cmd.Input)) - } - } - } else { - tty.Write(p) - } - } - if t == websocket.BinaryMessage { - tty.Write(p) - } - } - }() + // 设置初始终端大小 + pty.Setsize(tty, &pty.Winsize{Rows: 24, Cols: 80}) - go func() { - buf := make([]byte, 4096) - for { - n, err := tty.Read(buf) - if err != nil { - err_chan <- err - return - } - - err = conn.WriteMessage(websocket.BinaryMessage, buf[:n]) - if err != nil { - err_chan <- err - return - } - } - }() - - err = <-err_chan - if err != nil && conn != nil { - conn.WriteMessage(websocket.TextMessage, []byte(fmt.Sprintf("Error: %v\r\n", err))) - } - terminateConn() + return &terminalImpl{ + shell: shell, + term: &unixTerminal{ + tty: tty, + cmd: cmd, + }, + }, nil +} + +type unixTerminal struct { + tty *os.File + cmd *exec.Cmd +} + +func (t *unixTerminal) Close() error { + pgid, err := syscall.Getpgid(t.cmd.Process.Pid) + if err != nil { + return t.cmd.Process.Kill() + } + return syscall.Kill(-pgid, syscall.SIGKILL) +} + +func (t *unixTerminal) Read(p []byte) (int, error) { + return t.tty.Read(p) +} + +func (t *unixTerminal) Write(p []byte) (int, error) { + return t.tty.Write(p) +} + +func (t *unixTerminal) Resize(cols, rows int) error { + return pty.Setsize(t.tty, &pty.Winsize{Rows: uint16(rows), Cols: uint16(cols)}) +} + +func (t *unixTerminal) Wait() error { + return t.cmd.Wait() } diff --git a/terminal/terminal_windows.go b/terminal/terminal_windows.go index 526d0e9..6d431b7 100644 --- a/terminal/terminal_windows.go +++ b/terminal/terminal_windows.go @@ -4,105 +4,77 @@ package terminal import ( "context" - "encoding/json" "fmt" "os" "os/exec" "path/filepath" "github.com/UserExistsError/conpty" - "github.com/gorilla/websocket" ) -// StartTerminal 在Windows系统上启动终端 -func StartTerminal(conn *websocket.Conn) { - // 创建进程 +func newTerminalImpl() (*terminalImpl, error) { + // 查找 shell shell, err := exec.LookPath("powershell.exe") if err != nil || shell == "" { shell = "cmd.exe" } - current_dir := "." - executable, err := os.Executable() - if err == nil { - current_dir = filepath.Dir(executable) - } - if shell == "" || current_dir == "" { - conn.WriteMessage(websocket.TextMessage, []byte("No supported shell found.")) - return + if shell == "" { + return nil, fmt.Errorf("no supported shell found") } - tty, err := conpty.Start(shell, conpty.ConPtyWorkDir(current_dir)) - if err != nil { - conn.WriteMessage(websocket.TextMessage, []byte(fmt.Sprintf("Error: %v\r\n", err))) - return + // 获取工作目录 + workingDir := "." + if executable, err := os.Executable(); err == nil { + workingDir = filepath.Dir(executable) } - defer tty.Close() - err_chan := make(chan error, 1) - // 设置终端大小 + + // 启动 ConPTY + tty, err := conpty.Start(shell, conpty.ConPtyWorkDir(workingDir)) + if err != nil { + return nil, fmt.Errorf("failed to start conpty: %v", err) + } + + // 设置初始终端大小 tty.Resize(80, 24) - go func() { - for { - t, p, err := conn.ReadMessage() - if err != nil { - err_chan <- err - return - } - if t == websocket.TextMessage { - var cmd struct { - Type string `json:"type"` - Cols int `json:"cols,omitempty"` - Rows int `json:"rows,omitempty"` - Input string `json:"input,omitempty"` - } - - if err := json.Unmarshal(p, &cmd); err == nil { - switch cmd.Type { - case "resize": - if cmd.Cols > 0 && cmd.Rows > 0 { - tty.Resize(cmd.Cols, cmd.Rows) - } - case "input": - if cmd.Input != "" { - tty.Write([]byte(cmd.Input)) - } - } - } else { - tty.Write(p) - } - } - if t == websocket.BinaryMessage { - tty.Write(p) - } - } - }() - - go func() { - buf := make([]byte, 4096) - for { - n, err := tty.Read(buf) - if err != nil { - err_chan <- err - return - } - - err = conn.WriteMessage(websocket.BinaryMessage, buf[:n]) - if err != nil { - err_chan <- err - return - } - } - }() - - go func() { - err := <-err_chan - if err != nil && tty != nil { - conn.WriteMessage(websocket.TextMessage, []byte(fmt.Sprintf("Error: %v\r\n", err))) - } - conn.Close() - tty.Close() - }() - tty.Wait(context.Background()) - tty.Close() - + return &terminalImpl{ + shell: shell, + workingDir: workingDir, + term: &windowsTerminal{ + tty: tty, + }, + }, nil +} + +type windowsTerminal struct { + tty *conpty.ConPty + closed bool +} + +func (t *windowsTerminal) Close() error { + if t.closed { + return nil + } + if err := t.tty.Close(); err != nil { + return err + } + t.closed = true + return nil +} + +func (t *windowsTerminal) Read(p []byte) (int, error) { + return t.tty.Read(p) +} + +func (t *windowsTerminal) Write(p []byte) (int, error) { + return t.tty.Write(p) +} + +func (t *windowsTerminal) Resize(cols, rows int) error { + return t.tty.Resize(cols, rows) +} + +func (t *windowsTerminal) Wait() error { + _, err := t.tty.Wait(context.Background()) + return err }