diff --git a/cmd/generatego/main-generatego.go b/cmd/generatego/main-generatego.go index c93faa352..62e65d4f7 100644 --- a/cmd/generatego/main-generatego.go +++ b/cmd/generatego/main-generatego.go @@ -30,6 +30,7 @@ func GenerateWshClient() error { "github.com/wavetermdev/waveterm/pkg/waveobj", "github.com/wavetermdev/waveterm/pkg/wps", "github.com/wavetermdev/waveterm/pkg/vdom", + "github.com/wavetermdev/waveterm/pkg/util/iochan/iochantypes", }) wshDeclMap := wshrpc.GenerateWshCommandDeclMap() for _, key := range utilfn.GetOrderedMapKeys(wshDeclMap) { diff --git a/cmd/wsh/cmd/wshcmd-file.go b/cmd/wsh/cmd/wshcmd-file.go index 091dd74ac..62e60474e 100644 --- a/cmd/wsh/cmd/wshcmd-file.go +++ b/cmd/wsh/cmd/wshcmd-file.go @@ -98,6 +98,7 @@ func init() { fileCmd.AddCommand(fileListCmd) fileCmd.AddCommand(fileCatCmd) fileCmd.AddCommand(fileWriteCmd) + fileRmCmd.Flags().BoolP("recursive", "r", false, "remove directories recursively") fileCmd.AddCommand(fileRmCmd) fileCmd.AddCommand(fileInfoCmd) fileCmd.AddCommand(fileAppendCmd) @@ -259,6 +260,10 @@ func fileRmRun(cmd *cobra.Command, args []string) error { if err != nil { return err } + recursive, err := cmd.Flags().GetBool("recursive") + if err != nil { + return err + } fileData := wshrpc.FileData{ Info: &wshrpc.FileInfo{ Path: path}} @@ -272,7 +277,7 @@ func fileRmRun(cmd *cobra.Command, args []string) error { return fmt.Errorf("getting file info: %w", err) } - err = wshclient.FileDeleteCommand(RpcClient, fileData, &wshrpc.RpcOpts{Timeout: DefaultFileTimeout}) + err = wshclient.FileDeleteCommand(RpcClient, wshrpc.CommandDeleteFileData{Path: path, Recursive: recursive}, &wshrpc.RpcOpts{Timeout: DefaultFileTimeout}) if err != nil { return fmt.Errorf("removing file: %w", err) } diff --git a/frontend/app/store/wshclientapi.ts b/frontend/app/store/wshclientapi.ts index 2822ac298..9c614af4e 100644 --- a/frontend/app/store/wshclientapi.ts +++ b/frontend/app/store/wshclientapi.ts @@ -168,7 +168,7 @@ class RpcApiType { } // command "filedelete" [call] - FileDeleteCommand(client: WshClient, data: FileData, opts?: RpcOpts): Promise { + FileDeleteCommand(client: WshClient, data: CommandDeleteFileData, opts?: RpcOpts): Promise { return client.wshRpcCall("filedelete", data, opts); } @@ -203,7 +203,7 @@ class RpcApiType { } // command "filestreamtar" [responsestream] - FileStreamTarCommand(client: WshClient, data: CommandRemoteStreamTarData, opts?: RpcOpts): AsyncGenerator { + FileStreamTarCommand(client: WshClient, data: CommandRemoteStreamTarData, opts?: RpcOpts): AsyncGenerator { return client.wshRpcStream("filestreamtar", data, opts); } @@ -258,7 +258,7 @@ class RpcApiType { } // command "remotefiledelete" [call] - RemoteFileDeleteCommand(client: WshClient, data: string, opts?: RpcOpts): Promise { + RemoteFileDeleteCommand(client: WshClient, data: CommandDeleteFileData, opts?: RpcOpts): Promise { return client.wshRpcCall("remotefiledelete", data, opts); } @@ -313,7 +313,7 @@ class RpcApiType { } // command "remotetarstream" [responsestream] - RemoteTarStreamCommand(client: WshClient, data: CommandRemoteStreamTarData, opts?: RpcOpts): AsyncGenerator { + RemoteTarStreamCommand(client: WshClient, data: CommandRemoteStreamTarData, opts?: RpcOpts): AsyncGenerator { return client.wshRpcStream("remotetarstream", data, opts); } diff --git a/frontend/types/gotypes.d.ts b/frontend/types/gotypes.d.ts index 41423efb2..400e86935 100644 --- a/frontend/types/gotypes.d.ts +++ b/frontend/types/gotypes.d.ts @@ -166,6 +166,12 @@ declare global { blockid: string; }; + // wshrpc.CommandDeleteFileData + type CommandDeleteFileData = { + path: string; + recursive: boolean; + }; + // wshrpc.CommandDisposeData type CommandDisposeData = { routeid: string; @@ -583,6 +589,12 @@ declare global { // waveobj.ORef type ORef = string; + // iochantypes.Packet + type Packet = { + Data: string; + Checksum: string; + }; + // wshrpc.PathCommandData type PathCommandData = { pathtype: string; diff --git a/pkg/remote/awsconn/awsconn.go b/pkg/remote/awsconn/awsconn.go index 6b7471dbe..ff0deaeda 100644 --- a/pkg/remote/awsconn/awsconn.go +++ b/pkg/remote/awsconn/awsconn.go @@ -124,28 +124,13 @@ func ParseProfiles() map[string]struct{} { } func ListBuckets(ctx context.Context, client *s3.Client) ([]types.Bucket, error) { - var err error - var output *s3.ListBucketsOutput - var buckets []types.Bucket - region := client.Options().Region - bucketPaginator := s3.NewListBucketsPaginator(client, &s3.ListBucketsInput{BucketRegion: ®ion}) - for bucketPaginator.HasMorePages() { - output, err = bucketPaginator.NextPage(ctx) - log.Printf("output: %v", output) - if err != nil { - var apiErr smithy.APIError - if errors.As(err, &apiErr) && apiErr.ErrorCode() == "AccessDenied" { - fmt.Println("You don't have permission to list buckets for this account.") - err = apiErr - } else { - return nil, fmt.Errorf("Couldn't list buckets for your account. Here's why: %v\n", err) - } - break - } - if output == nil { - break + output, err := client.ListBuckets(ctx, &s3.ListBucketsInput{}) + if err != nil { + var apiErr smithy.APIError + if errors.As(err, &apiErr) { + return nil, fmt.Errorf("error listing buckets: %v", apiErr) } - buckets = append(buckets, output.Buckets...) + return nil, fmt.Errorf("error listing buckets: %v", err) } - return buckets, nil + return output.Buckets, nil } diff --git a/pkg/remote/connparse/connparse.go b/pkg/remote/connparse/connparse.go index e4cadba91..995106561 100644 --- a/pkg/remote/connparse/connparse.go +++ b/pkg/remote/connparse/connparse.go @@ -57,6 +57,10 @@ func (c *Connection) GetFullURI() string { return c.Scheme + "://" + c.GetPathWithHost() } +func (c *Connection) GetSchemeAndHost() string { + return c.Scheme + "://" + c.Host +} + func ParseURIAndReplaceCurrentHost(ctx context.Context, uri string) (*Connection, error) { conn, err := ParseURI(uri) if err != nil { @@ -148,7 +152,7 @@ func ParseURI(uri string) (*Connection, error) { } if strings.HasPrefix(remotePath, "/~") { remotePath = strings.TrimPrefix(remotePath, "/") - } else if len(remotePath) > 1 && !windowsDriveRegex.MatchString(remotePath) && !strings.HasPrefix(remotePath, "/") && !strings.HasPrefix(remotePath, "~") { + } else if len(remotePath) > 1 && !windowsDriveRegex.MatchString(remotePath) && !strings.HasPrefix(remotePath, "/") && !strings.HasPrefix(remotePath, "~") && !strings.HasPrefix(remotePath, "./") && !strings.HasPrefix(remotePath, "../") && !strings.HasPrefix(remotePath, ".\\") && !strings.HasPrefix(remotePath, "..\\") && remotePath != ".." { remotePath = "/" + remotePath } } diff --git a/pkg/remote/connparse/connparse_test.go b/pkg/remote/connparse/connparse_test.go index 5eae1ce88..c530c8e76 100644 --- a/pkg/remote/connparse/connparse_test.go +++ b/pkg/remote/connparse/connparse_test.go @@ -190,6 +190,54 @@ func TestParseURI_WSHCurrentPathShorthand(t *testing.T) { } } +func TestParseURI_WSHCurrentPath(t *testing.T) { + cstr := "./Documents/path/to/file" + c, err := connparse.ParseURI(cstr) + if err != nil { + t.Fatalf("failed to parse URI: %v", err) + } + expected := "./Documents/path/to/file" + if c.Path != expected { + t.Fatalf("expected path to be %q, got %q", expected, c.Path) + } + expected = "current" + if c.Host != expected { + t.Fatalf("expected host to be %q, got %q", expected, c.Host) + } + expected = "wsh" + if c.Scheme != expected { + t.Fatalf("expected scheme to be %q, got %q", expected, c.Scheme) + } + expected = "wsh://current/./Documents/path/to/file" + if c.GetFullURI() != expected { + t.Fatalf("expected full URI to be %q, got %q", expected, c.GetFullURI()) + } +} + +func TestParseURI_WSHCurrentPathWindows(t *testing.T) { + cstr := ".\\Documents\\path\\to\\file" + c, err := connparse.ParseURI(cstr) + if err != nil { + t.Fatalf("failed to parse URI: %v", err) + } + expected := ".\\Documents\\path\\to\\file" + if c.Path != expected { + t.Fatalf("expected path to be %q, got %q", expected, c.Path) + } + expected = "current" + if c.Host != expected { + t.Fatalf("expected host to be %q, got %q", expected, c.Host) + } + expected = "wsh" + if c.Scheme != expected { + t.Fatalf("expected scheme to be %q, got %q", expected, c.Scheme) + } + expected = "wsh://current/.\\Documents\\path\\to\\file" + if c.GetFullURI() != expected { + t.Fatalf("expected full URI to be %q, got %q", expected, c.GetFullURI()) + } +} + func TestParseURI_WSHLocalShorthand(t *testing.T) { t.Parallel() cstr := "/~/path/to/file" diff --git a/pkg/remote/fileshare/fileshare.go b/pkg/remote/fileshare/fileshare.go index 3b2458cad..9473db55a 100644 --- a/pkg/remote/fileshare/fileshare.go +++ b/pkg/remote/fileshare/fileshare.go @@ -5,12 +5,11 @@ import ( "fmt" "log" - "github.com/wavetermdev/waveterm/pkg/remote/awsconn" "github.com/wavetermdev/waveterm/pkg/remote/connparse" "github.com/wavetermdev/waveterm/pkg/remote/fileshare/fstype" - "github.com/wavetermdev/waveterm/pkg/remote/fileshare/s3fs" "github.com/wavetermdev/waveterm/pkg/remote/fileshare/wavefs" "github.com/wavetermdev/waveterm/pkg/remote/fileshare/wshfs" + "github.com/wavetermdev/waveterm/pkg/util/iochan/iochantypes" "github.com/wavetermdev/waveterm/pkg/wshrpc" "github.com/wavetermdev/waveterm/pkg/wshutil" ) @@ -29,12 +28,12 @@ func CreateFileShareClient(ctx context.Context, connection string) (fstype.FileS } conntype := conn.GetType() if conntype == connparse.ConnectionTypeS3 { - config, err := awsconn.GetConfig(ctx, connection) - if err != nil { - log.Printf("error getting aws config: %v", err) - return nil, nil - } - return s3fs.NewS3Client(config), conn + // config, err := awsconn.GetConfig(ctx, connection) + // if err != nil { + // log.Printf("error getting aws config: %v", err) + // return nil, nil + // } + return nil, nil } else if conntype == connparse.ConnectionTypeWave { return wavefs.NewWaveClient(), conn } else if conntype == connparse.ConnectionTypeWsh { @@ -61,10 +60,10 @@ func ReadStream(ctx context.Context, data wshrpc.FileData) <-chan wshrpc.RespOrE return client.ReadStream(ctx, conn, data) } -func ReadTarStream(ctx context.Context, data wshrpc.CommandRemoteStreamTarData) <-chan wshrpc.RespOrErrorUnion[[]byte] { +func ReadTarStream(ctx context.Context, data wshrpc.CommandRemoteStreamTarData) <-chan wshrpc.RespOrErrorUnion[iochantypes.Packet] { client, conn := CreateFileShareClient(ctx, data.Path) if conn == nil || client == nil { - return wshutil.SendErrCh[[]byte](fmt.Errorf(ErrorParsingConnection, data.Path)) + return wshutil.SendErrCh[iochantypes.Packet](fmt.Errorf(ErrorParsingConnection, data.Path)) } return client.ReadTarStream(ctx, conn, data.Opts) } @@ -110,35 +109,47 @@ func Mkdir(ctx context.Context, path string) error { } func Move(ctx context.Context, data wshrpc.CommandFileCopyData) error { - srcConn, err := connparse.ParseURIAndReplaceCurrentHost(ctx, data.SrcUri) - if err != nil { - return fmt.Errorf("error parsing source connection %s: %v", data.SrcUri, err) + srcClient, srcConn := CreateFileShareClient(ctx, data.SrcUri) + if srcConn == nil || srcClient == nil { + return fmt.Errorf("error creating fileshare client, could not parse source connection %s", data.SrcUri) } destClient, destConn := CreateFileShareClient(ctx, data.DestUri) if destConn == nil || destClient == nil { - return fmt.Errorf("error creating fileshare client, could not parse connection %s or %s", data.SrcUri, data.DestUri) + return fmt.Errorf("error creating fileshare client, could not parse destination connection %s", data.DestUri) + } + if srcConn.Host != destConn.Host { + err := destClient.CopyRemote(ctx, srcConn, destConn, srcClient, data.Opts) + if err != nil { + return fmt.Errorf("cannot copy %q to %q: %w", data.SrcUri, data.DestUri, err) + } + return srcClient.Delete(ctx, srcConn, data.Opts.Recursive) + } else { + return srcClient.MoveInternal(ctx, srcConn, destConn, data.Opts) } - return destClient.Move(ctx, srcConn, destConn, data.Opts) } func Copy(ctx context.Context, data wshrpc.CommandFileCopyData) error { - srcConn, err := connparse.ParseURIAndReplaceCurrentHost(ctx, data.SrcUri) - if err != nil { - return fmt.Errorf("error parsing source connection %s: %v", data.SrcUri, err) + srcClient, srcConn := CreateFileShareClient(ctx, data.SrcUri) + if srcConn == nil || srcClient == nil { + return fmt.Errorf("error creating fileshare client, could not parse source connection %s", data.SrcUri) } destClient, destConn := CreateFileShareClient(ctx, data.DestUri) if destConn == nil || destClient == nil { - return fmt.Errorf("error creating fileshare client, could not parse connection %s or %s", data.SrcUri, data.DestUri) + return fmt.Errorf("error creating fileshare client, could not parse destination connection %s", data.DestUri) + } + if srcConn.Host != destConn.Host { + return destClient.CopyRemote(ctx, srcConn, destConn, srcClient, data.Opts) + } else { + return srcClient.CopyInternal(ctx, srcConn, destConn, data.Opts) } - return destClient.Copy(ctx, srcConn, destConn, data.Opts) } -func Delete(ctx context.Context, path string) error { - client, conn := CreateFileShareClient(ctx, path) +func Delete(ctx context.Context, data wshrpc.CommandDeleteFileData) error { + client, conn := CreateFileShareClient(ctx, data.Path) if conn == nil || client == nil { - return fmt.Errorf(ErrorParsingConnection, path) + return fmt.Errorf(ErrorParsingConnection, data.Path) } - return client.Delete(ctx, conn) + return client.Delete(ctx, conn, data.Recursive) } func Join(ctx context.Context, path string, parts ...string) (string, error) { diff --git a/pkg/remote/fileshare/fstype/fstype.go b/pkg/remote/fileshare/fstype/fstype.go index 3d42c0b03..3c3d6fceb 100644 --- a/pkg/remote/fileshare/fstype/fstype.go +++ b/pkg/remote/fileshare/fstype/fstype.go @@ -7,6 +7,7 @@ import ( "context" "github.com/wavetermdev/waveterm/pkg/remote/connparse" + "github.com/wavetermdev/waveterm/pkg/util/iochan/iochantypes" "github.com/wavetermdev/waveterm/pkg/wshrpc" ) @@ -18,7 +19,7 @@ type FileShareClient interface { // ReadStream returns a stream of file data at the given path. If it's a directory, then the list of entries ReadStream(ctx context.Context, conn *connparse.Connection, data wshrpc.FileData) <-chan wshrpc.RespOrErrorUnion[wshrpc.FileData] // ReadTarStream returns a stream of tar data at the given path - ReadTarStream(ctx context.Context, conn *connparse.Connection, opts *wshrpc.FileCopyOpts) <-chan wshrpc.RespOrErrorUnion[[]byte] + ReadTarStream(ctx context.Context, conn *connparse.Connection, opts *wshrpc.FileCopyOpts) <-chan wshrpc.RespOrErrorUnion[iochantypes.Packet] // ListEntries returns the list of entries at the given path, or nothing if the path is a file ListEntries(ctx context.Context, conn *connparse.Connection, opts *wshrpc.FileListOpts) ([]*wshrpc.FileInfo, error) // ListEntriesStream returns a stream of entries at the given path @@ -29,12 +30,14 @@ type FileShareClient interface { AppendFile(ctx context.Context, conn *connparse.Connection, data wshrpc.FileData) error // Mkdir creates a directory at the given path Mkdir(ctx context.Context, conn *connparse.Connection) error - // Move moves the file from srcConn to destConn - Move(ctx context.Context, srcConn, destConn *connparse.Connection, opts *wshrpc.FileCopyOpts) error - // Copy copies the file from srcConn to destConn - Copy(ctx context.Context, srcConn, destConn *connparse.Connection, opts *wshrpc.FileCopyOpts) error + // Move moves the file within the same connection + MoveInternal(ctx context.Context, srcConn, destConn *connparse.Connection, opts *wshrpc.FileCopyOpts) error + // Copy copies the file within the same connection + CopyInternal(ctx context.Context, srcConn, destConn *connparse.Connection, opts *wshrpc.FileCopyOpts) error + // CopyRemote copies the file between different connections + CopyRemote(ctx context.Context, srcConn, destConn *connparse.Connection, srcClient FileShareClient, opts *wshrpc.FileCopyOpts) error // Delete deletes the entry at the given path - Delete(ctx context.Context, conn *connparse.Connection) error + Delete(ctx context.Context, conn *connparse.Connection, recursive bool) error // Join joins the given parts to the connection path Join(ctx context.Context, conn *connparse.Connection, parts ...string) (string, error) // GetConnectionType returns the type of connection for the fileshare diff --git a/pkg/remote/fileshare/s3fs/s3fs.go b/pkg/remote/fileshare/s3fs/s3fs.go index 32d0636d5..b406615d4 100644 --- a/pkg/remote/fileshare/s3fs/s3fs.go +++ b/pkg/remote/fileshare/s3fs/s3fs.go @@ -5,6 +5,7 @@ package s3fs import ( "context" + "errors" "log" "github.com/aws/aws-sdk-go-v2/aws" @@ -12,6 +13,7 @@ import ( "github.com/wavetermdev/waveterm/pkg/remote/awsconn" "github.com/wavetermdev/waveterm/pkg/remote/connparse" "github.com/wavetermdev/waveterm/pkg/remote/fileshare/fstype" + "github.com/wavetermdev/waveterm/pkg/util/iochan/iochantypes" "github.com/wavetermdev/waveterm/pkg/wshrpc" "github.com/wavetermdev/waveterm/pkg/wshutil" ) @@ -29,15 +31,15 @@ func NewS3Client(config *aws.Config) *S3Client { } func (c S3Client) Read(ctx context.Context, conn *connparse.Connection, data wshrpc.FileData) (*wshrpc.FileData, error) { - return nil, nil + return nil, errors.ErrUnsupported } func (c S3Client) ReadStream(ctx context.Context, conn *connparse.Connection, data wshrpc.FileData) <-chan wshrpc.RespOrErrorUnion[wshrpc.FileData] { - return nil + return wshutil.SendErrCh[wshrpc.FileData](errors.ErrUnsupported) } -func (c S3Client) ReadTarStream(ctx context.Context, conn *connparse.Connection, opts *wshrpc.FileCopyOpts) <-chan wshrpc.RespOrErrorUnion[[]byte] { - return nil +func (c S3Client) ReadTarStream(ctx context.Context, conn *connparse.Connection, opts *wshrpc.FileCopyOpts) <-chan wshrpc.RespOrErrorUnion[iochantypes.Packet] { + return wshutil.SendErrCh[iochantypes.Packet](errors.ErrUnsupported) } func (c S3Client) ListEntriesStream(ctx context.Context, conn *connparse.Connection, opts *wshrpc.FileListOpts) <-chan wshrpc.RespOrErrorUnion[wshrpc.CommandRemoteListEntriesRtnData] { @@ -82,35 +84,39 @@ func (c S3Client) ListEntries(ctx context.Context, conn *connparse.Connection, o } func (c S3Client) Stat(ctx context.Context, conn *connparse.Connection) (*wshrpc.FileInfo, error) { - return nil, nil + return nil, errors.ErrUnsupported } func (c S3Client) PutFile(ctx context.Context, conn *connparse.Connection, data wshrpc.FileData) error { - return nil + return errors.ErrUnsupported } func (c S3Client) AppendFile(ctx context.Context, conn *connparse.Connection, data wshrpc.FileData) error { - return nil + return errors.ErrUnsupported } func (c S3Client) Mkdir(ctx context.Context, conn *connparse.Connection) error { - return nil + return errors.ErrUnsupported +} + +func (c S3Client) MoveInternal(ctx context.Context, srcConn, destConn *connparse.Connection, opts *wshrpc.FileCopyOpts) error { + return errors.ErrUnsupported } -func (c S3Client) Move(ctx context.Context, srcConn, destConn *connparse.Connection, opts *wshrpc.FileCopyOpts) error { - return nil +func (c S3Client) CopyRemote(ctx context.Context, srcConn, destConn *connparse.Connection, srcClient fstype.FileShareClient, opts *wshrpc.FileCopyOpts) error { + return errors.ErrUnsupported } -func (c S3Client) Copy(ctx context.Context, srcConn, destConn *connparse.Connection, opts *wshrpc.FileCopyOpts) error { - return nil +func (c S3Client) CopyInternal(ctx context.Context, srcConn, destConn *connparse.Connection, opts *wshrpc.FileCopyOpts) error { + return errors.ErrUnsupported } -func (c S3Client) Delete(ctx context.Context, conn *connparse.Connection) error { - return nil +func (c S3Client) Delete(ctx context.Context, conn *connparse.Connection, recursive bool) error { + return errors.ErrUnsupported } func (c S3Client) Join(ctx context.Context, conn *connparse.Connection, parts ...string) (string, error) { - return "", nil + return "", errors.ErrUnsupported } func (c S3Client) GetConnectionType() string { diff --git a/pkg/remote/fileshare/wavefs/wavefs.go b/pkg/remote/fileshare/wavefs/wavefs.go index 7be121547..63cbe36a1 100644 --- a/pkg/remote/fileshare/wavefs/wavefs.go +++ b/pkg/remote/fileshare/wavefs/wavefs.go @@ -4,17 +4,24 @@ package wavefs import ( + "archive/tar" "context" "encoding/base64" "errors" "fmt" + "io" "io/fs" + "log" "path" "strings" + "time" "github.com/wavetermdev/waveterm/pkg/filestore" "github.com/wavetermdev/waveterm/pkg/remote/connparse" "github.com/wavetermdev/waveterm/pkg/remote/fileshare/fstype" + "github.com/wavetermdev/waveterm/pkg/util/fileutil" + "github.com/wavetermdev/waveterm/pkg/util/iochan/iochantypes" + "github.com/wavetermdev/waveterm/pkg/util/tarcopy" "github.com/wavetermdev/waveterm/pkg/util/wavefileutil" "github.com/wavetermdev/waveterm/pkg/waveobj" "github.com/wavetermdev/waveterm/pkg/wps" @@ -22,6 +29,10 @@ import ( "github.com/wavetermdev/waveterm/pkg/wshutil" ) +const ( + DefaultTimeout = 30 * time.Second +) + type WaveClient struct{} var _ fstype.FileShareClient = WaveClient{} @@ -95,8 +106,60 @@ func (c WaveClient) Read(ctx context.Context, conn *connparse.Connection, data w return &wshrpc.FileData{Info: data.Info, Entries: list}, nil } -func (c WaveClient) ReadTarStream(ctx context.Context, conn *connparse.Connection, opts *wshrpc.FileCopyOpts) <-chan wshrpc.RespOrErrorUnion[[]byte] { - return nil +func (c WaveClient) ReadTarStream(ctx context.Context, conn *connparse.Connection, opts *wshrpc.FileCopyOpts) <-chan wshrpc.RespOrErrorUnion[iochantypes.Packet] { + log.Printf("ReadTarStream: conn: %v, opts: %v\n", conn, opts) + list, err := c.ListEntries(ctx, conn, nil) + if err != nil { + return wshutil.SendErrCh[iochantypes.Packet](fmt.Errorf("error listing blockfiles: %w", err)) + } + + pathPrefix := getPathPrefix(conn) + schemeAndHost := conn.GetSchemeAndHost() + "/" + + timeout := DefaultTimeout + if opts.Timeout > 0 { + timeout = time.Duration(opts.Timeout) * time.Millisecond + } + readerCtx, cancel := context.WithTimeout(context.Background(), timeout) + rtn, writeHeader, fileWriter, tarClose := tarcopy.TarCopySrc(readerCtx, pathPrefix) + + go func() { + defer func() { + tarClose() + cancel() + }() + for _, file := range list { + if readerCtx.Err() != nil { + rtn <- wshutil.RespErr[iochantypes.Packet](readerCtx.Err()) + return + } + file.Mode = 0644 + + if err = writeHeader(fileutil.ToFsFileInfo(file), file.Path); err != nil { + rtn <- wshutil.RespErr[iochantypes.Packet](fmt.Errorf("error writing tar header: %w", err)) + return + } + if file.IsDir { + continue + } + + log.Printf("ReadTarStream: reading file: %s\n", file.Path) + + internalPath := strings.TrimPrefix(file.Path, schemeAndHost) + + _, dataBuf, err := filestore.WFS.ReadFile(ctx, conn.Host, internalPath) + if err != nil { + rtn <- wshutil.RespErr[iochantypes.Packet](fmt.Errorf("error reading blockfile: %w", err)) + return + } + if _, err = fileWriter.Write(dataBuf); err != nil { + rtn <- wshutil.RespErr[iochantypes.Packet](fmt.Errorf("error writing tar data: %w", err)) + return + } + } + }() + + return rtn } func (c WaveClient) ListEntriesStream(ctx context.Context, conn *connparse.Connection, opts *wshrpc.FileListOpts) <-chan wshrpc.RespOrErrorUnion[wshrpc.CommandRemoteListEntriesRtnData] { @@ -116,10 +179,14 @@ func (c WaveClient) ListEntriesStream(ctx context.Context, conn *connparse.Conne } func (c WaveClient) ListEntries(ctx context.Context, conn *connparse.Connection, opts *wshrpc.FileListOpts) ([]*wshrpc.FileInfo, error) { + log.Printf("ListEntries: conn: %v, opts: %v\n", conn, opts) zoneId := conn.Host if zoneId == "" { return nil, fmt.Errorf("zoneid not found in connection") } + if opts == nil { + opts = &wshrpc.FileListOpts{} + } prefix, err := cleanPath(conn.Path) if err != nil { return nil, fmt.Errorf("error cleaning path: %w", err) @@ -265,33 +332,6 @@ func (c WaveClient) PutFile(ctx context.Context, conn *connparse.Connection, dat return nil } -/* - - path := data.Info.Path - log.Printf("Append: path=%s", path) - client, conn := CreateFileShareClient(ctx, path) - if conn == nil || client == nil { - return fmt.Errorf(ErrorParsingConnection, path) - } - finfo, err := client.Stat(ctx, conn) - if err != nil { - return err - } - if data.Info == nil { - data.Info = &wshrpc.FileInfo{} - } - oldInfo := data.Info - data.Info = finfo - if oldInfo.Opts != nil { - data.Info.Opts = oldInfo.Opts - } - data.At = &wshrpc.FileDataAt{ - Offset: finfo.Size, - } - log.Printf("Append: offset=%d", data.At.Offset) - return client.PutFile(ctx, conn, data) -*/ - func (c WaveClient) AppendFile(ctx context.Context, conn *connparse.Connection, data wshrpc.FileData) error { dataBuf, err := base64.StdEncoding.DecodeString(data.Data64) if err != nil { @@ -346,39 +386,156 @@ func (c WaveClient) AppendFile(ctx context.Context, conn *connparse.Connection, // WaveFile does not support directories, only prefix-based listing func (c WaveClient) Mkdir(ctx context.Context, conn *connparse.Connection) error { - return nil + return errors.ErrUnsupported } -func (c WaveClient) Move(ctx context.Context, srcConn, destConn *connparse.Connection, opts *wshrpc.FileCopyOpts) error { +func (c WaveClient) MoveInternal(ctx context.Context, srcConn, destConn *connparse.Connection, opts *wshrpc.FileCopyOpts) error { + if srcConn.Host != destConn.Host { + return fmt.Errorf("move internal, src and dest hosts do not match") + } + err := c.CopyInternal(ctx, srcConn, destConn, opts) + if err != nil { + return fmt.Errorf("error copying blockfile: %w", err) + } + err = c.Delete(ctx, srcConn, opts.Recursive) + if err != nil { + return fmt.Errorf("error deleting blockfile: %w", err) + } return nil } -func (c WaveClient) Copy(ctx context.Context, srcConn, destConn *connparse.Connection, opts *wshrpc.FileCopyOpts) error { +func (c WaveClient) CopyInternal(ctx context.Context, srcConn, destConn *connparse.Connection, opts *wshrpc.FileCopyOpts) error { + if srcConn.Host == destConn.Host { + host := srcConn.Host + srcFileName, err := cleanPath(srcConn.Path) + if err != nil { + return fmt.Errorf("error cleaning source path: %w", err) + } + destFileName, err := cleanPath(destConn.Path) + if err != nil { + return fmt.Errorf("error cleaning destination path: %w", err) + } + err = filestore.WFS.MakeFile(ctx, host, destFileName, wshrpc.FileMeta{}, wshrpc.FileOpts{}) + if err != nil { + return fmt.Errorf("error making source blockfile: %w", err) + } + _, dataBuf, err := filestore.WFS.ReadFile(ctx, host, srcFileName) + if err != nil { + return fmt.Errorf("error reading source blockfile: %w", err) + } + err = filestore.WFS.WriteFile(ctx, host, destFileName, dataBuf) + if err != nil { + return fmt.Errorf("error writing to destination blockfile: %w", err) + } + wps.Broker.Publish(wps.WaveEvent{ + Event: wps.Event_BlockFile, + Scopes: []string{waveobj.MakeORef(waveobj.OType_Block, host).String()}, + Data: &wps.WSFileEventData{ + ZoneId: host, + FileName: destFileName, + FileOp: wps.FileOp_Invalidate, + }, + }) + return nil + } else { + return fmt.Errorf("copy between different hosts not supported") + } +} + +func (c WaveClient) CopyRemote(ctx context.Context, srcConn, destConn *connparse.Connection, srcClient fstype.FileShareClient, opts *wshrpc.FileCopyOpts) error { + zoneId := destConn.Host + if zoneId == "" { + return fmt.Errorf("zoneid not found in connection") + } + destPrefix := getPathPrefix(destConn) + destPrefix = strings.TrimPrefix(destPrefix, destConn.GetSchemeAndHost()+"/") + log.Printf("CopyRemote: srcConn: %v, destConn: %v, destPrefix: %s\n", srcConn, destConn, destPrefix) + readCtx, cancel := context.WithCancelCause(ctx) + ioch := srcClient.ReadTarStream(readCtx, srcConn, opts) + err := tarcopy.TarCopyDest(readCtx, cancel, ioch, func(next *tar.Header, reader *tar.Reader) error { + if next.Typeflag == tar.TypeDir { + return nil + } + fileName, err := cleanPath(path.Join(destPrefix, next.Name)) + if err != nil { + return fmt.Errorf("error cleaning path: %w", err) + } + _, err = filestore.WFS.Stat(ctx, zoneId, fileName) + if err != nil { + if !errors.Is(err, fs.ErrNotExist) { + return fmt.Errorf("error getting blockfile info: %w", err) + } + err := filestore.WFS.MakeFile(ctx, zoneId, fileName, nil, wshrpc.FileOpts{}) + if err != nil { + return fmt.Errorf("error making blockfile: %w", err) + } + } + log.Printf("CopyRemote: writing file: %s; size: %d\n", fileName, next.Size) + dataBuf := make([]byte, next.Size) + _, err = reader.Read(dataBuf) + if err != nil { + if !errors.Is(err, io.EOF) { + return fmt.Errorf("error reading tar data: %w", err) + } + } + err = filestore.WFS.WriteFile(ctx, zoneId, fileName, dataBuf) + if err != nil { + return fmt.Errorf("error writing to blockfile: %w", err) + } + wps.Broker.Publish(wps.WaveEvent{ + Event: wps.Event_BlockFile, + Scopes: []string{waveobj.MakeORef(waveobj.OType_Block, zoneId).String()}, + Data: &wps.WSFileEventData{ + ZoneId: zoneId, + FileName: fileName, + FileOp: wps.FileOp_Invalidate, + }, + }) + return nil + }) + if err != nil { + return fmt.Errorf("error copying tar stream: %w", err) + } return nil } -func (c WaveClient) Delete(ctx context.Context, conn *connparse.Connection) error { +func (c WaveClient) Delete(ctx context.Context, conn *connparse.Connection, recursive bool) error { zoneId := conn.Host if zoneId == "" { return fmt.Errorf("zoneid not found in connection") } - fileName, err := cleanPath(conn.Path) + schemeAndHost := conn.GetSchemeAndHost() + "/" + + entries, err := c.ListEntries(ctx, conn, nil) if err != nil { - return fmt.Errorf("error cleaning path: %w", err) + return fmt.Errorf("error listing blockfiles: %w", err) } - err = filestore.WFS.DeleteFile(ctx, zoneId, fileName) - if err != nil { - return fmt.Errorf("error deleting blockfile: %w", err) + if len(entries) > 0 { + if !recursive { + return fmt.Errorf("more than one entry, use recursive flag to delete") + } + errs := make([]error, 0) + for _, entry := range entries { + fileName := strings.TrimPrefix(entry.Path, schemeAndHost) + err = filestore.WFS.DeleteFile(ctx, zoneId, fileName) + if err != nil { + errs = append(errs, fmt.Errorf("error deleting blockfile %s/%s: %w", zoneId, fileName, err)) + continue + } + wps.Broker.Publish(wps.WaveEvent{ + Event: wps.Event_BlockFile, + Scopes: []string{waveobj.MakeORef(waveobj.OType_Block, zoneId).String()}, + Data: &wps.WSFileEventData{ + ZoneId: zoneId, + FileName: fileName, + FileOp: wps.FileOp_Delete, + }, + }) + } + if len(errs) > 0 { + return fmt.Errorf("error deleting blockfiles: %v", errs) + } } - wps.Broker.Publish(wps.WaveEvent{ - Event: wps.Event_BlockFile, - Scopes: []string{waveobj.MakeORef(waveobj.OType_Block, zoneId).String()}, - Data: &wps.WSFileEventData{ - ZoneId: zoneId, - FileName: fileName, - FileOp: wps.FileOp_Delete, - }, - }) return nil } @@ -417,3 +574,13 @@ func cleanPath(path string) (string, error) { func (c WaveClient) GetConnectionType() string { return connparse.ConnectionTypeWave } + +func getPathPrefix(conn *connparse.Connection) string { + fullUri := conn.GetFullURI() + pathPrefix := fullUri + lastSlash := strings.LastIndex(fullUri, "/") + if lastSlash > 10 && lastSlash < len(fullUri)-1 { + pathPrefix = fullUri[:lastSlash+1] + } + return pathPrefix +} diff --git a/pkg/remote/fileshare/wshfs/wshfs.go b/pkg/remote/fileshare/wshfs/wshfs.go index 000c58b66..61816ea57 100644 --- a/pkg/remote/fileshare/wshfs/wshfs.go +++ b/pkg/remote/fileshare/wshfs/wshfs.go @@ -12,6 +12,7 @@ import ( "github.com/wavetermdev/waveterm/pkg/remote/connparse" "github.com/wavetermdev/waveterm/pkg/remote/fileshare/fstype" + "github.com/wavetermdev/waveterm/pkg/util/iochan/iochantypes" "github.com/wavetermdev/waveterm/pkg/wshrpc" "github.com/wavetermdev/waveterm/pkg/wshrpc/wshclient" "github.com/wavetermdev/waveterm/pkg/wshutil" @@ -86,7 +87,7 @@ func (c WshClient) ReadStream(ctx context.Context, conn *connparse.Connection, d return wshclient.RemoteStreamFileCommand(RpcClient, streamFileData, &wshrpc.RpcOpts{Route: wshutil.MakeConnectionRouteId(conn.Host)}) } -func (c WshClient) ReadTarStream(ctx context.Context, conn *connparse.Connection, opts *wshrpc.FileCopyOpts) <-chan wshrpc.RespOrErrorUnion[[]byte] { +func (c WshClient) ReadTarStream(ctx context.Context, conn *connparse.Connection, opts *wshrpc.FileCopyOpts) <-chan wshrpc.RespOrErrorUnion[iochantypes.Packet] { timeout := opts.Timeout if timeout == 0 { timeout = ThirtySeconds @@ -145,7 +146,10 @@ func (c WshClient) Mkdir(ctx context.Context, conn *connparse.Connection) error return wshclient.RemoteMkdirCommand(RpcClient, conn.Path, &wshrpc.RpcOpts{Route: wshutil.MakeConnectionRouteId(conn.Host)}) } -func (c WshClient) Move(ctx context.Context, srcConn, destConn *connparse.Connection, opts *wshrpc.FileCopyOpts) error { +func (c WshClient) MoveInternal(ctx context.Context, srcConn, destConn *connparse.Connection, opts *wshrpc.FileCopyOpts) error { + if srcConn.Host != destConn.Host { + return fmt.Errorf("move internal, src and dest hosts do not match") + } if opts == nil { opts = &wshrpc.FileCopyOpts{} } @@ -156,7 +160,11 @@ func (c WshClient) Move(ctx context.Context, srcConn, destConn *connparse.Connec return wshclient.RemoteFileMoveCommand(RpcClient, wshrpc.CommandRemoteFileCopyData{SrcUri: srcConn.GetFullURI(), DestUri: destConn.GetFullURI(), Opts: opts}, &wshrpc.RpcOpts{Route: wshutil.MakeConnectionRouteId(destConn.Host), Timeout: timeout}) } -func (c WshClient) Copy(ctx context.Context, srcConn, destConn *connparse.Connection, opts *wshrpc.FileCopyOpts) error { +func (c WshClient) CopyRemote(ctx context.Context, srcConn, destConn *connparse.Connection, _ fstype.FileShareClient, opts *wshrpc.FileCopyOpts) error { + return c.CopyInternal(ctx, srcConn, destConn, opts) +} + +func (c WshClient) CopyInternal(ctx context.Context, srcConn, destConn *connparse.Connection, opts *wshrpc.FileCopyOpts) error { if opts == nil { opts = &wshrpc.FileCopyOpts{} } @@ -167,8 +175,8 @@ func (c WshClient) Copy(ctx context.Context, srcConn, destConn *connparse.Connec return wshclient.RemoteFileCopyCommand(RpcClient, wshrpc.CommandRemoteFileCopyData{SrcUri: srcConn.GetFullURI(), DestUri: destConn.GetFullURI(), Opts: opts}, &wshrpc.RpcOpts{Route: wshutil.MakeConnectionRouteId(destConn.Host), Timeout: timeout}) } -func (c WshClient) Delete(ctx context.Context, conn *connparse.Connection) error { - return wshclient.RemoteFileDeleteCommand(RpcClient, conn.Path, &wshrpc.RpcOpts{Route: wshutil.MakeConnectionRouteId(conn.Host)}) +func (c WshClient) Delete(ctx context.Context, conn *connparse.Connection, recursive bool) error { + return wshclient.RemoteFileDeleteCommand(RpcClient, wshrpc.CommandDeleteFileData{Path: conn.Path, Recursive: recursive}, &wshrpc.RpcOpts{Route: wshutil.MakeConnectionRouteId(conn.Host)}) } func (c WshClient) Join(ctx context.Context, conn *connparse.Connection, parts ...string) (string, error) { diff --git a/pkg/util/fileutil/fileutil.go b/pkg/util/fileutil/fileutil.go index b5bbfbb1b..18bb538d6 100644 --- a/pkg/util/fileutil/fileutil.go +++ b/pkg/util/fileutil/fileutil.go @@ -6,27 +6,27 @@ package fileutil import ( "io" "io/fs" - "log" "mime" "net/http" "os" "path/filepath" "regexp" "strings" + "time" "github.com/wavetermdev/waveterm/pkg/wavebase" + "github.com/wavetermdev/waveterm/pkg/wshrpc" ) func FixPath(path string) (string, error) { + var err error if strings.HasPrefix(path, "~") { path = filepath.Join(wavebase.GetHomeDir(), path[1:]) } else if !filepath.IsAbs(path) { - log.Printf("FixPath: path is not absolute: %s", path) - path, err := filepath.Abs(path) + path, err = filepath.Abs(path) if err != nil { return "", err } - log.Printf("FixPath: fixed path: %s", path) } return path, nil } @@ -164,3 +164,52 @@ func IsInitScriptPath(input string) bool { return true } + +type FsFileInfo struct { + NameInternal string + ModeInternal os.FileMode + SizeInternal int64 + ModTimeInternal int64 + IsDirInternal bool +} + +func (f FsFileInfo) Name() string { + return f.NameInternal +} + +func (f FsFileInfo) Size() int64 { + return f.SizeInternal +} + +func (f FsFileInfo) Mode() os.FileMode { + return f.ModeInternal +} + +func (f FsFileInfo) ModTime() time.Time { + return time.Unix(0, f.ModTimeInternal) +} + +func (f FsFileInfo) IsDir() bool { + return f.IsDirInternal +} + +func (f FsFileInfo) Sys() interface{} { + return nil +} + +var _ fs.FileInfo = FsFileInfo{} + +// ToFsFileInfo converts wshrpc.FileInfo to FsFileInfo. +// It panics if fi is nil. +func ToFsFileInfo(fi *wshrpc.FileInfo) FsFileInfo { + if fi == nil { + panic("ToFsFileInfo: nil FileInfo") + } + return FsFileInfo{ + NameInternal: fi.Name, + ModeInternal: fi.Mode, + SizeInternal: fi.Size, + ModTimeInternal: fi.ModTime, + IsDirInternal: fi.IsDir, + } +} diff --git a/pkg/util/iochan/iochan.go b/pkg/util/iochan/iochan.go index 4145837fe..98fb94a19 100644 --- a/pkg/util/iochan/iochan.go +++ b/pkg/util/iochan/iochan.go @@ -5,46 +5,49 @@ package iochan import ( + "bytes" "context" + "crypto/sha256" "errors" "fmt" "io" - "log" + "github.com/wavetermdev/waveterm/pkg/util/iochan/iochantypes" "github.com/wavetermdev/waveterm/pkg/wshrpc" "github.com/wavetermdev/waveterm/pkg/wshutil" ) // ReaderChan reads from an io.Reader and sends the data to a channel -func ReaderChan(ctx context.Context, r io.Reader, chunkSize int64, callback func()) chan wshrpc.RespOrErrorUnion[[]byte] { - ch := make(chan wshrpc.RespOrErrorUnion[[]byte], 32) +func ReaderChan(ctx context.Context, r io.Reader, chunkSize int64, callback func()) chan wshrpc.RespOrErrorUnion[iochantypes.Packet] { + ch := make(chan wshrpc.RespOrErrorUnion[iochantypes.Packet], 32) go func() { defer func() { - log.Printf("ReaderChan: closing channel") close(ch) callback() }() - buf := make([]byte, chunkSize) + sha256Hash := sha256.New() for { select { case <-ctx.Done(): if ctx.Err() == context.Canceled { return } - log.Printf("ReaderChan: context error: %v", ctx.Err()) return default: + buf := make([]byte, chunkSize) if n, err := r.Read(buf); err != nil { if errors.Is(err, io.EOF) { - log.Printf("ReaderChan: EOF") + ch <- wshrpc.RespOrErrorUnion[iochantypes.Packet]{Response: iochantypes.Packet{Checksum: sha256Hash.Sum(nil)}} // send the checksum return } - ch <- wshutil.RespErr[[]byte](fmt.Errorf("ReaderChan: read error: %v", err)) - log.Printf("ReaderChan: read error: %v", err) + ch <- wshutil.RespErr[iochantypes.Packet](fmt.Errorf("ReaderChan: read error: %v", err)) return } else if n > 0 { - // log.Printf("ReaderChan: read %d bytes", n) - ch <- wshrpc.RespOrErrorUnion[[]byte]{Response: buf[:n]} + if _, err := sha256Hash.Write(buf[:n]); err != nil { + ch <- wshutil.RespErr[iochantypes.Packet](fmt.Errorf("ReaderChan: error writing to sha256 hash: %v", err)) + return + } + ch <- wshrpc.RespOrErrorUnion[iochantypes.Packet]{Response: iochantypes.Packet{Data: buf[:n]}} } } } @@ -53,13 +56,15 @@ func ReaderChan(ctx context.Context, r io.Reader, chunkSize int64, callback func } // WriterChan reads from a channel and writes the data to an io.Writer -func WriterChan(ctx context.Context, w io.Writer, ch <-chan wshrpc.RespOrErrorUnion[[]byte], callback func(), errCallback func(error)) { +func WriterChan(ctx context.Context, w io.Writer, ch <-chan wshrpc.RespOrErrorUnion[iochantypes.Packet], callback func(), cancel context.CancelCauseFunc) { go func() { defer func() { - log.Printf("WriterChan: closing channel") + if ctx.Err() != nil { + drainChannel(ch) + } callback() - drainChannel(ch) }() + sha256Hash := sha256.New() for { select { case <-ctx.Done(): @@ -69,22 +74,33 @@ func WriterChan(ctx context.Context, w io.Writer, ch <-chan wshrpc.RespOrErrorUn return } if resp.Error != nil { - log.Printf("WriterChan: error: %v", resp.Error) - errCallback(resp.Error) + cancel(resp.Error) return } - if _, err := w.Write(resp.Response); err != nil { - log.Printf("WriterChan: write error: %v", err) - errCallback(err) + if _, err := sha256Hash.Write(resp.Response.Data); err != nil { + cancel(fmt.Errorf("WriterChan: error writing to sha256 hash: %v", err)) + return + } + // The checksum is sent as the last packet + if resp.Response.Checksum != nil { + localChecksum := sha256Hash.Sum(nil) + if !bytes.Equal(localChecksum, resp.Response.Checksum) { + cancel(fmt.Errorf("WriterChan: checksum mismatch")) + } + return + } + if _, err := w.Write(resp.Response.Data); err != nil { + cancel(fmt.Errorf("WriterChan: write error: %v", err)) return - } else { - // log.Printf("WriterChan: wrote %d bytes", n) } } } }() } -func drainChannel(ch <-chan wshrpc.RespOrErrorUnion[[]byte]) { - for range ch {} +func drainChannel(ch <-chan wshrpc.RespOrErrorUnion[iochantypes.Packet]) { + go func() { + for range ch { + } + }() } diff --git a/pkg/util/iochan/iochantypes/iochantypes.go b/pkg/util/iochan/iochantypes/iochantypes.go new file mode 100644 index 000000000..00ea42b0b --- /dev/null +++ b/pkg/util/iochan/iochantypes/iochantypes.go @@ -0,0 +1,6 @@ +package iochantypes + +type Packet struct { + Data []byte + Checksum []byte +} diff --git a/pkg/util/tarcopy/tarcopy.go b/pkg/util/tarcopy/tarcopy.go new file mode 100644 index 000000000..825ac2b31 --- /dev/null +++ b/pkg/util/tarcopy/tarcopy.go @@ -0,0 +1,137 @@ +// Copyright 2025, Command Line Inc. +// SPDX-License-Identifier: Apache-2.0 + +// Package tarcopy provides functions for copying files over a channel via a tar stream. +package tarcopy + +import ( + "archive/tar" + "context" + "errors" + "fmt" + "io" + "io/fs" + "log" + "path/filepath" + "strings" + "time" + + "github.com/wavetermdev/waveterm/pkg/util/iochan" + "github.com/wavetermdev/waveterm/pkg/util/iochan/iochantypes" + "github.com/wavetermdev/waveterm/pkg/wshrpc" +) + +const ( + maxRetries = 5 + retryDelay = 10 * time.Millisecond + tarCopySrcName = "TarCopySrc" + tarCopyDestName = "TarCopyDest" + pipeReaderName = "pipe reader" + pipeWriterName = "pipe writer" + tarWriterName = "tar writer" +) + +// TarCopySrc creates a tar stream writer and returns a channel to send the tar stream to. +// writeHeader is a function that writes the tar header for the file. +// writer is the tar writer to write the file data to. +// close is a function that closes the tar writer and internal pipe writer. +func TarCopySrc(ctx context.Context, pathPrefix string) (outputChan chan wshrpc.RespOrErrorUnion[iochantypes.Packet], writeHeader func(fi fs.FileInfo, file string) error, writer io.Writer, close func()) { + pipeReader, pipeWriter := io.Pipe() + tarWriter := tar.NewWriter(pipeWriter) + rtnChan := iochan.ReaderChan(ctx, pipeReader, wshrpc.FileChunkSize, func() { + gracefulClose(pipeReader, tarCopySrcName, pipeReaderName) + }) + + return rtnChan, func(fi fs.FileInfo, file string) error { + // generate tar header + header, err := tar.FileInfoHeader(fi, file) + if err != nil { + return err + } + + header.Name = filepath.Clean(strings.TrimPrefix(file, pathPrefix)) + if err := validatePath(header.Name); err != nil { + return err + } + + // write header + if err := tarWriter.WriteHeader(header); err != nil { + return err + } + return nil + }, tarWriter, func() { + gracefulClose(tarWriter, tarCopySrcName, tarWriterName) + gracefulClose(pipeWriter, tarCopySrcName, pipeWriterName) + } +} + +func validatePath(path string) error { + if strings.Contains(path, "..") { + return fmt.Errorf("invalid tar path containing directory traversal: %s", path) + } + if strings.HasPrefix(path, "/") { + return fmt.Errorf("invalid tar path starting with /: %s", path) + } + return nil +} + +// TarCopyDest reads a tar stream from a channel and writes the files to the destination. +// readNext is a function that is called for each file in the tar stream to read the file data. It should return an error if the file cannot be read. +// The function returns an error if the tar stream cannot be read. +func TarCopyDest(ctx context.Context, cancel context.CancelCauseFunc, ch <-chan wshrpc.RespOrErrorUnion[iochantypes.Packet], readNext func(next *tar.Header, reader *tar.Reader) error) error { + pipeReader, pipeWriter := io.Pipe() + iochan.WriterChan(ctx, pipeWriter, ch, func() { + gracefulClose(pipeWriter, tarCopyDestName, pipeWriterName) + cancel(nil) + }, cancel) + tarReader := tar.NewReader(pipeReader) + defer func() { + if !gracefulClose(pipeReader, tarCopyDestName, pipeReaderName) { + // If the pipe reader cannot be closed, cancel the context. This should kill the writer goroutine. + cancel(nil) + } + }() + for { + select { + case <-ctx.Done(): + if ctx.Err() != nil { + return context.Cause(ctx) + } + return nil + default: + next, err := tarReader.Next() + if err != nil { + // Do one more check for context error before returning + if ctx.Err() != nil { + return context.Cause(ctx) + } + if errors.Is(err, io.EOF) { + return nil + } else { + return err + } + } + err = readNext(next, tarReader) + if err != nil { + return err + } + } + } +} + +func gracefulClose(closer io.Closer, debugName string, closerName string) bool { + closed := false + for retries := 0; retries < maxRetries; retries++ { + if err := closer.Close(); err != nil { + log.Printf("%s: error closing %s: %v, trying again in %dms\n", debugName, closerName, err, retryDelay.Milliseconds()) + time.Sleep(retryDelay) + continue + } + closed = true + break + } + if !closed { + log.Printf("%s: unable to close %s after %d retries\n", debugName, closerName, maxRetries) + } + return closed +} diff --git a/pkg/web/web.go b/pkg/web/web.go index 35d9445c6..6cc1f999b 100644 --- a/pkg/web/web.go +++ b/pkg/web/web.go @@ -243,11 +243,12 @@ func handleLocalStreamFile(w http.ResponseWriter, r *http.Request, path string, } } -func handleRemoteStreamFile(w http.ResponseWriter, _ *http.Request, conn string, path string, no404 bool) error { +func handleRemoteStreamFile(w http.ResponseWriter, req *http.Request, conn string, path string, no404 bool) error { client := wshserver.GetMainRpcClient() streamFileData := wshrpc.CommandRemoteStreamFileData{Path: path} route := wshutil.MakeConnectionRouteId(conn) - rtnCh := wshclient.RemoteStreamFileCommand(client, streamFileData, &wshrpc.RpcOpts{Route: route}) + rpcOpts := &wshrpc.RpcOpts{Route: route, Timeout: 60 * 1000} + rtnCh := wshclient.RemoteStreamFileCommand(client, streamFileData, rpcOpts) firstPk := true var fileInfo *wshrpc.FileInfo loopDone := false @@ -261,45 +262,54 @@ func handleRemoteStreamFile(w http.ResponseWriter, _ *http.Request, conn string, } }() }() - for respUnion := range rtnCh { - if respUnion.Error != nil { - return respUnion.Error - } - if firstPk { - firstPk = false - if respUnion.Response.Info == nil { - return fmt.Errorf("stream file protocol error, fileinfo is empty") + ctx := req.Context() + for { + select { + case <-ctx.Done(): + rpcOpts.StreamCancelFn() + return ctx.Err() + case respUnion, ok := <-rtnCh: + if !ok { + loopDone = true + return nil + } + if respUnion.Error != nil { + return respUnion.Error } - fileInfo = respUnion.Response.Info - if fileInfo.NotFound { - if no404 { - serveTransparentGIF(w) - return nil - } else { - return fmt.Errorf("file not found: %q", path) + if firstPk { + firstPk = false + if respUnion.Response.Info == nil { + return fmt.Errorf("stream file protocol error, fileinfo is empty") + } + fileInfo = respUnion.Response.Info + if fileInfo.NotFound { + if no404 { + serveTransparentGIF(w) + return nil + } else { + return fmt.Errorf("file not found: %q", path) + } + } + if fileInfo.IsDir { + return fmt.Errorf("cannot stream directory: %q", path) } + w.Header().Set(ContentTypeHeaderKey, fileInfo.MimeType) + w.Header().Set(ContentLengthHeaderKey, fmt.Sprintf("%d", fileInfo.Size)) + continue } - if fileInfo.IsDir { - return fmt.Errorf("cannot stream directory: %q", path) + if respUnion.Response.Data64 == "" { + continue + } + decoder := base64.NewDecoder(base64.StdEncoding, bytes.NewReader([]byte(respUnion.Response.Data64))) + _, err := io.Copy(w, decoder) + if err != nil { + log.Printf("error streaming file %q: %v\n", path, err) + // not sure what to do here, the headers have already been sent. + // just return + return nil } - w.Header().Set(ContentTypeHeaderKey, fileInfo.MimeType) - w.Header().Set(ContentLengthHeaderKey, fmt.Sprintf("%d", fileInfo.Size)) - continue - } - if respUnion.Response.Data64 == "" { - continue - } - decoder := base64.NewDecoder(base64.StdEncoding, bytes.NewReader([]byte(respUnion.Response.Data64))) - _, err := io.Copy(w, decoder) - if err != nil { - log.Printf("error streaming file %q: %v\n", path, err) - // not sure what to do here, the headers have already been sent. - // just return - return nil } } - loopDone = true - return nil } func handleStreamFile(w http.ResponseWriter, r *http.Request) { diff --git a/pkg/wshrpc/wshclient/wshclient.go b/pkg/wshrpc/wshclient/wshclient.go index 2480484ec..599cd79c5 100644 --- a/pkg/wshrpc/wshclient/wshclient.go +++ b/pkg/wshrpc/wshclient/wshclient.go @@ -12,6 +12,7 @@ import ( "github.com/wavetermdev/waveterm/pkg/waveobj" "github.com/wavetermdev/waveterm/pkg/wps" "github.com/wavetermdev/waveterm/pkg/vdom" + "github.com/wavetermdev/waveterm/pkg/util/iochan/iochantypes" ) // command "activity", wshserver.ActivityCommand @@ -207,7 +208,7 @@ func FileCreateCommand(w *wshutil.WshRpc, data wshrpc.FileData, opts *wshrpc.Rpc } // command "filedelete", wshserver.FileDeleteCommand -func FileDeleteCommand(w *wshutil.WshRpc, data wshrpc.FileData, opts *wshrpc.RpcOpts) error { +func FileDeleteCommand(w *wshutil.WshRpc, data wshrpc.CommandDeleteFileData, opts *wshrpc.RpcOpts) error { _, err := sendRpcRequestCallHelper[any](w, "filedelete", data, opts) return err } @@ -248,8 +249,8 @@ func FileReadCommand(w *wshutil.WshRpc, data wshrpc.FileData, opts *wshrpc.RpcOp } // command "filestreamtar", wshserver.FileStreamTarCommand -func FileStreamTarCommand(w *wshutil.WshRpc, data wshrpc.CommandRemoteStreamTarData, opts *wshrpc.RpcOpts) chan wshrpc.RespOrErrorUnion[[]uint8] { - return sendRpcRequestResponseStreamHelper[[]uint8](w, "filestreamtar", data, opts) +func FileStreamTarCommand(w *wshutil.WshRpc, data wshrpc.CommandRemoteStreamTarData, opts *wshrpc.RpcOpts) chan wshrpc.RespOrErrorUnion[iochantypes.Packet] { + return sendRpcRequestResponseStreamHelper[iochantypes.Packet](w, "filestreamtar", data, opts) } // command "filewrite", wshserver.FileWriteCommand @@ -313,7 +314,7 @@ func RemoteFileCopyCommand(w *wshutil.WshRpc, data wshrpc.CommandRemoteFileCopyD } // command "remotefiledelete", wshserver.RemoteFileDeleteCommand -func RemoteFileDeleteCommand(w *wshutil.WshRpc, data string, opts *wshrpc.RpcOpts) error { +func RemoteFileDeleteCommand(w *wshutil.WshRpc, data wshrpc.CommandDeleteFileData, opts *wshrpc.RpcOpts) error { _, err := sendRpcRequestCallHelper[any](w, "remotefiledelete", data, opts) return err } @@ -376,8 +377,8 @@ func RemoteStreamFileCommand(w *wshutil.WshRpc, data wshrpc.CommandRemoteStreamF } // command "remotetarstream", wshserver.RemoteTarStreamCommand -func RemoteTarStreamCommand(w *wshutil.WshRpc, data wshrpc.CommandRemoteStreamTarData, opts *wshrpc.RpcOpts) chan wshrpc.RespOrErrorUnion[[]uint8] { - return sendRpcRequestResponseStreamHelper[[]uint8](w, "remotetarstream", data, opts) +func RemoteTarStreamCommand(w *wshutil.WshRpc, data wshrpc.CommandRemoteStreamTarData, opts *wshrpc.RpcOpts) chan wshrpc.RespOrErrorUnion[iochantypes.Packet] { + return sendRpcRequestResponseStreamHelper[iochantypes.Packet](w, "remotetarstream", data, opts) } // command "remotewritefile", wshserver.RemoteWriteFileCommand diff --git a/pkg/wshrpc/wshremote/wshremote.go b/pkg/wshrpc/wshremote/wshremote.go index df241dd26..f324f52fe 100644 --- a/pkg/wshrpc/wshremote/wshremote.go +++ b/pkg/wshrpc/wshremote/wshremote.go @@ -18,14 +18,21 @@ import ( "time" "github.com/wavetermdev/waveterm/pkg/remote/connparse" + "github.com/wavetermdev/waveterm/pkg/remote/fileshare/wshfs" "github.com/wavetermdev/waveterm/pkg/util/fileutil" - "github.com/wavetermdev/waveterm/pkg/util/iochan" + "github.com/wavetermdev/waveterm/pkg/util/iochan/iochantypes" + "github.com/wavetermdev/waveterm/pkg/util/tarcopy" "github.com/wavetermdev/waveterm/pkg/util/utilfn" "github.com/wavetermdev/waveterm/pkg/wavebase" "github.com/wavetermdev/waveterm/pkg/wshrpc" + "github.com/wavetermdev/waveterm/pkg/wshrpc/wshclient" "github.com/wavetermdev/waveterm/pkg/wshutil" ) +const ( + DefaultTimeout = 30 * time.Second +) + type ServerImpl struct { LogWriter io.Writer } @@ -190,9 +197,6 @@ func (impl *ServerImpl) remoteStreamFileInternal(ctx context.Context, data wshrp if finfo.NotFound { return nil } - if finfo.Size > wshrpc.MaxFileSize { - return fmt.Errorf("file %q is too large to read, use /wave/stream-file", path) - } if finfo.IsDir { return impl.remoteStreamFileDir(ctx, path, byteRange, dataCallback) } else { @@ -220,7 +224,6 @@ func (impl *ServerImpl) RemoteStreamFileCommand(ctx context.Context, data wshrpc resp.Data64 = base64.StdEncoding.EncodeToString(data) resp.At = &wshrpc.FileDataAt{Offset: byteRange.Start, Size: len(data)} } - logPrintfDev("callback -- sending response %d\n", len(resp.Data64)) ch <- wshrpc.RespOrErrorUnion[wshrpc.FileData]{Response: resp} }) if err != nil { @@ -230,92 +233,83 @@ func (impl *ServerImpl) RemoteStreamFileCommand(ctx context.Context, data wshrpc return ch } -func (impl *ServerImpl) RemoteTarStreamCommand(ctx context.Context, data wshrpc.CommandRemoteStreamTarData) <-chan wshrpc.RespOrErrorUnion[[]byte] { +func (impl *ServerImpl) RemoteTarStreamCommand(ctx context.Context, data wshrpc.CommandRemoteStreamTarData) <-chan wshrpc.RespOrErrorUnion[iochantypes.Packet] { path := data.Path opts := data.Opts if opts == nil { opts = &wshrpc.FileCopyOpts{} } recursive := opts.Recursive - log.Printf("RemoteTarStreamCommand: path=%s\n", path) + logPrintfDev("RemoteTarStreamCommand: path=%s\n", path) path, err := wavebase.ExpandHomeDir(path) if err != nil { - return wshutil.SendErrCh[[]byte](fmt.Errorf("cannot expand path %q: %w", path, err)) + return wshutil.SendErrCh[iochantypes.Packet](fmt.Errorf("cannot expand path %q: %w", path, err)) } cleanedPath := filepath.Clean(wavebase.ExpandHomeDirSafe(path)) finfo, err := os.Stat(cleanedPath) if err != nil { - return wshutil.SendErrCh[[]byte](fmt.Errorf("cannot stat file %q: %w", path, err)) - } - pipeReader, pipeWriter := io.Pipe() - tarWriter := tar.NewWriter(pipeWriter) - timeout := time.Millisecond * 100 - if opts.Timeout > 0 { - timeout = time.Duration(opts.Timeout) * time.Millisecond + return wshutil.SendErrCh[iochantypes.Packet](fmt.Errorf("cannot stat file %q: %w", path, err)) } - readerCtx, _ := context.WithTimeout(context.Background(), timeout) - rtn := iochan.ReaderChan(readerCtx, pipeReader, wshrpc.FileChunkSize, func() { - pipeReader.Close() - pipeWriter.Close() - }) var pathPrefix string if finfo.IsDir() && strings.HasSuffix(cleanedPath, "/") { pathPrefix = cleanedPath } else { - pathPrefix = filepath.Dir(cleanedPath) + pathPrefix = filepath.Dir(cleanedPath) + "/" } - go func() { - if readerCtx.Err() != nil { - return + if finfo.IsDir() { + if !recursive { + return wshutil.SendErrCh[iochantypes.Packet](fmt.Errorf("cannot create tar stream for %q: %w", path, errors.New("directory copy requires recursive option"))) } - defer tarWriter.Close() - logPrintfDev("creating tar stream for %q\n", path) - if finfo.IsDir() { - logPrintfDev("%q is a directory, recursive: %v\n", path, recursive) - if !recursive { - rtn <- wshutil.RespErr[[]byte](fmt.Errorf("cannot create tar stream for %q: %w", path, errors.New("directory copy requires recursive option"))) - return + } + + timeout := DefaultTimeout + if opts.Timeout > 0 { + timeout = time.Duration(opts.Timeout) * time.Millisecond + } + readerCtx, cancel := context.WithTimeout(ctx, timeout) + rtn, writeHeader, fileWriter, tarClose := tarcopy.TarCopySrc(readerCtx, pathPrefix) + + go func() { + defer func() { + tarClose() + cancel() + }() + walkFunc := func(path string, info fs.FileInfo, err error) error { + if readerCtx.Err() != nil { + return readerCtx.Err() } - } - err := filepath.Walk(path, func(file string, fi os.FileInfo, err error) error { - // generate tar header - header, err := tar.FileInfoHeader(fi, file) if err != nil { return err } - - header.Name = strings.TrimPrefix(file, pathPrefix) - if header.Name == "" { - return nil - } - - // write header - if err := tarWriter.WriteHeader(header); err != nil { + if err = writeHeader(info, path); err != nil { return err } // if not a dir, write file content - if !fi.IsDir() { - data, err := os.Open(file) + if !info.IsDir() { + data, err := os.Open(path) if err != nil { return err } - if n, err := io.Copy(tarWriter, data); err != nil { - log.Printf("error copying file %q: %v\n", file, err) + if _, err := io.Copy(fileWriter, data); err != nil { return err - } else { - logPrintfDev("wrote %d bytes to tar stream\n", n) } } - time.Sleep(time.Millisecond * 10) return nil - }) + } + log.Printf("RemoteTarStreamCommand: starting\n") + err = nil + if finfo.IsDir() { + err = filepath.Walk(path, walkFunc) + } else { + err = walkFunc(path, finfo, nil) + } if err != nil { - rtn <- wshutil.RespErr[[]byte](fmt.Errorf("cannot create tar stream for %q: %w", path, err)) + rtn <- wshutil.RespErr[iochantypes.Packet](err) } - logPrintfDev("returning tar stream\n") + log.Printf("RemoteTarStreamCommand: done\n") }() - logPrintfDev("returning channel\n") + log.Printf("RemoteTarStreamCommand: returning channel\n") return rtn } @@ -327,7 +321,7 @@ func (impl *ServerImpl) RemoteFileCopyCommand(ctx context.Context, data wshrpc.C } destUri := data.DestUri srcUri := data.SrcUri - // merge := opts.Merge + merge := opts.Merge overwrite := opts.Overwrite destConn, err := connparse.ParseURIAndReplaceCurrentHost(ctx, destUri) @@ -350,68 +344,46 @@ func (impl *ServerImpl) RemoteFileCopyCommand(ctx context.Context, data wshrpc.C } else if !errors.Is(err, fs.ErrNotExist) { return fmt.Errorf("cannot stat destination %q: %w", destPathCleaned, err) } - logPrintfDev("copying %q to %q\n", srcUri, destUri) srcConn, err := connparse.ParseURIAndReplaceCurrentHost(ctx, srcUri) if err != nil { return fmt.Errorf("cannot parse source URI %q: %w", srcUri, err) } if srcConn.Host == destConn.Host { - logPrintfDev("same host, copying file\n") srcPathCleaned := filepath.Clean(wavebase.ExpandHomeDirSafe(srcConn.Path)) err := os.Rename(srcPathCleaned, destPathCleaned) if err != nil { return fmt.Errorf("cannot copy file %q to %q: %w", srcPathCleaned, destPathCleaned, err) } } else { - return fmt.Errorf("cannot copy file %q to %q: source and destination must be on the same host", srcUri, destPathCleaned) - } - /* TODO: uncomment once ready for cross-connection copy - timeout := time.Millisecond * 100 - if opts.Timeout > 0 { - timeout = time.Duration(opts.Timeout) * time.Millisecond - } - readCtx, _ := context.WithTimeout(ctx, timeout) - readCtx, cancel := context.WithCancelCause(readCtx) - ioch := fileshare.ReadTarStream(readCtx, wshrpc.CommandRemoteStreamTarData{Path: srcUri, Opts: opts}) - pipeReader, pipeWriter := io.Pipe() - iochan.WriterChan(readCtx, pipeWriter, ioch, func() { - log.Printf("closing pipe writer\n") - pipeWriter.Close() - pipeReader.Close() - }, cancel) - defer cancel(nil) - tarReader := tar.NewReader(pipeReader) - for { - select { - case <-readCtx.Done(): - if readCtx.Err() != nil { - return context.Cause(readCtx) - } - return nil - default: - next, err := tarReader.Next() - if err != nil { - if errors.Is(err, io.EOF) { - // Do one more check for context error before returning - if readCtx.Err() != nil { - return context.Cause(readCtx) - } - return nil - } - return fmt.Errorf("cannot read tar stream: %w", err) - } + timeout := DefaultTimeout + if opts.Timeout > 0 { + timeout = time.Duration(opts.Timeout) * time.Millisecond + } + readCtx, cancel := context.WithCancelCause(ctx) + readCtx, timeoutCancel := context.WithTimeoutCause(readCtx, timeout, fmt.Errorf("timeout copying file %q to %q", srcUri, destUri)) + defer timeoutCancel() + copyStart := time.Now() + ioch := wshclient.FileStreamTarCommand(wshfs.RpcClient, wshrpc.CommandRemoteStreamTarData{Path: srcUri, Opts: opts}, &wshrpc.RpcOpts{Timeout: opts.Timeout}) + numFiles := 0 + numSkipped := 0 + totalBytes := int64(0) + err := tarcopy.TarCopyDest(readCtx, cancel, ioch, func(next *tar.Header, reader *tar.Reader) error { // Check for directory traversal if strings.Contains(next.Name, "..") { log.Printf("skipping file with unsafe path: %q\n", next.Name) - continue + numSkipped++ + return nil } + numFiles++ finfo := next.FileInfo() nextPath := filepath.Join(destPathCleaned, next.Name) destinfo, err = os.Stat(nextPath) if err != nil && !errors.Is(err, fs.ErrNotExist) { return fmt.Errorf("cannot stat file %q: %w", nextPath, err) } - log.Printf("new file: name %q; dest %q\n", next.Name, nextPath) + if !finfo.IsDir() { + totalBytes += finfo.Size() + } if destinfo != nil { if destinfo.IsDir() { @@ -444,16 +416,10 @@ func (impl *ServerImpl) RemoteFileCopyCommand(ctx context.Context, data wshrpc.C } } else if !overwrite { return fmt.Errorf("cannot create file %q, file exists at path, overwrite not specified", nextPath) - } else { - err := os.Remove(nextPath) - if err != nil { - return fmt.Errorf("cannot remove file %q: %w", nextPath, err) - } } } } else { if finfo.IsDir() { - log.Printf("creating directory %q\n", nextPath) err := os.MkdirAll(nextPath, finfo.Mode()) if err != nil { return fmt.Errorf("cannot create directory %q: %w", nextPath, err) @@ -463,20 +429,30 @@ func (impl *ServerImpl) RemoteFileCopyCommand(ctx context.Context, data wshrpc.C if err != nil { return fmt.Errorf("cannot create parent directory %q: %w", filepath.Dir(nextPath), err) } - file, err := os.Create(nextPath) + file, err := os.OpenFile(nextPath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, finfo.Mode()) if err != nil { return fmt.Errorf("cannot create new file %q: %w", nextPath, err) } - _, err = io.Copy(file, tarReader) + _, err = io.Copy(file, reader) if err != nil { return fmt.Errorf("cannot write file %q: %w", nextPath, err) } - file.Chmod(finfo.Mode()) file.Close() } } + return nil + }) + if err != nil { + return fmt.Errorf("cannot copy %q to %q: %w", srcUri, destUri, err) + } + totalTime := time.Since(copyStart).Seconds() + totalMegaBytes := float64(totalBytes) / 1024 / 1024 + rate := float64(0) + if totalTime > 0 { + rate = totalMegaBytes / totalTime } - }*/ + log.Printf("RemoteFileCopyCommand: done; %d files copied in %.3fs, total of %.4f MB, %.2f MB/s, %d files skipped\n", numFiles, totalTime, totalMegaBytes, rate, numSkipped) + } return nil } @@ -696,22 +672,18 @@ func (impl *ServerImpl) RemoteFileMoveCommand(ctx context.Context, data wshrpc.C } else if !errors.Is(err, fs.ErrNotExist) { return fmt.Errorf("cannot stat destination %q: %w", destUri, err) } - logPrintfDev("moving %q to %q\n", srcUri, destUri) srcConn, err := connparse.ParseURIAndReplaceCurrentHost(ctx, srcUri) if err != nil { return fmt.Errorf("cannot parse source URI %q: %w", srcUri, err) } - logPrintfDev("source host: %q, destination host: %q\n", srcConn.Host, destConn.Host) if srcConn.Host == destConn.Host { - logPrintfDev("moving file on same host\n") srcPathCleaned := filepath.Clean(wavebase.ExpandHomeDirSafe(srcConn.Path)) - logPrintfDev("moving %q to %q\n", srcPathCleaned, destPathCleaned) err := os.Rename(srcPathCleaned, destPathCleaned) if err != nil { return fmt.Errorf("cannot move file %q to %q: %w", srcPathCleaned, destPathCleaned, err) } } else { - return fmt.Errorf("cannot move file %q to %q: source and destination must be on the same host", srcUri, destUri) + return fmt.Errorf("cannot move file %q to %q: different hosts", srcUri, destUri) } return nil } @@ -796,15 +768,27 @@ func (*ServerImpl) RemoteWriteFileCommand(ctx context.Context, data wshrpc.FileD return nil } -func (*ServerImpl) RemoteFileDeleteCommand(ctx context.Context, path string) error { - expandedPath, err := wavebase.ExpandHomeDir(path) +func (*ServerImpl) RemoteFileDeleteCommand(ctx context.Context, data wshrpc.CommandDeleteFileData) error { + expandedPath, err := wavebase.ExpandHomeDir(data.Path) if err != nil { - return fmt.Errorf("cannot delete file %q: %w", path, err) + return fmt.Errorf("cannot delete file %q: %w", data.Path, err) } cleanedPath := filepath.Clean(expandedPath) + err = os.Remove(cleanedPath) if err != nil { - return fmt.Errorf("cannot delete file %q: %w", path, err) + finfo, _ := os.Stat(cleanedPath) + if finfo != nil && finfo.IsDir() { + if !data.Recursive { + return fmt.Errorf("cannot delete directory %q, recursive option not specified", data.Path) + } + err = os.RemoveAll(cleanedPath) + if err != nil { + return fmt.Errorf("cannot delete directory %q: %w", data.Path, err) + } + } else { + return fmt.Errorf("cannot delete file %q: %w", data.Path, err) + } } return nil } diff --git a/pkg/wshrpc/wshrpctypes.go b/pkg/wshrpc/wshrpctypes.go index 7fc5121dc..06de27f5c 100644 --- a/pkg/wshrpc/wshrpctypes.go +++ b/pkg/wshrpc/wshrpctypes.go @@ -13,6 +13,7 @@ import ( "reflect" "github.com/wavetermdev/waveterm/pkg/ijson" + "github.com/wavetermdev/waveterm/pkg/util/iochan/iochantypes" "github.com/wavetermdev/waveterm/pkg/vdom" "github.com/wavetermdev/waveterm/pkg/waveobj" "github.com/wavetermdev/waveterm/pkg/wconfig" @@ -25,7 +26,7 @@ const ( // MaxDirSize is the maximum number of entries that can be read in a directory MaxDirSize = 1024 // FileChunkSize is the size of the file chunk to read - FileChunkSize = 16 * 1024 + FileChunkSize = 64 * 1024 // DirChunkSize is the size of the directory chunk to read DirChunkSize = 128 ) @@ -65,6 +66,7 @@ const ( Command_FileRead = "fileread" Command_FileMove = "filemove" Command_FileCopy = "filecopy" + Command_FileStreamTar = "filestreamtar" Command_EventPublish = "eventpublish" Command_EventRecv = "eventrecv" Command_EventSub = "eventsub" @@ -150,12 +152,12 @@ type WshRpcInterface interface { WaitForRouteCommand(ctx context.Context, data CommandWaitForRouteData) (bool, error) FileMkdirCommand(ctx context.Context, data FileData) error FileCreateCommand(ctx context.Context, data FileData) error - FileDeleteCommand(ctx context.Context, data FileData) error + FileDeleteCommand(ctx context.Context, data CommandDeleteFileData) error FileAppendCommand(ctx context.Context, data FileData) error FileAppendIJsonCommand(ctx context.Context, data CommandAppendIJsonData) error FileWriteCommand(ctx context.Context, data FileData) error FileReadCommand(ctx context.Context, data FileData) (*FileData, error) - FileStreamTarCommand(ctx context.Context, data CommandRemoteStreamTarData) <-chan RespOrErrorUnion[[]byte] + FileStreamTarCommand(ctx context.Context, data CommandRemoteStreamTarData) <-chan RespOrErrorUnion[iochantypes.Packet] FileMoveCommand(ctx context.Context, data CommandFileCopyData) error FileCopyCommand(ctx context.Context, data CommandFileCopyData) error FileInfoCommand(ctx context.Context, data FileData) (*FileInfo, error) @@ -199,13 +201,13 @@ type WshRpcInterface interface { // remotes RemoteStreamFileCommand(ctx context.Context, data CommandRemoteStreamFileData) chan RespOrErrorUnion[FileData] - RemoteTarStreamCommand(ctx context.Context, data CommandRemoteStreamTarData) <-chan RespOrErrorUnion[[]byte] + RemoteTarStreamCommand(ctx context.Context, data CommandRemoteStreamTarData) <-chan RespOrErrorUnion[iochantypes.Packet] RemoteFileCopyCommand(ctx context.Context, data CommandRemoteFileCopyData) error RemoteListEntriesCommand(ctx context.Context, data CommandRemoteListEntriesData) chan RespOrErrorUnion[CommandRemoteListEntriesRtnData] RemoteFileInfoCommand(ctx context.Context, path string) (*FileInfo, error) RemoteFileTouchCommand(ctx context.Context, path string) error RemoteFileMoveCommand(ctx context.Context, data CommandRemoteFileCopyData) error - RemoteFileDeleteCommand(ctx context.Context, path string) error + RemoteFileDeleteCommand(ctx context.Context, data CommandDeleteFileData) error RemoteWriteFileCommand(ctx context.Context, data FileData) error RemoteFileJoinCommand(ctx context.Context, paths []string) (*FileInfo, error) RemoteMkdirCommand(ctx context.Context, path string) error @@ -498,6 +500,11 @@ type CpuDataType struct { Value float64 `json:"value"` } +type CommandDeleteFileData struct { + Path string `json:"path"` + Recursive bool `json:"recursive"` +} + type CommandFileCopyData struct { SrcUri string `json:"srcuri"` DestUri string `json:"desturi"` diff --git a/pkg/wshrpc/wshserver/wshserver.go b/pkg/wshrpc/wshserver/wshserver.go index fd575872b..ff403cd73 100644 --- a/pkg/wshrpc/wshserver/wshserver.go +++ b/pkg/wshrpc/wshserver/wshserver.go @@ -28,6 +28,7 @@ import ( "github.com/wavetermdev/waveterm/pkg/remote/fileshare" "github.com/wavetermdev/waveterm/pkg/telemetry" "github.com/wavetermdev/waveterm/pkg/util/envutil" + "github.com/wavetermdev/waveterm/pkg/util/iochan/iochantypes" "github.com/wavetermdev/waveterm/pkg/util/shellutil" "github.com/wavetermdev/waveterm/pkg/util/utilfn" "github.com/wavetermdev/waveterm/pkg/util/wavefileutil" @@ -307,8 +308,8 @@ func (ws *WshServer) FileMkdirCommand(ctx context.Context, data wshrpc.FileData) return fileshare.Mkdir(ctx, data.Info.Path) } -func (ws *WshServer) FileDeleteCommand(ctx context.Context, data wshrpc.FileData) error { - return fileshare.Delete(ctx, data.Info.Path) +func (ws *WshServer) FileDeleteCommand(ctx context.Context, data wshrpc.CommandDeleteFileData) error { + return fileshare.Delete(ctx, data) } func (ws *WshServer) FileInfoCommand(ctx context.Context, data wshrpc.FileData) (*wshrpc.FileInfo, error) { @@ -339,7 +340,7 @@ func (ws *WshServer) FileMoveCommand(ctx context.Context, data wshrpc.CommandFil return fileshare.Move(ctx, data) } -func (ws *WshServer) FileStreamTarCommand(ctx context.Context, data wshrpc.CommandRemoteStreamTarData) <-chan wshrpc.RespOrErrorUnion[[]byte] { +func (ws *WshServer) FileStreamTarCommand(ctx context.Context, data wshrpc.CommandRemoteStreamTarData) <-chan wshrpc.RespOrErrorUnion[iochantypes.Packet] { return fileshare.ReadTarStream(ctx, data) } diff --git a/pkg/wshutil/wshrpcio.go b/pkg/wshutil/wshrpcio.go index 67dbe48d6..9aa5f1609 100644 --- a/pkg/wshutil/wshrpcio.go +++ b/pkg/wshutil/wshrpcio.go @@ -22,12 +22,23 @@ func AdaptStreamToMsgCh(input io.Reader, output chan []byte) error { } func AdaptOutputChToStream(outputCh chan []byte, output io.Writer) error { + drain := false + defer func() { + if drain { + go func() { + for range outputCh { + } + }() + } + }() for msg := range outputCh { if _, err := output.Write(msg); err != nil { + drain = true return fmt.Errorf("error writing to output (AdaptOutputChToStream): %w", err) } // write trailing newline if _, err := output.Write([]byte{'\n'}); err != nil { + drain = true return fmt.Errorf("error writing trailing newline to output (AdaptOutputChToStream): %w", err) } } diff --git a/pkg/wshutil/wshutil.go b/pkg/wshutil/wshutil.go index 871fd72d1..17a422eeb 100644 --- a/pkg/wshutil/wshutil.go +++ b/pkg/wshutil/wshutil.go @@ -484,6 +484,8 @@ func handleDomainSocketClient(conn net.Conn) { }() defer func() { conn.Close() + close(proxy.FromRemoteCh) + close(proxy.ToRemoteCh) routeIdPtr := routeIdContainer.Load() if routeIdPtr != nil && *routeIdPtr != "" { DefaultRouter.UnregisterRoute(*routeIdPtr)