Skip to content

Commit

Permalink
WSL Integration (#1031)
Browse files Browse the repository at this point in the history
Adds support for connecting to local WSL installations on Windows.

(also adds wshrpcmmultiproxy / connserver router)
  • Loading branch information
oneirocosm authored Oct 24, 2024
1 parent 4e86b67 commit 8248637
Show file tree
Hide file tree
Showing 31 changed files with 2,101 additions and 75 deletions.
2 changes: 1 addition & 1 deletion .gitattributes
Original file line number Diff line number Diff line change
@@ -1 +1 @@
* text=auto
* text=auto eol=lf
4 changes: 2 additions & 2 deletions cmd/server/main-server.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,11 +159,11 @@ func shutdownActivityUpdate() {

func createMainWshClient() {
rpc := wshserver.GetMainRpcClient()
wshutil.DefaultRouter.RegisterRoute(wshutil.DefaultRoute, rpc)
wshutil.DefaultRouter.RegisterRoute(wshutil.DefaultRoute, rpc, true)
wps.Broker.SetClient(wshutil.DefaultRouter)
localConnWsh := wshutil.MakeWshRpc(nil, nil, wshrpc.RpcContext{Conn: wshrpc.LocalConnName}, &wshremote.ServerImpl{})
go wshremote.RunSysInfoLoop(localConnWsh, wshrpc.LocalConnName)
wshutil.DefaultRouter.RegisterRoute(wshutil.MakeConnectionRouteId(wshrpc.LocalConnName), localConnWsh)
wshutil.DefaultRouter.RegisterRoute(wshutil.MakeConnectionRouteId(wshrpc.LocalConnName), localConnWsh, true)
}

func main() {
Expand Down
18 changes: 13 additions & 5 deletions cmd/wsh/cmd/wshcmd-conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ package cmd

import (
"fmt"
"strings"

"github.com/spf13/cobra"
"github.com/wavetermdev/waveterm/pkg/remote"
Expand All @@ -25,17 +26,24 @@ func init() {
}

func connStatus() error {
resp, err := wshclient.ConnStatusCommand(RpcClient, nil)
var allResp []wshrpc.ConnStatus
sshResp, err := wshclient.ConnStatusCommand(RpcClient, nil)
if err != nil {
return fmt.Errorf("getting connection status: %w", err)
return fmt.Errorf("getting ssh connection status: %w", err)
}
if len(resp) == 0 {
allResp = append(allResp, sshResp...)
wslResp, err := wshclient.WslStatusCommand(RpcClient, nil)
if err != nil {
return fmt.Errorf("getting wsl connection status: %w", err)
}
allResp = append(allResp, wslResp...)
if len(allResp) == 0 {
WriteStdout("no connections\n")
return nil
}
WriteStdout("%-30s %-12s\n", "connection", "status")
WriteStdout("----------------------------------------------\n")
for _, conn := range resp {
for _, conn := range allResp {
str := fmt.Sprintf("%-30s %-12s", conn.Connection, conn.Status)
if conn.Error != "" {
str += fmt.Sprintf(" (%s)", conn.Error)
Expand Down Expand Up @@ -110,7 +118,7 @@ func connRun(cmd *cobra.Command, args []string) error {
}
connName = args[1]
_, err := remote.ParseOpts(connName)
if err != nil {
if err != nil && !strings.HasPrefix(connName, "wsl://") {
return fmt.Errorf("cannot parse connection name: %w", err)
}
}
Expand Down
175 changes: 166 additions & 9 deletions cmd/wsh/cmd/wshcmd-connserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,29 +4,186 @@
package cmd

import (
"encoding/json"
"fmt"
"io"
"log"
"net"
"os"
"sync/atomic"
"time"

"github.com/spf13/cobra"
"github.com/wavetermdev/waveterm/pkg/util/packetparser"
"github.com/wavetermdev/waveterm/pkg/wavebase"
"github.com/wavetermdev/waveterm/pkg/wshrpc"
"github.com/wavetermdev/waveterm/pkg/wshrpc/wshclient"
"github.com/wavetermdev/waveterm/pkg/wshrpc/wshremote"
"github.com/wavetermdev/waveterm/pkg/wshutil"
)

var serverCmd = &cobra.Command{
Use: "connserver",
Hidden: true,
Short: "remote server to power wave blocks",
Args: cobra.NoArgs,
Run: serverRun,
PreRunE: preRunSetupRpcClient,
Use: "connserver",
Hidden: true,
Short: "remote server to power wave blocks",
Args: cobra.NoArgs,
RunE: serverRun,
}

var connServerRouter bool

func init() {
serverCmd.Flags().BoolVar(&connServerRouter, "router", false, "run in local router mode")
rootCmd.AddCommand(serverCmd)
}

func serverRun(cmd *cobra.Command, args []string) {
func MakeRemoteUnixListener() (net.Listener, error) {
serverAddr := wavebase.GetRemoteDomainSocketName()
os.Remove(serverAddr) // ignore error
rtn, err := net.Listen("unix", serverAddr)
if err != nil {
return nil, fmt.Errorf("error creating listener at %v: %v", serverAddr, err)
}
os.Chmod(serverAddr, 0700)
log.Printf("Server [unix-domain] listening on %s\n", serverAddr)
return rtn, nil
}

func handleNewListenerConn(conn net.Conn, router *wshutil.WshRouter) {
var routeIdContainer atomic.Pointer[string]
proxy := wshutil.MakeRpcProxy()
go func() {
writeErr := wshutil.AdaptOutputChToStream(proxy.ToRemoteCh, conn)
if writeErr != nil {
log.Printf("error writing to domain socket: %v\n", writeErr)
}
}()
go func() {
// when input is closed, close the connection
defer func() {
conn.Close()
routeIdPtr := routeIdContainer.Load()
if routeIdPtr != nil && *routeIdPtr != "" {
router.UnregisterRoute(*routeIdPtr)
disposeMsg := &wshutil.RpcMessage{
Command: wshrpc.Command_Dispose,
Data: wshrpc.CommandDisposeData{
RouteId: *routeIdPtr,
},
Source: *routeIdPtr,
AuthToken: proxy.GetAuthToken(),
}
disposeBytes, _ := json.Marshal(disposeMsg)
router.InjectMessage(disposeBytes, *routeIdPtr)
}
}()
wshutil.AdaptStreamToMsgCh(conn, proxy.FromRemoteCh)
}()
routeId, err := proxy.HandleClientProxyAuth(router)
if err != nil {
log.Printf("error handling client proxy auth: %v\n", err)
conn.Close()
return
}
router.RegisterRoute(routeId, proxy, false)
routeIdContainer.Store(&routeId)
}

func runListener(listener net.Listener, router *wshutil.WshRouter) {
defer func() {
log.Printf("listener closed, exiting\n")
time.Sleep(500 * time.Millisecond)
wshutil.DoShutdown("", 1, true)
}()
for {
conn, err := listener.Accept()
if err == io.EOF {
break
}
if err != nil {
log.Printf("error accepting connection: %v\n", err)
continue
}
go handleNewListenerConn(conn, router)
}
}

func setupConnServerRpcClientWithRouter(router *wshutil.WshRouter) (*wshutil.WshRpc, error) {
jwtToken := os.Getenv(wshutil.WaveJwtTokenVarName)
if jwtToken == "" {
return nil, fmt.Errorf("no jwt token found for connserver")
}
rpcCtx, err := wshutil.ExtractUnverifiedRpcContext(jwtToken)
if err != nil {
return nil, fmt.Errorf("error extracting rpc context from %s: %v", wshutil.WaveJwtTokenVarName, err)
}
authRtn, err := router.HandleProxyAuth(jwtToken)
if err != nil {
return nil, fmt.Errorf("error handling proxy auth: %v", err)
}
inputCh := make(chan []byte, wshutil.DefaultInputChSize)
outputCh := make(chan []byte, wshutil.DefaultOutputChSize)
connServerClient := wshutil.MakeWshRpc(inputCh, outputCh, *rpcCtx, &wshremote.ServerImpl{LogWriter: os.Stdout})
connServerClient.SetAuthToken(authRtn.AuthToken)
router.RegisterRoute(authRtn.RouteId, connServerClient, false)
wshclient.RouteAnnounceCommand(connServerClient, nil)
return connServerClient, nil
}

func serverRunRouter() error {
router := wshutil.NewWshRouter()
termProxy := wshutil.MakeRpcProxy()
rawCh := make(chan []byte, wshutil.DefaultOutputChSize)
go packetparser.Parse(os.Stdin, termProxy.FromRemoteCh, rawCh)
go func() {
for msg := range termProxy.ToRemoteCh {
packetparser.WritePacket(os.Stdout, msg)
}
}()
go func() {
// just ignore and drain the rawCh (stdin)
// when stdin is closed, shutdown
defer wshutil.DoShutdown("", 0, true)
for range rawCh {
// ignore
}
}()
go func() {
for msg := range termProxy.FromRemoteCh {
// send this to the router
router.InjectMessage(msg, wshutil.UpstreamRoute)
}
}()
router.SetUpstreamClient(termProxy)
// now set up the domain socket
unixListener, err := MakeRemoteUnixListener()
if err != nil {
return fmt.Errorf("cannot create unix listener: %v", err)
}
client, err := setupConnServerRpcClientWithRouter(router)
if err != nil {
return fmt.Errorf("error setting up connserver rpc client: %v", err)
}
go runListener(unixListener, router)
// run the sysinfo loop
wshremote.RunSysInfoLoop(client, client.GetRpcContext().Conn)
select {}
}

func serverRunNormal() error {
err := setupRpcClient(&wshremote.ServerImpl{LogWriter: os.Stdout})
if err != nil {
return err
}
WriteStdout("running wsh connserver (%s)\n", RpcContext.Conn)
go wshremote.RunSysInfoLoop(RpcClient, RpcContext.Conn)
RpcClient.SetServerImpl(&wshremote.ServerImpl{LogWriter: os.Stdout})

select {} // run forever
}

func serverRun(cmd *cobra.Command, args []string) error {
if connServerRouter {
return serverRunRouter()
} else {
return serverRunNormal()
}
}
60 changes: 60 additions & 0 deletions cmd/wsh/cmd/wshcmd-wsl.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
// Copyright 2024, Command Line Inc.
// SPDX-License-Identifier: Apache-2.0

package cmd

import (
"strings"

"github.com/spf13/cobra"
"github.com/wavetermdev/waveterm/pkg/waveobj"
"github.com/wavetermdev/waveterm/pkg/wshrpc"
"github.com/wavetermdev/waveterm/pkg/wshrpc/wshclient"
)

var distroName string

var wslCmd = &cobra.Command{
Use: "wsl [-d <Distro>]",
Short: "connect this terminal to a local wsl connection",
Args: cobra.NoArgs,
Run: wslRun,
PreRunE: preRunSetupRpcClient,
}

func init() {
wslCmd.Flags().StringVarP(&distroName, "distribution", "d", "", "Run the specified distribution")
rootCmd.AddCommand(wslCmd)
}

func wslRun(cmd *cobra.Command, args []string) {
var err error
if distroName == "" {
// get default distro from the host
distroName, err = wshclient.WslDefaultDistroCommand(RpcClient, nil)
if err != nil {
WriteStderr("[error] %s\n", err)
return
}
}
if !strings.HasPrefix(distroName, "wsl://") {
distroName = "wsl://" + distroName
}
blockId := RpcContext.BlockId
if blockId == "" {
WriteStderr("[error] cannot determine blockid (not in JWT)\n")
return
}
data := wshrpc.CommandSetMetaData{
ORef: waveobj.MakeORef(waveobj.OType_Block, blockId),
Meta: map[string]any{
waveobj.MetaKey_Connection: distroName,
},
}
err = wshclient.SetMetaCommand(RpcClient, data, nil)
if err != nil {
WriteStderr("[error] setting switching connection: %v\n", err)
return
}
WriteStderr("switched connection to %q\n", distroName)
}
36 changes: 36 additions & 0 deletions frontend/app/block/blockframe.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,7 @@ const ChangeConnectionBlockModal = React.memo(
const connStatusAtom = getConnStatusAtom(connection);
const connStatus = jotai.useAtomValue(connStatusAtom);
const [connList, setConnList] = React.useState<Array<string>>([]);
const [wslList, setWslList] = React.useState<Array<string>>([]);
const allConnStatus = jotai.useAtomValue(atoms.allConnStatus);
const [rowIndex, setRowIndex] = React.useState(0);
const connStatusMap = new Map<string, ConnStatus>();
Expand All @@ -540,6 +541,18 @@ const ChangeConnectionBlockModal = React.memo(
prtn.then((newConnList) => {
setConnList(newConnList ?? []);
}).catch((e) => console.log("unable to load conn list from backend. using blank list: ", e));
const p2rtn = RpcApi.WslListCommand(TabRpcClient, { timeout: 2000 });
p2rtn
.then((newWslList) => {
console.log(newWslList);
setWslList(newWslList ?? []);
})
.catch((e) => {
// removing this log and failing silentyly since it will happen
// if a system isn't using the wsl. and would happen every time the
// typeahead was opened. good candidate for verbose log level.
//console.log("unable to load wsl list from backend. using blank list: ", e)
});
}, [changeConnModalOpen, setConnList]);

const changeConnection = React.useCallback(
Expand Down Expand Up @@ -588,6 +601,15 @@ const ChangeConnectionBlockModal = React.memo(
filteredList.push(conn);
}
}
const filteredWslList: Array<string> = [];
for (const conn of wslList) {
if (conn === connSelected) {
createNew = false;
}
if (conn.includes(connSelected)) {
filteredWslList.push(conn);
}
}
// priority handles special suggestions when necessary
// for instance, when reconnecting
const newConnectionSuggestion: SuggestionConnectionItem = {
Expand Down Expand Up @@ -637,6 +659,20 @@ const ChangeConnectionBlockModal = React.memo(
label: localName,
});
}
for (const wslConn of filteredWslList) {
const connStatus = connStatusMap.get(wslConn);
const connColorNum = computeConnColorNum(connStatus);
localSuggestion.items.push({
status: "connected",
icon: "arrow-right-arrow-left",
iconColor:
connStatus?.status == "connected"
? `var(--conn-icon-color-${connColorNum})`
: "var(--grey-text-color)",
value: "wsl://" + wslConn,
label: "wsl://" + wslConn,
});
}
const remoteItems = filteredList.map((connName) => {
const connStatus = connStatusMap.get(connName);
const connColorNum = computeConnColorNum(connStatus);
Expand Down
Loading

0 comments on commit 8248637

Please sign in to comment.