diff --git a/agent/main.go b/agent/main.go index f2c2aee412e..d4dcb0687b2 100644 --- a/agent/main.go +++ b/agent/main.go @@ -1,7 +1,9 @@ package main import ( + "errors" "fmt" + "io" "net" "net/http" "os" @@ -9,14 +11,13 @@ import ( "strings" "time" - "github.com/shellhub-io/shellhub/pkg/loglevel" - "github.com/Masterminds/semver" "github.com/gorilla/mux" "github.com/kelseyhightower/envconfig" "github.com/shellhub-io/shellhub/agent/pkg/tunnel" "github.com/shellhub-io/shellhub/agent/selfupdater" "github.com/shellhub-io/shellhub/agent/server" + "github.com/shellhub-io/shellhub/pkg/loglevel" log "github.com/sirupsen/logrus" "github.com/spf13/cobra" ) @@ -65,7 +66,7 @@ type ConfigOptions struct { } // NewAgentServer creates a new agent server instance. -func NewAgentServer() *Agent { +func NewAgentServer() *Agent { // nolint:gocyclo opts := ConfigOptions{} // Process unprefixed env vars for backward compatibility @@ -170,6 +171,59 @@ func NewAgentServer() *Agent { conn.Close() } + tun.HTTPHandler = func(w http.ResponseWriter, r *http.Request) { + replyError := func(err error, msg string, code int) { + log.WithError(err).WithFields(log.Fields{ + "remote": r.RemoteAddr, + "namespace": r.Header.Get("X-Namespace"), + "path": r.Header.Get("X-Path"), + "version": AgentVersion, + }).Error(msg) + + http.Error(w, msg, code) + } + + in, err := net.Dial("tcp", ":8080") + if err != nil { + replyError(err, "failed to connect to HTTP the server on device", http.StatusInternalServerError) + + return + } + + defer in.Close() + + url, err := r.URL.Parse(r.Header.Get("X-Path")) + if err != nil { + replyError(err, "failed to parse URL", http.StatusInternalServerError) + + return + } + + r.URL.Scheme = "http" + r.URL = url + + if err := r.Write(in); err != nil { + replyError(err, "failed to write request to the server on device", http.StatusInternalServerError) + + return + } + + ctr := http.NewResponseController(w) + out, _, err := ctr.Hijack() + if err != nil { + replyError(err, "failed to hijack connection", http.StatusInternalServerError) + + return + } + + defer out.Close() // nolint:errcheck + + if _, err := io.Copy(out, in); errors.Is(err, io.ErrUnexpectedEOF) { + replyError(err, "failed to copy response from device service to client", http.StatusInternalServerError) + + return + } + } tun.CloseHandler = func(w http.ResponseWriter, r *http.Request) { vars := mux.Vars(r) serv.CloseSession(vars["id"]) diff --git a/agent/pkg/tunnel/tunnel.go b/agent/pkg/tunnel/tunnel.go index 55909b68cd9..9ad96c35e08 100644 --- a/agent/pkg/tunnel/tunnel.go +++ b/agent/pkg/tunnel/tunnel.go @@ -12,6 +12,7 @@ import ( type Tunnel struct { router *mux.Router srv *http.Server + HTTPHandler func(w http.ResponseWriter, r *http.Request) ConnHandler func(w http.ResponseWriter, r *http.Request) CloseHandler func(w http.ResponseWriter, r *http.Request) } @@ -27,6 +28,9 @@ func NewTunnel() *Tunnel { return context.WithValue(ctx, "http-conn", c) //nolint:revive }, }, + HTTPHandler: func(w http.ResponseWriter, r *http.Request) { + panic("HTTPHandler can not be nil") + }, ConnHandler: func(w http.ResponseWriter, r *http.Request) { panic("connHandler can not be nil") }, @@ -34,6 +38,9 @@ func NewTunnel() *Tunnel { panic("closeHandler can not be nil") }, } + t.router.HandleFunc("/ssh/http", func(w http.ResponseWriter, r *http.Request) { + t.HTTPHandler(w, r) + }) t.router.HandleFunc("/ssh/{id}", func(w http.ResponseWriter, r *http.Request) { t.ConnHandler(w, r) }) diff --git a/api/pkg/guard/actions.go b/api/pkg/guard/actions.go index 031ad8c7a0d..2fc6dff6af2 100644 --- a/api/pkg/guard/actions.go +++ b/api/pkg/guard/actions.go @@ -13,7 +13,7 @@ type AllActions struct { } type DeviceActions struct { - Accept, Reject, Remove, Connect, Rename, CreateTag, UpdateTag, RemoveTag, RenameTag, DeleteTag int + Accept, Reject, Update, Remove, Connect, Rename, CreateTag, UpdateTag, RemoveTag, RenameTag, DeleteTag int } type SessionActions struct { @@ -42,6 +42,7 @@ var Actions = AllActions{ Device: DeviceActions{ Accept: DeviceAccept, Reject: DeviceReject, + Update: DeviceUpdate, Remove: DeviceRemove, Connect: DeviceConnect, Rename: DeviceRename, diff --git a/api/pkg/guard/guard_test.go b/api/pkg/guard/guard_test.go index 7b33a70ef8f..31745318f5c 100644 --- a/api/pkg/guard/guard_test.go +++ b/api/pkg/guard/guard_test.go @@ -331,6 +331,7 @@ func TestCheckPermission(t *testing.T) { Actions.Device.Reject, Actions.Device.Connect, Actions.Device.Rename, + Actions.Device.Update, Actions.Device.CreateTag, Actions.Device.UpdateTag, @@ -353,6 +354,7 @@ func TestCheckPermission(t *testing.T) { Actions.Device.Remove, Actions.Device.Connect, Actions.Device.Rename, + Actions.Device.Update, Actions.Device.CreateTag, Actions.Device.UpdateTag, @@ -392,6 +394,7 @@ func TestCheckPermission(t *testing.T) { Actions.Device.Remove, Actions.Device.Connect, Actions.Device.Rename, + Actions.Device.Update, Actions.Device.CreateTag, Actions.Device.UpdateTag, diff --git a/api/pkg/guard/permissions.go b/api/pkg/guard/permissions.go index 60df39480d4..b379eeefbce 100644 --- a/api/pkg/guard/permissions.go +++ b/api/pkg/guard/permissions.go @@ -5,6 +5,7 @@ type Permissions []int const ( DeviceAccept = iota + 1 DeviceReject + DeviceUpdate DeviceRemove DeviceConnect DeviceRename @@ -66,6 +67,7 @@ var operatorPermissions = Permissions{ DeviceConnect, DeviceRename, DeviceDetails, + DeviceUpdate, DeviceCreateTag, DeviceUpdateTag, @@ -83,6 +85,7 @@ var adminPermissions = Permissions{ DeviceConnect, DeviceRename, DeviceDetails, + DeviceUpdate, DeviceCreateTag, DeviceUpdateTag, @@ -90,6 +93,8 @@ var adminPermissions = Permissions{ DeviceRenameTag, DeviceDeleteTag, + DeviceUpdate, + SessionPlay, SessionClose, SessionRemove, @@ -123,6 +128,7 @@ var ownerPermissions = Permissions{ DeviceConnect, DeviceRename, DeviceDetails, + DeviceUpdate, DeviceCreateTag, DeviceUpdateTag, @@ -130,6 +136,8 @@ var ownerPermissions = Permissions{ DeviceRenameTag, DeviceDeleteTag, + DeviceUpdate, + SessionPlay, SessionClose, SessionRemove, diff --git a/api/routes/device.go b/api/routes/device.go index 99cac8ff970..01406b6d9e3 100644 --- a/api/routes/device.go +++ b/api/routes/device.go @@ -25,6 +25,7 @@ const ( CreateTagURL = "/devices/:uid/tags" // Add a tag to a device. UpdateTagURL = "/devices/:uid/tags" // Update device's tags with a new set. RemoveTagURL = "/devices/:uid/tags/:tag" // Delete a tag from a device. + UpdateDevice = "/devices/:uid" ) const ( @@ -281,3 +282,27 @@ func (h *Handler) UpdateDeviceTag(c gateway.Context) error { return c.NoContent(http.StatusOK) } + +func (h *Handler) UpdateDevice(c gateway.Context) error { + var req request.DeviceUpdate + if err := c.Bind(&req); err != nil { + return err + } + + if err := c.Validate(&req); err != nil { + return err + } + + var tenant string + if c.Tenant() != nil { + tenant = c.Tenant().ID + } + + if err := guard.EvaluatePermission(c.Role(), guard.Actions.Device.Update, func() error { + return h.service.UpdateDevice(c.Ctx(), tenant, models.UID(req.UID), req.Name, req.PublicURL) + }); err != nil { + return err + } + + return c.NoContent(http.StatusOK) +} diff --git a/api/server.go b/api/server.go index e6a55849cee..3ed0de981e5 100644 --- a/api/server.go +++ b/api/server.go @@ -206,6 +206,7 @@ func startServer(cfg *config) error { publicAPI.GET(routes.GetDeviceURL, apiMiddleware.Authorize(gateway.Handler(handler.GetDevice))) publicAPI.DELETE(routes.DeleteDeviceURL, gateway.Handler(handler.DeleteDevice)) + publicAPI.PUT(routes.UpdateDevice, gateway.Handler(handler.UpdateDevice)) publicAPI.PATCH(routes.RenameDeviceURL, gateway.Handler(handler.RenameDevice)) internalAPI.POST(routes.OfflineDeviceURL, gateway.Handler(handler.OfflineDevice)) internalAPI.POST(routes.HeartbeatDeviceURL, gateway.Handler(handler.HeartbeatDevice)) diff --git a/api/services/device.go b/api/services/device.go index c18f073b4ca..0d40f28c373 100644 --- a/api/services/device.go +++ b/api/services/device.go @@ -2,6 +2,7 @@ package services import ( "context" + "fmt" "net" "strings" @@ -25,6 +26,7 @@ type DeviceService interface { UpdatePendingStatus(ctx context.Context, uid models.UID, status, tenant string) error SetDevicePosition(ctx context.Context, uid models.UID, ip string) error DeviceHeartbeat(ctx context.Context, uid models.UID) error + UpdateDevice(ctx context.Context, tenant string, uid models.UID, name *string, publicURL *bool) error } func (s *service) ListDevices(ctx context.Context, pagination paginator.Query, filter []models.Filter, status string, sort string, order string) ([]models.Device, int, error) { @@ -225,3 +227,29 @@ func (s *service) DeviceHeartbeat(ctx context.Context, uid models.UID) error { return nil } + +func (s *service) UpdateDevice(ctx context.Context, tenant string, uid models.UID, name *string, publicURL *bool) error { + device, err := s.store.DeviceGetByUID(ctx, uid, tenant) + if err != nil { + return NewErrDeviceNotFound(uid, err) + } + + if name != nil { + *name = strings.ToLower(*name) + + if device.Name == *name { + return nil + } + + otherDevice, err := s.store.DeviceGetByName(ctx, *name, tenant) + if err != nil && err != store.ErrNoDocuments { + return NewErrDeviceNotFound(models.UID(device.UID), fmt.Errorf("failed to get device by name: %w", err)) + } + + if otherDevice != nil { + return NewErrDeviceDuplicated(otherDevice.Name, err) + } + } + + return s.store.DeviceUpdate(ctx, uid, name, publicURL) +} diff --git a/api/services/mocks/services.go b/api/services/mocks/services.go index a3078ab751c..2cdeda54e9e 100644 --- a/api/services/mocks/services.go +++ b/api/services/mocks/services.go @@ -1099,6 +1099,20 @@ func (_m *Service) UpdateDataUser(ctx context.Context, id string, userData reque return r0, r1 } +// UpdateDevice provides a mock function with given fields: ctx, tenant, uid, name, publicURL +func (_m *Service) UpdateDevice(ctx context.Context, tenant string, uid models.UID, name *string, publicURL *bool) error { + ret := _m.Called(ctx, tenant, uid, name, publicURL) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, models.UID, *string, *bool) error); ok { + r0 = rf(ctx, tenant, uid, name, publicURL) + } else { + r0 = ret.Error(0) + } + + return r0 +} + // UpdateDeviceStatus provides a mock function with given fields: ctx, uid, online func (_m *Service) UpdateDeviceStatus(ctx context.Context, uid models.UID, online bool) error { ret := _m.Called(ctx, uid, online) diff --git a/api/store/device_store.go b/api/store/device_store.go index b064c61ced8..e1076c743e8 100644 --- a/api/store/device_store.go +++ b/api/store/device_store.go @@ -11,6 +11,7 @@ import ( type DeviceStore interface { DeviceList(ctx context.Context, pagination paginator.Query, filters []models.Filter, status string, sort string, order string) ([]models.Device, int, error) DeviceGet(ctx context.Context, uid models.UID) (*models.Device, error) + DeviceUpdate(ctx context.Context, uid models.UID, name *string, publicURL *bool) error DeviceDelete(ctx context.Context, uid models.UID) error DeviceCreate(ctx context.Context, d models.Device, hostname string) error DeviceRename(ctx context.Context, uid models.UID, hostname string) error diff --git a/api/store/mocks/store.go b/api/store/mocks/store.go index 84f99ad818f..9691ac0f947 100644 --- a/api/store/mocks/store.go +++ b/api/store/mocks/store.go @@ -1,4 +1,4 @@ -// Code generated by mockery v2.10.4. DO NOT EDIT. +// Code generated by mockery v2.20.0. DO NOT EDIT. package mocks @@ -53,6 +53,10 @@ func (_m *Store) AnnouncementGet(ctx context.Context, uuid string) (*models.Anno ret := _m.Called(ctx, uuid) var r0 *models.Announcement + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (*models.Announcement, error)); ok { + return rf(ctx, uuid) + } if rf, ok := ret.Get(0).(func(context.Context, string) *models.Announcement); ok { r0 = rf(ctx, uuid) } else { @@ -61,7 +65,6 @@ func (_m *Store) AnnouncementGet(ctx context.Context, uuid string) (*models.Anno } } - var r1 error if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { r1 = rf(ctx, uuid) } else { @@ -76,6 +79,11 @@ func (_m *Store) AnnouncementList(ctx context.Context, pagination paginator.Quer ret := _m.Called(ctx, pagination, ordination) var r0 []models.AnnouncementShort + var r1 int + var r2 error + if rf, ok := ret.Get(0).(func(context.Context, paginator.Query, order.Query) ([]models.AnnouncementShort, int, error)); ok { + return rf(ctx, pagination, ordination) + } if rf, ok := ret.Get(0).(func(context.Context, paginator.Query, order.Query) []models.AnnouncementShort); ok { r0 = rf(ctx, pagination, ordination) } else { @@ -84,14 +92,12 @@ func (_m *Store) AnnouncementList(ctx context.Context, pagination paginator.Quer } } - var r1 int if rf, ok := ret.Get(1).(func(context.Context, paginator.Query, order.Query) int); ok { r1 = rf(ctx, pagination, ordination) } else { r1 = ret.Get(1).(int) } - var r2 error if rf, ok := ret.Get(2).(func(context.Context, paginator.Query, order.Query) error); ok { r2 = rf(ctx, pagination, ordination) } else { @@ -176,6 +182,10 @@ func (_m *Store) DeviceGet(ctx context.Context, uid models.UID) (*models.Device, ret := _m.Called(ctx, uid) var r0 *models.Device + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, models.UID) (*models.Device, error)); ok { + return rf(ctx, uid) + } if rf, ok := ret.Get(0).(func(context.Context, models.UID) *models.Device); ok { r0 = rf(ctx, uid) } else { @@ -184,7 +194,6 @@ func (_m *Store) DeviceGet(ctx context.Context, uid models.UID) (*models.Device, } } - var r1 error if rf, ok := ret.Get(1).(func(context.Context, models.UID) error); ok { r1 = rf(ctx, uid) } else { @@ -199,6 +208,10 @@ func (_m *Store) DeviceGetByMac(ctx context.Context, mac string, tenantID string ret := _m.Called(ctx, mac, tenantID, status) var r0 *models.Device + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, string) (*models.Device, error)); ok { + return rf(ctx, mac, tenantID, status) + } if rf, ok := ret.Get(0).(func(context.Context, string, string, string) *models.Device); ok { r0 = rf(ctx, mac, tenantID, status) } else { @@ -207,7 +220,6 @@ func (_m *Store) DeviceGetByMac(ctx context.Context, mac string, tenantID string } } - var r1 error if rf, ok := ret.Get(1).(func(context.Context, string, string, string) error); ok { r1 = rf(ctx, mac, tenantID, status) } else { @@ -222,6 +234,10 @@ func (_m *Store) DeviceGetByName(ctx context.Context, name string, tenantID stri ret := _m.Called(ctx, name, tenantID) var r0 *models.Device + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string) (*models.Device, error)); ok { + return rf(ctx, name, tenantID) + } if rf, ok := ret.Get(0).(func(context.Context, string, string) *models.Device); ok { r0 = rf(ctx, name, tenantID) } else { @@ -230,7 +246,6 @@ func (_m *Store) DeviceGetByName(ctx context.Context, name string, tenantID stri } } - var r1 error if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok { r1 = rf(ctx, name, tenantID) } else { @@ -245,6 +260,10 @@ func (_m *Store) DeviceGetByUID(ctx context.Context, uid models.UID, tenantID st ret := _m.Called(ctx, uid, tenantID) var r0 *models.Device + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, models.UID, string) (*models.Device, error)); ok { + return rf(ctx, uid, tenantID) + } if rf, ok := ret.Get(0).(func(context.Context, models.UID, string) *models.Device); ok { r0 = rf(ctx, uid, tenantID) } else { @@ -253,7 +272,6 @@ func (_m *Store) DeviceGetByUID(ctx context.Context, uid models.UID, tenantID st } } - var r1 error if rf, ok := ret.Get(1).(func(context.Context, models.UID, string) error); ok { r1 = rf(ctx, uid, tenantID) } else { @@ -268,6 +286,11 @@ func (_m *Store) DeviceList(ctx context.Context, pagination paginator.Query, fil ret := _m.Called(ctx, pagination, filters, status, sort, _a5) var r0 []models.Device + var r1 int + var r2 error + if rf, ok := ret.Get(0).(func(context.Context, paginator.Query, []models.Filter, string, string, string) ([]models.Device, int, error)); ok { + return rf(ctx, pagination, filters, status, sort, _a5) + } if rf, ok := ret.Get(0).(func(context.Context, paginator.Query, []models.Filter, string, string, string) []models.Device); ok { r0 = rf(ctx, pagination, filters, status, sort, _a5) } else { @@ -276,14 +299,12 @@ func (_m *Store) DeviceList(ctx context.Context, pagination paginator.Query, fil } } - var r1 int if rf, ok := ret.Get(1).(func(context.Context, paginator.Query, []models.Filter, string, string, string) int); ok { r1 = rf(ctx, pagination, filters, status, sort, _a5) } else { r1 = ret.Get(1).(int) } - var r2 error if rf, ok := ret.Get(2).(func(context.Context, paginator.Query, []models.Filter, string, string, string) error); ok { r2 = rf(ctx, pagination, filters, status, sort, _a5) } else { @@ -298,6 +319,10 @@ func (_m *Store) DeviceListByUsage(ctx context.Context, tenantID string) ([]mode ret := _m.Called(ctx, tenantID) var r0 []models.UID + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) ([]models.UID, error)); ok { + return rf(ctx, tenantID) + } if rf, ok := ret.Get(0).(func(context.Context, string) []models.UID); ok { r0 = rf(ctx, tenantID) } else { @@ -306,7 +331,6 @@ func (_m *Store) DeviceListByUsage(ctx context.Context, tenantID string) ([]mode } } - var r1 error if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { r1 = rf(ctx, tenantID) } else { @@ -321,6 +345,10 @@ func (_m *Store) DeviceLookup(ctx context.Context, namespace string, hostname st ret := _m.Called(ctx, namespace, hostname) var r0 *models.Device + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string) (*models.Device, error)); ok { + return rf(ctx, namespace, hostname) + } if rf, ok := ret.Get(0).(func(context.Context, string, string) *models.Device); ok { r0 = rf(ctx, namespace, hostname) } else { @@ -329,7 +357,6 @@ func (_m *Store) DeviceLookup(ctx context.Context, namespace string, hostname st } } - var r1 error if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok { r1 = rf(ctx, namespace, hostname) } else { @@ -395,6 +422,20 @@ func (_m *Store) DeviceSetPosition(ctx context.Context, uid models.UID, position return r0 } +// DeviceUpdate provides a mock function with given fields: ctx, uid, name, publicURL +func (_m *Store) DeviceUpdate(ctx context.Context, uid models.UID, name *string, publicURL *bool) error { + ret := _m.Called(ctx, uid, name, publicURL) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, models.UID, *string, *bool) error); ok { + r0 = rf(ctx, uid, name, publicURL) + } else { + r0 = ret.Error(0) + } + + return r0 +} + // DeviceUpdateLastSeen provides a mock function with given fields: ctx, uid, ts func (_m *Store) DeviceUpdateLastSeen(ctx context.Context, uid models.UID, ts time.Time) error { ret := _m.Called(ctx, uid, ts) @@ -512,6 +553,10 @@ func (_m *Store) FirewallRuleGet(ctx context.Context, id string) (*models.Firewa ret := _m.Called(ctx, id) var r0 *models.FirewallRule + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (*models.FirewallRule, error)); ok { + return rf(ctx, id) + } if rf, ok := ret.Get(0).(func(context.Context, string) *models.FirewallRule); ok { r0 = rf(ctx, id) } else { @@ -520,7 +565,6 @@ func (_m *Store) FirewallRuleGet(ctx context.Context, id string) (*models.Firewa } } - var r1 error if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { r1 = rf(ctx, id) } else { @@ -535,6 +579,11 @@ func (_m *Store) FirewallRuleGetTags(ctx context.Context, tenant string) ([]stri ret := _m.Called(ctx, tenant) var r0 []string + var r1 int + var r2 error + if rf, ok := ret.Get(0).(func(context.Context, string) ([]string, int, error)); ok { + return rf(ctx, tenant) + } if rf, ok := ret.Get(0).(func(context.Context, string) []string); ok { r0 = rf(ctx, tenant) } else { @@ -543,14 +592,12 @@ func (_m *Store) FirewallRuleGetTags(ctx context.Context, tenant string) ([]stri } } - var r1 int if rf, ok := ret.Get(1).(func(context.Context, string) int); ok { r1 = rf(ctx, tenant) } else { r1 = ret.Get(1).(int) } - var r2 error if rf, ok := ret.Get(2).(func(context.Context, string) error); ok { r2 = rf(ctx, tenant) } else { @@ -565,6 +612,11 @@ func (_m *Store) FirewallRuleList(ctx context.Context, pagination paginator.Quer ret := _m.Called(ctx, pagination) var r0 []models.FirewallRule + var r1 int + var r2 error + if rf, ok := ret.Get(0).(func(context.Context, paginator.Query) ([]models.FirewallRule, int, error)); ok { + return rf(ctx, pagination) + } if rf, ok := ret.Get(0).(func(context.Context, paginator.Query) []models.FirewallRule); ok { r0 = rf(ctx, pagination) } else { @@ -573,14 +625,12 @@ func (_m *Store) FirewallRuleList(ctx context.Context, pagination paginator.Quer } } - var r1 int if rf, ok := ret.Get(1).(func(context.Context, paginator.Query) int); ok { r1 = rf(ctx, pagination) } else { r1 = ret.Get(1).(int) } - var r2 error if rf, ok := ret.Get(2).(func(context.Context, paginator.Query) error); ok { r2 = rf(ctx, pagination) } else { @@ -623,6 +673,10 @@ func (_m *Store) FirewallRuleUpdate(ctx context.Context, id string, rule models. ret := _m.Called(ctx, id, rule) var r0 *models.FirewallRule + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, models.FirewallRuleUpdate) (*models.FirewallRule, error)); ok { + return rf(ctx, id, rule) + } if rf, ok := ret.Get(0).(func(context.Context, string, models.FirewallRuleUpdate) *models.FirewallRule); ok { r0 = rf(ctx, id, rule) } else { @@ -631,7 +685,6 @@ func (_m *Store) FirewallRuleUpdate(ctx context.Context, id string, rule models. } } - var r1 error if rf, ok := ret.Get(1).(func(context.Context, string, models.FirewallRuleUpdate) error); ok { r1 = rf(ctx, id, rule) } else { @@ -660,6 +713,10 @@ func (_m *Store) GetStats(ctx context.Context) (*models.Stats, error) { ret := _m.Called(ctx) var r0 *models.Stats + var r1 error + if rf, ok := ret.Get(0).(func(context.Context) (*models.Stats, error)); ok { + return rf(ctx) + } if rf, ok := ret.Get(0).(func(context.Context) *models.Stats); ok { r0 = rf(ctx) } else { @@ -668,7 +725,6 @@ func (_m *Store) GetStats(ctx context.Context) (*models.Stats, error) { } } - var r1 error if rf, ok := ret.Get(1).(func(context.Context) error); ok { r1 = rf(ctx) } else { @@ -683,6 +739,10 @@ func (_m *Store) LicenseLoad(ctx context.Context) (*models.License, error) { ret := _m.Called(ctx) var r0 *models.License + var r1 error + if rf, ok := ret.Get(0).(func(context.Context) (*models.License, error)); ok { + return rf(ctx) + } if rf, ok := ret.Get(0).(func(context.Context) *models.License); ok { r0 = rf(ctx) } else { @@ -691,7 +751,6 @@ func (_m *Store) LicenseLoad(ctx context.Context) (*models.License, error) { } } - var r1 error if rf, ok := ret.Get(1).(func(context.Context) error); ok { r1 = rf(ctx) } else { @@ -720,6 +779,10 @@ func (_m *Store) NamespaceAddMember(ctx context.Context, tenantID string, member ret := _m.Called(ctx, tenantID, memberID, memberRole) var r0 *models.Namespace + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, string) (*models.Namespace, error)); ok { + return rf(ctx, tenantID, memberID, memberRole) + } if rf, ok := ret.Get(0).(func(context.Context, string, string, string) *models.Namespace); ok { r0 = rf(ctx, tenantID, memberID, memberRole) } else { @@ -728,7 +791,6 @@ func (_m *Store) NamespaceAddMember(ctx context.Context, tenantID string, member } } - var r1 error if rf, ok := ret.Get(1).(func(context.Context, string, string, string) error); ok { r1 = rf(ctx, tenantID, memberID, memberRole) } else { @@ -743,6 +805,10 @@ func (_m *Store) NamespaceCreate(ctx context.Context, namespace *models.Namespac ret := _m.Called(ctx, namespace) var r0 *models.Namespace + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, *models.Namespace) (*models.Namespace, error)); ok { + return rf(ctx, namespace) + } if rf, ok := ret.Get(0).(func(context.Context, *models.Namespace) *models.Namespace); ok { r0 = rf(ctx, namespace) } else { @@ -751,7 +817,6 @@ func (_m *Store) NamespaceCreate(ctx context.Context, namespace *models.Namespac } } - var r1 error if rf, ok := ret.Get(1).(func(context.Context, *models.Namespace) error); ok { r1 = rf(ctx, namespace) } else { @@ -794,6 +859,10 @@ func (_m *Store) NamespaceGet(ctx context.Context, tenantID string) (*models.Nam ret := _m.Called(ctx, tenantID) var r0 *models.Namespace + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (*models.Namespace, error)); ok { + return rf(ctx, tenantID) + } if rf, ok := ret.Get(0).(func(context.Context, string) *models.Namespace); ok { r0 = rf(ctx, tenantID) } else { @@ -802,7 +871,6 @@ func (_m *Store) NamespaceGet(ctx context.Context, tenantID string) (*models.Nam } } - var r1 error if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { r1 = rf(ctx, tenantID) } else { @@ -817,6 +885,10 @@ func (_m *Store) NamespaceGetByName(ctx context.Context, name string) (*models.N ret := _m.Called(ctx, name) var r0 *models.Namespace + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (*models.Namespace, error)); ok { + return rf(ctx, name) + } if rf, ok := ret.Get(0).(func(context.Context, string) *models.Namespace); ok { r0 = rf(ctx, name) } else { @@ -825,7 +897,6 @@ func (_m *Store) NamespaceGetByName(ctx context.Context, name string) (*models.N } } - var r1 error if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { r1 = rf(ctx, name) } else { @@ -840,6 +911,10 @@ func (_m *Store) NamespaceGetFirst(ctx context.Context, id string) (*models.Name ret := _m.Called(ctx, id) var r0 *models.Namespace + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (*models.Namespace, error)); ok { + return rf(ctx, id) + } if rf, ok := ret.Get(0).(func(context.Context, string) *models.Namespace); ok { r0 = rf(ctx, id) } else { @@ -848,7 +923,6 @@ func (_m *Store) NamespaceGetFirst(ctx context.Context, id string) (*models.Name } } - var r1 error if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { r1 = rf(ctx, id) } else { @@ -863,13 +937,16 @@ func (_m *Store) NamespaceGetSessionRecord(ctx context.Context, tenantID string) ret := _m.Called(ctx, tenantID) var r0 bool + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (bool, error)); ok { + return rf(ctx, tenantID) + } if rf, ok := ret.Get(0).(func(context.Context, string) bool); ok { r0 = rf(ctx, tenantID) } else { r0 = ret.Get(0).(bool) } - var r1 error if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { r1 = rf(ctx, tenantID) } else { @@ -884,6 +961,11 @@ func (_m *Store) NamespaceList(ctx context.Context, pagination paginator.Query, ret := _m.Called(ctx, pagination, filters, export) var r0 []models.Namespace + var r1 int + var r2 error + if rf, ok := ret.Get(0).(func(context.Context, paginator.Query, []models.Filter, bool) ([]models.Namespace, int, error)); ok { + return rf(ctx, pagination, filters, export) + } if rf, ok := ret.Get(0).(func(context.Context, paginator.Query, []models.Filter, bool) []models.Namespace); ok { r0 = rf(ctx, pagination, filters, export) } else { @@ -892,14 +974,12 @@ func (_m *Store) NamespaceList(ctx context.Context, pagination paginator.Query, } } - var r1 int if rf, ok := ret.Get(1).(func(context.Context, paginator.Query, []models.Filter, bool) int); ok { r1 = rf(ctx, pagination, filters, export) } else { r1 = ret.Get(1).(int) } - var r2 error if rf, ok := ret.Get(2).(func(context.Context, paginator.Query, []models.Filter, bool) error); ok { r2 = rf(ctx, pagination, filters, export) } else { @@ -914,6 +994,10 @@ func (_m *Store) NamespaceRemoveMember(ctx context.Context, tenantID string, mem ret := _m.Called(ctx, tenantID, memberID) var r0 *models.Namespace + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string) (*models.Namespace, error)); ok { + return rf(ctx, tenantID, memberID) + } if rf, ok := ret.Get(0).(func(context.Context, string, string) *models.Namespace); ok { r0 = rf(ctx, tenantID, memberID) } else { @@ -922,7 +1006,6 @@ func (_m *Store) NamespaceRemoveMember(ctx context.Context, tenantID string, mem } } - var r1 error if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok { r1 = rf(ctx, tenantID, memberID) } else { @@ -937,6 +1020,10 @@ func (_m *Store) NamespaceRename(ctx context.Context, tenantID string, name stri ret := _m.Called(ctx, tenantID, name) var r0 *models.Namespace + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string) (*models.Namespace, error)); ok { + return rf(ctx, tenantID, name) + } if rf, ok := ret.Get(0).(func(context.Context, string, string) *models.Namespace); ok { r0 = rf(ctx, tenantID, name) } else { @@ -945,7 +1032,6 @@ func (_m *Store) NamespaceRename(ctx context.Context, tenantID string, name stri } } - var r1 error if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok { r1 = rf(ctx, tenantID, name) } else { @@ -1002,6 +1088,10 @@ func (_m *Store) PrivateKeyGet(ctx context.Context, fingerprint string) (*models ret := _m.Called(ctx, fingerprint) var r0 *models.PrivateKey + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (*models.PrivateKey, error)); ok { + return rf(ctx, fingerprint) + } if rf, ok := ret.Get(0).(func(context.Context, string) *models.PrivateKey); ok { r0 = rf(ctx, fingerprint) } else { @@ -1010,7 +1100,6 @@ func (_m *Store) PrivateKeyGet(ctx context.Context, fingerprint string) (*models } } - var r1 error if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { r1 = rf(ctx, fingerprint) } else { @@ -1081,6 +1170,10 @@ func (_m *Store) PublicKeyGet(ctx context.Context, fingerprint string, tenantID ret := _m.Called(ctx, fingerprint, tenantID) var r0 *models.PublicKey + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string) (*models.PublicKey, error)); ok { + return rf(ctx, fingerprint, tenantID) + } if rf, ok := ret.Get(0).(func(context.Context, string, string) *models.PublicKey); ok { r0 = rf(ctx, fingerprint, tenantID) } else { @@ -1089,7 +1182,6 @@ func (_m *Store) PublicKeyGet(ctx context.Context, fingerprint string, tenantID } } - var r1 error if rf, ok := ret.Get(1).(func(context.Context, string, string) error); ok { r1 = rf(ctx, fingerprint, tenantID) } else { @@ -1104,6 +1196,11 @@ func (_m *Store) PublicKeyGetTags(ctx context.Context, tenant string) ([]string, ret := _m.Called(ctx, tenant) var r0 []string + var r1 int + var r2 error + if rf, ok := ret.Get(0).(func(context.Context, string) ([]string, int, error)); ok { + return rf(ctx, tenant) + } if rf, ok := ret.Get(0).(func(context.Context, string) []string); ok { r0 = rf(ctx, tenant) } else { @@ -1112,14 +1209,12 @@ func (_m *Store) PublicKeyGetTags(ctx context.Context, tenant string) ([]string, } } - var r1 int if rf, ok := ret.Get(1).(func(context.Context, string) int); ok { r1 = rf(ctx, tenant) } else { r1 = ret.Get(1).(int) } - var r2 error if rf, ok := ret.Get(2).(func(context.Context, string) error); ok { r2 = rf(ctx, tenant) } else { @@ -1134,6 +1229,11 @@ func (_m *Store) PublicKeyList(ctx context.Context, pagination paginator.Query) ret := _m.Called(ctx, pagination) var r0 []models.PublicKey + var r1 int + var r2 error + if rf, ok := ret.Get(0).(func(context.Context, paginator.Query) ([]models.PublicKey, int, error)); ok { + return rf(ctx, pagination) + } if rf, ok := ret.Get(0).(func(context.Context, paginator.Query) []models.PublicKey); ok { r0 = rf(ctx, pagination) } else { @@ -1142,14 +1242,12 @@ func (_m *Store) PublicKeyList(ctx context.Context, pagination paginator.Query) } } - var r1 int if rf, ok := ret.Get(1).(func(context.Context, paginator.Query) int); ok { r1 = rf(ctx, pagination) } else { r1 = ret.Get(1).(int) } - var r2 error if rf, ok := ret.Get(2).(func(context.Context, paginator.Query) error); ok { r2 = rf(ctx, pagination) } else { @@ -1192,6 +1290,10 @@ func (_m *Store) PublicKeyUpdate(ctx context.Context, fingerprint string, tenant ret := _m.Called(ctx, fingerprint, tenantID, key) var r0 *models.PublicKey + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string, string, *models.PublicKeyUpdate) (*models.PublicKey, error)); ok { + return rf(ctx, fingerprint, tenantID, key) + } if rf, ok := ret.Get(0).(func(context.Context, string, string, *models.PublicKeyUpdate) *models.PublicKey); ok { r0 = rf(ctx, fingerprint, tenantID, key) } else { @@ -1200,7 +1302,6 @@ func (_m *Store) PublicKeyUpdate(ctx context.Context, fingerprint string, tenant } } - var r1 error if rf, ok := ret.Get(1).(func(context.Context, string, string, *models.PublicKeyUpdate) error); ok { r1 = rf(ctx, fingerprint, tenantID, key) } else { @@ -1229,6 +1330,10 @@ func (_m *Store) SessionCreate(ctx context.Context, session models.Session) (*mo ret := _m.Called(ctx, session) var r0 *models.Session + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, models.Session) (*models.Session, error)); ok { + return rf(ctx, session) + } if rf, ok := ret.Get(0).(func(context.Context, models.Session) *models.Session); ok { r0 = rf(ctx, session) } else { @@ -1237,7 +1342,6 @@ func (_m *Store) SessionCreate(ctx context.Context, session models.Session) (*mo } } - var r1 error if rf, ok := ret.Get(1).(func(context.Context, models.Session) error); ok { r1 = rf(ctx, session) } else { @@ -1294,6 +1398,10 @@ func (_m *Store) SessionGet(ctx context.Context, uid models.UID) (*models.Sessio ret := _m.Called(ctx, uid) var r0 *models.Session + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, models.UID) (*models.Session, error)); ok { + return rf(ctx, uid) + } if rf, ok := ret.Get(0).(func(context.Context, models.UID) *models.Session); ok { r0 = rf(ctx, uid) } else { @@ -1302,7 +1410,6 @@ func (_m *Store) SessionGet(ctx context.Context, uid models.UID) (*models.Sessio } } - var r1 error if rf, ok := ret.Get(1).(func(context.Context, models.UID) error); ok { r1 = rf(ctx, uid) } else { @@ -1317,6 +1424,11 @@ func (_m *Store) SessionGetRecordFrame(ctx context.Context, uid models.UID) ([]m ret := _m.Called(ctx, uid) var r0 []models.RecordedSession + var r1 int + var r2 error + if rf, ok := ret.Get(0).(func(context.Context, models.UID) ([]models.RecordedSession, int, error)); ok { + return rf(ctx, uid) + } if rf, ok := ret.Get(0).(func(context.Context, models.UID) []models.RecordedSession); ok { r0 = rf(ctx, uid) } else { @@ -1325,14 +1437,12 @@ func (_m *Store) SessionGetRecordFrame(ctx context.Context, uid models.UID) ([]m } } - var r1 int if rf, ok := ret.Get(1).(func(context.Context, models.UID) int); ok { r1 = rf(ctx, uid) } else { r1 = ret.Get(1).(int) } - var r2 error if rf, ok := ret.Get(2).(func(context.Context, models.UID) error); ok { r2 = rf(ctx, uid) } else { @@ -1347,6 +1457,11 @@ func (_m *Store) SessionList(ctx context.Context, pagination paginator.Query) ([ ret := _m.Called(ctx, pagination) var r0 []models.Session + var r1 int + var r2 error + if rf, ok := ret.Get(0).(func(context.Context, paginator.Query) ([]models.Session, int, error)); ok { + return rf(ctx, pagination) + } if rf, ok := ret.Get(0).(func(context.Context, paginator.Query) []models.Session); ok { r0 = rf(ctx, pagination) } else { @@ -1355,14 +1470,12 @@ func (_m *Store) SessionList(ctx context.Context, pagination paginator.Query) ([ } } - var r1 int if rf, ok := ret.Get(1).(func(context.Context, paginator.Query) int); ok { r1 = rf(ctx, pagination) } else { r1 = ret.Get(1).(int) } - var r2 error if rf, ok := ret.Get(2).(func(context.Context, paginator.Query) error); ok { r2 = rf(ctx, pagination) } else { @@ -1461,6 +1574,11 @@ func (_m *Store) TagsGet(ctx context.Context, tenant string) ([]string, int, err ret := _m.Called(ctx, tenant) var r0 []string + var r1 int + var r2 error + if rf, ok := ret.Get(0).(func(context.Context, string) ([]string, int, error)); ok { + return rf(ctx, tenant) + } if rf, ok := ret.Get(0).(func(context.Context, string) []string); ok { r0 = rf(ctx, tenant) } else { @@ -1469,14 +1587,12 @@ func (_m *Store) TagsGet(ctx context.Context, tenant string) ([]string, int, err } } - var r1 int if rf, ok := ret.Get(1).(func(context.Context, string) int); ok { r1 = rf(ctx, tenant) } else { r1 = ret.Get(1).(int) } - var r2 error if rf, ok := ret.Get(2).(func(context.Context, string) error); ok { r2 = rf(ctx, tenant) } else { @@ -1547,6 +1663,10 @@ func (_m *Store) UserDetachInfo(ctx context.Context, id string) (map[string][]*m ret := _m.Called(ctx, id) var r0 map[string][]*models.Namespace + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (map[string][]*models.Namespace, error)); ok { + return rf(ctx, id) + } if rf, ok := ret.Get(0).(func(context.Context, string) map[string][]*models.Namespace); ok { r0 = rf(ctx, id) } else { @@ -1555,7 +1675,6 @@ func (_m *Store) UserDetachInfo(ctx context.Context, id string) (map[string][]*m } } - var r1 error if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { r1 = rf(ctx, id) } else { @@ -1570,6 +1689,10 @@ func (_m *Store) UserGetByEmail(ctx context.Context, email string) (*models.User ret := _m.Called(ctx, email) var r0 *models.User + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (*models.User, error)); ok { + return rf(ctx, email) + } if rf, ok := ret.Get(0).(func(context.Context, string) *models.User); ok { r0 = rf(ctx, email) } else { @@ -1578,7 +1701,6 @@ func (_m *Store) UserGetByEmail(ctx context.Context, email string) (*models.User } } - var r1 error if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { r1 = rf(ctx, email) } else { @@ -1593,6 +1715,11 @@ func (_m *Store) UserGetByID(ctx context.Context, id string, ns bool) (*models.U ret := _m.Called(ctx, id, ns) var r0 *models.User + var r1 int + var r2 error + if rf, ok := ret.Get(0).(func(context.Context, string, bool) (*models.User, int, error)); ok { + return rf(ctx, id, ns) + } if rf, ok := ret.Get(0).(func(context.Context, string, bool) *models.User); ok { r0 = rf(ctx, id, ns) } else { @@ -1601,14 +1728,12 @@ func (_m *Store) UserGetByID(ctx context.Context, id string, ns bool) (*models.U } } - var r1 int if rf, ok := ret.Get(1).(func(context.Context, string, bool) int); ok { r1 = rf(ctx, id, ns) } else { r1 = ret.Get(1).(int) } - var r2 error if rf, ok := ret.Get(2).(func(context.Context, string, bool) error); ok { r2 = rf(ctx, id, ns) } else { @@ -1623,6 +1748,10 @@ func (_m *Store) UserGetByUsername(ctx context.Context, username string) (*model ret := _m.Called(ctx, username) var r0 *models.User + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (*models.User, error)); ok { + return rf(ctx, username) + } if rf, ok := ret.Get(0).(func(context.Context, string) *models.User); ok { r0 = rf(ctx, username) } else { @@ -1631,7 +1760,6 @@ func (_m *Store) UserGetByUsername(ctx context.Context, username string) (*model } } - var r1 error if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { r1 = rf(ctx, username) } else { @@ -1646,6 +1774,10 @@ func (_m *Store) UserGetToken(ctx context.Context, id string) (*models.UserToken ret := _m.Called(ctx, id) var r0 *models.UserTokenRecover + var r1 error + if rf, ok := ret.Get(0).(func(context.Context, string) (*models.UserTokenRecover, error)); ok { + return rf(ctx, id) + } if rf, ok := ret.Get(0).(func(context.Context, string) *models.UserTokenRecover); ok { r0 = rf(ctx, id) } else { @@ -1654,7 +1786,6 @@ func (_m *Store) UserGetToken(ctx context.Context, id string) (*models.UserToken } } - var r1 error if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { r1 = rf(ctx, id) } else { @@ -1669,6 +1800,11 @@ func (_m *Store) UserList(ctx context.Context, pagination paginator.Query, filte ret := _m.Called(ctx, pagination, filters) var r0 []models.User + var r1 int + var r2 error + if rf, ok := ret.Get(0).(func(context.Context, paginator.Query, []models.Filter) ([]models.User, int, error)); ok { + return rf(ctx, pagination, filters) + } if rf, ok := ret.Get(0).(func(context.Context, paginator.Query, []models.Filter) []models.User); ok { r0 = rf(ctx, pagination, filters) } else { @@ -1677,14 +1813,12 @@ func (_m *Store) UserList(ctx context.Context, pagination paginator.Query, filte } } - var r1 int if rf, ok := ret.Get(1).(func(context.Context, paginator.Query, []models.Filter) int); ok { r1 = rf(ctx, pagination, filters) } else { r1 = ret.Get(1).(int) } - var r2 error if rf, ok := ret.Get(2).(func(context.Context, paginator.Query, []models.Filter) error); ok { r2 = rf(ctx, pagination, filters) } else { @@ -1749,3 +1883,18 @@ func (_m *Store) UserUpdatePassword(ctx context.Context, newPassword string, id return r0 } + +type mockConstructorTestingTNewStore interface { + mock.TestingT + Cleanup(func()) +} + +// NewStore creates a new instance of Store. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +func NewStore(t mockConstructorTestingTNewStore) *Store { + mock := &Store{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/api/store/mongo/device_store.go b/api/store/mongo/device_store.go index 18a1905b9b1..29011d4299d 100644 --- a/api/store/mongo/device_store.go +++ b/api/store/mongo/device_store.go @@ -12,6 +12,7 @@ import ( "github.com/shellhub-io/shellhub/pkg/models" "github.com/sirupsen/logrus" "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" ) @@ -443,3 +444,30 @@ func (s *Store) DeviceChooser(ctx context.Context, tenantID string, chosen []str return nil } + +func (s *Store) DeviceUpdate(ctx context.Context, uid models.UID, name *string, publicURL *bool) error { + session, err := s.db.Client().StartSession() + if err != nil { + return err + } + + defer session.EndSession(ctx) + + err = mongo.WithSession(ctx, session, func(sessionContext mongo.SessionContext) error { + if name != nil { + if _, err := s.db.Collection("devices").UpdateOne(sessionContext, bson.M{"uid": uid}, bson.M{"$set": bson.M{"name": *name}}); err != nil { + return err + } + } + + if publicURL != nil { + if _, err := s.db.Collection("devices").UpdateOne(sessionContext, bson.M{"uid": uid}, bson.M{"$set": bson.M{"public_url": *publicURL}}); err != nil { + return err + } + } + + return nil + }) + + return FromMongoError(err) +} diff --git a/gateway/shellhub.conf b/gateway/shellhub.conf index 30eebaef0a0..524fa73d998 100644 --- a/gateway/shellhub.conf +++ b/gateway/shellhub.conf @@ -199,7 +199,7 @@ server { proxy_pass http://$upstream; } {{ end -}} - + {{ if bool (env.Getenv "SHELLHUB_ENTERPRISE") -}} location /api/register { set $upstream cloud-api:8080; @@ -391,3 +391,16 @@ server { } } } + +server { + listen 80; + server_name ~^(?.+)\.(?.+)\..+$; + + location / { #~ ^/(.*)$ { + rewrite ^/(.*)$ /ssh/http break; + proxy_set_header X-Namespace $namespace; + proxy_set_header X-Device $device; + proxy_set_header X-Path /$1$is_args$args; + proxy_pass http://ssh:8080; + } +} diff --git a/pkg/api/request/device.go b/pkg/api/request/device.go index 4cd214d54ef..e9e68933a60 100644 --- a/pkg/api/request/device.go +++ b/pkg/api/request/device.go @@ -89,3 +89,14 @@ type DeviceAuth struct { PublicKey string `json:"public_key" validate:"required"` TenantID string `json:"tenant_id" validate:"required"` } + +type DeviceGetPublicURL struct { + DeviceParam +} + +type DeviceUpdate struct { + DeviceParam + // NOTICE: the pointers here help to distinguish between the zero value and the absence of the field. + Name *string `json:"name"` + PublicURL *bool `json:"public_url"` +} diff --git a/pkg/models/device.go b/pkg/models/device.go index a2c96be2c68..98d4a0cbd54 100644 --- a/pkg/models/device.go +++ b/pkg/models/device.go @@ -22,6 +22,7 @@ type Device struct { RemoteAddr string `json:"remote_addr" bson:"remote_addr"` Position *DevicePosition `json:"position" bson:"position"` Tags []string `json:"tags" bson:"tags,omitempty"` + PublicURL bool `json:"public_url" bson:"public_url,omitempty"` } type DeviceAuthClaims struct { diff --git a/ssh/main.go b/ssh/main.go index 7e88a18ccc2..82c20ad9176 100644 --- a/ssh/main.go +++ b/ssh/main.go @@ -3,7 +3,9 @@ package main import ( "context" "encoding/json" + "errors" "fmt" + "io" "net/http" "github.com/gorilla/mux" @@ -71,6 +73,74 @@ func main() { } }) + router.HandleFunc("/ssh/http", func(w http.ResponseWriter, r *http.Request) { + replyError := func(err error, msg string, code int) { + log.WithError(err).WithFields(log.Fields{ + "remote": r.RemoteAddr, + "namespace": r.Header.Get("X-Namespace"), + "device": r.Header.Get("X-Device"), + "path": r.Header.Get("X-Path"), + }).Error(msg) + http.Error(w, msg, code) + } + + uid, errs := tunnel.API.Lookup( + map[string]string{ + "domain": r.Header.Get("X-Namespace"), + "name": r.Header.Get("X-Device"), + }, + ) + if len(errs) > 0 { + replyError(errs[0], "failed find the device on this namespace", http.StatusInternalServerError) + + return + } + + dev, err := tunnel.API.GetDevice(uid) + if err != nil { + replyError(err, "failed to get device data", http.StatusInternalServerError) + + return + } + + if !dev.PublicURL { + replyError(err, "this device is not accessible via public URL", http.StatusForbidden) + + return + } + + in, err := tunnel.Dial(r.Context(), dev.UID) + if err != nil { + replyError(err, "failed to connect to device", http.StatusInternalServerError) + + return + } + + defer in.Close() // nolint:errcheck + + if err := r.Write(in); err != nil { + replyError(err, "failed to write request to device", http.StatusInternalServerError) + + return + } + + ctr := http.NewResponseController(w) + out, _, err := ctr.Hijack() + if err != nil { + replyError(err, "failed to hijack response", http.StatusInternalServerError) + + return + } + + defer out.Close() // nolint:errcheck + + if _, err := io.Copy(out, in); errors.Is(err, io.ErrUnexpectedEOF) { + replyError(err, "failed to copy response from device service to client", http.StatusInternalServerError) + + return + } + }) + // TODO: add `/ws/ssh` route to OpenAPI repository. router.Handle("/ws/ssh", web.HandlerRestoreSession(web.RestoreSession, handler.WebSession)). Methods(http.MethodGet)