Compare commits

..

62 Commits

Author SHA1 Message Date
Steven Kang
8332232840 chore: version bump 2.33.1 (#1108) 2025-08-27 11:30:12 +12:00
andres-portainer
816a6f9bef chore(bbolt): upgrade bbolt to v1.4.3 BE-12193 (#1104) 2025-08-25 17:59:33 -03:00
Devon Steenberg
e86ea22900 fix(sslflags): Deprecate ssl flags [BE-12168] (#1076) 2025-08-25 20:25:07 +12:00
Malcolm Lockyer
12b2acbc00 fix(standard): manual endpoint refresh fails to save new status [be-12188] (#1096) 2025-08-25 13:49:04 +12:00
Ali
4a8b42928e fix(environments): create k8s specific edge agent before connecting [r8s-438] (#1086)
Merging because this change is unrelated to the failing kubernetes/tests/helm-oci.spec.ts tests
2025-08-25 09:32:16 +12:00
Oscar Zhou
2e828b39da fix(autoupdate): update tooltips in edge stack gitops update [BE-12177] (#1080) 2025-08-23 10:55:57 +12:00
Steven Kang
49c6521c23 fix: GHSA-2464-8j7c-4cjm - release 2.33 [R8S-495] (#1089) 2025-08-22 14:03:16 +12:00
Steven Kang
debf1a742b chore: version bump 2.33.0 (#1065) 2025-08-20 11:28:05 +12:00
James Player
5d3708ec3e fix(UI): add experimental features back in [r8s-483] (#1060) 2025-08-19 17:07:27 +12:00
Steven Kang
9320fd4c50 fix: cve-2025-55198 and cve-2025-55199 - release 2.33 [R8S-482] (#1058) 2025-08-19 16:22:54 +12:00
Steven Kang
974682bd98 chore: version bump to 2.33.0-rc2 (#1054) 2025-08-19 11:04:56 +12:00
Ali
631f1deb2e fix(helm): support http and custom tls helm registries, give help when misconfigured [r8s-472] (#1032)
Co-authored-by: JamesPlayer <james.player@portainer.io>
2025-08-18 12:07:41 +12:00
LP B
4169b045fb fix(api/edge-stacks): avoid overriding updates with old values (#1048) 2025-08-16 03:52:21 +02:00
andres-portainer
0a2a786aa3 fix(migrator): rewrite a migration so it is idempotent BE-12053 (#1043) 2025-08-15 09:18:31 -03:00
James Player
808f87206e fix(ui): Fixed react-select TooManyResultsSelector filter and improved scrolling (#1028) 2025-08-15 15:33:43 +12:00
Cara Ryan
ed6fa82904 fix(pending-actions): Small improvements to pending actions (R8S-350) (#1025) 2025-08-15 10:51:45 +12:00
andres-portainer
9fc301110b fix(crypto): replace fips140 calls with fips calls BE-11979 (#1035) 2025-08-14 19:36:05 -03:00
andres-portainer
69101ac89a feat(openai): remove OpenAI BE-12018 (#1034) 2025-08-14 19:35:43 -03:00
Malcolm Lockyer
69d33dd432 fix(fips): use standard lib pbkdf2 [be-12164] (#1037) 2025-08-15 09:45:49 +12:00
Ali
389cbf748c fix(logs): improve log rendering performance [r8s-437] (#1023)
Merging because the same tests are failing in CE develop https://github.com/portainer/system-tests/actions/runs/16953578581
2025-08-14 13:53:35 +12:00
LP B
d01b31f707 feat(api): Permissions-Policy header deny all (#1022) 2025-08-13 22:07:52 +02:00
Andrew Amesbury
3ade5cdf19 bump version to 2.33.0-rc1 (#1019) 2025-08-13 14:40:34 +12:00
LP B
5f6fa4d79f fix(app/update_schedule): create schedule performance issues at scale (#1002) 2025-08-12 16:50:11 +02:00
Ali
3ee20863d6 fix(editor): remove yaml specific highlighting [r8s-441] (#1010) 2025-08-12 11:53:31 +12:00
Steven Kang
8fe5eaee29 feat(ui): Kubernetes - Create from Manifest - tidy up [R8S-67] (#971) 2025-08-12 11:49:33 +12:00
Cara Ryan
208534c9d9 fix(helm): helm apps do not combine in applications view if different namespace [R8S-420] (#988) 2025-08-12 10:23:27 +12:00
Steven Kang
3f030394c6 fix(security): remediation of cve-2025-54338 and cve-2025-8556 (#989) 2025-08-12 09:08:29 +12:00
Devon Steenberg
6ca0085ec8 fix(stackbuilders): swarm and k8s deploys [BE-12138] (#1003) 2025-08-11 15:44:36 +12:00
Malcolm Lockyer
2cf1649c67 fix(encryption): in fips mode, use pbkdf2 for db password [be-11933] (#985)
Co-authored-by: andres-portainer <andres-portainer@users.noreply.github.com>
2025-08-11 12:03:38 +12:00
andres-portainer
64ed988169 fix(linters): upgrade golangci-lint to v2.3.1 BE-12136 (#997) 2025-08-08 21:39:21 -03:00
LP B
85b7e881eb docs(api/dashboard): docker/{envId}/dashboard incorrectly marked as POST instead of GET (#996) 2025-08-08 09:31:34 +02:00
andres-portainer
9325cb2872 fix(all): avoid using pointers to zero sized structs BE-12129 (#986) 2025-08-07 09:47:42 -03:00
Steven Kang
e39dcc458b fix(security): ghsa-fv92-fjc5-jj9h [R8S-449] (#979) 2025-08-07 12:21:31 +12:00
Devon Steenberg
84b4b30f21 fix(rand): Use crypto/rand instead of math/rand in FIPS mode [BE-12071] (#961)
Co-authored-by: codecov-ai[bot] <156709835+codecov-ai[bot]@users.noreply.github.com>
2025-08-06 10:19:15 +12:00
andres-portainer
6c47598cd9 fix(apikey): use HMAC-SHA256 for FIPS mode API keys BE-11936 (#980) 2025-08-05 13:09:35 -03:00
andres-portainer
d00d71ecbf fix(linter): add linter rules to reduce the chance for invalid FIPS settings BE-11979 (#975) 2025-08-05 09:23:07 -03:00
Ali
dc273b2d63 fix(helm): don't block install with dry-run errors [r8s-454] (#976)
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2025-08-05 18:53:41 +12:00
James Carppe
497b16e942 Chore update readme graphic (#963)
Co-authored-by: Phil Calder <4473109+predlac@users.noreply.github.com>
2025-08-05 17:14:54 +12:00
LP B
a472de1919 fix(app/edge-jobs): edge job results page crash at scale (#954) 2025-08-04 17:10:46 +02:00
Malcolm Lockyer
d306d7a983 fix(encryption): replace encryption related methods for fips mode [be-11933] (#919)
Co-authored-by: andres-portainer <andres-portainer@users.noreply.github.com>
2025-08-04 17:04:03 +12:00
andres-portainer
163aa57e5c fix(tls): centralize the TLS configuration to ensure FIPS compliance BE-11979 (#960) 2025-08-01 22:23:59 -03:00
andres-portainer
3eab294908 fix(linters): add the bodyclose linter BE-12112 (#959) 2025-07-30 11:35:30 -03:00
Viktor Pettersson
da30780ac2 feat(autopatch): implement patch finder for retrieving latest patches from GitHub (#957) BE-12085 2025-07-30 15:57:32 +12:00
Ali
ef53354193 fix(snapshot): show snapshot stats [r8s-432] (#952) 2025-07-29 22:51:05 +12:00
andres-portainer
e9ce3d2213 fix(endpointedge): optimize buildSchedules() BE-12099 (#955) 2025-07-28 19:19:07 -03:00
andres-portainer
a46db61c4c fix(endpointrelation): optimize updateEdgeStacksAfterRelationChange() BE-12092 (#941) 2025-07-28 13:19:05 -03:00
Steven Kang
5e271fd4a4 feat(ui): reordered kubernetes create from code options [R8S-429] (#951) 2025-07-28 15:41:12 +12:00
James Player
6481483074 fix(app/sidebar): Custom logo UI issue [r8s-435] (#939) 2025-07-25 15:29:06 +12:00
James Player
7bcb37c761 feat(app/kubernetes): Popout kubectl shell into new window [r8s-307] (#922) 2025-07-25 15:24:32 +12:00
LP B
e7d97d7a2b fix(app/edge-configs): high numbers UI overlap (#931) 2025-07-24 16:37:07 +02:00
James Carppe
1afae99345 Updates for release 2.32.0 (#936) 2025-07-24 14:30:37 +12:00
Steven Kang
bdb2e2f417 fix(transport): portainer generated kubeconfig causes kubectl exec fail [R8S-430] (#929) 2025-07-24 13:11:13 +12:00
andres-portainer
bba3751268 fix(roar): return empty slices instead of nil for easier API compatibility BE-12053 (#932) 2025-07-23 14:06:20 -03:00
Ali
60bc04bc33 feat(helm): show manifest previews/changes when installing and upgrading a helm chart [r8s-405] (#898) 2025-07-23 10:52:58 +12:00
andres-portainer
a4cff13531 fix(bouncer): add missing domain to CSP header BE-12067 (#916) 2025-07-21 21:32:50 -03:00
andres-portainer
937456596a fix(edgegroups): convert the related endpoint IDs to roaring bitmaps to increase performance BE-12053 (#903) 2025-07-21 21:31:13 -03:00
Devon Steenberg
caf382b64c feat(git): support bearer token auth for git [BE-11770] (#879) 2025-07-22 08:36:08 +12:00
Ali
55cc250d2e fix(pods): represent pod container statuses correctly [r8s-416] (#910) 2025-07-21 15:05:08 +12:00
Ali
eaa2be017d fix(helm): ensure the form is not 'dirty', when the values are unchanged [r8s-421] (#901) 2025-07-17 12:07:11 +12:00
James Player
4e4c5ffdb6 fix(app/kubernetes): Fix listing of secrets and configmaps with same name [r8s-288] (#897) 2025-07-16 16:37:59 +12:00
James Player
383bcc4113 fix(docker/images): Fix image detail actions icon colours [be-12044] (#892) 2025-07-15 13:57:43 +12:00
James Player
9f906b7417 refactor(app/tests): Make createMockUsers more deterministic [r8s-406] (#887) 2025-07-14 17:16:33 +12:00
305 changed files with 13231 additions and 1710 deletions

View File

@@ -94,6 +94,7 @@ body:
description: We only provide support for current versions of Portainer as per the lifecycle policy linked above. If you are on an older version of Portainer we recommend [updating first](https://docs.portainer.io/start/upgrade) in case your bug has already been fixed.
multiple: false
options:
- '2.32.0'
- '2.31.3'
- '2.31.2'
- '2.31.1'

View File

@@ -1,41 +1,59 @@
version: "2"
linters:
# Disable all linters, the defaults don't pass on our code yet
disable-all: true
# Enable these for now
default: none
enable:
- unused
- depguard
- gosimple
- govet
- errorlint
- bodyclose
- copyloopvar
- depguard
- errorlint
- forbidigo
- govet
- ineffassign
- intrange
- perfsprint
- ineffassign
linters-settings:
depguard:
rules:
main:
deny:
- pkg: 'encoding/json'
desc: 'use github.com/segmentio/encoding/json'
- pkg: 'golang.org/x/exp'
desc: 'exp is not allowed'
- pkg: 'github.com/portainer/libcrypto'
desc: 'use github.com/portainer/portainer/pkg/libcrypto'
- pkg: 'github.com/portainer/libhttp'
desc: 'use github.com/portainer/portainer/pkg/libhttp'
files:
- '!**/*_test.go'
- '!**/base.go'
- '!**/base_tx.go'
# errorlint is causing a typecheck error for some reason. The go compiler will report these
# anyway, so ignore them from the linter
issues:
exclude-rules:
- path: ./
linters:
- typecheck
- staticcheck
- unused
settings:
staticcheck:
checks: ["all", "-ST1003", "-ST1005", "-ST1016", "-SA1019", "-QF1003"]
depguard:
rules:
main:
files:
- '!**/*_test.go'
- '!**/base.go'
- '!**/base_tx.go'
deny:
- pkg: encoding/json
desc: use github.com/segmentio/encoding/json
- pkg: golang.org/x/exp
desc: exp is not allowed
- pkg: github.com/portainer/libcrypto
desc: use github.com/portainer/portainer/pkg/libcrypto
- pkg: github.com/portainer/libhttp
desc: use github.com/portainer/portainer/pkg/libhttp
forbidigo:
forbid:
- pattern: ^tls\.Config$
msg: Use crypto.CreateTLSConfiguration() instead
- pattern: ^tls\.Config\.(InsecureSkipVerify|MinVersion|MaxVersion|CipherSuites|CurvePreferences)$
msg: Do not set this field directly, use crypto.CreateTLSConfiguration() instead
analyze-types: true
exclusions:
generated: lax
presets:
- comments
- common-false-positives
- legacy
- std-error-handling
paths:
- third_party$
- builtin$
- examples$
formatters:
exclusions:
generated: lax
paths:
- third_party$
- builtin$
- examples$

View File

@@ -8,9 +8,9 @@ Portainer consists of a single container that can run on any cluster. It can be
**Portainer Business Edition** builds on the open-source base and includes a range of advanced features and functions (like RBAC and Support) that are specific to the needs of business users.
- [Compare Portainer CE and Compare Portainer BE](https://portainer.io/products)
- [Compare Portainer CE and Compare Portainer BE](https://www.portainer.io/features)
- [Take3 get 3 free nodes of Portainer Business for as long as you want them](https://www.portainer.io/take-3)
- [Portainer BE install guide](https://install.portainer.io)
- [Portainer BE install guide](https://academy.portainer.io/install/)
## Latest Version
@@ -20,22 +20,19 @@ Portainer CE is updated regularly. We aim to do an update release every couple o
## Getting started
- [Deploy Portainer](https://docs.portainer.io/start/install)
- [Deploy Portainer](https://docs.portainer.io/start/install-ce)
- [Documentation](https://docs.portainer.io)
- [Contribute to the project](https://docs.portainer.io/contribute/contribute)
## Features & Functions
View [this](https://www.portainer.io/products) table to see all of the Portainer CE functionality and compare to Portainer Business.
- [Portainer CE for Docker / Docker Swarm](https://www.portainer.io/solutions/docker)
- [Portainer CE for Kubernetes](https://www.portainer.io/solutions/kubernetes-ui)
View [this](https://www.portainer.io/features) table to see all of the Portainer CE functionality and compare to Portainer Business.
## Getting help
Portainer CE is an open source project and is supported by the community. You can buy a supported version of Portainer at portainer.io
Learn more about Portainer's community support channels [here.](https://www.portainer.io/get-support-for-portainer)
Learn more about Portainer's community support channels [here.](https://www.portainer.io/resources/get-help/get-support)
- Issues: https://github.com/portainer/portainer/issues
- Slack (chat): [https://portainer.io/slack](https://portainer.io/slack)
@@ -53,13 +50,13 @@ You can join the Portainer Community by visiting [https://www.portainer.io/join-
## Work for us
If you are a developer, and our code in this repo makes sense to you, we would love to hear from you. We are always on the hunt for awesome devs, either freelance or employed. Drop us a line to info@portainer.io with your details and/or visit our [careers page](https://portainer.io/careers).
If you are a developer, and our code in this repo makes sense to you, we would love to hear from you. We are always on the hunt for awesome devs, either freelance or employed. Drop us a line to success@portainer.io with your details and/or visit our [careers page](https://apply.workable.com/portainer/).
## Privacy
**To make sure we focus our development effort in the right places we need to know which features get used most often. To give us this information we use [Matomo Analytics](https://matomo.org/), which is hosted in Germany and is fully GDPR compliant.**
When Portainer first starts, you are given the option to DISABLE analytics. If you **don't** choose to disable it, we collect anonymous usage as per [our privacy policy](https://www.portainer.io/privacy-policy). **Please note**, there is no personally identifiable information sent or stored at any time and we only use the data to help us improve Portainer.
When Portainer first starts, you are given the option to DISABLE analytics. If you **don't** choose to disable it, we collect anonymous usage as per [our privacy policy](https://www.portainer.io/legal/privacy-policy). **Please note**, there is no personally identifiable information sent or stored at any time and we only use the data to help us improve Portainer.
## Limitations

View File

@@ -16,7 +16,7 @@ import (
// GetAgentVersionAndPlatform returns the agent version and platform
//
// it sends a ping to the agent and parses the version and platform from the headers
func GetAgentVersionAndPlatform(endpointUrl string, tlsConfig *tls.Config) (portainer.AgentPlatform, string, error) {
func GetAgentVersionAndPlatform(endpointUrl string, tlsConfig *tls.Config) (portainer.AgentPlatform, string, error) { //nolint:forbidigo
httpCli := &http.Client{
Timeout: 3 * time.Second,
}

View File

@@ -54,8 +54,8 @@ func ecdsaGenerateKey(c elliptic.Curve, rand io.Reader) (*ecdsa.PrivateKey, erro
}
priv := new(ecdsa.PrivateKey)
priv.PublicKey.Curve = c
priv.Curve = c
priv.D = k
priv.PublicKey.X, priv.PublicKey.Y = c.ScalarBaseMult(k.Bytes())
priv.X, priv.Y = c.ScalarBaseMult(k.Bytes())
return priv, nil
}

View File

@@ -9,10 +9,15 @@ import (
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/datastore"
"github.com/portainer/portainer/pkg/fips"
"github.com/stretchr/testify/require"
)
func init() {
fips.InitFIPS(false)
}
func TestPingAgentPanic(t *testing.T) {
endpoint := &portainer.Endpoint{
ID: 1,

View File

@@ -4,7 +4,6 @@ import (
"encoding/base64"
"errors"
"fmt"
"math/rand"
"net"
"strings"
"time"
@@ -14,6 +13,7 @@ import (
"github.com/portainer/portainer/api/internal/edge/cache"
"github.com/portainer/portainer/api/internal/endpointutils"
"github.com/portainer/portainer/pkg/libcrypto"
"github.com/portainer/portainer/pkg/librand"
"github.com/dchest/uniuri"
"github.com/rs/zerolog/log"
@@ -200,7 +200,9 @@ func (service *Service) getUnusedPort() int {
conn, err := net.DialTCP("tcp", nil, &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: port})
if err == nil {
conn.Close()
if err := conn.Close(); err != nil {
log.Warn().Msg("failed to close tcp connection that checks if port is free")
}
log.Debug().
Int("port", port).
@@ -213,7 +215,7 @@ func (service *Service) getUnusedPort() int {
}
func randomInt(min, max int) int {
return min + rand.Intn(max-min)
return min + librand.Intn(max-min)
}
func generateRandomCredentials() (string, string) {

79
api/chisel/tunnel_test.go Normal file
View File

@@ -0,0 +1,79 @@
package chisel
import (
"net"
"strings"
"testing"
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/dataservices"
)
type testSettingsService struct {
dataservices.SettingsService
}
func (s *testSettingsService) Settings() (*portainer.Settings, error) {
return &portainer.Settings{
EdgeAgentCheckinInterval: 1,
}, nil
}
type testStore struct {
dataservices.DataStore
}
func (s *testStore) Settings() dataservices.SettingsService {
return &testSettingsService{}
}
func TestGetUnusedPort(t *testing.T) {
testCases := []struct {
name string
existingTunnels map[portainer.EndpointID]*portainer.TunnelDetails
expectedError error
}{
{
name: "simple case",
},
{
name: "existing tunnels",
existingTunnels: map[portainer.EndpointID]*portainer.TunnelDetails{
portainer.EndpointID(1): {
Port: 53072,
},
portainer.EndpointID(2): {
Port: 63072,
},
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
store := &testStore{}
s := NewService(store, nil, nil)
s.activeTunnels = tc.existingTunnels
port := s.getUnusedPort()
if port < 49152 || port > 65535 {
t.Fatalf("Expected port to be inbetween 49152 and 65535 but got %d", port)
}
for _, tun := range tc.existingTunnels {
if tun.Port == port {
t.Fatalf("returned port %d already has an existing tunnel", port)
}
}
conn, err := net.DialTCP("tcp", nil, &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: port})
if err == nil {
// Ignore error
_ = conn.Close()
t.Fatalf("expected port %d to be unused", port)
} else if !strings.Contains(err.Error(), "connection refused") {
t.Fatalf("unexpected error: %v", err)
}
})
}
}

View File

@@ -9,8 +9,8 @@ import (
portainer "github.com/portainer/portainer/api"
"github.com/alecthomas/kingpin/v2"
"github.com/rs/zerolog/log"
"gopkg.in/alecthomas/kingpin.v2"
)
// Service implements the CLIService interface
@@ -35,16 +35,9 @@ func CLIFlags() *portainer.CLIFlags {
FeatureFlags: kingpin.Flag("feat", "List of feature flags").Strings(),
EnableEdgeComputeFeatures: kingpin.Flag("edge-compute", "Enable Edge Compute features").Bool(),
NoAnalytics: kingpin.Flag("no-analytics", "Disable Analytics in app (deprecated)").Bool(),
TLS: kingpin.Flag("tlsverify", "TLS support").Default(defaultTLS).Bool(),
TLSSkipVerify: kingpin.Flag("tlsskipverify", "Disable TLS server verification").Default(defaultTLSSkipVerify).Bool(),
TLSCacert: kingpin.Flag("tlscacert", "Path to the CA").Default(defaultTLSCACertPath).String(),
TLSCert: kingpin.Flag("tlscert", "Path to the TLS certificate file").Default(defaultTLSCertPath).String(),
TLSKey: kingpin.Flag("tlskey", "Path to the TLS key").Default(defaultTLSKeyPath).String(),
HTTPDisabled: kingpin.Flag("http-disabled", "Serve portainer only on https").Default(defaultHTTPDisabled).Bool(),
HTTPEnabled: kingpin.Flag("http-enabled", "Serve portainer on http").Default(defaultHTTPEnabled).Bool(),
SSL: kingpin.Flag("ssl", "Secure Portainer instance using SSL (deprecated)").Default(defaultSSL).Bool(),
SSLCert: kingpin.Flag("sslcert", "Path to the SSL certificate used to secure the Portainer instance").String(),
SSLKey: kingpin.Flag("sslkey", "Path to the SSL key used to secure the Portainer instance").String(),
Rollback: kingpin.Flag("rollback", "Rollback the database to the previous backup").Bool(),
SnapshotInterval: kingpin.Flag("snapshot-interval", "Duration between each environment snapshot job").String(),
AdminPassword: kingpin.Flag("admin-password", "Set admin password with provided hash").String(),
@@ -67,11 +60,40 @@ func CLIFlags() *portainer.CLIFlags {
}
// ParseFlags parse the CLI flags and return a portainer.Flags struct
func (*Service) ParseFlags(version string) (*portainer.CLIFlags, error) {
func (Service) ParseFlags(version string) (*portainer.CLIFlags, error) {
kingpin.Version(version)
var hasSSLFlag, hasSSLCertFlag, hasSSLKeyFlag bool
sslFlag := kingpin.Flag(
"ssl",
"Secure Portainer instance using SSL (deprecated)",
).Default(defaultSSL).IsSetByUser(&hasSSLFlag)
ssl := sslFlag.Bool()
sslCertFlag := kingpin.Flag(
"sslcert",
"Path to the SSL certificate used to secure the Portainer instance",
).IsSetByUser(&hasSSLCertFlag)
sslCert := sslCertFlag.String()
sslKeyFlag := kingpin.Flag(
"sslkey",
"Path to the SSL key used to secure the Portainer instance",
).IsSetByUser(&hasSSLKeyFlag)
sslKey := sslKeyFlag.String()
flags := CLIFlags()
var hasTLSFlag, hasTLSCertFlag, hasTLSKeyFlag bool
tlsFlag := kingpin.Flag("tlsverify", "TLS support").Default(defaultTLS).IsSetByUser(&hasTLSFlag)
flags.TLS = tlsFlag.Bool()
tlsCertFlag := kingpin.Flag(
"tlscert",
"Path to the TLS certificate file",
).Default(defaultTLSCertPath).IsSetByUser(&hasTLSCertFlag)
flags.TLSCert = tlsCertFlag.String()
tlsKeyFlag := kingpin.Flag("tlskey", "Path to the TLS key").Default(defaultTLSKeyPath).IsSetByUser(&hasTLSKeyFlag)
flags.TLSKey = tlsKeyFlag.String()
flags.TLSCacert = kingpin.Flag("tlscacert", "Path to the CA").Default(defaultTLSCACertPath).String()
kingpin.Parse()
if !filepath.IsAbs(*flags.Assets) {
@@ -83,11 +105,46 @@ func (*Service) ParseFlags(version string) (*portainer.CLIFlags, error) {
*flags.Assets = filepath.Join(filepath.Dir(ex), *flags.Assets)
}
// If the user didn't provide a tls flag remove the defaults to match previous behaviour
if !hasTLSFlag {
if !hasTLSCertFlag {
*flags.TLSCert = ""
}
if !hasTLSKeyFlag {
*flags.TLSKey = ""
}
}
if hasSSLFlag {
log.Warn().Msgf("the %q flag is deprecated. use %q instead.", sslFlag.Model().Name, tlsFlag.Model().Name)
if !hasTLSFlag {
flags.TLS = ssl
}
}
if hasSSLCertFlag {
log.Warn().Msgf("the %q flag is deprecated. use %q instead.", sslCertFlag.Model().Name, tlsCertFlag.Model().Name)
if !hasTLSCertFlag {
flags.TLSCert = sslCert
}
}
if hasSSLKeyFlag {
log.Warn().Msgf("the %q flag is deprecated. use %q instead.", sslKeyFlag.Model().Name, tlsKeyFlag.Model().Name)
if !hasTLSKeyFlag {
flags.TLSKey = sslKey
}
}
return flags, nil
}
// ValidateFlags validates the values of the flags.
func (*Service) ValidateFlags(flags *portainer.CLIFlags) error {
func (Service) ValidateFlags(flags *portainer.CLIFlags) error {
displayDeprecationWarnings(flags)
if err := validateEndpointURL(*flags.EndpointURL); err != nil {
@@ -109,10 +166,6 @@ func displayDeprecationWarnings(flags *portainer.CLIFlags) {
if *flags.NoAnalytics {
log.Warn().Msg("the --no-analytics flag has been kept to allow migration of instances running a previous version of Portainer with this flag enabled, to version 2.0 where enabling this flag will have no effect")
}
if *flags.SSL {
log.Warn().Msg("SSL is enabled by default and there is no need for the --ssl flag, it has been kept to allow migration of instances running a previous version of Portainer with this flag enabled")
}
}
func validateEndpointURL(endpointURL string) error {

209
api/cli/cli_test.go Normal file
View File

@@ -0,0 +1,209 @@
package cli
import (
"io"
"os"
"strings"
"testing"
zerolog "github.com/rs/zerolog/log"
"github.com/stretchr/testify/require"
)
func TestOptionParser(t *testing.T) {
p := Service{}
require.NotNil(t, p)
a := os.Args
defer func() { os.Args = a }()
os.Args = []string{"portainer", "--edge-compute"}
opts, err := p.ParseFlags("2.34.5")
require.NoError(t, err)
require.False(t, *opts.HTTPDisabled)
require.True(t, *opts.EnableEdgeComputeFeatures)
}
func TestParseTLSFlags(t *testing.T) {
testCases := []struct {
name string
args []string
expectedTLSFlag bool
expectedTLSCertFlag string
expectedTLSKeyFlag string
expectedLogMessages []string
}{
{
name: "no flags",
expectedTLSFlag: false,
expectedTLSCertFlag: "",
expectedTLSKeyFlag: "",
},
{
name: "only ssl flag",
args: []string{
"portainer",
"--ssl",
},
expectedTLSFlag: true,
expectedTLSCertFlag: "",
expectedTLSKeyFlag: "",
},
{
name: "only tls flag",
args: []string{
"portainer",
"--tlsverify",
},
expectedTLSFlag: true,
expectedTLSCertFlag: defaultTLSCertPath,
expectedTLSKeyFlag: defaultTLSKeyPath,
},
{
name: "partial ssl flags",
args: []string{
"portainer",
"--ssl",
"--sslcert=ssl-cert-flag-value",
},
expectedTLSFlag: true,
expectedTLSCertFlag: "ssl-cert-flag-value",
expectedTLSKeyFlag: "",
},
{
name: "partial tls flags",
args: []string{
"portainer",
"--tlsverify",
"--tlscert=tls-cert-flag-value",
},
expectedTLSFlag: true,
expectedTLSCertFlag: "tls-cert-flag-value",
expectedTLSKeyFlag: defaultTLSKeyPath,
},
{
name: "partial tls and ssl flags",
args: []string{
"portainer",
"--tlsverify",
"--tlscert=tls-cert-flag-value",
"--sslkey=ssl-key-flag-value",
},
expectedTLSFlag: true,
expectedTLSCertFlag: "tls-cert-flag-value",
expectedTLSKeyFlag: "ssl-key-flag-value",
},
{
name: "partial tls and ssl flags 2",
args: []string{
"portainer",
"--ssl",
"--tlscert=tls-cert-flag-value",
"--sslkey=ssl-key-flag-value",
},
expectedTLSFlag: true,
expectedTLSCertFlag: "tls-cert-flag-value",
expectedTLSKeyFlag: "ssl-key-flag-value",
},
{
name: "ssl flags",
args: []string{
"portainer",
"--ssl",
"--sslcert=ssl-cert-flag-value",
"--sslkey=ssl-key-flag-value",
},
expectedTLSFlag: true,
expectedTLSCertFlag: "ssl-cert-flag-value",
expectedTLSKeyFlag: "ssl-key-flag-value",
expectedLogMessages: []string{
"the \\\"ssl\\\" flag is deprecated. use \\\"tlsverify\\\" instead.",
"the \\\"sslcert\\\" flag is deprecated. use \\\"tlscert\\\" instead.",
"the \\\"sslkey\\\" flag is deprecated. use \\\"tlskey\\\" instead.",
},
},
{
name: "tls flags",
args: []string{
"portainer",
"--tlsverify",
"--tlscert=tls-cert-flag-value",
"--tlskey=tls-key-flag-value",
},
expectedTLSFlag: true,
expectedTLSCertFlag: "tls-cert-flag-value",
expectedTLSKeyFlag: "tls-key-flag-value",
},
{
name: "tls and ssl flags",
args: []string{
"portainer",
"--tlsverify",
"--tlscert=tls-cert-flag-value",
"--tlskey=tls-key-flag-value",
"--ssl",
"--sslcert=ssl-cert-flag-value",
"--sslkey=ssl-key-flag-value",
},
expectedTLSFlag: true,
expectedTLSCertFlag: "tls-cert-flag-value",
expectedTLSKeyFlag: "tls-key-flag-value",
expectedLogMessages: []string{
"the \\\"ssl\\\" flag is deprecated. use \\\"tlsverify\\\" instead.",
"the \\\"sslcert\\\" flag is deprecated. use \\\"tlscert\\\" instead.",
"the \\\"sslkey\\\" flag is deprecated. use \\\"tlskey\\\" instead.",
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
var logOutput strings.Builder
setupLogOutput(t, &logOutput)
if tc.args == nil {
tc.args = []string{"portainer"}
}
setOsArgs(t, tc.args)
s := Service{}
flags, err := s.ParseFlags("test-version")
if err != nil {
t.Fatalf("error parsing flags: %v", err)
}
if flags.TLS == nil {
t.Fatal("TLS flag was nil")
}
require.Equal(t, tc.expectedTLSFlag, *flags.TLS, "tlsverify flag didn't match")
require.Equal(t, tc.expectedTLSCertFlag, *flags.TLSCert, "tlscert flag didn't match")
require.Equal(t, tc.expectedTLSKeyFlag, *flags.TLSKey, "tlskey flag didn't match")
for _, expectedLogMessage := range tc.expectedLogMessages {
require.Contains(t, logOutput.String(), expectedLogMessage, "Log didn't contain expected message")
}
})
}
}
func setOsArgs(t *testing.T, args []string) {
t.Helper()
previousArgs := os.Args
os.Args = args
t.Cleanup(func() {
os.Args = previousArgs
})
}
func setupLogOutput(t *testing.T, w io.Writer) {
t.Helper()
oldLogger := zerolog.Logger
zerolog.Logger = zerolog.Output(w)
t.Cleanup(func() {
zerolog.Logger = oldLogger
})
}

View File

@@ -6,7 +6,7 @@ import (
"fmt"
"strings"
"gopkg.in/alecthomas/kingpin.v2"
"github.com/alecthomas/kingpin/v2"
)
type pairList []portainer.Pair

View File

@@ -49,6 +49,7 @@ import (
"github.com/portainer/portainer/api/stacks/deployments"
"github.com/portainer/portainer/pkg/build"
"github.com/portainer/portainer/pkg/featureflags"
"github.com/portainer/portainer/pkg/fips"
"github.com/portainer/portainer/pkg/libhelm"
libhelmtypes "github.com/portainer/portainer/pkg/libhelm/types"
"github.com/portainer/portainer/pkg/libstack/compose"
@@ -59,7 +60,7 @@ import (
)
func initCLI() *portainer.CLIFlags {
cliService := &cli.Service{}
cliService := cli.Service{}
flags, err := cliService.ParseFlags(portainer.APIVersion)
if err != nil {
@@ -306,8 +307,19 @@ func initKeyPair(fileService portainer.FileService, signatureService portainer.D
return generateAndStoreKeyPair(fileService, signatureService)
}
// dbSecretPath build the path to the file that contains the db encryption
// secret. Normally in Docker this is built from the static path inside
// /run/portainer for example: /run/portainer/<keyFilenameFlag> but for ease of
// use outside Docker it also accepts an absolute path
func dbSecretPath(keyFilenameFlag string) string {
if path.IsAbs(keyFilenameFlag) {
return keyFilenameFlag
}
return path.Join("/run/portainer", keyFilenameFlag)
}
func loadEncryptionSecretKey(keyfilename string) []byte {
content, err := os.ReadFile(path.Join("/run/secrets", keyfilename))
content, err := os.ReadFile(keyfilename)
if err != nil {
if os.IsNotExist(err) {
log.Info().Str("filename", keyfilename).Msg("encryption key file not present")
@@ -319,6 +331,7 @@ func loadEncryptionSecretKey(keyfilename string) []byte {
}
// return a 32 byte hash of the secret (required for AES)
// fips compliant version of this is not implemented in -ce
hash := sha256.Sum256(content)
return hash[:]
@@ -343,8 +356,11 @@ func buildServer(flags *portainer.CLIFlags) portainer.Server {
}
}
// -ce can not ever be run in FIPS mode
fips.InitFIPS(false)
fileService := initFileService(*flags.Data)
encryptionKey := loadEncryptionSecretKey(*flags.SecretKeyName)
encryptionKey := loadEncryptionSecretKey(dbSecretPath(*flags.SecretKeyName))
if encryptionKey == nil {
log.Info().Msg("proceeding without encryption key")
}
@@ -377,7 +393,7 @@ func buildServer(flags *portainer.CLIFlags) portainer.Server {
log.Fatal().Err(err).Msg("failed initializing JWT service")
}
ldapService := &ldap.Service{}
ldapService := ldap.Service{}
oauthService := oauth.NewService()
@@ -386,13 +402,13 @@ func buildServer(flags *portainer.CLIFlags) portainer.Server {
// Setting insecureSkipVerify to true to preserve the old behaviour.
openAMTService := openamt.NewService(true)
cryptoService := &crypto.Service{}
cryptoService := crypto.Service{}
signatureService := initDigitalSignatureService()
edgeStacksService := edgestacks.NewService(dataStore)
sslService, err := initSSLService(*flags.AddrHTTPS, *flags.SSLCert, *flags.SSLKey, fileService, dataStore, shutdownTrigger)
sslService, err := initSSLService(*flags.AddrHTTPS, *flags.TLSCert, *flags.TLSKey, fileService, dataStore, shutdownTrigger)
if err != nil {
log.Fatal().Err(err).Msg("")
}
@@ -451,7 +467,7 @@ func buildServer(flags *portainer.CLIFlags) portainer.Server {
snapshotService.Start()
proxyManager.NewProxyFactory(dataStore, signatureService, reverseTunnelService, dockerClientFactory, kubernetesClientFactory, kubernetesTokenCacheManager, gitService, snapshotService)
proxyManager.NewProxyFactory(dataStore, signatureService, reverseTunnelService, dockerClientFactory, kubernetesClientFactory, kubernetesTokenCacheManager, gitService, snapshotService, jwtService)
helmPackageManager, err := initHelmPackageManager()
if err != nil {

View File

@@ -0,0 +1,57 @@
package main
import (
"os"
"path"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
const secretFileName = "secret.txt"
func createPasswordFile(t *testing.T, secretPath, password string) string {
err := os.WriteFile(secretPath, []byte(password), 0600)
require.NoError(t, err)
return secretPath
}
func TestLoadEncryptionSecretKey(t *testing.T) {
tempDir := t.TempDir()
secretPath := path.Join(tempDir, secretFileName)
// first pointing to file that does not exist, gives nil hash (no encryption)
encryptionKey := loadEncryptionSecretKey(secretPath)
require.Nil(t, encryptionKey)
// point to a directory instead of a file
encryptionKey = loadEncryptionSecretKey(tempDir)
require.Nil(t, encryptionKey)
password := "portainer@1234"
createPasswordFile(t, secretPath, password)
encryptionKey = loadEncryptionSecretKey(secretPath)
require.NotNil(t, encryptionKey)
// should be 32 bytes for aes256 encryption
require.Len(t, encryptionKey, 32)
}
func TestDBSecretPath(t *testing.T) {
tests := []struct {
keyFilenameFlag string
expected string
}{
{keyFilenameFlag: "secret.txt", expected: "/run/portainer/secret.txt"},
{keyFilenameFlag: "/tmp/secret.txt", expected: "/tmp/secret.txt"},
{keyFilenameFlag: "/run/portainer/secret.txt", expected: "/run/portainer/secret.txt"},
{keyFilenameFlag: "./secret.txt", expected: "/run/portainer/secret.txt"},
{keyFilenameFlag: "../secret.txt", expected: "/run/secret.txt"},
{keyFilenameFlag: "foo/bar/secret.txt", expected: "/run/portainer/foo/bar/secret.txt"},
}
for _, test := range tests {
assert.Equal(t, test.expected, dbSecretPath(test.keyFilenameFlag))
}
}

View File

@@ -5,10 +5,15 @@ import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/pbkdf2"
"crypto/rand"
"crypto/sha256"
"errors"
"fmt"
"io"
"strings"
"github.com/portainer/portainer/pkg/fips"
"golang.org/x/crypto/argon2"
"golang.org/x/crypto/scrypt"
@@ -19,20 +24,32 @@ const (
aesGcmHeader = "AES256-GCM" // The encrypted file header
aesGcmBlockSize = 1024 * 1024 // 1MB block for aes gcm
aesGcmFIPSHeader = "FIPS-AES256-GCM"
aesGcmFIPSBlockSize = 16 * 1024 * 1024 // 16MB block for aes gcm
// Argon2 settings
// Recommded settings lower memory hardware according to current OWASP recommendations
// Recommended settings lower memory hardware according to current OWASP recommendations
// Considering some people run portainer on a NAS I think it's prudent not to assume we're on server grade hardware
// https://cheatsheetseries.owasp.org/cheatsheets/Password_Storage_Cheat_Sheet.html#argon2id
argon2MemoryCost = 12 * 1024
argon2TimeCost = 3
argon2Threads = 1
argon2KeyLength = 32
pbkdf2Iterations = 600_000 // use recommended iterations from https://cheatsheetseries.owasp.org/cheatsheets/Password_Storage_Cheat_Sheet.html#pbkdf2 a little overkill for this use
pbkdf2SaltLength = 32
)
// AesEncrypt reads from input, encrypts with AES-256 and writes to output. passphrase is used to generate an encryption key
func AesEncrypt(input io.Reader, output io.Writer, passphrase []byte) error {
if err := aesEncryptGCM(input, output, passphrase); err != nil {
return fmt.Errorf("error encrypting file: %w", err)
if fips.FIPSMode() {
if err := aesEncryptGCMFIPS(input, output, passphrase); err != nil {
return fmt.Errorf("error encrypting file: %w", err)
}
} else {
if err := aesEncryptGCM(input, output, passphrase); err != nil {
return fmt.Errorf("error encrypting file: %w", err)
}
}
return nil
@@ -40,14 +57,35 @@ func AesEncrypt(input io.Reader, output io.Writer, passphrase []byte) error {
// AesDecrypt reads from input, decrypts with AES-256 and returns the reader to read the decrypted content from
func AesDecrypt(input io.Reader, passphrase []byte) (io.Reader, error) {
return aesDecrypt(input, passphrase, fips.FIPSMode())
}
func aesDecrypt(input io.Reader, passphrase []byte, fipsMode bool) (io.Reader, error) {
// Read file header to determine how it was encrypted
inputReader := bufio.NewReader(input)
header, err := inputReader.Peek(len(aesGcmHeader))
header, err := inputReader.Peek(len(aesGcmFIPSHeader))
if err != nil {
return nil, fmt.Errorf("error reading encrypted backup file header: %w", err)
}
if string(header) == aesGcmHeader {
if strings.HasPrefix(string(header), aesGcmFIPSHeader) {
if !fipsMode {
return nil, errors.New("fips encrypted file detected but fips mode is not enabled")
}
reader, err := aesDecryptGCMFIPS(inputReader, passphrase)
if err != nil {
return nil, fmt.Errorf("error decrypting file: %w", err)
}
return reader, nil
}
if strings.HasPrefix(string(header), aesGcmHeader) {
if fipsMode {
return nil, errors.New("fips mode is enabled but non-fips encrypted file detected")
}
reader, err := aesDecryptGCM(inputReader, passphrase)
if err != nil {
return nil, fmt.Errorf("error decrypting file: %w", err)
@@ -114,15 +152,14 @@ func aesEncryptGCM(input io.Reader, output io.Writer, passphrase []byte) error {
break // end of plaintext input
}
if err != nil && !(errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF)) {
if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, io.ErrUnexpectedEOF) {
return err
}
// Seal encrypts the plaintext using the nonce returning the updated slice.
ciphertext = aesgcm.Seal(ciphertext[:0], nonce.Value(), buf[:n], nil)
_, err = output.Write(ciphertext)
if err != nil {
if _, err := output.Write(ciphertext); err != nil {
return err
}
@@ -183,7 +220,7 @@ func aesDecryptGCM(input io.Reader, passphrase []byte) (io.Reader, error) {
break // end of ciphertext
}
if err != nil && !(errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF)) {
if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, io.ErrUnexpectedEOF) {
return nil, err
}
@@ -203,15 +240,138 @@ func aesDecryptGCM(input io.Reader, passphrase []byte) (io.Reader, error) {
return &buf, nil
}
// aesEncryptGCMFIPS reads from input, encrypts with AES-256 in a fips compliant
// way and writes to output. passphrase is used to generate an encryption key.
func aesEncryptGCMFIPS(input io.Reader, output io.Writer, passphrase []byte) error {
salt := make([]byte, pbkdf2SaltLength)
if _, err := io.ReadFull(rand.Reader, salt); err != nil {
return err
}
key, err := pbkdf2.Key(sha256.New, string(passphrase), salt, pbkdf2Iterations, 32)
if err != nil {
return fmt.Errorf("error deriving key: %w", err)
}
block, err := aes.NewCipher(key)
if err != nil {
return err
}
// write the header
if _, err := output.Write([]byte(aesGcmFIPSHeader)); err != nil {
return err
}
// Write nonce and salt to the output file
if _, err := output.Write(salt); err != nil {
return err
}
// Buffer for reading plaintext blocks
buf := make([]byte, aesGcmFIPSBlockSize)
// Encrypt plaintext in blocks
for {
// new random nonce for each block
aesgcm, err := cipher.NewGCMWithRandomNonce(block)
if err != nil {
return fmt.Errorf("error creating gcm: %w", err)
}
n, err := io.ReadFull(input, buf)
if n == 0 {
break // end of plaintext input
}
if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, io.ErrUnexpectedEOF) {
return err
}
// Seal encrypts the plaintext
ciphertext := aesgcm.Seal(nil, nil, buf[:n], nil)
if _, err := output.Write(ciphertext); err != nil {
return err
}
}
return nil
}
// aesDecryptGCMFIPS reads from input, decrypts with AES-256 in a fips compliant
// way and returns the reader to read the decrypted content from.
func aesDecryptGCMFIPS(input io.Reader, passphrase []byte) (io.Reader, error) {
// Reader & verify header
header := make([]byte, len(aesGcmFIPSHeader))
if _, err := io.ReadFull(input, header); err != nil {
return nil, err
}
if string(header) != aesGcmFIPSHeader {
return nil, errors.New("invalid header")
}
// Read salt
salt := make([]byte, pbkdf2SaltLength)
if _, err := io.ReadFull(input, salt); err != nil {
return nil, err
}
key, err := pbkdf2.Key(sha256.New, string(passphrase), salt, pbkdf2Iterations, 32)
if err != nil {
return nil, fmt.Errorf("error deriving key: %w", err)
}
// Initialize AES cipher block
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
// Initialize a buffer to store decrypted data
buf := bytes.Buffer{}
// Decrypt the ciphertext in blocks
for {
// Create GCM mode with the cipher block
aesgcm, err := cipher.NewGCMWithRandomNonce(block)
if err != nil {
return nil, err
}
// Read a block of ciphertext from the input reader
ciphertextBlock := make([]byte, aesGcmFIPSBlockSize+aesgcm.Overhead())
n, err := io.ReadFull(input, ciphertextBlock)
if n == 0 {
break // end of ciphertext
}
if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, io.ErrUnexpectedEOF) {
return nil, err
}
// Decrypt the block of ciphertext
plaintext, err := aesgcm.Open(nil, nil, ciphertextBlock[:n], nil)
if err != nil {
return nil, err
}
if _, err := buf.Write(plaintext); err != nil {
return nil, err
}
}
return &buf, nil
}
// aesDecryptOFB reads from input, decrypts with AES-256 and returns the reader to a read decrypted content from.
// passphrase is used to generate an encryption key.
// note: This function used to decrypt files that were encrypted without a header i.e. old archives
func aesDecryptOFB(input io.Reader, passphrase []byte) (io.Reader, error) {
var emptySalt []byte = make([]byte, 0)
// making a 32 bytes key that would correspond to AES-256
// don't necessarily need a salt, so just kept in empty
key, err := scrypt.Key(passphrase, emptySalt, 32768, 8, 1, 32)
key, err := scrypt.Key(passphrase, nil, 32768, 8, 1, 32)
if err != nil {
return nil, err
}

View File

@@ -1,15 +1,25 @@
package crypto
import (
"crypto/aes"
"crypto/cipher"
"io"
"math/rand"
"os"
"path/filepath"
"testing"
"github.com/portainer/portainer/pkg/fips"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/scrypt"
)
func init() {
fips.InitFIPS(false)
}
const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
func randBytes(n int) []byte {
@@ -17,201 +27,326 @@ func randBytes(n int) []byte {
for i := range b {
b[i] = letterBytes[rand.Intn(len(letterBytes))]
}
return b
}
type encryptFunc func(input io.Reader, output io.Writer, passphrase []byte) error
type decryptFunc func(input io.Reader, passphrase []byte) (io.Reader, error)
func Test_encryptAndDecrypt_withTheSamePassword(t *testing.T) {
const passphrase = "passphrase"
tmpdir := t.TempDir()
testFunc := func(t *testing.T, encrypt encryptFunc, decrypt decryptFunc, decryptShouldSucceed bool) {
tmpdir := t.TempDir()
var (
originFilePath = filepath.Join(tmpdir, "origin")
encryptedFilePath = filepath.Join(tmpdir, "encrypted")
decryptedFilePath = filepath.Join(tmpdir, "decrypted")
)
var (
originFilePath = filepath.Join(tmpdir, "origin")
encryptedFilePath = filepath.Join(tmpdir, "encrypted")
decryptedFilePath = filepath.Join(tmpdir, "decrypted")
)
content := randBytes(1024*1024*100 + 523)
os.WriteFile(originFilePath, content, 0600)
content := randBytes(1024*1024*100 + 523)
os.WriteFile(originFilePath, content, 0600)
originFile, _ := os.Open(originFilePath)
defer originFile.Close()
originFile, _ := os.Open(originFilePath)
defer originFile.Close()
encryptedFileWriter, _ := os.Create(encryptedFilePath)
encryptedFileWriter, _ := os.Create(encryptedFilePath)
err := AesEncrypt(originFile, encryptedFileWriter, []byte(passphrase))
assert.Nil(t, err, "Failed to encrypt a file")
encryptedFileWriter.Close()
err := encrypt(originFile, encryptedFileWriter, []byte(passphrase))
require.Nil(t, err, "Failed to encrypt a file")
encryptedFileWriter.Close()
encryptedContent, err := os.ReadFile(encryptedFilePath)
assert.Nil(t, err, "Couldn't read encrypted file")
assert.NotEqual(t, encryptedContent, content, "Content wasn't encrypted")
encryptedContent, err := os.ReadFile(encryptedFilePath)
require.Nil(t, err, "Couldn't read encrypted file")
assert.NotEqual(t, encryptedContent, content, "Content wasn't encrypted")
encryptedFileReader, _ := os.Open(encryptedFilePath)
defer encryptedFileReader.Close()
encryptedFileReader, _ := os.Open(encryptedFilePath)
defer encryptedFileReader.Close()
decryptedFileWriter, _ := os.Create(decryptedFilePath)
defer decryptedFileWriter.Close()
decryptedFileWriter, _ := os.Create(decryptedFilePath)
defer decryptedFileWriter.Close()
decryptedReader, err := AesDecrypt(encryptedFileReader, []byte(passphrase))
assert.Nil(t, err, "Failed to decrypt file")
decryptedReader, err := decrypt(encryptedFileReader, []byte(passphrase))
if !decryptShouldSucceed {
require.Error(t, err, "Failed to decrypt file as indicated by decryptShouldSucceed")
} else {
require.NoError(t, err, "Failed to decrypt file indicated by decryptShouldSucceed")
io.Copy(decryptedFileWriter, decryptedReader)
io.Copy(decryptedFileWriter, decryptedReader)
decryptedContent, _ := os.ReadFile(decryptedFilePath)
assert.Equal(t, content, decryptedContent, "Original and decrypted content should match")
decryptedContent, _ := os.ReadFile(decryptedFilePath)
assert.Equal(t, content, decryptedContent, "Original and decrypted content should match")
}
}
t.Run("fips", func(t *testing.T) {
testFunc(t, aesEncryptGCMFIPS, aesDecryptGCMFIPS, true)
})
t.Run("non_fips", func(t *testing.T) {
testFunc(t, aesEncryptGCM, aesDecryptGCM, true)
})
t.Run("system_fips_mode_public_entry_points", func(t *testing.T) {
// use the init mode, public entry points
testFunc(t, AesEncrypt, AesDecrypt, true)
})
t.Run("fips_encrypted_file_header_fails_in_non_fips_mode", func(t *testing.T) {
// use aesDecrypt which checks the header, confirm that it fails
decrypt := func(input io.Reader, passphrase []byte) (io.Reader, error) {
return aesDecrypt(input, passphrase, false)
}
testFunc(t, aesEncryptGCMFIPS, decrypt, false)
})
t.Run("non_fips_encrypted_file_header_fails_in_fips_mode", func(t *testing.T) {
// use aesDecrypt which checks the header, confirm that it fails
decrypt := func(input io.Reader, passphrase []byte) (io.Reader, error) {
return aesDecrypt(input, passphrase, true)
}
testFunc(t, aesEncryptGCM, decrypt, false)
})
t.Run("fips_encrypted_file_fails_in_non_fips_mode", func(t *testing.T) {
testFunc(t, aesEncryptGCMFIPS, aesDecryptGCM, false)
})
t.Run("non_fips_encrypted_file_with_fips_mode_should_fail", func(t *testing.T) {
testFunc(t, aesEncryptGCM, aesDecryptGCMFIPS, false)
})
t.Run("fips_with_base_aesDecrypt", func(t *testing.T) {
// maximize coverage, use the base aesDecrypt function with valid fips mode
decrypt := func(input io.Reader, passphrase []byte) (io.Reader, error) {
return aesDecrypt(input, passphrase, true)
}
testFunc(t, aesEncryptGCMFIPS, decrypt, true)
})
t.Run("legacy", func(t *testing.T) {
testFunc(t, legacyAesEncrypt, aesDecryptOFB, true)
})
}
func Test_encryptAndDecrypt_withStrongPassphrase(t *testing.T) {
const passphrase = "A strong passphrase with special characters: !@#$%^&*()_+"
tmpdir := t.TempDir()
var (
originFilePath = filepath.Join(tmpdir, "origin2")
encryptedFilePath = filepath.Join(tmpdir, "encrypted2")
decryptedFilePath = filepath.Join(tmpdir, "decrypted2")
)
testFunc := func(t *testing.T, encrypt encryptFunc, decrypt decryptFunc) {
tmpdir := t.TempDir()
content := randBytes(500)
os.WriteFile(originFilePath, content, 0600)
var (
originFilePath = filepath.Join(tmpdir, "origin2")
encryptedFilePath = filepath.Join(tmpdir, "encrypted2")
decryptedFilePath = filepath.Join(tmpdir, "decrypted2")
)
originFile, _ := os.Open(originFilePath)
defer originFile.Close()
content := randBytes(500)
os.WriteFile(originFilePath, content, 0600)
encryptedFileWriter, _ := os.Create(encryptedFilePath)
originFile, _ := os.Open(originFilePath)
defer originFile.Close()
err := AesEncrypt(originFile, encryptedFileWriter, []byte(passphrase))
assert.Nil(t, err, "Failed to encrypt a file")
encryptedFileWriter.Close()
encryptedFileWriter, _ := os.Create(encryptedFilePath)
encryptedContent, err := os.ReadFile(encryptedFilePath)
assert.Nil(t, err, "Couldn't read encrypted file")
assert.NotEqual(t, encryptedContent, content, "Content wasn't encrypted")
err := encrypt(originFile, encryptedFileWriter, []byte(passphrase))
assert.Nil(t, err, "Failed to encrypt a file")
encryptedFileWriter.Close()
encryptedFileReader, _ := os.Open(encryptedFilePath)
defer encryptedFileReader.Close()
encryptedContent, err := os.ReadFile(encryptedFilePath)
assert.Nil(t, err, "Couldn't read encrypted file")
assert.NotEqual(t, encryptedContent, content, "Content wasn't encrypted")
decryptedFileWriter, _ := os.Create(decryptedFilePath)
defer decryptedFileWriter.Close()
encryptedFileReader, _ := os.Open(encryptedFilePath)
defer encryptedFileReader.Close()
decryptedReader, err := AesDecrypt(encryptedFileReader, []byte(passphrase))
assert.Nil(t, err, "Failed to decrypt file")
decryptedFileWriter, _ := os.Create(decryptedFilePath)
defer decryptedFileWriter.Close()
io.Copy(decryptedFileWriter, decryptedReader)
decryptedReader, err := decrypt(encryptedFileReader, []byte(passphrase))
assert.Nil(t, err, "Failed to decrypt file")
decryptedContent, _ := os.ReadFile(decryptedFilePath)
assert.Equal(t, content, decryptedContent, "Original and decrypted content should match")
io.Copy(decryptedFileWriter, decryptedReader)
decryptedContent, _ := os.ReadFile(decryptedFilePath)
assert.Equal(t, content, decryptedContent, "Original and decrypted content should match")
}
t.Run("fips", func(t *testing.T) {
testFunc(t, aesEncryptGCMFIPS, aesDecryptGCMFIPS)
})
t.Run("non_fips", func(t *testing.T) {
testFunc(t, aesEncryptGCM, aesDecryptGCM)
})
}
func Test_encryptAndDecrypt_withTheSamePasswordSmallFile(t *testing.T) {
tmpdir := t.TempDir()
testFunc := func(t *testing.T, encrypt encryptFunc, decrypt decryptFunc) {
tmpdir := t.TempDir()
var (
originFilePath = filepath.Join(tmpdir, "origin2")
encryptedFilePath = filepath.Join(tmpdir, "encrypted2")
decryptedFilePath = filepath.Join(tmpdir, "decrypted2")
)
var (
originFilePath = filepath.Join(tmpdir, "origin2")
encryptedFilePath = filepath.Join(tmpdir, "encrypted2")
decryptedFilePath = filepath.Join(tmpdir, "decrypted2")
)
content := randBytes(500)
os.WriteFile(originFilePath, content, 0600)
content := randBytes(500)
os.WriteFile(originFilePath, content, 0600)
originFile, _ := os.Open(originFilePath)
defer originFile.Close()
originFile, _ := os.Open(originFilePath)
defer originFile.Close()
encryptedFileWriter, _ := os.Create(encryptedFilePath)
encryptedFileWriter, _ := os.Create(encryptedFilePath)
err := AesEncrypt(originFile, encryptedFileWriter, []byte("passphrase"))
assert.Nil(t, err, "Failed to encrypt a file")
encryptedFileWriter.Close()
err := encrypt(originFile, encryptedFileWriter, []byte("passphrase"))
assert.Nil(t, err, "Failed to encrypt a file")
encryptedFileWriter.Close()
encryptedContent, err := os.ReadFile(encryptedFilePath)
assert.Nil(t, err, "Couldn't read encrypted file")
assert.NotEqual(t, encryptedContent, content, "Content wasn't encrypted")
encryptedContent, err := os.ReadFile(encryptedFilePath)
assert.Nil(t, err, "Couldn't read encrypted file")
assert.NotEqual(t, encryptedContent, content, "Content wasn't encrypted")
encryptedFileReader, _ := os.Open(encryptedFilePath)
defer encryptedFileReader.Close()
encryptedFileReader, _ := os.Open(encryptedFilePath)
defer encryptedFileReader.Close()
decryptedFileWriter, _ := os.Create(decryptedFilePath)
defer decryptedFileWriter.Close()
decryptedFileWriter, _ := os.Create(decryptedFilePath)
defer decryptedFileWriter.Close()
decryptedReader, err := AesDecrypt(encryptedFileReader, []byte("passphrase"))
assert.Nil(t, err, "Failed to decrypt file")
decryptedReader, err := decrypt(encryptedFileReader, []byte("passphrase"))
assert.Nil(t, err, "Failed to decrypt file")
io.Copy(decryptedFileWriter, decryptedReader)
io.Copy(decryptedFileWriter, decryptedReader)
decryptedContent, _ := os.ReadFile(decryptedFilePath)
assert.Equal(t, content, decryptedContent, "Original and decrypted content should match")
decryptedContent, _ := os.ReadFile(decryptedFilePath)
assert.Equal(t, content, decryptedContent, "Original and decrypted content should match")
}
t.Run("fips", func(t *testing.T) {
testFunc(t, aesEncryptGCMFIPS, aesDecryptGCMFIPS)
})
t.Run("non_fips", func(t *testing.T) {
testFunc(t, aesEncryptGCM, aesDecryptGCM)
})
}
func Test_encryptAndDecrypt_withEmptyPassword(t *testing.T) {
tmpdir := t.TempDir()
testFunc := func(t *testing.T, encrypt encryptFunc, decrypt decryptFunc) {
tmpdir := t.TempDir()
var (
originFilePath = filepath.Join(tmpdir, "origin")
encryptedFilePath = filepath.Join(tmpdir, "encrypted")
decryptedFilePath = filepath.Join(tmpdir, "decrypted")
)
var (
originFilePath = filepath.Join(tmpdir, "origin")
encryptedFilePath = filepath.Join(tmpdir, "encrypted")
decryptedFilePath = filepath.Join(tmpdir, "decrypted")
)
content := randBytes(1024 * 50)
os.WriteFile(originFilePath, content, 0600)
content := randBytes(1024 * 50)
os.WriteFile(originFilePath, content, 0600)
originFile, _ := os.Open(originFilePath)
defer originFile.Close()
originFile, _ := os.Open(originFilePath)
defer originFile.Close()
encryptedFileWriter, _ := os.Create(encryptedFilePath)
defer encryptedFileWriter.Close()
encryptedFileWriter, _ := os.Create(encryptedFilePath)
defer encryptedFileWriter.Close()
err := AesEncrypt(originFile, encryptedFileWriter, []byte(""))
assert.Nil(t, err, "Failed to encrypt a file")
encryptedContent, err := os.ReadFile(encryptedFilePath)
assert.Nil(t, err, "Couldn't read encrypted file")
assert.NotEqual(t, encryptedContent, content, "Content wasn't encrypted")
err := encrypt(originFile, encryptedFileWriter, []byte(""))
assert.Nil(t, err, "Failed to encrypt a file")
encryptedContent, err := os.ReadFile(encryptedFilePath)
assert.Nil(t, err, "Couldn't read encrypted file")
assert.NotEqual(t, encryptedContent, content, "Content wasn't encrypted")
encryptedFileReader, _ := os.Open(encryptedFilePath)
defer encryptedFileReader.Close()
encryptedFileReader, _ := os.Open(encryptedFilePath)
defer encryptedFileReader.Close()
decryptedFileWriter, _ := os.Create(decryptedFilePath)
defer decryptedFileWriter.Close()
decryptedFileWriter, _ := os.Create(decryptedFilePath)
defer decryptedFileWriter.Close()
decryptedReader, err := AesDecrypt(encryptedFileReader, []byte(""))
assert.Nil(t, err, "Failed to decrypt file")
decryptedReader, err := decrypt(encryptedFileReader, []byte(""))
assert.Nil(t, err, "Failed to decrypt file")
io.Copy(decryptedFileWriter, decryptedReader)
io.Copy(decryptedFileWriter, decryptedReader)
decryptedContent, _ := os.ReadFile(decryptedFilePath)
assert.Equal(t, content, decryptedContent, "Original and decrypted content should match")
decryptedContent, _ := os.ReadFile(decryptedFilePath)
assert.Equal(t, content, decryptedContent, "Original and decrypted content should match")
}
t.Run("fips", func(t *testing.T) {
testFunc(t, aesEncryptGCMFIPS, aesDecryptGCMFIPS)
})
t.Run("non_fips", func(t *testing.T) {
testFunc(t, aesEncryptGCM, aesDecryptGCM)
})
}
func Test_decryptWithDifferentPassphrase_shouldProduceWrongResult(t *testing.T) {
tmpdir := t.TempDir()
testFunc := func(t *testing.T, encrypt encryptFunc, decrypt decryptFunc) {
tmpdir := t.TempDir()
var (
originFilePath = filepath.Join(tmpdir, "origin")
encryptedFilePath = filepath.Join(tmpdir, "encrypted")
decryptedFilePath = filepath.Join(tmpdir, "decrypted")
)
var (
originFilePath = filepath.Join(tmpdir, "origin")
encryptedFilePath = filepath.Join(tmpdir, "encrypted")
decryptedFilePath = filepath.Join(tmpdir, "decrypted")
)
content := randBytes(1034)
os.WriteFile(originFilePath, content, 0600)
content := randBytes(1034)
os.WriteFile(originFilePath, content, 0600)
originFile, _ := os.Open(originFilePath)
defer originFile.Close()
originFile, _ := os.Open(originFilePath)
defer originFile.Close()
encryptedFileWriter, _ := os.Create(encryptedFilePath)
defer encryptedFileWriter.Close()
encryptedFileWriter, _ := os.Create(encryptedFilePath)
defer encryptedFileWriter.Close()
err := AesEncrypt(originFile, encryptedFileWriter, []byte("passphrase"))
assert.Nil(t, err, "Failed to encrypt a file")
encryptedContent, err := os.ReadFile(encryptedFilePath)
assert.Nil(t, err, "Couldn't read encrypted file")
assert.NotEqual(t, encryptedContent, content, "Content wasn't encrypted")
err := encrypt(originFile, encryptedFileWriter, []byte("passphrase"))
assert.Nil(t, err, "Failed to encrypt a file")
encryptedContent, err := os.ReadFile(encryptedFilePath)
assert.Nil(t, err, "Couldn't read encrypted file")
assert.NotEqual(t, encryptedContent, content, "Content wasn't encrypted")
encryptedFileReader, _ := os.Open(encryptedFilePath)
defer encryptedFileReader.Close()
encryptedFileReader, _ := os.Open(encryptedFilePath)
defer encryptedFileReader.Close()
decryptedFileWriter, _ := os.Create(decryptedFilePath)
defer decryptedFileWriter.Close()
decryptedFileWriter, _ := os.Create(decryptedFilePath)
defer decryptedFileWriter.Close()
_, err = AesDecrypt(encryptedFileReader, []byte("garbage"))
assert.NotNil(t, err, "Should not allow decrypt with wrong passphrase")
_, err = decrypt(encryptedFileReader, []byte("garbage"))
assert.NotNil(t, err, "Should not allow decrypt with wrong passphrase")
}
t.Run("fips", func(t *testing.T) {
testFunc(t, aesEncryptGCMFIPS, aesDecryptGCMFIPS)
})
t.Run("non_fips", func(t *testing.T) {
testFunc(t, aesEncryptGCM, aesDecryptGCM)
})
}
func legacyAesEncrypt(input io.Reader, output io.Writer, passphrase []byte) error {
key, err := scrypt.Key(passphrase, nil, 32768, 8, 1, 32)
if err != nil {
return err
}
block, err := aes.NewCipher(key)
if err != nil {
return err
}
var iv [aes.BlockSize]byte
stream := cipher.NewOFB(block, iv[:])
writer := &cipher.StreamWriter{S: stream, W: output}
if _, err := io.Copy(writer, input); err != nil {
return err
}
return nil
}

View File

@@ -112,7 +112,7 @@ func (service *ECDSAService) CreateSignature(message string) (string, error) {
message = service.secret
}
hash := libcrypto.HashFromBytes([]byte(message))
hash := libcrypto.InsecureHashFromBytes([]byte(message))
r, s, err := ecdsa.Sign(rand.Reader, service.privateKey, hash)
if err != nil {

22
api/crypto/ecdsa_test.go Normal file
View File

@@ -0,0 +1,22 @@
package crypto
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestCreateSignature(t *testing.T) {
var s = NewECDSAService("secret")
privKey, pubKey, err := s.GenerateKeyPair()
require.NoError(t, err)
require.Greater(t, len(privKey), 0)
require.Greater(t, len(pubKey), 0)
m := "test message"
r, err := s.CreateSignature(m)
require.NoError(t, err)
require.NotEqual(t, r, m)
require.Greater(t, len(r), 0)
}

View File

@@ -8,15 +8,16 @@ import (
type Service struct{}
// Hash hashes a string using the bcrypt algorithm
func (*Service) Hash(data string) (string, error) {
func (Service) Hash(data string) (string, error) {
bytes, err := bcrypt.GenerateFromPassword([]byte(data), bcrypt.DefaultCost)
if err != nil {
return "", err
}
return string(bytes), err
}
// CompareHashAndData compares a hash to clear data and returns an error if the comparison fails.
func (*Service) CompareHashAndData(hash string, data string) error {
func (Service) CompareHashAndData(hash string, data string) error {
return bcrypt.CompareHashAndPassword([]byte(hash), []byte(data))
}

View File

@@ -2,10 +2,12 @@ package crypto
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestService_Hash(t *testing.T) {
var s = &Service{}
var s = Service{}
type args struct {
hash string
@@ -51,3 +53,11 @@ func TestService_Hash(t *testing.T) {
})
}
}
func TestHash(t *testing.T) {
s := Service{}
hash, err := s.Hash("Passw0rd!")
require.NoError(t, err)
require.NotEmpty(t, hash)
}

View File

@@ -15,7 +15,7 @@ func NewNonce(size int) *Nonce {
}
// NewRandomNonce generates a new initial nonce with the lower byte set to a random value
// This ensures there are plenty of nonce values availble before rolling over
// This ensures there are plenty of nonce values available before rolling over
// Based on ideas from the Secure Programming Cookbook for C and C++ by John Viega, Matt Messier
// https://www.oreilly.com/library/view/secure-programming-cookbook/0596003943/ch04s09.html
func NewRandomNonce(size int) (*Nonce, error) {

View File

@@ -4,11 +4,32 @@ import (
"crypto/tls"
"crypto/x509"
"os"
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/pkg/fips"
)
// CreateTLSConfiguration creates a basic tls.Config with recommended TLS settings
func CreateTLSConfiguration() *tls.Config {
return &tls.Config{
func CreateTLSConfiguration(insecureSkipVerify bool) *tls.Config { //nolint:forbidigo
return createTLSConfiguration(fips.FIPSMode(), insecureSkipVerify)
}
func createTLSConfiguration(fipsEnabled bool, insecureSkipVerify bool) *tls.Config { //nolint:forbidigo
if fipsEnabled {
return &tls.Config{ //nolint:forbidigo
MinVersion: tls.VersionTLS12,
MaxVersion: tls.VersionTLS13,
CipherSuites: []uint16{
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
},
CurvePreferences: []tls.CurveID{tls.CurveP256, tls.CurveP384, tls.CurveP521},
}
}
return &tls.Config{ //nolint:forbidigo
MinVersion: tls.VersionTLS12,
CipherSuites: []uint16{
tls.TLS_AES_128_GCM_SHA256,
@@ -29,24 +50,33 @@ func CreateTLSConfiguration() *tls.Config {
tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256,
},
InsecureSkipVerify: insecureSkipVerify, //nolint:forbidigo
}
}
// CreateTLSConfigurationFromBytes initializes a tls.Config using a CA certificate, a certificate and a key
// loaded from memory.
func CreateTLSConfigurationFromBytes(caCert, cert, key []byte, skipClientVerification, skipServerVerification bool) (*tls.Config, error) {
config := CreateTLSConfiguration()
config.InsecureSkipVerify = skipServerVerification
func CreateTLSConfigurationFromBytes(useTLS bool, caCert, cert, key []byte, skipClientVerification, skipServerVerification bool) (*tls.Config, error) { //nolint:forbidigo
return createTLSConfigurationFromBytes(fips.FIPSMode(), useTLS, caCert, cert, key, skipClientVerification, skipServerVerification)
}
if !skipClientVerification {
func createTLSConfigurationFromBytes(fipsEnabled, useTLS bool, caCert, cert, key []byte, skipClientVerification, skipServerVerification bool) (*tls.Config, error) { //nolint:forbidigo
if !useTLS {
return nil, nil
}
config := createTLSConfiguration(fipsEnabled, skipServerVerification)
if !skipClientVerification || fipsEnabled {
certificate, err := tls.X509KeyPair(cert, key)
if err != nil {
return nil, err
}
config.Certificates = []tls.Certificate{certificate}
}
if !skipServerVerification {
if !skipServerVerification || fipsEnabled {
caCertPool := x509.NewCertPool()
caCertPool.AppendCertsFromPEM(caCert)
config.RootCAs = caCertPool
@@ -57,29 +87,36 @@ func CreateTLSConfigurationFromBytes(caCert, cert, key []byte, skipClientVerific
// CreateTLSConfigurationFromDisk initializes a tls.Config using a CA certificate, a certificate and a key
// loaded from disk.
func CreateTLSConfigurationFromDisk(caCertPath, certPath, keyPath string, skipServerVerification bool) (*tls.Config, error) {
config := CreateTLSConfiguration()
config.InsecureSkipVerify = skipServerVerification
func CreateTLSConfigurationFromDisk(config portainer.TLSConfiguration) (*tls.Config, error) { //nolint:forbidigo
return createTLSConfigurationFromDisk(fips.FIPSMode(), config)
}
if certPath != "" && keyPath != "" {
cert, err := tls.LoadX509KeyPair(certPath, keyPath)
func createTLSConfigurationFromDisk(fipsEnabled bool, config portainer.TLSConfiguration) (*tls.Config, error) { //nolint:forbidigo
if !config.TLS {
return nil, nil
}
tlsConfig := createTLSConfiguration(fipsEnabled, config.TLSSkipVerify)
if config.TLSCertPath != "" && config.TLSKeyPath != "" {
cert, err := tls.LoadX509KeyPair(config.TLSCertPath, config.TLSKeyPath)
if err != nil {
return nil, err
}
config.Certificates = []tls.Certificate{cert}
tlsConfig.Certificates = []tls.Certificate{cert}
}
if !skipServerVerification && caCertPath != "" {
caCert, err := os.ReadFile(caCertPath)
if !tlsConfig.InsecureSkipVerify && config.TLSCACertPath != "" { //nolint:forbidigo
caCert, err := os.ReadFile(config.TLSCACertPath)
if err != nil {
return nil, err
}
caCertPool := x509.NewCertPool()
caCertPool.AppendCertsFromPEM(caCert)
config.RootCAs = caCertPool
tlsConfig.RootCAs = caCertPool
}
return config, nil
return tlsConfig, nil
}

87
api/crypto/tls_test.go Normal file
View File

@@ -0,0 +1,87 @@
package crypto
import (
"crypto/tls"
"testing"
portainer "github.com/portainer/portainer/api"
"github.com/stretchr/testify/require"
)
func TestCreateTLSConfiguration(t *testing.T) {
// InsecureSkipVerify = false
config := CreateTLSConfiguration(false)
require.Equal(t, config.MinVersion, uint16(tls.VersionTLS12)) //nolint:forbidigo
require.False(t, config.InsecureSkipVerify) //nolint:forbidigo
// InsecureSkipVerify = true
config = CreateTLSConfiguration(true)
require.Equal(t, config.MinVersion, uint16(tls.VersionTLS12)) //nolint:forbidigo
require.True(t, config.InsecureSkipVerify) //nolint:forbidigo
}
func TestCreateTLSConfigurationFIPS(t *testing.T) {
fips := true
fipsCipherSuites := []uint16{
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
}
fipsCurvePreferences := []tls.CurveID{tls.CurveP256, tls.CurveP384, tls.CurveP521}
config := createTLSConfiguration(fips, false)
require.Equal(t, config.MinVersion, uint16(tls.VersionTLS12)) //nolint:forbidigo
require.Equal(t, config.MaxVersion, uint16(tls.VersionTLS13)) //nolint:forbidigo
require.Equal(t, config.CipherSuites, fipsCipherSuites) //nolint:forbidigo
require.Equal(t, config.CurvePreferences, fipsCurvePreferences) //nolint:forbidigo
require.False(t, config.InsecureSkipVerify) //nolint:forbidigo
}
func TestCreateTLSConfigurationFromBytes(t *testing.T) {
// No TLS
config, err := CreateTLSConfigurationFromBytes(false, nil, nil, nil, false, false)
require.Nil(t, err)
require.Nil(t, config)
// Skip TLS client/server verifications
config, err = CreateTLSConfigurationFromBytes(true, nil, nil, nil, true, true)
require.NoError(t, err)
require.NotNil(t, config)
// Empty TLS
config, err = CreateTLSConfigurationFromBytes(true, nil, nil, nil, false, false)
require.Error(t, err)
require.Nil(t, config)
}
func TestCreateTLSConfigurationFromDisk(t *testing.T) {
// No TLS
config, err := CreateTLSConfigurationFromDisk(portainer.TLSConfiguration{})
require.Nil(t, err)
require.Nil(t, config)
// Skip TLS verifications
config, err = CreateTLSConfigurationFromDisk(portainer.TLSConfiguration{
TLS: true,
TLSSkipVerify: true,
})
require.NoError(t, err)
require.NotNil(t, config)
}
func TestCreateTLSConfigurationFromDiskFIPS(t *testing.T) {
fips := true
// Skipping TLS verifications cannot be done in FIPS mode
config, err := createTLSConfigurationFromDisk(fips, portainer.TLSConfiguration{
TLS: true,
TLSSkipVerify: true,
})
require.NoError(t, err)
require.NotNil(t, config)
require.False(t, config.InsecureSkipVerify) //nolint:forbidigo
}

View File

@@ -4,8 +4,6 @@ import (
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"io"
"github.com/pkg/errors"
"github.com/segmentio/encoding/json"
@@ -65,18 +63,18 @@ func (connection *DbConnection) UnmarshalObject(data []byte, object any) error {
// https://gist.github.com/atoponce/07d8d4c833873be2f68c34f9afc5a78a#symmetric-encryption
func encrypt(plaintext []byte, passphrase []byte) (encrypted []byte, err error) {
block, _ := aes.NewCipher(passphrase)
gcm, err := cipher.NewGCM(block)
block, err := aes.NewCipher(passphrase)
if err != nil {
return encrypted, err
}
nonce := make([]byte, gcm.NonceSize())
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
// NewGCMWithRandomNonce in go 1.24 handles setting up the nonce and adding it to the encrypted output
gcm, err := cipher.NewGCMWithRandomNonce(block)
if err != nil {
return encrypted, err
}
return gcm.Seal(nonce, nonce, plaintext, nil), nil
return gcm.Seal(nil, nil, plaintext, nil), nil
}
func decrypt(encrypted []byte, passphrase []byte) (plaintextByte []byte, err error) {
@@ -89,19 +87,17 @@ func decrypt(encrypted []byte, passphrase []byte) (plaintextByte []byte, err err
return encrypted, errors.Wrap(err, "Error creating cypher block")
}
gcm, err := cipher.NewGCM(block)
// NewGCMWithRandomNonce in go 1.24 handles reading the nonce from the encrypted input for us
gcm, err := cipher.NewGCMWithRandomNonce(block)
if err != nil {
return encrypted, errors.Wrap(err, "Error creating GCM")
}
nonceSize := gcm.NonceSize()
if len(encrypted) < nonceSize {
if len(encrypted) < gcm.NonceSize() {
return encrypted, errEncryptedStringTooShort
}
nonce, ciphertextByteClean := encrypted[:nonceSize], encrypted[nonceSize:]
plaintextByte, err = gcm.Open(nil, nonce, ciphertextByteClean, nil)
plaintextByte, err = gcm.Open(nil, nil, encrypted, nil)
if err != nil {
return encrypted, errors.Wrap(err, "Error decrypting text")
}

View File

@@ -1,12 +1,19 @@
package boltdb
import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"fmt"
"io"
"testing"
"github.com/gofrs/uuid"
"github.com/pkg/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
const (
@@ -160,7 +167,7 @@ func Test_ObjectMarshallingEncrypted(t *testing.T) {
}
key := secretToEncryptionKey(passphrase)
conn := DbConnection{EncryptionKey: key}
conn := DbConnection{EncryptionKey: key, isEncrypted: true}
for _, test := range tests {
t.Run(fmt.Sprintf("%s -> %s", test.object, test.expected), func(t *testing.T) {
@@ -175,3 +182,94 @@ func Test_ObjectMarshallingEncrypted(t *testing.T) {
})
}
}
func Test_NonceSources(t *testing.T) {
// ensure that the new go 1.24 NewGCMWithRandomNonce works correctly with
// the old way of creating and including the nonce
encryptOldFn := func(plaintext []byte, passphrase []byte) (encrypted []byte, err error) {
block, _ := aes.NewCipher(passphrase)
gcm, err := cipher.NewGCM(block)
if err != nil {
return encrypted, err
}
nonce := make([]byte, gcm.NonceSize())
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
return encrypted, err
}
return gcm.Seal(nonce, nonce, plaintext, nil), nil
}
decryptOldFn := func(encrypted []byte, passphrase []byte) (plaintext []byte, err error) {
block, err := aes.NewCipher(passphrase)
if err != nil {
return encrypted, errors.Wrap(err, "Error creating cypher block")
}
gcm, err := cipher.NewGCM(block)
if err != nil {
return encrypted, errors.Wrap(err, "Error creating GCM")
}
nonceSize := gcm.NonceSize()
if len(encrypted) < nonceSize {
return encrypted, errEncryptedStringTooShort
}
nonce, ciphertextByteClean := encrypted[:nonceSize], encrypted[nonceSize:]
plaintext, err = gcm.Open(nil, nonce, ciphertextByteClean, nil)
if err != nil {
return encrypted, errors.Wrap(err, "Error decrypting text")
}
return plaintext, err
}
encryptNewFn := encrypt
decryptNewFn := decrypt
passphrase := make([]byte, 32)
_, err := io.ReadFull(rand.Reader, passphrase)
require.NoError(t, err)
junk := make([]byte, 1024)
_, err = io.ReadFull(rand.Reader, junk)
require.NoError(t, err)
junkEnc := make([]byte, base64.StdEncoding.EncodedLen(len(junk)))
base64.StdEncoding.Encode(junkEnc, junk)
cases := [][]byte{
[]byte("test"),
[]byte("35"),
[]byte("9ca4a1dd-a439-4593-b386-a7dfdc2e9fc6"),
[]byte(jsonobject),
passphrase,
junk,
junkEnc,
}
for _, plain := range cases {
var enc, dec []byte
var err error
enc, err = encryptOldFn(plain, passphrase)
require.NoError(t, err)
dec, err = decryptNewFn(enc, passphrase)
require.NoError(t, err)
require.Equal(t, plain, dec)
enc, err = encryptNewFn(plain, passphrase)
require.NoError(t, err)
dec, err = decryptOldFn(enc, passphrase)
require.NoError(t, err)
require.Equal(t, plain, dec)
}
}

View File

@@ -17,11 +17,29 @@ func (service ServiceTx) UpdateEdgeGroupFunc(ID portainer.EdgeGroupID, updateFun
}
func (service ServiceTx) Create(group *portainer.EdgeGroup) error {
return service.Tx.CreateObject(
es := group.Endpoints
group.Endpoints = nil // Clear deprecated field
err := service.Tx.CreateObject(
BucketName,
func(id uint64) (int, any) {
group.ID = portainer.EdgeGroupID(id)
return int(group.ID), group
},
)
group.Endpoints = es // Restore endpoints after create
return err
}
func (service ServiceTx) Update(ID portainer.EdgeGroupID, group *portainer.EdgeGroup) error {
es := group.Endpoints
group.Endpoints = nil // Clear deprecated field
err := service.BaseDataServiceTx.Update(ID, group)
group.Endpoints = es // Restore endpoints after update
return err
}

View File

@@ -0,0 +1,50 @@
package edgestack
import (
"testing"
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/database/boltdb"
"github.com/stretchr/testify/require"
)
func TestUpdate(t *testing.T) {
var conn portainer.Connection = &boltdb.DbConnection{Path: t.TempDir()}
err := conn.Open()
require.NoError(t, err)
defer conn.Close()
service, err := NewService(conn, func(portainer.Transaction, portainer.EdgeStackID) {})
require.NoError(t, err)
const edgeStackID = 1
edgeStack := &portainer.EdgeStack{
ID: edgeStackID,
Name: "Test Stack",
}
err = service.Create(edgeStackID, edgeStack)
require.NoError(t, err)
err = service.UpdateEdgeStackFunc(edgeStackID, func(edgeStack *portainer.EdgeStack) {
edgeStack.Name = "Updated Stack"
})
require.NoError(t, err)
updatedStack, err := service.EdgeStack(edgeStackID)
require.NoError(t, err)
require.Equal(t, "Updated Stack", updatedStack.Name)
err = conn.UpdateTx(func(tx portainer.Transaction) error {
return service.UpdateEdgeStackFuncTx(tx, edgeStackID, func(edgeStack *portainer.EdgeStack) {
edgeStack.Name = "Updated Stack Again"
})
})
require.NoError(t, err)
updatedStack, err = service.EdgeStack(edgeStackID)
require.NoError(t, err)
require.Equal(t, "Updated Stack Again", updatedStack.Name)
}

View File

@@ -6,8 +6,6 @@ import (
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/dataservices"
"github.com/portainer/portainer/api/internal/edge/cache"
"github.com/rs/zerolog/log"
)
// BucketName represents the name of the bucket where this service stores data.
@@ -16,7 +14,6 @@ const BucketName = "endpoint_relations"
// Service represents a service for managing environment(endpoint) relation data.
type Service struct {
connection portainer.Connection
updateStackFn func(ID portainer.EdgeStackID, updateFunc func(edgeStack *portainer.EdgeStack)) error
updateStackFnTx func(tx portainer.Transaction, ID portainer.EdgeStackID, updateFunc func(edgeStack *portainer.EdgeStack)) error
endpointRelationsCache []portainer.EndpointRelation
mu sync.Mutex
@@ -29,10 +26,8 @@ func (service *Service) BucketName() string {
}
func (service *Service) RegisterUpdateStackFunction(
updateFunc func(portainer.EdgeStackID, func(*portainer.EdgeStack)) error,
updateFuncTx func(portainer.Transaction, portainer.EdgeStackID, func(*portainer.EdgeStack)) error,
) {
service.updateStackFn = updateFunc
service.updateStackFnTx = updateFuncTx
}
@@ -91,29 +86,14 @@ func (service *Service) Create(endpointRelation *portainer.EndpointRelation) err
// UpdateEndpointRelation updates an Environment(Endpoint) relation object
func (service *Service) UpdateEndpointRelation(endpointID portainer.EndpointID, endpointRelation *portainer.EndpointRelation) error {
previousRelationState, _ := service.EndpointRelation(endpointID)
identifier := service.connection.ConvertToKey(int(endpointID))
err := service.connection.UpdateObject(BucketName, identifier, endpointRelation)
cache.Del(endpointID)
if err != nil {
return err
}
updatedRelationState, _ := service.EndpointRelation(endpointID)
service.mu.Lock()
service.endpointRelationsCache = nil
service.mu.Unlock()
service.updateEdgeStacksAfterRelationChange(previousRelationState, updatedRelationState)
return nil
return service.connection.UpdateTx(func(tx portainer.Transaction) error {
return service.Tx(tx).UpdateEndpointRelation(endpointID, endpointRelation)
})
}
func (service *Service) AddEndpointRelationsForEdgeStack(endpointIDs []portainer.EndpointID, edgeStackID portainer.EdgeStackID) error {
func (service *Service) AddEndpointRelationsForEdgeStack(endpointIDs []portainer.EndpointID, edgeStack *portainer.EdgeStack) error {
return service.connection.UpdateTx(func(tx portainer.Transaction) error {
return service.Tx(tx).AddEndpointRelationsForEdgeStack(endpointIDs, edgeStackID)
return service.Tx(tx).AddEndpointRelationsForEdgeStack(endpointIDs, edgeStack)
})
}
@@ -125,72 +105,7 @@ func (service *Service) RemoveEndpointRelationsForEdgeStack(endpointIDs []portai
// DeleteEndpointRelation deletes an Environment(Endpoint) relation object
func (service *Service) DeleteEndpointRelation(endpointID portainer.EndpointID) error {
deletedRelation, _ := service.EndpointRelation(endpointID)
identifier := service.connection.ConvertToKey(int(endpointID))
err := service.connection.DeleteObject(BucketName, identifier)
cache.Del(endpointID)
if err != nil {
return err
}
service.mu.Lock()
service.endpointRelationsCache = nil
service.mu.Unlock()
service.updateEdgeStacksAfterRelationChange(deletedRelation, nil)
return nil
}
func (service *Service) updateEdgeStacksAfterRelationChange(previousRelationState *portainer.EndpointRelation, updatedRelationState *portainer.EndpointRelation) {
relations, _ := service.EndpointRelations()
stacksToUpdate := map[portainer.EdgeStackID]bool{}
if previousRelationState != nil {
for stackId, enabled := range previousRelationState.EdgeStacks {
// flag stack for update if stack is not in the updated relation state
// = stack has been removed for this relation
// or this relation has been deleted
if enabled && (updatedRelationState == nil || !updatedRelationState.EdgeStacks[stackId]) {
stacksToUpdate[stackId] = true
}
}
}
if updatedRelationState != nil {
for stackId, enabled := range updatedRelationState.EdgeStacks {
// flag stack for update if stack is not in the previous relation state
// = stack has been added for this relation
if enabled && (previousRelationState == nil || !previousRelationState.EdgeStacks[stackId]) {
stacksToUpdate[stackId] = true
}
}
}
// for each stack referenced by the updated relation
// list how many time this stack is referenced in all relations
// in order to update the stack deployments count
for refStackId, refStackEnabled := range stacksToUpdate {
if !refStackEnabled {
continue
}
numDeployments := 0
for _, r := range relations {
for sId, enabled := range r.EdgeStacks {
if enabled && sId == refStackId {
numDeployments += 1
}
}
}
if err := service.updateStackFn(refStackId, func(edgeStack *portainer.EdgeStack) {
edgeStack.NumDeployments = numDeployments
}); err != nil {
log.Error().Err(err).Msg("could not update the number of deployments")
}
}
return service.connection.UpdateTx(func(tx portainer.Transaction) error {
return service.Tx(tx).DeleteEndpointRelation(endpointID)
})
}

View File

@@ -0,0 +1,140 @@
package endpointrelation
import (
"testing"
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/database/boltdb"
"github.com/portainer/portainer/api/dataservices/edgestack"
"github.com/portainer/portainer/api/internal/edge/cache"
"github.com/stretchr/testify/require"
)
func TestUpdateRelation(t *testing.T) {
const endpointID = 1
const edgeStackID1 = 1
const edgeStackID2 = 2
var conn portainer.Connection = &boltdb.DbConnection{Path: t.TempDir()}
err := conn.Open()
require.NoError(t, err)
defer conn.Close()
service, err := NewService(conn)
require.NoError(t, err)
updateStackFnTxCalled := false
edgeStacks := make(map[portainer.EdgeStackID]portainer.EdgeStack)
edgeStacks[edgeStackID1] = portainer.EdgeStack{ID: edgeStackID1}
edgeStacks[edgeStackID2] = portainer.EdgeStack{ID: edgeStackID2}
service.RegisterUpdateStackFunction(func(tx portainer.Transaction, ID portainer.EdgeStackID, updateFunc func(edgeStack *portainer.EdgeStack)) error {
updateStackFnTxCalled = true
s, ok := edgeStacks[ID]
require.True(t, ok)
updateFunc(&s)
edgeStacks[ID] = s
return nil
})
// Nil relation
cache.Set(endpointID, []byte("value"))
err = service.UpdateEndpointRelation(endpointID, nil)
_, cacheKeyExists := cache.Get(endpointID)
require.NoError(t, err)
require.False(t, updateStackFnTxCalled)
require.False(t, cacheKeyExists)
// Add a relation to two edge stacks
cache.Set(endpointID, []byte("value"))
err = service.UpdateEndpointRelation(endpointID, &portainer.EndpointRelation{
EndpointID: endpointID,
EdgeStacks: map[portainer.EdgeStackID]bool{
edgeStackID1: true,
edgeStackID2: true,
},
})
_, cacheKeyExists = cache.Get(endpointID)
require.NoError(t, err)
require.True(t, updateStackFnTxCalled)
require.False(t, cacheKeyExists)
require.Equal(t, 1, edgeStacks[edgeStackID1].NumDeployments)
require.Equal(t, 1, edgeStacks[edgeStackID2].NumDeployments)
// Remove a relation to one edge stack
updateStackFnTxCalled = false
cache.Set(endpointID, []byte("value"))
err = service.UpdateEndpointRelation(endpointID, &portainer.EndpointRelation{
EndpointID: endpointID,
EdgeStacks: map[portainer.EdgeStackID]bool{
2: true,
},
})
_, cacheKeyExists = cache.Get(endpointID)
require.NoError(t, err)
require.True(t, updateStackFnTxCalled)
require.False(t, cacheKeyExists)
require.Equal(t, 0, edgeStacks[edgeStackID1].NumDeployments)
require.Equal(t, 1, edgeStacks[edgeStackID2].NumDeployments)
// Delete the relation
updateStackFnTxCalled = false
cache.Set(endpointID, []byte("value"))
err = service.DeleteEndpointRelation(endpointID)
_, cacheKeyExists = cache.Get(endpointID)
require.NoError(t, err)
require.True(t, updateStackFnTxCalled)
require.False(t, cacheKeyExists)
require.Equal(t, 0, edgeStacks[edgeStackID1].NumDeployments)
require.Equal(t, 0, edgeStacks[edgeStackID2].NumDeployments)
}
func TestAddEndpointRelationsForEdgeStack(t *testing.T) {
var conn portainer.Connection = &boltdb.DbConnection{Path: t.TempDir()}
err := conn.Open()
require.NoError(t, err)
defer conn.Close()
service, err := NewService(conn)
require.NoError(t, err)
edgeStackService, err := edgestack.NewService(conn, func(t portainer.Transaction, esi portainer.EdgeStackID) {})
require.NoError(t, err)
service.RegisterUpdateStackFunction(edgeStackService.UpdateEdgeStackFuncTx)
require.NoError(t, edgeStackService.Create(1, &portainer.EdgeStack{}))
require.NoError(t, service.Create(&portainer.EndpointRelation{EndpointID: 1, EdgeStacks: map[portainer.EdgeStackID]bool{}}))
require.NoError(t, service.AddEndpointRelationsForEdgeStack([]portainer.EndpointID{1}, &portainer.EdgeStack{ID: 1}))
}
func TestEndpointRelations(t *testing.T) {
var conn portainer.Connection = &boltdb.DbConnection{Path: t.TempDir()}
err := conn.Open()
require.NoError(t, err)
defer conn.Close()
service, err := NewService(conn)
require.NoError(t, err)
require.NoError(t, service.Create(&portainer.EndpointRelation{EndpointID: 1}))
rels, err := service.EndpointRelations()
require.NoError(t, err)
require.Equal(t, 1, len(rels))
}

View File

@@ -76,14 +76,14 @@ func (service ServiceTx) UpdateEndpointRelation(endpointID portainer.EndpointID,
return nil
}
func (service ServiceTx) AddEndpointRelationsForEdgeStack(endpointIDs []portainer.EndpointID, edgeStackID portainer.EdgeStackID) error {
func (service ServiceTx) AddEndpointRelationsForEdgeStack(endpointIDs []portainer.EndpointID, edgeStack *portainer.EdgeStack) error {
for _, endpointID := range endpointIDs {
rel, err := service.EndpointRelation(endpointID)
if err != nil {
return err
}
rel.EdgeStacks[edgeStackID] = true
rel.EdgeStacks[edgeStack.ID] = true
identifier := service.service.connection.ConvertToKey(int(endpointID))
err = service.tx.UpdateObject(BucketName, identifier, rel)
@@ -97,8 +97,12 @@ func (service ServiceTx) AddEndpointRelationsForEdgeStack(endpointIDs []portaine
service.service.endpointRelationsCache = nil
service.service.mu.Unlock()
if err := service.service.updateStackFnTx(service.tx, edgeStackID, func(edgeStack *portainer.EdgeStack) {
edgeStack.NumDeployments += len(endpointIDs)
if err := service.service.updateStackFnTx(service.tx, edgeStack.ID, func(es *portainer.EdgeStack) {
es.NumDeployments += len(endpointIDs)
// sync changes in `edgeStack` in case it is re-persisted after `AddEndpointRelationsForEdgeStack` call
// to avoid overriding with the previous values
edgeStack.NumDeployments = es.NumDeployments
}); err != nil {
log.Error().Err(err).Msg("could not update the number of deployments")
}
@@ -186,53 +190,49 @@ func (service ServiceTx) cachedEndpointRelations() ([]portainer.EndpointRelation
}
func (service ServiceTx) updateEdgeStacksAfterRelationChange(previousRelationState *portainer.EndpointRelation, updatedRelationState *portainer.EndpointRelation) {
relations, _ := service.EndpointRelations()
stacksToUpdate := map[portainer.EdgeStackID]bool{}
if previousRelationState != nil {
for stackId, enabled := range previousRelationState.EdgeStacks {
// flag stack for update if stack is not in the updated relation state
// = stack has been removed for this relation
// or this relation has been deleted
if enabled && (updatedRelationState == nil || !updatedRelationState.EdgeStacks[stackId]) {
stacksToUpdate[stackId] = true
}
}
}
if err := service.service.updateStackFnTx(service.tx, stackId, func(edgeStack *portainer.EdgeStack) {
// Sanity check
if edgeStack.NumDeployments <= 0 {
log.Error().
Int("edgestack_id", int(edgeStack.ID)).
Int("endpoint_id", int(previousRelationState.EndpointID)).
Int("num_deployments", edgeStack.NumDeployments).
Msg("cannot decrement the number of deployments for an edge stack with zero deployments")
if updatedRelationState != nil {
for stackId, enabled := range updatedRelationState.EdgeStacks {
// flag stack for update if stack is not in the previous relation state
// = stack has been added for this relation
if enabled && (previousRelationState == nil || !previousRelationState.EdgeStacks[stackId]) {
stacksToUpdate[stackId] = true
}
}
}
return
}
// for each stack referenced by the updated relation
// list how many time this stack is referenced in all relations
// in order to update the stack deployments count
for refStackId, refStackEnabled := range stacksToUpdate {
if !refStackEnabled {
continue
}
numDeployments := 0
for _, r := range relations {
for sId, enabled := range r.EdgeStacks {
if enabled && sId == refStackId {
numDeployments += 1
edgeStack.NumDeployments--
}); err != nil {
log.Error().Err(err).Msg("could not update the number of deployments")
}
cache.Del(previousRelationState.EndpointID)
}
}
}
if err := service.service.updateStackFnTx(service.tx, refStackId, func(edgeStack *portainer.EdgeStack) {
edgeStack.NumDeployments = numDeployments
}); err != nil {
log.Error().Err(err).Msg("could not update the number of deployments")
if updatedRelationState == nil {
return
}
for stackId, enabled := range updatedRelationState.EdgeStacks {
// flag stack for update if stack is not in the previous relation state
// = stack has been added for this relation
if enabled && (previousRelationState == nil || !previousRelationState.EdgeStacks[stackId]) {
if err := service.service.updateStackFnTx(service.tx, stackId, func(edgeStack *portainer.EdgeStack) {
edgeStack.NumDeployments++
}); err != nil {
log.Error().Err(err).Msg("could not update the number of deployments")
}
cache.Del(updatedRelationState.EndpointID)
}
}
}

View File

@@ -126,7 +126,7 @@ type (
EndpointRelation(EndpointID portainer.EndpointID) (*portainer.EndpointRelation, error)
Create(endpointRelation *portainer.EndpointRelation) error
UpdateEndpointRelation(EndpointID portainer.EndpointID, endpointRelation *portainer.EndpointRelation) error
AddEndpointRelationsForEdgeStack(endpointIDs []portainer.EndpointID, edgeStackID portainer.EdgeStackID) error
AddEndpointRelationsForEdgeStack(endpointIDs []portainer.EndpointID, edgeStack *portainer.EdgeStack) error
RemoveEndpointRelationsForEdgeStack(endpointIDs []portainer.EndpointID, edgeStackID portainer.EdgeStackID) error
DeleteEndpointRelation(EndpointID portainer.EndpointID) error
BucketName() string

View File

@@ -6,12 +6,11 @@ import (
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/dataservices"
"github.com/rs/zerolog/log"
)
const (
BucketName = "pending_actions"
)
const BucketName = "pending_actions"
type Service struct {
dataservices.BaseDataService[portainer.PendingAction, portainer.PendingActionID]
@@ -78,15 +77,14 @@ func (s ServiceTx) Update(ID portainer.PendingActionID, config *portainer.Pendin
func (s ServiceTx) DeleteByEndpointID(ID portainer.EndpointID) error {
log.Debug().Int("endpointId", int(ID)).Msg("deleting pending actions for endpoint")
pendingActions, err := s.BaseDataServiceTx.ReadAll()
pendingActions, err := s.ReadAll()
if err != nil {
return fmt.Errorf("failed to retrieve pending-actions for endpoint (%d): %w", ID, err)
}
for _, pendingAction := range pendingActions {
if pendingAction.EndpointID == ID {
err := s.BaseDataServiceTx.Delete(pendingAction.ID)
if err != nil {
if err := s.Delete(pendingAction.ID); err != nil {
log.Debug().Int("endpointId", int(ID)).Msgf("failed to delete pending action: %v", err)
}
}

View File

@@ -0,0 +1,32 @@
package pendingactions_test
import (
"testing"
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/datastore"
"github.com/stretchr/testify/require"
)
func TestDeleteByEndpoint(t *testing.T) {
_, store := datastore.MustNewTestStore(t, false, false)
// Create Endpoint 1
err := store.PendingActions().Create(&portainer.PendingAction{EndpointID: 1})
require.NoError(t, err)
// Create Endpoint 2
err = store.PendingActions().Create(&portainer.PendingAction{EndpointID: 2})
require.NoError(t, err)
// Delete Endpoint 1
err = store.PendingActions().DeleteByEndpointID(1)
require.NoError(t, err)
// Check that only Endpoint 2 remains
pendingActions, err := store.PendingActions().ReadAll()
require.NoError(t, err)
require.Len(t, pendingActions, 1)
require.Equal(t, portainer.EndpointID(2), pendingActions[0].EndpointID)
}

View File

@@ -232,7 +232,7 @@ func (store *Store) createAccount(username, password string, role portainer.User
user := &portainer.User{Username: username, Role: role}
// encrypt the password
cs := &crypto.Service{}
cs := crypto.Service{}
user.Password, err = cs.Hash(password)
if err != nil {
return err
@@ -259,7 +259,7 @@ func (store *Store) checkAccount(username, expectPassword string, expectRole por
}
// Check the password
cs := &crypto.Service{}
cs := crypto.Service{}
expectPasswordHash, err := cs.Hash(expectPassword)
if err != nil {
return errors.Wrap(err, "hash failed")

View File

@@ -85,6 +85,7 @@ func (store *Store) newMigratorParameters(version *models.Version, flags *portai
EdgeStackService: store.EdgeStackService,
EdgeStackStatusService: store.EdgeStackStatusService,
EdgeJobService: store.EdgeJobService,
EdgeGroupService: store.EdgeGroupService,
TunnelServerService: store.TunnelServerService,
PendingActionsService: store.PendingActionsService,
}

View File

@@ -0,0 +1,25 @@
package migrator
import (
"github.com/portainer/portainer/api/roar"
)
func (m *Migrator) migrateEdgeGroupEndpointsToRoars_2_33_0() error {
egs, err := m.edgeGroupService.ReadAll()
if err != nil {
return err
}
for _, eg := range egs {
if eg.EndpointIDs.Len() == 0 {
eg.EndpointIDs = roar.FromSlice(eg.Endpoints)
eg.Endpoints = nil
}
if err := m.edgeGroupService.Update(eg.ID, &eg); err != nil {
return err
}
}
return nil
}

View File

@@ -0,0 +1,55 @@
package migrator
import (
"testing"
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/database/boltdb"
"github.com/portainer/portainer/api/dataservices/edgegroup"
"github.com/stretchr/testify/require"
)
func TestMigrateEdgeGroupEndpointsToRoars_2_33_0Idempotency(t *testing.T) {
var conn portainer.Connection = &boltdb.DbConnection{Path: t.TempDir()}
err := conn.Open()
require.NoError(t, err)
defer conn.Close()
edgeGroupService, err := edgegroup.NewService(conn)
require.NoError(t, err)
edgeGroup := &portainer.EdgeGroup{
ID: 1,
Name: "test-edge-group",
Endpoints: []portainer.EndpointID{1, 2, 3},
}
err = conn.CreateObjectWithId(edgegroup.BucketName, int(edgeGroup.ID), edgeGroup)
require.NoError(t, err)
m := NewMigrator(&MigratorParameters{EdgeGroupService: edgeGroupService})
// Run migration once
err = m.migrateEdgeGroupEndpointsToRoars_2_33_0()
require.NoError(t, err)
migratedEdgeGroup, err := edgeGroupService.Read(edgeGroup.ID)
require.NoError(t, err)
require.Len(t, migratedEdgeGroup.Endpoints, 0)
require.Equal(t, len(edgeGroup.Endpoints), migratedEdgeGroup.EndpointIDs.Len())
// Run migration again to ensure the results didn't change
err = m.migrateEdgeGroupEndpointsToRoars_2_33_0()
require.NoError(t, err)
migratedEdgeGroup, err = edgeGroupService.Read(edgeGroup.ID)
require.NoError(t, err)
require.Len(t, migratedEdgeGroup.Endpoints, 0)
require.Equal(t, len(edgeGroup.Endpoints), migratedEdgeGroup.EndpointIDs.Len())
}

View File

@@ -6,6 +6,7 @@ import (
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/database/models"
"github.com/portainer/portainer/api/dataservices/dockerhub"
"github.com/portainer/portainer/api/dataservices/edgegroup"
"github.com/portainer/portainer/api/dataservices/edgejob"
"github.com/portainer/portainer/api/dataservices/edgestack"
"github.com/portainer/portainer/api/dataservices/edgestackstatus"
@@ -60,6 +61,7 @@ type (
edgeStackService *edgestack.Service
edgeStackStatusService *edgestackstatus.Service
edgeJobService *edgejob.Service
edgeGroupService *edgegroup.Service
TunnelServerService *tunnelserver.Service
pendingActionsService *pendingactions.Service
}
@@ -89,6 +91,7 @@ type (
EdgeStackService *edgestack.Service
EdgeStackStatusService *edgestackstatus.Service
EdgeJobService *edgejob.Service
EdgeGroupService *edgegroup.Service
TunnelServerService *tunnelserver.Service
PendingActionsService *pendingactions.Service
}
@@ -120,11 +123,13 @@ func NewMigrator(parameters *MigratorParameters) *Migrator {
edgeStackService: parameters.EdgeStackService,
edgeStackStatusService: parameters.EdgeStackStatusService,
edgeJobService: parameters.EdgeJobService,
edgeGroupService: parameters.EdgeGroupService,
TunnelServerService: parameters.TunnelServerService,
pendingActionsService: parameters.PendingActionsService,
}
migrator.initMigrations()
return migrator
}
@@ -251,6 +256,8 @@ func (m *Migrator) initMigrations() {
m.addMigrations("2.32.0", m.addEndpointRelationForEdgeAgents_2_32_0)
m.addMigrations("2.33.1", m.migrateEdgeGroupEndpointsToRoars_2_33_0)
// Add new migrations above...
// One function per migration, each versions migration funcs in the same file.
}

View File

@@ -2,6 +2,7 @@ package postinit
import (
"context"
"fmt"
"github.com/docker/docker/api/types/container"
"github.com/docker/docker/client"
@@ -83,17 +84,27 @@ func (postInitMigrator *PostInitMigrator) PostInitMigrate() error {
// try to create a post init migration pending action. If it already exists, do nothing
// this function exists for readability, not reusability
// TODO: This should be moved into pending actions as part of the pending action migration
func (postInitMigrator *PostInitMigrator) createPostInitMigrationPendingAction(environmentID portainer.EndpointID) error {
// If there are no pending actions for the given endpoint, create one
err := postInitMigrator.dataStore.PendingActions().Create(&portainer.PendingAction{
action := portainer.PendingAction{
EndpointID: environmentID,
Action: actions.PostInitMigrateEnvironment,
})
if err != nil {
log.Error().Err(err).Msgf("Error creating pending action for environment %d", environmentID)
}
return nil
pendingActions, err := postInitMigrator.dataStore.PendingActions().ReadAll()
if err != nil {
return fmt.Errorf("failed to retrieve pending actions: %w", err)
}
for _, dba := range pendingActions {
if dba.EndpointID == action.EndpointID && dba.Action == action.Action {
log.Debug().
Str("action", action.Action).
Int("endpoint_id", int(action.EndpointID)).
Msg("pending action already exists for environment, skipping...")
return nil
}
}
return postInitMigrator.dataStore.PendingActions().Create(&action)
}
// MigrateEnvironment runs migrations on a single environment
@@ -149,7 +160,7 @@ func (migrator *PostInitMigrator) MigrateGPUs(e portainer.Endpoint, dockerClient
return migrator.dataStore.UpdateTx(func(tx dataservices.DataStoreTx) error {
environment, err := tx.Endpoint().Endpoint(e.ID)
if err != nil {
log.Error().Err(err).Msgf("Error getting environment %d", environment.ID)
log.Error().Err(err).Msgf("Error getting environment %d", e.ID)
return err
}
// Early exit if we do not need to migrate!
@@ -175,10 +186,11 @@ func (migrator *PostInitMigrator) MigrateGPUs(e portainer.Endpoint, dockerClient
continue
}
deviceRequests := containerDetails.HostConfig.Resources.DeviceRequests
deviceRequests := containerDetails.HostConfig.DeviceRequests
for _, deviceRequest := range deviceRequests {
if deviceRequest.Driver == "nvidia" {
environment.EnableGPUManagement = true
break containersLoop
}
}
@@ -186,8 +198,7 @@ func (migrator *PostInitMigrator) MigrateGPUs(e portainer.Endpoint, dockerClient
// set the MigrateGPUs flag to false so we don't run this again
environment.PostInitMigrations.MigrateGPUs = false
err = tx.Endpoint().UpdateEndpoint(environment.ID, environment)
if err != nil {
if err := tx.Endpoint().UpdateEndpoint(environment.ID, environment); err != nil {
log.Error().Err(err).Msgf("Error updating EnableGPUManagement flag for environment %d", environment.ID)
return err
}

View File

@@ -0,0 +1,170 @@
package postinit
import (
"net/http"
"net/http/httptest"
"strings"
"testing"
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/datastore"
"github.com/portainer/portainer/api/pendingactions/actions"
"github.com/docker/docker/api/types/container"
"github.com/docker/docker/client"
"github.com/segmentio/encoding/json"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestMigrateGPUs(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if strings.HasSuffix(r.URL.Path, "/containers/json") {
containerSummary := []container.Summary{{ID: "container1"}}
err := json.NewEncoder(w).Encode(containerSummary)
require.NoError(t, err)
return
}
container := container.InspectResponse{
ContainerJSONBase: &container.ContainerJSONBase{
ID: "container1",
HostConfig: &container.HostConfig{
Resources: container.Resources{
DeviceRequests: []container.DeviceRequest{
{Driver: "nvidia"},
},
},
},
},
}
err := json.NewEncoder(w).Encode(container)
require.NoError(t, err)
}))
defer srv.Close()
_, store := datastore.MustNewTestStore(t, true, false)
migrator := &PostInitMigrator{dataStore: store}
dockerCli, err := client.NewClientWithOpts(client.WithHost(srv.URL), client.WithHTTPClient(http.DefaultClient))
require.NoError(t, err)
// Nonexistent endpoint
err = migrator.MigrateGPUs(portainer.Endpoint{}, dockerCli)
require.Error(t, err)
// Valid endpoint
endpoint := portainer.Endpoint{ID: 1, PostInitMigrations: portainer.EndpointPostInitMigrations{MigrateGPUs: true}}
err = store.Endpoint().Create(&endpoint)
require.NoError(t, err)
err = migrator.MigrateGPUs(endpoint, dockerCli)
require.NoError(t, err)
migratedEndpoint, err := store.Endpoint().Endpoint(endpoint.ID)
require.NoError(t, err)
require.Equal(t, endpoint.ID, migratedEndpoint.ID)
require.False(t, migratedEndpoint.PostInitMigrations.MigrateGPUs)
require.True(t, migratedEndpoint.EnableGPUManagement)
}
func TestPostInitMigrate_PendingActionsCreated(t *testing.T) {
tests := []struct {
name string
existingPendingActions []*portainer.PendingAction
expectedPendingActions int
expectedAction string
}{
{
name: "when existing non-matching action exists, should add migration action",
existingPendingActions: []*portainer.PendingAction{
{
EndpointID: 7,
Action: "some-other-action",
},
},
expectedPendingActions: 2,
expectedAction: actions.PostInitMigrateEnvironment,
},
{
name: "when matching action exists, should not add duplicate",
existingPendingActions: []*portainer.PendingAction{
{
EndpointID: 7,
Action: actions.PostInitMigrateEnvironment,
},
},
expectedPendingActions: 1,
expectedAction: actions.PostInitMigrateEnvironment,
},
{
name: "when no actions exist, should add migration action",
existingPendingActions: []*portainer.PendingAction{},
expectedPendingActions: 1,
expectedAction: actions.PostInitMigrateEnvironment,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
is := assert.New(t)
_, store := datastore.MustNewTestStore(t, true, true)
// Create test endpoint
endpoint := &portainer.Endpoint{
ID: 7,
UserTrusted: true,
Type: portainer.EdgeAgentOnDockerEnvironment,
Edge: portainer.EnvironmentEdgeSettings{
AsyncMode: false,
},
EdgeID: "edgeID",
}
err := store.Endpoint().Create(endpoint)
is.NoError(err, "error creating endpoint")
// Create any existing pending actions
for _, action := range tt.existingPendingActions {
err = store.PendingActions().Create(action)
is.NoError(err, "error creating pending action")
}
migrator := NewPostInitMigrator(
nil, // kubeFactory not needed for this test
nil, // dockerFactory not needed for this test
store,
"", // assetsPath not needed for this test
nil, // kubernetesDeployer not needed for this test
)
err = migrator.PostInitMigrate()
is.NoError(err, "PostInitMigrate should not return error")
// Verify the results
pendingActions, err := store.PendingActions().ReadAll()
is.NoError(err, "error reading pending actions")
is.Len(pendingActions, tt.expectedPendingActions, "unexpected number of pending actions")
// If we expect any actions, verify at least one has the expected action type
if tt.expectedPendingActions > 0 {
hasExpectedAction := false
for _, action := range pendingActions {
if action.Action == tt.expectedAction {
hasExpectedAction = true
is.Equal(endpoint.ID, action.EndpointID, "action should reference correct endpoint")
break
}
}
is.True(hasExpectedAction, "should have found action of expected type")
}
})
}
}

View File

@@ -111,7 +111,7 @@ func (store *Store) initServices() error {
return err
}
store.EdgeStackService = edgeStackService
endpointRelationService.RegisterUpdateStackFunction(edgeStackService.UpdateEdgeStackFunc, edgeStackService.UpdateEdgeStackFuncTx)
endpointRelationService.RegisterUpdateStackFunction(edgeStackService.UpdateEdgeStackFuncTx)
edgeStackStatusService, err := edgestackstatus.NewService(store.connection)
if err != nil {

View File

@@ -615,7 +615,7 @@
"RequiredPasswordLength": 12
},
"KubeconfigExpiry": "0",
"KubectlShellImage": "portainer/kubectl-shell:2.32.0",
"KubectlShellImage": "portainer/kubectl-shell:2.33.1",
"LDAPSettings": {
"AnonymousMode": true,
"AutoCreateUsers": true,
@@ -944,7 +944,7 @@
}
],
"version": {
"VERSION": "{\"SchemaVersion\":\"2.32.0\",\"MigratorCount\":1,\"Edition\":1,\"InstanceID\":\"463d5c47-0ea5-4aca-85b1-405ceefee254\"}"
"VERSION": "{\"SchemaVersion\":\"2.33.1\",\"MigratorCount\":1,\"Edition\":1,\"InstanceID\":\"463d5c47-0ea5-4aca-85b1-405ceefee254\"}"
},
"webhooks": null
}

View File

@@ -24,6 +24,8 @@ const (
dockerClientVersion = "1.37"
)
type NodeNamesCtxKey struct{}
// ClientFactory is used to create Docker clients
type ClientFactory struct {
signatureService portainer.DigitalSignatureService
@@ -162,7 +164,7 @@ func (t *NodeNameTransport) RoundTrip(req *http.Request) (*http.Response, error)
return resp, nil
}
nodeNames, ok := req.Context().Value("nodeNames").(map[string]string)
nodeNames, ok := req.Context().Value(NodeNamesCtxKey{}).(map[string]string)
if ok {
for idx, r := range rs {
// as there is no way to differentiate the same image available in multiple nodes only by their ID
@@ -181,10 +183,11 @@ func httpClient(endpoint *portainer.Endpoint, timeout *time.Duration) (*http.Cli
}
if endpoint.TLSConfig.TLS {
tlsConfig, err := crypto.CreateTLSConfigurationFromDisk(endpoint.TLSConfig.TLSCACertPath, endpoint.TLSConfig.TLSCertPath, endpoint.TLSConfig.TLSKeyPath, endpoint.TLSConfig.TLSSkipVerify)
tlsConfig, err := crypto.CreateTLSConfigurationFromDisk(endpoint.TLSConfig)
if err != nil {
return nil, err
}
transport.TLSClientConfig = tlsConfig
}

View File

@@ -0,0 +1,30 @@
package client
import (
"testing"
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/pkg/fips"
"github.com/stretchr/testify/require"
)
func TestHttpClient(t *testing.T) {
fips.InitFIPS(false)
// Valid TLS configuration
endpoint := &portainer.Endpoint{}
endpoint.TLSConfig = portainer.TLSConfiguration{TLS: true}
cli, err := httpClient(endpoint, nil)
require.NoError(t, err)
require.NotNil(t, cli)
// Invalid TLS configuration
endpoint.TLSConfig.TLSCertPath = "/invalid/path/client.crt"
endpoint.TLSConfig.TLSKeyPath = "/invalid/path/client.key"
cli, err = httpClient(endpoint, nil)
require.Error(t, err)
require.Nil(t, cli)
}

View File

@@ -9,7 +9,7 @@ import (
"github.com/containers/image/v5/docker"
imagetypes "github.com/containers/image/v5/types"
"github.com/docker/docker/api/types"
"github.com/docker/docker/api/types/image"
"github.com/opencontainers/go-digest"
"github.com/pkg/errors"
"github.com/rs/zerolog/log"
@@ -85,7 +85,7 @@ func (c *DigestClient) RemoteDigest(image Image) (digest.Digest, error) {
return rmDigest, nil
}
func ParseLocalImage(inspect types.ImageInspect) (*Image, error) {
func ParseLocalImage(inspect image.InspectResponse) (*Image, error) {
if IsLocalImage(inspect) || IsDanglingImage(inspect) {
return nil, errors.New("the image is not regular")
}

View File

@@ -0,0 +1,33 @@
package images
import (
"testing"
"github.com/docker/docker/api/types/image"
"github.com/stretchr/testify/require"
)
func TestParseLocalImage(t *testing.T) {
// Test with a regular image
img, err := ParseLocalImage(image.InspectResponse{
ID: "sha256:9234e8fb04c47cfe0f49931e4ac7eb76fa904e33b7f8576aec0501c085f02516",
RepoTags: []string{"myimage:latest"},
RepoDigests: []string{"myimage@sha256:4bcff63911fcb4448bd4fdacec207030997caf25e9bea4045fa6c8c44de311d1"},
})
require.NoError(t, err)
require.NotNil(t, img)
require.Equal(t, "library/myimage", img.Path)
require.Equal(t, "latest", img.Tag)
require.Equal(t, "sha256:4bcff63911fcb4448bd4fdacec207030997caf25e9bea4045fa6c8c44de311d1", img.Digest.String())
// Test with a dangling image
img, err = ParseLocalImage(image.InspectResponse{
ID: "sha256:abcdef1234567890",
RepoTags: []string{"<none>:<none>"},
RepoDigests: []string{"<none>@<none>"},
})
require.Error(t, err)
require.Nil(t, img)
}

View File

@@ -7,9 +7,9 @@ import (
"strings"
"text/template"
"github.com/docker/docker/api/types"
"github.com/containers/image/v5/docker/reference"
"github.com/docker/docker/api/types"
"github.com/docker/docker/api/types/image"
"github.com/opencontainers/go-digest"
"github.com/pkg/errors"
)
@@ -179,11 +179,11 @@ func IsLocalImage(image types.ImageInspect) bool {
// IsDanglingImage returns whether the given image is "dangling" which means
// that there are no repository references to the given image and it has no
// child images
func IsDanglingImage(image types.ImageInspect) bool {
func IsDanglingImage(image image.InspectResponse) bool {
return len(image.RepoTags) == 1 && image.RepoTags[0] == "<none>:<none>" && len(image.RepoDigests) == 1 && image.RepoDigests[0] == "<none>@<none>"
}
// IsNoTagImage returns whether the given image is damaged, has no tags
func IsNoTagImage(image types.ImageInspect) bool {
func IsNoTagImage(image image.InspectResponse) bool {
return len(image.RepoTags) == 0
}

View File

@@ -60,15 +60,9 @@ func NewAzureClient() *azureClient {
}
func newHttpClientForAzure(insecureSkipVerify bool) *http.Client {
tlsConfig := crypto.CreateTLSConfiguration()
if insecureSkipVerify {
tlsConfig.InsecureSkipVerify = true
}
httpsCli := &http.Client{
Transport: &http.Transport{
TLSClientConfig: tlsConfig,
TLSClientConfig: crypto.CreateTLSConfiguration(insecureSkipVerify),
Proxy: http.ProxyFromEnvironment,
},
Timeout: 300 * time.Second,

View File

@@ -58,7 +58,15 @@ func TestService_ClonePublicRepository_Azure(t *testing.T) {
t.Run(tt.name, func(t *testing.T) {
dst := t.TempDir()
repositoryUrl := fmt.Sprintf(tt.args.repositoryURLFormat, tt.args.password)
err := service.CloneRepository(dst, repositoryUrl, tt.args.referenceName, "", "", false)
err := service.CloneRepository(
dst,
repositoryUrl,
tt.args.referenceName,
"",
"",
gittypes.GitCredentialAuthType_Basic,
false,
)
assert.NoError(t, err)
assert.FileExists(t, filepath.Join(dst, "README.md"))
})
@@ -73,7 +81,15 @@ func TestService_ClonePrivateRepository_Azure(t *testing.T) {
dst := t.TempDir()
err := service.CloneRepository(dst, privateAzureRepoURL, "refs/heads/main", "", pat, false)
err := service.CloneRepository(
dst,
privateAzureRepoURL,
"refs/heads/main",
"",
pat,
gittypes.GitCredentialAuthType_Basic,
false,
)
assert.NoError(t, err)
assert.FileExists(t, filepath.Join(dst, "README.md"))
}
@@ -84,7 +100,14 @@ func TestService_LatestCommitID_Azure(t *testing.T) {
pat := getRequiredValue(t, "AZURE_DEVOPS_PAT")
service := NewService(context.TODO())
id, err := service.LatestCommitID(privateAzureRepoURL, "refs/heads/main", "", pat, false)
id, err := service.LatestCommitID(
privateAzureRepoURL,
"refs/heads/main",
"",
pat,
gittypes.GitCredentialAuthType_Basic,
false,
)
assert.NoError(t, err)
assert.NotEmpty(t, id, "cannot guarantee commit id, but it should be not empty")
}
@@ -96,7 +119,14 @@ func TestService_ListRefs_Azure(t *testing.T) {
username := getRequiredValue(t, "AZURE_DEVOPS_USERNAME")
service := NewService(context.TODO())
refs, err := service.ListRefs(privateAzureRepoURL, username, accessToken, false, false)
refs, err := service.ListRefs(
privateAzureRepoURL,
username,
accessToken,
gittypes.GitCredentialAuthType_Basic,
false,
false,
)
assert.NoError(t, err)
assert.GreaterOrEqual(t, len(refs), 1)
}
@@ -108,8 +138,8 @@ func TestService_ListRefs_Azure_Concurrently(t *testing.T) {
username := getRequiredValue(t, "AZURE_DEVOPS_USERNAME")
service := newService(context.TODO(), repositoryCacheSize, 200*time.Millisecond)
go service.ListRefs(privateAzureRepoURL, username, accessToken, false, false)
service.ListRefs(privateAzureRepoURL, username, accessToken, false, false)
go service.ListRefs(privateAzureRepoURL, username, accessToken, gittypes.GitCredentialAuthType_Basic, false, false)
service.ListRefs(privateAzureRepoURL, username, accessToken, gittypes.GitCredentialAuthType_Basic, false, false)
time.Sleep(2 * time.Second)
}
@@ -247,7 +277,17 @@ func TestService_ListFiles_Azure(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
paths, err := service.ListFiles(tt.args.repositoryUrl, tt.args.referenceName, tt.args.username, tt.args.password, false, false, tt.extensions, false)
paths, err := service.ListFiles(
tt.args.repositoryUrl,
tt.args.referenceName,
tt.args.username,
tt.args.password,
gittypes.GitCredentialAuthType_Basic,
false,
false,
tt.extensions,
false,
)
if tt.expect.shouldFail {
assert.Error(t, err)
if tt.expect.err != nil {
@@ -270,8 +310,28 @@ func TestService_ListFiles_Azure_Concurrently(t *testing.T) {
username := getRequiredValue(t, "AZURE_DEVOPS_USERNAME")
service := newService(context.TODO(), repositoryCacheSize, 200*time.Millisecond)
go service.ListFiles(privateAzureRepoURL, "refs/heads/main", username, accessToken, false, false, []string{}, false)
service.ListFiles(privateAzureRepoURL, "refs/heads/main", username, accessToken, false, false, []string{}, false)
go service.ListFiles(
privateAzureRepoURL,
"refs/heads/main",
username,
accessToken,
gittypes.GitCredentialAuthType_Basic,
false,
false,
[]string{},
false,
)
service.ListFiles(
privateAzureRepoURL,
"refs/heads/main",
username,
accessToken,
gittypes.GitCredentialAuthType_Basic,
false,
false,
[]string{},
false,
)
time.Sleep(2 * time.Second)
}

View File

@@ -8,6 +8,7 @@ import (
"testing"
gittypes "github.com/portainer/portainer/api/git/types"
"github.com/portainer/portainer/pkg/fips"
"github.com/stretchr/testify/assert"
)
@@ -234,6 +235,8 @@ func Test_isAzureUrl(t *testing.T) {
}
func Test_azureDownloader_downloadZipFromAzureDevOps(t *testing.T) {
fips.InitFIPS(false)
type args struct {
options baseOption
}
@@ -308,6 +311,8 @@ func Test_azureDownloader_downloadZipFromAzureDevOps(t *testing.T) {
}
func Test_azureDownloader_latestCommitID(t *testing.T) {
fips.InitFIPS(false)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
response := `{
"count": 1,

View File

@@ -19,6 +19,7 @@ type CloneOptions struct {
ReferenceName string
Username string
Password string
AuthType gittypes.GitCredentialAuthType
// TLSSkipVerify skips SSL verification when cloning the Git repository
TLSSkipVerify bool `example:"false"`
}
@@ -42,7 +43,15 @@ func CloneWithBackup(gitService portainer.GitService, fileService portainer.File
cleanUp = true
if err := gitService.CloneRepository(options.ProjectPath, options.URL, options.ReferenceName, options.Username, options.Password, options.TLSSkipVerify); err != nil {
if err := gitService.CloneRepository(
options.ProjectPath,
options.URL,
options.ReferenceName,
options.Username,
options.Password,
options.AuthType,
options.TLSSkipVerify,
); err != nil {
cleanUp = false
if err := filesystem.MoveDirectory(backupProjectPath, options.ProjectPath, false); err != nil {
log.Warn().Err(err).Msg("failed restoring backup folder")

View File

@@ -7,12 +7,14 @@ import (
"strings"
gittypes "github.com/portainer/portainer/api/git/types"
"github.com/rs/zerolog/log"
"github.com/go-git/go-git/v5"
"github.com/go-git/go-git/v5/config"
"github.com/go-git/go-git/v5/plumbing"
"github.com/go-git/go-git/v5/plumbing/filemode"
"github.com/go-git/go-git/v5/plumbing/object"
"github.com/go-git/go-git/v5/plumbing/transport"
githttp "github.com/go-git/go-git/v5/plumbing/transport/http"
"github.com/go-git/go-git/v5/storage/memory"
"github.com/pkg/errors"
@@ -33,7 +35,7 @@ func (c *gitClient) download(ctx context.Context, dst string, opt cloneOption) e
URL: opt.repositoryUrl,
Depth: opt.depth,
InsecureSkipTLS: opt.tlsSkipVerify,
Auth: getAuth(opt.username, opt.password),
Auth: getAuth(opt.authType, opt.username, opt.password),
Tags: git.NoTags,
}
@@ -51,7 +53,10 @@ func (c *gitClient) download(ctx context.Context, dst string, opt cloneOption) e
}
if !c.preserveGitDirectory {
os.RemoveAll(filepath.Join(dst, ".git"))
err := os.RemoveAll(filepath.Join(dst, ".git"))
if err != nil {
log.Error().Err(err).Msg("failed to remove .git directory")
}
}
return nil
@@ -64,7 +69,7 @@ func (c *gitClient) latestCommitID(ctx context.Context, opt fetchOption) (string
})
listOptions := &git.ListOptions{
Auth: getAuth(opt.username, opt.password),
Auth: getAuth(opt.authType, opt.username, opt.password),
InsecureSkipTLS: opt.tlsSkipVerify,
}
@@ -94,7 +99,23 @@ func (c *gitClient) latestCommitID(ctx context.Context, opt fetchOption) (string
return "", errors.Errorf("could not find ref %q in the repository", opt.referenceName)
}
func getAuth(username, password string) *githttp.BasicAuth {
func getAuth(authType gittypes.GitCredentialAuthType, username, password string) transport.AuthMethod {
if password == "" {
return nil
}
switch authType {
case gittypes.GitCredentialAuthType_Basic:
return getBasicAuth(username, password)
case gittypes.GitCredentialAuthType_Token:
return getTokenAuth(password)
default:
log.Warn().Msg("unknown git credentials authorization type, defaulting to None")
return nil
}
}
func getBasicAuth(username, password string) *githttp.BasicAuth {
if password != "" {
if username == "" {
username = "token"
@@ -108,6 +129,15 @@ func getAuth(username, password string) *githttp.BasicAuth {
return nil
}
func getTokenAuth(token string) *githttp.TokenAuth {
if token != "" {
return &githttp.TokenAuth{
Token: token,
}
}
return nil
}
func (c *gitClient) listRefs(ctx context.Context, opt baseOption) ([]string, error) {
rem := git.NewRemote(memory.NewStorage(), &config.RemoteConfig{
Name: "origin",
@@ -115,7 +145,7 @@ func (c *gitClient) listRefs(ctx context.Context, opt baseOption) ([]string, err
})
listOptions := &git.ListOptions{
Auth: getAuth(opt.username, opt.password),
Auth: getAuth(opt.authType, opt.username, opt.password),
InsecureSkipTLS: opt.tlsSkipVerify,
}
@@ -143,7 +173,7 @@ func (c *gitClient) listFiles(ctx context.Context, opt fetchOption) ([]string, e
Depth: 1,
SingleBranch: true,
ReferenceName: plumbing.ReferenceName(opt.referenceName),
Auth: getAuth(opt.username, opt.password),
Auth: getAuth(opt.authType, opt.username, opt.password),
InsecureSkipTLS: opt.tlsSkipVerify,
Tags: git.NoTags,
}

View File

@@ -2,6 +2,8 @@ package git
import (
"context"
"net/http"
"net/http/httptest"
"path/filepath"
"testing"
"time"
@@ -24,7 +26,15 @@ func TestService_ClonePrivateRepository_GitHub(t *testing.T) {
dst := t.TempDir()
repositoryUrl := privateGitRepoURL
err := service.CloneRepository(dst, repositoryUrl, "refs/heads/main", username, accessToken, false)
err := service.CloneRepository(
dst,
repositoryUrl,
"refs/heads/main",
username,
accessToken,
gittypes.GitCredentialAuthType_Basic,
false,
)
assert.NoError(t, err)
assert.FileExists(t, filepath.Join(dst, "README.md"))
}
@@ -37,7 +47,14 @@ func TestService_LatestCommitID_GitHub(t *testing.T) {
service := newService(context.TODO(), 0, 0)
repositoryUrl := privateGitRepoURL
id, err := service.LatestCommitID(repositoryUrl, "refs/heads/main", username, accessToken, false)
id, err := service.LatestCommitID(
repositoryUrl,
"refs/heads/main",
username,
accessToken,
gittypes.GitCredentialAuthType_Basic,
false,
)
assert.NoError(t, err)
assert.NotEmpty(t, id, "cannot guarantee commit id, but it should be not empty")
}
@@ -50,7 +67,7 @@ func TestService_ListRefs_GitHub(t *testing.T) {
service := newService(context.TODO(), 0, 0)
repositoryUrl := privateGitRepoURL
refs, err := service.ListRefs(repositoryUrl, username, accessToken, false, false)
refs, err := service.ListRefs(repositoryUrl, username, accessToken, gittypes.GitCredentialAuthType_Basic, false, false)
assert.NoError(t, err)
assert.GreaterOrEqual(t, len(refs), 1)
}
@@ -63,8 +80,8 @@ func TestService_ListRefs_Github_Concurrently(t *testing.T) {
service := newService(context.TODO(), repositoryCacheSize, 200*time.Millisecond)
repositoryUrl := privateGitRepoURL
go service.ListRefs(repositoryUrl, username, accessToken, false, false)
service.ListRefs(repositoryUrl, username, accessToken, false, false)
go service.ListRefs(repositoryUrl, username, accessToken, gittypes.GitCredentialAuthType_Basic, false, false)
service.ListRefs(repositoryUrl, username, accessToken, gittypes.GitCredentialAuthType_Basic, false, false)
time.Sleep(2 * time.Second)
}
@@ -202,7 +219,17 @@ func TestService_ListFiles_GitHub(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
paths, err := service.ListFiles(tt.args.repositoryUrl, tt.args.referenceName, tt.args.username, tt.args.password, false, false, tt.extensions, false)
paths, err := service.ListFiles(
tt.args.repositoryUrl,
tt.args.referenceName,
tt.args.username,
tt.args.password,
gittypes.GitCredentialAuthType_Basic,
false,
false,
tt.extensions,
false,
)
if tt.expect.shouldFail {
assert.Error(t, err)
if tt.expect.err != nil {
@@ -226,8 +253,28 @@ func TestService_ListFiles_Github_Concurrently(t *testing.T) {
username := getRequiredValue(t, "GITHUB_USERNAME")
service := newService(context.TODO(), repositoryCacheSize, 200*time.Millisecond)
go service.ListFiles(repositoryUrl, "refs/heads/main", username, accessToken, false, false, []string{}, false)
service.ListFiles(repositoryUrl, "refs/heads/main", username, accessToken, false, false, []string{}, false)
go service.ListFiles(
repositoryUrl,
"refs/heads/main",
username,
accessToken,
gittypes.GitCredentialAuthType_Basic,
false,
false,
[]string{},
false,
)
service.ListFiles(
repositoryUrl,
"refs/heads/main",
username,
accessToken,
gittypes.GitCredentialAuthType_Basic,
false,
false,
[]string{},
false,
)
time.Sleep(2 * time.Second)
}
@@ -240,8 +287,18 @@ func TestService_purgeCache_Github(t *testing.T) {
username := getRequiredValue(t, "GITHUB_USERNAME")
service := NewService(context.TODO())
service.ListRefs(repositoryUrl, username, accessToken, false, false)
service.ListFiles(repositoryUrl, "refs/heads/main", username, accessToken, false, false, []string{}, false)
service.ListRefs(repositoryUrl, username, accessToken, gittypes.GitCredentialAuthType_Basic, false, false)
service.ListFiles(
repositoryUrl,
"refs/heads/main",
username,
accessToken,
gittypes.GitCredentialAuthType_Basic,
false,
false,
[]string{},
false,
)
assert.Equal(t, 1, service.repoRefCache.Len())
assert.Equal(t, 1, service.repoFileCache.Len())
@@ -261,8 +318,18 @@ func TestService_purgeCacheByTTL_Github(t *testing.T) {
// 40*timeout is designed for giving enough time for ListRefs and ListFiles to cache the result
service := newService(context.TODO(), 2, 40*timeout)
service.ListRefs(repositoryUrl, username, accessToken, false, false)
service.ListFiles(repositoryUrl, "refs/heads/main", username, accessToken, false, false, []string{}, false)
service.ListRefs(repositoryUrl, username, accessToken, gittypes.GitCredentialAuthType_Basic, false, false)
service.ListFiles(
repositoryUrl,
"refs/heads/main",
username,
accessToken,
gittypes.GitCredentialAuthType_Basic,
false,
false,
[]string{},
false,
)
assert.Equal(t, 1, service.repoRefCache.Len())
assert.Equal(t, 1, service.repoFileCache.Len())
@@ -293,12 +360,12 @@ func TestService_HardRefresh_ListRefs_GitHub(t *testing.T) {
service := newService(context.TODO(), 2, 0)
repositoryUrl := privateGitRepoURL
refs, err := service.ListRefs(repositoryUrl, username, accessToken, false, false)
refs, err := service.ListRefs(repositoryUrl, username, accessToken, gittypes.GitCredentialAuthType_Basic, false, false)
assert.NoError(t, err)
assert.GreaterOrEqual(t, len(refs), 1)
assert.Equal(t, 1, service.repoRefCache.Len())
_, err = service.ListRefs(repositoryUrl, username, "fake-token", false, false)
_, err = service.ListRefs(repositoryUrl, username, "fake-token", gittypes.GitCredentialAuthType_Basic, false, false)
assert.Error(t, err)
assert.Equal(t, 1, service.repoRefCache.Len())
}
@@ -311,26 +378,46 @@ func TestService_HardRefresh_ListRefs_And_RemoveAllCaches_GitHub(t *testing.T) {
service := newService(context.TODO(), 2, 0)
repositoryUrl := privateGitRepoURL
refs, err := service.ListRefs(repositoryUrl, username, accessToken, false, false)
refs, err := service.ListRefs(repositoryUrl, username, accessToken, gittypes.GitCredentialAuthType_Basic, false, false)
assert.NoError(t, err)
assert.GreaterOrEqual(t, len(refs), 1)
assert.Equal(t, 1, service.repoRefCache.Len())
files, err := service.ListFiles(repositoryUrl, "refs/heads/main", username, accessToken, false, false, []string{}, false)
files, err := service.ListFiles(
repositoryUrl,
"refs/heads/main",
username,
accessToken,
gittypes.GitCredentialAuthType_Basic,
false,
false,
[]string{},
false,
)
assert.NoError(t, err)
assert.GreaterOrEqual(t, len(files), 1)
assert.Equal(t, 1, service.repoFileCache.Len())
files, err = service.ListFiles(repositoryUrl, "refs/heads/test", username, accessToken, false, false, []string{}, false)
files, err = service.ListFiles(
repositoryUrl,
"refs/heads/test",
username,
accessToken,
gittypes.GitCredentialAuthType_Basic,
false,
false,
[]string{},
false,
)
assert.NoError(t, err)
assert.GreaterOrEqual(t, len(files), 1)
assert.Equal(t, 2, service.repoFileCache.Len())
_, err = service.ListRefs(repositoryUrl, username, "fake-token", false, false)
_, err = service.ListRefs(repositoryUrl, username, "fake-token", gittypes.GitCredentialAuthType_Basic, false, false)
assert.Error(t, err)
assert.Equal(t, 1, service.repoRefCache.Len())
_, err = service.ListRefs(repositoryUrl, username, "fake-token", true, false)
_, err = service.ListRefs(repositoryUrl, username, "fake-token", gittypes.GitCredentialAuthType_Basic, true, false)
assert.Error(t, err)
assert.Equal(t, 1, service.repoRefCache.Len())
// The relevant file caches should be removed too
@@ -344,12 +431,72 @@ func TestService_HardRefresh_ListFiles_GitHub(t *testing.T) {
accessToken := getRequiredValue(t, "GITHUB_PAT")
username := getRequiredValue(t, "GITHUB_USERNAME")
repositoryUrl := privateGitRepoURL
files, err := service.ListFiles(repositoryUrl, "refs/heads/main", username, accessToken, false, false, []string{}, false)
files, err := service.ListFiles(
repositoryUrl,
"refs/heads/main",
username,
accessToken,
gittypes.GitCredentialAuthType_Basic,
false,
false,
[]string{},
false,
)
assert.NoError(t, err)
assert.GreaterOrEqual(t, len(files), 1)
assert.Equal(t, 1, service.repoFileCache.Len())
_, err = service.ListFiles(repositoryUrl, "refs/heads/main", username, "fake-token", false, true, []string{}, false)
_, err = service.ListFiles(
repositoryUrl,
"refs/heads/main",
username,
"fake-token",
gittypes.GitCredentialAuthType_Basic,
false,
true,
[]string{},
false,
)
assert.Error(t, err)
assert.Equal(t, 0, service.repoFileCache.Len())
}
func TestService_CloneRepository_TokenAuth(t *testing.T) {
ensureIntegrationTest(t)
service := newService(context.TODO(), 2, 0)
var requests []*http.Request
testServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requests = append(requests, r)
}))
accessToken := "test_access_token"
username := "test_username"
repositoryUrl := testServer.URL
// Since we aren't hitting a real git server we ignore the error
_ = service.CloneRepository(
"test_dir",
repositoryUrl,
"refs/heads/main",
username,
accessToken,
gittypes.GitCredentialAuthType_Token,
false,
)
testServer.Close()
if len(requests) != 1 {
t.Fatalf("expected 1 request sent but got %d", len(requests))
}
gotAuthHeader := requests[0].Header.Get("Authorization")
if gotAuthHeader == "" {
t.Fatal("no Authorization header in git request")
}
expectedAuthHeader := "Bearer test_access_token"
if gotAuthHeader != expectedAuthHeader {
t.Fatalf("expected Authorization header %q but got %q", expectedAuthHeader, gotAuthHeader)
}
}

View File

@@ -38,9 +38,9 @@ func Test_ClonePublicRepository_Shallow(t *testing.T) {
dir := t.TempDir()
t.Logf("Cloning into %s", dir)
err := service.CloneRepository(dir, repositoryURL, referenceName, "", "", false)
err := service.CloneRepository(dir, repositoryURL, referenceName, "", "", gittypes.GitCredentialAuthType_Basic, false)
assert.NoError(t, err)
assert.Equal(t, 1, getCommitHistoryLength(t, err, dir), "cloned repo has incorrect depth")
assert.Equal(t, 1, getCommitHistoryLength(t, dir), "cloned repo has incorrect depth")
}
func Test_ClonePublicRepository_NoGitDirectory(t *testing.T) {
@@ -50,7 +50,7 @@ func Test_ClonePublicRepository_NoGitDirectory(t *testing.T) {
dir := t.TempDir()
t.Logf("Cloning into %s", dir)
err := service.CloneRepository(dir, repositoryURL, referenceName, "", "", false)
err := service.CloneRepository(dir, repositoryURL, referenceName, "", "", gittypes.GitCredentialAuthType_Basic, false)
assert.NoError(t, err)
assert.NoDirExists(t, filepath.Join(dir, ".git"))
}
@@ -75,7 +75,7 @@ func Test_cloneRepository(t *testing.T) {
})
assert.NoError(t, err)
assert.Equal(t, 4, getCommitHistoryLength(t, err, dir), "cloned repo has incorrect depth")
assert.Equal(t, 4, getCommitHistoryLength(t, dir), "cloned repo has incorrect depth")
}
func Test_latestCommitID(t *testing.T) {
@@ -84,7 +84,7 @@ func Test_latestCommitID(t *testing.T) {
repositoryURL := setup(t)
referenceName := "refs/heads/main"
id, err := service.LatestCommitID(repositoryURL, referenceName, "", "", false)
id, err := service.LatestCommitID(repositoryURL, referenceName, "", "", gittypes.GitCredentialAuthType_Basic, false)
assert.NoError(t, err)
assert.Equal(t, "68dcaa7bd452494043c64252ab90db0f98ecf8d2", id)
@@ -95,7 +95,7 @@ func Test_ListRefs(t *testing.T) {
repositoryURL := setup(t)
fs, err := service.ListRefs(repositoryURL, "", "", false, false)
fs, err := service.ListRefs(repositoryURL, "", "", gittypes.GitCredentialAuthType_Basic, false, false)
assert.NoError(t, err)
assert.Equal(t, []string{"refs/heads/main"}, fs)
@@ -107,13 +107,23 @@ func Test_ListFiles(t *testing.T) {
repositoryURL := setup(t)
referenceName := "refs/heads/main"
fs, err := service.ListFiles(repositoryURL, referenceName, "", "", false, false, []string{".yml"}, false)
fs, err := service.ListFiles(
repositoryURL,
referenceName,
"",
"",
gittypes.GitCredentialAuthType_Basic,
false,
false,
[]string{".yml"},
false,
)
assert.NoError(t, err)
assert.Equal(t, []string{"docker-compose.yml"}, fs)
}
func getCommitHistoryLength(t *testing.T, err error, dir string) int {
func getCommitHistoryLength(t *testing.T, dir string) int {
repo, err := git.PlainOpen(dir)
if err != nil {
t.Fatalf("can't open a git repo at %s with error %v", dir, err)
@@ -125,11 +135,10 @@ func getCommitHistoryLength(t *testing.T, err error, dir string) int {
}
count := 0
err = iter.ForEach(func(_ *object.Commit) error {
if err := iter.ForEach(func(_ *object.Commit) error {
count++
return nil
})
if err != nil {
}); err != nil {
t.Fatalf("can't iterate over the commit history with error %v", err)
}
@@ -255,7 +264,7 @@ func Test_listFilesPrivateRepository(t *testing.T) {
name: "list tree with real repository and head ref but no credential",
args: fetchOption{
baseOption: baseOption{
repositoryUrl: privateGitRepoURL + "fake",
repositoryUrl: privateGitRepoURL,
username: "",
password: "",
},

View File

@@ -8,6 +8,7 @@ import (
"time"
lru "github.com/hashicorp/golang-lru"
gittypes "github.com/portainer/portainer/api/git/types"
"github.com/rs/zerolog/log"
"golang.org/x/sync/singleflight"
)
@@ -22,6 +23,7 @@ type baseOption struct {
repositoryUrl string
username string
password string
authType gittypes.GitCredentialAuthType
tlsSkipVerify bool
}
@@ -123,13 +125,22 @@ func (service *Service) timerHasStopped() bool {
// CloneRepository clones a git repository using the specified URL in the specified
// destination folder.
func (service *Service) CloneRepository(destination, repositoryURL, referenceName, username, password string, tlsSkipVerify bool) error {
func (service *Service) CloneRepository(
destination,
repositoryURL,
referenceName,
username,
password string,
authType gittypes.GitCredentialAuthType,
tlsSkipVerify bool,
) error {
options := cloneOption{
fetchOption: fetchOption{
baseOption: baseOption{
repositoryUrl: repositoryURL,
username: username,
password: password,
authType: authType,
tlsSkipVerify: tlsSkipVerify,
},
referenceName: referenceName,
@@ -155,12 +166,20 @@ func (service *Service) cloneRepository(destination string, options cloneOption)
}
// LatestCommitID returns SHA1 of the latest commit of the specified reference
func (service *Service) LatestCommitID(repositoryURL, referenceName, username, password string, tlsSkipVerify bool) (string, error) {
func (service *Service) LatestCommitID(
repositoryURL,
referenceName,
username,
password string,
authType gittypes.GitCredentialAuthType,
tlsSkipVerify bool,
) (string, error) {
options := fetchOption{
baseOption: baseOption{
repositoryUrl: repositoryURL,
username: username,
password: password,
authType: authType,
tlsSkipVerify: tlsSkipVerify,
},
referenceName: referenceName,
@@ -170,7 +189,14 @@ func (service *Service) LatestCommitID(repositoryURL, referenceName, username, p
}
// ListRefs will list target repository's references without cloning the repository
func (service *Service) ListRefs(repositoryURL, username, password string, hardRefresh bool, tlsSkipVerify bool) ([]string, error) {
func (service *Service) ListRefs(
repositoryURL,
username,
password string,
authType gittypes.GitCredentialAuthType,
hardRefresh bool,
tlsSkipVerify bool,
) ([]string, error) {
refCacheKey := generateCacheKey(repositoryURL, username, password, strconv.FormatBool(tlsSkipVerify))
if service.cacheEnabled && hardRefresh {
// Should remove the cache explicitly, so that the following normal list can show the correct result
@@ -196,6 +222,7 @@ func (service *Service) ListRefs(repositoryURL, username, password string, hardR
repositoryUrl: repositoryURL,
username: username,
password: password,
authType: authType,
tlsSkipVerify: tlsSkipVerify,
}
@@ -215,18 +242,62 @@ var singleflightGroup = &singleflight.Group{}
// ListFiles will list all the files of the target repository with specific extensions.
// If extension is not provided, it will list all the files under the target repository
func (service *Service) ListFiles(repositoryURL, referenceName, username, password string, dirOnly, hardRefresh bool, includedExts []string, tlsSkipVerify bool) ([]string, error) {
repoKey := generateCacheKey(repositoryURL, referenceName, username, password, strconv.FormatBool(tlsSkipVerify), strconv.FormatBool(dirOnly))
func (service *Service) ListFiles(
repositoryURL,
referenceName,
username,
password string,
authType gittypes.GitCredentialAuthType,
dirOnly,
hardRefresh bool,
includedExts []string,
tlsSkipVerify bool,
) ([]string, error) {
repoKey := generateCacheKey(
repositoryURL,
referenceName,
username,
password,
strconv.FormatBool(tlsSkipVerify),
strconv.Itoa(int(authType)),
strconv.FormatBool(dirOnly),
)
fs, err, _ := singleflightGroup.Do(repoKey, func() (any, error) {
return service.listFiles(repositoryURL, referenceName, username, password, dirOnly, hardRefresh, tlsSkipVerify)
return service.listFiles(
repositoryURL,
referenceName,
username,
password,
authType,
dirOnly,
hardRefresh,
tlsSkipVerify,
)
})
return filterFiles(fs.([]string), includedExts), err
}
func (service *Service) listFiles(repositoryURL, referenceName, username, password string, dirOnly, hardRefresh bool, tlsSkipVerify bool) ([]string, error) {
repoKey := generateCacheKey(repositoryURL, referenceName, username, password, strconv.FormatBool(tlsSkipVerify), strconv.FormatBool(dirOnly))
func (service *Service) listFiles(
repositoryURL,
referenceName,
username,
password string,
authType gittypes.GitCredentialAuthType,
dirOnly,
hardRefresh bool,
tlsSkipVerify bool,
) ([]string, error) {
repoKey := generateCacheKey(
repositoryURL,
referenceName,
username,
password,
strconv.FormatBool(tlsSkipVerify),
strconv.Itoa(int(authType)),
strconv.FormatBool(dirOnly),
)
if service.cacheEnabled && hardRefresh {
// Should remove the cache explicitly, so that the following normal list can show the correct result
@@ -247,6 +318,7 @@ func (service *Service) listFiles(repositoryURL, referenceName, username, passwo
repositoryUrl: repositoryURL,
username: username,
password: password,
authType: authType,
tlsSkipVerify: tlsSkipVerify,
},
referenceName: referenceName,

View File

@@ -1,12 +1,21 @@
package gittypes
import "errors"
import (
"errors"
)
var (
ErrIncorrectRepositoryURL = errors.New("git repository could not be found, please ensure that the URL is correct")
ErrAuthenticationFailure = errors.New("authentication failed, please ensure that the git credentials are correct")
)
type GitCredentialAuthType int
const (
GitCredentialAuthType_Basic GitCredentialAuthType = iota
GitCredentialAuthType_Token
)
// RepoConfig represents a configuration for a repo
type RepoConfig struct {
// The repo url
@@ -24,10 +33,11 @@ type RepoConfig struct {
}
type GitAuthentication struct {
Username string
Password string
Username string
Password string
AuthorizationType GitCredentialAuthType
// Git credentials identifier when the value is not 0
// When the value is 0, Username and Password are set without using saved credential
// When the value is 0, Username, Password, and Authtype are set without using saved credential
// This is introduced since 2.15.0
GitCredentialID int `example:"0"`
}

View File

@@ -29,7 +29,14 @@ func UpdateGitObject(gitService portainer.GitService, objId string, gitConfig *g
return false, "", errors.WithMessagef(err, "failed to get credentials for %v", objId)
}
newHash, err := gitService.LatestCommitID(gitConfig.URL, gitConfig.ReferenceName, username, password, gitConfig.TLSSkipVerify)
newHash, err := gitService.LatestCommitID(
gitConfig.URL,
gitConfig.ReferenceName,
username,
password,
gittypes.GitCredentialAuthType_Basic,
gitConfig.TLSSkipVerify,
)
if err != nil {
return false, "", errors.WithMessagef(err, "failed to fetch latest commit id of %v", objId)
}
@@ -62,6 +69,7 @@ func UpdateGitObject(gitService portainer.GitService, objId string, gitConfig *g
cloneParams.auth = &gitAuth{
username: username,
password: password,
authType: gitConfig.Authentication.AuthorizationType,
}
}
@@ -89,14 +97,31 @@ type cloneRepositoryParameters struct {
}
type gitAuth struct {
authType gittypes.GitCredentialAuthType
username string
password string
}
func cloneGitRepository(gitService portainer.GitService, cloneParams *cloneRepositoryParameters) error {
if cloneParams.auth != nil {
return gitService.CloneRepository(cloneParams.toDir, cloneParams.url, cloneParams.ref, cloneParams.auth.username, cloneParams.auth.password, cloneParams.tlsSkipVerify)
return gitService.CloneRepository(
cloneParams.toDir,
cloneParams.url,
cloneParams.ref,
cloneParams.auth.username,
cloneParams.auth.password,
cloneParams.auth.authType,
cloneParams.tlsSkipVerify,
)
}
return gitService.CloneRepository(cloneParams.toDir, cloneParams.url, cloneParams.ref, "", "", cloneParams.tlsSkipVerify)
return gitService.CloneRepository(
cloneParams.toDir,
cloneParams.url,
cloneParams.ref,
"",
"",
gittypes.GitCredentialAuthType_Basic,
cloneParams.tlsSkipVerify,
)
}

View File

@@ -33,14 +33,11 @@ type Service struct {
// NewService initializes a new service.
func NewService(insecureSkipVerify bool) *Service {
tlsConfig := crypto.CreateTLSConfiguration()
tlsConfig.InsecureSkipVerify = insecureSkipVerify
return &Service{
httpsClient: &http.Client{
Timeout: httpClientTimeout,
Transport: &http.Transport{
TLSClientConfig: tlsConfig,
TLSClientConfig: crypto.CreateTLSConfiguration(insecureSkipVerify),
},
},
}

View File

@@ -0,0 +1,18 @@
package openamt
import (
"net/http"
"testing"
"github.com/portainer/portainer/pkg/fips"
"github.com/stretchr/testify/require"
)
func TestNewService(t *testing.T) {
fips.InitFIPS(false)
service := NewService(true)
require.NotNil(t, service)
require.True(t, service.httpsClient.Transport.(*http.Transport).TLSClientConfig.InsecureSkipVerify) //nolint:forbidigo
}

View File

@@ -1,7 +1,6 @@
package client
import (
"crypto/tls"
"errors"
"fmt"
"io"
@@ -11,6 +10,7 @@ import (
"time"
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/crypto"
"github.com/rs/zerolog/log"
"github.com/segmentio/encoding/json"
@@ -105,21 +105,28 @@ func Get(url string, timeout int) ([]byte, error) {
// ExecutePingOperation will send a SystemPing operation HTTP request to a Docker environment(endpoint)
// using the specified host and optional TLS configuration.
// It uses a new Http.Client for each operation.
func ExecutePingOperation(host string, tlsConfig *tls.Config) (bool, error) {
func ExecutePingOperation(host string, tlsConfiguration portainer.TLSConfiguration) (bool, error) {
transport := &http.Transport{}
scheme := "http"
if tlsConfig != nil {
if tlsConfiguration.TLS {
tlsConfig, err := crypto.CreateTLSConfigurationFromDisk(tlsConfiguration)
if err != nil {
return false, err
}
transport.TLSClientConfig = tlsConfig
scheme = "https"
}
client := &http.Client{
Timeout: time.Second * 3,
Timeout: 3 * time.Second,
Transport: transport,
}
target := strings.Replace(host, "tcp://", scheme+"://", 1)
return pingOperation(client, target)
}
@@ -134,10 +141,7 @@ func pingOperation(client *http.Client, target string) (bool, error) {
io.Copy(io.Discard, resp.Body)
resp.Body.Close()
agentOnDockerEnvironment := false
if resp.Header.Get(portainer.PortainerAgentHeader) != "" {
agentOnDockerEnvironment = true
}
agentOnDockerEnvironment := resp.Header.Get(portainer.PortainerAgentHeader) != ""
return agentOnDockerEnvironment, nil
}

View File

@@ -0,0 +1,47 @@
package client
import (
"net/http"
"net/http/httptest"
"testing"
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/pkg/fips"
"github.com/stretchr/testify/require"
)
func TestExecutePingOperationFailure(t *testing.T) {
fips.InitFIPS(false)
host := "http://localhost:1"
config := portainer.TLSConfiguration{
TLS: true,
TLSSkipVerify: true,
}
// Invalid host
ok, err := ExecutePingOperation(host, config)
require.False(t, ok)
require.Error(t, err)
// Invalid TLS configuration
config.TLSCertPath = "/invalid/path/to/cert"
config.TLSKeyPath = "/invalid/path/to/key"
ok, err = ExecutePingOperation(host, config)
require.False(t, ok)
require.Error(t, err)
}
func TestPingOperation(t *testing.T) {
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Add(portainer.PortainerAgentHeader, "1")
}))
defer srv.Close()
agentOnDockerEnv, err := pingOperation(http.DefaultClient, srv.URL)
require.NoError(t, err)
require.True(t, agentOnDockerEnv)
}

View File

@@ -18,10 +18,15 @@ import (
"github.com/portainer/portainer/api/crypto"
"github.com/portainer/portainer/api/http/offlinegate"
"github.com/portainer/portainer/api/internal/testhelpers"
"github.com/portainer/portainer/pkg/fips"
"github.com/stretchr/testify/assert"
)
func init() {
fips.InitFIPS(false)
}
func listFiles(dir string) []string {
items := make([]string, 0)
filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {

View File

@@ -20,12 +20,17 @@ import (
"github.com/portainer/portainer/api/internal/authorization"
"github.com/portainer/portainer/api/internal/testhelpers"
"github.com/portainer/portainer/api/jwt"
"github.com/portainer/portainer/pkg/fips"
httperror "github.com/portainer/portainer/pkg/libhttp/error"
"github.com/segmentio/encoding/json"
"github.com/stretchr/testify/assert"
)
func init() {
fips.InitFIPS(false)
}
var testFileContent = "abcdefg"
type TestGitService struct {
@@ -33,13 +38,28 @@ type TestGitService struct {
targetFilePath string
}
func (g *TestGitService) CloneRepository(destination string, repositoryURL, referenceName string, username, password string, tlsSkipVerify bool) error {
func (g *TestGitService) CloneRepository(
destination string,
repositoryURL,
referenceName string,
username,
password string,
authType gittypes.GitCredentialAuthType,
tlsSkipVerify bool,
) error {
time.Sleep(100 * time.Millisecond)
return createTestFile(g.targetFilePath)
}
func (g *TestGitService) LatestCommitID(repositoryURL, referenceName, username, password string, tlsSkipVerify bool) (string, error) {
func (g *TestGitService) LatestCommitID(
repositoryURL,
referenceName,
username,
password string,
authType gittypes.GitCredentialAuthType,
tlsSkipVerify bool,
) (string, error) {
return "", nil
}
@@ -56,11 +76,26 @@ type InvalidTestGitService struct {
targetFilePath string
}
func (g *InvalidTestGitService) CloneRepository(dest, repoUrl, refName, username, password string, tlsSkipVerify bool) error {
func (g *InvalidTestGitService) CloneRepository(
dest,
repoUrl,
refName,
username,
password string,
authType gittypes.GitCredentialAuthType,
tlsSkipVerify bool,
) error {
return errors.New("simulate network error")
}
func (g *InvalidTestGitService) LatestCommitID(repositoryURL, referenceName, username, password string, tlsSkipVerify bool) (string, error) {
func (g *InvalidTestGitService) LatestCommitID(
repositoryURL,
referenceName,
username,
password string,
authType gittypes.GitCredentialAuthType,
tlsSkipVerify bool,
) (string, error) {
return "", nil
}

View File

@@ -37,14 +37,16 @@ type customTemplateUpdatePayload struct {
RepositoryURL string `example:"https://github.com/openfaas/faas" validate:"required"`
// Reference name of a Git repository hosting the Stack file
RepositoryReferenceName string `example:"refs/heads/master"`
// Use basic authentication to clone the Git repository
// Use authentication to clone the Git repository
RepositoryAuthentication bool `example:"true"`
// Username used in basic authentication. Required when RepositoryAuthentication is true
// and RepositoryGitCredentialID is 0
// and RepositoryGitCredentialID is 0. Ignored if RepositoryAuthType is token
RepositoryUsername string `example:"myGitUsername"`
// Password used in basic authentication. Required when RepositoryAuthentication is true
// and RepositoryGitCredentialID is 0
// Password used in basic authentication or token used in token authentication.
// Required when RepositoryAuthentication is true and RepositoryGitCredentialID is 0
RepositoryPassword string `example:"myGitPassword"`
// RepositoryAuthorizationType is the authorization type to use
RepositoryAuthorizationType gittypes.GitCredentialAuthType `example:"0"`
// GitCredentialID used to identify the bound git credential. Required when RepositoryAuthentication
// is true and RepositoryUsername/RepositoryPassword are not provided
RepositoryGitCredentialID int `example:"0"`
@@ -182,12 +184,15 @@ func (handler *Handler) customTemplateUpdate(w http.ResponseWriter, r *http.Requ
repositoryUsername := ""
repositoryPassword := ""
repositoryAuthType := gittypes.GitCredentialAuthType_Basic
if payload.RepositoryAuthentication {
repositoryUsername = payload.RepositoryUsername
repositoryPassword = payload.RepositoryPassword
repositoryAuthType = payload.RepositoryAuthorizationType
gitConfig.Authentication = &gittypes.GitAuthentication{
Username: payload.RepositoryUsername,
Password: payload.RepositoryPassword,
Username: payload.RepositoryUsername,
Password: payload.RepositoryPassword,
AuthorizationType: payload.RepositoryAuthorizationType,
}
}
@@ -197,6 +202,7 @@ func (handler *Handler) customTemplateUpdate(w http.ResponseWriter, r *http.Requ
ReferenceName: gitConfig.ReferenceName,
Username: repositoryUsername,
Password: repositoryPassword,
AuthType: repositoryAuthType,
TLSSkipVerify: gitConfig.TLSSkipVerify,
})
if err != nil {
@@ -205,7 +211,14 @@ func (handler *Handler) customTemplateUpdate(w http.ResponseWriter, r *http.Requ
defer cleanBackup()
commitHash, err := handler.GitService.LatestCommitID(gitConfig.URL, gitConfig.ReferenceName, repositoryUsername, repositoryPassword, gitConfig.TLSSkipVerify)
commitHash, err := handler.GitService.LatestCommitID(
gitConfig.URL,
gitConfig.ReferenceName,
repositoryUsername,
repositoryPassword,
repositoryAuthType,
gitConfig.TLSSkipVerify,
)
if err != nil {
return httperror.InternalServerError("Unable get latest commit id", fmt.Errorf("failed to fetch latest commit id of the template %v: %w", customTemplate.ID, err))
}

View File

@@ -45,7 +45,7 @@ type dashboardResponse struct {
// @success 200 {object} dashboardResponse "Success"
// @failure 400 "Bad request"
// @failure 500 "Internal server error"
// @router /docker/{environmentId}/dashboard [post]
// @router /docker/{environmentId}/dashboard [get]
func (h *Handler) dashboard(w http.ResponseWriter, r *http.Request) *httperror.HandlerError {
var resp dashboardResponse
err := h.dataStore.ViewTx(func(tx dataservices.DataStoreTx) error {

View File

@@ -6,6 +6,7 @@ import (
"net/http"
"strings"
"github.com/portainer/portainer/api/docker/client"
"github.com/portainer/portainer/api/http/handler/docker/utils"
"github.com/portainer/portainer/api/set"
httperror "github.com/portainer/portainer/pkg/libhttp/error"
@@ -50,7 +51,7 @@ func (handler *Handler) imagesList(w http.ResponseWriter, r *http.Request) *http
nodeNames := make(map[string]string)
// Pass the node names map to the context so the custom NodeNameTransport can use it
ctx := context.WithValue(r.Context(), "nodeNames", nodeNames)
ctx := context.WithValue(r.Context(), client.NodeNamesCtxKey{}, nodeNames)
images, err := cli.ImageList(ctx, image.ListOptions{})
if err != nil {

View File

@@ -6,14 +6,13 @@ import (
"github.com/docker/docker/api/types"
"github.com/docker/docker/api/types/swarm"
portainer "github.com/portainer/portainer/api"
portaineree "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/dataservices"
dockerconsts "github.com/portainer/portainer/api/docker/consts"
"github.com/portainer/portainer/api/http/security"
)
type StackViewModel struct {
InternalStack *portaineree.Stack
InternalStack *portainer.Stack
ID portainer.StackID
Name string
@@ -23,7 +22,6 @@ type StackViewModel struct {
// GetDockerStacks retrieves all the stacks associated to a specific environment filtered by the user's access.
func GetDockerStacks(tx dataservices.DataStoreTx, securityContext *security.RestrictedRequestContext, environmentID portainer.EndpointID, containers []types.Container, services []swarm.Service) ([]StackViewModel, error) {
stacks, err := tx.Stack().ReadAll()
if err != nil {
return nil, fmt.Errorf("Unable to retrieve stacks: %w", err)

View File

@@ -6,15 +6,15 @@ import (
"github.com/docker/docker/api/types"
"github.com/docker/docker/api/types/swarm"
portainer "github.com/portainer/portainer/api"
portaineree "github.com/portainer/portainer/api"
dockerconsts "github.com/portainer/portainer/api/docker/consts"
"github.com/portainer/portainer/api/http/security"
"github.com/portainer/portainer/api/internal/testhelpers"
"github.com/stretchr/testify/assert"
)
func TestHandler_getDockerStacks(t *testing.T) {
environment := &portaineree.Endpoint{
environment := &portainer.Endpoint{
ID: 1,
SecuritySettings: portainer.EndpointSecuritySettings{
AllowStackManagementForRegularUsers: true,
@@ -46,7 +46,7 @@ func TestHandler_getDockerStacks(t *testing.T) {
},
}
stack1 := portaineree.Stack{
stack1 := portainer.Stack{
ID: 1,
Name: "stack1",
EndpointID: 1,
@@ -54,8 +54,8 @@ func TestHandler_getDockerStacks(t *testing.T) {
}
datastore := testhelpers.NewDatastore(
testhelpers.WithEndpoints([]portaineree.Endpoint{*environment}),
testhelpers.WithStacks([]portaineree.Stack{
testhelpers.WithEndpoints([]portainer.Endpoint{*environment}),
testhelpers.WithStacks([]portainer.Stack{
stack1,
{
ID: 2,

View File

@@ -4,6 +4,7 @@ import (
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/dataservices"
"github.com/portainer/portainer/api/internal/endpointutils"
"github.com/portainer/portainer/api/roar"
)
type endpointSetType map[portainer.EndpointID]bool
@@ -49,22 +50,29 @@ func GetEndpointsByTags(tx dataservices.DataStoreTx, tagIDs []portainer.TagID, p
return results, nil
}
func getTrustedEndpoints(tx dataservices.DataStoreTx, endpointIDs []portainer.EndpointID) ([]portainer.EndpointID, error) {
func getTrustedEndpoints(tx dataservices.DataStoreTx, endpointIDs roar.Roar[portainer.EndpointID]) ([]portainer.EndpointID, error) {
var innerErr error
results := []portainer.EndpointID{}
for _, endpointID := range endpointIDs {
endpointIDs.Iterate(func(endpointID portainer.EndpointID) bool {
endpoint, err := tx.Endpoint().Endpoint(endpointID)
if err != nil {
return nil, err
innerErr = err
return false
}
if !endpoint.UserTrusted {
continue
return true
}
results = append(results, endpoint.ID)
}
return results, nil
return true
})
return results, innerErr
}
func mapEndpointGroupToEndpoints(endpoints []portainer.Endpoint) map[portainer.EndpointGroupID]endpointSetType {

View File

@@ -7,6 +7,7 @@ import (
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/dataservices"
"github.com/portainer/portainer/api/internal/endpointutils"
"github.com/portainer/portainer/api/roar"
httperror "github.com/portainer/portainer/pkg/libhttp/error"
"github.com/portainer/portainer/pkg/libhttp/request"
)
@@ -52,6 +53,7 @@ func calculateEndpointsOrTags(tx dataservices.DataStoreTx, edgeGroup *portainer.
}
edgeGroup.Endpoints = endpointIDs
edgeGroup.EndpointIDs = roar.FromSlice(endpointIDs)
return nil
}
@@ -94,6 +96,7 @@ func (handler *Handler) edgeGroupCreate(w http.ResponseWriter, r *http.Request)
Dynamic: payload.Dynamic,
TagIDs: []portainer.TagID{},
Endpoints: []portainer.EndpointID{},
EndpointIDs: roar.Roar[portainer.EndpointID]{},
PartialMatch: payload.PartialMatch,
}
@@ -108,5 +111,5 @@ func (handler *Handler) edgeGroupCreate(w http.ResponseWriter, r *http.Request)
return nil
})
return txResponse(w, edgeGroup, err)
return txResponse(w, shadowedEdgeGroup{EdgeGroup: *edgeGroup}, err)
}

View File

@@ -0,0 +1,62 @@
package edgegroups
import (
"net/http"
"net/http/httptest"
"strconv"
"strings"
"testing"
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/datastore"
"github.com/portainer/portainer/api/internal/testhelpers"
"github.com/segmentio/encoding/json"
"github.com/stretchr/testify/require"
)
func TestEdgeGroupCreateHandler(t *testing.T) {
_, store := datastore.MustNewTestStore(t, true, true)
handler := NewHandler(testhelpers.NewTestRequestBouncer())
handler.DataStore = store
err := store.EndpointGroup().Create(&portainer.EndpointGroup{
ID: 1,
Name: "Test Group",
})
require.NoError(t, err)
for i := range 3 {
err = store.Endpoint().Create(&portainer.Endpoint{
ID: portainer.EndpointID(i + 1),
Name: "Test Endpoint " + strconv.Itoa(i+1),
Type: portainer.EdgeAgentOnDockerEnvironment,
GroupID: 1,
})
require.NoError(t, err)
err = store.EndpointRelation().Create(&portainer.EndpointRelation{
EndpointID: portainer.EndpointID(i + 1),
EdgeStacks: map[portainer.EdgeStackID]bool{},
})
require.NoError(t, err)
}
rr := httptest.NewRecorder()
req := httptest.NewRequest(
http.MethodPost,
"/edge_groups",
strings.NewReader(`{"Name": "New Edge Group", "Endpoints": [1, 2, 3]}`),
)
handler.ServeHTTP(rr, req)
require.Equal(t, http.StatusOK, rr.Result().StatusCode)
var responseGroup portainer.EdgeGroup
err = json.NewDecoder(rr.Body).Decode(&responseGroup)
require.NoError(t, err)
require.ElementsMatch(t, []portainer.EndpointID{1, 2, 3}, responseGroup.Endpoints)
}

View File

@@ -5,6 +5,7 @@ import (
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/dataservices"
"github.com/portainer/portainer/api/roar"
httperror "github.com/portainer/portainer/pkg/libhttp/error"
"github.com/portainer/portainer/pkg/libhttp/request"
)
@@ -33,7 +34,9 @@ func (handler *Handler) edgeGroupInspect(w http.ResponseWriter, r *http.Request)
return err
})
return txResponse(w, edgeGroup, err)
edgeGroup.Endpoints = edgeGroup.EndpointIDs.ToSlice()
return txResponse(w, shadowedEdgeGroup{EdgeGroup: *edgeGroup}, err)
}
func getEdgeGroup(tx dataservices.DataStoreTx, ID portainer.EdgeGroupID) (*portainer.EdgeGroup, error) {
@@ -50,7 +53,7 @@ func getEdgeGroup(tx dataservices.DataStoreTx, ID portainer.EdgeGroupID) (*porta
return nil, httperror.InternalServerError("Unable to retrieve environments and environment groups for Edge group", err)
}
edgeGroup.Endpoints = endpoints
edgeGroup.EndpointIDs = roar.FromSlice(endpoints)
}
return edgeGroup, err

View File

@@ -0,0 +1,176 @@
package edgegroups
import (
"net/http"
"net/http/httptest"
"strconv"
"testing"
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/datastore"
"github.com/portainer/portainer/api/internal/testhelpers"
"github.com/portainer/portainer/api/roar"
"github.com/segmentio/encoding/json"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestEdgeGroupInspectHandler(t *testing.T) {
_, store := datastore.MustNewTestStore(t, true, true)
handler := NewHandler(testhelpers.NewTestRequestBouncer())
handler.DataStore = store
err := store.EndpointGroup().Create(&portainer.EndpointGroup{
ID: 1,
Name: "Test Group",
})
require.NoError(t, err)
for i := range 3 {
err = store.Endpoint().Create(&portainer.Endpoint{
ID: portainer.EndpointID(i + 1),
Name: "Test Endpoint " + strconv.Itoa(i+1),
Type: portainer.EdgeAgentOnDockerEnvironment,
GroupID: 1,
})
require.NoError(t, err)
err = store.EndpointRelation().Create(&portainer.EndpointRelation{
EndpointID: portainer.EndpointID(i + 1),
EdgeStacks: map[portainer.EdgeStackID]bool{},
})
require.NoError(t, err)
}
err = store.EdgeGroup().Create(&portainer.EdgeGroup{
ID: 1,
Name: "Test Edge Group",
EndpointIDs: roar.FromSlice([]portainer.EndpointID{1, 2, 3}),
})
require.NoError(t, err)
rr := httptest.NewRecorder()
req := httptest.NewRequest(
http.MethodGet,
"/edge_groups/1",
nil,
)
handler.ServeHTTP(rr, req)
require.Equal(t, http.StatusOK, rr.Result().StatusCode)
var responseGroup portainer.EdgeGroup
err = json.NewDecoder(rr.Body).Decode(&responseGroup)
require.NoError(t, err)
assert.ElementsMatch(t, []portainer.EndpointID{1, 2, 3}, responseGroup.Endpoints)
}
func TestEmptyEdgeGroupInspectHandler(t *testing.T) {
_, store := datastore.MustNewTestStore(t, true, true)
handler := NewHandler(testhelpers.NewTestRequestBouncer())
handler.DataStore = store
err := store.EndpointGroup().Create(&portainer.EndpointGroup{
ID: 1,
Name: "Test Group",
})
require.NoError(t, err)
err = store.EdgeGroup().Create(&portainer.EdgeGroup{
ID: 1,
Name: "Test Edge Group",
EndpointIDs: roar.Roar[portainer.EndpointID]{},
})
require.NoError(t, err)
rr := httptest.NewRecorder()
req := httptest.NewRequest(
http.MethodGet,
"/edge_groups/1",
nil,
)
handler.ServeHTTP(rr, req)
require.Equal(t, http.StatusOK, rr.Result().StatusCode)
var responseGroup portainer.EdgeGroup
err = json.NewDecoder(rr.Body).Decode(&responseGroup)
require.NoError(t, err)
// Make sure the frontend does not get a null value but a [] instead
require.NotNil(t, responseGroup.Endpoints)
require.Len(t, responseGroup.Endpoints, 0)
}
func TestDynamicEdgeGroupInspectHandler(t *testing.T) {
_, store := datastore.MustNewTestStore(t, true, true)
handler := NewHandler(testhelpers.NewTestRequestBouncer())
handler.DataStore = store
err := store.EndpointGroup().Create(&portainer.EndpointGroup{
ID: 1,
Name: "Test Group",
})
require.NoError(t, err)
err = store.Tag().Create(&portainer.Tag{
ID: 1,
Name: "Test Tag",
Endpoints: map[portainer.EndpointID]bool{
1: true,
2: true,
3: true,
},
})
require.NoError(t, err)
for i := range 3 {
err = store.Endpoint().Create(&portainer.Endpoint{
ID: portainer.EndpointID(i + 1),
Name: "Test Endpoint " + strconv.Itoa(i+1),
Type: portainer.EdgeAgentOnDockerEnvironment,
GroupID: 1,
TagIDs: []portainer.TagID{1},
UserTrusted: true,
})
require.NoError(t, err)
err = store.EndpointRelation().Create(&portainer.EndpointRelation{
EndpointID: portainer.EndpointID(i + 1),
EdgeStacks: map[portainer.EdgeStackID]bool{},
})
require.NoError(t, err)
}
err = store.EdgeGroup().Create(&portainer.EdgeGroup{
ID: 1,
Name: "Test Edge Group",
Dynamic: true,
TagIDs: []portainer.TagID{1},
})
require.NoError(t, err)
rr := httptest.NewRecorder()
req := httptest.NewRequest(
http.MethodGet,
"/edge_groups/1",
nil,
)
handler.ServeHTTP(rr, req)
require.Equal(t, http.StatusOK, rr.Result().StatusCode)
var responseGroup portainer.EdgeGroup
err = json.NewDecoder(rr.Body).Decode(&responseGroup)
require.NoError(t, err)
require.ElementsMatch(t, []portainer.EndpointID{1, 2, 3}, responseGroup.Endpoints)
}

View File

@@ -7,11 +7,17 @@ import (
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/dataservices"
"github.com/portainer/portainer/api/roar"
httperror "github.com/portainer/portainer/pkg/libhttp/error"
)
type decoratedEdgeGroup struct {
type shadowedEdgeGroup struct {
portainer.EdgeGroup
EndpointIds int `json:"EndpointIds,omitempty"` // Shadow to avoid exposing in the API
}
type decoratedEdgeGroup struct {
shadowedEdgeGroup
HasEdgeStack bool `json:"HasEdgeStack"`
HasEdgeJob bool `json:"HasEdgeJob"`
EndpointTypes []portainer.EndpointType
@@ -76,8 +82,8 @@ func getEdgeGroupList(tx dataservices.DataStoreTx) ([]decoratedEdgeGroup, error)
}
edgeGroup := decoratedEdgeGroup{
EdgeGroup: orgEdgeGroup,
EndpointTypes: []portainer.EndpointType{},
shadowedEdgeGroup: shadowedEdgeGroup{EdgeGroup: orgEdgeGroup},
EndpointTypes: []portainer.EndpointType{},
}
if edgeGroup.Dynamic {
endpointIDs, err := GetEndpointsByTags(tx, edgeGroup.TagIDs, edgeGroup.PartialMatch)
@@ -88,15 +94,16 @@ func getEdgeGroupList(tx dataservices.DataStoreTx) ([]decoratedEdgeGroup, error)
edgeGroup.Endpoints = endpointIDs
edgeGroup.TrustedEndpoints = endpointIDs
} else {
trustedEndpoints, err := getTrustedEndpoints(tx, edgeGroup.Endpoints)
trustedEndpoints, err := getTrustedEndpoints(tx, edgeGroup.EndpointIDs)
if err != nil {
return nil, httperror.InternalServerError("Unable to retrieve environments for Edge group", err)
}
edgeGroup.Endpoints = edgeGroup.EndpointIDs.ToSlice()
edgeGroup.TrustedEndpoints = trustedEndpoints
}
endpointTypes, err := getEndpointTypes(tx, edgeGroup.Endpoints)
endpointTypes, err := getEndpointTypes(tx, edgeGroup.EndpointIDs)
if err != nil {
return nil, httperror.InternalServerError("Unable to retrieve environment types for Edge group", err)
}
@@ -111,15 +118,26 @@ func getEdgeGroupList(tx dataservices.DataStoreTx) ([]decoratedEdgeGroup, error)
return decoratedEdgeGroups, nil
}
func getEndpointTypes(tx dataservices.DataStoreTx, endpointIds []portainer.EndpointID) ([]portainer.EndpointType, error) {
func getEndpointTypes(tx dataservices.DataStoreTx, endpointIds roar.Roar[portainer.EndpointID]) ([]portainer.EndpointType, error) {
var innerErr error
typeSet := map[portainer.EndpointType]bool{}
for _, endpointID := range endpointIds {
endpointIds.Iterate(func(endpointID portainer.EndpointID) bool {
endpoint, err := tx.Endpoint().Endpoint(endpointID)
if err != nil {
return nil, fmt.Errorf("failed fetching environment: %w", err)
innerErr = fmt.Errorf("failed fetching environment: %w", err)
return false
}
typeSet[endpoint.Type] = true
return true
})
if innerErr != nil {
return nil, innerErr
}
endpointTypes := make([]portainer.EndpointType, 0, len(typeSet))

View File

@@ -1,11 +1,19 @@
package edgegroups
import (
"net/http"
"net/http/httptest"
"strconv"
"testing"
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/datastore"
"github.com/portainer/portainer/api/internal/testhelpers"
"github.com/portainer/portainer/api/roar"
"github.com/segmentio/encoding/json"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_getEndpointTypes(t *testing.T) {
@@ -38,7 +46,7 @@ func Test_getEndpointTypes(t *testing.T) {
}
for _, test := range tests {
ans, err := getEndpointTypes(datastore, test.endpointIds)
ans, err := getEndpointTypes(datastore, roar.FromSlice(test.endpointIds))
assert.NoError(t, err, "getEndpointTypes shouldn't fail")
assert.ElementsMatch(t, test.expected, ans, "getEndpointTypes expected to return %b for %v, but returned %b", test.expected, test.endpointIds, ans)
@@ -48,6 +56,61 @@ func Test_getEndpointTypes(t *testing.T) {
func Test_getEndpointTypes_failWhenEndpointDontExist(t *testing.T) {
datastore := testhelpers.NewDatastore(testhelpers.WithEndpoints([]portainer.Endpoint{}))
_, err := getEndpointTypes(datastore, []portainer.EndpointID{1})
_, err := getEndpointTypes(datastore, roar.FromSlice([]portainer.EndpointID{1}))
assert.Error(t, err, "getEndpointTypes should fail")
}
func TestEdgeGroupListHandler(t *testing.T) {
_, store := datastore.MustNewTestStore(t, true, true)
handler := NewHandler(testhelpers.NewTestRequestBouncer())
handler.DataStore = store
err := store.EndpointGroup().Create(&portainer.EndpointGroup{
ID: 1,
Name: "Test Group",
})
require.NoError(t, err)
for i := range 3 {
err = store.Endpoint().Create(&portainer.Endpoint{
ID: portainer.EndpointID(i + 1),
Name: "Test Endpoint " + strconv.Itoa(i+1),
Type: portainer.EdgeAgentOnDockerEnvironment,
GroupID: 1,
})
require.NoError(t, err)
err = store.EndpointRelation().Create(&portainer.EndpointRelation{
EndpointID: portainer.EndpointID(i + 1),
EdgeStacks: map[portainer.EdgeStackID]bool{},
})
require.NoError(t, err)
}
err = store.EdgeGroup().Create(&portainer.EdgeGroup{
ID: 1,
Name: "Test Edge Group",
EndpointIDs: roar.FromSlice([]portainer.EndpointID{1, 2, 3}),
})
require.NoError(t, err)
rr := httptest.NewRecorder()
req := httptest.NewRequest(
http.MethodGet,
"/edge_groups",
nil,
)
handler.ServeHTTP(rr, req)
require.Equal(t, http.StatusOK, rr.Result().StatusCode)
var responseGroups []decoratedEdgeGroup
err = json.NewDecoder(rr.Body).Decode(&responseGroups)
require.NoError(t, err)
require.Len(t, responseGroups, 1)
require.ElementsMatch(t, []portainer.EndpointID{1, 2, 3}, responseGroups[0].Endpoints)
require.Len(t, responseGroups[0].TrustedEndpoints, 0)
}

View File

@@ -158,7 +158,7 @@ func (handler *Handler) edgeGroupUpdate(w http.ResponseWriter, r *http.Request)
return nil
})
return txResponse(w, edgeGroup, err)
return txResponse(w, shadowedEdgeGroup{EdgeGroup: *edgeGroup}, err)
}
func (handler *Handler) updateEndpointStacks(tx dataservices.DataStoreTx, endpoint *portainer.Endpoint, edgeGroups []portainer.EdgeGroup, edgeStacks []portainer.EdgeStack) error {

View File

@@ -0,0 +1,70 @@
package edgegroups
import (
"net/http"
"net/http/httptest"
"strconv"
"strings"
"testing"
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/datastore"
"github.com/portainer/portainer/api/internal/testhelpers"
"github.com/portainer/portainer/api/roar"
"github.com/segmentio/encoding/json"
"github.com/stretchr/testify/require"
)
func TestEdgeGroupUpdateHandler(t *testing.T) {
_, store := datastore.MustNewTestStore(t, true, true)
handler := NewHandler(testhelpers.NewTestRequestBouncer())
handler.DataStore = store
err := store.EndpointGroup().Create(&portainer.EndpointGroup{
ID: 1,
Name: "Test Group",
})
require.NoError(t, err)
for i := range 3 {
err = store.Endpoint().Create(&portainer.Endpoint{
ID: portainer.EndpointID(i + 1),
Name: "Test Endpoint " + strconv.Itoa(i+1),
Type: portainer.EdgeAgentOnDockerEnvironment,
GroupID: 1,
})
require.NoError(t, err)
err = store.EndpointRelation().Create(&portainer.EndpointRelation{
EndpointID: portainer.EndpointID(i + 1),
EdgeStacks: map[portainer.EdgeStackID]bool{},
})
require.NoError(t, err)
}
err = store.EdgeGroup().Create(&portainer.EdgeGroup{
ID: 1,
Name: "Test Edge Group",
EndpointIDs: roar.FromSlice([]portainer.EndpointID{1}),
})
require.NoError(t, err)
rr := httptest.NewRecorder()
req := httptest.NewRequest(
http.MethodPut,
"/edge_groups/1",
strings.NewReader(`{"Endpoints": [1, 2, 3]}`),
)
handler.ServeHTTP(rr, req)
require.Equal(t, http.StatusOK, rr.Result().StatusCode)
var responseGroup portainer.EdgeGroup
err = json.NewDecoder(rr.Body).Decode(&responseGroup)
require.NoError(t, err)
require.ElementsMatch(t, []portainer.EndpointID{1, 2, 3}, responseGroup.Endpoints)
}

View File

@@ -1,21 +1,25 @@
package edgejobs
import (
"errors"
"fmt"
"maps"
"net/http"
"strings"
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/dataservices"
"github.com/portainer/portainer/api/http/utils/filters"
"github.com/portainer/portainer/api/internal/edge"
httperror "github.com/portainer/portainer/pkg/libhttp/error"
"github.com/portainer/portainer/pkg/libhttp/request"
)
type taskContainer struct {
ID string `json:"Id"`
EndpointID portainer.EndpointID `json:"EndpointId"`
LogsStatus portainer.EdgeJobLogsStatus `json:"LogsStatus"`
ID string `json:"Id"`
EndpointID portainer.EndpointID `json:"EndpointId"`
EndpointName string `json:"EndpointName"`
LogsStatus portainer.EdgeJobLogsStatus `json:"LogsStatus"`
}
// @id EdgeJobTasksList
@@ -37,16 +41,42 @@ func (handler *Handler) edgeJobTasksList(w http.ResponseWriter, r *http.Request)
return httperror.BadRequest("Invalid Edge job identifier route variable", err)
}
var tasks []taskContainer
err = handler.DataStore.UpdateTx(func(tx dataservices.DataStoreTx) error {
params := filters.ExtractListModifiersQueryParams(r)
var tasks []*taskContainer
err = handler.DataStore.ViewTx(func(tx dataservices.DataStoreTx) error {
tasks, err = listEdgeJobTasks(tx, portainer.EdgeJobID(edgeJobID))
return err
})
return txResponse(w, tasks, err)
results := filters.SearchOrderAndPaginate(tasks, params, filters.Config[*taskContainer]{
SearchAccessors: []filters.SearchAccessor[*taskContainer]{
func(tc *taskContainer) (string, error) {
switch tc.LogsStatus {
case portainer.EdgeJobLogsStatusPending:
return "pending", nil
case 0, portainer.EdgeJobLogsStatusIdle:
return "idle", nil
case portainer.EdgeJobLogsStatusCollected:
return "collected", nil
}
return "", errors.New("unknown state")
},
func(tc *taskContainer) (string, error) {
return tc.EndpointName, nil
},
},
SortBindings: []filters.SortBinding[*taskContainer]{
{Key: "EndpointName", Fn: func(a, b *taskContainer) int { return strings.Compare(a.EndpointName, b.EndpointName) }},
},
})
filters.ApplyFilterResultsHeaders(&w, results)
return txResponse(w, results.Items, err)
}
func listEdgeJobTasks(tx dataservices.DataStoreTx, edgeJobID portainer.EdgeJobID) ([]taskContainer, error) {
func listEdgeJobTasks(tx dataservices.DataStoreTx, edgeJobID portainer.EdgeJobID) ([]*taskContainer, error) {
edgeJob, err := tx.EdgeJob().Read(edgeJobID)
if tx.IsErrObjectNotFound(err) {
return nil, httperror.NotFound("Unable to find an Edge job with the specified identifier inside the database", err)
@@ -54,7 +84,12 @@ func listEdgeJobTasks(tx dataservices.DataStoreTx, edgeJobID portainer.EdgeJobID
return nil, httperror.InternalServerError("Unable to find an Edge job with the specified identifier inside the database", err)
}
tasks := make([]taskContainer, 0)
endpoints, err := tx.Endpoint().Endpoints()
if err != nil {
return nil, err
}
tasks := make([]*taskContainer, 0)
endpointsMap := map[portainer.EndpointID]portainer.EdgeJobEndpointMeta{}
if len(edgeJob.EdgeGroups) > 0 {
@@ -70,10 +105,19 @@ func listEdgeJobTasks(tx dataservices.DataStoreTx, edgeJobID portainer.EdgeJobID
maps.Copy(endpointsMap, edgeJob.Endpoints)
for endpointID, meta := range endpointsMap {
tasks = append(tasks, taskContainer{
ID: fmt.Sprintf("edgejob_task_%d_%d", edgeJob.ID, endpointID),
EndpointID: endpointID,
LogsStatus: meta.LogsStatus,
endpointName := ""
for idx := range endpoints {
if endpoints[idx].ID == endpointID {
endpointName = endpoints[idx].Name
}
}
tasks = append(tasks, &taskContainer{
ID: fmt.Sprintf("edgejob_task_%d_%d", edgeJob.ID, endpointID),
EndpointID: endpointID,
EndpointName: endpointName,
LogsStatus: meta.LogsStatus,
})
}

View File

@@ -0,0 +1,131 @@
package edgejobs
import (
"encoding/json"
"net/http"
"net/http/httptest"
"strconv"
"testing"
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/datastore"
"github.com/portainer/portainer/api/internal/testhelpers"
"github.com/portainer/portainer/api/roar"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_EdgeJobTasksListHandler(t *testing.T) {
_, store := datastore.MustNewTestStore(t, true, false)
handler := NewHandler(testhelpers.NewTestRequestBouncer())
handler.DataStore = store
addEnv := func(env *portainer.Endpoint) {
err := store.EndpointService.Create(env)
require.NoError(t, err)
}
addEdgeGroup := func(group *portainer.EdgeGroup) {
err := store.EdgeGroupService.Create(group)
require.NoError(t, err)
}
addJob := func(job *portainer.EdgeJob) {
err := store.EdgeJobService.Create(job)
require.NoError(t, err)
}
envCount := 6
for i := range envCount {
addEnv(&portainer.Endpoint{ID: portainer.EndpointID(i + 1), Name: "env_" + strconv.Itoa(i+1)})
}
addEdgeGroup(&portainer.EdgeGroup{ID: 1, Name: "edge_group_1", EndpointIDs: roar.FromSlice([]portainer.EndpointID{5, 6})})
addJob(&portainer.EdgeJob{
ID: 1,
Endpoints: map[portainer.EndpointID]portainer.EdgeJobEndpointMeta{
1: {},
2: {LogsStatus: portainer.EdgeJobLogsStatusIdle},
3: {LogsStatus: portainer.EdgeJobLogsStatusPending},
4: {LogsStatus: portainer.EdgeJobLogsStatusCollected}},
EdgeGroups: []portainer.EdgeGroupID{1},
})
test := func(params string, expect []taskContainer, expectedCount int) {
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/edge_jobs/1/tasks"+params, nil)
handler.ServeHTTP(rr, req)
require.Equal(t, http.StatusOK, rr.Result().StatusCode)
var response []taskContainer
err := json.NewDecoder(rr.Body).Decode(&response)
require.NoError(t, err)
assert.ElementsMatch(t, expect, response)
tcStr := rr.Header().Get("x-total-count")
assert.NotEmpty(t, tcStr)
totalCount, err := strconv.Atoi(tcStr)
assert.NoError(t, err)
assert.Equal(t, expectedCount, totalCount)
taStr := rr.Header().Get("x-total-available")
assert.NotEmpty(t, taStr)
totalAvailable, err := strconv.Atoi(taStr)
assert.NoError(t, err)
assert.Equal(t, envCount, totalAvailable)
}
tasks := []taskContainer{
{},
{"edgejob_task_1_1", 1, "env_1", 0},
{"edgejob_task_1_2", 2, "env_2", portainer.EdgeJobLogsStatusIdle},
{"edgejob_task_1_3", 3, "env_3", portainer.EdgeJobLogsStatusPending},
{"edgejob_task_1_4", 4, "env_4", portainer.EdgeJobLogsStatusCollected},
{"edgejob_task_1_5", 5, "env_5", 0},
{"edgejob_task_1_6", 6, "env_6", 0},
}
t.Run("should return no results", func(t *testing.T) {
test("?search=foo", []taskContainer{}, 0) // unknown search
test("?start=100&limit=1", []taskContainer{}, 6) // overflowing start. Still return the correct count header
})
t.Run("should return one element", func(t *testing.T) {
// limit the *returned* results but not the total count
test("?start=0&limit=1&sort=EndpointName&order=asc", []taskContainer{tasks[1]}, envCount) // limit
test("?start=5&limit=10&sort=EndpointName&order=asc", []taskContainer{tasks[6]}, envCount) // start = last element + overflowing limit
// limit the number of results
test("?search=env_1", []taskContainer{tasks[1]}, 1) // only 1 result
})
t.Run("should filter by status", func(t *testing.T) {
test("?search=idle", []taskContainer{tasks[1], tasks[2], tasks[5], tasks[6]}, 4) // 0 (default value) is IDLE
test("?search=pending", []taskContainer{tasks[3]}, 1)
test("?search=collected", []taskContainer{tasks[4]}, 1)
})
t.Run("should return all elements", func(t *testing.T) {
test("", tasks[1:], envCount) // default
test("?some=invalid_param", tasks[1:], envCount) // unknown query params
test("?limit=-1", tasks[1:], envCount) // underflowing limit
test("?start=100", tasks[1:], envCount) // overflowing start without limit
test("?search=env", tasks[1:], envCount) // search in a match-all keyword
})
testError := func(params string, status int) {
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/edge_jobs/2/tasks"+params, nil)
handler.ServeHTTP(rr, req)
require.Equal(t, status, rr.Result().StatusCode)
}
t.Run("errors", func(t *testing.T) {
testError("", http.StatusNotFound) // unknown job id
})
}

View File

@@ -33,6 +33,8 @@ type edgeStackFromGitRepositoryPayload struct {
RepositoryUsername string `example:"myGitUsername"`
// Password used in basic authentication. Required when RepositoryAuthentication is true.
RepositoryPassword string `example:"myGitPassword"`
// RepositoryAuthorizationType is the authorization type to use
RepositoryAuthorizationType gittypes.GitCredentialAuthType `example:"0"`
// Path to the Stack file inside the Git repository
FilePathInRepository string `example:"docker-compose.yml" default:"docker-compose.yml"`
// List of identifiers of EdgeGroups
@@ -125,8 +127,9 @@ func (handler *Handler) createEdgeStackFromGitRepository(r *http.Request, tx dat
if payload.RepositoryAuthentication {
repoConfig.Authentication = &gittypes.GitAuthentication{
Username: payload.RepositoryUsername,
Password: payload.RepositoryPassword,
Username: payload.RepositoryUsername,
Password: payload.RepositoryPassword,
AuthorizationType: payload.RepositoryAuthorizationType,
}
}
@@ -145,12 +148,22 @@ func (handler *Handler) storeManifestFromGitRepository(tx dataservices.DataStore
projectPath = handler.FileService.GetEdgeStackProjectPath(stackFolder)
repositoryUsername := ""
repositoryPassword := ""
repositoryAuthType := gittypes.GitCredentialAuthType_Basic
if repositoryConfig.Authentication != nil && repositoryConfig.Authentication.Password != "" {
repositoryUsername = repositoryConfig.Authentication.Username
repositoryPassword = repositoryConfig.Authentication.Password
repositoryAuthType = repositoryConfig.Authentication.AuthorizationType
}
if err := handler.GitService.CloneRepository(projectPath, repositoryConfig.URL, repositoryConfig.ReferenceName, repositoryUsername, repositoryPassword, repositoryConfig.TLSSkipVerify); err != nil {
if err := handler.GitService.CloneRepository(
projectPath,
repositoryConfig.URL,
repositoryConfig.ReferenceName,
repositoryUsername,
repositoryPassword,
repositoryAuthType,
repositoryConfig.TLSSkipVerify,
); err != nil {
return "", "", "", err
}

View File

@@ -8,9 +8,10 @@ import (
"testing"
portainer "github.com/portainer/portainer/api"
"github.com/stretchr/testify/require"
"github.com/portainer/portainer/api/roar"
"github.com/segmentio/encoding/json"
"github.com/stretchr/testify/require"
)
// Create
@@ -24,7 +25,7 @@ func TestCreateAndInspect(t *testing.T) {
Name: "EdgeGroup 1",
Dynamic: false,
TagIDs: nil,
Endpoints: []portainer.EndpointID{endpoint.ID},
EndpointIDs: roar.FromSlice([]portainer.EndpointID{endpoint.ID}),
PartialMatch: false,
}

View File

@@ -15,6 +15,7 @@ import (
"github.com/portainer/portainer/api/internal/edge/edgestacks"
"github.com/portainer/portainer/api/internal/testhelpers"
"github.com/portainer/portainer/api/jwt"
"github.com/portainer/portainer/api/roar"
"github.com/pkg/errors"
"github.com/stretchr/testify/require"
@@ -103,7 +104,7 @@ func createEdgeStack(t *testing.T, store dataservices.DataStore, endpointID port
Name: "EdgeGroup 1",
Dynamic: false,
TagIDs: nil,
Endpoints: []portainer.EndpointID{endpointID},
EndpointIDs: roar.FromSlice([]portainer.EndpointID{endpointID}),
PartialMatch: false,
}

View File

@@ -99,7 +99,7 @@ func (handler *Handler) updateEdgeStack(tx dataservices.DataStoreTx, stackID por
groupsIds := stack.EdgeGroups
if payload.EdgeGroups != nil {
newRelated, _, err := handler.handleChangeEdgeGroups(tx, stack.ID, payload.EdgeGroups, relatedEndpointIds, relationConfig)
newRelated, _, err := handler.handleChangeEdgeGroups(tx, stack, payload.EdgeGroups, relatedEndpointIds, relationConfig)
if err != nil {
return nil, httperror.InternalServerError("Unable to handle edge groups change", err)
}
@@ -136,7 +136,7 @@ func (handler *Handler) updateEdgeStack(tx dataservices.DataStoreTx, stackID por
return stack, nil
}
func (handler *Handler) handleChangeEdgeGroups(tx dataservices.DataStoreTx, edgeStackID portainer.EdgeStackID, newEdgeGroupsIDs []portainer.EdgeGroupID, oldRelatedEnvironmentIDs []portainer.EndpointID, relationConfig *edge.EndpointRelationsConfig) ([]portainer.EndpointID, set.Set[portainer.EndpointID], error) {
func (handler *Handler) handleChangeEdgeGroups(tx dataservices.DataStoreTx, edgeStack *portainer.EdgeStack, newEdgeGroupsIDs []portainer.EdgeGroupID, oldRelatedEnvironmentIDs []portainer.EndpointID, relationConfig *edge.EndpointRelationsConfig) ([]portainer.EndpointID, set.Set[portainer.EndpointID], error) {
newRelatedEnvironmentIDs, err := edge.EdgeStackRelatedEndpoints(newEdgeGroupsIDs, relationConfig.Endpoints, relationConfig.EndpointGroups, relationConfig.EdgeGroups)
if err != nil {
return nil, nil, errors.WithMessage(err, "Unable to retrieve edge stack related environments from database")
@@ -149,13 +149,13 @@ func (handler *Handler) handleChangeEdgeGroups(tx dataservices.DataStoreTx, edge
relatedEnvironmentsToRemove := oldRelatedEnvironmentsSet.Difference(newRelatedEnvironmentsSet)
if len(relatedEnvironmentsToRemove) > 0 {
if err := tx.EndpointRelation().RemoveEndpointRelationsForEdgeStack(relatedEnvironmentsToRemove.Keys(), edgeStackID); err != nil {
if err := tx.EndpointRelation().RemoveEndpointRelationsForEdgeStack(relatedEnvironmentsToRemove.Keys(), edgeStack.ID); err != nil {
return nil, nil, errors.WithMessage(err, "Unable to remove edge stack relations from the database")
}
}
if len(relatedEnvironmentsToAdd) > 0 {
if err := tx.EndpointRelation().AddEndpointRelationsForEdgeStack(relatedEnvironmentsToAdd.Keys(), edgeStackID); err != nil {
if err := tx.EndpointRelation().AddEndpointRelationsForEdgeStack(relatedEnvironmentsToAdd.Keys(), edgeStack); err != nil {
return nil, nil, errors.WithMessage(err, "Unable to add edge stack relations to the database")
}
}

View File

@@ -9,9 +9,10 @@ import (
"testing"
portainer "github.com/portainer/portainer/api"
"github.com/stretchr/testify/require"
"github.com/portainer/portainer/api/roar"
"github.com/segmentio/encoding/json"
"github.com/stretchr/testify/require"
)
// Update
@@ -43,7 +44,7 @@ func TestUpdateAndInspect(t *testing.T) {
Name: "EdgeGroup 2",
Dynamic: false,
TagIDs: nil,
Endpoints: []portainer.EndpointID{newEndpoint.ID},
EndpointIDs: roar.FromSlice([]portainer.EndpointID{newEndpoint.ID}),
PartialMatch: false,
}
@@ -112,7 +113,7 @@ func TestUpdateWithInvalidEdgeGroups(t *testing.T) {
Name: "EdgeGroup 2",
Dynamic: false,
TagIDs: nil,
Endpoints: []portainer.EndpointID{8889},
EndpointIDs: roar.FromSlice([]portainer.EndpointID{8889}),
PartialMatch: false,
}

View File

@@ -170,7 +170,7 @@ func (handler *Handler) inspectStatus(tx dataservices.DataStoreTx, r *http.Reque
Credentials: tunnel.Credentials,
}
schedules, handlerErr := handler.buildSchedules(tx, endpoint.ID)
schedules, handlerErr := handler.buildSchedules(tx, endpoint)
if handlerErr != nil {
return nil, handlerErr
}
@@ -208,7 +208,7 @@ func parseAgentPlatform(r *http.Request) (portainer.EndpointType, error) {
}
}
func (handler *Handler) buildSchedules(tx dataservices.DataStoreTx, endpointID portainer.EndpointID) ([]edgeJobResponse, *httperror.HandlerError) {
func (handler *Handler) buildSchedules(tx dataservices.DataStoreTx, endpoint *portainer.Endpoint) ([]edgeJobResponse, *httperror.HandlerError) {
schedules := []edgeJobResponse{}
edgeJobs, err := tx.EdgeJob().ReadAll()
@@ -216,11 +216,16 @@ func (handler *Handler) buildSchedules(tx dataservices.DataStoreTx, endpointID p
return nil, httperror.InternalServerError("Unable to retrieve Edge Jobs", err)
}
endpointGroups, err := tx.EndpointGroup().ReadAll()
if err != nil {
return nil, httperror.InternalServerError("Unable to retrieve endpoint groups", err)
}
for _, job := range edgeJobs {
_, endpointHasJob := job.Endpoints[endpointID]
_, endpointHasJob := job.Endpoints[endpoint.ID]
if !endpointHasJob {
for _, edgeGroupID := range job.EdgeGroups {
member, _, err := edge.EndpointInEdgeGroup(tx, endpointID, edgeGroupID)
member, _, err := edge.EndpointInEdgeGroup(tx, endpoint, edgeGroupID, endpointGroups)
if err != nil {
return nil, httperror.InternalServerError("Unable to retrieve relations", err)
} else if member {
@@ -236,10 +241,10 @@ func (handler *Handler) buildSchedules(tx dataservices.DataStoreTx, endpointID p
}
var collectLogs bool
if _, ok := job.GroupLogsCollection[endpointID]; ok {
collectLogs = job.GroupLogsCollection[endpointID].CollectLogs
if _, ok := job.GroupLogsCollection[endpoint.ID]; ok {
collectLogs = job.GroupLogsCollection[endpoint.ID].CollectLogs
} else {
collectLogs = job.Endpoints[endpointID].CollectLogs
collectLogs = job.Endpoints[endpoint.ID].CollectLogs
}
schedule := edgeJobResponse{

View File

@@ -16,6 +16,7 @@ import (
"github.com/portainer/portainer/api/filesystem"
"github.com/portainer/portainer/api/http/security"
"github.com/portainer/portainer/api/jwt"
"github.com/portainer/portainer/api/roar"
"github.com/segmentio/encoding/json"
"github.com/stretchr/testify/assert"
@@ -366,8 +367,8 @@ func TestEdgeJobsResponse(t *testing.T) {
unrelatedEndpoint := localCreateEndpoint(80, nil)
staticEdgeGroup := portainer.EdgeGroup{
ID: 1,
Endpoints: []portainer.EndpointID{endpointFromStaticEdgeGroup.ID},
ID: 1,
EndpointIDs: roar.FromSlice([]portainer.EndpointID{endpointFromStaticEdgeGroup.ID}),
}
err := handler.DataStore.EdgeGroup().Create(&staticEdgeGroup)
require.NoError(t, err)

View File

@@ -1,7 +1,6 @@
package endpoints
import (
"crypto/tls"
"errors"
"net/http"
"runtime"
@@ -285,8 +284,6 @@ func (handler *Handler) endpointCreate(w http.ResponseWriter, r *http.Request) *
}
func (handler *Handler) createEndpoint(tx dataservices.DataStoreTx, payload *endpointCreatePayload) (*portainer.Endpoint, *httperror.HandlerError) {
var err error
switch payload.EndpointCreationType {
case azureEnvironment:
return handler.createAzureEndpoint(tx, payload)
@@ -301,12 +298,9 @@ func (handler *Handler) createEndpoint(tx dataservices.DataStoreTx, payload *end
endpointType := portainer.DockerEnvironment
var agentVersion string
if payload.EndpointCreationType == agentEnvironment {
var tlsConfig *tls.Config
if payload.TLS {
tlsConfig, err = crypto.CreateTLSConfigurationFromBytes(payload.TLSCACertFile, payload.TLSCertFile, payload.TLSKeyFile, payload.TLSSkipVerify, payload.TLSSkipClientVerify)
if err != nil {
return nil, httperror.InternalServerError("Unable to create TLS configuration", err)
}
tlsConfig, err := crypto.CreateTLSConfigurationFromBytes(payload.TLS, payload.TLSCACertFile, payload.TLSCertFile, payload.TLSKeyFile, payload.TLSSkipVerify, payload.TLSSkipClientVerify)
if err != nil {
return nil, httperror.InternalServerError("Unable to create TLS configuration", err)
}
agentPlatform, version, err := agent.GetAgentVersionAndPlatform(payload.URL, tlsConfig)
@@ -378,10 +372,16 @@ func (handler *Handler) createEdgeAgentEndpoint(tx dataservices.DataStoreTx, pay
edgeKey := handler.ReverseTunnelService.GenerateEdgeKey(payload.URL, portainerHost, endpointID)
endpoint := &portainer.Endpoint{
ID: portainer.EndpointID(endpointID),
Name: payload.Name,
URL: portainerHost,
Type: portainer.EdgeAgentOnDockerEnvironment,
ID: portainer.EndpointID(endpointID),
Name: payload.Name,
URL: portainerHost,
Type: func() portainer.EndpointType {
// an empty container engine means that the endpoint is a Kubernetes endpoint
if payload.ContainerEngine == "" {
return portainer.EdgeAgentOnKubernetesEnvironment
}
return portainer.EdgeAgentOnDockerEnvironment
}(),
ContainerEngine: payload.ContainerEngine,
GroupID: portainer.EndpointGroupID(payload.GroupID),
Gpus: payload.Gpus,

View File

@@ -6,11 +6,11 @@ import (
"testing"
portainer "github.com/portainer/portainer/api"
helper "github.com/portainer/portainer/api/internal/testhelpers"
"github.com/portainer/portainer/api/internal/testhelpers"
)
func TestEmptyGlobalKey(t *testing.T) {
handler := NewHandler(helper.NewTestRequestBouncer())
handler := NewHandler(testhelpers.NewTestRequestBouncer())
req, err := http.NewRequest(http.MethodPost, "https://portainer.io:9443/endpoints/global-key", nil)
if err != nil {

View File

@@ -0,0 +1,172 @@
package endpoints
import (
"net/http"
"testing"
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/chisel"
"github.com/portainer/portainer/api/datastore"
"github.com/portainer/portainer/api/internal/testhelpers"
"github.com/portainer/portainer/pkg/fips"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
// EE-only kubeconfig validation tests removed for CE
func TestSaveEndpointAndUpdateAuthorizations(t *testing.T) {
_, store := datastore.MustNewTestStore(t, true, false)
endpointGroup := &portainer.EndpointGroup{
ID: 1,
Name: "test-endpoint-group",
}
err := store.EndpointGroup().Create(endpointGroup)
require.NoError(t, err)
h := &Handler{
DataStore: store,
}
testCases := []struct {
name string
endpointType portainer.EndpointType
expectRelation bool
}{
{
name: "create azure environment, expect no relation to be created",
endpointType: portainer.AzureEnvironment,
expectRelation: false,
},
{
name: "create edge agent environment, expect relation to be created",
endpointType: portainer.EdgeAgentOnDockerEnvironment,
expectRelation: true,
},
{
name: "create kubernetes environment, expect no relation to be created",
endpointType: portainer.KubernetesLocalEnvironment,
expectRelation: false,
},
{
name: "create kubeconfig environment, expect no relation to be created",
endpointType: portainer.AgentOnKubernetesEnvironment,
expectRelation: false,
},
{
name: "create agent docker environment, expect no relation to be created",
endpointType: portainer.AgentOnDockerEnvironment,
expectRelation: false,
},
{
name: "create unsecured environment, expect no relation to be created",
endpointType: portainer.DockerEnvironment,
expectRelation: false,
},
}
for _, testCase := range testCases {
t.Run(testCase.name, func(t *testing.T) {
endpoint := &portainer.Endpoint{
ID: portainer.EndpointID(store.Endpoint().GetNextIdentifier()),
Type: testCase.endpointType,
GroupID: portainer.EndpointGroupID(endpointGroup.ID),
}
err := h.saveEndpointAndUpdateAuthorizations(store, endpoint)
require.NoError(t, err)
relation, relationErr := store.EndpointRelation().EndpointRelation(endpoint.ID)
if testCase.expectRelation {
require.NoError(t, relationErr)
require.NotNil(t, relation)
} else {
require.Error(t, relationErr)
require.True(t, store.IsErrObjectNotFound(relationErr))
require.Nil(t, relation)
}
})
}
}
func TestCreateEndpointFailure(t *testing.T) {
fips.InitFIPS(false)
_, store := datastore.MustNewTestStore(t, true, false)
h := NewHandler(testhelpers.NewTestRequestBouncer())
h.DataStore = store
payload := &endpointCreatePayload{
Name: "Test Endpoint",
EndpointCreationType: agentEnvironment,
TLS: true,
TLSCertFile: []byte("invalid data"),
TLSKeyFile: []byte("invalid data"),
}
endpoint, httpErr := h.createEndpoint(store, payload)
require.NotNil(t, httpErr)
require.Equal(t, http.StatusInternalServerError, httpErr.StatusCode)
require.Nil(t, endpoint)
}
func TestCreateEdgeAgentEndpoint_ContainerEngineMapping(t *testing.T) {
fips.InitFIPS(false)
_, store := datastore.MustNewTestStore(t, true, false)
// required group for save flow
endpointGroup := &portainer.EndpointGroup{ID: 1, Name: "test-group"}
err := store.EndpointGroup().Create(endpointGroup)
require.NoError(t, err)
h := &Handler{
DataStore: store,
ReverseTunnelService: chisel.NewService(store, nil, nil),
}
tests := []struct {
name string
engine string
wantType portainer.EndpointType
}{
{
name: "empty engine -> EdgeAgentOnKubernetesEnvironment",
engine: "",
wantType: portainer.EdgeAgentOnKubernetesEnvironment,
},
{
name: "docker engine -> EdgeAgentOnDockerEnvironment",
engine: portainer.ContainerEngineDocker,
wantType: portainer.EdgeAgentOnDockerEnvironment,
},
{
name: "podman engine -> EdgeAgentOnDockerEnvironment",
engine: portainer.ContainerEnginePodman,
wantType: portainer.EdgeAgentOnDockerEnvironment,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
payload := &endpointCreatePayload{
Name: "edge-endpoint",
EndpointCreationType: edgeAgentEnvironment,
ContainerEngine: tc.engine,
GroupID: 1,
URL: "https://portainer.example:9443",
}
ep, httpErr := h.createEdgeAgentEndpoint(store, payload)
require.Nil(t, httpErr)
require.NotNil(t, ep)
assert.Equal(t, tc.wantType, ep.Type)
assert.Equal(t, tc.engine, ep.ContainerEngine)
})
}
}

View File

@@ -3,7 +3,6 @@ package endpoints
import (
"errors"
"net/http"
"slices"
"strconv"
portainer "github.com/portainer/portainer/api"
@@ -200,9 +199,7 @@ func (handler *Handler) deleteEndpoint(tx dataservices.DataStoreTx, endpointID p
}
for _, edgeGroup := range edgeGroups {
edgeGroup.Endpoints = slices.DeleteFunc(edgeGroup.Endpoints, func(e portainer.EndpointID) bool {
return e == endpoint.ID
})
edgeGroup.EndpointIDs.Remove(endpoint.ID)
if err := tx.EdgeGroup().Update(edgeGroup.ID, &edgeGroup); err != nil {
log.Warn().Err(err).Msg("Unable to update edge group")

View File

@@ -11,6 +11,7 @@ import (
"github.com/portainer/portainer/api/datastore"
"github.com/portainer/portainer/api/http/proxy"
"github.com/portainer/portainer/api/internal/testhelpers"
"github.com/portainer/portainer/api/roar"
)
func TestEndpointDeleteEdgeGroupsConcurrently(t *testing.T) {
@@ -21,7 +22,7 @@ func TestEndpointDeleteEdgeGroupsConcurrently(t *testing.T) {
handler := NewHandler(testhelpers.NewTestRequestBouncer())
handler.DataStore = store
handler.ProxyManager = proxy.NewManager(nil)
handler.ProxyManager.NewProxyFactory(nil, nil, nil, nil, nil, nil, nil, nil)
handler.ProxyManager.NewProxyFactory(nil, nil, nil, nil, nil, nil, nil, nil, nil)
// Create all the environments and add them to the same edge group
@@ -42,9 +43,9 @@ func TestEndpointDeleteEdgeGroupsConcurrently(t *testing.T) {
}
if err := store.EdgeGroup().Create(&portainer.EdgeGroup{
ID: 1,
Name: "edgegroup-1",
Endpoints: endpointIDs,
ID: 1,
Name: "edgegroup-1",
EndpointIDs: roar.FromSlice(endpointIDs),
}); err != nil {
t.Fatal("could not create edge group:", err)
}
@@ -78,7 +79,7 @@ func TestEndpointDeleteEdgeGroupsConcurrently(t *testing.T) {
t.Fatal("could not retrieve the edge group:", err)
}
if len(edgeGroup.Endpoints) > 0 {
if edgeGroup.EndpointIDs.Len() > 0 {
t.Fatal("the edge group is not consistent")
}
}

View File

@@ -49,7 +49,7 @@ func (handler *Handler) endpointSnapshots(w http.ResponseWriter, r *http.Request
continue
}
endpoint.Status = portainer.EndpointStatusUp
latestEndpointReference.Status = portainer.EndpointStatusUp
if snapshotError != nil {
log.Debug().
Str("endpoint", endpoint.Name).
@@ -57,7 +57,7 @@ func (handler *Handler) endpointSnapshots(w http.ResponseWriter, r *http.Request
Err(snapshotError).
Msg("background schedule error (environment snapshot), unable to create snapshot")
endpoint.Status = portainer.EndpointStatusDown
latestEndpointReference.Status = portainer.EndpointStatusDown
}
latestEndpointReference.Agent.Version = endpoint.Agent.Version

View File

@@ -0,0 +1,107 @@
package endpoints
import (
"errors"
"io"
"net/http"
"net/http/httptest"
"sync/atomic"
"testing"
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/datastore"
"github.com/portainer/portainer/api/http/security"
"github.com/portainer/portainer/api/internal/testhelpers"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func Test_endpointSnapshots(t *testing.T) {
_, store := datastore.MustNewTestStore(t, true, true)
endpointID := portainer.EndpointID(123)
endpoint := &portainer.Endpoint{
ID: endpointID,
Name: "mock",
URL: "http://mock.example/",
Status: portainer.EndpointStatusDown, // starts in down state
}
err := store.Endpoint().Create(endpoint)
require.NoError(t, err, "error creating environment")
err = store.User().Create(
&portainer.User{
Username: "admin",
Role: portainer.AdministratorRole,
},
)
require.NoError(t, err, "error creating a user")
bouncer := testhelpers.NewTestRequestBouncer()
snapshotService := &mockSnapshotService{
snapshotEndpointShouldSucceed: atomic.Bool{},
}
snapshotService.snapshotEndpointShouldSucceed.Store(true)
h := NewHandler(bouncer)
h.DataStore = store
h.SnapshotService = snapshotService
doPostRequest := func() {
req := httptest.NewRequest(http.MethodPost, "/endpoints/snapshot", nil)
ctx := security.StoreTokenData(req, &portainer.TokenData{ID: 1, Username: "admin", Role: 1})
req = req.WithContext(ctx)
testhelpers.AddTestSecurityCookie(req, "Bearer dummytoken")
rr := httptest.NewRecorder()
h.ServeHTTP(rr, req)
require.Equal(t, http.StatusNoContent, rr.Code, "Status should be 204")
_, err := io.ReadAll(rr.Body)
require.NoError(t, err, "ReadAll should not return error")
}
doPostRequest()
// check that the endpoint has been immediately set to up
endpoint, err = store.Endpoint().Endpoint(endpointID)
require.NoError(t, err, "error getting endpoint")
assert.Equal(t, portainer.EndpointStatusUp, endpoint.Status, "endpoint should be up (1) since mock snapshot returned ok")
// set the mock to return an error
snapshotService.snapshotEndpointShouldSucceed.Store(false)
doPostRequest()
// check that the endpoint has been immediately set to down
endpoint, err = store.Endpoint().Endpoint(endpointID)
require.NoError(t, err, "error getting endpoint")
assert.Equal(t, portainer.EndpointStatusDown, endpoint.Status, "endpoint should be down (2) since mock snapshot returned error")
}
var _ portainer.SnapshotService = &mockSnapshotService{}
type mockSnapshotService struct {
snapshotEndpointShouldSucceed atomic.Bool
}
func (s *mockSnapshotService) Start() {
}
func (s *mockSnapshotService) SetSnapshotInterval(snapshotInterval string) error {
return nil
}
func (s *mockSnapshotService) SnapshotEndpoint(endpoint *portainer.Endpoint) error {
if s.snapshotEndpointShouldSucceed.Load() {
return nil
}
return errors.New("snapshot failed")
}
func (s *mockSnapshotService) FillSnapshotData(endpoint *portainer.Endpoint, includeRaw bool) error {
return nil
}

View File

@@ -14,7 +14,7 @@ import (
"github.com/portainer/portainer/api/http/security"
"github.com/portainer/portainer/api/internal/edge"
"github.com/portainer/portainer/api/internal/endpointutils"
"github.com/portainer/portainer/api/slicesx"
"github.com/portainer/portainer/api/roar"
"github.com/portainer/portainer/pkg/libhttp/request"
"github.com/pkg/errors"
@@ -146,7 +146,9 @@ func (handler *Handler) filterEndpointsByQuery(
totalAvailableEndpoints := len(filteredEndpoints)
if len(query.endpointIds) > 0 {
filteredEndpoints = filteredEndpointsByIds(filteredEndpoints, query.endpointIds)
endpointIDs := roar.FromSlice(query.endpointIds)
filteredEndpoints = filteredEndpointsByIds(filteredEndpoints, endpointIDs)
}
if len(query.excludeIds) > 0 {
@@ -275,7 +277,7 @@ func filterEndpointsByEdgeStack(endpoints []portainer.Endpoint, edgeStackId port
return nil, errors.WithMessage(err, "Unable to retrieve edge stack from the database")
}
envIds := make([]portainer.EndpointID, 0)
envIds := roar.Roar[portainer.EndpointID]{}
for _, edgeGroupdId := range stack.EdgeGroups {
edgeGroup, err := datastore.EdgeGroup().Read(edgeGroupdId)
if err != nil {
@@ -287,32 +289,37 @@ func filterEndpointsByEdgeStack(endpoints []portainer.Endpoint, edgeStackId port
if err != nil {
return nil, errors.WithMessage(err, "Unable to retrieve environments and environment groups for Edge group")
}
edgeGroup.Endpoints = endpointIDs
edgeGroup.EndpointIDs = roar.FromSlice(endpointIDs)
}
envIds = append(envIds, edgeGroup.Endpoints...)
envIds.Union(edgeGroup.EndpointIDs)
}
if statusFilter != nil {
n := 0
for _, envId := range envIds {
var innerErr error
envIds.Iterate(func(envId portainer.EndpointID) bool {
edgeStackStatus, err := datastore.EdgeStackStatus().Read(edgeStackId, envId)
if dataservices.IsErrObjectNotFound(err) {
continue
return true
} else if err != nil {
return nil, errors.WithMessagef(err, "Unable to retrieve edge stack status for environment %d", envId)
innerErr = errors.WithMessagef(err, "Unable to retrieve edge stack status for environment %d", envId)
return false
}
if endpointStatusInStackMatchesFilter(edgeStackStatus, envId, *statusFilter) {
envIds[n] = envId
n++
if !endpointStatusInStackMatchesFilter(edgeStackStatus, portainer.EndpointID(envId), *statusFilter) {
envIds.Remove(envId)
}
return true
})
if innerErr != nil {
return nil, innerErr
}
envIds = envIds[:n]
}
uniqueIds := slicesx.Unique(envIds)
filteredEndpoints := filteredEndpointsByIds(endpoints, uniqueIds)
filteredEndpoints := filteredEndpointsByIds(endpoints, envIds)
return filteredEndpoints, nil
}
@@ -344,16 +351,14 @@ func filterEndpointsByEdgeGroupIDs(endpoints []portainer.Endpoint, edgeGroups []
}
edgeGroups = edgeGroups[:n]
endpointIDSet := make(map[portainer.EndpointID]struct{})
endpointIDSet := roar.Roar[portainer.EndpointID]{}
for _, edgeGroup := range edgeGroups {
for _, endpointID := range edgeGroup.Endpoints {
endpointIDSet[endpointID] = struct{}{}
}
endpointIDSet.Union(edgeGroup.EndpointIDs)
}
n = 0
for _, endpoint := range endpoints {
if _, exists := endpointIDSet[endpoint.ID]; exists {
if endpointIDSet.Contains(endpoint.ID) {
endpoints[n] = endpoint
n++
}
@@ -369,12 +374,11 @@ func filterEndpointsByExcludeEdgeGroupIDs(endpoints []portainer.Endpoint, edgeGr
}
n := 0
excludeEndpointIDSet := make(map[portainer.EndpointID]struct{})
excludeEndpointIDSet := roar.Roar[portainer.EndpointID]{}
for _, edgeGroup := range edgeGroups {
if _, ok := excludeEdgeGroupIDSet[edgeGroup.ID]; ok {
for _, endpointID := range edgeGroup.Endpoints {
excludeEndpointIDSet[endpointID] = struct{}{}
}
excludeEndpointIDSet.Union(edgeGroup.EndpointIDs)
} else {
edgeGroups[n] = edgeGroup
n++
@@ -384,7 +388,7 @@ func filterEndpointsByExcludeEdgeGroupIDs(endpoints []portainer.Endpoint, edgeGr
n = 0
for _, endpoint := range endpoints {
if _, ok := excludeEndpointIDSet[endpoint.ID]; !ok {
if !excludeEndpointIDSet.Contains(endpoint.ID) {
endpoints[n] = endpoint
n++
}
@@ -609,15 +613,10 @@ func endpointFullMatchTags(endpoint portainer.Endpoint, endpointGroup portainer.
return len(missingTags) == 0
}
func filteredEndpointsByIds(endpoints []portainer.Endpoint, ids []portainer.EndpointID) []portainer.Endpoint {
idsSet := make(map[portainer.EndpointID]bool, len(ids))
for _, id := range ids {
idsSet[id] = true
}
func filteredEndpointsByIds(endpoints []portainer.Endpoint, ids roar.Roar[portainer.EndpointID]) []portainer.Endpoint {
n := 0
for _, endpoint := range endpoints {
if idsSet[endpoint.ID] {
if ids.Contains(endpoint.ID) {
endpoints[n] = endpoint
n++
}
@@ -724,27 +723,22 @@ func getEdgeStackStatusParam(r *http.Request) (*portainer.EdgeStackStatusType, e
}
func getShortestAsyncInterval(endpoint *portainer.Endpoint, settings *portainer.Settings) int {
var edgeIntervalUseDefault int = -1
const edgeIntervalUseDefault = -1
pingInterval := endpoint.Edge.PingInterval
if pingInterval == edgeIntervalUseDefault {
pingInterval = settings.Edge.PingInterval
}
shortestAsyncInterval := pingInterval
snapshotInterval := endpoint.Edge.SnapshotInterval
if snapshotInterval == edgeIntervalUseDefault {
snapshotInterval = settings.Edge.SnapshotInterval
}
if shortestAsyncInterval > snapshotInterval {
shortestAsyncInterval = snapshotInterval
}
commandInterval := endpoint.Edge.CommandInterval
if commandInterval == edgeIntervalUseDefault {
commandInterval = settings.Edge.CommandInterval
}
if shortestAsyncInterval > commandInterval {
shortestAsyncInterval = commandInterval
}
return shortestAsyncInterval
return min(pingInterval, snapshotInterval, commandInterval)
}

View File

@@ -8,9 +8,11 @@ import (
"github.com/portainer/portainer/api/datastore"
"github.com/portainer/portainer/api/http/security"
"github.com/portainer/portainer/api/internal/testhelpers"
"github.com/portainer/portainer/api/roar"
"github.com/portainer/portainer/api/slicesx"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type filterTest struct {
@@ -20,7 +22,6 @@ type filterTest struct {
}
func Test_Filter_AgentVersion(t *testing.T) {
version1Endpoint := portainer.Endpoint{ID: 1, GroupID: 1,
Type: portainer.AgentOnDockerEnvironment,
Agent: struct {
@@ -76,7 +77,6 @@ func Test_Filter_AgentVersion(t *testing.T) {
}
func Test_Filter_edgeFilter(t *testing.T) {
trustedEdgeAsync := portainer.Endpoint{ID: 1, UserTrusted: true, Edge: portainer.EnvironmentEdgeSettings{AsyncMode: true}, GroupID: 1, Type: portainer.EdgeAgentOnDockerEnvironment}
untrustedEdgeAsync := portainer.Endpoint{ID: 2, UserTrusted: false, Edge: portainer.EnvironmentEdgeSettings{AsyncMode: true}, GroupID: 1, Type: portainer.EdgeAgentOnDockerEnvironment}
regularUntrustedEdgeStandard := portainer.Endpoint{ID: 3, UserTrusted: false, Edge: portainer.EnvironmentEdgeSettings{AsyncMode: false}, GroupID: 1, Type: portainer.EdgeAgentOnDockerEnvironment}
@@ -175,7 +175,7 @@ func BenchmarkFilterEndpointsBySearchCriteria_PartialMatch(b *testing.B) {
edgeGroups = append(edgeGroups, portainer.EdgeGroup{
ID: portainer.EdgeGroupID(i + 1),
Name: "edge-group-" + strconv.Itoa(i+1),
Endpoints: append([]portainer.EndpointID{}, endpointIDs...),
EndpointIDs: roar.FromSlice(endpointIDs),
Dynamic: true,
TagIDs: []portainer.TagID{1, 2, 3},
PartialMatch: true,
@@ -222,11 +222,11 @@ func BenchmarkFilterEndpointsBySearchCriteria_FullMatch(b *testing.B) {
edgeGroups := []portainer.EdgeGroup{}
for i := range 1000 {
edgeGroups = append(edgeGroups, portainer.EdgeGroup{
ID: portainer.EdgeGroupID(i + 1),
Name: "edge-group-" + strconv.Itoa(i+1),
Endpoints: append([]portainer.EndpointID{}, endpointIDs...),
Dynamic: true,
TagIDs: []portainer.TagID{1},
ID: portainer.EdgeGroupID(i + 1),
Name: "edge-group-" + strconv.Itoa(i+1),
EndpointIDs: roar.FromSlice(endpointIDs),
Dynamic: true,
TagIDs: []portainer.TagID{1},
})
}
@@ -278,7 +278,6 @@ func runTest(t *testing.T, test filterTest, handler *Handler, endpoints []portai
}
is.ElementsMatch(test.expected, respIds)
}
func setupFilterTest(t *testing.T, endpoints []portainer.Endpoint) *Handler {
@@ -300,3 +299,149 @@ func setupFilterTest(t *testing.T, endpoints []portainer.Endpoint) *Handler {
return handler
}
func TestFilterEndpointsByEdgeStack(t *testing.T) {
_, store := datastore.MustNewTestStore(t, false, false)
endpoints := []portainer.Endpoint{
{ID: 1, Name: "Endpoint 1"},
{ID: 2, Name: "Endpoint 2"},
{ID: 3, Name: "Endpoint 3"},
{ID: 4, Name: "Endpoint 4"},
}
edgeStackId := portainer.EdgeStackID(1)
err := store.EdgeStack().Create(edgeStackId, &portainer.EdgeStack{
ID: edgeStackId,
Name: "Test Edge Stack",
EdgeGroups: []portainer.EdgeGroupID{1, 2},
})
require.NoError(t, err)
err = store.EdgeGroup().Create(&portainer.EdgeGroup{
ID: 1,
Name: "Edge Group 1",
EndpointIDs: roar.FromSlice([]portainer.EndpointID{1}),
})
require.NoError(t, err)
err = store.EdgeGroup().Create(&portainer.EdgeGroup{
ID: 2,
Name: "Edge Group 2",
EndpointIDs: roar.FromSlice([]portainer.EndpointID{2, 3}),
})
require.NoError(t, err)
es, err := filterEndpointsByEdgeStack(endpoints, edgeStackId, nil, store)
require.NoError(t, err)
require.Len(t, es, 3)
require.Contains(t, es, endpoints[0]) // Endpoint 1
require.Contains(t, es, endpoints[1]) // Endpoint 2
require.Contains(t, es, endpoints[2]) // Endpoint 3
require.NotContains(t, es, endpoints[3]) // Endpoint 4
}
func TestFilterEndpointsByEdgeGroup(t *testing.T) {
_, store := datastore.MustNewTestStore(t, false, false)
endpoints := []portainer.Endpoint{
{ID: 1, Name: "Endpoint 1"},
{ID: 2, Name: "Endpoint 2"},
{ID: 3, Name: "Endpoint 3"},
{ID: 4, Name: "Endpoint 4"},
}
err := store.EdgeGroup().Create(&portainer.EdgeGroup{
ID: 1,
Name: "Edge Group 1",
EndpointIDs: roar.FromSlice([]portainer.EndpointID{1}),
})
require.NoError(t, err)
err = store.EdgeGroup().Create(&portainer.EdgeGroup{
ID: 2,
Name: "Edge Group 2",
EndpointIDs: roar.FromSlice([]portainer.EndpointID{2, 3}),
})
require.NoError(t, err)
edgeGroups, err := store.EdgeGroup().ReadAll()
require.NoError(t, err)
es, egs := filterEndpointsByEdgeGroupIDs(endpoints, edgeGroups, []portainer.EdgeGroupID{1, 2})
require.NoError(t, err)
require.Len(t, es, 3)
require.Contains(t, es, endpoints[0]) // Endpoint 1
require.Contains(t, es, endpoints[1]) // Endpoint 2
require.Contains(t, es, endpoints[2]) // Endpoint 3
require.NotContains(t, es, endpoints[3]) // Endpoint 4
require.Len(t, egs, 2)
require.Equal(t, egs[0].ID, portainer.EdgeGroupID(1))
require.Equal(t, egs[1].ID, portainer.EdgeGroupID(2))
}
func TestFilterEndpointsByExcludeEdgeGroupIDs(t *testing.T) {
_, store := datastore.MustNewTestStore(t, false, false)
endpoints := []portainer.Endpoint{
{ID: 1, Name: "Endpoint 1"},
{ID: 2, Name: "Endpoint 2"},
{ID: 3, Name: "Endpoint 3"},
{ID: 4, Name: "Endpoint 4"},
}
err := store.EdgeGroup().Create(&portainer.EdgeGroup{
ID: 1,
Name: "Edge Group 1",
EndpointIDs: roar.FromSlice([]portainer.EndpointID{1}),
})
require.NoError(t, err)
err = store.EdgeGroup().Create(&portainer.EdgeGroup{
ID: 2,
Name: "Edge Group 2",
EndpointIDs: roar.FromSlice([]portainer.EndpointID{2, 3}),
})
require.NoError(t, err)
edgeGroups, err := store.EdgeGroup().ReadAll()
require.NoError(t, err)
es, egs := filterEndpointsByExcludeEdgeGroupIDs(endpoints, edgeGroups, []portainer.EdgeGroupID{1})
require.NoError(t, err)
require.Len(t, es, 3)
require.Equal(t, es, []portainer.Endpoint{
{ID: 2, Name: "Endpoint 2"},
{ID: 3, Name: "Endpoint 3"},
{ID: 4, Name: "Endpoint 4"},
})
require.Len(t, egs, 1)
require.Equal(t, egs[0].ID, portainer.EdgeGroupID(2))
}
func TestGetShortestAsyncInterval(t *testing.T) {
endpoint := &portainer.Endpoint{
ID: 1,
Name: "Test Endpoint",
Edge: portainer.EnvironmentEdgeSettings{
PingInterval: -1,
SnapshotInterval: -1,
CommandInterval: -1,
},
}
settings := &portainer.Settings{
Edge: portainer.Edge{
PingInterval: 10,
SnapshotInterval: 20,
CommandInterval: 30,
},
}
require.Equal(t, 10, getShortestAsyncInterval(endpoint, settings))
}

View File

@@ -1,12 +1,11 @@
package endpoints
import (
"slices"
"github.com/pkg/errors"
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/dataservices"
"github.com/portainer/portainer/api/set"
"github.com/pkg/errors"
)
func updateEnvironmentEdgeGroups(tx dataservices.DataStoreTx, newEdgeGroups []portainer.EdgeGroupID, environmentID portainer.EndpointID) (bool, error) {
@@ -19,10 +18,8 @@ func updateEnvironmentEdgeGroups(tx dataservices.DataStoreTx, newEdgeGroups []po
environmentEdgeGroupsSet := set.Set[portainer.EdgeGroupID]{}
for _, edgeGroup := range edgeGroups {
for _, eID := range edgeGroup.Endpoints {
if eID == environmentID {
environmentEdgeGroupsSet[edgeGroup.ID] = true
}
if edgeGroup.EndpointIDs.Contains(environmentID) {
environmentEdgeGroupsSet[edgeGroup.ID] = true
}
}
@@ -52,20 +49,16 @@ func updateEnvironmentEdgeGroups(tx dataservices.DataStoreTx, newEdgeGroups []po
}
removeEdgeGroups := environmentEdgeGroupsSet.Difference(newEdgeGroupsSet)
err = updateSet(removeEdgeGroups, func(edgeGroup *portainer.EdgeGroup) {
edgeGroup.Endpoints = slices.DeleteFunc(edgeGroup.Endpoints, func(eID portainer.EndpointID) bool {
return eID == environmentID
})
})
if err != nil {
if err := updateSet(removeEdgeGroups, func(edgeGroup *portainer.EdgeGroup) {
edgeGroup.EndpointIDs.Remove(environmentID)
}); err != nil {
return false, err
}
addToEdgeGroups := newEdgeGroupsSet.Difference(environmentEdgeGroupsSet)
err = updateSet(addToEdgeGroups, func(edgeGroup *portainer.EdgeGroup) {
edgeGroup.Endpoints = append(edgeGroup.Endpoints, environmentID)
})
if err != nil {
if err := updateSet(addToEdgeGroups, func(edgeGroup *portainer.EdgeGroup) {
edgeGroup.EndpointIDs.Add(environmentID)
}); err != nil {
return false, err
}

View File

@@ -6,6 +6,7 @@ import (
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/dataservices"
"github.com/portainer/portainer/api/datastore"
"github.com/stretchr/testify/assert"
)
@@ -14,10 +15,9 @@ func Test_updateEdgeGroups(t *testing.T) {
groups := make([]portainer.EdgeGroup, len(names))
for index, name := range names {
group := &portainer.EdgeGroup{
Name: name,
Dynamic: false,
TagIDs: make([]portainer.TagID, 0),
Endpoints: make([]portainer.EndpointID, 0),
Name: name,
Dynamic: false,
TagIDs: make([]portainer.TagID, 0),
}
if err := store.EdgeGroup().Create(group); err != nil {
@@ -35,13 +35,8 @@ func Test_updateEdgeGroups(t *testing.T) {
group, err := store.EdgeGroup().Read(groupID)
is.NoError(err)
for _, endpoint := range group.Endpoints {
if endpoint == endpointID {
return
}
}
is.Fail("expected endpoint to be in group")
is.True(group.EndpointIDs.Contains(endpointID),
"expected endpoint to be in group")
}
}
@@ -81,7 +76,7 @@ func Test_updateEdgeGroups(t *testing.T) {
endpointGroups := groupsByName(groups, testCase.endpointGroupNames)
for _, group := range endpointGroups {
group.Endpoints = append(group.Endpoints, testCase.endpoint.ID)
group.EndpointIDs.Add(testCase.endpoint.ID)
err = store.EdgeGroup().Update(group.ID, &group)
is.NoError(err)

View File

@@ -10,7 +10,6 @@ import (
)
func Test_updateTags(t *testing.T) {
createTags := func(store *datastore.Store, tagNames []string) ([]portainer.Tag, error) {
tags := make([]portainer.Tag, len(tagNames))
for index, tagName := range tagNames {

View File

@@ -55,6 +55,10 @@ func (handler *Handler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
}
if r.RequestURI == "/" || strings.HasSuffix(r.RequestURI, ".html") {
w.Header().Set("Permissions-Policy", strings.Join(permissions, ","))
}
if !isHTML(r.Header["Accept"]) {
w.Header().Set("Cache-Control", "max-age=31536000")
} else {

View File

@@ -0,0 +1,70 @@
package file_test
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/portainer/portainer/api/http/handler/file"
"github.com/stretchr/testify/require"
)
func TestNormalServe(t *testing.T) {
handler := file.NewHandler("", false, func() bool { return false })
require.NotNil(t, handler)
request := func(path string) (*http.Request, *httptest.ResponseRecorder) {
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, path, nil)
handler.ServeHTTP(rr, req)
return req, rr
}
_, rr := request("/timeout.html")
require.Equal(t, http.StatusTemporaryRedirect, rr.Result().StatusCode)
loc, err := rr.Result().Location()
require.NoError(t, err)
require.NotNil(t, loc)
require.Equal(t, "/", loc.Path)
_, rr = request("/")
require.Equal(t, http.StatusOK, rr.Result().StatusCode)
}
func TestPermissionsPolicyHeader(t *testing.T) {
handler := file.NewHandler("", false, func() bool { return false })
require.NotNil(t, handler)
test := func(path string, exist bool) {
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, path, nil)
handler.ServeHTTP(rr, req)
require.Equal(t, exist, rr.Result().Header.Get("Permissions-Policy") != "")
}
test("/", true)
test("/index.html", true)
test("/api", false)
test("/an/image.png", false)
}
func TestRedirectInstanceDisabled(t *testing.T) {
handler := file.NewHandler("", false, func() bool { return true })
require.NotNil(t, handler)
test := func(path string) {
rr := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, path, nil)
handler.ServeHTTP(rr, req)
require.Equal(t, http.StatusTemporaryRedirect, rr.Result().StatusCode)
loc, err := rr.Result().Location()
require.NoError(t, err)
require.NotNil(t, loc)
require.Equal(t, "/timeout.html", loc.Path)
}
test("/")
test("/index.html")
}

View File

@@ -0,0 +1,91 @@
package file
var permissions = []string{
"accelerometer=()",
"ambient-light-sensor=()",
"attribution-reporting=()",
"autoplay=()",
"battery=()",
"browsing-topics=()",
"camera=()",
"captured-surface-control=()",
"ch-device-memory=()",
"ch-downlink=()",
"ch-dpr=()",
"ch-ect=()",
"ch-prefers-color-scheme=()",
"ch-prefers-reduced-motion=()",
"ch-prefers-reduced-transparency=()",
"ch-rtt=()",
"ch-save-data=()",
"ch-ua=()",
"ch-ua-arch=()",
"ch-ua-bitness=()",
"ch-ua-form-factors=()",
"ch-ua-full-version=()",
"ch-ua-full-version-list=()",
"ch-ua-mobile=()",
"ch-ua-model=()",
"ch-ua-platform=()",
"ch-ua-platform-version=()",
"ch-ua-wow64=()",
"ch-viewport-height=()",
"ch-viewport-width=()",
"ch-width=()",
"compute-pressure=()",
"conversion-measurement=()",
"cross-origin-isolated=()",
"deferred-fetch=()",
"deferred-fetch-minimal=()",
"display-capture=()",
"document-domain=()",
"encrypted-media=()",
"execution-while-not-rendered=()",
"execution-while-out-of-viewport=()",
"focus-without-user-activation=()",
"fullscreen=()",
"gamepad=()",
"geolocation=()",
"gyroscope=()",
"hid=()",
"identity-credentials-get=()",
"idle-detection=()",
"interest-cohort=()",
"join-ad-interest-group=()",
"keyboard-map=()",
"language-detector=()",
"local-fonts=()",
"magnetometer=()",
"microphone=()",
"midi=()",
"navigation-override=()",
"otp-credentials=()",
"payment=()",
"picture-in-picture=()",
"private-aggregation=()",
"private-state-token-issuance=()",
"private-state-token-redemption=()",
"publickey-credentials-create=()",
"publickey-credentials-get=()",
"rewriter=()",
"run-ad-auction=()",
"screen-wake-lock=()",
"serial=()",
"shared-storage=()",
"shared-storage-select-url=()",
"speaker-selection=()",
"storage-access=()",
"summarizer=()",
"sync-script=()",
"sync-xhr=()",
"translator=()",
"trust-token-redemption=()",
"unload=()",
"usb=()",
"vertical-scroll=()",
"web-share=()",
"window-management=()",
"window-placement=()",
"writer=()",
"xr-spatial-tracking=()",
}

Some files were not shown because too many files have changed in this diff Show More