diff --git a/api/dataservices/customtemplate/customtemplate.go b/api/dataservices/customtemplate/customtemplate.go index 9cc54493e..de239487f 100644 --- a/api/dataservices/customtemplate/customtemplate.go +++ b/api/dataservices/customtemplate/customtemplate.go @@ -28,13 +28,12 @@ func NewService(connection portainer.Connection) (*Service, error) { }, nil } -// CreateCustomTemplate uses the existing id and saves it. -// TODO: where does the ID come from, and is it safe? -func (service *Service) Create(customTemplate *portainer.CustomTemplate) error { - return service.Connection.CreateObjectWithId(BucketName, int(customTemplate.ID), customTemplate) -} - -// GetNextIdentifier returns the next identifier for a custom template. func (service *Service) GetNextIdentifier() int { return service.Connection.GetNextIdentifier(BucketName) } + +func (service *Service) Create(customTemplate *portainer.CustomTemplate) error { + return service.Connection.UpdateTx(func(tx portainer.Transaction) error { + return service.Tx(tx).Create(customTemplate) + }) +} diff --git a/api/dataservices/customtemplate/customtemplate_test.go b/api/dataservices/customtemplate/customtemplate_test.go new file mode 100644 index 000000000..fed6083e7 --- /dev/null +++ b/api/dataservices/customtemplate/customtemplate_test.go @@ -0,0 +1,19 @@ +package customtemplate_test + +import ( + "testing" + + portainer "github.com/portainer/portainer/api" + "github.com/portainer/portainer/api/datastore" + "github.com/stretchr/testify/require" +) + +func TestCustomTemplateCreate(t *testing.T) { + _, ds := datastore.MustNewTestStore(t, true, false) + require.NotNil(t, ds) + + require.NoError(t, ds.CustomTemplate().Create(&portainer.CustomTemplate{ID: 1})) + e, err := ds.CustomTemplate().Read(1) + require.NoError(t, err) + require.Equal(t, portainer.CustomTemplateID(1), e.ID) +} diff --git a/api/dataservices/customtemplate/tx.go b/api/dataservices/customtemplate/tx.go new file mode 100644 index 000000000..60b0120e4 --- /dev/null +++ b/api/dataservices/customtemplate/tx.go @@ -0,0 +1,31 @@ +package customtemplate + +import ( + portainer "github.com/portainer/portainer/api" + "github.com/portainer/portainer/api/dataservices" +) + +// Service represents a service for managing custom template data. +type ServiceTx struct { + dataservices.BaseDataServiceTx[portainer.CustomTemplate, portainer.CustomTemplateID] +} + +func (service *Service) Tx(tx portainer.Transaction) ServiceTx { + return ServiceTx{ + BaseDataServiceTx: dataservices.BaseDataServiceTx[portainer.CustomTemplate, portainer.CustomTemplateID]{ + Bucket: BucketName, + Connection: service.Connection, + Tx: tx, + }, + } +} + +func (service ServiceTx) GetNextIdentifier() int { + return service.Tx.GetNextIdentifier(BucketName) +} + +// CreateCustomTemplate uses the existing id and saves it. +// TODO: where does the ID come from, and is it safe? +func (service ServiceTx) Create(customTemplate *portainer.CustomTemplate) error { + return service.Tx.CreateObjectWithId(BucketName, int(customTemplate.ID), customTemplate) +} diff --git a/api/dataservices/customtemplate/tx_test.go b/api/dataservices/customtemplate/tx_test.go new file mode 100644 index 000000000..43afde3cd --- /dev/null +++ b/api/dataservices/customtemplate/tx_test.go @@ -0,0 +1,28 @@ +package customtemplate_test + +import ( + "testing" + + portainer "github.com/portainer/portainer/api" + "github.com/portainer/portainer/api/dataservices" + "github.com/portainer/portainer/api/datastore" + "github.com/stretchr/testify/require" +) + +func TestCustomTemplateCreateTx(t *testing.T) { + _, ds := datastore.MustNewTestStore(t, true, false) + require.NotNil(t, ds) + + require.NoError(t, ds.UpdateTx(func(tx dataservices.DataStoreTx) error { + return tx.CustomTemplate().Create(&portainer.CustomTemplate{ID: 1}) + })) + + var template *portainer.CustomTemplate + require.NoError(t, ds.ViewTx(func(tx dataservices.DataStoreTx) error { + var err error + template, err = tx.CustomTemplate().Read(1) + return err + })) + + require.Equal(t, portainer.CustomTemplateID(1), template.ID) +} diff --git a/api/dataservices/pendingactions/pendingactions.go b/api/dataservices/pendingactions/pendingactions.go index 6fd2eddeb..5fd279301 100644 --- a/api/dataservices/pendingactions/pendingactions.go +++ b/api/dataservices/pendingactions/pendingactions.go @@ -1,13 +1,8 @@ package pendingactions import ( - "fmt" - "time" - portainer "github.com/portainer/portainer/api" "github.com/portainer/portainer/api/dataservices" - - "github.com/rs/zerolog/log" ) const BucketName = "pending_actions" @@ -16,10 +11,6 @@ type Service struct { dataservices.BaseDataService[portainer.PendingAction, portainer.PendingActionID] } -type ServiceTx struct { - dataservices.BaseDataServiceTx[portainer.PendingAction, portainer.PendingActionID] -} - func NewService(connection portainer.Connection) (*Service, error) { err := connection.SetServiceName(BucketName) if err != nil { @@ -34,6 +25,11 @@ func NewService(connection portainer.Connection) (*Service, error) { }, nil } +// GetNextIdentifier returns the next identifier for a custom template. +func (service *Service) GetNextIdentifier() int { + return service.Connection.GetNextIdentifier(BucketName) +} + func (s Service) Create(config *portainer.PendingAction) error { return s.Connection.UpdateTx(func(tx portainer.Transaction) error { return s.Tx(tx).Create(config) @@ -61,43 +57,3 @@ func (service *Service) Tx(tx portainer.Transaction) ServiceTx { }, } } - -func (s ServiceTx) Create(config *portainer.PendingAction) error { - return s.Tx.CreateObject(BucketName, func(id uint64) (int, any) { - config.ID = portainer.PendingActionID(id) - config.CreatedAt = time.Now().Unix() - - return int(config.ID), config - }) -} - -func (s ServiceTx) Update(ID portainer.PendingActionID, config *portainer.PendingAction) error { - return s.BaseDataServiceTx.Update(ID, config) -} - -func (s ServiceTx) DeleteByEndpointID(ID portainer.EndpointID) error { - log.Debug().Int("endpointId", int(ID)).Msg("deleting pending actions for endpoint") - pendingActions, err := s.ReadAll() - if err != nil { - return fmt.Errorf("failed to retrieve pending-actions for endpoint (%d): %w", ID, err) - } - - for _, pendingAction := range pendingActions { - if pendingAction.EndpointID == ID { - if err := s.Delete(pendingAction.ID); err != nil { - log.Debug().Int("endpointId", int(ID)).Msgf("failed to delete pending action: %v", err) - } - } - } - return nil -} - -// GetNextIdentifier returns the next identifier for a custom template. -func (service ServiceTx) GetNextIdentifier() int { - return service.Tx.GetNextIdentifier(BucketName) -} - -// GetNextIdentifier returns the next identifier for a custom template. -func (service *Service) GetNextIdentifier() int { - return service.Connection.GetNextIdentifier(BucketName) -} diff --git a/api/dataservices/pendingactions/tx.go b/api/dataservices/pendingactions/tx.go new file mode 100644 index 000000000..40090a219 --- /dev/null +++ b/api/dataservices/pendingactions/tx.go @@ -0,0 +1,49 @@ +package pendingactions + +import ( + "fmt" + "time" + + portainer "github.com/portainer/portainer/api" + "github.com/portainer/portainer/api/dataservices" + "github.com/rs/zerolog/log" +) + +type ServiceTx struct { + dataservices.BaseDataServiceTx[portainer.PendingAction, portainer.PendingActionID] +} + +func (s ServiceTx) Create(config *portainer.PendingAction) error { + return s.Tx.CreateObject(BucketName, func(id uint64) (int, any) { + config.ID = portainer.PendingActionID(id) + config.CreatedAt = time.Now().Unix() + + return int(config.ID), config + }) +} + +func (s ServiceTx) Update(ID portainer.PendingActionID, config *portainer.PendingAction) error { + return s.BaseDataServiceTx.Update(ID, config) +} + +func (s ServiceTx) DeleteByEndpointID(ID portainer.EndpointID) error { + log.Debug().Int("endpointId", int(ID)).Msg("deleting pending actions for endpoint") + pendingActions, err := s.ReadAll() + if err != nil { + return fmt.Errorf("failed to retrieve pending-actions for endpoint (%d): %w", ID, err) + } + + for _, pendingAction := range pendingActions { + if pendingAction.EndpointID == ID { + if err := s.Delete(pendingAction.ID); err != nil { + log.Debug().Int("endpointId", int(ID)).Msgf("failed to delete pending action: %v", err) + } + } + } + return nil +} + +// GetNextIdentifier returns the next identifier for a custom template. +func (service ServiceTx) GetNextIdentifier() int { + return service.Tx.GetNextIdentifier(BucketName) +} diff --git a/api/datastore/services_tx.go b/api/datastore/services_tx.go index cf9f868f4..e77fe11cc 100644 --- a/api/datastore/services_tx.go +++ b/api/datastore/services_tx.go @@ -14,7 +14,9 @@ func (tx *StoreTx) IsErrObjectNotFound(err error) bool { return tx.store.IsErrObjectNotFound(err) } -func (tx *StoreTx) CustomTemplate() dataservices.CustomTemplateService { return nil } +func (tx *StoreTx) CustomTemplate() dataservices.CustomTemplateService { + return tx.store.CustomTemplateService.Tx(tx.tx) +} func (tx *StoreTx) PendingActions() dataservices.PendingActionsService { return tx.store.PendingActionsService.Tx(tx.tx) diff --git a/api/http/handler/customtemplates/customtemplate_inspect.go b/api/http/handler/customtemplates/customtemplate_inspect.go index 1aa0ffa7b..bd07b196c 100644 --- a/api/http/handler/customtemplates/customtemplate_inspect.go +++ b/api/http/handler/customtemplates/customtemplate_inspect.go @@ -5,8 +5,11 @@ import ( "strconv" portainer "github.com/portainer/portainer/api" + "github.com/portainer/portainer/api/dataservices" httperrors "github.com/portainer/portainer/api/http/errors" "github.com/portainer/portainer/api/http/security" + "github.com/portainer/portainer/api/internal/authorization" + "github.com/portainer/portainer/api/slicesx" httperror "github.com/portainer/portainer/pkg/libhttp/error" "github.com/portainer/portainer/pkg/libhttp/request" "github.com/portainer/portainer/pkg/libhttp/response" @@ -32,31 +35,45 @@ func (handler *Handler) customTemplateInspect(w http.ResponseWriter, r *http.Req return httperror.BadRequest("Invalid Custom template identifier route variable", err) } - customTemplate, err := handler.DataStore.CustomTemplate().Read(portainer.CustomTemplateID(customTemplateID)) - if handler.DataStore.IsErrObjectNotFound(err) { - return httperror.NotFound("Unable to find a custom template with the specified identifier inside the database", err) - } else if err != nil { - return httperror.InternalServerError("Unable to find a custom template with the specified identifier inside the database", err) - } + var customTemplate *portainer.CustomTemplate + err = handler.DataStore.ViewTx(func(tx dataservices.DataStoreTx) error { + customTemplate, err = tx.CustomTemplate().Read(portainer.CustomTemplateID(customTemplateID)) + if handler.DataStore.IsErrObjectNotFound(err) { + return httperror.NotFound("Unable to find a custom template with the specified identifier inside the database", err) + } else if err != nil { + return httperror.InternalServerError("Unable to find a custom template with the specified identifier inside the database", err) + } - securityContext, err := security.RetrieveRestrictedRequestContext(r) - if err != nil { - return httperror.InternalServerError("Unable to retrieve user info from request context", err) - } + resourceControl, err := tx.ResourceControl().ResourceControlByResourceIDAndType(strconv.Itoa(customTemplateID), portainer.CustomTemplateResourceControl) + if err != nil { + return httperror.InternalServerError("Unable to retrieve a resource control associated to the custom template", err) + } - resourceControl, err := handler.DataStore.ResourceControl().ResourceControlByResourceIDAndType(strconv.Itoa(customTemplateID), portainer.CustomTemplateResourceControl) - if err != nil { - return httperror.InternalServerError("Unable to retrieve a resource control associated to the custom template", err) - } + securityContext, err := security.RetrieveRestrictedRequestContext(r) + if err != nil { + return httperror.InternalServerError("Unable to retrieve user info from request context", err) + } + + canEdit := userCanEditTemplate(customTemplate, securityContext) + hasAccess := false + + if resourceControl != nil { + customTemplate.ResourceControl = resourceControl + + teamIDs := slicesx.Map(securityContext.UserMemberships, func(m portainer.TeamMembership) portainer.TeamID { + return m.TeamID + }) + + hasAccess = authorization.UserCanAccessResource(securityContext.UserID, teamIDs, resourceControl) + + } + + if canEdit || hasAccess { + return nil + } - access := userCanEditTemplate(customTemplate, securityContext) - if !access { return httperror.Forbidden("Access denied to resource", httperrors.ErrResourceAccessDenied) - } + }) - if resourceControl != nil { - customTemplate.ResourceControl = resourceControl - } - - return response.JSON(w, customTemplate) + return response.TxResponse(w, customTemplate, err) } diff --git a/api/http/handler/customtemplates/customtemplate_inspect_test.go b/api/http/handler/customtemplates/customtemplate_inspect_test.go new file mode 100644 index 000000000..169301b59 --- /dev/null +++ b/api/http/handler/customtemplates/customtemplate_inspect_test.go @@ -0,0 +1,100 @@ +package customtemplates + +import ( + "net/http" + "net/http/httptest" + "testing" + + "github.com/gorilla/mux" + portainer "github.com/portainer/portainer/api" + "github.com/portainer/portainer/api/dataservices" + "github.com/portainer/portainer/api/datastore" + "github.com/portainer/portainer/api/http/security" + "github.com/portainer/portainer/api/internal/testhelpers" + httperror "github.com/portainer/portainer/pkg/libhttp/error" + "github.com/segmentio/encoding/json" + "github.com/stretchr/testify/require" +) + +func TestInspectHandler(t *testing.T) { + _, ds := datastore.MustNewTestStore(t, true, false) + require.NotNil(t, ds) + + require.NoError(t, ds.UpdateTx(func(tx dataservices.DataStoreTx) error { + require.NoError(t, tx.User().Create(&portainer.User{ID: 1, Username: "admin", Role: portainer.AdministratorRole})) + require.NoError(t, tx.User().Create(&portainer.User{ID: 2, Username: "std2", Role: portainer.StandardUserRole})) + require.NoError(t, tx.User().Create(&portainer.User{ID: 3, Username: "std3", Role: portainer.StandardUserRole})) + require.NoError(t, tx.User().Create(&portainer.User{ID: 4, Username: "std4", Role: portainer.StandardUserRole})) + require.NoError(t, tx.Endpoint().Create(&portainer.Endpoint{ID: 1, + UserAccessPolicies: portainer.UserAccessPolicies{ + 2: portainer.AccessPolicy{RoleID: 0}, + 3: portainer.AccessPolicy{RoleID: 0}, + }})) + require.NoError(t, tx.Team().Create(&portainer.Team{ID: 1})) + require.NoError(t, tx.TeamMembership().Create(&portainer.TeamMembership{ID: 1, UserID: 3, TeamID: 1, Role: portainer.TeamMember})) + + require.NoError(t, tx.CustomTemplate().Create(&portainer.CustomTemplate{ID: 1})) + require.NoError(t, tx.CustomTemplate().Create(&portainer.CustomTemplate{ID: 2})) + require.NoError(t, tx.ResourceControl().Create(&portainer.ResourceControl{ID: 1, ResourceID: "2", Type: portainer.CustomTemplateResourceControl, + UserAccesses: []portainer.UserResourceAccess{{UserID: 2}}, + TeamAccesses: []portainer.TeamResourceAccess{{TeamID: 1}}, + })) + return nil + })) + + handler := NewHandler(testhelpers.NewTestRequestBouncer(), ds, &TestFileService{}, nil) + + test := func(templateID string, restrictedContext *security.RestrictedRequestContext) (*httptest.ResponseRecorder, *httperror.HandlerError) { + r := httptest.NewRequest(http.MethodGet, "/custom_templates/"+templateID, nil) + r = mux.SetURLVars(r, map[string]string{"id": templateID}) + ctx := security.StoreRestrictedRequestContext(r, restrictedContext) + r = r.WithContext(ctx) + rr := httptest.NewRecorder() + return rr, handler.customTemplateInspect(rr, r) + } + + t.Run("unknown id should get not found error", func(t *testing.T) { + _, r := test("0", &security.RestrictedRequestContext{UserID: 1}) + require.NotNil(t, r) + require.Equal(t, http.StatusNotFound, r.StatusCode) + }) + + t.Run("admin should access adminonly template", func(t *testing.T) { + rr, r := test("1", &security.RestrictedRequestContext{UserID: 1, IsAdmin: true}) + require.Nil(t, r) + require.Equal(t, http.StatusOK, rr.Result().StatusCode) + var template portainer.CustomTemplate + require.NoError(t, json.NewDecoder(rr.Body).Decode(&template)) + require.Equal(t, portainer.CustomTemplateID(1), template.ID) + }) + + t.Run("std should not access adminonly template", func(t *testing.T) { + _, r := test("1", &security.RestrictedRequestContext{UserID: 2}) + require.NotNil(t, r) + require.Equal(t, http.StatusForbidden, r.StatusCode) + }) + + t.Run("std should access template via direct user access", func(t *testing.T) { + rr, r := test("2", &security.RestrictedRequestContext{UserID: 2}) + require.Nil(t, r) + require.Equal(t, http.StatusOK, rr.Result().StatusCode) + var template portainer.CustomTemplate + require.NoError(t, json.NewDecoder(rr.Body).Decode(&template)) + require.Equal(t, portainer.CustomTemplateID(2), template.ID) + }) + + t.Run("std should access template via team access", func(t *testing.T) { + rr, r := test("2", &security.RestrictedRequestContext{UserID: 3, UserMemberships: []portainer.TeamMembership{{ID: 1, UserID: 3, TeamID: 1}}}) + require.Nil(t, r) + require.Equal(t, http.StatusOK, rr.Result().StatusCode) + var template portainer.CustomTemplate + require.NoError(t, json.NewDecoder(rr.Body).Decode(&template)) + require.Equal(t, portainer.CustomTemplateID(2), template.ID) + }) + + t.Run("std should not access template without access", func(t *testing.T) { + _, r := test("2", &security.RestrictedRequestContext{UserID: 4}) + require.NotNil(t, r) + require.Equal(t, http.StatusForbidden, r.StatusCode) + }) +} diff --git a/api/http/security/bouncer.go b/api/http/security/bouncer.go index b4285c603..72dfecad5 100644 --- a/api/http/security/bouncer.go +++ b/api/http/security/bouncer.go @@ -2,6 +2,7 @@ package security import ( "net/http" + "slices" "strings" "sync" "time" @@ -555,12 +556,9 @@ func (bouncer *RequestBouncer) newRestrictedContextRequest(userID portainer.User return nil, err } - isTeamLeader := false - for _, membership := range memberships { - if membership.Role == portainer.TeamLeader { - isTeamLeader = true - } - } + isTeamLeader := slices.ContainsFunc(memberships, func(m portainer.TeamMembership) bool { + return m.Role == portainer.TeamLeader + }) return &RestrictedRequestContext{ IsAdmin: false, diff --git a/pkg/libhttp/response/txresponse.go b/pkg/libhttp/response/txresponse.go new file mode 100644 index 000000000..631dfde78 --- /dev/null +++ b/pkg/libhttp/response/txresponse.go @@ -0,0 +1,47 @@ +package response + +import ( + "errors" + "net/http" + + httperror "github.com/portainer/portainer/pkg/libhttp/error" +) + +func TxResponse(w http.ResponseWriter, r any, err error) *httperror.HandlerError { + return TxFuncResponse(err, func() *httperror.HandlerError { return JSON(w, r) }) +} + +func TxEmptyResponse(w http.ResponseWriter, err error) *httperror.HandlerError { + if err != nil { + var handlerError *httperror.HandlerError + if errors.As(err, &handlerError) { + return handlerError + } + + return httperror.InternalServerError("Unexpected error", err) + } + + return Empty(w) +} + +func TxFuncResponse(err error, validResponse func() *httperror.HandlerError) *httperror.HandlerError { + if err != nil { + var handlerError *httperror.HandlerError + if errors.As(err, &handlerError) { + return handlerError + } + + return httperror.InternalServerError("Unexpected error", err) + } + + return validResponse() +} + +func TxErrorResponse(err error) *httperror.HandlerError { + var handlerError *httperror.HandlerError + if errors.As(err, &handlerError) { + return handlerError + } + + return httperror.InternalServerError("Unexpected error", err) +} diff --git a/pkg/libhttp/response/txresponse_test.go b/pkg/libhttp/response/txresponse_test.go new file mode 100644 index 000000000..5094b4c4f --- /dev/null +++ b/pkg/libhttp/response/txresponse_test.go @@ -0,0 +1,86 @@ +package response + +import ( + "errors" + "net/http" + "net/http/httptest" + "testing" + + httperrors "github.com/portainer/portainer/api/http/errors" + httperror "github.com/portainer/portainer/pkg/libhttp/error" + "github.com/stretchr/testify/require" +) + +func TestTxResponse(t *testing.T) { + type sample struct { + Name string `json:"name"` + } + + w := httptest.NewRecorder() + got := TxResponse(w, sample{Name: "Alice"}, nil) + require.Nil(t, got) + require.Equal(t, http.StatusOK, w.Result().StatusCode) + + w = httptest.NewRecorder() + got = TxResponse(w, sample{}, httperror.Forbidden("Access denied to resource", httperrors.ErrResourceAccessDenied)) + require.NotNil(t, got) + require.Equal(t, http.StatusForbidden, got.StatusCode) + require.Equal(t, "Access denied to resource", got.Message) + + w = httptest.NewRecorder() + got = TxResponse(w, sample{}, errors.New("Some error")) + require.NotNil(t, got) + require.Equal(t, http.StatusInternalServerError, got.StatusCode) + require.Equal(t, "Unexpected error", got.Message) +} + +func TestTxEmptyResponse(t *testing.T) { + w := httptest.NewRecorder() + got := TxEmptyResponse(w, nil) + require.Nil(t, got) + require.Equal(t, http.StatusNoContent, w.Result().StatusCode) + + w = httptest.NewRecorder() + got = TxEmptyResponse(w, httperror.Forbidden("Access denied to resource", httperrors.ErrResourceAccessDenied)) + require.NotNil(t, got) + require.Equal(t, http.StatusForbidden, got.StatusCode) + require.Equal(t, "Access denied to resource", got.Message) + + w = httptest.NewRecorder() + got = TxEmptyResponse(w, errors.New("Some error")) + require.NotNil(t, got) + require.Equal(t, http.StatusInternalServerError, got.StatusCode) + require.Equal(t, "Unexpected error", got.Message) +} + +func TestTxFuncResponse(t *testing.T) { + got := TxFuncResponse(nil, func() *httperror.HandlerError { return nil }) + require.Nil(t, got) + + got = TxFuncResponse(httperror.Forbidden("Access denied to resource", httperrors.ErrResourceAccessDenied), func() *httperror.HandlerError { return nil }) + require.NotNil(t, got) + require.Equal(t, http.StatusForbidden, got.StatusCode) + require.Equal(t, "Access denied to resource", got.Message) + + got = TxFuncResponse(errors.New("Some error"), func() *httperror.HandlerError { return nil }) + require.NotNil(t, got) + require.Equal(t, http.StatusInternalServerError, got.StatusCode) + require.Equal(t, "Unexpected error", got.Message) +} + +func TestTxErrorResponse(t *testing.T) { + got := TxErrorResponse(nil) + require.NotNil(t, got) + require.Equal(t, http.StatusInternalServerError, got.StatusCode) + require.Equal(t, "Unexpected error", got.Message) + + got = TxErrorResponse(httperror.Forbidden("Access denied to resource", httperrors.ErrResourceAccessDenied)) + require.NotNil(t, got) + require.Equal(t, http.StatusForbidden, got.StatusCode) + require.Equal(t, "Access denied to resource", got.Message) + + got = TxErrorResponse(errors.New("Some error")) + require.NotNil(t, got) + require.Equal(t, http.StatusInternalServerError, got.StatusCode) + require.Equal(t, "Unexpected error", got.Message) +}