Compare commits
6 Commits
develop
...
release/2.
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
38ca054aee | ||
|
|
faee22c907 | ||
|
|
6547b06f97 | ||
|
|
87b1cc80ed | ||
|
|
eaf5283cdb | ||
|
|
afd8507042 |
@@ -17,7 +17,7 @@ plugins:
|
||||
- import
|
||||
|
||||
parserOptions:
|
||||
ecmaVersion: latest
|
||||
ecmaVersion: 2018
|
||||
sourceType: module
|
||||
project: './tsconfig.json'
|
||||
ecmaFeatures:
|
||||
@@ -114,13 +114,7 @@ overrides:
|
||||
'@typescript-eslint/explicit-module-boundary-types': off
|
||||
'@typescript-eslint/no-unused-vars': 'error'
|
||||
'@typescript-eslint/no-explicit-any': 'error'
|
||||
'jsx-a11y/label-has-associated-control':
|
||||
- error
|
||||
- assert: either
|
||||
controlComponents:
|
||||
- Input
|
||||
- Checkbox
|
||||
'jsx-a11y/control-has-associated-label': off
|
||||
'jsx-a11y/label-has-associated-control': ['error', { 'assert': 'either', controlComponents: ['Input', 'Checkbox'] }]
|
||||
'react/function-component-definition': ['error', { 'namedComponents': 'function-declaration' }]
|
||||
'react/jsx-no-bind': off
|
||||
'no-await-in-loop': 'off'
|
||||
@@ -139,19 +133,15 @@ overrides:
|
||||
'react/jsx-props-no-spreading': off
|
||||
- files:
|
||||
- app/**/*.test.*
|
||||
plugins:
|
||||
- '@vitest'
|
||||
extends:
|
||||
- 'plugin:@vitest/legacy-recommended'
|
||||
- 'plugin:vitest/recommended'
|
||||
env:
|
||||
'@vitest/env': true
|
||||
'vitest/env': true
|
||||
rules:
|
||||
'react/jsx-no-constructed-context-values': off
|
||||
'@typescript-eslint/no-restricted-imports': off
|
||||
no-restricted-imports: off
|
||||
'react/jsx-props-no-spreading': off
|
||||
'@vitest/no-conditional-expect': warn
|
||||
'max-classes-per-file': off
|
||||
- files:
|
||||
- app/**/*.stories.*
|
||||
rules:
|
||||
@@ -159,4 +149,3 @@ overrides:
|
||||
'@typescript-eslint/no-restricted-imports': off
|
||||
no-restricted-imports: off
|
||||
'react/jsx-props-no-spreading': off
|
||||
'storybook/no-renderer-packages': off
|
||||
|
||||
2
.github/DISCUSSION_TEMPLATE/ideas.yaml
vendored
2
.github/DISCUSSION_TEMPLATE/ideas.yaml
vendored
@@ -6,7 +6,7 @@ body:
|
||||
|
||||
Thanks for suggesting an idea for Portainer!
|
||||
|
||||
Before opening a new idea or feature request, make sure that we do not have any duplicates already open. You can ensure this by [searching this discussion category](https://github.com/orgs/portainer/discussions/categories/ideas). If there is a duplicate, please add a comment to the existing idea instead.
|
||||
Before opening a new idea or feature request, make sure that we do not have any duplicates already open. You can ensure this by [searching this discussion cagetory](https://github.com/orgs/portainer/discussions/categories/ideas). If there is a duplicate, please add a comment to the existing idea instead.
|
||||
|
||||
Also, be sure to check our [knowledge base](https://portal.portainer.io/knowledge) and [documentation](https://docs.portainer.io) as they may point you toward a solution.
|
||||
|
||||
|
||||
35
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
35
.github/ISSUE_TEMPLATE/bug_report.yml
vendored
@@ -22,7 +22,7 @@ body:
|
||||
options:
|
||||
- label: Yes, I've searched similar issues on [GitHub](https://github.com/portainer/portainer/issues).
|
||||
required: true
|
||||
- label: Yes, I've checked whether this issue is covered in the Portainer [documentation](https://docs.portainer.io).
|
||||
- label: Yes, I've checked whether this issue is covered in the Portainer [documentation](https://docs.portainer.io) or [knowledge base](https://portal.portainer.io/knowledge).
|
||||
required: true
|
||||
|
||||
- type: markdown
|
||||
@@ -94,39 +94,8 @@ body:
|
||||
description: We only provide support for current versions of Portainer as per the lifecycle policy linked above. If you are on an older version of Portainer we recommend [updating first](https://docs.portainer.io/start/upgrade) in case your bug has already been fixed.
|
||||
multiple: false
|
||||
options:
|
||||
- '2.39.0'
|
||||
- '2.38.1'
|
||||
- '2.38.0'
|
||||
- '2.37.0'
|
||||
- '2.36.0'
|
||||
- '2.35.0'
|
||||
- '2.34.0'
|
||||
- '2.33.7'
|
||||
- '2.33.6'
|
||||
- '2.33.5'
|
||||
- '2.33.4'
|
||||
- '2.33.3'
|
||||
- '2.33.2'
|
||||
- '2.33.1'
|
||||
- '2.33.0'
|
||||
- '2.32.0'
|
||||
- '2.31.3'
|
||||
- '2.31.2'
|
||||
- '2.31.1'
|
||||
- '2.31.0'
|
||||
- '2.30.1'
|
||||
- '2.30.0'
|
||||
- '2.29.2'
|
||||
- '2.29.1'
|
||||
- '2.29.0'
|
||||
- '2.28.1'
|
||||
- '2.28.0'
|
||||
- '2.27.9'
|
||||
- '2.27.8'
|
||||
- '2.27.7'
|
||||
- '2.27.6'
|
||||
- '2.27.5'
|
||||
- '2.27.4'
|
||||
- '2.27.3'
|
||||
- '2.27.2'
|
||||
- '2.27.1'
|
||||
@@ -143,6 +112,8 @@ body:
|
||||
- '2.21.4'
|
||||
- '2.21.3'
|
||||
- '2.21.2'
|
||||
- '2.21.1'
|
||||
- '2.21.0'
|
||||
validations:
|
||||
required: true
|
||||
|
||||
|
||||
2
.gitignore
vendored
2
.gitignore
vendored
@@ -18,5 +18,3 @@ api/docs
|
||||
.env
|
||||
go.work.sum
|
||||
|
||||
.vitest
|
||||
|
||||
|
||||
@@ -1,16 +0,0 @@
|
||||
version: "2"
|
||||
linters:
|
||||
default: none
|
||||
enable:
|
||||
- forbidigo
|
||||
settings:
|
||||
forbidigo:
|
||||
forbid:
|
||||
- pattern: ^dataservices.DataStore.(EdgeGroup|EdgeJob|EdgeStack|EndpointRelation|Endpoint|GitCredential|Registry|ResourceControl|Role|Settings|Snapshot|SSLSettings|Stack|Tag|User)$
|
||||
msg: Use a transaction instead
|
||||
analyze-types: true
|
||||
exclusions:
|
||||
rules:
|
||||
- path: _test\.go
|
||||
linters:
|
||||
- forbidigo
|
||||
136
.golangci.yaml
136
.golangci.yaml
@@ -1,108 +1,40 @@
|
||||
version: "2"
|
||||
|
||||
run:
|
||||
allow-parallel-runners: true
|
||||
linters:
|
||||
default: none
|
||||
# Disable all linters, the defaults don't pass on our code yet
|
||||
disable-all: true
|
||||
|
||||
# Enable these for now
|
||||
enable:
|
||||
- bodyclose
|
||||
- copyloopvar
|
||||
- unused
|
||||
- depguard
|
||||
- errcheck
|
||||
- errorlint
|
||||
- forbidigo
|
||||
- gosimple
|
||||
- govet
|
||||
- ineffassign
|
||||
- errorlint
|
||||
- copyloopvar
|
||||
- intrange
|
||||
- perfsprint
|
||||
- staticcheck
|
||||
- unused
|
||||
- mirror
|
||||
- durationcheck
|
||||
- errorlint
|
||||
- govet
|
||||
- usetesting
|
||||
- zerologlint
|
||||
- testifylint
|
||||
- modernize
|
||||
- unconvert
|
||||
- unused
|
||||
- zerologlint
|
||||
- exptostd
|
||||
settings:
|
||||
staticcheck:
|
||||
checks: ["all", "-ST1003", "-ST1005", "-ST1016", "-SA1019", "-QF1003"]
|
||||
depguard:
|
||||
rules:
|
||||
main:
|
||||
files:
|
||||
- '!**/*_test.go'
|
||||
- '!**/base.go'
|
||||
- '!**/base_tx.go'
|
||||
deny:
|
||||
- pkg: encoding/json
|
||||
desc: use github.com/segmentio/encoding/json
|
||||
- pkg: golang.org/x/exp
|
||||
desc: exp is not allowed
|
||||
- pkg: github.com/portainer/libcrypto
|
||||
desc: use github.com/portainer/portainer/pkg/libcrypto
|
||||
- pkg: github.com/portainer/libhttp
|
||||
desc: use github.com/portainer/portainer/pkg/libhttp
|
||||
- pkg: golang.org/x/crypto
|
||||
desc: golang.org/x/crypto is not allowed because of FIPS mode
|
||||
- pkg: github.com/ProtonMail/go-crypto/openpgp
|
||||
desc: github.com/ProtonMail/go-crypto/openpgp is not allowed because of FIPS mode
|
||||
- pkg: github.com/cosi-project/runtime
|
||||
desc: github.com/cosi-project/runtime is not allowed because of FIPS mode
|
||||
- pkg: gopkg.in/yaml.v2
|
||||
desc: use go.yaml.in/yaml/v3 instead
|
||||
- pkg: gopkg.in/yaml.v3
|
||||
desc: use go.yaml.in/yaml/v3 instead
|
||||
- pkg: github.com/golang-jwt/jwt/v4
|
||||
desc: use github.com/golang-jwt/jwt/v5 instead
|
||||
- pkg: github.com/mitchellh/mapstructure
|
||||
desc: use github.com/go-viper/mapstructure/v2 instead
|
||||
- pkg: gopkg.in/alecthomas/kingpin.v2
|
||||
desc: use github.com/alecthomas/kingpin/v2 instead
|
||||
- pkg: github.com/jcmturner/gokrb5$
|
||||
desc: use github.com/jcmturner/gokrb5/v8 instead
|
||||
- pkg: github.com/gofrs/uuid
|
||||
desc: use github.com/google/uuid
|
||||
- pkg: github.com/Masterminds/semver$
|
||||
desc: use github.com/Masterminds/semver/v3
|
||||
- pkg: github.com/blang/semver
|
||||
desc: use github.com/Masterminds/semver/v3
|
||||
- pkg: github.com/coreos/go-semver
|
||||
desc: use github.com/Masterminds/semver/v3
|
||||
- pkg: github.com/hashicorp/go-version
|
||||
desc: use github.com/Masterminds/semver/v3
|
||||
forbidigo:
|
||||
forbid:
|
||||
- pattern: ^tls\.Config$
|
||||
msg: Use crypto.CreateTLSConfiguration() instead
|
||||
- pattern: ^tls\.Config\.(InsecureSkipVerify|MinVersion|MaxVersion|CipherSuites|CurvePreferences)$
|
||||
msg: Do not set this field directly, use crypto.CreateTLSConfiguration() instead
|
||||
- pattern: ^object\.(Commit|Tag)\.Verify$
|
||||
msg: "Not allowed because of FIPS mode"
|
||||
- pattern: ^(types\.SystemContext\.)?(DockerDaemonInsecureSkipTLSVerify|DockerInsecureSkipTLSVerify|OCIInsecureSkipTLSVerify)$
|
||||
msg: "Not allowed because of FIPS mode"
|
||||
analyze-types: true
|
||||
exclusions:
|
||||
generated: lax
|
||||
presets:
|
||||
- comments
|
||||
- common-false-positives
|
||||
- legacy
|
||||
paths:
|
||||
- third_party$
|
||||
- builtin$
|
||||
- examples$
|
||||
formatters:
|
||||
enable:
|
||||
- gofmt
|
||||
exclusions:
|
||||
generated: lax
|
||||
paths:
|
||||
- third_party$
|
||||
- builtin$
|
||||
- examples$
|
||||
|
||||
linters-settings:
|
||||
depguard:
|
||||
rules:
|
||||
main:
|
||||
deny:
|
||||
- pkg: 'encoding/json'
|
||||
desc: 'use github.com/segmentio/encoding/json'
|
||||
- pkg: 'golang.org/x/exp'
|
||||
desc: 'exp is not allowed'
|
||||
- pkg: 'github.com/portainer/libcrypto'
|
||||
desc: 'use github.com/portainer/portainer/pkg/libcrypto'
|
||||
- pkg: 'github.com/portainer/libhttp'
|
||||
desc: 'use github.com/portainer/portainer/pkg/libhttp'
|
||||
files:
|
||||
- '!**/*_test.go'
|
||||
- '!**/base.go'
|
||||
- '!**/base_tx.go'
|
||||
|
||||
# errorlint is causing a typecheck error for some reason. The go compiler will report these
|
||||
# anyway, so ignore them from the linter
|
||||
issues:
|
||||
exclude-rules:
|
||||
- path: ./
|
||||
linters:
|
||||
- typecheck
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
#!/usr/bin/env sh
|
||||
. "$(dirname -- "$0")/_/husky.sh"
|
||||
|
||||
cd $(dirname -- "$0") && pnpm lint-staged
|
||||
cd $(dirname -- "$0") && yarn lint-staged
|
||||
@@ -1,3 +1,2 @@
|
||||
dist
|
||||
api/datastore/test_data
|
||||
coverage
|
||||
api/datastore/test_data
|
||||
@@ -9,38 +9,20 @@ const config: StorybookConfig = {
|
||||
addons: [
|
||||
'@storybook/addon-links',
|
||||
'@storybook/addon-essentials',
|
||||
'@storybook/addon-webpack5-compiler-swc',
|
||||
'@chromatic-com/storybook',
|
||||
{
|
||||
name: '@storybook/addon-styling-webpack',
|
||||
|
||||
name: '@storybook/addon-styling',
|
||||
options: {
|
||||
rules: [
|
||||
{
|
||||
test: /\.css$/,
|
||||
sideEffects: true,
|
||||
use: [
|
||||
require.resolve('style-loader'),
|
||||
{
|
||||
loader: require.resolve('css-loader'),
|
||||
options: {
|
||||
importLoaders: 1,
|
||||
modules: {
|
||||
localIdentName: '[path][name]__[local]',
|
||||
auto: true,
|
||||
exportLocalsConvention: 'camelCaseOnly',
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
loader: require.resolve('postcss-loader'),
|
||||
options: {
|
||||
implementation: postcss,
|
||||
},
|
||||
},
|
||||
],
|
||||
cssLoaderOptions: {
|
||||
importLoaders: 1,
|
||||
modules: {
|
||||
localIdentName: '[path][name]__[local]',
|
||||
auto: true,
|
||||
exportLocalsConvention: 'camelCaseOnly',
|
||||
},
|
||||
],
|
||||
},
|
||||
postCss: {
|
||||
implementation: postcss,
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import '../app/assets/css';
|
||||
import React from 'react';
|
||||
import { pushStateLocationPlugin, UIRouter } from '@uirouter/react';
|
||||
import { initialize as initMSW, mswLoader } from 'msw-storybook-addon';
|
||||
import { handlers } from '../app/setup-tests/server-handlers';
|
||||
import { QueryClient, QueryClientProvider } from '@tanstack/react-query';
|
||||
import { Preview } from '@storybook/react';
|
||||
|
||||
initMSW(
|
||||
{
|
||||
@@ -21,30 +21,31 @@ initMSW(
|
||||
handlers
|
||||
);
|
||||
|
||||
export const parameters = {
|
||||
actions: { argTypesRegex: '^on[A-Z].*' },
|
||||
controls: {
|
||||
matchers: {
|
||||
color: /(background|color)$/i,
|
||||
date: /Date$/,
|
||||
},
|
||||
},
|
||||
msw: {
|
||||
handlers,
|
||||
},
|
||||
};
|
||||
|
||||
const testQueryClient = new QueryClient({
|
||||
defaultOptions: { queries: { retry: false } },
|
||||
});
|
||||
|
||||
const preview: Preview = {
|
||||
decorators: (Story) => (
|
||||
export const decorators = [
|
||||
(Story) => (
|
||||
<QueryClientProvider client={testQueryClient}>
|
||||
<UIRouter plugins={[pushStateLocationPlugin]}>
|
||||
<Story />
|
||||
</UIRouter>
|
||||
</QueryClientProvider>
|
||||
),
|
||||
loaders: [mswLoader],
|
||||
parameters: {
|
||||
controls: {
|
||||
matchers: {
|
||||
color: /(background|color)$/i,
|
||||
date: /Date$/,
|
||||
},
|
||||
},
|
||||
msw: {
|
||||
handlers,
|
||||
},
|
||||
},
|
||||
};
|
||||
];
|
||||
|
||||
export default preview;
|
||||
export const loaders = [mswLoader];
|
||||
|
||||
44
CLAUDE.md
44
CLAUDE.md
@@ -1,44 +0,0 @@
|
||||
# Portainer Community Edition
|
||||
|
||||
Open-source container management platform with full Docker and Kubernetes support.
|
||||
|
||||
see also:
|
||||
|
||||
- docs/guidelines/server-architecture.md
|
||||
- docs/guidelines/go-conventions.md
|
||||
- docs/guidelines/typescript-conventions.md
|
||||
|
||||
## Package Manager
|
||||
|
||||
- **PNPM** 10+ (for frontend)
|
||||
- **Go** 1.25.7 (for backend)
|
||||
|
||||
## Build Commands
|
||||
|
||||
```bash
|
||||
# Full build
|
||||
make build # Build both client and server
|
||||
make build-client # Build React/AngularJS frontend
|
||||
make build-server # Build Go binary
|
||||
make build-image # Build Docker image
|
||||
|
||||
# Development
|
||||
make dev # Run both in dev mode
|
||||
make dev-client # Start webpack-dev-server (port 8999)
|
||||
make dev-server # Run containerized Go server
|
||||
|
||||
pnpm run dev # Webpack dev server
|
||||
pnpm run build # Build frontend with webpack
|
||||
pnpm run test # Run frontend tests
|
||||
|
||||
# Testing
|
||||
make test # All tests (backend + frontend)
|
||||
make test-server # Backend tests only
|
||||
make lint # Lint all code
|
||||
make format # Format code
|
||||
```
|
||||
|
||||
## Development Servers
|
||||
|
||||
- Frontend: http://localhost:8999
|
||||
- Backend: http://localhost:9000 (HTTP) / https://localhost:9443 (HTTPS)
|
||||
@@ -77,7 +77,7 @@ The feature request process is similar to the bug report process but has an extr
|
||||
|
||||
## Build and run Portainer locally
|
||||
|
||||
Ensure you have Docker, Node.js, pnpm, and Golang installed in the correct versions.
|
||||
Ensure you have Docker, Node.js, yarn, and Golang installed in the correct versions.
|
||||
|
||||
Install dependencies:
|
||||
|
||||
|
||||
36
Makefile
36
Makefile
@@ -1,3 +1,9 @@
|
||||
# See: https://gist.github.com/asukakenji/f15ba7e588ac42795f421b48b8aede63
|
||||
# For a list of valid GOOS and GOARCH values
|
||||
# Note: these can be overriden on the command line e.g. `make PLATFORM=<platform> ARCH=<arch>`
|
||||
PLATFORM=$(shell go env GOOS)
|
||||
ARCH=$(shell go env GOARCH)
|
||||
|
||||
# build target, can be one of "production", "testing", "development"
|
||||
ENV=development
|
||||
WEBPACK_CONFIG=webpack/webpack.$(ENV).js
|
||||
@@ -20,7 +26,7 @@ all: tidy deps build-server build-client ## Build the client, server and downloa
|
||||
build-all: all ## Alias for the 'all' target (used by CI)
|
||||
|
||||
build-client: init-dist ## Build the client
|
||||
export NODE_ENV=$(ENV) && pnpm run build --config $(WEBPACK_CONFIG)
|
||||
export NODE_ENV=$(ENV) && yarn build --config $(WEBPACK_CONFIG)
|
||||
|
||||
build-server: init-dist ## Build the server binary
|
||||
./build/build_binary.sh "$(PLATFORM)" "$(ARCH)"
|
||||
@@ -29,7 +35,11 @@ build-image: build-all ## Build the Portainer image locally
|
||||
docker buildx build --load -t portainerci/portainer-ce:$(TAG) -f build/linux/Dockerfile .
|
||||
|
||||
build-storybook: ## Build and serve the storybook files
|
||||
pnpm run storybook:build
|
||||
yarn storybook:build
|
||||
|
||||
devops: clean deps build-client ## Build the everything target specifically for CI
|
||||
echo "Building the devops binary..."
|
||||
@./build/build_binary_azuredevops.sh "$(PLATFORM)" "$(ARCH)"
|
||||
|
||||
##@ Build dependencies
|
||||
.PHONY: deps server-deps client-deps tidy
|
||||
@@ -39,23 +49,25 @@ server-deps: init-dist ## Download dependant server binaries
|
||||
@./build/download_binaries.sh $(PLATFORM) $(ARCH)
|
||||
|
||||
client-deps: ## Install client dependencies
|
||||
pnpm install
|
||||
yarn
|
||||
|
||||
tidy: ## Tidy up the go.mod file
|
||||
@go mod tidy
|
||||
|
||||
|
||||
##@ Cleanup
|
||||
.PHONY: clean
|
||||
clean: ## Remove all build and download artifacts
|
||||
@echo "Clearing the dist directory..."
|
||||
@rm -rf dist/*
|
||||
|
||||
|
||||
##@ Testing
|
||||
.PHONY: test test-client test-server
|
||||
test: test-server test-client ## Run all tests
|
||||
|
||||
test-client: ## Run client tests
|
||||
pnpm run test $(ARGS) --coverage
|
||||
yarn test $(ARGS) --coverage
|
||||
|
||||
test-server: ## Run server tests
|
||||
$(GOTESTSUM) --format pkgname-and-test-fails --format-hide-empty-pkg --hide-summary skipped -- -cover -covermode=atomic -coverprofile=coverage.out ./...
|
||||
@@ -67,7 +79,7 @@ dev: ## Run both the client and server in development mode
|
||||
make dev-client
|
||||
|
||||
dev-client: ## Run the client in development mode
|
||||
pnpm install && pnpm run dev
|
||||
yarn dev
|
||||
|
||||
dev-server: build-server ## Run the server in development mode
|
||||
@./dev/run_container.sh
|
||||
@@ -81,7 +93,7 @@ dev-server-podman: build-server ## Run the server in development mode
|
||||
format: format-client format-server ## Format all code
|
||||
|
||||
format-client: ## Format client code
|
||||
pnpm run format
|
||||
yarn format
|
||||
|
||||
format-server: ## Format server code
|
||||
go fmt ./...
|
||||
@@ -91,26 +103,26 @@ format-server: ## Format server code
|
||||
lint: lint-client lint-server ## Lint all code
|
||||
|
||||
lint-client: ## Lint client code
|
||||
pnpm run lint
|
||||
yarn lint
|
||||
|
||||
lint-server: tidy ## Lint server code
|
||||
lint-server: ## Lint server code
|
||||
golangci-lint run --timeout=10m -c .golangci.yaml
|
||||
golangci-lint run --timeout=10m --new-from-rev=HEAD~ -c .golangci-forward.yaml
|
||||
|
||||
|
||||
##@ Extension
|
||||
.PHONY: dev-extension
|
||||
dev-extension: build-server build-client ## Run the extension in development mode
|
||||
make local -f build/docker-extension/Makefile
|
||||
|
||||
|
||||
##@ Docs
|
||||
.PHONY: docs-build docs-validate docs-clean docs-validate-clean
|
||||
docs-build: init-dist ## Build docs
|
||||
go mod download -x
|
||||
cd api && $(SWAG) init -o "../dist/docs" -ot "yaml" -g ./http/handler/handler.go --parseDependency --parseInternal --parseDepth 2 -p pascalcase --markdownFiles ./
|
||||
|
||||
docs-validate: docs-build ## Validate docs
|
||||
pnpm swagger2openapi --warnOnly dist/docs/swagger.yaml -o dist/docs/openapi.yaml
|
||||
pnpm swagger-cli validate dist/docs/openapi.yaml
|
||||
yarn swagger2openapi --warnOnly dist/docs/swagger.yaml -o dist/docs/openapi.yaml
|
||||
yarn swagger-cli validate dist/docs/openapi.yaml
|
||||
|
||||
##@ Helpers
|
||||
.PHONY: help
|
||||
|
||||
19
README.md
19
README.md
@@ -8,9 +8,9 @@ Portainer consists of a single container that can run on any cluster. It can be
|
||||
|
||||
**Portainer Business Edition** builds on the open-source base and includes a range of advanced features and functions (like RBAC and Support) that are specific to the needs of business users.
|
||||
|
||||
- [Compare Portainer CE and Compare Portainer BE](https://www.portainer.io/features)
|
||||
- [Compare Portainer CE and Compare Portainer BE](https://portainer.io/products)
|
||||
- [Take3 – get 3 free nodes of Portainer Business for as long as you want them](https://www.portainer.io/take-3)
|
||||
- [Portainer BE install guide](https://academy.portainer.io/install/)
|
||||
- [Portainer BE install guide](https://install.portainer.io)
|
||||
|
||||
## Latest Version
|
||||
|
||||
@@ -20,19 +20,22 @@ Portainer CE is updated regularly. We aim to do an update release every couple o
|
||||
|
||||
## Getting started
|
||||
|
||||
- [Deploy Portainer](https://docs.portainer.io/start/install-ce)
|
||||
- [Deploy Portainer](https://docs.portainer.io/start/install)
|
||||
- [Documentation](https://docs.portainer.io)
|
||||
- [Contribute to the project](https://docs.portainer.io/contribute/contribute)
|
||||
|
||||
## Features & Functions
|
||||
|
||||
View [this](https://www.portainer.io/features) table to see all of the Portainer CE functionality and compare to Portainer Business.
|
||||
View [this](https://www.portainer.io/products) table to see all of the Portainer CE functionality and compare to Portainer Business.
|
||||
|
||||
- [Portainer CE for Docker / Docker Swarm](https://www.portainer.io/solutions/docker)
|
||||
- [Portainer CE for Kubernetes](https://www.portainer.io/solutions/kubernetes-ui)
|
||||
|
||||
## Getting help
|
||||
|
||||
Portainer CE is an open source project and is supported by the community. You can buy a supported version of Portainer at portainer.io
|
||||
|
||||
Learn more about Portainer's community support channels [here.](https://www.portainer.io/resources/get-help/get-support)
|
||||
Learn more about Portainer's community support channels [here.](https://www.portainer.io/get-support-for-portainer)
|
||||
|
||||
- Issues: https://github.com/portainer/portainer/issues
|
||||
- Slack (chat): [https://portainer.io/slack](https://portainer.io/slack)
|
||||
@@ -46,17 +49,17 @@ You can join the Portainer Community by visiting [https://www.portainer.io/join-
|
||||
|
||||
## Security
|
||||
|
||||
For information about reporting security vulnerabilities, please see our [Security Policy](SECURITY.md).
|
||||
- Here at Portainer, we believe in [responsible disclosure](https://en.wikipedia.org/wiki/Responsible_disclosure) of security issues. If you have found a security issue, please report it to <security@portainer.io>.
|
||||
|
||||
## Work for us
|
||||
|
||||
If you are a developer, and our code in this repo makes sense to you, we would love to hear from you. We are always on the hunt for awesome devs, either freelance or employed. Drop us a line to success@portainer.io with your details and/or visit our [careers page](https://apply.workable.com/portainer/).
|
||||
If you are a developer, and our code in this repo makes sense to you, we would love to hear from you. We are always on the hunt for awesome devs, either freelance or employed. Drop us a line to info@portainer.io with your details and/or visit our [careers page](https://portainer.io/careers).
|
||||
|
||||
## Privacy
|
||||
|
||||
**To make sure we focus our development effort in the right places we need to know which features get used most often. To give us this information we use [Matomo Analytics](https://matomo.org/), which is hosted in Germany and is fully GDPR compliant.**
|
||||
|
||||
When Portainer first starts, you are given the option to DISABLE analytics. If you **don't** choose to disable it, we collect anonymous usage as per [our privacy policy](https://www.portainer.io/legal/privacy-policy). **Please note**, there is no personally identifiable information sent or stored at any time and we only use the data to help us improve Portainer.
|
||||
When Portainer first starts, you are given the option to DISABLE analytics. If you **don't** choose to disable it, we collect anonymous usage as per [our privacy policy](https://www.portainer.io/privacy-policy). **Please note**, there is no personally identifiable information sent or stored at any time and we only use the data to help us improve Portainer.
|
||||
|
||||
## Limitations
|
||||
|
||||
|
||||
61
SECURITY.md
61
SECURITY.md
@@ -1,61 +0,0 @@
|
||||
# Security Policy
|
||||
|
||||
## Supported Versions
|
||||
|
||||
Portainer maintains both Short-Term Support (STS) and Long-Term Support (LTS) versions in accordance with our official [Portainer Lifecycle Policy](https://docs.portainer.io/start/lifecycle).
|
||||
|
||||
| Version Type | Support Status |
|
||||
| --- | --- |
|
||||
| LTS (Long-Term Support) | Supported for critical security fixes |
|
||||
| STS (Short-Term Support) | Supported until the next STS or LTS release |
|
||||
| Legacy / EOL | Not supported |
|
||||
|
||||
For a detailed breakdown of current versions and their specific End of Life (EOL) dates,
|
||||
please refer to the [Portainer Lifecycle Policy](https://docs.portainer.io/start/lifecycle).
|
||||
|
||||
## Reporting a Vulnerability
|
||||
|
||||
The Portainer team takes the security of our products seriously. If you believe you have found a security vulnerability in any Portainer-owned repository, please report it to us responsibly.
|
||||
|
||||
**Please do not report security vulnerabilities via public GitHub issues.**
|
||||
|
||||
### Disclosure Process
|
||||
|
||||
1. **Report**: You can report in one of two ways:
|
||||
|
||||
- **GitHub**: Use the **Report a vulnerability** button on the **Security** tab of this repository.
|
||||
|
||||
- **Email**: Send your findings to security@portainer.io.
|
||||
|
||||
2. **Details**: To help us verify the issue, please include:
|
||||
|
||||
- A description of the vulnerability and its potential impact.
|
||||
|
||||
- Step-by-step instructions to reproduce the issue (e.g. proof-of-concept code, scripts, or screenshots).
|
||||
|
||||
- The version of the software and the environment in which it was found.
|
||||
|
||||
3. **Acknowledge**: We will acknowledge receipt of your report and provide an initial assessment.
|
||||
|
||||
4. **Resolution**: We will work to resolve the issue as quickly as possible. We request that you do not disclose the vulnerability publicly until we have released a fix and notified affected users.
|
||||
|
||||
## Our Commitment
|
||||
|
||||
If you follow the responsible disclosure process, we will:
|
||||
|
||||
- Respond to your report in a timely manner.
|
||||
|
||||
- Provide an estimated timeline for remediation.
|
||||
|
||||
- Notify you when the vulnerability has been patched.
|
||||
|
||||
- Give credit for the discovery (if desired) once the fix is public.
|
||||
|
||||
|
||||
We will make every effort to promptly address any security weaknesses. Security advisories and fixes will be published through GitHub Security Advisories and other channels as needed.
|
||||
|
||||
Thank you for helping keep Portainer and our community secure.
|
||||
|
||||
## Resources
|
||||
|
||||
- [Contributing to Portainer](https://docs.portainer.io/contribute/contribute#contributing-to-the-portainer-ce-codebase)
|
||||
@@ -11,18 +11,20 @@ import (
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/url"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// GetAgentVersionAndPlatform returns the agent version and platform
|
||||
//
|
||||
// it sends a ping to the agent and parses the version and platform from the headers
|
||||
func GetAgentVersionAndPlatform(endpointUrl string, tlsConfig *tls.Config) (portainer.AgentPlatform, string, error) { //nolint:forbidigo
|
||||
httpCli := &http.Client{Timeout: 3 * time.Second}
|
||||
func GetAgentVersionAndPlatform(endpointUrl string, tlsConfig *tls.Config) (portainer.AgentPlatform, string, error) {
|
||||
httpCli := &http.Client{
|
||||
Timeout: 3 * time.Second,
|
||||
}
|
||||
|
||||
if tlsConfig != nil {
|
||||
httpCli.Transport = &http.Transport{TLSClientConfig: tlsConfig}
|
||||
httpCli.Transport = &http.Transport{
|
||||
TLSClientConfig: tlsConfig,
|
||||
}
|
||||
}
|
||||
|
||||
parsedURL, err := url.ParseURL(endpointUrl + "/ping")
|
||||
@@ -42,10 +44,8 @@ func GetAgentVersionAndPlatform(endpointUrl string, tlsConfig *tls.Config) (port
|
||||
return 0, "", err
|
||||
}
|
||||
|
||||
_, _ = io.Copy(io.Discard, resp.Body)
|
||||
if err := resp.Body.Close(); err != nil {
|
||||
log.Warn().Err(err).Msg("failed to close response body")
|
||||
}
|
||||
io.Copy(io.Discard, resp.Body)
|
||||
resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusNoContent {
|
||||
return 0, "", fmt.Errorf("Failed request with status %d", resp.StatusCode)
|
||||
|
||||
@@ -10,31 +10,31 @@ func Test_generateRandomKey(t *testing.T) {
|
||||
is := assert.New(t)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
wantLength int
|
||||
name string
|
||||
wantLenth int
|
||||
}{
|
||||
{
|
||||
name: "Generate a random key of length 16",
|
||||
wantLength: 16,
|
||||
name: "Generate a random key of length 16",
|
||||
wantLenth: 16,
|
||||
},
|
||||
{
|
||||
name: "Generate a random key of length 32",
|
||||
wantLength: 32,
|
||||
name: "Generate a random key of length 32",
|
||||
wantLenth: 32,
|
||||
},
|
||||
{
|
||||
name: "Generate a random key of length 64",
|
||||
wantLength: 64,
|
||||
name: "Generate a random key of length 64",
|
||||
wantLenth: 64,
|
||||
},
|
||||
{
|
||||
name: "Generate a random key of length 128",
|
||||
wantLength: 128,
|
||||
name: "Generate a random key of length 128",
|
||||
wantLenth: 128,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := GenerateRandomKey(tt.wantLength)
|
||||
is.Len(got, tt.wantLength)
|
||||
got := GenerateRandomKey(tt.wantLenth)
|
||||
is.Equal(tt.wantLenth, len(got))
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@@ -10,10 +10,9 @@ import (
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/datastore"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_SatisfiesAPIKeyServiceInterface(t *testing.T) {
|
||||
@@ -31,7 +30,7 @@ func Test_GenerateApiKey(t *testing.T) {
|
||||
t.Run("Successfully generates API key", func(t *testing.T) {
|
||||
desc := "test-1"
|
||||
rawKey, apiKey, err := service.GenerateApiKey(portainer.User{ID: 1}, desc)
|
||||
require.NoError(t, err)
|
||||
is.NoError(err)
|
||||
is.NotEmpty(rawKey)
|
||||
is.NotEmpty(apiKey)
|
||||
is.Equal(desc, apiKey.Description)
|
||||
@@ -39,7 +38,7 @@ func Test_GenerateApiKey(t *testing.T) {
|
||||
|
||||
t.Run("Api key prefix is 7 chars", func(t *testing.T) {
|
||||
rawKey, apiKey, err := service.GenerateApiKey(portainer.User{ID: 1}, "test-2")
|
||||
require.NoError(t, err)
|
||||
is.NoError(err)
|
||||
|
||||
is.Equal(rawKey[:7], apiKey.Prefix)
|
||||
is.Len(apiKey.Prefix, 7)
|
||||
@@ -47,7 +46,7 @@ func Test_GenerateApiKey(t *testing.T) {
|
||||
|
||||
t.Run("Api key has 'ptr_' as prefix", func(t *testing.T) {
|
||||
rawKey, _, err := service.GenerateApiKey(portainer.User{ID: 1}, "test-x")
|
||||
require.NoError(t, err)
|
||||
is.NoError(err)
|
||||
|
||||
is.Equal(portainerAPIKeyPrefix, "ptr_")
|
||||
is.True(strings.HasPrefix(rawKey, "ptr_"))
|
||||
@@ -56,7 +55,7 @@ func Test_GenerateApiKey(t *testing.T) {
|
||||
t.Run("Successfully caches API key", func(t *testing.T) {
|
||||
user := portainer.User{ID: 1}
|
||||
_, apiKey, err := service.GenerateApiKey(user, "test-3")
|
||||
require.NoError(t, err)
|
||||
is.NoError(err)
|
||||
|
||||
userFromCache, apiKeyFromCache, ok := service.cache.Get(apiKey.Digest)
|
||||
is.True(ok)
|
||||
@@ -66,7 +65,7 @@ func Test_GenerateApiKey(t *testing.T) {
|
||||
|
||||
t.Run("Decoded raw api-key digest matches generated digest", func(t *testing.T) {
|
||||
rawKey, apiKey, err := service.GenerateApiKey(portainer.User{ID: 1}, "test-4")
|
||||
require.NoError(t, err)
|
||||
is.NoError(err)
|
||||
|
||||
generatedDigest := sha256.Sum256([]byte(rawKey))
|
||||
|
||||
@@ -84,10 +83,10 @@ func Test_GetAPIKey(t *testing.T) {
|
||||
t.Run("Successfully returns all API keys", func(t *testing.T) {
|
||||
user := portainer.User{ID: 1}
|
||||
_, apiKey, err := service.GenerateApiKey(user, "test-1")
|
||||
require.NoError(t, err)
|
||||
is.NoError(err)
|
||||
|
||||
apiKeyGot, err := service.GetAPIKey(apiKey.ID)
|
||||
require.NoError(t, err)
|
||||
is.NoError(err)
|
||||
|
||||
is.Equal(apiKey, apiKeyGot)
|
||||
})
|
||||
@@ -103,12 +102,12 @@ func Test_GetAPIKeys(t *testing.T) {
|
||||
t.Run("Successfully returns all API keys", func(t *testing.T) {
|
||||
user := portainer.User{ID: 1}
|
||||
_, _, err := service.GenerateApiKey(user, "test-1")
|
||||
require.NoError(t, err)
|
||||
is.NoError(err)
|
||||
_, _, err = service.GenerateApiKey(user, "test-2")
|
||||
require.NoError(t, err)
|
||||
is.NoError(err)
|
||||
|
||||
keys, err := service.GetAPIKeys(user.ID)
|
||||
require.NoError(t, err)
|
||||
is.NoError(err)
|
||||
is.Len(keys, 2)
|
||||
})
|
||||
}
|
||||
@@ -123,10 +122,10 @@ func Test_GetDigestUserAndKey(t *testing.T) {
|
||||
t.Run("Successfully returns user and api key associated to digest", func(t *testing.T) {
|
||||
user := portainer.User{ID: 1}
|
||||
_, apiKey, err := service.GenerateApiKey(user, "test-1")
|
||||
require.NoError(t, err)
|
||||
is.NoError(err)
|
||||
|
||||
userGot, apiKeyGot, err := service.GetDigestUserAndKey(apiKey.Digest)
|
||||
require.NoError(t, err)
|
||||
is.NoError(err)
|
||||
is.Equal(user, userGot)
|
||||
is.Equal(*apiKey, apiKeyGot)
|
||||
})
|
||||
@@ -134,10 +133,10 @@ func Test_GetDigestUserAndKey(t *testing.T) {
|
||||
t.Run("Successfully caches user and api key associated to digest", func(t *testing.T) {
|
||||
user := portainer.User{ID: 1}
|
||||
_, apiKey, err := service.GenerateApiKey(user, "test-1")
|
||||
require.NoError(t, err)
|
||||
is.NoError(err)
|
||||
|
||||
userGot, apiKeyGot, err := service.GetDigestUserAndKey(apiKey.Digest)
|
||||
require.NoError(t, err)
|
||||
is.NoError(err)
|
||||
is.Equal(user, userGot)
|
||||
is.Equal(*apiKey, apiKeyGot)
|
||||
|
||||
@@ -157,19 +156,16 @@ func Test_UpdateAPIKey(t *testing.T) {
|
||||
|
||||
t.Run("Successfully updates the api-key LastUsed time", func(t *testing.T) {
|
||||
user := portainer.User{ID: 1}
|
||||
|
||||
err := store.User().Create(&user)
|
||||
require.NoError(t, err)
|
||||
|
||||
store.User().Create(&user)
|
||||
_, apiKey, err := service.GenerateApiKey(user, "test-x")
|
||||
require.NoError(t, err)
|
||||
is.NoError(err)
|
||||
|
||||
apiKey.LastUsed = time.Now().UTC().Unix()
|
||||
err = service.UpdateAPIKey(apiKey)
|
||||
require.NoError(t, err)
|
||||
is.NoError(err)
|
||||
|
||||
_, apiKeyGot, err := service.GetDigestUserAndKey(apiKey.Digest)
|
||||
require.NoError(t, err)
|
||||
is.NoError(err)
|
||||
|
||||
log.Debug().Str("wanted", fmt.Sprintf("%+v", apiKey)).Str("got", fmt.Sprintf("%+v", apiKeyGot)).Msg("")
|
||||
|
||||
@@ -178,7 +174,7 @@ func Test_UpdateAPIKey(t *testing.T) {
|
||||
|
||||
t.Run("Successfully updates api-key in cache upon api-key update", func(t *testing.T) {
|
||||
_, apiKey, err := service.GenerateApiKey(portainer.User{ID: 1}, "test-x2")
|
||||
require.NoError(t, err)
|
||||
is.NoError(err)
|
||||
|
||||
_, apiKeyFromCache, ok := service.cache.Get(apiKey.Digest)
|
||||
is.True(ok)
|
||||
@@ -188,7 +184,7 @@ func Test_UpdateAPIKey(t *testing.T) {
|
||||
is.NotEqual(*apiKey, apiKeyFromCache)
|
||||
|
||||
err = service.UpdateAPIKey(apiKey)
|
||||
require.NoError(t, err)
|
||||
is.NoError(err)
|
||||
|
||||
_, updatedAPIKeyFromCache, ok := service.cache.Get(apiKey.Digest)
|
||||
is.True(ok)
|
||||
@@ -206,30 +202,30 @@ func Test_DeleteAPIKey(t *testing.T) {
|
||||
t.Run("Successfully updates the api-key", func(t *testing.T) {
|
||||
user := portainer.User{ID: 1}
|
||||
_, apiKey, err := service.GenerateApiKey(user, "test-1")
|
||||
require.NoError(t, err)
|
||||
is.NoError(err)
|
||||
|
||||
_, apiKeyGot, err := service.GetDigestUserAndKey(apiKey.Digest)
|
||||
require.NoError(t, err)
|
||||
is.NoError(err)
|
||||
is.Equal(*apiKey, apiKeyGot)
|
||||
|
||||
err = service.DeleteAPIKey(apiKey.ID)
|
||||
require.NoError(t, err)
|
||||
is.NoError(err)
|
||||
|
||||
_, _, err = service.GetDigestUserAndKey(apiKey.Digest)
|
||||
require.Error(t, err)
|
||||
is.Error(err)
|
||||
})
|
||||
|
||||
t.Run("Successfully removes api-key from cache upon deletion", func(t *testing.T) {
|
||||
user := portainer.User{ID: 1}
|
||||
_, apiKey, err := service.GenerateApiKey(user, "test-1")
|
||||
require.NoError(t, err)
|
||||
is.NoError(err)
|
||||
|
||||
_, apiKeyFromCache, ok := service.cache.Get(apiKey.Digest)
|
||||
is.True(ok)
|
||||
is.Equal(*apiKey, apiKeyFromCache)
|
||||
|
||||
err = service.DeleteAPIKey(apiKey.ID)
|
||||
require.NoError(t, err)
|
||||
is.NoError(err)
|
||||
|
||||
_, _, ok = service.cache.Get(apiKey.Digest)
|
||||
is.False(ok)
|
||||
@@ -247,10 +243,10 @@ func Test_InvalidateUserKeyCache(t *testing.T) {
|
||||
// generate api keys
|
||||
user := portainer.User{ID: 1}
|
||||
_, apiKey1, err := service.GenerateApiKey(user, "test-1")
|
||||
require.NoError(t, err)
|
||||
is.NoError(err)
|
||||
|
||||
_, apiKey2, err := service.GenerateApiKey(user, "test-2")
|
||||
require.NoError(t, err)
|
||||
is.NoError(err)
|
||||
|
||||
// verify api keys are present in cache
|
||||
_, apiKeyFromCache, ok := service.cache.Get(apiKey1.Digest)
|
||||
@@ -277,11 +273,11 @@ func Test_InvalidateUserKeyCache(t *testing.T) {
|
||||
// generate keys for 2 users
|
||||
user1 := portainer.User{ID: 1}
|
||||
_, apiKey1, err := service.GenerateApiKey(user1, "test-1")
|
||||
require.NoError(t, err)
|
||||
is.NoError(err)
|
||||
|
||||
user2 := portainer.User{ID: 2}
|
||||
_, apiKey2, err := service.GenerateApiKey(user2, "test-2")
|
||||
require.NoError(t, err)
|
||||
is.NoError(err)
|
||||
|
||||
// verify keys in cache
|
||||
_, apiKeyFromCache, ok := service.cache.Get(apiKey1.Digest)
|
||||
|
||||
@@ -17,15 +17,18 @@ func TarFileInBuffer(fileContent []byte, fileName string, mode int64) ([]byte, e
|
||||
Size: int64(len(fileContent)),
|
||||
}
|
||||
|
||||
if err := tarWriter.WriteHeader(header); err != nil {
|
||||
err := tarWriter.WriteHeader(header)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if _, err := tarWriter.Write(fileContent); err != nil {
|
||||
_, err = tarWriter.Write(fileContent)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := tarWriter.Close(); err != nil {
|
||||
err = tarWriter.Close()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -40,7 +43,10 @@ type tarFileInBuffer struct {
|
||||
|
||||
func NewTarFileInBuffer() *tarFileInBuffer {
|
||||
var b bytes.Buffer
|
||||
return &tarFileInBuffer{b: &b, w: tar.NewWriter(&b)}
|
||||
return &tarFileInBuffer{
|
||||
b: &b,
|
||||
w: tar.NewWriter(&b),
|
||||
}
|
||||
}
|
||||
|
||||
// Put puts a single file to tar archive buffer.
|
||||
@@ -55,9 +61,11 @@ func (t *tarFileInBuffer) Put(fileContent []byte, fileName string, mode int64) e
|
||||
return err
|
||||
}
|
||||
|
||||
_, err := t.w.Write(fileContent)
|
||||
if _, err := t.w.Write(fileContent); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return err
|
||||
return nil
|
||||
}
|
||||
|
||||
// Bytes returns the archive as a byte array.
|
||||
|
||||
@@ -9,9 +9,6 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/portainer/portainer/api/filesystem"
|
||||
"github.com/portainer/portainer/api/logs"
|
||||
)
|
||||
|
||||
// TarGzDir creates a tar.gz archive and returns it's path.
|
||||
@@ -23,13 +20,12 @@ func TarGzDir(absolutePath string) (string, error) {
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer logs.CloseAndLogErr(outFile)
|
||||
defer outFile.Close()
|
||||
|
||||
zipWriter := gzip.NewWriter(outFile)
|
||||
defer logs.CloseAndLogErr(zipWriter)
|
||||
|
||||
defer zipWriter.Close()
|
||||
tarWriter := tar.NewWriter(zipWriter)
|
||||
defer logs.CloseAndLogErr(tarWriter)
|
||||
defer tarWriter.Close()
|
||||
|
||||
err = filepath.Walk(absolutePath, func(path string, info os.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
@@ -90,7 +86,7 @@ func ExtractTarGz(r io.Reader, outputDirPath string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer logs.CloseAndLogErr(zipReader)
|
||||
defer zipReader.Close()
|
||||
|
||||
tarReader := tar.NewReader(zipReader)
|
||||
|
||||
@@ -109,7 +105,7 @@ func ExtractTarGz(r io.Reader, outputDirPath string) error {
|
||||
case tar.TypeDir:
|
||||
// skip, dir will be created with a file
|
||||
case tar.TypeReg:
|
||||
p := filesystem.JoinPaths(outputDirPath, header.Name)
|
||||
p := filepath.Clean(filepath.Join(outputDirPath, header.Name))
|
||||
if err := os.MkdirAll(filepath.Dir(p), 0o744); err != nil {
|
||||
return fmt.Errorf("Failed to extract dir %s", filepath.Dir(p))
|
||||
}
|
||||
@@ -120,7 +116,7 @@ func ExtractTarGz(r io.Reader, outputDirPath string) error {
|
||||
if _, err := io.Copy(outFile, tarReader); err != nil {
|
||||
return fmt.Errorf("Failed to extract file %s", header.Name)
|
||||
}
|
||||
logs.CloseAndLogErr(outFile)
|
||||
outFile.Close()
|
||||
default:
|
||||
return fmt.Errorf("tar: unknown type: %v in %s",
|
||||
header.Typeflag,
|
||||
|
||||
@@ -1,34 +1,24 @@
|
||||
package archive
|
||||
|
||||
import (
|
||||
"archive/tar"
|
||||
"compress/gzip"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/portainer/portainer/api/filesystem"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func listFiles(dir string) []string {
|
||||
items := make([]string, 0)
|
||||
|
||||
if err := filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
|
||||
filepath.Walk(dir, func(path string, info os.FileInfo, err error) error {
|
||||
if path == dir {
|
||||
return nil
|
||||
}
|
||||
|
||||
items = append(items, path)
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
log.Warn().Err(err).Msg("failed to list files in directory")
|
||||
}
|
||||
})
|
||||
|
||||
return items
|
||||
}
|
||||
@@ -36,21 +26,13 @@ func listFiles(dir string) []string {
|
||||
func Test_shouldCreateArchive(t *testing.T) {
|
||||
tmpdir := t.TempDir()
|
||||
content := []byte("content")
|
||||
|
||||
err := os.WriteFile(path.Join(tmpdir, "outer"), content, 0600)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = os.MkdirAll(path.Join(tmpdir, "dir"), 0700)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = os.WriteFile(path.Join(tmpdir, "dir", ".dotfile"), content, 0600)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = os.WriteFile(path.Join(tmpdir, "dir", "inner"), content, 0600)
|
||||
require.NoError(t, err)
|
||||
os.WriteFile(path.Join(tmpdir, "outer"), content, 0600)
|
||||
os.MkdirAll(path.Join(tmpdir, "dir"), 0700)
|
||||
os.WriteFile(path.Join(tmpdir, "dir", ".dotfile"), content, 0600)
|
||||
os.WriteFile(path.Join(tmpdir, "dir", "inner"), content, 0600)
|
||||
|
||||
gzPath, err := TarGzDir(tmpdir)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, filepath.Join(tmpdir, filepath.Base(tmpdir)+".tar.gz"), gzPath)
|
||||
|
||||
extractionDir := t.TempDir()
|
||||
@@ -63,8 +45,7 @@ func Test_shouldCreateArchive(t *testing.T) {
|
||||
wasExtracted := func(p string) {
|
||||
fullpath := path.Join(extractionDir, p)
|
||||
assert.Contains(t, extractedFiles, fullpath)
|
||||
copyContent, err := os.ReadFile(fullpath)
|
||||
require.NoError(t, err)
|
||||
copyContent, _ := os.ReadFile(fullpath)
|
||||
assert.Equal(t, content, copyContent)
|
||||
}
|
||||
|
||||
@@ -76,21 +57,13 @@ func Test_shouldCreateArchive(t *testing.T) {
|
||||
func Test_shouldCreateArchive2(t *testing.T) {
|
||||
tmpdir := t.TempDir()
|
||||
content := []byte("content")
|
||||
|
||||
err := os.WriteFile(path.Join(tmpdir, "outer"), content, 0600)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = os.MkdirAll(path.Join(tmpdir, "dir"), 0700)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = os.WriteFile(path.Join(tmpdir, "dir", ".dotfile"), content, 0600)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = os.WriteFile(path.Join(tmpdir, "dir", "inner"), content, 0600)
|
||||
require.NoError(t, err)
|
||||
os.WriteFile(path.Join(tmpdir, "outer"), content, 0600)
|
||||
os.MkdirAll(path.Join(tmpdir, "dir"), 0700)
|
||||
os.WriteFile(path.Join(tmpdir, "dir", ".dotfile"), content, 0600)
|
||||
os.WriteFile(path.Join(tmpdir, "dir", "inner"), content, 0600)
|
||||
|
||||
gzPath, err := TarGzDir(tmpdir)
|
||||
require.NoError(t, err)
|
||||
assert.Nil(t, err)
|
||||
assert.Equal(t, filepath.Join(tmpdir, filepath.Base(tmpdir)+".tar.gz"), gzPath)
|
||||
|
||||
extractionDir := t.TempDir()
|
||||
@@ -111,56 +84,3 @@ func Test_shouldCreateArchive2(t *testing.T) {
|
||||
wasExtracted("dir/inner")
|
||||
wasExtracted("dir/.dotfile")
|
||||
}
|
||||
|
||||
func TestExtractTarGzPathTraversal(t *testing.T) {
|
||||
testDir := t.TempDir()
|
||||
|
||||
// Create an evil file with a path traversal attempt
|
||||
tarPath := filesystem.JoinPaths(testDir, "evil.tar.gz")
|
||||
|
||||
evilFile, err := os.Create(tarPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
gzWriter := gzip.NewWriter(evilFile)
|
||||
tarWriter := tar.NewWriter(gzWriter)
|
||||
|
||||
content := []byte("evil content")
|
||||
|
||||
header := &tar.Header{
|
||||
Name: "../evil.txt",
|
||||
Mode: 0600,
|
||||
Size: int64(len(content)),
|
||||
Typeflag: tar.TypeReg,
|
||||
}
|
||||
|
||||
err = tarWriter.WriteHeader(header)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = tarWriter.Write(content)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = tarWriter.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = gzWriter.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = evilFile.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Attempt to extract the evil file
|
||||
extractionDir := filesystem.JoinPaths(testDir, "extraction")
|
||||
err = os.Mkdir(extractionDir, 0700)
|
||||
require.NoError(t, err)
|
||||
|
||||
tarFile, err := os.Open(tarPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check that the file didn't escape
|
||||
err = ExtractTarGz(tarFile, extractionDir)
|
||||
require.NoError(t, err)
|
||||
require.NoFileExists(t, filesystem.JoinPaths(testDir, "evil.txt"))
|
||||
|
||||
err = tarFile.Close()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
@@ -2,17 +2,60 @@ package archive
|
||||
|
||||
import (
|
||||
"archive/zip"
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/portainer/portainer/api/logs"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// UnzipArchive will unzip an archive from bytes into the dest destination folder on disk
|
||||
func UnzipArchive(archiveData []byte, dest string) error {
|
||||
zipReader, err := zip.NewReader(bytes.NewReader(archiveData), int64(len(archiveData)))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, zipFile := range zipReader.File {
|
||||
err := extractFileFromArchive(zipFile, dest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func extractFileFromArchive(file *zip.File, dest string) error {
|
||||
f, err := file.Open()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer f.Close()
|
||||
|
||||
data, err := io.ReadAll(f)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fpath := filepath.Join(dest, file.Name)
|
||||
|
||||
outFile, err := os.OpenFile(fpath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, file.Mode())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = io.Copy(outFile, bytes.NewReader(data))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return outFile.Close()
|
||||
}
|
||||
|
||||
// UnzipFile will decompress a zip archive, moving all files and folders
|
||||
// within the zip file (parameter 1) to an output directory (parameter 2).
|
||||
func UnzipFile(src string, dest string) error {
|
||||
@@ -20,7 +63,7 @@ func UnzipFile(src string, dest string) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer logs.CloseAndLogErr(r)
|
||||
defer r.Close()
|
||||
|
||||
for _, f := range r.File {
|
||||
p := filepath.Join(dest, f.Name)
|
||||
@@ -32,14 +75,12 @@ func UnzipFile(src string, dest string) error {
|
||||
|
||||
if f.FileInfo().IsDir() {
|
||||
// Make Folder
|
||||
if err := os.MkdirAll(p, os.ModePerm); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
os.MkdirAll(p, os.ModePerm)
|
||||
continue
|
||||
}
|
||||
|
||||
if err := unzipFile(f, p); err != nil {
|
||||
err = unzipFile(f, p)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -52,20 +93,20 @@ func unzipFile(f *zip.File, p string) error {
|
||||
if err := os.MkdirAll(filepath.Dir(p), os.ModePerm); err != nil {
|
||||
return errors.Wrapf(err, "unzipFile: can't make a path %s", p)
|
||||
}
|
||||
|
||||
outFile, err := os.OpenFile(p, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, f.Mode())
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "unzipFile: can't create file %s", p)
|
||||
}
|
||||
defer logs.CloseAndLogErr(outFile)
|
||||
|
||||
defer outFile.Close()
|
||||
rc, err := f.Open()
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "unzipFile: can't open zip file %s in the archive", f.Name)
|
||||
}
|
||||
defer logs.CloseAndLogErr(rc)
|
||||
defer rc.Close()
|
||||
|
||||
if _, err = io.Copy(outFile, rc); err != nil {
|
||||
_, err = io.Copy(outFile, rc)
|
||||
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "unzipFile: can't copy an archived file content")
|
||||
}
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestUnzipFile(t *testing.T) {
|
||||
@@ -21,7 +20,7 @@ func TestUnzipFile(t *testing.T) {
|
||||
|
||||
err := UnzipFile("./testdata/sample_archive.zip", dir)
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.NoError(t, err)
|
||||
archiveDir := dir + "/sample_archive"
|
||||
assert.FileExists(t, filepath.Join(archiveDir, "0.txt"))
|
||||
assert.FileExists(t, filepath.Join(archiveDir, "0", "1.txt"))
|
||||
|
||||
@@ -6,15 +6,6 @@ import (
|
||||
"github.com/aws/aws-sdk-go-v2/service/ecr"
|
||||
)
|
||||
|
||||
// Registry represents an ECR registry endpoint information.
|
||||
// This struct is used to parse and validate ECR endpoint URLs.
|
||||
type Registry struct {
|
||||
ID string // AWS account ID (empty for accountless endpoints like "ecr-fips.us-west-1.amazonaws.com")
|
||||
FIPS bool // Whether this is a FIPS endpoint (contains "-fips" in the URL)
|
||||
Region string // AWS region (e.g., "us-east-1", "us-gov-west-1")
|
||||
Public bool // Whether this is ecr-public.aws.com
|
||||
}
|
||||
|
||||
type (
|
||||
Service struct {
|
||||
accessKey string
|
||||
|
||||
@@ -1,70 +0,0 @@
|
||||
package ecr
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// ecrEndpointPattern matches all valid ECR endpoints including account-prefixed and accountless formats.
|
||||
// Based on AWS ECR credential helper regex but extended to support accountless endpoints.
|
||||
//
|
||||
// Supported formats:
|
||||
// - Account-prefixed: 123456789012.dkr.ecr-fips.us-east-1.amazonaws.com
|
||||
// - Account-prefixed (hyphen): 123456789012.dkr-ecr-fips.us-west-1.on.aws
|
||||
// - Accountless service: ecr-fips.us-west-1.amazonaws.com
|
||||
// - Accountless API: ecr-fips.us-east-1.api.aws
|
||||
// - Non-FIPS variants: All formats above without "-fips"
|
||||
//
|
||||
// Regex groups:
|
||||
// - Group 1: Full account prefix (optional) - e.g., "123456789012.dkr." or "123456789012.dkr-"
|
||||
// - Group 2: Account ID (optional) - e.g., "123456789012"
|
||||
// - Group 3: FIPS flag (optional) - either "-fips" or empty string
|
||||
// - Group 4: Region - e.g., "us-east-1", "us-gov-west-1"
|
||||
// - Group 5: Domain suffix - e.g., "amazonaws.com", "api.aws"
|
||||
var ecrEndpointPattern = regexp.MustCompile(
|
||||
`^((\d{12})\.dkr[\.\-])?ecr(\-fips)?\.([a-zA-Z0-9][a-zA-Z0-9-_]*)\.(amazonaws\.(?:com(?:\.cn)?|eu)|api\.aws|on\.(?:aws|amazonwebservices\.com\.cn)|sc2s\.sgov\.gov|c2s\.ic\.gov|cloud\.adc-e\.uk|csp\.hci\.ic\.gov)$`,
|
||||
)
|
||||
|
||||
// ParseECREndpoint parses an ECR registry URL and extracts registry information.
|
||||
|
||||
// This function replaces the AWS ECR credential helper library's ExtractRegistry function,
|
||||
// which only supports account-prefixed endpoints.
|
||||
//
|
||||
// Reference: https://docs.aws.amazon.com/general/latest/gr/ecr.html
|
||||
func ParseECREndpoint(urlStr string) (*Registry, error) {
|
||||
// Normalize URL by adding https:// prefix if not present
|
||||
if !strings.HasPrefix(urlStr, "https://") && !strings.HasPrefix(urlStr, "http://") {
|
||||
urlStr = "https://" + urlStr
|
||||
}
|
||||
|
||||
u, err := url.Parse(urlStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid URL: %w", err)
|
||||
}
|
||||
|
||||
hostname := u.Hostname()
|
||||
|
||||
// Special case: ECR Public
|
||||
// ECR Public uses a different domain and doesn't have FIPS variant
|
||||
if hostname == "ecr-public.aws.com" {
|
||||
return &Registry{
|
||||
FIPS: false,
|
||||
Public: true,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Parse standard ECR endpoints using regex
|
||||
matches := ecrEndpointPattern.FindStringSubmatch(hostname)
|
||||
if len(matches) == 0 {
|
||||
return nil, fmt.Errorf("not a valid ECR endpoint: %s", hostname)
|
||||
}
|
||||
|
||||
return &Registry{
|
||||
ID: matches[2], // Account ID (may be empty for accountless endpoints)
|
||||
FIPS: matches[3] == "-fips", // Check if "-fips" is present
|
||||
Region: matches[4], // AWS region
|
||||
Public: false,
|
||||
}, nil
|
||||
}
|
||||
@@ -1,253 +0,0 @@
|
||||
package ecr
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestParseECREndpoint(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
url string
|
||||
want *Registry
|
||||
wantError bool
|
||||
}{
|
||||
// Standard AWS Commercial - Account-prefixed FIPS
|
||||
{
|
||||
name: "account-prefixed FIPS us-east-1",
|
||||
url: "123456789012.dkr.ecr-fips.us-east-1.amazonaws.com",
|
||||
want: &Registry{
|
||||
ID: "123456789012",
|
||||
FIPS: true,
|
||||
Region: "us-east-1",
|
||||
Public: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "account-prefixed FIPS us-west-2",
|
||||
url: "123456789012.dkr.ecr-fips.us-west-2.amazonaws.com",
|
||||
want: &Registry{
|
||||
ID: "123456789012",
|
||||
FIPS: true,
|
||||
Region: "us-west-2",
|
||||
Public: false,
|
||||
},
|
||||
},
|
||||
|
||||
// Accountless FIPS service endpoints
|
||||
{
|
||||
name: "accountless FIPS us-west-1",
|
||||
url: "ecr-fips.us-west-1.amazonaws.com",
|
||||
want: &Registry{
|
||||
ID: "",
|
||||
FIPS: true,
|
||||
Region: "us-west-1",
|
||||
Public: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "accountless FIPS us-east-2",
|
||||
url: "ecr-fips.us-east-2.amazonaws.com",
|
||||
want: &Registry{
|
||||
ID: "",
|
||||
FIPS: true,
|
||||
Region: "us-east-2",
|
||||
Public: false,
|
||||
},
|
||||
},
|
||||
|
||||
// Accountless FIPS API endpoints
|
||||
{
|
||||
name: "accountless FIPS API us-west-1",
|
||||
url: "ecr-fips.us-west-1.api.aws",
|
||||
want: &Registry{
|
||||
ID: "",
|
||||
FIPS: true,
|
||||
Region: "us-west-1",
|
||||
Public: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "accountless FIPS API us-east-1",
|
||||
url: "ecr-fips.us-east-1.api.aws",
|
||||
want: &Registry{
|
||||
ID: "",
|
||||
FIPS: true,
|
||||
Region: "us-east-1",
|
||||
Public: false,
|
||||
},
|
||||
},
|
||||
|
||||
// on.aws domain with hyphen separator
|
||||
{
|
||||
name: "account-prefixed FIPS hyphen us-west-1",
|
||||
url: "123456789012.dkr-ecr-fips.us-west-1.on.aws",
|
||||
want: &Registry{
|
||||
ID: "123456789012",
|
||||
FIPS: true,
|
||||
Region: "us-west-1",
|
||||
Public: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "account-prefixed FIPS hyphen us-east-2",
|
||||
url: "123456789012.dkr-ecr-fips.us-east-2.on.aws",
|
||||
want: &Registry{
|
||||
ID: "123456789012",
|
||||
FIPS: true,
|
||||
Region: "us-east-2",
|
||||
Public: false,
|
||||
},
|
||||
},
|
||||
|
||||
// AWS GovCloud
|
||||
{
|
||||
name: "account-prefixed FIPS us-gov-east-1",
|
||||
url: "123456789012.dkr.ecr-fips.us-gov-east-1.amazonaws.com",
|
||||
want: &Registry{
|
||||
ID: "123456789012",
|
||||
FIPS: true,
|
||||
Region: "us-gov-east-1",
|
||||
Public: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "account-prefixed FIPS us-gov-west-1",
|
||||
url: "123456789012.dkr.ecr-fips.us-gov-west-1.amazonaws.com",
|
||||
want: &Registry{
|
||||
ID: "123456789012",
|
||||
FIPS: true,
|
||||
Region: "us-gov-west-1",
|
||||
Public: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "accountless FIPS us-gov-west-1",
|
||||
url: "ecr-fips.us-gov-west-1.amazonaws.com",
|
||||
want: &Registry{
|
||||
ID: "",
|
||||
FIPS: true,
|
||||
Region: "us-gov-west-1",
|
||||
Public: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "accountless FIPS API us-gov-east-1",
|
||||
url: "ecr-fips.us-gov-east-1.api.aws",
|
||||
want: &Registry{
|
||||
ID: "",
|
||||
FIPS: true,
|
||||
Region: "us-gov-east-1",
|
||||
Public: false,
|
||||
},
|
||||
},
|
||||
|
||||
// ECR Public
|
||||
{
|
||||
name: "ecr-public",
|
||||
url: "ecr-public.aws.com",
|
||||
want: &Registry{
|
||||
ID: "",
|
||||
FIPS: false,
|
||||
Region: "",
|
||||
Public: true,
|
||||
},
|
||||
},
|
||||
|
||||
// Non-FIPS endpoints (valid ECR but FIPS=false)
|
||||
{
|
||||
name: "account-prefixed non-FIPS us-east-1",
|
||||
url: "123456789012.dkr.ecr.us-east-1.amazonaws.com",
|
||||
want: &Registry{
|
||||
ID: "123456789012",
|
||||
FIPS: false,
|
||||
Region: "us-east-1",
|
||||
Public: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "accountless non-FIPS us-west-1",
|
||||
url: "ecr.us-west-1.amazonaws.com",
|
||||
want: &Registry{
|
||||
ID: "",
|
||||
FIPS: false,
|
||||
Region: "us-west-1",
|
||||
Public: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "accountless non-FIPS API us-east-2",
|
||||
url: "ecr.us-east-2.api.aws",
|
||||
want: &Registry{
|
||||
ID: "",
|
||||
FIPS: false,
|
||||
Region: "us-east-2",
|
||||
Public: false,
|
||||
},
|
||||
},
|
||||
|
||||
// URLs with https:// prefix
|
||||
{
|
||||
name: "with https prefix",
|
||||
url: "https://ecr-fips.us-west-1.amazonaws.com",
|
||||
want: &Registry{
|
||||
ID: "",
|
||||
FIPS: true,
|
||||
Region: "us-west-1",
|
||||
Public: false,
|
||||
},
|
||||
},
|
||||
|
||||
// Invalid endpoints
|
||||
{
|
||||
name: "not an ECR URL",
|
||||
url: "not-an-ecr-url.com",
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "invalid account ID length",
|
||||
url: "123.dkr.ecr-fips.us-east-1.amazonaws.com",
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "empty string",
|
||||
url: "",
|
||||
wantError: true,
|
||||
},
|
||||
{
|
||||
name: "docker hub",
|
||||
url: "docker.io",
|
||||
wantError: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := ParseECREndpoint(tt.url)
|
||||
|
||||
if tt.wantError {
|
||||
if err == nil {
|
||||
t.Errorf("ParseECREndpoint() expected error but got none")
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
t.Errorf("ParseECREndpoint() unexpected error: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if got.ID != tt.want.ID {
|
||||
t.Errorf("ParseECREndpoint() ID = %v, want %v", got.ID, tt.want.ID)
|
||||
}
|
||||
if got.FIPS != tt.want.FIPS {
|
||||
t.Errorf("ParseECREndpoint() FIPS = %v, want %v", got.FIPS, tt.want.FIPS)
|
||||
}
|
||||
if got.Region != tt.want.Region {
|
||||
t.Errorf("ParseECREndpoint() Region = %v, want %v", got.Region, tt.want.Region)
|
||||
}
|
||||
if got.Public != tt.want.Public {
|
||||
t.Errorf("ParseECREndpoint() Public = %v, want %v", got.Public, tt.want.Public)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
"github.com/portainer/portainer/api/dataservices"
|
||||
"github.com/portainer/portainer/api/filesystem"
|
||||
"github.com/portainer/portainer/api/http/offlinegate"
|
||||
"github.com/portainer/portainer/api/logs"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/rs/zerolog/log"
|
||||
@@ -98,7 +97,7 @@ func encrypt(path string, passphrase string) (string, error) {
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer logs.CloseAndLogErr(in)
|
||||
defer in.Close()
|
||||
|
||||
outFileName := path + ".encrypted"
|
||||
out, err := os.Create(outFileName)
|
||||
@@ -106,5 +105,7 @@ func encrypt(path string, passphrase string) (string, error) {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return outFileName, crypto.AesEncrypt(in, out, []byte(passphrase))
|
||||
err = crypto.AesEncrypt(in, out, []byte(passphrase))
|
||||
|
||||
return outFileName, err
|
||||
}
|
||||
|
||||
@@ -16,8 +16,6 @@ import (
|
||||
"github.com/portainer/portainer/api/dataservices"
|
||||
"github.com/portainer/portainer/api/filesystem"
|
||||
"github.com/portainer/portainer/api/http/offlinegate"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
var filesToRestore = append(filesToBackup, "portainer.db")
|
||||
@@ -33,20 +31,17 @@ func RestoreArchive(archive io.Reader, password string, filestorePath string, ga
|
||||
}
|
||||
|
||||
restorePath := filepath.Join(filestorePath, "restore", time.Now().Format("20060102150405"))
|
||||
defer func() {
|
||||
if err := os.RemoveAll(filepath.Dir(restorePath)); err != nil {
|
||||
log.Warn().Err(err).Msg("failed to clean up restore files")
|
||||
}
|
||||
}()
|
||||
defer os.RemoveAll(filepath.Dir(restorePath))
|
||||
|
||||
if err := extractArchive(archive, restorePath); err != nil {
|
||||
err = extractArchive(archive, restorePath)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "cannot extract files from the archive. Please ensure the password is correct and try again")
|
||||
}
|
||||
|
||||
unlock := gate.Lock()
|
||||
defer unlock()
|
||||
|
||||
if err := datastore.Close(); err != nil {
|
||||
if err = datastore.Close(); err != nil {
|
||||
return errors.Wrap(err, "Failed to stop db")
|
||||
}
|
||||
|
||||
@@ -56,7 +51,7 @@ func RestoreArchive(archive io.Reader, password string, filestorePath string, ga
|
||||
return errors.Wrap(err, "failed to restore from backup. Portainer database missing from backup file")
|
||||
}
|
||||
|
||||
if err := restoreFiles(restorePath, filestorePath); err != nil {
|
||||
if err = restoreFiles(restorePath, filestorePath); err != nil {
|
||||
return errors.Wrap(err, "failed to restore the system state")
|
||||
}
|
||||
|
||||
@@ -94,7 +89,8 @@ func getRestoreSourcePath(dir string) (string, error) {
|
||||
|
||||
func restoreFiles(srcDir string, destinationDir string) error {
|
||||
for _, filename := range filesToRestore {
|
||||
if err := filesystem.CopyPath(filepath.Join(srcDir, filename), destinationDir); err != nil {
|
||||
err := filesystem.CopyPath(filepath.Join(srcDir, filename), destinationDir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -102,18 +98,14 @@ func restoreFiles(srcDir string, destinationDir string) error {
|
||||
// TODO: This is very boltdb module specific once again due to the filename. Move to bolt module? Refactor for another day
|
||||
|
||||
// Prevent the possibility of having both databases. Remove any default new instance
|
||||
if err := os.Remove(filepath.Join(destinationDir, boltdb.DatabaseFileName)); err != nil && !os.IsNotExist(err) {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := os.Remove(filepath.Join(destinationDir, boltdb.EncryptedDatabaseFileName)); err != nil && !os.IsNotExist(err) {
|
||||
return err
|
||||
}
|
||||
os.Remove(filepath.Join(destinationDir, boltdb.DatabaseFileName))
|
||||
os.Remove(filepath.Join(destinationDir, boltdb.EncryptedDatabaseFileName))
|
||||
|
||||
// Now copy the database. It'll be either portainer.db or portainer.edb
|
||||
|
||||
// Note: CopyPath does not return an error if the source file doesn't exist
|
||||
if err := filesystem.CopyPath(filepath.Join(srcDir, boltdb.EncryptedDatabaseFileName), destinationDir); err != nil {
|
||||
err := filesystem.CopyPath(filepath.Join(srcDir, boltdb.EncryptedDatabaseFileName), destinationDir)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
@@ -54,8 +54,8 @@ func ecdsaGenerateKey(c elliptic.Curve, rand io.Reader) (*ecdsa.PrivateKey, erro
|
||||
}
|
||||
|
||||
priv := new(ecdsa.PrivateKey)
|
||||
priv.Curve = c
|
||||
priv.PublicKey.Curve = c
|
||||
priv.D = k
|
||||
priv.X, priv.Y = c.ScalarBaseMult(k.Bytes())
|
||||
priv.PublicKey.X, priv.PublicKey.Y = c.ScalarBaseMult(k.Bytes())
|
||||
return priv, nil
|
||||
}
|
||||
|
||||
@@ -89,8 +89,10 @@ func (service *Service) pingAgent(endpointID portainer.EndpointID) error {
|
||||
return err
|
||||
}
|
||||
|
||||
_, _ = io.Copy(io.Discard, resp.Body)
|
||||
return resp.Body.Close()
|
||||
io.Copy(io.Discard, resp.Body)
|
||||
resp.Body.Close()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// KeepTunnelAlive keeps the tunnel of the given environment for maxAlive duration, or until ctx is done
|
||||
|
||||
@@ -9,15 +9,10 @@ import (
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/datastore"
|
||||
"github.com/portainer/portainer/pkg/fips"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func init() {
|
||||
fips.InitFIPS(false)
|
||||
}
|
||||
|
||||
func TestPingAgentPanic(t *testing.T) {
|
||||
endpoint := &portainer.Endpoint{
|
||||
ID: 1,
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -13,7 +14,6 @@ import (
|
||||
"github.com/portainer/portainer/api/internal/edge/cache"
|
||||
"github.com/portainer/portainer/api/internal/endpointutils"
|
||||
"github.com/portainer/portainer/pkg/libcrypto"
|
||||
"github.com/portainer/portainer/pkg/librand"
|
||||
|
||||
"github.com/dchest/uniuri"
|
||||
"github.com/rs/zerolog/log"
|
||||
@@ -142,9 +142,7 @@ func (s *Service) TunnelAddr(endpoint *portainer.Endpoint) (string, error) {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Warn().Err(err).Msg("failed to close tcp connection")
|
||||
}
|
||||
conn.Close()
|
||||
|
||||
break
|
||||
}
|
||||
@@ -202,9 +200,7 @@ func (service *Service) getUnusedPort() int {
|
||||
|
||||
conn, err := net.DialTCP("tcp", nil, &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: port})
|
||||
if err == nil {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Warn().Msg("failed to close tcp connection that checks if port is free")
|
||||
}
|
||||
conn.Close()
|
||||
|
||||
log.Debug().
|
||||
Int("port", port).
|
||||
@@ -217,7 +213,7 @@ func (service *Service) getUnusedPort() int {
|
||||
}
|
||||
|
||||
func randomInt(min, max int) int {
|
||||
return min + librand.Intn(max-min)
|
||||
return min + rand.Intn(max-min)
|
||||
}
|
||||
|
||||
func generateRandomCredentials() (string, string) {
|
||||
|
||||
@@ -1,79 +0,0 @@
|
||||
package chisel
|
||||
|
||||
import (
|
||||
"net"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/dataservices"
|
||||
)
|
||||
|
||||
type testSettingsService struct {
|
||||
dataservices.SettingsService
|
||||
}
|
||||
|
||||
func (s *testSettingsService) Settings() (*portainer.Settings, error) {
|
||||
return &portainer.Settings{
|
||||
EdgeAgentCheckinInterval: 1,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type testStore struct {
|
||||
dataservices.DataStore
|
||||
}
|
||||
|
||||
func (s *testStore) Settings() dataservices.SettingsService {
|
||||
return &testSettingsService{}
|
||||
}
|
||||
|
||||
func TestGetUnusedPort(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
existingTunnels map[portainer.EndpointID]*portainer.TunnelDetails
|
||||
expectedError error
|
||||
}{
|
||||
{
|
||||
name: "simple case",
|
||||
},
|
||||
{
|
||||
name: "existing tunnels",
|
||||
existingTunnels: map[portainer.EndpointID]*portainer.TunnelDetails{
|
||||
portainer.EndpointID(1): {
|
||||
Port: 53072,
|
||||
},
|
||||
portainer.EndpointID(2): {
|
||||
Port: 63072,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
store := &testStore{}
|
||||
s := NewService(store, nil, nil)
|
||||
s.activeTunnels = tc.existingTunnels
|
||||
port := s.getUnusedPort()
|
||||
|
||||
if port < 49152 || port > 65535 {
|
||||
t.Fatalf("Expected port to be inbetween 49152 and 65535 but got %d", port)
|
||||
}
|
||||
|
||||
for _, tun := range tc.existingTunnels {
|
||||
if tun.Port == port {
|
||||
t.Fatalf("returned port %d already has an existing tunnel", port)
|
||||
}
|
||||
}
|
||||
|
||||
conn, err := net.DialTCP("tcp", nil, &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: port})
|
||||
if err == nil {
|
||||
// Ignore error
|
||||
_ = conn.Close()
|
||||
t.Fatalf("expected port %d to be unused", port)
|
||||
} else if !strings.Contains(err.Error(), "connection refused") {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -9,8 +9,8 @@ import (
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
|
||||
"github.com/alecthomas/kingpin/v2"
|
||||
"github.com/rs/zerolog/log"
|
||||
"gopkg.in/alecthomas/kingpin.v2"
|
||||
)
|
||||
|
||||
// Service implements the CLIService interface
|
||||
@@ -32,12 +32,19 @@ func CLIFlags() *portainer.CLIFlags {
|
||||
Assets: kingpin.Flag("assets", "Path to the assets").Default(defaultAssetsDirectory).Short('a').String(),
|
||||
Data: kingpin.Flag("data", "Path to the folder where the data is stored").Default(defaultDataDirectory).Short('d').String(),
|
||||
EndpointURL: kingpin.Flag("host", "Environment URL").Short('H').String(),
|
||||
FeatureFlags: kingpin.Flag("feat", "List of feature flags").Envar(portainer.FeatureFlagEnvVar).Strings(),
|
||||
FeatureFlags: kingpin.Flag("feat", "List of feature flags").Strings(),
|
||||
EnableEdgeComputeFeatures: kingpin.Flag("edge-compute", "Enable Edge Compute features").Bool(),
|
||||
NoAnalytics: kingpin.Flag("no-analytics", "Disable Analytics in app (deprecated)").Bool(),
|
||||
TLS: kingpin.Flag("tlsverify", "TLS support").Default(defaultTLS).Bool(),
|
||||
TLSSkipVerify: kingpin.Flag("tlsskipverify", "Disable TLS server verification").Default(defaultTLSSkipVerify).Bool(),
|
||||
TLSCacert: kingpin.Flag("tlscacert", "Path to the CA").Default(defaultTLSCACertPath).String(),
|
||||
TLSCert: kingpin.Flag("tlscert", "Path to the TLS certificate file").Default(defaultTLSCertPath).String(),
|
||||
TLSKey: kingpin.Flag("tlskey", "Path to the TLS key").Default(defaultTLSKeyPath).String(),
|
||||
HTTPDisabled: kingpin.Flag("http-disabled", "Serve portainer only on https").Default(defaultHTTPDisabled).Bool(),
|
||||
HTTPEnabled: kingpin.Flag("http-enabled", "Serve portainer on http").Default(defaultHTTPEnabled).Bool(),
|
||||
SSL: kingpin.Flag("ssl", "Secure Portainer instance using SSL (deprecated)").Default(defaultSSL).Bool(),
|
||||
SSLCert: kingpin.Flag("sslcert", "Path to the SSL certificate used to secure the Portainer instance").String(),
|
||||
SSLKey: kingpin.Flag("sslkey", "Path to the SSL key used to secure the Portainer instance").String(),
|
||||
Rollback: kingpin.Flag("rollback", "Rollback the database to the previous backup").Bool(),
|
||||
SnapshotInterval: kingpin.Flag("snapshot-interval", "Duration between each environment snapshot job").String(),
|
||||
AdminPassword: kingpin.Flag("admin-password", "Set admin password with provided hash").String(),
|
||||
@@ -52,53 +59,17 @@ func CLIFlags() *portainer.CLIFlags {
|
||||
SecretKeyName: kingpin.Flag("secret-key-name", "Secret key name for encryption and will be used as /run/secrets/<secret-key-name>.").Default(defaultSecretKeyName).String(),
|
||||
LogLevel: kingpin.Flag("log-level", "Set the minimum logging level to show").Default("INFO").Enum("DEBUG", "INFO", "WARN", "ERROR"),
|
||||
LogMode: kingpin.Flag("log-mode", "Set the logging output mode").Default("PRETTY").Enum("NOCOLOR", "PRETTY", "JSON"),
|
||||
KubectlShellImage: kingpin.Flag("kubectl-shell-image", "Kubectl shell image").Envar(portainer.KubectlShellImageEnvVar).Default(portainer.DefaultKubectlShellImage).String(),
|
||||
PullLimitCheckDisabled: kingpin.Flag("pull-limit-check-disabled", "Pull limit check").Envar(portainer.PullLimitCheckDisabledEnvVar).Default(defaultPullLimitCheckDisabled).Bool(),
|
||||
TrustedOrigins: kingpin.Flag("trusted-origins", "List of trusted origins for CSRF protection. Separate multiple origins with a comma.").Envar(portainer.TrustedOriginsEnvVar).String(),
|
||||
CSP: kingpin.Flag("csp", "Content Security Policy (CSP) header").Envar(portainer.CSPEnvVar).Default("true").Bool(),
|
||||
CompactDB: kingpin.Flag("compact-db", "Enable database compaction on startup").Envar(portainer.CompactDBEnvVar).Default("false").Bool(),
|
||||
}
|
||||
}
|
||||
|
||||
// ParseFlags parse the CLI flags and return a portainer.Flags struct
|
||||
func (Service) ParseFlags(version string) (*portainer.CLIFlags, error) {
|
||||
func (*Service) ParseFlags(version string) (*portainer.CLIFlags, error) {
|
||||
kingpin.Version(version)
|
||||
|
||||
var hasSSLFlag, hasSSLCertFlag, hasSSLKeyFlag bool
|
||||
sslFlag := kingpin.Flag(
|
||||
"ssl",
|
||||
"Secure Portainer instance using SSL (deprecated)",
|
||||
).Default(defaultSSL).IsSetByUser(&hasSSLFlag)
|
||||
ssl := sslFlag.Bool()
|
||||
sslCertFlag := kingpin.Flag(
|
||||
"sslcert",
|
||||
"Path to the SSL certificate used to secure the Portainer instance",
|
||||
).IsSetByUser(&hasSSLCertFlag)
|
||||
sslCert := sslCertFlag.String()
|
||||
sslKeyFlag := kingpin.Flag(
|
||||
"sslkey",
|
||||
"Path to the SSL key used to secure the Portainer instance",
|
||||
).IsSetByUser(&hasSSLKeyFlag)
|
||||
sslKey := sslKeyFlag.String()
|
||||
|
||||
flags := CLIFlags()
|
||||
|
||||
var hasTLSFlag, hasTLSCertFlag, hasTLSKeyFlag bool
|
||||
tlsFlag := kingpin.Flag("tlsverify", "TLS support").Default(defaultTLS).IsSetByUser(&hasTLSFlag)
|
||||
flags.TLS = tlsFlag.Bool()
|
||||
tlsCertFlag := kingpin.Flag(
|
||||
"tlscert",
|
||||
"Path to the TLS certificate file",
|
||||
).Default(defaultTLSCertPath).IsSetByUser(&hasTLSCertFlag)
|
||||
flags.TLSCert = tlsCertFlag.String()
|
||||
tlsKeyFlag := kingpin.Flag("tlskey", "Path to the TLS key").Default(defaultTLSKeyPath).IsSetByUser(&hasTLSKeyFlag)
|
||||
flags.TLSKey = tlsKeyFlag.String()
|
||||
flags.TLSCacert = kingpin.Flag("tlscacert", "Path to the CA").Default(defaultTLSCACertPath).String()
|
||||
|
||||
flags.KubectlShellImage = kingpin.Flag(
|
||||
"kubectl-shell-image",
|
||||
"Kubectl shell image",
|
||||
).Envar(portainer.KubectlShellImageEnvVar).Default(portainer.DefaultKubectlShellImage).String()
|
||||
|
||||
kingpin.Parse()
|
||||
|
||||
if !filepath.IsAbs(*flags.Assets) {
|
||||
@@ -110,46 +81,11 @@ func (Service) ParseFlags(version string) (*portainer.CLIFlags, error) {
|
||||
*flags.Assets = filepath.Join(filepath.Dir(ex), *flags.Assets)
|
||||
}
|
||||
|
||||
// If the user didn't provide a tls flag remove the defaults to match previous behaviour
|
||||
if !hasTLSFlag {
|
||||
if !hasTLSCertFlag {
|
||||
*flags.TLSCert = ""
|
||||
}
|
||||
|
||||
if !hasTLSKeyFlag {
|
||||
*flags.TLSKey = ""
|
||||
}
|
||||
}
|
||||
|
||||
if hasSSLFlag {
|
||||
log.Warn().Msgf("the %q flag is deprecated. use %q instead.", sslFlag.Model().Name, tlsFlag.Model().Name)
|
||||
|
||||
if !hasTLSFlag {
|
||||
flags.TLS = ssl
|
||||
}
|
||||
}
|
||||
|
||||
if hasSSLCertFlag {
|
||||
log.Warn().Msgf("the %q flag is deprecated. use %q instead.", sslCertFlag.Model().Name, tlsCertFlag.Model().Name)
|
||||
|
||||
if !hasTLSCertFlag {
|
||||
flags.TLSCert = sslCert
|
||||
}
|
||||
}
|
||||
|
||||
if hasSSLKeyFlag {
|
||||
log.Warn().Msgf("the %q flag is deprecated. use %q instead.", sslKeyFlag.Model().Name, tlsKeyFlag.Model().Name)
|
||||
|
||||
if !hasTLSKeyFlag {
|
||||
flags.TLSKey = sslKey
|
||||
}
|
||||
}
|
||||
|
||||
return flags, nil
|
||||
}
|
||||
|
||||
// ValidateFlags validates the values of the flags.
|
||||
func (Service) ValidateFlags(flags *portainer.CLIFlags) error {
|
||||
func (*Service) ValidateFlags(flags *portainer.CLIFlags) error {
|
||||
displayDeprecationWarnings(flags)
|
||||
|
||||
if err := validateEndpointURL(*flags.EndpointURL); err != nil {
|
||||
@@ -171,6 +107,10 @@ func displayDeprecationWarnings(flags *portainer.CLIFlags) {
|
||||
if *flags.NoAnalytics {
|
||||
log.Warn().Msg("the --no-analytics flag has been kept to allow migration of instances running a previous version of Portainer with this flag enabled, to version 2.0 where enabling this flag will have no effect")
|
||||
}
|
||||
|
||||
if *flags.SSL {
|
||||
log.Warn().Msg("SSL is enabled by default and there is no need for the --ssl flag, it has been kept to allow migration of instances running a previous version of Portainer with this flag enabled")
|
||||
}
|
||||
}
|
||||
|
||||
func validateEndpointURL(endpointURL string) error {
|
||||
|
||||
@@ -1,209 +0,0 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
zerolog "github.com/rs/zerolog/log"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestOptionParser(t *testing.T) {
|
||||
p := Service{}
|
||||
require.NotNil(t, p)
|
||||
|
||||
a := os.Args
|
||||
defer func() { os.Args = a }()
|
||||
|
||||
os.Args = []string{"portainer", "--edge-compute"}
|
||||
|
||||
opts, err := p.ParseFlags("2.34.5")
|
||||
require.NoError(t, err)
|
||||
|
||||
require.False(t, *opts.HTTPDisabled)
|
||||
require.True(t, *opts.EnableEdgeComputeFeatures)
|
||||
}
|
||||
|
||||
func TestParseTLSFlags(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
args []string
|
||||
expectedTLSFlag bool
|
||||
expectedTLSCertFlag string
|
||||
expectedTLSKeyFlag string
|
||||
expectedLogMessages []string
|
||||
}{
|
||||
{
|
||||
name: "no flags",
|
||||
expectedTLSFlag: false,
|
||||
expectedTLSCertFlag: "",
|
||||
expectedTLSKeyFlag: "",
|
||||
},
|
||||
{
|
||||
name: "only ssl flag",
|
||||
args: []string{
|
||||
"portainer",
|
||||
"--ssl",
|
||||
},
|
||||
expectedTLSFlag: true,
|
||||
expectedTLSCertFlag: "",
|
||||
expectedTLSKeyFlag: "",
|
||||
},
|
||||
{
|
||||
name: "only tls flag",
|
||||
args: []string{
|
||||
"portainer",
|
||||
"--tlsverify",
|
||||
},
|
||||
expectedTLSFlag: true,
|
||||
expectedTLSCertFlag: defaultTLSCertPath,
|
||||
expectedTLSKeyFlag: defaultTLSKeyPath,
|
||||
},
|
||||
{
|
||||
name: "partial ssl flags",
|
||||
args: []string{
|
||||
"portainer",
|
||||
"--ssl",
|
||||
"--sslcert=ssl-cert-flag-value",
|
||||
},
|
||||
expectedTLSFlag: true,
|
||||
expectedTLSCertFlag: "ssl-cert-flag-value",
|
||||
expectedTLSKeyFlag: "",
|
||||
},
|
||||
{
|
||||
name: "partial tls flags",
|
||||
args: []string{
|
||||
"portainer",
|
||||
"--tlsverify",
|
||||
"--tlscert=tls-cert-flag-value",
|
||||
},
|
||||
expectedTLSFlag: true,
|
||||
expectedTLSCertFlag: "tls-cert-flag-value",
|
||||
expectedTLSKeyFlag: defaultTLSKeyPath,
|
||||
},
|
||||
{
|
||||
name: "partial tls and ssl flags",
|
||||
args: []string{
|
||||
"portainer",
|
||||
"--tlsverify",
|
||||
"--tlscert=tls-cert-flag-value",
|
||||
"--sslkey=ssl-key-flag-value",
|
||||
},
|
||||
expectedTLSFlag: true,
|
||||
expectedTLSCertFlag: "tls-cert-flag-value",
|
||||
expectedTLSKeyFlag: "ssl-key-flag-value",
|
||||
},
|
||||
{
|
||||
name: "partial tls and ssl flags 2",
|
||||
args: []string{
|
||||
"portainer",
|
||||
"--ssl",
|
||||
"--tlscert=tls-cert-flag-value",
|
||||
"--sslkey=ssl-key-flag-value",
|
||||
},
|
||||
expectedTLSFlag: true,
|
||||
expectedTLSCertFlag: "tls-cert-flag-value",
|
||||
expectedTLSKeyFlag: "ssl-key-flag-value",
|
||||
},
|
||||
{
|
||||
name: "ssl flags",
|
||||
args: []string{
|
||||
"portainer",
|
||||
"--ssl",
|
||||
"--sslcert=ssl-cert-flag-value",
|
||||
"--sslkey=ssl-key-flag-value",
|
||||
},
|
||||
expectedTLSFlag: true,
|
||||
expectedTLSCertFlag: "ssl-cert-flag-value",
|
||||
expectedTLSKeyFlag: "ssl-key-flag-value",
|
||||
expectedLogMessages: []string{
|
||||
"the \\\"ssl\\\" flag is deprecated. use \\\"tlsverify\\\" instead.",
|
||||
"the \\\"sslcert\\\" flag is deprecated. use \\\"tlscert\\\" instead.",
|
||||
"the \\\"sslkey\\\" flag is deprecated. use \\\"tlskey\\\" instead.",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "tls flags",
|
||||
args: []string{
|
||||
"portainer",
|
||||
"--tlsverify",
|
||||
"--tlscert=tls-cert-flag-value",
|
||||
"--tlskey=tls-key-flag-value",
|
||||
},
|
||||
expectedTLSFlag: true,
|
||||
expectedTLSCertFlag: "tls-cert-flag-value",
|
||||
expectedTLSKeyFlag: "tls-key-flag-value",
|
||||
},
|
||||
{
|
||||
name: "tls and ssl flags",
|
||||
args: []string{
|
||||
"portainer",
|
||||
"--tlsverify",
|
||||
"--tlscert=tls-cert-flag-value",
|
||||
"--tlskey=tls-key-flag-value",
|
||||
"--ssl",
|
||||
"--sslcert=ssl-cert-flag-value",
|
||||
"--sslkey=ssl-key-flag-value",
|
||||
},
|
||||
expectedTLSFlag: true,
|
||||
expectedTLSCertFlag: "tls-cert-flag-value",
|
||||
expectedTLSKeyFlag: "tls-key-flag-value",
|
||||
expectedLogMessages: []string{
|
||||
"the \\\"ssl\\\" flag is deprecated. use \\\"tlsverify\\\" instead.",
|
||||
"the \\\"sslcert\\\" flag is deprecated. use \\\"tlscert\\\" instead.",
|
||||
"the \\\"sslkey\\\" flag is deprecated. use \\\"tlskey\\\" instead.",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
var logOutput strings.Builder
|
||||
setupLogOutput(t, &logOutput)
|
||||
|
||||
if tc.args == nil {
|
||||
tc.args = []string{"portainer"}
|
||||
}
|
||||
setOsArgs(t, tc.args)
|
||||
|
||||
s := Service{}
|
||||
flags, err := s.ParseFlags("test-version")
|
||||
if err != nil {
|
||||
t.Fatalf("error parsing flags: %v", err)
|
||||
}
|
||||
|
||||
if flags.TLS == nil {
|
||||
t.Fatal("TLS flag was nil")
|
||||
}
|
||||
|
||||
require.Equal(t, tc.expectedTLSFlag, *flags.TLS, "tlsverify flag didn't match")
|
||||
require.Equal(t, tc.expectedTLSCertFlag, *flags.TLSCert, "tlscert flag didn't match")
|
||||
require.Equal(t, tc.expectedTLSKeyFlag, *flags.TLSKey, "tlskey flag didn't match")
|
||||
|
||||
for _, expectedLogMessage := range tc.expectedLogMessages {
|
||||
require.Contains(t, logOutput.String(), expectedLogMessage, "Log didn't contain expected message")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func setOsArgs(t *testing.T, args []string) {
|
||||
t.Helper()
|
||||
previousArgs := os.Args
|
||||
os.Args = args
|
||||
t.Cleanup(func() {
|
||||
os.Args = previousArgs
|
||||
})
|
||||
}
|
||||
|
||||
func setupLogOutput(t *testing.T, w io.Writer) {
|
||||
t.Helper()
|
||||
|
||||
oldLogger := zerolog.Logger
|
||||
zerolog.Logger = zerolog.Output(w)
|
||||
t.Cleanup(func() {
|
||||
zerolog.Logger = oldLogger
|
||||
})
|
||||
}
|
||||
@@ -1,4 +1,5 @@
|
||||
//go:build !windows
|
||||
// +build !windows
|
||||
|
||||
package cli
|
||||
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/alecthomas/kingpin/v2"
|
||||
"gopkg.in/alecthomas/kingpin.v2"
|
||||
)
|
||||
|
||||
type pairList []portainer.Pair
|
||||
|
||||
45
api/cli/pairlistbool.go
Normal file
45
api/cli/pairlistbool.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package cli
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
|
||||
"gopkg.in/alecthomas/kingpin.v2"
|
||||
)
|
||||
|
||||
type pairListBool []portainer.Pair
|
||||
|
||||
// Set implementation for a list of portainer.Pair
|
||||
func (l *pairListBool) Set(value string) error {
|
||||
p := new(portainer.Pair)
|
||||
|
||||
// default to true. example setting=true is equivalent to setting
|
||||
parts := strings.SplitN(value, "=", 2)
|
||||
if len(parts) != 2 {
|
||||
p.Name = parts[0]
|
||||
p.Value = "true"
|
||||
} else {
|
||||
p.Name = parts[0]
|
||||
p.Value = parts[1]
|
||||
}
|
||||
|
||||
*l = append(*l, *p)
|
||||
return nil
|
||||
}
|
||||
|
||||
// String implementation for a list of pair
|
||||
func (l *pairListBool) String() string {
|
||||
return ""
|
||||
}
|
||||
|
||||
// IsCumulative implementation for a list of pair
|
||||
func (l *pairListBool) IsCumulative() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
func BoolPairs(s kingpin.Settings) (target *[]portainer.Pair) {
|
||||
target = new([]portainer.Pair)
|
||||
s.SetValue((*pairListBool)(target))
|
||||
return
|
||||
}
|
||||
@@ -1,8 +1,7 @@
|
||||
package logs
|
||||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
stdlog "log"
|
||||
"os"
|
||||
|
||||
@@ -11,7 +10,7 @@ import (
|
||||
"github.com/rs/zerolog/pkgerrors"
|
||||
)
|
||||
|
||||
func ConfigureLogger() {
|
||||
func configureLogger() {
|
||||
zerolog.ErrorStackFieldName = "stack_trace"
|
||||
zerolog.ErrorStackMarshaler = pkgerrors.MarshalStack
|
||||
zerolog.TimeFieldFormat = zerolog.TimeFormatUnix
|
||||
@@ -22,7 +21,7 @@ func ConfigureLogger() {
|
||||
log.Logger = log.Logger.With().Caller().Stack().Logger()
|
||||
}
|
||||
|
||||
func SetLoggingLevel(level string) {
|
||||
func setLoggingLevel(level string) {
|
||||
switch level {
|
||||
case "ERROR":
|
||||
zerolog.SetGlobalLevel(zerolog.ErrorLevel)
|
||||
@@ -35,7 +34,7 @@ func SetLoggingLevel(level string) {
|
||||
}
|
||||
}
|
||||
|
||||
func SetLoggingMode(mode string) {
|
||||
func setLoggingMode(mode string) {
|
||||
switch mode {
|
||||
case "PRETTY":
|
||||
log.Logger = log.Output(zerolog.ConsoleWriter{
|
||||
@@ -62,9 +61,3 @@ func formatMessage(i any) string {
|
||||
|
||||
return fmt.Sprintf("%s |", i)
|
||||
}
|
||||
|
||||
func CloseAndLogErr(c io.Closer) {
|
||||
if err := c.Close(); err != nil {
|
||||
log.Error().Err(err).Msg("failure to close resource")
|
||||
}
|
||||
}
|
||||
@@ -39,7 +39,6 @@ import (
|
||||
"github.com/portainer/portainer/api/kubernetes"
|
||||
kubecli "github.com/portainer/portainer/api/kubernetes/cli"
|
||||
"github.com/portainer/portainer/api/ldap"
|
||||
"github.com/portainer/portainer/api/logs"
|
||||
"github.com/portainer/portainer/api/oauth"
|
||||
"github.com/portainer/portainer/api/pendingactions"
|
||||
"github.com/portainer/portainer/api/pendingactions/actions"
|
||||
@@ -49,18 +48,16 @@ import (
|
||||
"github.com/portainer/portainer/api/stacks/deployments"
|
||||
"github.com/portainer/portainer/pkg/build"
|
||||
"github.com/portainer/portainer/pkg/featureflags"
|
||||
"github.com/portainer/portainer/pkg/fips"
|
||||
"github.com/portainer/portainer/pkg/libhelm"
|
||||
libhelmtypes "github.com/portainer/portainer/pkg/libhelm/types"
|
||||
"github.com/portainer/portainer/pkg/libstack/compose"
|
||||
"github.com/portainer/portainer/pkg/validate"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/gofrs/uuid"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
func initCLI() *portainer.CLIFlags {
|
||||
cliService := cli.Service{}
|
||||
cliService := &cli.Service{}
|
||||
|
||||
flags, err := cliService.ParseFlags(portainer.APIVersion)
|
||||
if err != nil {
|
||||
@@ -84,7 +81,7 @@ func initFileService(dataStorePath string) portainer.FileService {
|
||||
}
|
||||
|
||||
func initDataStore(flags *portainer.CLIFlags, secretKey []byte, fileService portainer.FileService, shutdownCtx context.Context) dataservices.DataStore {
|
||||
connection, err := database.NewDatabase("boltdb", *flags.Data, secretKey, *flags.CompactDB)
|
||||
connection, err := database.NewDatabase("boltdb", *flags.Data, secretKey)
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("failed creating database connection")
|
||||
}
|
||||
@@ -119,7 +116,7 @@ func initDataStore(flags *portainer.CLIFlags, secretKey []byte, fileService port
|
||||
}
|
||||
|
||||
if isNew {
|
||||
instanceId, err := uuid.NewRandom()
|
||||
instanceId, err := uuid.NewV4()
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("failed generating instance id")
|
||||
}
|
||||
@@ -134,16 +131,15 @@ func initDataStore(flags *portainer.CLIFlags, secretKey []byte, fileService port
|
||||
InstanceID: instanceId.String(),
|
||||
MigratorCount: migratorCount,
|
||||
}
|
||||
|
||||
if err := store.VersionService.UpdateVersion(&v); err != nil {
|
||||
log.Fatal().Err(err).Msg("failed to update version")
|
||||
}
|
||||
store.VersionService.UpdateVersion(&v)
|
||||
|
||||
if err := updateSettingsFromFlags(store, flags); err != nil {
|
||||
log.Fatal().Err(err).Msg("failed updating settings from flags")
|
||||
}
|
||||
} else if err := store.MigrateData(); err != nil {
|
||||
log.Fatal().Err(err).Msg("failed migration")
|
||||
} else {
|
||||
if err := store.MigrateData(); err != nil {
|
||||
log.Fatal().Err(err).Msg("failed migration")
|
||||
}
|
||||
}
|
||||
|
||||
if err := updateSettingsFromFlags(store, flags); err != nil {
|
||||
@@ -154,7 +150,7 @@ func initDataStore(flags *portainer.CLIFlags, secretKey []byte, fileService port
|
||||
go func() {
|
||||
<-shutdownCtx.Done()
|
||||
|
||||
defer logs.CloseAndLogErr(connection)
|
||||
defer connection.Close()
|
||||
}()
|
||||
|
||||
return store
|
||||
@@ -170,8 +166,8 @@ func checkDBSchemaServerVersionMatch(dbStore dataservices.DataStore, serverVersi
|
||||
return v.SchemaVersion == serverVersion && v.Edition == serverEdition
|
||||
}
|
||||
|
||||
func initKubernetesDeployer(kubernetesTokenCacheManager *kubeproxy.TokenCacheManager, kubernetesClientFactory *kubecli.ClientFactory, dataStore dataservices.DataStore, reverseTunnelService portainer.ReverseTunnelService, signatureService portainer.DigitalSignatureService, proxyManager *proxy.Manager) portainer.KubernetesDeployer {
|
||||
return exec.NewKubernetesDeployer(kubernetesTokenCacheManager, kubernetesClientFactory, dataStore, reverseTunnelService, signatureService, proxyManager)
|
||||
func initKubernetesDeployer(kubernetesTokenCacheManager *kubeproxy.TokenCacheManager, kubernetesClientFactory *kubecli.ClientFactory, dataStore dataservices.DataStore, reverseTunnelService portainer.ReverseTunnelService, signatureService portainer.DigitalSignatureService, proxyManager *proxy.Manager, assetsPath string) portainer.KubernetesDeployer {
|
||||
return exec.NewKubernetesDeployer(kubernetesTokenCacheManager, kubernetesClientFactory, dataStore, reverseTunnelService, signatureService, proxyManager, assetsPath)
|
||||
}
|
||||
|
||||
func initHelmPackageManager() (libhelmtypes.HelmPackageManager, error) {
|
||||
@@ -308,19 +304,8 @@ func initKeyPair(fileService portainer.FileService, signatureService portainer.D
|
||||
return generateAndStoreKeyPair(fileService, signatureService)
|
||||
}
|
||||
|
||||
// dbSecretPath build the path to the file that contains the db encryption
|
||||
// secret. Normally in Docker this is built from the static path inside
|
||||
// /run/secrets for example: /run/secrets/<keyFilenameFlag> but for ease of
|
||||
// use outside Docker it also accepts an absolute path
|
||||
func dbSecretPath(keyFilenameFlag string) string {
|
||||
if path.IsAbs(keyFilenameFlag) {
|
||||
return keyFilenameFlag
|
||||
}
|
||||
return path.Join("/run/secrets", keyFilenameFlag)
|
||||
}
|
||||
|
||||
func loadEncryptionSecretKey(keyfilename string) []byte {
|
||||
content, err := os.ReadFile(keyfilename)
|
||||
content, err := os.ReadFile(path.Join("/run/secrets", keyfilename))
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
log.Info().Str("filename", keyfilename).Msg("encryption key file not present")
|
||||
@@ -332,7 +317,6 @@ func loadEncryptionSecretKey(keyfilename string) []byte {
|
||||
}
|
||||
|
||||
// return a 32 byte hash of the secret (required for AES)
|
||||
// fips compliant version of this is not implemented in -ce
|
||||
hash := sha256.Sum256(content)
|
||||
|
||||
return hash[:]
|
||||
@@ -345,23 +329,8 @@ func buildServer(flags *portainer.CLIFlags) portainer.Server {
|
||||
featureflags.Parse(*flags.FeatureFlags, portainer.SupportedFeatureFlags)
|
||||
}
|
||||
|
||||
trustedOrigins := []string{}
|
||||
if *flags.TrustedOrigins != "" {
|
||||
// validate if the trusted origins are valid urls
|
||||
for origin := range strings.SplitSeq(*flags.TrustedOrigins, ",") {
|
||||
if !validate.IsTrustedOrigin(origin) {
|
||||
log.Fatal().Str("trusted_origin", origin).Msg("invalid url for trusted origin. Please check the trusted origins flag.")
|
||||
}
|
||||
|
||||
trustedOrigins = append(trustedOrigins, origin)
|
||||
}
|
||||
}
|
||||
|
||||
// -ce can not ever be run in FIPS mode
|
||||
fips.InitFIPS(false)
|
||||
|
||||
fileService := initFileService(*flags.Data)
|
||||
encryptionKey := loadEncryptionSecretKey(dbSecretPath(*flags.SecretKeyName))
|
||||
encryptionKey := loadEncryptionSecretKey(*flags.SecretKeyName)
|
||||
if encryptionKey == nil {
|
||||
log.Info().Msg("proceeding without encryption key")
|
||||
}
|
||||
@@ -394,22 +363,21 @@ func buildServer(flags *portainer.CLIFlags) portainer.Server {
|
||||
log.Fatal().Err(err).Msg("failed initializing JWT service")
|
||||
}
|
||||
|
||||
ldapService := ldap.Service{}
|
||||
ldapService := &ldap.Service{}
|
||||
|
||||
oauthService := oauth.NewService()
|
||||
|
||||
gitService := git.NewService(shutdownCtx)
|
||||
|
||||
// Setting insecureSkipVerify to true to preserve the old behaviour.
|
||||
openAMTService := openamt.NewService(true)
|
||||
openAMTService := openamt.NewService()
|
||||
|
||||
cryptoService := crypto.Service{}
|
||||
cryptoService := &crypto.Service{}
|
||||
|
||||
signatureService := initDigitalSignatureService()
|
||||
|
||||
edgeStacksService := edgestacks.NewService(dataStore)
|
||||
|
||||
sslService, err := initSSLService(*flags.AddrHTTPS, *flags.TLSCert, *flags.TLSKey, fileService, dataStore, shutdownTrigger)
|
||||
sslService, err := initSSLService(*flags.AddrHTTPS, *flags.SSLCert, *flags.SSLKey, fileService, dataStore, shutdownTrigger)
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("")
|
||||
}
|
||||
@@ -454,7 +422,7 @@ func buildServer(flags *portainer.CLIFlags) portainer.Server {
|
||||
log.Fatal().Err(err).Msg("failed initializing swarm stack manager")
|
||||
}
|
||||
|
||||
kubernetesDeployer := initKubernetesDeployer(kubernetesTokenCacheManager, kubernetesClientFactory, dataStore, reverseTunnelService, signatureService, proxyManager)
|
||||
kubernetesDeployer := initKubernetesDeployer(kubernetesTokenCacheManager, kubernetesClientFactory, dataStore, reverseTunnelService, signatureService, proxyManager, *flags.Assets)
|
||||
|
||||
pendingActionsService := pendingactions.NewService(dataStore, kubernetesClientFactory)
|
||||
pendingActionsService.RegisterHandler(actions.CleanNAPWithOverridePolicies, handlers.NewHandlerCleanNAPWithOverridePolicies(authorizationService, dataStore))
|
||||
@@ -468,7 +436,7 @@ func buildServer(flags *portainer.CLIFlags) portainer.Server {
|
||||
|
||||
snapshotService.Start()
|
||||
|
||||
proxyManager.NewProxyFactory(dataStore, signatureService, reverseTunnelService, dockerClientFactory, kubernetesClientFactory, kubernetesTokenCacheManager, gitService, snapshotService, jwtService)
|
||||
proxyManager.NewProxyFactory(dataStore, signatureService, reverseTunnelService, dockerClientFactory, kubernetesClientFactory, kubernetesTokenCacheManager, gitService, snapshotService)
|
||||
|
||||
helmPackageManager, err := initHelmPackageManager()
|
||||
if err != nil {
|
||||
@@ -530,9 +498,7 @@ func buildServer(flags *portainer.CLIFlags) portainer.Server {
|
||||
|
||||
scheduler := scheduler.NewScheduler(shutdownCtx)
|
||||
stackDeployer := deployments.NewStackDeployer(swarmStackManager, composeStackManager, kubernetesDeployer, dockerClientFactory, dataStore)
|
||||
if err := deployments.StartStackSchedules(scheduler, stackDeployer, dataStore, gitService); err != nil {
|
||||
log.Fatal().Err(err).Msg("failed to start stack scheduler")
|
||||
}
|
||||
deployments.StartStackSchedules(scheduler, stackDeployer, dataStore, gitService)
|
||||
|
||||
sslDBSettings, err := dataStore.SSLSettings().Settings()
|
||||
if err != nil {
|
||||
@@ -578,7 +544,6 @@ func buildServer(flags *portainer.CLIFlags) portainer.Server {
|
||||
Status: applicationStatus,
|
||||
BindAddress: *flags.Addr,
|
||||
BindAddressHTTPS: *flags.AddrHTTPS,
|
||||
CSP: *flags.CSP,
|
||||
HTTPEnabled: sslDBSettings.HTTPEnabled,
|
||||
AssetsPath: *flags.Assets,
|
||||
DataStore: dataStore,
|
||||
@@ -612,18 +577,17 @@ func buildServer(flags *portainer.CLIFlags) portainer.Server {
|
||||
PendingActionsService: pendingActionsService,
|
||||
PlatformService: platformService,
|
||||
PullLimitCheckDisabled: *flags.PullLimitCheckDisabled,
|
||||
TrustedOrigins: trustedOrigins,
|
||||
}
|
||||
}
|
||||
|
||||
func main() {
|
||||
logs.ConfigureLogger()
|
||||
logs.SetLoggingMode("PRETTY")
|
||||
configureLogger()
|
||||
setLoggingMode("PRETTY")
|
||||
|
||||
flags := initCLI()
|
||||
|
||||
logs.SetLoggingLevel(*flags.LogLevel)
|
||||
logs.SetLoggingMode(*flags.LogMode)
|
||||
setLoggingLevel(*flags.LogLevel)
|
||||
setLoggingMode(*flags.LogMode)
|
||||
|
||||
for {
|
||||
server := buildServer(flags)
|
||||
@@ -633,7 +597,7 @@ func main() {
|
||||
Str("build_number", build.BuildNumber).
|
||||
Str("image_tag", build.ImageTag).
|
||||
Str("nodejs_version", build.NodejsVersion).
|
||||
Str("pnpm_version", build.PnpmVersion).
|
||||
Str("yarn_version", build.YarnVersion).
|
||||
Str("webpack_version", build.WebpackVersion).
|
||||
Str("go_version", build.GoVersion).
|
||||
Msg("starting Portainer")
|
||||
|
||||
@@ -1,57 +0,0 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const secretFileName = "secret.txt"
|
||||
|
||||
func createPasswordFile(t *testing.T, secretPath, password string) string {
|
||||
err := os.WriteFile(secretPath, []byte(password), 0600)
|
||||
require.NoError(t, err)
|
||||
return secretPath
|
||||
}
|
||||
|
||||
func TestLoadEncryptionSecretKey(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
secretPath := path.Join(tempDir, secretFileName)
|
||||
|
||||
// first pointing to file that does not exist, gives nil hash (no encryption)
|
||||
encryptionKey := loadEncryptionSecretKey(secretPath)
|
||||
require.Nil(t, encryptionKey)
|
||||
|
||||
// point to a directory instead of a file
|
||||
encryptionKey = loadEncryptionSecretKey(tempDir)
|
||||
require.Nil(t, encryptionKey)
|
||||
|
||||
password := "portainer@1234"
|
||||
createPasswordFile(t, secretPath, password)
|
||||
|
||||
encryptionKey = loadEncryptionSecretKey(secretPath)
|
||||
require.NotNil(t, encryptionKey)
|
||||
// should be 32 bytes for aes256 encryption
|
||||
require.Len(t, encryptionKey, 32)
|
||||
}
|
||||
|
||||
func TestDBSecretPath(t *testing.T) {
|
||||
tests := []struct {
|
||||
keyFilenameFlag string
|
||||
expected string
|
||||
}{
|
||||
{keyFilenameFlag: "secret.txt", expected: "/run/secrets/secret.txt"},
|
||||
{keyFilenameFlag: "/tmp/secret.txt", expected: "/tmp/secret.txt"},
|
||||
{keyFilenameFlag: "/run/secrets/secret.txt", expected: "/run/secrets/secret.txt"},
|
||||
{keyFilenameFlag: "./secret.txt", expected: "/run/secrets/secret.txt"},
|
||||
{keyFilenameFlag: "../secret.txt", expected: "/run/secret.txt"},
|
||||
{keyFilenameFlag: "foo/bar/secret.txt", expected: "/run/secrets/foo/bar/secret.txt"},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
assert.Equal(t, test.expected, dbSecretPath(test.keyFilenameFlag))
|
||||
}
|
||||
}
|
||||
@@ -5,19 +5,13 @@ import (
|
||||
"bytes"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/pbkdf2"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"github.com/portainer/portainer/pkg/fips"
|
||||
|
||||
// Not allowed in FIPS mode
|
||||
"golang.org/x/crypto/argon2" //nolint:depguard
|
||||
"golang.org/x/crypto/scrypt" //nolint:depguard
|
||||
"golang.org/x/crypto/argon2"
|
||||
"golang.org/x/crypto/scrypt"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -25,32 +19,20 @@ const (
|
||||
aesGcmHeader = "AES256-GCM" // The encrypted file header
|
||||
aesGcmBlockSize = 1024 * 1024 // 1MB block for aes gcm
|
||||
|
||||
aesGcmFIPSHeader = "FIPS-AES256-GCM"
|
||||
aesGcmFIPSBlockSize = 16 * 1024 * 1024 // 16MB block for aes gcm
|
||||
|
||||
// Argon2 settings
|
||||
// Recommended settings lower memory hardware according to current OWASP recommendations
|
||||
// Recommded settings lower memory hardware according to current OWASP recommendations
|
||||
// Considering some people run portainer on a NAS I think it's prudent not to assume we're on server grade hardware
|
||||
// https://cheatsheetseries.owasp.org/cheatsheets/Password_Storage_Cheat_Sheet.html#argon2id
|
||||
argon2MemoryCost = 12 * 1024
|
||||
argon2TimeCost = 3
|
||||
argon2Threads = 1
|
||||
argon2KeyLength = 32
|
||||
|
||||
pbkdf2Iterations = 600_000 // use recommended iterations from https://cheatsheetseries.owasp.org/cheatsheets/Password_Storage_Cheat_Sheet.html#pbkdf2 a little overkill for this use
|
||||
pbkdf2SaltLength = 32
|
||||
)
|
||||
|
||||
// AesEncrypt reads from input, encrypts with AES-256 and writes to output. passphrase is used to generate an encryption key
|
||||
func AesEncrypt(input io.Reader, output io.Writer, passphrase []byte) error {
|
||||
if fips.FIPSMode() {
|
||||
if err := aesEncryptGCMFIPS(input, output, passphrase); err != nil {
|
||||
return fmt.Errorf("error encrypting file: %w", err)
|
||||
}
|
||||
} else {
|
||||
if err := aesEncryptGCM(input, output, passphrase); err != nil {
|
||||
return fmt.Errorf("error encrypting file: %w", err)
|
||||
}
|
||||
if err := aesEncryptGCM(input, output, passphrase); err != nil {
|
||||
return fmt.Errorf("error encrypting file: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -58,35 +40,14 @@ func AesEncrypt(input io.Reader, output io.Writer, passphrase []byte) error {
|
||||
|
||||
// AesDecrypt reads from input, decrypts with AES-256 and returns the reader to read the decrypted content from
|
||||
func AesDecrypt(input io.Reader, passphrase []byte) (io.Reader, error) {
|
||||
return aesDecrypt(input, passphrase, fips.FIPSMode())
|
||||
}
|
||||
|
||||
func aesDecrypt(input io.Reader, passphrase []byte, fipsMode bool) (io.Reader, error) {
|
||||
// Read file header to determine how it was encrypted
|
||||
inputReader := bufio.NewReader(input)
|
||||
header, err := inputReader.Peek(len(aesGcmFIPSHeader))
|
||||
header, err := inputReader.Peek(len(aesGcmHeader))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error reading encrypted backup file header: %w", err)
|
||||
}
|
||||
|
||||
if strings.HasPrefix(string(header), aesGcmFIPSHeader) {
|
||||
if !fipsMode {
|
||||
return nil, errors.New("fips encrypted file detected but fips mode is not enabled")
|
||||
}
|
||||
|
||||
reader, err := aesDecryptGCMFIPS(inputReader, passphrase)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error decrypting file: %w", err)
|
||||
}
|
||||
|
||||
return reader, nil
|
||||
}
|
||||
|
||||
if strings.HasPrefix(string(header), aesGcmHeader) {
|
||||
if fipsMode {
|
||||
return nil, errors.New("fips mode is enabled but non-fips encrypted file detected")
|
||||
}
|
||||
|
||||
if string(header) == aesGcmHeader {
|
||||
reader, err := aesDecryptGCM(inputReader, passphrase)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error decrypting file: %w", err)
|
||||
@@ -153,20 +114,19 @@ func aesEncryptGCM(input io.Reader, output io.Writer, passphrase []byte) error {
|
||||
break // end of plaintext input
|
||||
}
|
||||
|
||||
if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
if err != nil && !(errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF)) {
|
||||
return err
|
||||
}
|
||||
|
||||
// Seal encrypts the plaintext using the nonce returning the updated slice.
|
||||
ciphertext = aesgcm.Seal(ciphertext[:0], nonce.Value(), buf[:n], nil)
|
||||
|
||||
if _, err := output.Write(ciphertext); err != nil {
|
||||
_, err = output.Write(ciphertext)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := nonce.Increment(); err != nil {
|
||||
return err
|
||||
}
|
||||
nonce.Increment()
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -223,7 +183,7 @@ func aesDecryptGCM(input io.Reader, passphrase []byte) (io.Reader, error) {
|
||||
break // end of ciphertext
|
||||
}
|
||||
|
||||
if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
if err != nil && !(errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF)) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -237,134 +197,7 @@ func aesDecryptGCM(input io.Reader, passphrase []byte) (io.Reader, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := nonce.Increment(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return &buf, nil
|
||||
}
|
||||
|
||||
// aesEncryptGCMFIPS reads from input, encrypts with AES-256 in a fips compliant
|
||||
// way and writes to output. passphrase is used to generate an encryption key.
|
||||
func aesEncryptGCMFIPS(input io.Reader, output io.Writer, passphrase []byte) error {
|
||||
salt := make([]byte, pbkdf2SaltLength)
|
||||
if _, err := io.ReadFull(rand.Reader, salt); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
key, err := pbkdf2.Key(sha256.New, string(passphrase), salt, pbkdf2Iterations, 32)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error deriving key: %w", err)
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// write the header
|
||||
if _, err := output.Write([]byte(aesGcmFIPSHeader)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Write nonce and salt to the output file
|
||||
if _, err := output.Write(salt); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Buffer for reading plaintext blocks
|
||||
buf := make([]byte, aesGcmFIPSBlockSize)
|
||||
|
||||
// Encrypt plaintext in blocks
|
||||
for {
|
||||
// new random nonce for each block
|
||||
aesgcm, err := cipher.NewGCMWithRandomNonce(block)
|
||||
if err != nil {
|
||||
return fmt.Errorf("error creating gcm: %w", err)
|
||||
}
|
||||
|
||||
n, err := io.ReadFull(input, buf)
|
||||
if n == 0 {
|
||||
break // end of plaintext input
|
||||
}
|
||||
|
||||
if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
return err
|
||||
}
|
||||
|
||||
// Seal encrypts the plaintext
|
||||
ciphertext := aesgcm.Seal(nil, nil, buf[:n], nil)
|
||||
|
||||
if _, err := output.Write(ciphertext); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// aesDecryptGCMFIPS reads from input, decrypts with AES-256 in a fips compliant
|
||||
// way and returns the reader to read the decrypted content from.
|
||||
func aesDecryptGCMFIPS(input io.Reader, passphrase []byte) (io.Reader, error) {
|
||||
// Reader & verify header
|
||||
header := make([]byte, len(aesGcmFIPSHeader))
|
||||
if _, err := io.ReadFull(input, header); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if string(header) != aesGcmFIPSHeader {
|
||||
return nil, errors.New("invalid header")
|
||||
}
|
||||
|
||||
// Read salt
|
||||
salt := make([]byte, pbkdf2SaltLength)
|
||||
if _, err := io.ReadFull(input, salt); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
key, err := pbkdf2.Key(sha256.New, string(passphrase), salt, pbkdf2Iterations, 32)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("error deriving key: %w", err)
|
||||
}
|
||||
|
||||
// Initialize AES cipher block
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Initialize a buffer to store decrypted data
|
||||
buf := bytes.Buffer{}
|
||||
|
||||
// Decrypt the ciphertext in blocks
|
||||
for {
|
||||
// Create GCM mode with the cipher block
|
||||
aesgcm, err := cipher.NewGCMWithRandomNonce(block)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Read a block of ciphertext from the input reader
|
||||
ciphertextBlock := make([]byte, aesGcmFIPSBlockSize+aesgcm.Overhead())
|
||||
n, err := io.ReadFull(input, ciphertextBlock)
|
||||
if n == 0 {
|
||||
break // end of ciphertext
|
||||
}
|
||||
|
||||
if err != nil && !errors.Is(err, io.EOF) && !errors.Is(err, io.ErrUnexpectedEOF) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Decrypt the block of ciphertext
|
||||
plaintext, err := aesgcm.Open(nil, nil, ciphertextBlock[:n], nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if _, err := buf.Write(plaintext); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
nonce.Increment()
|
||||
}
|
||||
|
||||
return &buf, nil
|
||||
@@ -374,9 +207,11 @@ func aesDecryptGCMFIPS(input io.Reader, passphrase []byte) (io.Reader, error) {
|
||||
// passphrase is used to generate an encryption key.
|
||||
// note: This function used to decrypt files that were encrypted without a header i.e. old archives
|
||||
func aesDecryptOFB(input io.Reader, passphrase []byte) (io.Reader, error) {
|
||||
var emptySalt []byte = make([]byte, 0)
|
||||
|
||||
// making a 32 bytes key that would correspond to AES-256
|
||||
// don't necessarily need a salt, so just kept in empty
|
||||
key, err := scrypt.Key(passphrase, nil, 32768, 8, 1, 32)
|
||||
key, err := scrypt.Key(passphrase, emptySalt, 32768, 8, 1, 32)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -393,18 +228,3 @@ func aesDecryptOFB(input io.Reader, passphrase []byte) (io.Reader, error) {
|
||||
|
||||
return reader, nil
|
||||
}
|
||||
|
||||
// HasEncryptedHeader checks if the data has an encrypted header, note that fips
|
||||
// mode changes this behavior and so will only recognize data encrypted by the
|
||||
// same mode (fips enabled or disabled)
|
||||
func HasEncryptedHeader(data []byte) bool {
|
||||
return hasEncryptedHeader(data, fips.FIPSMode())
|
||||
}
|
||||
|
||||
func hasEncryptedHeader(data []byte, fipsMode bool) bool {
|
||||
if fipsMode {
|
||||
return bytes.HasPrefix(data, []byte(aesGcmFIPSHeader))
|
||||
}
|
||||
|
||||
return bytes.HasPrefix(data, []byte(aesGcmHeader))
|
||||
}
|
||||
|
||||
@@ -1,26 +1,15 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"io"
|
||||
"math/rand"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/portainer/portainer/api/logs"
|
||||
"github.com/portainer/portainer/pkg/fips"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/crypto/scrypt"
|
||||
)
|
||||
|
||||
func init() {
|
||||
fips.InitFIPS(false)
|
||||
}
|
||||
|
||||
const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
|
||||
|
||||
func randBytes(n int) []byte {
|
||||
@@ -28,417 +17,201 @@ func randBytes(n int) []byte {
|
||||
for i := range b {
|
||||
b[i] = letterBytes[rand.Intn(len(letterBytes))]
|
||||
}
|
||||
|
||||
return b
|
||||
}
|
||||
|
||||
type encryptFunc func(input io.Reader, output io.Writer, passphrase []byte) error
|
||||
type decryptFunc func(input io.Reader, passphrase []byte) (io.Reader, error)
|
||||
|
||||
func Test_encryptAndDecrypt_withTheSamePassword(t *testing.T) {
|
||||
const passphrase = "passphrase"
|
||||
|
||||
testFunc := func(t *testing.T, encrypt encryptFunc, decrypt decryptFunc, decryptShouldSucceed bool) {
|
||||
tmpdir := t.TempDir()
|
||||
tmpdir := t.TempDir()
|
||||
|
||||
var (
|
||||
originFilePath = filepath.Join(tmpdir, "origin")
|
||||
encryptedFilePath = filepath.Join(tmpdir, "encrypted")
|
||||
decryptedFilePath = filepath.Join(tmpdir, "decrypted")
|
||||
)
|
||||
var (
|
||||
originFilePath = filepath.Join(tmpdir, "origin")
|
||||
encryptedFilePath = filepath.Join(tmpdir, "encrypted")
|
||||
decryptedFilePath = filepath.Join(tmpdir, "decrypted")
|
||||
)
|
||||
|
||||
content := randBytes(1024*1024*100 + 523)
|
||||
err := os.WriteFile(originFilePath, content, 0600)
|
||||
require.NoError(t, err)
|
||||
content := randBytes(1024*1024*100 + 523)
|
||||
os.WriteFile(originFilePath, content, 0600)
|
||||
|
||||
originFile, _ := os.Open(originFilePath)
|
||||
defer logs.CloseAndLogErr(originFile)
|
||||
originFile, _ := os.Open(originFilePath)
|
||||
defer originFile.Close()
|
||||
|
||||
encryptedFileWriter, _ := os.Create(encryptedFilePath)
|
||||
encryptedFileWriter, _ := os.Create(encryptedFilePath)
|
||||
|
||||
err = encrypt(originFile, encryptedFileWriter, []byte(passphrase))
|
||||
require.NoError(t, err, "Failed to encrypt a file")
|
||||
logs.CloseAndLogErr(encryptedFileWriter)
|
||||
err := AesEncrypt(originFile, encryptedFileWriter, []byte(passphrase))
|
||||
assert.Nil(t, err, "Failed to encrypt a file")
|
||||
encryptedFileWriter.Close()
|
||||
|
||||
encryptedContent, err := os.ReadFile(encryptedFilePath)
|
||||
require.NoError(t, err, "Couldn't read encrypted file")
|
||||
assert.NotEqual(t, encryptedContent, content, "Content wasn't encrypted")
|
||||
encryptedContent, err := os.ReadFile(encryptedFilePath)
|
||||
assert.Nil(t, err, "Couldn't read encrypted file")
|
||||
assert.NotEqual(t, encryptedContent, content, "Content wasn't encrypted")
|
||||
|
||||
encryptedFileReader, err := os.Open(encryptedFilePath)
|
||||
require.NoError(t, err)
|
||||
defer logs.CloseAndLogErr(encryptedFileReader)
|
||||
encryptedFileReader, _ := os.Open(encryptedFilePath)
|
||||
defer encryptedFileReader.Close()
|
||||
|
||||
decryptedFileWriter, err := os.Create(decryptedFilePath)
|
||||
require.NoError(t, err)
|
||||
defer logs.CloseAndLogErr(decryptedFileWriter)
|
||||
decryptedFileWriter, _ := os.Create(decryptedFilePath)
|
||||
defer decryptedFileWriter.Close()
|
||||
|
||||
decryptedReader, err := decrypt(encryptedFileReader, []byte(passphrase))
|
||||
if !decryptShouldSucceed {
|
||||
require.Error(t, err, "Failed to decrypt file as indicated by decryptShouldSucceed")
|
||||
} else {
|
||||
require.NoError(t, err, "Failed to decrypt file indicated by decryptShouldSucceed")
|
||||
decryptedReader, err := AesDecrypt(encryptedFileReader, []byte(passphrase))
|
||||
assert.Nil(t, err, "Failed to decrypt file")
|
||||
|
||||
_, err = io.Copy(decryptedFileWriter, decryptedReader)
|
||||
require.NoError(t, err)
|
||||
io.Copy(decryptedFileWriter, decryptedReader)
|
||||
|
||||
decryptedContent, err := os.ReadFile(decryptedFilePath)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, content, decryptedContent, "Original and decrypted content should match")
|
||||
}
|
||||
}
|
||||
|
||||
t.Run("fips", func(t *testing.T) {
|
||||
testFunc(t, aesEncryptGCMFIPS, aesDecryptGCMFIPS, true)
|
||||
})
|
||||
|
||||
t.Run("non_fips", func(t *testing.T) {
|
||||
testFunc(t, aesEncryptGCM, aesDecryptGCM, true)
|
||||
})
|
||||
|
||||
t.Run("system_fips_mode_public_entry_points", func(t *testing.T) {
|
||||
// use the init mode, public entry points
|
||||
testFunc(t, AesEncrypt, AesDecrypt, true)
|
||||
})
|
||||
|
||||
t.Run("fips_encrypted_file_header_fails_in_non_fips_mode", func(t *testing.T) {
|
||||
// use aesDecrypt which checks the header, confirm that it fails
|
||||
decrypt := func(input io.Reader, passphrase []byte) (io.Reader, error) {
|
||||
return aesDecrypt(input, passphrase, false)
|
||||
}
|
||||
|
||||
testFunc(t, aesEncryptGCMFIPS, decrypt, false)
|
||||
})
|
||||
|
||||
t.Run("non_fips_encrypted_file_header_fails_in_fips_mode", func(t *testing.T) {
|
||||
// use aesDecrypt which checks the header, confirm that it fails
|
||||
decrypt := func(input io.Reader, passphrase []byte) (io.Reader, error) {
|
||||
return aesDecrypt(input, passphrase, true)
|
||||
}
|
||||
|
||||
testFunc(t, aesEncryptGCM, decrypt, false)
|
||||
})
|
||||
|
||||
t.Run("fips_encrypted_file_fails_in_non_fips_mode", func(t *testing.T) {
|
||||
testFunc(t, aesEncryptGCMFIPS, aesDecryptGCM, false)
|
||||
})
|
||||
|
||||
t.Run("non_fips_encrypted_file_with_fips_mode_should_fail", func(t *testing.T) {
|
||||
testFunc(t, aesEncryptGCM, aesDecryptGCMFIPS, false)
|
||||
})
|
||||
|
||||
t.Run("fips_with_base_aesDecrypt", func(t *testing.T) {
|
||||
// maximize coverage, use the base aesDecrypt function with valid fips mode
|
||||
decrypt := func(input io.Reader, passphrase []byte) (io.Reader, error) {
|
||||
return aesDecrypt(input, passphrase, true)
|
||||
}
|
||||
|
||||
testFunc(t, aesEncryptGCMFIPS, decrypt, true)
|
||||
})
|
||||
|
||||
t.Run("legacy", func(t *testing.T) {
|
||||
testFunc(t, legacyAesEncrypt, aesDecryptOFB, true)
|
||||
})
|
||||
decryptedContent, _ := os.ReadFile(decryptedFilePath)
|
||||
assert.Equal(t, content, decryptedContent, "Original and decrypted content should match")
|
||||
}
|
||||
|
||||
func Test_encryptAndDecrypt_withStrongPassphrase(t *testing.T) {
|
||||
const passphrase = "A strong passphrase with special characters: !@#$%^&*()_+"
|
||||
tmpdir := t.TempDir()
|
||||
|
||||
testFunc := func(t *testing.T, encrypt encryptFunc, decrypt decryptFunc) {
|
||||
tmpdir := t.TempDir()
|
||||
var (
|
||||
originFilePath = filepath.Join(tmpdir, "origin2")
|
||||
encryptedFilePath = filepath.Join(tmpdir, "encrypted2")
|
||||
decryptedFilePath = filepath.Join(tmpdir, "decrypted2")
|
||||
)
|
||||
|
||||
var (
|
||||
originFilePath = filepath.Join(tmpdir, "origin2")
|
||||
encryptedFilePath = filepath.Join(tmpdir, "encrypted2")
|
||||
decryptedFilePath = filepath.Join(tmpdir, "decrypted2")
|
||||
)
|
||||
content := randBytes(500)
|
||||
os.WriteFile(originFilePath, content, 0600)
|
||||
|
||||
content := randBytes(500)
|
||||
originFile, _ := os.Open(originFilePath)
|
||||
defer originFile.Close()
|
||||
|
||||
err := os.WriteFile(originFilePath, content, 0600)
|
||||
require.NoError(t, err)
|
||||
encryptedFileWriter, _ := os.Create(encryptedFilePath)
|
||||
|
||||
originFile, err := os.Open(originFilePath)
|
||||
require.NoError(t, err)
|
||||
defer logs.CloseAndLogErr(originFile)
|
||||
err := AesEncrypt(originFile, encryptedFileWriter, []byte(passphrase))
|
||||
assert.Nil(t, err, "Failed to encrypt a file")
|
||||
encryptedFileWriter.Close()
|
||||
|
||||
encryptedFileWriter, _ := os.Create(encryptedFilePath)
|
||||
encryptedContent, err := os.ReadFile(encryptedFilePath)
|
||||
assert.Nil(t, err, "Couldn't read encrypted file")
|
||||
assert.NotEqual(t, encryptedContent, content, "Content wasn't encrypted")
|
||||
|
||||
err = encrypt(originFile, encryptedFileWriter, []byte(passphrase))
|
||||
require.NoError(t, err, "Failed to encrypt a file")
|
||||
logs.CloseAndLogErr(encryptedFileWriter)
|
||||
encryptedFileReader, _ := os.Open(encryptedFilePath)
|
||||
defer encryptedFileReader.Close()
|
||||
|
||||
encryptedContent, err := os.ReadFile(encryptedFilePath)
|
||||
require.NoError(t, err, "Couldn't read encrypted file")
|
||||
assert.NotEqual(t, encryptedContent, content, "Content wasn't encrypted")
|
||||
decryptedFileWriter, _ := os.Create(decryptedFilePath)
|
||||
defer decryptedFileWriter.Close()
|
||||
|
||||
encryptedFileReader, err := os.Open(encryptedFilePath)
|
||||
require.NoError(t, err)
|
||||
defer logs.CloseAndLogErr(encryptedFileReader)
|
||||
decryptedReader, err := AesDecrypt(encryptedFileReader, []byte(passphrase))
|
||||
assert.Nil(t, err, "Failed to decrypt file")
|
||||
|
||||
decryptedFileWriter, err := os.Create(decryptedFilePath)
|
||||
require.NoError(t, err)
|
||||
defer logs.CloseAndLogErr(decryptedFileWriter)
|
||||
io.Copy(decryptedFileWriter, decryptedReader)
|
||||
|
||||
decryptedReader, err := decrypt(encryptedFileReader, []byte(passphrase))
|
||||
require.NoError(t, err, "Failed to decrypt file")
|
||||
|
||||
_, err = io.Copy(decryptedFileWriter, decryptedReader)
|
||||
require.NoError(t, err)
|
||||
|
||||
decryptedContent, err := os.ReadFile(decryptedFilePath)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, content, decryptedContent, "Original and decrypted content should match")
|
||||
}
|
||||
|
||||
t.Run("fips", func(t *testing.T) {
|
||||
testFunc(t, aesEncryptGCMFIPS, aesDecryptGCMFIPS)
|
||||
})
|
||||
|
||||
t.Run("non_fips", func(t *testing.T) {
|
||||
testFunc(t, aesEncryptGCM, aesDecryptGCM)
|
||||
})
|
||||
decryptedContent, _ := os.ReadFile(decryptedFilePath)
|
||||
assert.Equal(t, content, decryptedContent, "Original and decrypted content should match")
|
||||
}
|
||||
|
||||
func Test_encryptAndDecrypt_withTheSamePasswordSmallFile(t *testing.T) {
|
||||
testFunc := func(t *testing.T, encrypt encryptFunc, decrypt decryptFunc) {
|
||||
tmpdir := t.TempDir()
|
||||
tmpdir := t.TempDir()
|
||||
|
||||
var (
|
||||
originFilePath = filepath.Join(tmpdir, "origin2")
|
||||
encryptedFilePath = filepath.Join(tmpdir, "encrypted2")
|
||||
decryptedFilePath = filepath.Join(tmpdir, "decrypted2")
|
||||
)
|
||||
var (
|
||||
originFilePath = filepath.Join(tmpdir, "origin2")
|
||||
encryptedFilePath = filepath.Join(tmpdir, "encrypted2")
|
||||
decryptedFilePath = filepath.Join(tmpdir, "decrypted2")
|
||||
)
|
||||
|
||||
content := randBytes(500)
|
||||
err := os.WriteFile(originFilePath, content, 0600)
|
||||
require.NoError(t, err)
|
||||
content := randBytes(500)
|
||||
os.WriteFile(originFilePath, content, 0600)
|
||||
|
||||
originFile, err := os.Open(originFilePath)
|
||||
require.NoError(t, err)
|
||||
defer logs.CloseAndLogErr(originFile)
|
||||
originFile, _ := os.Open(originFilePath)
|
||||
defer originFile.Close()
|
||||
|
||||
encryptedFileWriter, err := os.Create(encryptedFilePath)
|
||||
require.NoError(t, err)
|
||||
encryptedFileWriter, _ := os.Create(encryptedFilePath)
|
||||
|
||||
err = encrypt(originFile, encryptedFileWriter, []byte("passphrase"))
|
||||
require.NoError(t, err, "Failed to encrypt a file")
|
||||
logs.CloseAndLogErr(encryptedFileWriter)
|
||||
err := AesEncrypt(originFile, encryptedFileWriter, []byte("passphrase"))
|
||||
assert.Nil(t, err, "Failed to encrypt a file")
|
||||
encryptedFileWriter.Close()
|
||||
|
||||
encryptedContent, err := os.ReadFile(encryptedFilePath)
|
||||
require.NoError(t, err, "Couldn't read encrypted file")
|
||||
assert.NotEqual(t, encryptedContent, content, "Content wasn't encrypted")
|
||||
encryptedContent, err := os.ReadFile(encryptedFilePath)
|
||||
assert.Nil(t, err, "Couldn't read encrypted file")
|
||||
assert.NotEqual(t, encryptedContent, content, "Content wasn't encrypted")
|
||||
|
||||
encryptedFileReader, err := os.Open(encryptedFilePath)
|
||||
require.NoError(t, err)
|
||||
defer logs.CloseAndLogErr(encryptedFileReader)
|
||||
encryptedFileReader, _ := os.Open(encryptedFilePath)
|
||||
defer encryptedFileReader.Close()
|
||||
|
||||
decryptedFileWriter, err := os.Create(decryptedFilePath)
|
||||
require.NoError(t, err)
|
||||
defer logs.CloseAndLogErr(decryptedFileWriter)
|
||||
decryptedFileWriter, _ := os.Create(decryptedFilePath)
|
||||
defer decryptedFileWriter.Close()
|
||||
|
||||
decryptedReader, err := decrypt(encryptedFileReader, []byte("passphrase"))
|
||||
require.NoError(t, err, "Failed to decrypt file")
|
||||
decryptedReader, err := AesDecrypt(encryptedFileReader, []byte("passphrase"))
|
||||
assert.Nil(t, err, "Failed to decrypt file")
|
||||
|
||||
_, err = io.Copy(decryptedFileWriter, decryptedReader)
|
||||
require.NoError(t, err)
|
||||
io.Copy(decryptedFileWriter, decryptedReader)
|
||||
|
||||
decryptedContent, err := os.ReadFile(decryptedFilePath)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, content, decryptedContent, "Original and decrypted content should match")
|
||||
}
|
||||
|
||||
t.Run("fips", func(t *testing.T) {
|
||||
testFunc(t, aesEncryptGCMFIPS, aesDecryptGCMFIPS)
|
||||
})
|
||||
|
||||
t.Run("non_fips", func(t *testing.T) {
|
||||
testFunc(t, aesEncryptGCM, aesDecryptGCM)
|
||||
})
|
||||
decryptedContent, _ := os.ReadFile(decryptedFilePath)
|
||||
assert.Equal(t, content, decryptedContent, "Original and decrypted content should match")
|
||||
}
|
||||
|
||||
func Test_encryptAndDecrypt_withEmptyPassword(t *testing.T) {
|
||||
testFunc := func(t *testing.T, encrypt encryptFunc, decrypt decryptFunc) {
|
||||
tmpdir := t.TempDir()
|
||||
tmpdir := t.TempDir()
|
||||
|
||||
var (
|
||||
originFilePath = filepath.Join(tmpdir, "origin")
|
||||
encryptedFilePath = filepath.Join(tmpdir, "encrypted")
|
||||
decryptedFilePath = filepath.Join(tmpdir, "decrypted")
|
||||
)
|
||||
var (
|
||||
originFilePath = filepath.Join(tmpdir, "origin")
|
||||
encryptedFilePath = filepath.Join(tmpdir, "encrypted")
|
||||
decryptedFilePath = filepath.Join(tmpdir, "decrypted")
|
||||
)
|
||||
|
||||
content := randBytes(1024 * 50)
|
||||
err := os.WriteFile(originFilePath, content, 0600)
|
||||
require.NoError(t, err)
|
||||
content := randBytes(1024 * 50)
|
||||
os.WriteFile(originFilePath, content, 0600)
|
||||
|
||||
originFile, err := os.Open(originFilePath)
|
||||
require.NoError(t, err)
|
||||
defer logs.CloseAndLogErr(originFile)
|
||||
originFile, _ := os.Open(originFilePath)
|
||||
defer originFile.Close()
|
||||
|
||||
encryptedFileWriter, err := os.Create(encryptedFilePath)
|
||||
require.NoError(t, err)
|
||||
defer logs.CloseAndLogErr(encryptedFileWriter)
|
||||
encryptedFileWriter, _ := os.Create(encryptedFilePath)
|
||||
defer encryptedFileWriter.Close()
|
||||
|
||||
err = encrypt(originFile, encryptedFileWriter, []byte(""))
|
||||
require.NoError(t, err, "Failed to encrypt a file")
|
||||
err := AesEncrypt(originFile, encryptedFileWriter, []byte(""))
|
||||
assert.Nil(t, err, "Failed to encrypt a file")
|
||||
encryptedContent, err := os.ReadFile(encryptedFilePath)
|
||||
assert.Nil(t, err, "Couldn't read encrypted file")
|
||||
assert.NotEqual(t, encryptedContent, content, "Content wasn't encrypted")
|
||||
|
||||
encryptedContent, err := os.ReadFile(encryptedFilePath)
|
||||
require.NoError(t, err, "Couldn't read encrypted file")
|
||||
assert.NotEqual(t, encryptedContent, content, "Content wasn't encrypted")
|
||||
encryptedFileReader, _ := os.Open(encryptedFilePath)
|
||||
defer encryptedFileReader.Close()
|
||||
|
||||
encryptedFileReader, err := os.Open(encryptedFilePath)
|
||||
require.NoError(t, err)
|
||||
defer logs.CloseAndLogErr(encryptedFileReader)
|
||||
decryptedFileWriter, _ := os.Create(decryptedFilePath)
|
||||
defer decryptedFileWriter.Close()
|
||||
|
||||
decryptedFileWriter, err := os.Create(decryptedFilePath)
|
||||
require.NoError(t, err)
|
||||
defer logs.CloseAndLogErr(decryptedFileWriter)
|
||||
decryptedReader, err := AesDecrypt(encryptedFileReader, []byte(""))
|
||||
assert.Nil(t, err, "Failed to decrypt file")
|
||||
|
||||
decryptedReader, err := decrypt(encryptedFileReader, []byte(""))
|
||||
require.NoError(t, err, "Failed to decrypt file")
|
||||
io.Copy(decryptedFileWriter, decryptedReader)
|
||||
|
||||
_, err = io.Copy(decryptedFileWriter, decryptedReader)
|
||||
require.NoError(t, err)
|
||||
|
||||
decryptedContent, err := os.ReadFile(decryptedFilePath)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, content, decryptedContent, "Original and decrypted content should match")
|
||||
}
|
||||
|
||||
t.Run("fips", func(t *testing.T) {
|
||||
testFunc(t, aesEncryptGCMFIPS, aesDecryptGCMFIPS)
|
||||
})
|
||||
|
||||
t.Run("non_fips", func(t *testing.T) {
|
||||
testFunc(t, aesEncryptGCM, aesDecryptGCM)
|
||||
})
|
||||
decryptedContent, _ := os.ReadFile(decryptedFilePath)
|
||||
assert.Equal(t, content, decryptedContent, "Original and decrypted content should match")
|
||||
}
|
||||
|
||||
func Test_decryptWithDifferentPassphrase_shouldProduceWrongResult(t *testing.T) {
|
||||
testFunc := func(t *testing.T, encrypt encryptFunc, decrypt decryptFunc) {
|
||||
tmpdir := t.TempDir()
|
||||
tmpdir := t.TempDir()
|
||||
|
||||
var (
|
||||
originFilePath = filepath.Join(tmpdir, "origin")
|
||||
encryptedFilePath = filepath.Join(tmpdir, "encrypted")
|
||||
decryptedFilePath = filepath.Join(tmpdir, "decrypted")
|
||||
)
|
||||
var (
|
||||
originFilePath = filepath.Join(tmpdir, "origin")
|
||||
encryptedFilePath = filepath.Join(tmpdir, "encrypted")
|
||||
decryptedFilePath = filepath.Join(tmpdir, "decrypted")
|
||||
)
|
||||
|
||||
content := randBytes(1034)
|
||||
err := os.WriteFile(originFilePath, content, 0600)
|
||||
require.NoError(t, err)
|
||||
content := randBytes(1034)
|
||||
os.WriteFile(originFilePath, content, 0600)
|
||||
|
||||
originFile, err := os.Open(originFilePath)
|
||||
require.NoError(t, err)
|
||||
defer logs.CloseAndLogErr(originFile)
|
||||
originFile, _ := os.Open(originFilePath)
|
||||
defer originFile.Close()
|
||||
|
||||
encryptedFileWriter, err := os.Create(encryptedFilePath)
|
||||
require.NoError(t, err)
|
||||
defer logs.CloseAndLogErr(encryptedFileWriter)
|
||||
encryptedFileWriter, _ := os.Create(encryptedFilePath)
|
||||
defer encryptedFileWriter.Close()
|
||||
|
||||
err = encrypt(originFile, encryptedFileWriter, []byte("passphrase"))
|
||||
require.NoError(t, err, "Failed to encrypt a file")
|
||||
encryptedContent, err := os.ReadFile(encryptedFilePath)
|
||||
require.NoError(t, err, "Couldn't read encrypted file")
|
||||
assert.NotEqual(t, encryptedContent, content, "Content wasn't encrypted")
|
||||
err := AesEncrypt(originFile, encryptedFileWriter, []byte("passphrase"))
|
||||
assert.Nil(t, err, "Failed to encrypt a file")
|
||||
encryptedContent, err := os.ReadFile(encryptedFilePath)
|
||||
assert.Nil(t, err, "Couldn't read encrypted file")
|
||||
assert.NotEqual(t, encryptedContent, content, "Content wasn't encrypted")
|
||||
|
||||
encryptedFileReader, err := os.Open(encryptedFilePath)
|
||||
require.NoError(t, err)
|
||||
defer logs.CloseAndLogErr(encryptedFileReader)
|
||||
encryptedFileReader, _ := os.Open(encryptedFilePath)
|
||||
defer encryptedFileReader.Close()
|
||||
|
||||
decryptedFileWriter, err := os.Create(decryptedFilePath)
|
||||
require.NoError(t, err)
|
||||
defer logs.CloseAndLogErr(decryptedFileWriter)
|
||||
decryptedFileWriter, _ := os.Create(decryptedFilePath)
|
||||
defer decryptedFileWriter.Close()
|
||||
|
||||
_, err = decrypt(encryptedFileReader, []byte("garbage"))
|
||||
require.Error(t, err, "Should not allow decrypt with wrong passphrase")
|
||||
}
|
||||
|
||||
t.Run("fips", func(t *testing.T) {
|
||||
testFunc(t, aesEncryptGCMFIPS, aesDecryptGCMFIPS)
|
||||
})
|
||||
|
||||
t.Run("non_fips", func(t *testing.T) {
|
||||
testFunc(t, aesEncryptGCM, aesDecryptGCM)
|
||||
})
|
||||
}
|
||||
|
||||
func legacyAesEncrypt(input io.Reader, output io.Writer, passphrase []byte) error {
|
||||
key, err := scrypt.Key(passphrase, nil, 32768, 8, 1, 32)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var iv [aes.BlockSize]byte
|
||||
stream := cipher.NewOFB(block, iv[:])
|
||||
|
||||
writer := &cipher.StreamWriter{S: stream, W: output}
|
||||
if _, err := io.Copy(writer, input); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func Test_hasEncryptedHeader(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
data []byte
|
||||
fipsMode bool
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "non-FIPS mode with valid header",
|
||||
data: []byte("AES256-GCM" + "some encrypted data"),
|
||||
fipsMode: false,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "non-FIPS mode with FIPS header",
|
||||
data: []byte("FIPS-AES256-GCM" + "some encrypted data"),
|
||||
fipsMode: false,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "FIPS mode with valid header",
|
||||
data: []byte("FIPS-AES256-GCM" + "some encrypted data"),
|
||||
fipsMode: true,
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "FIPS mode with non-FIPS header",
|
||||
data: []byte("AES256-GCM" + "some encrypted data"),
|
||||
fipsMode: true,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "invalid header",
|
||||
data: []byte("INVALID-HEADER" + "some data"),
|
||||
fipsMode: false,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "empty data",
|
||||
data: []byte{},
|
||||
fipsMode: false,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "nil data",
|
||||
data: nil,
|
||||
fipsMode: false,
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := hasEncryptedHeader(tt.data, tt.fipsMode)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
_, err = AesDecrypt(encryptedFileReader, []byte("garbage"))
|
||||
assert.NotNil(t, err, "Should not allow decrypt with wrong passphrase")
|
||||
}
|
||||
|
||||
@@ -112,7 +112,7 @@ func (service *ECDSAService) CreateSignature(message string) (string, error) {
|
||||
message = service.secret
|
||||
}
|
||||
|
||||
hash := libcrypto.InsecureHashFromBytes([]byte(message))
|
||||
hash := libcrypto.HashFromBytes([]byte(message))
|
||||
|
||||
r, s, err := ecdsa.Sign(rand.Reader, service.privateKey, hash)
|
||||
if err != nil {
|
||||
|
||||
@@ -1,22 +0,0 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCreateSignature(t *testing.T) {
|
||||
var s = NewECDSAService("secret")
|
||||
|
||||
privKey, pubKey, err := s.GenerateKeyPair()
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, privKey)
|
||||
require.NotEmpty(t, pubKey)
|
||||
|
||||
m := "test message"
|
||||
r, err := s.CreateSignature(m)
|
||||
require.NoError(t, err)
|
||||
require.NotEqual(t, r, m)
|
||||
require.NotEmpty(t, r)
|
||||
}
|
||||
@@ -1,24 +1,22 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
// Not allowed in FIPS mode
|
||||
"golang.org/x/crypto/bcrypt" //nolint:depguard
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
// Service represents a service for encrypting/hashing data.
|
||||
type Service struct{}
|
||||
|
||||
// Hash hashes a string using the bcrypt algorithm
|
||||
func (Service) Hash(data string) (string, error) {
|
||||
func (*Service) Hash(data string) (string, error) {
|
||||
bytes, err := bcrypt.GenerateFromPassword([]byte(data), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return string(bytes), err
|
||||
}
|
||||
|
||||
// CompareHashAndData compares a hash to clear data and returns an error if the comparison fails.
|
||||
func (Service) CompareHashAndData(hash string, data string) error {
|
||||
func (*Service) CompareHashAndData(hash string, data string) error {
|
||||
return bcrypt.CompareHashAndPassword([]byte(hash), []byte(data))
|
||||
}
|
||||
|
||||
@@ -2,12 +2,10 @@ package crypto
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestService_Hash(t *testing.T) {
|
||||
var s = Service{}
|
||||
var s = &Service{}
|
||||
|
||||
type args struct {
|
||||
hash string
|
||||
@@ -53,11 +51,3 @@ func TestService_Hash(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHash(t *testing.T) {
|
||||
s := Service{}
|
||||
|
||||
hash, err := s.Hash("Passw0rd!")
|
||||
require.NoError(t, err)
|
||||
require.NotEmpty(t, hash)
|
||||
}
|
||||
|
||||
@@ -15,7 +15,7 @@ func NewNonce(size int) *Nonce {
|
||||
}
|
||||
|
||||
// NewRandomNonce generates a new initial nonce with the lower byte set to a random value
|
||||
// This ensures there are plenty of nonce values available before rolling over
|
||||
// This ensures there are plenty of nonce values availble before rolling over
|
||||
// Based on ideas from the Secure Programming Cookbook for C and C++ by John Viega, Matt Messier
|
||||
// https://www.oreilly.com/library/view/secure-programming-cookbook/0596003943/ch04s09.html
|
||||
func NewRandomNonce(size int) (*Nonce, error) {
|
||||
|
||||
@@ -4,32 +4,11 @@ import (
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"os"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/pkg/fips"
|
||||
)
|
||||
|
||||
// CreateTLSConfiguration creates a basic tls.Config with recommended TLS settings
|
||||
func CreateTLSConfiguration(insecureSkipVerify bool) *tls.Config { //nolint:forbidigo
|
||||
return createTLSConfiguration(fips.FIPSMode(), insecureSkipVerify)
|
||||
}
|
||||
|
||||
func createTLSConfiguration(fipsEnabled bool, insecureSkipVerify bool) *tls.Config { //nolint:forbidigo
|
||||
if fipsEnabled {
|
||||
return &tls.Config{ //nolint:forbidigo
|
||||
MinVersion: tls.VersionTLS12,
|
||||
MaxVersion: tls.VersionTLS13,
|
||||
CipherSuites: []uint16{
|
||||
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
|
||||
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
|
||||
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
|
||||
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
|
||||
},
|
||||
CurvePreferences: []tls.CurveID{tls.CurveP256, tls.CurveP384, tls.CurveP521},
|
||||
}
|
||||
}
|
||||
|
||||
return &tls.Config{ //nolint:forbidigo
|
||||
func CreateTLSConfiguration() *tls.Config {
|
||||
return &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
CipherSuites: []uint16{
|
||||
tls.TLS_AES_128_GCM_SHA256,
|
||||
@@ -50,33 +29,24 @@ func createTLSConfiguration(fipsEnabled bool, insecureSkipVerify bool) *tls.Conf
|
||||
tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
|
||||
tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256,
|
||||
},
|
||||
InsecureSkipVerify: insecureSkipVerify, //nolint:forbidigo
|
||||
}
|
||||
}
|
||||
|
||||
// CreateTLSConfigurationFromBytes initializes a tls.Config using a CA certificate, a certificate and a key
|
||||
// loaded from memory.
|
||||
func CreateTLSConfigurationFromBytes(useTLS bool, caCert, cert, key []byte, skipClientVerification, skipServerVerification bool) (*tls.Config, error) { //nolint:forbidigo
|
||||
return createTLSConfigurationFromBytes(fips.FIPSMode(), useTLS, caCert, cert, key, skipClientVerification, skipServerVerification)
|
||||
}
|
||||
func CreateTLSConfigurationFromBytes(caCert, cert, key []byte, skipClientVerification, skipServerVerification bool) (*tls.Config, error) {
|
||||
config := CreateTLSConfiguration()
|
||||
config.InsecureSkipVerify = skipServerVerification
|
||||
|
||||
func createTLSConfigurationFromBytes(fipsEnabled, useTLS bool, caCert, cert, key []byte, skipClientVerification, skipServerVerification bool) (*tls.Config, error) { //nolint:forbidigo
|
||||
if !useTLS {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
config := createTLSConfiguration(fipsEnabled, skipServerVerification)
|
||||
|
||||
if !skipClientVerification || fipsEnabled {
|
||||
if !skipClientVerification {
|
||||
certificate, err := tls.X509KeyPair(cert, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
config.Certificates = []tls.Certificate{certificate}
|
||||
}
|
||||
|
||||
if !skipServerVerification || fipsEnabled {
|
||||
if !skipServerVerification {
|
||||
caCertPool := x509.NewCertPool()
|
||||
caCertPool.AppendCertsFromPEM(caCert)
|
||||
config.RootCAs = caCertPool
|
||||
@@ -87,38 +57,29 @@ func createTLSConfigurationFromBytes(fipsEnabled, useTLS bool, caCert, cert, key
|
||||
|
||||
// CreateTLSConfigurationFromDisk initializes a tls.Config using a CA certificate, a certificate and a key
|
||||
// loaded from disk.
|
||||
func CreateTLSConfigurationFromDisk(config portainer.TLSConfiguration) (*tls.Config, error) { //nolint:forbidigo
|
||||
return createTLSConfigurationFromDisk(fips.FIPSMode(), config)
|
||||
}
|
||||
func CreateTLSConfigurationFromDisk(caCertPath, certPath, keyPath string, skipServerVerification bool) (*tls.Config, error) {
|
||||
config := CreateTLSConfiguration()
|
||||
config.InsecureSkipVerify = skipServerVerification
|
||||
|
||||
func createTLSConfigurationFromDisk(fipsEnabled bool, config portainer.TLSConfiguration) (*tls.Config, error) { //nolint:forbidigo
|
||||
if !config.TLS && fipsEnabled {
|
||||
return nil, fips.ErrTLSRequired
|
||||
} else if !config.TLS {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
tlsConfig := createTLSConfiguration(fipsEnabled, config.TLSSkipVerify)
|
||||
|
||||
if config.TLSCertPath != "" && config.TLSKeyPath != "" {
|
||||
cert, err := tls.LoadX509KeyPair(config.TLSCertPath, config.TLSKeyPath)
|
||||
if certPath != "" && keyPath != "" {
|
||||
cert, err := tls.LoadX509KeyPair(certPath, keyPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tlsConfig.Certificates = []tls.Certificate{cert}
|
||||
config.Certificates = []tls.Certificate{cert}
|
||||
}
|
||||
|
||||
if !tlsConfig.InsecureSkipVerify && config.TLSCACertPath != "" { //nolint:forbidigo
|
||||
caCert, err := os.ReadFile(config.TLSCACertPath)
|
||||
if !skipServerVerification && caCertPath != "" {
|
||||
caCert, err := os.ReadFile(caCertPath)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
caCertPool := x509.NewCertPool()
|
||||
caCertPool.AppendCertsFromPEM(caCert)
|
||||
tlsConfig.RootCAs = caCertPool
|
||||
config.RootCAs = caCertPool
|
||||
}
|
||||
|
||||
return tlsConfig, nil
|
||||
return config, nil
|
||||
}
|
||||
|
||||
@@ -1,87 +0,0 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"testing"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCreateTLSConfiguration(t *testing.T) {
|
||||
// InsecureSkipVerify = false
|
||||
config := CreateTLSConfiguration(false)
|
||||
require.Equal(t, config.MinVersion, uint16(tls.VersionTLS12)) //nolint:forbidigo
|
||||
require.False(t, config.InsecureSkipVerify) //nolint:forbidigo
|
||||
|
||||
// InsecureSkipVerify = true
|
||||
config = CreateTLSConfiguration(true)
|
||||
require.Equal(t, config.MinVersion, uint16(tls.VersionTLS12)) //nolint:forbidigo
|
||||
require.True(t, config.InsecureSkipVerify) //nolint:forbidigo
|
||||
}
|
||||
|
||||
func TestCreateTLSConfigurationFIPS(t *testing.T) {
|
||||
fips := true
|
||||
|
||||
fipsCipherSuites := []uint16{
|
||||
tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
|
||||
tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
|
||||
tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
|
||||
tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
|
||||
}
|
||||
|
||||
fipsCurvePreferences := []tls.CurveID{tls.CurveP256, tls.CurveP384, tls.CurveP521}
|
||||
|
||||
config := createTLSConfiguration(fips, false)
|
||||
require.Equal(t, config.MinVersion, uint16(tls.VersionTLS12)) //nolint:forbidigo
|
||||
require.Equal(t, config.MaxVersion, uint16(tls.VersionTLS13)) //nolint:forbidigo
|
||||
require.Equal(t, config.CipherSuites, fipsCipherSuites) //nolint:forbidigo
|
||||
require.Equal(t, config.CurvePreferences, fipsCurvePreferences) //nolint:forbidigo
|
||||
require.False(t, config.InsecureSkipVerify) //nolint:forbidigo
|
||||
}
|
||||
|
||||
func TestCreateTLSConfigurationFromBytes(t *testing.T) {
|
||||
// No TLS
|
||||
config, err := CreateTLSConfigurationFromBytes(false, nil, nil, nil, false, false)
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, config)
|
||||
|
||||
// Skip TLS client/server verifications
|
||||
config, err = CreateTLSConfigurationFromBytes(true, nil, nil, nil, true, true)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
|
||||
// Empty TLS
|
||||
config, err = CreateTLSConfigurationFromBytes(true, nil, nil, nil, false, false)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, config)
|
||||
}
|
||||
|
||||
func TestCreateTLSConfigurationFromDisk(t *testing.T) {
|
||||
// No TLS
|
||||
config, err := CreateTLSConfigurationFromDisk(portainer.TLSConfiguration{})
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, config)
|
||||
|
||||
// Skip TLS verifications
|
||||
config, err = CreateTLSConfigurationFromDisk(portainer.TLSConfiguration{
|
||||
TLS: true,
|
||||
TLSSkipVerify: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
}
|
||||
|
||||
func TestCreateTLSConfigurationFromDiskFIPS(t *testing.T) {
|
||||
fips := true
|
||||
|
||||
// Skipping TLS verifications cannot be done in FIPS mode
|
||||
config, err := createTLSConfigurationFromDisk(fips, portainer.TLSConfiguration{
|
||||
TLS: true,
|
||||
TLSSkipVerify: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, config)
|
||||
require.False(t, config.InsecureSkipVerify) //nolint:forbidigo
|
||||
}
|
||||
@@ -21,9 +21,6 @@ import (
|
||||
const (
|
||||
DatabaseFileName = "portainer.db"
|
||||
EncryptedDatabaseFileName = "portainer.edb"
|
||||
|
||||
txMaxSize = 65536
|
||||
compactedSuffix = ".compacted"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -38,7 +35,6 @@ type DbConnection struct {
|
||||
InitialMmapSize int
|
||||
EncryptionKey []byte
|
||||
isEncrypted bool
|
||||
Compact bool
|
||||
|
||||
*bolt.DB
|
||||
}
|
||||
@@ -136,8 +132,13 @@ func (connection *DbConnection) NeedsEncryptionMigration() (bool, error) {
|
||||
func (connection *DbConnection) Open() error {
|
||||
log.Info().Str("filename", connection.GetDatabaseFileName()).Msg("loading PortainerDB")
|
||||
|
||||
// Now we open the db
|
||||
databasePath := connection.GetDatabaseFilePath()
|
||||
db, err := bolt.Open(databasePath, 0600, connection.boltOptions(connection.Compact))
|
||||
|
||||
db, err := bolt.Open(databasePath, 0600, &bolt.Options{
|
||||
Timeout: 1 * time.Second,
|
||||
InitialMmapSize: connection.InitialMmapSize,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -146,24 +147,6 @@ func (connection *DbConnection) Open() error {
|
||||
db.MaxBatchDelay = connection.MaxBatchDelay
|
||||
connection.DB = db
|
||||
|
||||
if connection.Compact {
|
||||
log.Info().Msg("compacting database")
|
||||
if err := connection.compact(); err != nil {
|
||||
log.Error().Err(err).Msg("failed to compact database")
|
||||
|
||||
// Close the read-only database and re-open in read-write mode
|
||||
if err := connection.Close(); err != nil {
|
||||
log.Warn().Err(err).Msg("failure to close the database after failed compaction")
|
||||
}
|
||||
|
||||
connection.Compact = false
|
||||
|
||||
return connection.Open()
|
||||
} else {
|
||||
log.Info().Msg("database compaction completed")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -429,48 +412,3 @@ func (connection *DbConnection) RestoreMetadata(s map[string]any) error {
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
// compact attempts to compact the database and replace it iff it succeeds
|
||||
func (connection *DbConnection) compact() (err error) {
|
||||
compactedPath := connection.GetDatabaseFilePath() + compactedSuffix
|
||||
|
||||
if err := os.Remove(compactedPath); err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
return fmt.Errorf("failure to remove an existing compacted database: %w", err)
|
||||
}
|
||||
|
||||
compactedDB, err := bolt.Open(compactedPath, 0o600, connection.boltOptions(false))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failure to create the compacted database: %w", err)
|
||||
}
|
||||
|
||||
compactedDB.MaxBatchSize = connection.MaxBatchSize
|
||||
compactedDB.MaxBatchDelay = connection.MaxBatchDelay
|
||||
|
||||
if err := bolt.Compact(compactedDB, connection.DB, txMaxSize); err != nil {
|
||||
return fmt.Errorf("failure to compact the database: %w",
|
||||
errors.Join(err, compactedDB.Close(), os.Remove(compactedPath)))
|
||||
}
|
||||
|
||||
if err := os.Rename(compactedPath, connection.GetDatabaseFilePath()); err != nil {
|
||||
return fmt.Errorf("failure to move the compacted database: %w",
|
||||
errors.Join(err, compactedDB.Close(), os.Remove(compactedPath)))
|
||||
}
|
||||
|
||||
if err := connection.Close(); err != nil {
|
||||
log.Warn().Err(err).Msg("failure to close the database after compaction")
|
||||
}
|
||||
|
||||
connection.DB = compactedDB
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (connection *DbConnection) boltOptions(readOnly bool) *bolt.Options {
|
||||
return &bolt.Options{
|
||||
Timeout: 1 * time.Second,
|
||||
InitialMmapSize: connection.InitialMmapSize,
|
||||
FreelistType: bolt.FreelistMapType,
|
||||
NoFreelistSync: true,
|
||||
ReadOnly: readOnly,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -5,11 +5,7 @@ import (
|
||||
"path"
|
||||
"testing"
|
||||
|
||||
"github.com/portainer/portainer/api/filesystem"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.etcd.io/bbolt"
|
||||
)
|
||||
|
||||
func Test_NeedsEncryptionMigration(t *testing.T) {
|
||||
@@ -98,36 +94,18 @@ func Test_NeedsEncryptionMigration(t *testing.T) {
|
||||
// Special case. If portainer.db and portainer.edb exist.
|
||||
dbFile1 := path.Join(connection.Path, DatabaseFileName)
|
||||
f, _ := os.Create(dbFile1)
|
||||
|
||||
err := f.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
err := os.Remove(dbFile1)
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
f.Close()
|
||||
defer os.Remove(dbFile1)
|
||||
|
||||
dbFile2 := path.Join(connection.Path, EncryptedDatabaseFileName)
|
||||
f, _ = os.Create(dbFile2)
|
||||
|
||||
err = f.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
err := os.Remove(dbFile2)
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
f.Close()
|
||||
defer os.Remove(dbFile2)
|
||||
} else if tc.dbname != "" {
|
||||
dbFile := path.Join(connection.Path, tc.dbname)
|
||||
f, _ := os.Create(dbFile)
|
||||
|
||||
err := f.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
defer func() {
|
||||
err := os.Remove(dbFile)
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
f.Close()
|
||||
defer os.Remove(dbFile)
|
||||
}
|
||||
|
||||
if tc.key {
|
||||
@@ -141,60 +119,3 @@ func Test_NeedsEncryptionMigration(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDBCompaction(t *testing.T) {
|
||||
db := &DbConnection{Path: t.TempDir()}
|
||||
|
||||
err := db.Open()
|
||||
require.NoError(t, err)
|
||||
|
||||
err = db.Update(func(tx *bbolt.Tx) error {
|
||||
b, err := tx.CreateBucketIfNotExists([]byte("testbucket"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = b.Put([]byte("key"), []byte("value"))
|
||||
require.NoError(t, err)
|
||||
|
||||
return nil
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = db.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Reopen the DB to trigger compaction
|
||||
db.Compact = true
|
||||
err = db.Open()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check that the data is still there
|
||||
err = db.View(func(tx *bbolt.Tx) error {
|
||||
b := tx.Bucket([]byte("testbucket"))
|
||||
if b == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
val := b.Get([]byte("key"))
|
||||
require.Equal(t, []byte("value"), val)
|
||||
|
||||
return nil
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = db.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Failures
|
||||
compactedPath := db.GetDatabaseFilePath() + compactedSuffix
|
||||
err = os.Mkdir(compactedPath, 0o755)
|
||||
require.NoError(t, err)
|
||||
|
||||
f, err := os.Create(filesystem.JoinPaths(compactedPath, "somefile"))
|
||||
require.NoError(t, err)
|
||||
require.NoError(t, f.Close())
|
||||
|
||||
err = db.Open()
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@ package boltdb
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/portainer/portainer/api/logs"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/segmentio/encoding/json"
|
||||
bolt "go.etcd.io/bbolt"
|
||||
@@ -38,7 +37,7 @@ func (c *DbConnection) ExportJSON(databasePath string, metadata bool) ([]byte, e
|
||||
if err != nil {
|
||||
return []byte("{}"), err
|
||||
}
|
||||
defer logs.CloseAndLogErr(connection)
|
||||
defer connection.Close()
|
||||
|
||||
backup := make(map[string]any)
|
||||
if metadata {
|
||||
|
||||
@@ -4,6 +4,8 @@ import (
|
||||
"bytes"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"io"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/segmentio/encoding/json"
|
||||
@@ -45,12 +47,12 @@ func (connection *DbConnection) UnmarshalObject(data []byte, object any) error {
|
||||
}
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(data, object); err != nil {
|
||||
if e := json.Unmarshal(data, object); e != nil {
|
||||
// Special case for the VERSION bucket. Here we're not using json
|
||||
// So we need to return it as a string
|
||||
s, ok := object.(*string)
|
||||
if !ok {
|
||||
return errors.Wrap(err, "Failed unmarshalling object")
|
||||
return errors.Wrap(err, e.Error())
|
||||
}
|
||||
|
||||
*s = string(data)
|
||||
@@ -63,18 +65,18 @@ func (connection *DbConnection) UnmarshalObject(data []byte, object any) error {
|
||||
// https://gist.github.com/atoponce/07d8d4c833873be2f68c34f9afc5a78a#symmetric-encryption
|
||||
|
||||
func encrypt(plaintext []byte, passphrase []byte) (encrypted []byte, err error) {
|
||||
block, err := aes.NewCipher(passphrase)
|
||||
block, _ := aes.NewCipher(passphrase)
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return encrypted, err
|
||||
}
|
||||
|
||||
// NewGCMWithRandomNonce in go 1.24 handles setting up the nonce and adding it to the encrypted output
|
||||
gcm, err := cipher.NewGCMWithRandomNonce(block)
|
||||
if err != nil {
|
||||
nonce := make([]byte, gcm.NonceSize())
|
||||
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
return encrypted, err
|
||||
}
|
||||
|
||||
return gcm.Seal(nil, nil, plaintext, nil), nil
|
||||
return gcm.Seal(nonce, nonce, plaintext, nil), nil
|
||||
}
|
||||
|
||||
func decrypt(encrypted []byte, passphrase []byte) (plaintextByte []byte, err error) {
|
||||
@@ -87,17 +89,19 @@ func decrypt(encrypted []byte, passphrase []byte) (plaintextByte []byte, err err
|
||||
return encrypted, errors.Wrap(err, "Error creating cypher block")
|
||||
}
|
||||
|
||||
// NewGCMWithRandomNonce in go 1.24 handles reading the nonce from the encrypted input for us
|
||||
gcm, err := cipher.NewGCMWithRandomNonce(block)
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return encrypted, errors.Wrap(err, "Error creating GCM")
|
||||
}
|
||||
|
||||
if len(encrypted) < gcm.NonceSize() {
|
||||
nonceSize := gcm.NonceSize()
|
||||
if len(encrypted) < nonceSize {
|
||||
return encrypted, errEncryptedStringTooShort
|
||||
}
|
||||
|
||||
plaintextByte, err = gcm.Open(nil, nil, encrypted, nil)
|
||||
nonce, ciphertextByteClean := encrypted[:nonceSize], encrypted[nonceSize:]
|
||||
|
||||
plaintextByte, err = gcm.Open(nil, nonce, ciphertextByteClean, nil)
|
||||
if err != nil {
|
||||
return encrypted, errors.Wrap(err, "Error decrypting text")
|
||||
}
|
||||
|
||||
@@ -1,23 +1,16 @@
|
||||
package boltdb
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
"testing"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/gofrs/uuid"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const (
|
||||
jsonobject = `{"LogoURL":"","BlackListedLabels":[],"AuthenticationMethod":1,"InternalAuthSettings": {"RequiredPasswordLength": 12}"LDAPSettings":{"AnonymousMode":true,"ReaderDN":"","URL":"","TLSConfig":{"TLS":false,"TLSSkipVerify":false},"StartTLS":false,"SearchSettings":[{"BaseDN":"","Filter":"","UserNameAttribute":""}],"GroupSearchSettings":[{"GroupBaseDN":"","GroupFilter":"","GroupAttribute":""}],"AutoCreateUsers":true},"OAuthSettings":{"ClientID":"","AccessTokenURI":"","AuthorizationURI":"","ResourceURI":"","RedirectURI":"","UserIdentifier":"","Scopes":"","OAuthAutoCreateUsers":false,"DefaultTeamID":0,"SSO":true,"LogoutURI":"","KubeSecretKey":"j0zLVtY/lAWBk62ByyF0uP80SOXaitsABP0TTJX8MhI="},"OpenAMTConfiguration":{"Enabled":false,"MPSServer":"","MPSUser":"","MPSPassword":"","MPSToken":"","CertFileContent":"","CertFileName":"","CertFilePassword":"","DomainName":""},"FeatureFlagSettings":{},"SnapshotInterval":"5m","TemplatesURL":"https://raw.githubusercontent.com/portainer/templates/master/templates-2.0.json","EdgeAgentCheckinInterval":5,"EnableEdgeComputeFeatures":false,"UserSessionTimeout":"8h","KubeconfigExpiry":"0","HelmRepositoryURL":"https://charts.bitnami.com/bitnami","KubectlShellImage":"portainer/kubectl-shell","DisplayDonationHeader":false,"DisplayExternalContributors":false,"EnableHostManagementFeatures":false,"AllowVolumeBrowserForRegularUsers":false,"AllowBindMountsForRegularUsers":false,"AllowPrivilegedModeForRegularUsers":false,"AllowHostNamespaceForRegularUsers":false,"AllowStackManagementForRegularUsers":false,"AllowDeviceMappingForRegularUsers":false,"AllowContainerCapabilitiesForRegularUsers":false}`
|
||||
jsonobject = `{"LogoURL":"","BlackListedLabels":[],"AuthenticationMethod":1,"InternalAuthSettings": {"RequiredPasswordLength": 12}"LDAPSettings":{"AnonymousMode":true,"ReaderDN":"","URL":"","TLSConfig":{"TLS":false,"TLSSkipVerify":false},"StartTLS":false,"SearchSettings":[{"BaseDN":"","Filter":"","UserNameAttribute":""}],"GroupSearchSettings":[{"GroupBaseDN":"","GroupFilter":"","GroupAttribute":""}],"AutoCreateUsers":true},"OAuthSettings":{"ClientID":"","AccessTokenURI":"","AuthorizationURI":"","ResourceURI":"","RedirectURI":"","UserIdentifier":"","Scopes":"","OAuthAutoCreateUsers":false,"DefaultTeamID":0,"SSO":true,"LogoutURI":"","KubeSecretKey":"j0zLVtY/lAWBk62ByyF0uP80SOXaitsABP0TTJX8MhI="},"OpenAMTConfiguration":{"Enabled":false,"MPSServer":"","MPSUser":"","MPSPassword":"","MPSToken":"","CertFileContent":"","CertFileName":"","CertFilePassword":"","DomainName":""},"FeatureFlagSettings":{},"SnapshotInterval":"5m","TemplatesURL":"https://raw.githubusercontent.com/portainer/templates/master/templates-2.0.json","EdgeAgentCheckinInterval":5,"EnableEdgeComputeFeatures":false,"UserSessionTimeout":"8h","KubeconfigExpiry":"0","EnableTelemetry":true,"HelmRepositoryURL":"https://charts.bitnami.com/bitnami","KubectlShellImage":"portainer/kubectl-shell","DisplayDonationHeader":false,"DisplayExternalContributors":false,"EnableHostManagementFeatures":false,"AllowVolumeBrowserForRegularUsers":false,"AllowBindMountsForRegularUsers":false,"AllowPrivilegedModeForRegularUsers":false,"AllowHostNamespaceForRegularUsers":false,"AllowStackManagementForRegularUsers":false,"AllowDeviceMappingForRegularUsers":false,"AllowContainerCapabilitiesForRegularUsers":false}`
|
||||
passphrase = "my secret key"
|
||||
)
|
||||
|
||||
@@ -29,7 +22,7 @@ func secretToEncryptionKey(passphrase string) []byte {
|
||||
func Test_MarshalObjectUnencrypted(t *testing.T) {
|
||||
is := assert.New(t)
|
||||
|
||||
uuid := uuid.New()
|
||||
uuid := uuid.Must(uuid.NewV4())
|
||||
|
||||
tests := []struct {
|
||||
object any
|
||||
@@ -94,7 +87,7 @@ func Test_MarshalObjectUnencrypted(t *testing.T) {
|
||||
for _, test := range tests {
|
||||
t.Run(fmt.Sprintf("%s -> %s", test.object, test.expected), func(t *testing.T) {
|
||||
data, err := conn.MarshalObject(test.object)
|
||||
require.NoError(t, err)
|
||||
is.NoError(err)
|
||||
is.Equal(test.expected, string(data))
|
||||
})
|
||||
}
|
||||
@@ -135,7 +128,7 @@ func Test_UnMarshalObjectUnencrypted(t *testing.T) {
|
||||
t.Run(fmt.Sprintf("%s -> %s", test.object, test.expected), func(t *testing.T) {
|
||||
var object string
|
||||
err := conn.UnmarshalObject(test.object, &object)
|
||||
require.NoError(t, err)
|
||||
is.NoError(err)
|
||||
is.Equal(test.expected, object)
|
||||
})
|
||||
}
|
||||
@@ -167,109 +160,18 @@ func Test_ObjectMarshallingEncrypted(t *testing.T) {
|
||||
}
|
||||
|
||||
key := secretToEncryptionKey(passphrase)
|
||||
conn := DbConnection{EncryptionKey: key, isEncrypted: true}
|
||||
conn := DbConnection{EncryptionKey: key}
|
||||
for _, test := range tests {
|
||||
t.Run(fmt.Sprintf("%s -> %s", test.object, test.expected), func(t *testing.T) {
|
||||
|
||||
data, err := conn.MarshalObject(test.object)
|
||||
require.NoError(t, err)
|
||||
is.NoError(err)
|
||||
|
||||
var object []byte
|
||||
err = conn.UnmarshalObject(data, &object)
|
||||
|
||||
require.NoError(t, err)
|
||||
is.NoError(err)
|
||||
is.Equal(test.object, object)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_NonceSources(t *testing.T) {
|
||||
// ensure that the new go 1.24 NewGCMWithRandomNonce works correctly with
|
||||
// the old way of creating and including the nonce
|
||||
|
||||
encryptOldFn := func(plaintext []byte, passphrase []byte) (encrypted []byte, err error) {
|
||||
block, _ := aes.NewCipher(passphrase)
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return encrypted, err
|
||||
}
|
||||
|
||||
nonce := make([]byte, gcm.NonceSize())
|
||||
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
|
||||
return encrypted, err
|
||||
}
|
||||
|
||||
return gcm.Seal(nonce, nonce, plaintext, nil), nil
|
||||
}
|
||||
|
||||
decryptOldFn := func(encrypted []byte, passphrase []byte) (plaintext []byte, err error) {
|
||||
block, err := aes.NewCipher(passphrase)
|
||||
if err != nil {
|
||||
return encrypted, errors.Wrap(err, "Error creating cypher block")
|
||||
}
|
||||
|
||||
gcm, err := cipher.NewGCM(block)
|
||||
if err != nil {
|
||||
return encrypted, errors.Wrap(err, "Error creating GCM")
|
||||
}
|
||||
|
||||
nonceSize := gcm.NonceSize()
|
||||
if len(encrypted) < nonceSize {
|
||||
return encrypted, errEncryptedStringTooShort
|
||||
}
|
||||
|
||||
nonce, ciphertextByteClean := encrypted[:nonceSize], encrypted[nonceSize:]
|
||||
|
||||
plaintext, err = gcm.Open(nil, nonce, ciphertextByteClean, nil)
|
||||
if err != nil {
|
||||
return encrypted, errors.Wrap(err, "Error decrypting text")
|
||||
}
|
||||
|
||||
return plaintext, err
|
||||
}
|
||||
|
||||
encryptNewFn := encrypt
|
||||
decryptNewFn := decrypt
|
||||
|
||||
passphrase := make([]byte, 32)
|
||||
_, err := io.ReadFull(rand.Reader, passphrase)
|
||||
require.NoError(t, err)
|
||||
|
||||
junk := make([]byte, 1024)
|
||||
_, err = io.ReadFull(rand.Reader, junk)
|
||||
require.NoError(t, err)
|
||||
|
||||
junkEnc := make([]byte, base64.StdEncoding.EncodedLen(len(junk)))
|
||||
base64.StdEncoding.Encode(junkEnc, junk)
|
||||
|
||||
cases := [][]byte{
|
||||
[]byte("test"),
|
||||
[]byte("35"),
|
||||
[]byte("9ca4a1dd-a439-4593-b386-a7dfdc2e9fc6"),
|
||||
[]byte(jsonobject),
|
||||
passphrase,
|
||||
junk,
|
||||
junkEnc,
|
||||
}
|
||||
|
||||
for _, plain := range cases {
|
||||
var enc, dec []byte
|
||||
var err error
|
||||
|
||||
enc, err = encryptOldFn(plain, passphrase)
|
||||
require.NoError(t, err)
|
||||
|
||||
dec, err = decryptNewFn(enc, passphrase)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, plain, dec)
|
||||
|
||||
enc, err = encryptNewFn(plain, passphrase)
|
||||
require.NoError(t, err)
|
||||
|
||||
dec, err = decryptOldFn(enc, passphrase)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.Equal(t, plain, dec)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/dataservices"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const testBucketName = "test-bucket"
|
||||
@@ -18,55 +17,70 @@ type testStruct struct {
|
||||
}
|
||||
|
||||
func TestTxs(t *testing.T) {
|
||||
conn := DbConnection{Path: t.TempDir()}
|
||||
conn := DbConnection{
|
||||
Path: t.TempDir(),
|
||||
}
|
||||
|
||||
err := conn.Open()
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
err := conn.Close()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Error propagation
|
||||
err = conn.UpdateTx(func(tx portainer.Transaction) error {
|
||||
return errors.New("this is an error")
|
||||
})
|
||||
require.Error(t, err)
|
||||
if err == nil {
|
||||
t.Fatal("an error was expected, got nil instead")
|
||||
}
|
||||
|
||||
// Create an object
|
||||
newObj := testStruct{Key: "key", Value: "value"}
|
||||
newObj := testStruct{
|
||||
Key: "key",
|
||||
Value: "value",
|
||||
}
|
||||
|
||||
err = conn.UpdateTx(func(tx portainer.Transaction) error {
|
||||
if err := tx.SetServiceName(testBucketName); err != nil {
|
||||
err = tx.SetServiceName(testBucketName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return tx.CreateObjectWithId(testBucketName, testId, newObj)
|
||||
})
|
||||
require.NoError(t, err)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
obj := testStruct{}
|
||||
err = conn.ViewTx(func(tx portainer.Transaction) error {
|
||||
return tx.GetObject(testBucketName, conn.ConvertToKey(testId), &obj)
|
||||
})
|
||||
require.NoError(t, err)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if obj.Key != newObj.Key || obj.Value != newObj.Value {
|
||||
t.Fatalf("expected %s:%s, got %s:%s instead", newObj.Key, newObj.Value, obj.Key, obj.Value)
|
||||
}
|
||||
|
||||
// Update an object
|
||||
updatedObj := testStruct{Key: "updated-key", Value: "updated-value"}
|
||||
updatedObj := testStruct{
|
||||
Key: "updated-key",
|
||||
Value: "updated-value",
|
||||
}
|
||||
|
||||
err = conn.UpdateTx(func(tx portainer.Transaction) error {
|
||||
return tx.UpdateObject(testBucketName, conn.ConvertToKey(testId), &updatedObj)
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = conn.ViewTx(func(tx portainer.Transaction) error {
|
||||
return tx.GetObject(testBucketName, conn.ConvertToKey(testId), &obj)
|
||||
})
|
||||
require.NoError(t, err)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if obj.Key != updatedObj.Key || obj.Value != updatedObj.Value {
|
||||
t.Fatalf("expected %s:%s, got %s:%s instead", updatedObj.Key, updatedObj.Value, obj.Key, obj.Value)
|
||||
@@ -76,12 +90,16 @@ func TestTxs(t *testing.T) {
|
||||
err = conn.UpdateTx(func(tx portainer.Transaction) error {
|
||||
return tx.DeleteObject(testBucketName, conn.ConvertToKey(testId))
|
||||
})
|
||||
require.NoError(t, err)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
err = conn.ViewTx(func(tx portainer.Transaction) error {
|
||||
return tx.GetObject(testBucketName, conn.ConvertToKey(testId), &obj)
|
||||
})
|
||||
require.True(t, dataservices.IsErrObjectNotFound(err))
|
||||
if !dataservices.IsErrObjectNotFound(err) {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Get next identifier
|
||||
err = conn.UpdateTx(func(tx portainer.Transaction) error {
|
||||
@@ -94,11 +112,15 @@ func TestTxs(t *testing.T) {
|
||||
|
||||
return nil
|
||||
})
|
||||
require.NoError(t, err)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Try to write in a read transaction
|
||||
err = conn.ViewTx(func(tx portainer.Transaction) error {
|
||||
return tx.CreateObjectWithId(testBucketName, testId, newObj)
|
||||
})
|
||||
require.Error(t, err)
|
||||
if err == nil {
|
||||
t.Fatal("an error was expected, got nil instead")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -8,12 +8,11 @@ import (
|
||||
)
|
||||
|
||||
// NewDatabase should use config options to return a connection to the requested database
|
||||
func NewDatabase(storeType, storePath string, encryptionKey []byte, compact bool) (connection portainer.Connection, err error) {
|
||||
func NewDatabase(storeType, storePath string, encryptionKey []byte) (connection portainer.Connection, err error) {
|
||||
if storeType == "boltdb" {
|
||||
return &boltdb.DbConnection{
|
||||
Path: storePath,
|
||||
EncryptionKey: encryptionKey,
|
||||
Compact: compact,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -1,24 +0,0 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/portainer/portainer/api/database/boltdb"
|
||||
"github.com/portainer/portainer/api/filesystem"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNewDatabase(t *testing.T) {
|
||||
dbPath := filesystem.JoinPaths(t.TempDir(), "test.db")
|
||||
connection, err := NewDatabase("boltdb", dbPath, nil, false)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, connection)
|
||||
|
||||
_, ok := connection.(*boltdb.DbConnection)
|
||||
require.True(t, ok)
|
||||
|
||||
connection, err = NewDatabase("unknown", dbPath, nil, false)
|
||||
require.Error(t, err)
|
||||
require.Nil(t, connection)
|
||||
}
|
||||
@@ -10,7 +10,7 @@ type BaseCRUD[T any, I constraints.Integer] interface {
|
||||
Create(element *T) error
|
||||
Read(ID I) (*T, error)
|
||||
Exists(ID I) (bool, error)
|
||||
ReadAll(predicates ...func(T) bool) ([]T, error)
|
||||
ReadAll() ([]T, error)
|
||||
Update(ID I, element *T) error
|
||||
Delete(ID I) error
|
||||
}
|
||||
@@ -56,13 +56,12 @@ func (service BaseDataService[T, I]) Exists(ID I) (bool, error) {
|
||||
return exists, err
|
||||
}
|
||||
|
||||
// ReadAll retrieves all the elements that satisfy all the provided predicates.
|
||||
func (service BaseDataService[T, I]) ReadAll(predicates ...func(T) bool) ([]T, error) {
|
||||
func (service BaseDataService[T, I]) ReadAll() ([]T, error) {
|
||||
var collection = make([]T, 0)
|
||||
|
||||
return collection, service.Connection.ViewTx(func(tx portainer.Transaction) error {
|
||||
var err error
|
||||
collection, err = service.Tx(tx).ReadAll(predicates...)
|
||||
collection, err = service.Tx(tx).ReadAll()
|
||||
|
||||
return err
|
||||
})
|
||||
|
||||
@@ -1,91 +0,0 @@
|
||||
package dataservices
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/slicesx"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
type testObject struct {
|
||||
ID int
|
||||
Value int
|
||||
}
|
||||
|
||||
type mockConnection struct {
|
||||
store map[int]testObject
|
||||
|
||||
portainer.Connection
|
||||
}
|
||||
|
||||
func (m mockConnection) UpdateObject(bucket string, key []byte, value any) error {
|
||||
obj := value.(*testObject)
|
||||
|
||||
m.store[obj.ID] = *obj
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m mockConnection) GetAll(bucketName string, obj any, appendFn func(o any) (any, error)) error {
|
||||
for _, v := range m.store {
|
||||
if _, err := appendFn(&v); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m mockConnection) UpdateTx(fn func(portainer.Transaction) error) error {
|
||||
return fn(m)
|
||||
}
|
||||
|
||||
func (m mockConnection) ViewTx(fn func(portainer.Transaction) error) error {
|
||||
return fn(m)
|
||||
}
|
||||
|
||||
func (m mockConnection) ConvertToKey(v int) []byte {
|
||||
return []byte(strconv.Itoa(v))
|
||||
}
|
||||
func TestReadAll(t *testing.T) {
|
||||
service := BaseDataService[testObject, int]{
|
||||
Bucket: "testBucket",
|
||||
Connection: mockConnection{store: make(map[int]testObject)},
|
||||
}
|
||||
|
||||
data := []testObject{
|
||||
{ID: 1, Value: 1},
|
||||
{ID: 2, Value: 2},
|
||||
{ID: 3, Value: 3},
|
||||
{ID: 4, Value: 4},
|
||||
{ID: 5, Value: 5},
|
||||
}
|
||||
|
||||
for _, item := range data {
|
||||
err := service.Update(item.ID, &item)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
// ReadAll without predicates
|
||||
result, err := service.ReadAll()
|
||||
require.NoError(t, err)
|
||||
|
||||
expected := append([]testObject{}, data...)
|
||||
|
||||
require.ElementsMatch(t, expected, result)
|
||||
|
||||
// ReadAll with predicates
|
||||
hasLowID := func(obj testObject) bool { return obj.ID < 3 }
|
||||
isEven := func(obj testObject) bool { return obj.Value%2 == 0 }
|
||||
|
||||
result, err = service.ReadAll(hasLowID, isEven)
|
||||
require.NoError(t, err)
|
||||
|
||||
expected = slicesx.Filter(expected, hasLowID)
|
||||
expected = slicesx.Filter(expected, isEven)
|
||||
|
||||
require.ElementsMatch(t, expected, result)
|
||||
}
|
||||
@@ -34,32 +34,13 @@ func (service BaseDataServiceTx[T, I]) Exists(ID I) (bool, error) {
|
||||
return service.Tx.KeyExists(service.Bucket, identifier)
|
||||
}
|
||||
|
||||
// ReadAll retrieves all the elements that satisfy all the provided predicates.
|
||||
func (service BaseDataServiceTx[T, I]) ReadAll(predicates ...func(T) bool) ([]T, error) {
|
||||
func (service BaseDataServiceTx[T, I]) ReadAll() ([]T, error) {
|
||||
var collection = make([]T, 0)
|
||||
|
||||
if len(predicates) == 0 {
|
||||
return collection, service.Tx.GetAll(
|
||||
service.Bucket,
|
||||
new(T),
|
||||
AppendFn(&collection),
|
||||
)
|
||||
}
|
||||
|
||||
filterFn := func(element T) bool {
|
||||
for _, p := range predicates {
|
||||
if !p(element) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
return collection, service.Tx.GetAll(
|
||||
service.Bucket,
|
||||
new(T),
|
||||
FilterFn(&collection, filterFn),
|
||||
AppendFn(&collection),
|
||||
)
|
||||
}
|
||||
|
||||
@@ -72,13 +53,3 @@ func (service BaseDataServiceTx[T, I]) Delete(ID I) error {
|
||||
identifier := service.Connection.ConvertToKey(int(ID))
|
||||
return service.Tx.DeleteObject(service.Bucket, identifier)
|
||||
}
|
||||
|
||||
func Read[T any](tx portainer.Transaction, bucket string, key []byte) (*T, error) {
|
||||
var element T
|
||||
|
||||
if err := tx.GetObject(bucket, key, &element); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &element, nil
|
||||
}
|
||||
|
||||
@@ -28,12 +28,13 @@ func NewService(connection portainer.Connection) (*Service, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CreateCustomTemplate uses the existing id and saves it.
|
||||
// TODO: where does the ID come from, and is it safe?
|
||||
func (service *Service) Create(customTemplate *portainer.CustomTemplate) error {
|
||||
return service.Connection.CreateObjectWithId(BucketName, int(customTemplate.ID), customTemplate)
|
||||
}
|
||||
|
||||
// GetNextIdentifier returns the next identifier for a custom template.
|
||||
func (service *Service) GetNextIdentifier() int {
|
||||
return service.Connection.GetNextIdentifier(BucketName)
|
||||
}
|
||||
|
||||
func (service *Service) Create(customTemplate *portainer.CustomTemplate) error {
|
||||
return service.Connection.UpdateTx(func(tx portainer.Transaction) error {
|
||||
return service.Tx(tx).Create(customTemplate)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -1,19 +0,0 @@
|
||||
package customtemplate_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/datastore"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCustomTemplateCreate(t *testing.T) {
|
||||
_, ds := datastore.MustNewTestStore(t, true, false)
|
||||
require.NotNil(t, ds)
|
||||
|
||||
require.NoError(t, ds.CustomTemplate().Create(&portainer.CustomTemplate{ID: 1}))
|
||||
e, err := ds.CustomTemplate().Read(1)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, portainer.CustomTemplateID(1), e.ID)
|
||||
}
|
||||
@@ -1,31 +0,0 @@
|
||||
package customtemplate
|
||||
|
||||
import (
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/dataservices"
|
||||
)
|
||||
|
||||
// Service represents a service for managing custom template data.
|
||||
type ServiceTx struct {
|
||||
dataservices.BaseDataServiceTx[portainer.CustomTemplate, portainer.CustomTemplateID]
|
||||
}
|
||||
|
||||
func (service *Service) Tx(tx portainer.Transaction) ServiceTx {
|
||||
return ServiceTx{
|
||||
BaseDataServiceTx: dataservices.BaseDataServiceTx[portainer.CustomTemplate, portainer.CustomTemplateID]{
|
||||
Bucket: BucketName,
|
||||
Connection: service.Connection,
|
||||
Tx: tx,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (service ServiceTx) GetNextIdentifier() int {
|
||||
return service.Tx.GetNextIdentifier(BucketName)
|
||||
}
|
||||
|
||||
// CreateCustomTemplate uses the existing id and saves it.
|
||||
// TODO: where does the ID come from, and is it safe?
|
||||
func (service ServiceTx) Create(customTemplate *portainer.CustomTemplate) error {
|
||||
return service.Tx.CreateObjectWithId(BucketName, int(customTemplate.ID), customTemplate)
|
||||
}
|
||||
@@ -1,28 +0,0 @@
|
||||
package customtemplate_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/dataservices"
|
||||
"github.com/portainer/portainer/api/datastore"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCustomTemplateCreateTx(t *testing.T) {
|
||||
_, ds := datastore.MustNewTestStore(t, true, false)
|
||||
require.NotNil(t, ds)
|
||||
|
||||
require.NoError(t, ds.UpdateTx(func(tx dataservices.DataStoreTx) error {
|
||||
return tx.CustomTemplate().Create(&portainer.CustomTemplate{ID: 1})
|
||||
}))
|
||||
|
||||
var template *portainer.CustomTemplate
|
||||
require.NoError(t, ds.ViewTx(func(tx dataservices.DataStoreTx) error {
|
||||
var err error
|
||||
template, err = tx.CustomTemplate().Read(1)
|
||||
return err
|
||||
}))
|
||||
|
||||
require.Equal(t, portainer.CustomTemplateID(1), template.ID)
|
||||
}
|
||||
@@ -17,29 +17,11 @@ func (service ServiceTx) UpdateEdgeGroupFunc(ID portainer.EdgeGroupID, updateFun
|
||||
}
|
||||
|
||||
func (service ServiceTx) Create(group *portainer.EdgeGroup) error {
|
||||
es := group.Endpoints
|
||||
group.Endpoints = nil // Clear deprecated field
|
||||
|
||||
err := service.Tx.CreateObject(
|
||||
return service.Tx.CreateObject(
|
||||
BucketName,
|
||||
func(id uint64) (int, any) {
|
||||
group.ID = portainer.EdgeGroupID(id)
|
||||
return int(group.ID), group
|
||||
},
|
||||
)
|
||||
|
||||
group.Endpoints = es // Restore endpoints after create
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (service ServiceTx) Update(ID portainer.EdgeGroupID, group *portainer.EdgeGroup) error {
|
||||
es := group.Endpoints
|
||||
group.Endpoints = nil // Clear deprecated field
|
||||
|
||||
err := service.BaseDataServiceTx.Update(ID, group)
|
||||
|
||||
group.Endpoints = es // Restore endpoints after update
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -1,51 +0,0 @@
|
||||
package edgestack
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/database/boltdb"
|
||||
"github.com/portainer/portainer/api/logs"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestUpdate(t *testing.T) {
|
||||
var conn portainer.Connection = &boltdb.DbConnection{Path: t.TempDir()}
|
||||
err := conn.Open()
|
||||
require.NoError(t, err)
|
||||
|
||||
defer logs.CloseAndLogErr(conn)
|
||||
|
||||
service, err := NewService(conn, func(portainer.Transaction, portainer.EdgeStackID) {})
|
||||
require.NoError(t, err)
|
||||
|
||||
const edgeStackID = 1
|
||||
edgeStack := &portainer.EdgeStack{
|
||||
ID: edgeStackID,
|
||||
Name: "Test Stack",
|
||||
}
|
||||
|
||||
err = service.Create(edgeStackID, edgeStack)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = service.UpdateEdgeStackFunc(edgeStackID, func(edgeStack *portainer.EdgeStack) {
|
||||
edgeStack.Name = "Updated Stack"
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
updatedStack, err := service.EdgeStack(edgeStackID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "Updated Stack", updatedStack.Name)
|
||||
|
||||
err = conn.UpdateTx(func(tx portainer.Transaction) error {
|
||||
return service.UpdateEdgeStackFuncTx(tx, edgeStackID, func(edgeStack *portainer.EdgeStack) {
|
||||
edgeStack.Name = "Updated Stack Again"
|
||||
})
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
updatedStack, err = service.EdgeStack(edgeStackID)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "Updated Stack Again", updatedStack.Name)
|
||||
}
|
||||
@@ -1,89 +0,0 @@
|
||||
package edgestackstatus
|
||||
|
||||
import (
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/dataservices"
|
||||
)
|
||||
|
||||
var _ dataservices.EdgeStackStatusService = &Service{}
|
||||
|
||||
const BucketName = "edge_stack_status"
|
||||
|
||||
type Service struct {
|
||||
conn portainer.Connection
|
||||
}
|
||||
|
||||
func (service *Service) BucketName() string {
|
||||
return BucketName
|
||||
}
|
||||
|
||||
func NewService(connection portainer.Connection) (*Service, error) {
|
||||
if err := connection.SetServiceName(BucketName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Service{conn: connection}, nil
|
||||
}
|
||||
|
||||
func (s *Service) Tx(tx portainer.Transaction) ServiceTx {
|
||||
return ServiceTx{
|
||||
service: s,
|
||||
tx: tx,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) Create(edgeStackID portainer.EdgeStackID, endpointID portainer.EndpointID, status *portainer.EdgeStackStatusForEnv) error {
|
||||
return s.conn.UpdateTx(func(tx portainer.Transaction) error {
|
||||
return s.Tx(tx).Create(edgeStackID, endpointID, status)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Service) Read(edgeStackID portainer.EdgeStackID, endpointID portainer.EndpointID) (*portainer.EdgeStackStatusForEnv, error) {
|
||||
var element *portainer.EdgeStackStatusForEnv
|
||||
|
||||
return element, s.conn.ViewTx(func(tx portainer.Transaction) error {
|
||||
var err error
|
||||
element, err = s.Tx(tx).Read(edgeStackID, endpointID)
|
||||
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Service) ReadAll(edgeStackID portainer.EdgeStackID) ([]portainer.EdgeStackStatusForEnv, error) {
|
||||
var collection = make([]portainer.EdgeStackStatusForEnv, 0)
|
||||
|
||||
return collection, s.conn.ViewTx(func(tx portainer.Transaction) error {
|
||||
var err error
|
||||
collection, err = s.Tx(tx).ReadAll(edgeStackID)
|
||||
|
||||
return err
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Service) Update(edgeStackID portainer.EdgeStackID, endpointID portainer.EndpointID, status *portainer.EdgeStackStatusForEnv) error {
|
||||
return s.conn.UpdateTx(func(tx portainer.Transaction) error {
|
||||
return s.Tx(tx).Update(edgeStackID, endpointID, status)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Service) Delete(edgeStackID portainer.EdgeStackID, endpointID portainer.EndpointID) error {
|
||||
return s.conn.UpdateTx(func(tx portainer.Transaction) error {
|
||||
return s.Tx(tx).Delete(edgeStackID, endpointID)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Service) DeleteAll(edgeStackID portainer.EdgeStackID) error {
|
||||
return s.conn.UpdateTx(func(tx portainer.Transaction) error {
|
||||
return s.Tx(tx).DeleteAll(edgeStackID)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Service) Clear(edgeStackID portainer.EdgeStackID, relatedEnvironmentsIDs []portainer.EndpointID) error {
|
||||
return s.conn.UpdateTx(func(tx portainer.Transaction) error {
|
||||
return s.Tx(tx).Clear(edgeStackID, relatedEnvironmentsIDs)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Service) key(edgeStackID portainer.EdgeStackID, endpointID portainer.EndpointID) []byte {
|
||||
return append(s.conn.ConvertToKey(int(edgeStackID)), s.conn.ConvertToKey(int(endpointID))...)
|
||||
}
|
||||
@@ -1,95 +0,0 @@
|
||||
package edgestackstatus
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/dataservices"
|
||||
)
|
||||
|
||||
var _ dataservices.EdgeStackStatusService = &Service{}
|
||||
|
||||
type ServiceTx struct {
|
||||
service *Service
|
||||
tx portainer.Transaction
|
||||
}
|
||||
|
||||
func (service ServiceTx) Create(edgeStackID portainer.EdgeStackID, endpointID portainer.EndpointID, status *portainer.EdgeStackStatusForEnv) error {
|
||||
identifier := service.service.key(edgeStackID, endpointID)
|
||||
return service.tx.CreateObjectWithStringId(BucketName, identifier, status)
|
||||
}
|
||||
|
||||
func (s ServiceTx) Read(edgeStackID portainer.EdgeStackID, endpointID portainer.EndpointID) (*portainer.EdgeStackStatusForEnv, error) {
|
||||
var status portainer.EdgeStackStatusForEnv
|
||||
identifier := s.service.key(edgeStackID, endpointID)
|
||||
|
||||
if err := s.tx.GetObject(BucketName, identifier, &status); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &status, nil
|
||||
}
|
||||
|
||||
func (s ServiceTx) ReadAll(edgeStackID portainer.EdgeStackID) ([]portainer.EdgeStackStatusForEnv, error) {
|
||||
keyPrefix := s.service.conn.ConvertToKey(int(edgeStackID))
|
||||
|
||||
statuses := make([]portainer.EdgeStackStatusForEnv, 0)
|
||||
|
||||
if err := s.tx.GetAllWithKeyPrefix(BucketName, keyPrefix, &portainer.EdgeStackStatusForEnv{}, dataservices.AppendFn(&statuses)); err != nil {
|
||||
return nil, fmt.Errorf("unable to retrieve EdgeStackStatus for EdgeStack %d: %w", edgeStackID, err)
|
||||
}
|
||||
|
||||
return statuses, nil
|
||||
}
|
||||
|
||||
func (s ServiceTx) Update(edgeStackID portainer.EdgeStackID, endpointID portainer.EndpointID, status *portainer.EdgeStackStatusForEnv) error {
|
||||
identifier := s.service.key(edgeStackID, endpointID)
|
||||
return s.tx.UpdateObject(BucketName, identifier, status)
|
||||
}
|
||||
|
||||
func (s ServiceTx) Delete(edgeStackID portainer.EdgeStackID, endpointID portainer.EndpointID) error {
|
||||
identifier := s.service.key(edgeStackID, endpointID)
|
||||
return s.tx.DeleteObject(BucketName, identifier)
|
||||
}
|
||||
|
||||
func (s ServiceTx) DeleteAll(edgeStackID portainer.EdgeStackID) error {
|
||||
keyPrefix := s.service.conn.ConvertToKey(int(edgeStackID))
|
||||
|
||||
statuses := make([]portainer.EdgeStackStatusForEnv, 0)
|
||||
|
||||
if err := s.tx.GetAllWithKeyPrefix(BucketName, keyPrefix, &portainer.EdgeStackStatusForEnv{}, dataservices.AppendFn(&statuses)); err != nil {
|
||||
return fmt.Errorf("unable to retrieve EdgeStackStatus for EdgeStack %d: %w", edgeStackID, err)
|
||||
}
|
||||
|
||||
for _, status := range statuses {
|
||||
if err := s.tx.DeleteObject(BucketName, s.service.key(edgeStackID, status.EndpointID)); err != nil {
|
||||
return fmt.Errorf("unable to delete EdgeStackStatus for EdgeStack %d and Endpoint %d: %w", edgeStackID, status.EndpointID, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s ServiceTx) Clear(edgeStackID portainer.EdgeStackID, relatedEnvironmentsIDs []portainer.EndpointID) error {
|
||||
for _, envID := range relatedEnvironmentsIDs {
|
||||
existingStatus, err := s.Read(edgeStackID, envID)
|
||||
if err != nil && !dataservices.IsErrObjectNotFound(err) {
|
||||
return fmt.Errorf("unable to retrieve status for environment %d: %w", envID, err)
|
||||
}
|
||||
|
||||
var deploymentInfo portainer.StackDeploymentInfo
|
||||
if existingStatus != nil {
|
||||
deploymentInfo = existingStatus.DeploymentInfo
|
||||
}
|
||||
|
||||
if err := s.Update(edgeStackID, envID, &portainer.EdgeStackStatusForEnv{
|
||||
EndpointID: envID,
|
||||
Status: []portainer.EdgeStackDeploymentStatus{},
|
||||
DeploymentInfo: deploymentInfo,
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -119,19 +119,6 @@ func (service *Service) Endpoints() ([]portainer.Endpoint, error) {
|
||||
return endpoints, nil
|
||||
}
|
||||
|
||||
// ReadAll retrieves all the elements that satisfy all the provided predicates.
|
||||
func (service *Service) ReadAll(predicates ...func(endpoint portainer.Endpoint) bool) ([]portainer.Endpoint, error) {
|
||||
var endpoints []portainer.Endpoint
|
||||
var err error
|
||||
|
||||
err = service.connection.ViewTx(func(tx portainer.Transaction) error {
|
||||
endpoints, err = service.Tx(tx).ReadAll(predicates...)
|
||||
return err
|
||||
})
|
||||
|
||||
return endpoints, err
|
||||
}
|
||||
|
||||
// EndpointIDByEdgeID returns the EndpointID from the given EdgeID using an in-memory index
|
||||
func (service *Service) EndpointIDByEdgeID(edgeID string) (portainer.EndpointID, bool) {
|
||||
service.mu.RLock()
|
||||
|
||||
@@ -89,11 +89,6 @@ func (service ServiceTx) Endpoints() ([]portainer.Endpoint, error) {
|
||||
)
|
||||
}
|
||||
|
||||
// ReadAll retrieves all the elements that satisfy all the provided predicates.
|
||||
func (service ServiceTx) ReadAll(predicates ...func(endpoint portainer.Endpoint) bool) ([]portainer.Endpoint, error) {
|
||||
return dataservices.BaseDataServiceTx[portainer.Endpoint, portainer.EndpointID]{Bucket: BucketName, Connection: service.service.connection, Tx: service.tx}.ReadAll(predicates...)
|
||||
}
|
||||
|
||||
func (service ServiceTx) EndpointIDByEdgeID(edgeID string) (portainer.EndpointID, bool) {
|
||||
log.Error().Str("func", "EndpointIDByEdgeID").Msg("cannot be called inside a transaction")
|
||||
|
||||
|
||||
@@ -6,6 +6,8 @@ import (
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/dataservices"
|
||||
"github.com/portainer/portainer/api/internal/edge/cache"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
// BucketName represents the name of the bucket where this service stores data.
|
||||
@@ -14,6 +16,7 @@ const BucketName = "endpoint_relations"
|
||||
// Service represents a service for managing environment(endpoint) relation data.
|
||||
type Service struct {
|
||||
connection portainer.Connection
|
||||
updateStackFn func(ID portainer.EdgeStackID, updateFunc func(edgeStack *portainer.EdgeStack)) error
|
||||
updateStackFnTx func(tx portainer.Transaction, ID portainer.EdgeStackID, updateFunc func(edgeStack *portainer.EdgeStack)) error
|
||||
endpointRelationsCache []portainer.EndpointRelation
|
||||
mu sync.Mutex
|
||||
@@ -26,8 +29,10 @@ func (service *Service) BucketName() string {
|
||||
}
|
||||
|
||||
func (service *Service) RegisterUpdateStackFunction(
|
||||
updateFunc func(portainer.EdgeStackID, func(*portainer.EdgeStack)) error,
|
||||
updateFuncTx func(portainer.Transaction, portainer.EdgeStackID, func(*portainer.EdgeStack)) error,
|
||||
) {
|
||||
service.updateStackFn = updateFunc
|
||||
service.updateStackFnTx = updateFuncTx
|
||||
}
|
||||
|
||||
@@ -86,26 +91,106 @@ func (service *Service) Create(endpointRelation *portainer.EndpointRelation) err
|
||||
|
||||
// UpdateEndpointRelation updates an Environment(Endpoint) relation object
|
||||
func (service *Service) UpdateEndpointRelation(endpointID portainer.EndpointID, endpointRelation *portainer.EndpointRelation) error {
|
||||
return service.connection.UpdateTx(func(tx portainer.Transaction) error {
|
||||
return service.Tx(tx).UpdateEndpointRelation(endpointID, endpointRelation)
|
||||
})
|
||||
previousRelationState, _ := service.EndpointRelation(endpointID)
|
||||
|
||||
identifier := service.connection.ConvertToKey(int(endpointID))
|
||||
err := service.connection.UpdateObject(BucketName, identifier, endpointRelation)
|
||||
cache.Del(endpointID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
updatedRelationState, _ := service.EndpointRelation(endpointID)
|
||||
|
||||
service.mu.Lock()
|
||||
service.endpointRelationsCache = nil
|
||||
service.mu.Unlock()
|
||||
|
||||
service.updateEdgeStacksAfterRelationChange(previousRelationState, updatedRelationState)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (service *Service) AddEndpointRelationsForEdgeStack(endpointIDs []portainer.EndpointID, edgeStack *portainer.EdgeStack) error {
|
||||
return service.connection.UpdateTx(func(tx portainer.Transaction) error {
|
||||
return service.Tx(tx).AddEndpointRelationsForEdgeStack(endpointIDs, edgeStack)
|
||||
func (service *Service) AddEndpointRelationsForEdgeStack(endpointIDs []portainer.EndpointID, edgeStackID portainer.EdgeStackID) error {
|
||||
return service.connection.ViewTx(func(tx portainer.Transaction) error {
|
||||
return service.Tx(tx).AddEndpointRelationsForEdgeStack(endpointIDs, edgeStackID)
|
||||
})
|
||||
}
|
||||
|
||||
func (service *Service) RemoveEndpointRelationsForEdgeStack(endpointIDs []portainer.EndpointID, edgeStackID portainer.EdgeStackID) error {
|
||||
return service.connection.UpdateTx(func(tx portainer.Transaction) error {
|
||||
return service.connection.ViewTx(func(tx portainer.Transaction) error {
|
||||
return service.Tx(tx).RemoveEndpointRelationsForEdgeStack(endpointIDs, edgeStackID)
|
||||
})
|
||||
}
|
||||
|
||||
// DeleteEndpointRelation deletes an Environment(Endpoint) relation object
|
||||
func (service *Service) DeleteEndpointRelation(endpointID portainer.EndpointID) error {
|
||||
return service.connection.UpdateTx(func(tx portainer.Transaction) error {
|
||||
return service.Tx(tx).DeleteEndpointRelation(endpointID)
|
||||
})
|
||||
deletedRelation, _ := service.EndpointRelation(endpointID)
|
||||
|
||||
identifier := service.connection.ConvertToKey(int(endpointID))
|
||||
err := service.connection.DeleteObject(BucketName, identifier)
|
||||
cache.Del(endpointID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
service.mu.Lock()
|
||||
service.endpointRelationsCache = nil
|
||||
service.mu.Unlock()
|
||||
|
||||
service.updateEdgeStacksAfterRelationChange(deletedRelation, nil)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (service *Service) updateEdgeStacksAfterRelationChange(previousRelationState *portainer.EndpointRelation, updatedRelationState *portainer.EndpointRelation) {
|
||||
relations, _ := service.EndpointRelations()
|
||||
|
||||
stacksToUpdate := map[portainer.EdgeStackID]bool{}
|
||||
|
||||
if previousRelationState != nil {
|
||||
for stackId, enabled := range previousRelationState.EdgeStacks {
|
||||
// flag stack for update if stack is not in the updated relation state
|
||||
// = stack has been removed for this relation
|
||||
// or this relation has been deleted
|
||||
if enabled && (updatedRelationState == nil || !updatedRelationState.EdgeStacks[stackId]) {
|
||||
stacksToUpdate[stackId] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if updatedRelationState != nil {
|
||||
for stackId, enabled := range updatedRelationState.EdgeStacks {
|
||||
// flag stack for update if stack is not in the previous relation state
|
||||
// = stack has been added for this relation
|
||||
if enabled && (previousRelationState == nil || !previousRelationState.EdgeStacks[stackId]) {
|
||||
stacksToUpdate[stackId] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// for each stack referenced by the updated relation
|
||||
// list how many time this stack is referenced in all relations
|
||||
// in order to update the stack deployments count
|
||||
for refStackId, refStackEnabled := range stacksToUpdate {
|
||||
if !refStackEnabled {
|
||||
continue
|
||||
}
|
||||
|
||||
numDeployments := 0
|
||||
|
||||
for _, r := range relations {
|
||||
for sId, enabled := range r.EdgeStacks {
|
||||
if enabled && sId == refStackId {
|
||||
numDeployments += 1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := service.updateStackFn(refStackId, func(edgeStack *portainer.EdgeStack) {
|
||||
edgeStack.NumDeployments = numDeployments
|
||||
}); err != nil {
|
||||
log.Error().Err(err).Msg("could not update the number of deployments")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,141 +0,0 @@
|
||||
package endpointrelation
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/database/boltdb"
|
||||
"github.com/portainer/portainer/api/dataservices/edgestack"
|
||||
"github.com/portainer/portainer/api/internal/edge/cache"
|
||||
"github.com/portainer/portainer/api/logs"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestUpdateRelation(t *testing.T) {
|
||||
const endpointID = 1
|
||||
const edgeStackID1 = 1
|
||||
const edgeStackID2 = 2
|
||||
|
||||
var conn portainer.Connection = &boltdb.DbConnection{Path: t.TempDir()}
|
||||
err := conn.Open()
|
||||
require.NoError(t, err)
|
||||
|
||||
defer logs.CloseAndLogErr(conn)
|
||||
|
||||
service, err := NewService(conn)
|
||||
require.NoError(t, err)
|
||||
|
||||
updateStackFnTxCalled := false
|
||||
|
||||
edgeStacks := make(map[portainer.EdgeStackID]portainer.EdgeStack)
|
||||
edgeStacks[edgeStackID1] = portainer.EdgeStack{ID: edgeStackID1}
|
||||
edgeStacks[edgeStackID2] = portainer.EdgeStack{ID: edgeStackID2}
|
||||
|
||||
service.RegisterUpdateStackFunction(func(tx portainer.Transaction, ID portainer.EdgeStackID, updateFunc func(edgeStack *portainer.EdgeStack)) error {
|
||||
updateStackFnTxCalled = true
|
||||
|
||||
s, ok := edgeStacks[ID]
|
||||
require.True(t, ok)
|
||||
|
||||
updateFunc(&s)
|
||||
edgeStacks[ID] = s
|
||||
|
||||
return nil
|
||||
})
|
||||
|
||||
// Nil relation
|
||||
|
||||
cache.Set(endpointID, []byte("value"))
|
||||
|
||||
err = service.UpdateEndpointRelation(endpointID, nil)
|
||||
_, cacheKeyExists := cache.Get(endpointID)
|
||||
require.NoError(t, err)
|
||||
require.False(t, updateStackFnTxCalled)
|
||||
require.False(t, cacheKeyExists)
|
||||
|
||||
// Add a relation to two edge stacks
|
||||
|
||||
cache.Set(endpointID, []byte("value"))
|
||||
|
||||
err = service.UpdateEndpointRelation(endpointID, &portainer.EndpointRelation{
|
||||
EndpointID: endpointID,
|
||||
EdgeStacks: map[portainer.EdgeStackID]bool{
|
||||
edgeStackID1: true,
|
||||
edgeStackID2: true,
|
||||
},
|
||||
})
|
||||
_, cacheKeyExists = cache.Get(endpointID)
|
||||
require.NoError(t, err)
|
||||
require.True(t, updateStackFnTxCalled)
|
||||
require.False(t, cacheKeyExists)
|
||||
require.Equal(t, 1, edgeStacks[edgeStackID1].NumDeployments)
|
||||
require.Equal(t, 1, edgeStacks[edgeStackID2].NumDeployments)
|
||||
|
||||
// Remove a relation to one edge stack
|
||||
|
||||
updateStackFnTxCalled = false
|
||||
cache.Set(endpointID, []byte("value"))
|
||||
|
||||
err = service.UpdateEndpointRelation(endpointID, &portainer.EndpointRelation{
|
||||
EndpointID: endpointID,
|
||||
EdgeStacks: map[portainer.EdgeStackID]bool{
|
||||
2: true,
|
||||
},
|
||||
})
|
||||
_, cacheKeyExists = cache.Get(endpointID)
|
||||
require.NoError(t, err)
|
||||
require.True(t, updateStackFnTxCalled)
|
||||
require.False(t, cacheKeyExists)
|
||||
require.Equal(t, 0, edgeStacks[edgeStackID1].NumDeployments)
|
||||
require.Equal(t, 1, edgeStacks[edgeStackID2].NumDeployments)
|
||||
|
||||
// Delete the relation
|
||||
|
||||
updateStackFnTxCalled = false
|
||||
cache.Set(endpointID, []byte("value"))
|
||||
|
||||
err = service.DeleteEndpointRelation(endpointID)
|
||||
|
||||
_, cacheKeyExists = cache.Get(endpointID)
|
||||
require.NoError(t, err)
|
||||
require.True(t, updateStackFnTxCalled)
|
||||
require.False(t, cacheKeyExists)
|
||||
require.Equal(t, 0, edgeStacks[edgeStackID1].NumDeployments)
|
||||
require.Equal(t, 0, edgeStacks[edgeStackID2].NumDeployments)
|
||||
}
|
||||
|
||||
func TestAddEndpointRelationsForEdgeStack(t *testing.T) {
|
||||
var conn portainer.Connection = &boltdb.DbConnection{Path: t.TempDir()}
|
||||
err := conn.Open()
|
||||
require.NoError(t, err)
|
||||
|
||||
defer logs.CloseAndLogErr(conn)
|
||||
|
||||
service, err := NewService(conn)
|
||||
require.NoError(t, err)
|
||||
|
||||
edgeStackService, err := edgestack.NewService(conn, func(t portainer.Transaction, esi portainer.EdgeStackID) {})
|
||||
require.NoError(t, err)
|
||||
|
||||
service.RegisterUpdateStackFunction(edgeStackService.UpdateEdgeStackFuncTx)
|
||||
require.NoError(t, edgeStackService.Create(1, &portainer.EdgeStack{}))
|
||||
require.NoError(t, service.Create(&portainer.EndpointRelation{EndpointID: 1, EdgeStacks: map[portainer.EdgeStackID]bool{}}))
|
||||
require.NoError(t, service.AddEndpointRelationsForEdgeStack([]portainer.EndpointID{1}, &portainer.EdgeStack{ID: 1}))
|
||||
}
|
||||
|
||||
func TestEndpointRelations(t *testing.T) {
|
||||
var conn portainer.Connection = &boltdb.DbConnection{Path: t.TempDir()}
|
||||
err := conn.Open()
|
||||
require.NoError(t, err)
|
||||
|
||||
defer logs.CloseAndLogErr(conn)
|
||||
|
||||
service, err := NewService(conn)
|
||||
require.NoError(t, err)
|
||||
|
||||
require.NoError(t, service.Create(&portainer.EndpointRelation{EndpointID: 1}))
|
||||
rels, err := service.EndpointRelations()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, rels, 1)
|
||||
}
|
||||
@@ -76,14 +76,14 @@ func (service ServiceTx) UpdateEndpointRelation(endpointID portainer.EndpointID,
|
||||
return nil
|
||||
}
|
||||
|
||||
func (service ServiceTx) AddEndpointRelationsForEdgeStack(endpointIDs []portainer.EndpointID, edgeStack *portainer.EdgeStack) error {
|
||||
func (service ServiceTx) AddEndpointRelationsForEdgeStack(endpointIDs []portainer.EndpointID, edgeStackID portainer.EdgeStackID) error {
|
||||
for _, endpointID := range endpointIDs {
|
||||
rel, err := service.EndpointRelation(endpointID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
rel.EdgeStacks[edgeStack.ID] = true
|
||||
rel.EdgeStacks[edgeStackID] = true
|
||||
|
||||
identifier := service.service.connection.ConvertToKey(int(endpointID))
|
||||
err = service.tx.UpdateObject(BucketName, identifier, rel)
|
||||
@@ -97,12 +97,8 @@ func (service ServiceTx) AddEndpointRelationsForEdgeStack(endpointIDs []portaine
|
||||
service.service.endpointRelationsCache = nil
|
||||
service.service.mu.Unlock()
|
||||
|
||||
if err := service.service.updateStackFnTx(service.tx, edgeStack.ID, func(es *portainer.EdgeStack) {
|
||||
es.NumDeployments += len(endpointIDs)
|
||||
|
||||
// sync changes in `edgeStack` in case it is re-persisted after `AddEndpointRelationsForEdgeStack` call
|
||||
// to avoid overriding with the previous values
|
||||
edgeStack.NumDeployments = es.NumDeployments
|
||||
if err := service.service.updateStackFnTx(service.tx, edgeStackID, func(edgeStack *portainer.EdgeStack) {
|
||||
edgeStack.NumDeployments += len(endpointIDs)
|
||||
}); err != nil {
|
||||
log.Error().Err(err).Msg("could not update the number of deployments")
|
||||
}
|
||||
@@ -190,49 +186,53 @@ func (service ServiceTx) cachedEndpointRelations() ([]portainer.EndpointRelation
|
||||
}
|
||||
|
||||
func (service ServiceTx) updateEdgeStacksAfterRelationChange(previousRelationState *portainer.EndpointRelation, updatedRelationState *portainer.EndpointRelation) {
|
||||
relations, _ := service.EndpointRelations()
|
||||
|
||||
stacksToUpdate := map[portainer.EdgeStackID]bool{}
|
||||
|
||||
if previousRelationState != nil {
|
||||
for stackId, enabled := range previousRelationState.EdgeStacks {
|
||||
// flag stack for update if stack is not in the updated relation state
|
||||
// = stack has been removed for this relation
|
||||
// or this relation has been deleted
|
||||
if enabled && (updatedRelationState == nil || !updatedRelationState.EdgeStacks[stackId]) {
|
||||
if err := service.service.updateStackFnTx(service.tx, stackId, func(edgeStack *portainer.EdgeStack) {
|
||||
// Sanity check
|
||||
if edgeStack.NumDeployments <= 0 {
|
||||
log.Error().
|
||||
Int("edgestack_id", int(edgeStack.ID)).
|
||||
Int("endpoint_id", int(previousRelationState.EndpointID)).
|
||||
Int("num_deployments", edgeStack.NumDeployments).
|
||||
Msg("cannot decrement the number of deployments for an edge stack with zero deployments")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
edgeStack.NumDeployments--
|
||||
}); err != nil {
|
||||
log.Error().Err(err).Msg("could not update the number of deployments")
|
||||
}
|
||||
|
||||
cache.Del(previousRelationState.EndpointID)
|
||||
stacksToUpdate[stackId] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if updatedRelationState == nil {
|
||||
return
|
||||
if updatedRelationState != nil {
|
||||
for stackId, enabled := range updatedRelationState.EdgeStacks {
|
||||
// flag stack for update if stack is not in the previous relation state
|
||||
// = stack has been added for this relation
|
||||
if enabled && (previousRelationState == nil || !previousRelationState.EdgeStacks[stackId]) {
|
||||
stacksToUpdate[stackId] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for stackId, enabled := range updatedRelationState.EdgeStacks {
|
||||
// flag stack for update if stack is not in the previous relation state
|
||||
// = stack has been added for this relation
|
||||
if enabled && (previousRelationState == nil || !previousRelationState.EdgeStacks[stackId]) {
|
||||
if err := service.service.updateStackFnTx(service.tx, stackId, func(edgeStack *portainer.EdgeStack) {
|
||||
edgeStack.NumDeployments++
|
||||
}); err != nil {
|
||||
log.Error().Err(err).Msg("could not update the number of deployments")
|
||||
}
|
||||
// for each stack referenced by the updated relation
|
||||
// list how many time this stack is referenced in all relations
|
||||
// in order to update the stack deployments count
|
||||
for refStackId, refStackEnabled := range stacksToUpdate {
|
||||
if !refStackEnabled {
|
||||
continue
|
||||
}
|
||||
|
||||
cache.Del(updatedRelationState.EndpointID)
|
||||
numDeployments := 0
|
||||
|
||||
for _, r := range relations {
|
||||
for sId, enabled := range r.EdgeStacks {
|
||||
if enabled && sId == refStackId {
|
||||
numDeployments += 1
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if err := service.service.updateStackFnTx(service.tx, refStackId, func(edgeStack *portainer.EdgeStack) {
|
||||
edgeStack.NumDeployments = numDeployments
|
||||
}); err != nil {
|
||||
log.Error().Err(err).Msg("could not update the number of deployments")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
|
||||
var (
|
||||
ErrObjectNotFound = errors.New("object not found inside the database")
|
||||
ErrWrongDBEdition = errors.New("the Portainer database is set for Portainer Business Edition, please follow the instructions in our documentation to downgrade it: https://docs.portainer.io/faqs/upgrading/can-i-downgrade-from-portainer-business-to-portainer-ce")
|
||||
ErrWrongDBEdition = errors.New("the Portainer database is set for Portainer Business Edition, please follow the instructions in our documentation to downgrade it: https://documentation.portainer.io/v2.0-be/downgrade/be-to-ce/")
|
||||
ErrDBImportFailed = errors.New("importing backup failed")
|
||||
ErrDatabaseIsUpdating = errors.New("database is currently in updating state. Failed prior upgrade. Please restore from backup or delete the database and restart Portainer")
|
||||
)
|
||||
|
||||
@@ -12,7 +12,6 @@ type (
|
||||
EdgeGroup() EdgeGroupService
|
||||
EdgeJob() EdgeJobService
|
||||
EdgeStack() EdgeStackService
|
||||
EdgeStackStatus() EdgeStackStatusService
|
||||
Endpoint() EndpointService
|
||||
EndpointGroup() EndpointGroupService
|
||||
EndpointRelation() EndpointRelationService
|
||||
@@ -40,8 +39,8 @@ type (
|
||||
Open() (newStore bool, err error)
|
||||
Init() error
|
||||
Close() error
|
||||
UpdateTx(func(tx DataStoreTx) error) error
|
||||
ViewTx(func(tx DataStoreTx) error) error
|
||||
UpdateTx(func(DataStoreTx) error) error
|
||||
ViewTx(func(DataStoreTx) error) error
|
||||
MigrateData() error
|
||||
Rollback(force bool) error
|
||||
CheckCurrentEdition() error
|
||||
@@ -90,21 +89,8 @@ type (
|
||||
BucketName() string
|
||||
}
|
||||
|
||||
EdgeStackStatusService interface {
|
||||
Create(edgeStackID portainer.EdgeStackID, endpointID portainer.EndpointID, status *portainer.EdgeStackStatusForEnv) error
|
||||
Read(edgeStackID portainer.EdgeStackID, endpointID portainer.EndpointID) (*portainer.EdgeStackStatusForEnv, error)
|
||||
ReadAll(edgeStackID portainer.EdgeStackID) ([]portainer.EdgeStackStatusForEnv, error)
|
||||
Update(edgeStackID portainer.EdgeStackID, endpointID portainer.EndpointID, status *portainer.EdgeStackStatusForEnv) error
|
||||
Delete(edgeStackID portainer.EdgeStackID, endpointID portainer.EndpointID) error
|
||||
DeleteAll(edgeStackID portainer.EdgeStackID) error
|
||||
Clear(edgeStackID portainer.EdgeStackID, relatedEnvironmentsIDs []portainer.EndpointID) error
|
||||
}
|
||||
|
||||
// EndpointService represents a service for managing environment(endpoint) data
|
||||
EndpointService interface {
|
||||
// partial dataservices.BaseCRUD[portainer.Endpoint, portainer.EndpointID]
|
||||
ReadAll(predicates ...func(endpoint portainer.Endpoint) bool) ([]portainer.Endpoint, error)
|
||||
|
||||
Endpoint(ID portainer.EndpointID) (*portainer.Endpoint, error)
|
||||
EndpointIDByEdgeID(edgeID string) (portainer.EndpointID, bool)
|
||||
EndpointsByTeamID(teamID portainer.TeamID) ([]portainer.Endpoint, error)
|
||||
@@ -129,7 +115,7 @@ type (
|
||||
EndpointRelation(EndpointID portainer.EndpointID) (*portainer.EndpointRelation, error)
|
||||
Create(endpointRelation *portainer.EndpointRelation) error
|
||||
UpdateEndpointRelation(EndpointID portainer.EndpointID, endpointRelation *portainer.EndpointRelation) error
|
||||
AddEndpointRelationsForEdgeStack(endpointIDs []portainer.EndpointID, edgeStack *portainer.EdgeStack) error
|
||||
AddEndpointRelationsForEdgeStack(endpointIDs []portainer.EndpointID, edgeStackID portainer.EdgeStackID) error
|
||||
RemoveEndpointRelationsForEdgeStack(endpointIDs []portainer.EndpointID, edgeStackID portainer.EdgeStackID) error
|
||||
DeleteEndpointRelation(EndpointID portainer.EndpointID) error
|
||||
BucketName() string
|
||||
@@ -226,7 +212,6 @@ type (
|
||||
UserService interface {
|
||||
BaseCRUD[portainer.User, portainer.UserID]
|
||||
UserByUsername(username string) (*portainer.User, error)
|
||||
UserIDByUsername(username string) (portainer.UserID, error)
|
||||
UsersByRole(role portainer.UserRole) ([]portainer.User, error)
|
||||
}
|
||||
|
||||
|
||||
@@ -1,16 +1,26 @@
|
||||
package pendingactions
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/dataservices"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
const BucketName = "pending_actions"
|
||||
const (
|
||||
BucketName = "pending_actions"
|
||||
)
|
||||
|
||||
type Service struct {
|
||||
dataservices.BaseDataService[portainer.PendingAction, portainer.PendingActionID]
|
||||
}
|
||||
|
||||
type ServiceTx struct {
|
||||
dataservices.BaseDataServiceTx[portainer.PendingAction, portainer.PendingActionID]
|
||||
}
|
||||
|
||||
func NewService(connection portainer.Connection) (*Service, error) {
|
||||
err := connection.SetServiceName(BucketName)
|
||||
if err != nil {
|
||||
@@ -25,11 +35,6 @@ func NewService(connection portainer.Connection) (*Service, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
// GetNextIdentifier returns the next identifier for a custom template.
|
||||
func (service *Service) GetNextIdentifier() int {
|
||||
return service.Connection.GetNextIdentifier(BucketName)
|
||||
}
|
||||
|
||||
func (s Service) Create(config *portainer.PendingAction) error {
|
||||
return s.Connection.UpdateTx(func(tx portainer.Transaction) error {
|
||||
return s.Tx(tx).Create(config)
|
||||
@@ -57,3 +62,44 @@ func (service *Service) Tx(tx portainer.Transaction) ServiceTx {
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (s ServiceTx) Create(config *portainer.PendingAction) error {
|
||||
return s.Tx.CreateObject(BucketName, func(id uint64) (int, any) {
|
||||
config.ID = portainer.PendingActionID(id)
|
||||
config.CreatedAt = time.Now().Unix()
|
||||
|
||||
return int(config.ID), config
|
||||
})
|
||||
}
|
||||
|
||||
func (s ServiceTx) Update(ID portainer.PendingActionID, config *portainer.PendingAction) error {
|
||||
return s.BaseDataServiceTx.Update(ID, config)
|
||||
}
|
||||
|
||||
func (s ServiceTx) DeleteByEndpointID(ID portainer.EndpointID) error {
|
||||
log.Debug().Int("endpointId", int(ID)).Msg("deleting pending actions for endpoint")
|
||||
pendingActions, err := s.BaseDataServiceTx.ReadAll()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to retrieve pending-actions for endpoint (%d): %w", ID, err)
|
||||
}
|
||||
|
||||
for _, pendingAction := range pendingActions {
|
||||
if pendingAction.EndpointID == ID {
|
||||
err := s.BaseDataServiceTx.Delete(pendingAction.ID)
|
||||
if err != nil {
|
||||
log.Debug().Int("endpointId", int(ID)).Msgf("failed to delete pending action: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetNextIdentifier returns the next identifier for a custom template.
|
||||
func (service ServiceTx) GetNextIdentifier() int {
|
||||
return service.Tx.GetNextIdentifier(BucketName)
|
||||
}
|
||||
|
||||
// GetNextIdentifier returns the next identifier for a custom template.
|
||||
func (service *Service) GetNextIdentifier() int {
|
||||
return service.Connection.GetNextIdentifier(BucketName)
|
||||
}
|
||||
|
||||
@@ -1,32 +0,0 @@
|
||||
package pendingactions_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/datastore"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestDeleteByEndpoint(t *testing.T) {
|
||||
_, store := datastore.MustNewTestStore(t, false, false)
|
||||
|
||||
// Create Endpoint 1
|
||||
err := store.PendingActions().Create(&portainer.PendingAction{EndpointID: 1})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create Endpoint 2
|
||||
err = store.PendingActions().Create(&portainer.PendingAction{EndpointID: 2})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Delete Endpoint 1
|
||||
err = store.PendingActions().DeleteByEndpointID(1)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Check that only Endpoint 2 remains
|
||||
pendingActions, err := store.PendingActions().ReadAll()
|
||||
require.NoError(t, err)
|
||||
require.Len(t, pendingActions, 1)
|
||||
require.Equal(t, portainer.EndpointID(2), pendingActions[0].EndpointID)
|
||||
}
|
||||
@@ -1,49 +0,0 @@
|
||||
package pendingactions
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/dataservices"
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
type ServiceTx struct {
|
||||
dataservices.BaseDataServiceTx[portainer.PendingAction, portainer.PendingActionID]
|
||||
}
|
||||
|
||||
func (s ServiceTx) Create(config *portainer.PendingAction) error {
|
||||
return s.Tx.CreateObject(BucketName, func(id uint64) (int, any) {
|
||||
config.ID = portainer.PendingActionID(id)
|
||||
config.CreatedAt = time.Now().Unix()
|
||||
|
||||
return int(config.ID), config
|
||||
})
|
||||
}
|
||||
|
||||
func (s ServiceTx) Update(ID portainer.PendingActionID, config *portainer.PendingAction) error {
|
||||
return s.BaseDataServiceTx.Update(ID, config)
|
||||
}
|
||||
|
||||
func (s ServiceTx) DeleteByEndpointID(ID portainer.EndpointID) error {
|
||||
log.Debug().Int("endpointId", int(ID)).Msg("deleting pending actions for endpoint")
|
||||
pendingActions, err := s.ReadAll()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to retrieve pending-actions for endpoint (%d): %w", ID, err)
|
||||
}
|
||||
|
||||
for _, pendingAction := range pendingActions {
|
||||
if pendingAction.EndpointID == ID {
|
||||
if err := s.Delete(pendingAction.ID); err != nil {
|
||||
log.Debug().Int("endpointId", int(ID)).Msgf("failed to delete pending action: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetNextIdentifier returns the next identifier for a custom template.
|
||||
func (service ServiceTx) GetNextIdentifier() int {
|
||||
return service.Tx.GetNextIdentifier(BucketName)
|
||||
}
|
||||
@@ -3,7 +3,6 @@ package resourcecontrol
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/dataservices"
|
||||
@@ -65,9 +64,11 @@ func (service *Service) ResourceControlByResourceIDAndType(resourceID string, re
|
||||
return nil, stop
|
||||
}
|
||||
|
||||
if slices.Contains(rc.SubResourceIDs, resourceID) {
|
||||
resourceControl = rc
|
||||
return nil, stop
|
||||
for _, subResourceID := range rc.SubResourceIDs {
|
||||
if subResourceID == resourceID {
|
||||
resourceControl = rc
|
||||
return nil, stop
|
||||
}
|
||||
}
|
||||
|
||||
return &portainer.ResourceControl{}, nil
|
||||
|
||||
@@ -3,7 +3,6 @@ package resourcecontrol
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/dataservices"
|
||||
@@ -36,9 +35,11 @@ func (service ServiceTx) ResourceControlByResourceIDAndType(resourceID string, r
|
||||
return nil, stop
|
||||
}
|
||||
|
||||
if slices.Contains(rc.SubResourceIDs, resourceID) {
|
||||
resourceControl = rc
|
||||
return nil, stop
|
||||
for _, subResourceID := range rc.SubResourceIDs {
|
||||
if subResourceID == resourceID {
|
||||
resourceControl = rc
|
||||
return nil, stop
|
||||
}
|
||||
}
|
||||
|
||||
return &portainer.ResourceControl{}, nil
|
||||
|
||||
@@ -51,20 +51,3 @@ func (service *Service) ReadWithoutSnapshotRaw(ID portainer.EndpointID) (*portai
|
||||
|
||||
return snapshot, err
|
||||
}
|
||||
|
||||
func (service *Service) ReadRawMessage(ID portainer.EndpointID) (*portainer.SnapshotRawMessage, error) {
|
||||
var snapshot *portainer.SnapshotRawMessage
|
||||
|
||||
err := service.Connection.ViewTx(func(tx portainer.Transaction) error {
|
||||
var err error
|
||||
snapshot, err = service.Tx(tx).ReadRawMessage(ID)
|
||||
|
||||
return err
|
||||
})
|
||||
|
||||
return snapshot, err
|
||||
}
|
||||
|
||||
func (service *Service) CreateRawMessage(snapshot *portainer.SnapshotRawMessage) error {
|
||||
return service.Connection.CreateObjectWithId(BucketName, int(snapshot.EndpointID), snapshot)
|
||||
}
|
||||
|
||||
@@ -35,19 +35,3 @@ func (service ServiceTx) ReadWithoutSnapshotRaw(ID portainer.EndpointID) (*porta
|
||||
|
||||
return &snapshot.Snapshot, nil
|
||||
}
|
||||
|
||||
func (service ServiceTx) ReadRawMessage(ID portainer.EndpointID) (*portainer.SnapshotRawMessage, error) {
|
||||
var snapshot = portainer.SnapshotRawMessage{}
|
||||
|
||||
identifier := service.Connection.ConvertToKey(int(ID))
|
||||
|
||||
if err := service.Tx.GetObject(service.Bucket, identifier, &snapshot); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &snapshot, nil
|
||||
}
|
||||
|
||||
func (service ServiceTx) CreateRawMessage(snapshot *portainer.SnapshotRawMessage) error {
|
||||
return service.Tx.CreateObjectWithId(BucketName, int(snapshot.EndpointID), snapshot)
|
||||
}
|
||||
|
||||
@@ -31,13 +31,6 @@ func NewService(connection portainer.Connection) (*Service, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (service *Service) Tx(tx portainer.Transaction) ServiceTx {
|
||||
return ServiceTx{
|
||||
service: service,
|
||||
tx: tx,
|
||||
}
|
||||
}
|
||||
|
||||
// Settings retrieve the ssl settings object.
|
||||
func (service *Service) Settings() (*portainer.SSLSettings, error) {
|
||||
var settings portainer.SSLSettings
|
||||
|
||||
@@ -1,31 +0,0 @@
|
||||
package ssl
|
||||
|
||||
import (
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
)
|
||||
|
||||
type ServiceTx struct {
|
||||
service *Service
|
||||
tx portainer.Transaction
|
||||
}
|
||||
|
||||
func (service ServiceTx) BucketName() string {
|
||||
return BucketName
|
||||
}
|
||||
|
||||
// Settings retrieve the settings object.
|
||||
func (service ServiceTx) Settings() (*portainer.SSLSettings, error) {
|
||||
var settings portainer.SSLSettings
|
||||
|
||||
err := service.tx.GetObject(BucketName, []byte(key), &settings)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &settings, nil
|
||||
}
|
||||
|
||||
// UpdateSettings persists a Settings object.
|
||||
func (service ServiceTx) UpdateSettings(settings *portainer.SSLSettings) error {
|
||||
return service.tx.UpdateObject(BucketName, []byte(key), settings)
|
||||
}
|
||||
@@ -4,18 +4,17 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/datastore"
|
||||
"github.com/portainer/portainer/api/filesystem"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/gofrs/uuid"
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/filesystem"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func newGuidString(t *testing.T) string {
|
||||
uuid, err := uuid.NewRandom()
|
||||
require.NoError(t, err)
|
||||
uuid, err := uuid.NewV4()
|
||||
assert.NoError(t, err)
|
||||
|
||||
return uuid.String()
|
||||
}
|
||||
@@ -42,7 +41,7 @@ func TestService_StackByWebhookID(t *testing.T) {
|
||||
|
||||
// can find a stack by webhook ID
|
||||
got, err := store.StackService.StackByWebhookID(webhookID)
|
||||
require.NoError(t, err)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, stack, *got)
|
||||
|
||||
// returns nil and object not found error if there's no stack associated with the webhook
|
||||
@@ -95,10 +94,10 @@ func Test_RefreshableStacks(t *testing.T) {
|
||||
|
||||
for _, stack := range []*portainer.Stack{&staticStack, &stackWithWebhook, &refreshableStack} {
|
||||
err := store.Stack().Create(stack)
|
||||
require.NoError(t, err)
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
stacks, err := store.Stack().RefreshableStacks()
|
||||
require.NoError(t, err)
|
||||
assert.NoError(t, err)
|
||||
assert.ElementsMatch(t, []portainer.Stack{refreshableStack}, stacks)
|
||||
}
|
||||
|
||||
@@ -5,9 +5,7 @@ import (
|
||||
|
||||
"github.com/portainer/portainer/api/dataservices/errors"
|
||||
"github.com/portainer/portainer/api/datastore"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func Test_teamByName(t *testing.T) {
|
||||
@@ -15,7 +13,7 @@ func Test_teamByName(t *testing.T) {
|
||||
_, store := datastore.MustNewTestStore(t, true, true)
|
||||
|
||||
_, err := store.Team().TeamByName("name")
|
||||
require.ErrorIs(t, err, errors.ErrObjectNotFound)
|
||||
assert.ErrorIs(t, err, errors.ErrObjectNotFound)
|
||||
|
||||
})
|
||||
|
||||
@@ -31,7 +29,7 @@ func Test_teamByName(t *testing.T) {
|
||||
teamBuilder.createNew("name1")
|
||||
|
||||
_, err := store.Team().TeamByName("name")
|
||||
require.ErrorIs(t, err, errors.ErrObjectNotFound)
|
||||
assert.ErrorIs(t, err, errors.ErrObjectNotFound)
|
||||
})
|
||||
|
||||
t.Run("When there is an object with the same name should return the object", func(t *testing.T) {
|
||||
@@ -46,7 +44,7 @@ func Test_teamByName(t *testing.T) {
|
||||
expectedTeam := teamBuilder.createNew("name1")
|
||||
|
||||
team, err := store.Team().TeamByName("name1")
|
||||
require.NoError(t, err, "TeamByName should succeed")
|
||||
assert.NoError(t, err, "TeamByName should succeed")
|
||||
assert.Equal(t, expectedTeam, team)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -36,18 +36,6 @@ func (service ServiceTx) UserByUsername(username string) (*portainer.User, error
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func (service ServiceTx) UserIDByUsername(username string) (portainer.UserID, error) {
|
||||
user, err := service.UserByUsername(username)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if user == nil {
|
||||
return 0, dserrors.ErrObjectNotFound
|
||||
}
|
||||
return user.ID, nil
|
||||
}
|
||||
|
||||
// UsersByRole return an array containing all the users with the specified role.
|
||||
func (service ServiceTx) UsersByRole(role portainer.UserRole) ([]portainer.User, error) {
|
||||
var users = make([]portainer.User, 0)
|
||||
|
||||
@@ -65,18 +65,6 @@ func (service *Service) UserByUsername(username string) (*portainer.User, error)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
func (service *Service) UserIDByUsername(username string) (portainer.UserID, error) {
|
||||
user, err := service.UserByUsername(username)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if user == nil {
|
||||
return 0, dserrors.ErrObjectNotFound
|
||||
}
|
||||
return user.ID, nil
|
||||
}
|
||||
|
||||
// UsersByRole return an array containing all the users with the specified role.
|
||||
func (service *Service) UsersByRole(role portainer.UserRole) ([]portainer.User, error) {
|
||||
var users = make([]portainer.User, 0)
|
||||
|
||||
@@ -1,70 +0,0 @@
|
||||
package version
|
||||
|
||||
import (
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/database/models"
|
||||
"github.com/portainer/portainer/api/dataservices"
|
||||
)
|
||||
|
||||
type ServiceTx struct {
|
||||
dataservices.BaseDataServiceTx[models.Version, int] // ID is not used
|
||||
}
|
||||
|
||||
func (tx ServiceTx) InstanceID() (string, error) {
|
||||
v, err := tx.Version()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return v.InstanceID, nil
|
||||
}
|
||||
|
||||
func (tx ServiceTx) UpdateInstanceID(ID string) error {
|
||||
v, err := tx.Version()
|
||||
if err != nil {
|
||||
if !dataservices.IsErrObjectNotFound(err) {
|
||||
return err
|
||||
}
|
||||
|
||||
v = &models.Version{}
|
||||
}
|
||||
|
||||
v.InstanceID = ID
|
||||
|
||||
return tx.UpdateVersion(v)
|
||||
}
|
||||
|
||||
func (tx ServiceTx) Edition() (portainer.SoftwareEdition, error) {
|
||||
v, err := tx.Version()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return portainer.SoftwareEdition(v.Edition), nil
|
||||
}
|
||||
|
||||
func (tx ServiceTx) Version() (*models.Version, error) {
|
||||
var v models.Version
|
||||
|
||||
err := tx.Tx.GetObject(BucketName, []byte(versionKey), &v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &v, nil
|
||||
}
|
||||
|
||||
func (tx ServiceTx) UpdateVersion(version *models.Version) error {
|
||||
return tx.Tx.UpdateObject(BucketName, []byte(versionKey), version)
|
||||
}
|
||||
|
||||
func (tx ServiceTx) SchemaVersion() (string, error) {
|
||||
var v models.Version
|
||||
|
||||
err := tx.Tx.GetObject(BucketName, []byte(versionKey), &v)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return v.SchemaVersion, nil
|
||||
}
|
||||
@@ -33,16 +33,6 @@ func NewService(connection portainer.Connection) (*Service, error) {
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (service *Service) Tx(tx portainer.Transaction) ServiceTx {
|
||||
return ServiceTx{
|
||||
BaseDataServiceTx: dataservices.BaseDataServiceTx[models.Version, int]{
|
||||
Bucket: BucketName,
|
||||
Connection: service.connection,
|
||||
Tx: tx,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (service *Service) SchemaVersion() (string, error) {
|
||||
v, err := service.Version()
|
||||
if err != nil {
|
||||
|
||||
@@ -14,40 +14,33 @@ import (
|
||||
// corruption and if a path is not given a default is used.
|
||||
// The path or an error are returned.
|
||||
func (store *Store) Backup(path string) (string, error) {
|
||||
if err := store.Close(); err != nil {
|
||||
return "", fmt.Errorf("failed to close store before backup: %w", err)
|
||||
}
|
||||
|
||||
filename, err := store.backupDBFile(path)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
if _, err := store.Open(); err != nil {
|
||||
return "", fmt.Errorf("failed to reopen store after backup: %w", err)
|
||||
}
|
||||
|
||||
return filename, nil
|
||||
}
|
||||
|
||||
// backupDBFile copies the database file to the backup location.
|
||||
// Does not manage connection state - works with the database file directly regardless of connection state.
|
||||
func (store *Store) backupDBFile(backupPath string) (string, error) {
|
||||
if err := store.createBackupPath(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
backupFilename := store.backupFilename()
|
||||
if backupPath != "" {
|
||||
backupFilename = backupPath
|
||||
if path != "" {
|
||||
backupFilename = path
|
||||
}
|
||||
log.Info().Str("from", store.connection.GetDatabaseFilePath()).Str("to", backupFilename).Msgf("Backing up database")
|
||||
|
||||
// Close the store before backing up
|
||||
err := store.Close()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to close store before backup: %w", err)
|
||||
}
|
||||
|
||||
log.Info().Str("from", store.connection.GetDatabaseFilePath()).Str("to", backupFilename).Msg("Backing up database")
|
||||
|
||||
if err := store.fileService.Copy(store.connection.GetDatabaseFilePath(), backupFilename, true); err != nil {
|
||||
err = store.fileService.Copy(store.connection.GetDatabaseFilePath(), backupFilename, true)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to create backup file: %w", err)
|
||||
}
|
||||
|
||||
// reopen the store
|
||||
_, err = store.Open()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to reopen store after backup: %w", err)
|
||||
}
|
||||
|
||||
return backupFilename, nil
|
||||
}
|
||||
|
||||
@@ -57,17 +50,15 @@ func (store *Store) Restore() error {
|
||||
}
|
||||
|
||||
func (store *Store) RestoreFromFile(backupFilename string) error {
|
||||
if err := store.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
store.Close()
|
||||
if err := store.fileService.Copy(backupFilename, store.connection.GetDatabaseFilePath(), true); err != nil {
|
||||
return fmt.Errorf("unable to restore backup file %q. err: %w", backupFilename, err)
|
||||
}
|
||||
|
||||
log.Info().Str("from", backupFilename).Str("to", store.connection.GetDatabaseFilePath()).Msgf("database restored")
|
||||
|
||||
if _, err := store.Open(); err != nil {
|
||||
_, err := store.Open()
|
||||
if err != nil {
|
||||
return fmt.Errorf("unable to determine version of restored portainer backup file: %w", err)
|
||||
}
|
||||
|
||||
@@ -89,7 +80,6 @@ func (store *Store) createBackupPath() error {
|
||||
return fmt.Errorf("unable to create backup folder: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -1,20 +1,19 @@
|
||||
package datastore
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/database/boltdb"
|
||||
"github.com/portainer/portainer/api/database/models"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/rs/zerolog/log"
|
||||
)
|
||||
|
||||
func TestStoreCreation(t *testing.T) {
|
||||
_, store := MustNewTestStore(t, true, true)
|
||||
require.NotNil(t, store)
|
||||
if store == nil {
|
||||
t.Fatal("Expect to create a store")
|
||||
}
|
||||
|
||||
v, err := store.VersionService.Version()
|
||||
if err != nil {
|
||||
@@ -38,12 +37,8 @@ func TestBackup(t *testing.T) {
|
||||
Edition: int(portainer.PortainerCE),
|
||||
SchemaVersion: portainer.APIVersion,
|
||||
}
|
||||
|
||||
err := store.VersionService.UpdateVersion(&v)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = store.Backup("")
|
||||
require.NoError(t, err)
|
||||
store.VersionService.UpdateVersion(&v)
|
||||
store.Backup("")
|
||||
|
||||
if !isFileExist(backupFileName) {
|
||||
t.Errorf("Expect backup file to be created %s", backupFileName)
|
||||
@@ -59,14 +54,10 @@ func TestRestore(t *testing.T) {
|
||||
updateEdition(store, portainer.PortainerCE)
|
||||
updateVersion(store, "2.4")
|
||||
|
||||
_, err := store.Backup("")
|
||||
require.NoError(t, err)
|
||||
|
||||
store.Backup("")
|
||||
updateVersion(store, "2.16")
|
||||
testVersion(store, "2.16", t)
|
||||
|
||||
err = store.Restore()
|
||||
require.NoError(t, err)
|
||||
store.Restore()
|
||||
|
||||
// check if the restore is successful and the version is correct
|
||||
testVersion(store, "2.4", t)
|
||||
@@ -76,65 +67,13 @@ func TestRestore(t *testing.T) {
|
||||
// override and set initial db version and edition
|
||||
updateEdition(store, portainer.PortainerCE)
|
||||
updateVersion(store, "2.4")
|
||||
|
||||
_, err := store.Backup("")
|
||||
require.NoError(t, err)
|
||||
|
||||
store.Backup("")
|
||||
updateVersion(store, "2.14")
|
||||
updateVersion(store, "2.16")
|
||||
testVersion(store, "2.16", t)
|
||||
|
||||
err = store.Restore()
|
||||
require.NoError(t, err)
|
||||
store.Restore()
|
||||
|
||||
// check if the restore is successful and the version is correct
|
||||
testVersion(store, "2.4", t)
|
||||
})
|
||||
}
|
||||
|
||||
func TestBackupDBFile(t *testing.T) {
|
||||
_, store := MustNewTestStore(t, true, false)
|
||||
|
||||
t.Run("creates backup file without managing connection state", func(t *testing.T) {
|
||||
// Verify connection is usable before
|
||||
_, err := store.VersionService.Version()
|
||||
require.NoError(t, err, "connection should be usable before backupDBFile")
|
||||
|
||||
// backupDBFile should work without closing the connection
|
||||
backupFilename, err := store.backupDBFile("")
|
||||
require.NoError(t, err)
|
||||
require.FileExists(t, backupFilename)
|
||||
|
||||
// Verify connection is still usable after (not closed/reopened)
|
||||
_, err = store.VersionService.Version()
|
||||
require.NoError(t, err, "connection should still be usable after backupDBFile")
|
||||
|
||||
require.NoError(t, os.Remove(backupFilename))
|
||||
})
|
||||
|
||||
t.Run("uses custom path when provided", func(t *testing.T) {
|
||||
customPath := t.TempDir() + "/custom-backup.db"
|
||||
backupFilename, err := store.backupDBFile(customPath)
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, customPath, backupFilename)
|
||||
require.FileExists(t, backupFilename)
|
||||
})
|
||||
}
|
||||
|
||||
func TestBackupDBFileUsesCorrectPath(t *testing.T) {
|
||||
_, store := MustNewTestStore(t, true, false)
|
||||
|
||||
t.Run("backs up unencrypted db when encrypted flag is false", func(t *testing.T) {
|
||||
store.connection.SetEncrypted(false)
|
||||
|
||||
backupFilename, err := store.backupDBFile("")
|
||||
require.NoError(t, err)
|
||||
require.FileExists(t, backupFilename)
|
||||
|
||||
// Verify it backed up the unencrypted file (portainer.db)
|
||||
require.Contains(t, backupFilename, boltdb.DatabaseFileName)
|
||||
require.NotContains(t, backupFilename, boltdb.EncryptedDatabaseFileName)
|
||||
|
||||
require.NoError(t, os.Remove(backupFilename))
|
||||
})
|
||||
}
|
||||
|
||||
@@ -32,38 +32,34 @@ func (store *Store) Open() (newStore bool, err error) {
|
||||
}
|
||||
|
||||
if encryptionReq {
|
||||
// NeedsEncryptionMigration() sets encrypted=true as a side effect when a key exists.
|
||||
// We need to set it back to false so GetDatabaseFilePath() returns the path to the
|
||||
// actual unencrypted file (portainer.db) that we want to back up.
|
||||
store.connection.SetEncrypted(false)
|
||||
|
||||
// Use backupDBFile directly since connection isn't open yet
|
||||
// and we don't want to trigger the close/open cycle of Backup()
|
||||
backupFilename, err := store.backupDBFile("")
|
||||
backupFilename, err := store.Backup("")
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to backup database prior to encrypting: %w", err)
|
||||
}
|
||||
|
||||
if err := store.encryptDB(); err != nil {
|
||||
innerErr := store.RestoreFromFile(backupFilename) // restore from backup if encryption fails
|
||||
return false, errors.Join(err, innerErr)
|
||||
err = store.encryptDB()
|
||||
if err != nil {
|
||||
store.RestoreFromFile(backupFilename) // restore from backup if encryption fails
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
|
||||
if err := store.connection.Open(); err != nil {
|
||||
err = store.connection.Open()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
if err := store.initServices(); err != nil {
|
||||
err = store.initServices()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
// If no settings object exists then assume we have a new store
|
||||
if _, err := store.SettingsService.Settings(); err != nil {
|
||||
_, err = store.SettingsService.Settings()
|
||||
if err != nil {
|
||||
if store.IsErrObjectNotFound(err) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
return false, err
|
||||
}
|
||||
|
||||
@@ -76,13 +72,19 @@ func (store *Store) Close() error {
|
||||
|
||||
func (store *Store) UpdateTx(fn func(dataservices.DataStoreTx) error) error {
|
||||
return store.connection.UpdateTx(func(tx portainer.Transaction) error {
|
||||
return fn(&StoreTx{store: store, tx: tx})
|
||||
return fn(&StoreTx{
|
||||
store: store,
|
||||
tx: tx,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func (store *Store) ViewTx(fn func(dataservices.DataStoreTx) error) error {
|
||||
return store.connection.ViewTx(func(tx portainer.Transaction) error {
|
||||
return fn(&StoreTx{store: store, tx: tx})
|
||||
return fn(&StoreTx{
|
||||
store: store,
|
||||
tx: tx,
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
@@ -97,7 +99,6 @@ func (store *Store) CheckCurrentEdition() error {
|
||||
if store.edition() != portainer.Edition {
|
||||
return portainerErrors.ErrWrongDBEdition
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -106,7 +107,6 @@ func (store *Store) edition() portainer.SoftwareEdition {
|
||||
if store.IsErrObjectNotFound(err) {
|
||||
edition = portainer.PortainerCE
|
||||
}
|
||||
|
||||
return edition
|
||||
}
|
||||
|
||||
@@ -125,11 +125,13 @@ func (store *Store) Rollback(force bool) error {
|
||||
|
||||
func (store *Store) encryptDB() error {
|
||||
store.connection.SetEncrypted(false)
|
||||
if err := store.connection.Open(); err != nil {
|
||||
err := store.connection.Open()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := store.initServices(); err != nil {
|
||||
err = store.initServices()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -142,7 +144,8 @@ func (store *Store) encryptDB() error {
|
||||
|
||||
log.Info().Str("filename", exportFilename).Msg("exporting database backup")
|
||||
|
||||
if err := store.Export(exportFilename); err != nil {
|
||||
err = store.Export(exportFilename)
|
||||
if err != nil {
|
||||
log.Error().Str("filename", exportFilename).Err(err).Msg("failed to export")
|
||||
|
||||
return err
|
||||
@@ -151,33 +154,38 @@ func (store *Store) encryptDB() error {
|
||||
log.Info().Msg("database backup exported")
|
||||
|
||||
// Close existing un-encrypted db so that we can delete the file later
|
||||
if err := store.connection.Close(); err != nil {
|
||||
store.connection.Close()
|
||||
|
||||
// Tell the db layer to create an encrypted db when opened
|
||||
store.connection.SetEncrypted(true)
|
||||
store.connection.Open()
|
||||
|
||||
// We have to init services before import
|
||||
err = store.initServices()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := store.Import(exportFilename); err != nil {
|
||||
log.Error().Err(err).Msg("failed to import database backup")
|
||||
|
||||
err = store.Import(exportFilename)
|
||||
if err != nil {
|
||||
// Remove the new encrypted file that we failed to import
|
||||
if err := os.Remove(store.connection.GetDatabaseFilePath()); err != nil {
|
||||
log.Error().Msg("failed to remove the file after import failure")
|
||||
}
|
||||
os.Remove(store.connection.GetDatabaseFilePath())
|
||||
|
||||
log.Fatal().Err(portainerErrors.ErrDBImportFailed).Msg("")
|
||||
}
|
||||
|
||||
if err := os.Remove(oldFilename); err != nil {
|
||||
err = os.Remove(oldFilename)
|
||||
if err != nil {
|
||||
log.Error().Msg("failed to remove the un-encrypted db file")
|
||||
}
|
||||
|
||||
if err := os.Remove(exportFilename); err != nil {
|
||||
err = os.Remove(exportFilename)
|
||||
if err != nil {
|
||||
log.Error().Msg("failed to remove the json backup file")
|
||||
}
|
||||
|
||||
// Close db connection
|
||||
if err := store.connection.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
store.connection.Close()
|
||||
|
||||
log.Info().Msg("database successfully encrypted")
|
||||
|
||||
|
||||
@@ -6,14 +6,12 @@ import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/dchest/uniuri"
|
||||
"github.com/pkg/errors"
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/chisel"
|
||||
"github.com/portainer/portainer/api/crypto"
|
||||
|
||||
"github.com/dchest/uniuri"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -32,32 +30,56 @@ func TestStoreFull(t *testing.T) {
|
||||
_, store := MustNewTestStore(t, true, true)
|
||||
|
||||
testCases := map[string]func(t *testing.T){
|
||||
"User Accounts": store.testUserAccounts,
|
||||
"Environments": store.testEnvironments,
|
||||
"Settings": store.testSettings,
|
||||
"SSL Settings": store.testSSLSettings,
|
||||
"Tunnel Server": store.testTunnelServer,
|
||||
"Custom Templates": store.testCustomTemplates,
|
||||
"Registries": store.testRegistries,
|
||||
"Resource Control": store.testResourceControl,
|
||||
"Schedules": store.testSchedules,
|
||||
"Tags": store.testTags,
|
||||
"User Accounts": func(t *testing.T) {
|
||||
store.testUserAccounts(t)
|
||||
},
|
||||
"Environments": func(t *testing.T) {
|
||||
store.testEnvironments(t)
|
||||
},
|
||||
"Settings": func(t *testing.T) {
|
||||
store.testSettings(t)
|
||||
},
|
||||
"SSL Settings": func(t *testing.T) {
|
||||
store.testSSLSettings(t)
|
||||
},
|
||||
"Tunnel Server": func(t *testing.T) {
|
||||
store.testTunnelServer(t)
|
||||
},
|
||||
"Custom Templates": func(t *testing.T) {
|
||||
store.testCustomTemplates(t)
|
||||
},
|
||||
"Registries": func(t *testing.T) {
|
||||
store.testRegistries(t)
|
||||
},
|
||||
"Resource Control": func(t *testing.T) {
|
||||
store.testResourceControl(t)
|
||||
},
|
||||
"Schedules": func(t *testing.T) {
|
||||
store.testSchedules(t)
|
||||
},
|
||||
"Tags": func(t *testing.T) {
|
||||
store.testTags(t)
|
||||
},
|
||||
|
||||
// "Test Title": func(t *testing.T) {
|
||||
// },
|
||||
}
|
||||
|
||||
for name, test := range testCases {
|
||||
t.Run(name, test)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (store *Store) testEnvironments(t *testing.T) {
|
||||
id := store.CreateEndpoint(t, "local", portainer.KubernetesLocalEnvironment, "", true)
|
||||
store.CreateEndpointRelation(t, id)
|
||||
store.CreateEndpointRelation(id)
|
||||
|
||||
id = store.CreateEndpoint(t, "agent", portainer.AgentOnDockerEnvironment, agentOnDockerEnvironmentUrl, true)
|
||||
store.CreateEndpointRelation(t, id)
|
||||
store.CreateEndpointRelation(id)
|
||||
|
||||
id = store.CreateEndpoint(t, "edge", portainer.EdgeAgentOnKubernetesEnvironment, edgeAgentOnKubernetesEnvironmentUrl, true)
|
||||
store.CreateEndpointRelation(t, id)
|
||||
store.CreateEndpointRelation(id)
|
||||
}
|
||||
|
||||
func newEndpoint(endpointType portainer.EndpointType, id portainer.EndpointID, name, URL string, TLS bool) *portainer.Endpoint {
|
||||
@@ -90,7 +112,18 @@ func newEndpoint(endpointType portainer.EndpointType, id portainer.EndpointID, n
|
||||
}
|
||||
|
||||
func setEndpointAuthorizations(endpoint *portainer.Endpoint) {
|
||||
endpoint.SecuritySettings = portainer.DefaultEndpointSecuritySettings()
|
||||
endpoint.SecuritySettings = portainer.EndpointSecuritySettings{
|
||||
AllowVolumeBrowserForRegularUsers: false,
|
||||
EnableHostManagementFeatures: false,
|
||||
|
||||
AllowSysctlSettingForRegularUsers: true,
|
||||
AllowBindMountsForRegularUsers: true,
|
||||
AllowPrivilegedModeForRegularUsers: true,
|
||||
AllowHostNamespaceForRegularUsers: true,
|
||||
AllowContainerCapabilitiesForRegularUsers: true,
|
||||
AllowDeviceMappingForRegularUsers: true,
|
||||
AllowStackManagementForRegularUsers: true,
|
||||
}
|
||||
}
|
||||
|
||||
func (store *Store) CreateEndpoint(t *testing.T, name string, endpointType portainer.EndpointType, URL string, tls bool) portainer.EndpointID {
|
||||
@@ -131,25 +164,22 @@ func (store *Store) CreateEndpoint(t *testing.T, name string, endpointType porta
|
||||
}
|
||||
|
||||
setEndpointAuthorizations(expectedEndpoint)
|
||||
|
||||
err := store.Endpoint().Create(expectedEndpoint)
|
||||
require.NoError(t, err)
|
||||
store.Endpoint().Create(expectedEndpoint)
|
||||
|
||||
endpoint, err := store.Endpoint().Endpoint(id)
|
||||
require.NoError(t, err, "Endpoint() should not return an error")
|
||||
is.NoError(err, "Endpoint() should not return an error")
|
||||
is.Equal(expectedEndpoint, endpoint, "endpoint should be the same")
|
||||
|
||||
return endpoint.ID
|
||||
}
|
||||
|
||||
func (store *Store) CreateEndpointRelation(t *testing.T, id portainer.EndpointID) {
|
||||
func (store *Store) CreateEndpointRelation(id portainer.EndpointID) {
|
||||
relation := &portainer.EndpointRelation{
|
||||
EndpointID: id,
|
||||
EdgeStacks: map[portainer.EdgeStackID]bool{},
|
||||
}
|
||||
|
||||
err := store.EndpointRelation().Create(relation)
|
||||
require.NoError(t, err)
|
||||
store.EndpointRelation().Create(relation)
|
||||
}
|
||||
|
||||
func (store *Store) testSSLSettings(t *testing.T) {
|
||||
@@ -161,11 +191,10 @@ func (store *Store) testSSLSettings(t *testing.T) {
|
||||
SelfSigned: true,
|
||||
}
|
||||
|
||||
err := store.SSLSettings().UpdateSettings(ssl)
|
||||
require.NoError(t, err)
|
||||
store.SSLSettings().UpdateSettings(ssl)
|
||||
|
||||
settings, err := store.SSLSettings().Settings()
|
||||
require.NoError(t, err, "Get sslsettings should succeed")
|
||||
is.NoError(err, "Get sslsettings should succeed")
|
||||
is.Equal(ssl, settings, "Stored SSLSettings should be the same as what is read out")
|
||||
}
|
||||
|
||||
@@ -174,27 +203,27 @@ func (store *Store) testTunnelServer(t *testing.T) {
|
||||
expectPrivateKeySeed := uniuri.NewLen(16)
|
||||
|
||||
err := store.TunnelServer().UpdateInfo(&portainer.TunnelServerInfo{PrivateKeySeed: expectPrivateKeySeed})
|
||||
require.NoError(t, err, "UpdateInfo should have succeeded")
|
||||
is.NoError(err, "UpdateInfo should have succeeded")
|
||||
|
||||
serverInfo, err := store.TunnelServer().Info()
|
||||
require.NoError(t, err, "Info should have succeeded")
|
||||
is.NoError(err, "Info should have succeeded")
|
||||
|
||||
is.Equal(expectPrivateKeySeed, serverInfo.PrivateKeySeed, "hashed passwords should not differ")
|
||||
}
|
||||
|
||||
// add users, read them back and check the details are unchanged
|
||||
func (store *Store) testUserAccounts(t *testing.T) {
|
||||
err := store.createAccount(adminUsername, adminPassword, portainer.AdministratorRole)
|
||||
require.NoError(t, err, "CreateAccount should succeed")
|
||||
is := assert.New(t)
|
||||
|
||||
err = store.checkAccount(adminUsername, adminPassword, portainer.AdministratorRole)
|
||||
require.NoError(t, err, "Account failure")
|
||||
err := store.createAccount(adminUsername, adminPassword, portainer.AdministratorRole)
|
||||
is.NoError(err, "CreateAccount should succeed")
|
||||
store.checkAccount(adminUsername, adminPassword, portainer.AdministratorRole)
|
||||
is.NoError(err, "Account failure")
|
||||
|
||||
err = store.createAccount(standardUsername, standardPassword, portainer.StandardUserRole)
|
||||
require.NoError(t, err, "CreateAccount should succeed")
|
||||
|
||||
err = store.checkAccount(standardUsername, standardPassword, portainer.StandardUserRole)
|
||||
require.NoError(t, err, "Account failure")
|
||||
is.NoError(err, "CreateAccount should succeed")
|
||||
store.checkAccount(standardUsername, standardPassword, portainer.StandardUserRole)
|
||||
is.NoError(err, "Account failure")
|
||||
}
|
||||
|
||||
// create an account with the provided details
|
||||
@@ -203,13 +232,18 @@ func (store *Store) createAccount(username, password string, role portainer.User
|
||||
user := &portainer.User{Username: username, Role: role}
|
||||
|
||||
// encrypt the password
|
||||
cs := crypto.Service{}
|
||||
cs := &crypto.Service{}
|
||||
user.Password, err = cs.Hash(password)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return store.User().Create(user)
|
||||
err = store.User().Create(user)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (store *Store) checkAccount(username, expectPassword string, expectRole portainer.UserRole) error {
|
||||
@@ -225,8 +259,13 @@ func (store *Store) checkAccount(username, expectPassword string, expectRole por
|
||||
}
|
||||
|
||||
// Check the password
|
||||
cs := crypto.Service{}
|
||||
if cs.CompareHashAndData(user.Password, expectPassword) != nil {
|
||||
cs := &crypto.Service{}
|
||||
expectPasswordHash, err := cs.Hash(expectPassword)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "hash failed")
|
||||
}
|
||||
|
||||
if user.Password != expectPasswordHash {
|
||||
return fmt.Errorf("%s user password hash failure", user.Username)
|
||||
}
|
||||
|
||||
@@ -238,7 +277,7 @@ func (store *Store) testSettings(t *testing.T) {
|
||||
|
||||
// since many settings are default and basically nil, I'm going to update some and read them back
|
||||
expectedSettings, err := store.Settings().Settings()
|
||||
require.NoError(t, err, "Settings() should not return an error")
|
||||
is.NoError(err, "Settings() should not return an error")
|
||||
expectedSettings.TemplatesURL = "http://portainer.io/application-templates"
|
||||
expectedSettings.HelmRepositoryURL = "http://portainer.io/helm-repository"
|
||||
expectedSettings.EdgeAgentCheckinInterval = 60
|
||||
@@ -252,10 +291,10 @@ func (store *Store) testSettings(t *testing.T) {
|
||||
expectedSettings.SnapshotInterval = "10m"
|
||||
|
||||
err = store.Settings().UpdateSettings(expectedSettings)
|
||||
require.NoError(t, err, "UpdateSettings() should succeed")
|
||||
is.NoError(err, "UpdateSettings() should succeed")
|
||||
|
||||
settings, err := store.Settings().Settings()
|
||||
require.NoError(t, err, "Settings() should not return an error")
|
||||
is.NoError(err, "Settings() should not return an error")
|
||||
is.Equal(expectedSettings, settings, "stored settings should match")
|
||||
}
|
||||
|
||||
@@ -275,11 +314,10 @@ func (store *Store) testCustomTemplates(t *testing.T) {
|
||||
CreatedByUserID: 10,
|
||||
}
|
||||
|
||||
err := customTemplate.Create(expectedTemplate)
|
||||
require.NoError(t, err)
|
||||
customTemplate.Create(expectedTemplate)
|
||||
|
||||
actualTemplate, err := customTemplate.Read(expectedTemplate.ID)
|
||||
require.NoError(t, err, "CustomTemplate should not return an error")
|
||||
is.NoError(err, "CustomTemplate should not return an error")
|
||||
is.Equal(expectedTemplate, actualTemplate, "expected and actual template do not match")
|
||||
}
|
||||
|
||||
@@ -307,17 +345,17 @@ func (store *Store) testRegistries(t *testing.T) {
|
||||
}
|
||||
|
||||
err := regService.Create(reg1)
|
||||
require.NoError(t, err)
|
||||
is.NoError(err)
|
||||
|
||||
err = regService.Create(reg2)
|
||||
require.NoError(t, err)
|
||||
is.NoError(err)
|
||||
|
||||
actualReg1, err := regService.Read(reg1.ID)
|
||||
require.NoError(t, err)
|
||||
is.NoError(err)
|
||||
is.Equal(reg1, actualReg1, "registries differ")
|
||||
|
||||
actualReg2, err := regService.Read(reg2.ID)
|
||||
require.NoError(t, err)
|
||||
is.NoError(err)
|
||||
is.Equal(reg2, actualReg2, "registries differ")
|
||||
}
|
||||
|
||||
@@ -340,10 +378,10 @@ func (store *Store) testSchedules(t *testing.T) {
|
||||
}
|
||||
|
||||
err := schedule.CreateSchedule(s)
|
||||
require.NoError(t, err, "CreateSchedule should succeed")
|
||||
is.NoError(err, "CreateSchedule should succeed")
|
||||
|
||||
actual, err := schedule.Schedule(s.ID)
|
||||
require.NoError(t, err, "schedule should be found")
|
||||
is.NoError(err, "schedule should be found")
|
||||
is.Equal(s, actual, "schedules differ")
|
||||
}
|
||||
|
||||
@@ -363,16 +401,16 @@ func (store *Store) testTags(t *testing.T) {
|
||||
}
|
||||
|
||||
err := tags.Create(tag1)
|
||||
require.NoError(t, err, "Tags.Create should succeed")
|
||||
is.NoError(err, "Tags.Create should succeed")
|
||||
|
||||
err = tags.Create(tag2)
|
||||
require.NoError(t, err, "Tags.Create should succeed")
|
||||
is.NoError(err, "Tags.Create should succeed")
|
||||
|
||||
actual, err := tags.Read(tag1.ID)
|
||||
require.NoError(t, err, "tag1 should be found")
|
||||
is.NoError(err, "tag1 should be found")
|
||||
is.Equal(tag1, actual, "tags differ")
|
||||
|
||||
actual, err = tags.Read(tag2.ID)
|
||||
require.NoError(t, err, "tag2 should be found")
|
||||
is.NoError(err, "tag2 should be found")
|
||||
is.Equal(tag2, actual, "tags differ")
|
||||
}
|
||||
|
||||
@@ -31,6 +31,7 @@ func (store *Store) checkOrCreateDefaultSettings() error {
|
||||
settings, err := store.SettingsService.Settings()
|
||||
if store.IsErrObjectNotFound(err) {
|
||||
defaultSettings := &portainer.Settings{
|
||||
EnableTelemetry: false,
|
||||
AuthenticationMethod: portainer.AuthenticationInternal,
|
||||
BlackListedLabels: make([]portainer.Pair, 0),
|
||||
InternalAuthSettings: portainer.InternalAuthSettings{
|
||||
|
||||
@@ -40,11 +40,13 @@ func (store *Store) MigrateData() error {
|
||||
}
|
||||
|
||||
// before we alter anything in the DB, create a backup
|
||||
if _, err := store.Backup(""); err != nil {
|
||||
_, err = store.Backup("")
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "while backing up database")
|
||||
}
|
||||
|
||||
if err := store.FailSafeMigrate(migrator, version); err != nil {
|
||||
err = store.FailSafeMigrate(migrator, version)
|
||||
if err != nil {
|
||||
err = errors.Wrap(err, "failed to migrate database")
|
||||
|
||||
log.Warn().Err(err).Msg("migration failed, restoring database to previous version")
|
||||
@@ -83,9 +85,7 @@ func (store *Store) newMigratorParameters(version *models.Version, flags *portai
|
||||
DockerhubService: store.DockerHubService,
|
||||
AuthorizationService: authorization.NewService(store),
|
||||
EdgeStackService: store.EdgeStackService,
|
||||
EdgeStackStatusService: store.EdgeStackStatusService,
|
||||
EdgeJobService: store.EdgeJobService,
|
||||
EdgeGroupService: store.EdgeGroupService,
|
||||
TunnelServerService: store.TunnelServerService,
|
||||
PendingActionsService: store.PendingActionsService,
|
||||
}
|
||||
@@ -140,7 +140,8 @@ func (store *Store) connectionRollback(force bool) error {
|
||||
}
|
||||
}
|
||||
|
||||
if err := store.Restore(); err != nil {
|
||||
err := store.Restore()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
@@ -9,15 +9,14 @@ import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/Masterminds/semver"
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/database/boltdb"
|
||||
"github.com/portainer/portainer/api/database/models"
|
||||
"github.com/portainer/portainer/api/datastore/migrator"
|
||||
|
||||
"github.com/Masterminds/semver/v3"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
"github.com/rs/zerolog/log"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMigrateData(t *testing.T) {
|
||||
@@ -54,11 +53,9 @@ func TestMigrateData(t *testing.T) {
|
||||
}
|
||||
|
||||
testVersion(store, portainer.APIVersion, t)
|
||||
err := store.Close()
|
||||
require.NoError(t, err)
|
||||
store.Close()
|
||||
|
||||
newStore, err = store.Open()
|
||||
require.NoError(t, err)
|
||||
newStore, _ = store.Open()
|
||||
if newStore {
|
||||
t.Error("Expect store to NOT be new DB")
|
||||
}
|
||||
@@ -66,11 +63,8 @@ func TestMigrateData(t *testing.T) {
|
||||
|
||||
t.Run("MigrateData should create backup file upon update", func(t *testing.T) {
|
||||
_, store := MustNewTestStore(t, true, false)
|
||||
err := store.VersionService.UpdateVersion(&models.Version{SchemaVersion: "2.0", Edition: int(portainer.PortainerCE)})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = store.MigrateData()
|
||||
require.NoError(t, err)
|
||||
store.VersionService.UpdateVersion(&models.Version{SchemaVersion: "1.0", Edition: int(portainer.PortainerCE)})
|
||||
store.MigrateData()
|
||||
|
||||
backupfilename := store.backupFilename()
|
||||
if exists, _ := store.fileService.FileExists(backupfilename); !exists {
|
||||
@@ -79,28 +73,21 @@ func TestMigrateData(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("MigrateData should recover and restore backup during migration critical failure", func(t *testing.T) {
|
||||
t.Setenv("PORTAINER_TEST_MIGRATE_FAIL", "FAIL")
|
||||
os.Setenv("PORTAINER_TEST_MIGRATE_FAIL", "FAIL")
|
||||
|
||||
version := "2.15"
|
||||
_, store := MustNewTestStore(t, true, false)
|
||||
store.VersionService.UpdateVersion(&models.Version{SchemaVersion: version, Edition: int(portainer.PortainerCE)})
|
||||
store.MigrateData()
|
||||
|
||||
err := store.VersionService.UpdateVersion(&models.Version{SchemaVersion: version, Edition: int(portainer.PortainerCE)})
|
||||
require.NoError(t, err)
|
||||
|
||||
err = store.MigrateData()
|
||||
require.Error(t, err)
|
||||
|
||||
store.Open()
|
||||
testVersion(store, version, t)
|
||||
})
|
||||
|
||||
t.Run("MigrateData should fail to create backup if database file is set to updating", func(t *testing.T) {
|
||||
_, store := MustNewTestStore(t, true, false)
|
||||
|
||||
err := store.VersionService.StoreIsUpdating(true)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = store.MigrateData()
|
||||
require.Error(t, err)
|
||||
store.VersionService.StoreIsUpdating(true)
|
||||
store.MigrateData()
|
||||
|
||||
// If you get an error, it usually means that the backup folder doesn't exist (no backups). Expected!
|
||||
// If the backup file is not blank, then it means a backup was created. We don't want that because we
|
||||
@@ -128,12 +115,10 @@ func TestMigrateData(t *testing.T) {
|
||||
|
||||
if latestMigrations.Version.Equal(semver.MustParse(portainer.APIVersion)) {
|
||||
v.MigratorCount = len(latestMigrations.MigrationFuncs)
|
||||
err = store.VersionService.UpdateVersion(v)
|
||||
require.NoError(t, err)
|
||||
store.VersionService.UpdateVersion(v)
|
||||
}
|
||||
|
||||
err = store.MigrateData()
|
||||
require.NoError(t, err)
|
||||
store.MigrateData()
|
||||
|
||||
// If you get an error, it usually means that the backup folder doesn't exist (no backups). Expected!
|
||||
// If the backup file is not blank, then it means a backup was created. We don't want that because we
|
||||
@@ -156,12 +141,8 @@ func TestMigrateData(t *testing.T) {
|
||||
}
|
||||
|
||||
v.MigratorCount = 1000
|
||||
|
||||
err = store.VersionService.UpdateVersion(v)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = store.MigrateData()
|
||||
require.NoError(t, err)
|
||||
store.VersionService.UpdateVersion(v)
|
||||
store.MigrateData()
|
||||
|
||||
// If you get an error, it usually means that the backup folder doesn't exist (no backups). Expected!
|
||||
// If the backup file is not blank, then it means a backup was created. We don't want that because we
|
||||
@@ -177,14 +158,14 @@ func TestRollback(t *testing.T) {
|
||||
t.Run("Rollback should restore upgrade after backup", func(t *testing.T) {
|
||||
version := "2.11"
|
||||
|
||||
v := models.Version{SchemaVersion: version}
|
||||
v := models.Version{
|
||||
SchemaVersion: version,
|
||||
}
|
||||
|
||||
_, store := MustNewTestStore(t, false, false)
|
||||
store.VersionService.UpdateVersion(&v)
|
||||
|
||||
err := store.VersionService.UpdateVersion(&v)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = store.Backup("")
|
||||
_, err := store.Backup("")
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("")
|
||||
}
|
||||
@@ -203,9 +184,7 @@ func TestRollback(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
_, err = store.Open()
|
||||
require.NoError(t, err)
|
||||
|
||||
store.Open()
|
||||
testVersion(store, version, t)
|
||||
})
|
||||
|
||||
@@ -218,11 +197,9 @@ func TestRollback(t *testing.T) {
|
||||
}
|
||||
|
||||
_, store := MustNewTestStore(t, true, false)
|
||||
store.VersionService.UpdateVersion(&v)
|
||||
|
||||
err := store.VersionService.UpdateVersion(&v)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = store.Backup("")
|
||||
_, err := store.Backup("")
|
||||
if err != nil {
|
||||
log.Fatal().Err(err).Msg("")
|
||||
}
|
||||
@@ -241,8 +218,7 @@ func TestRollback(t *testing.T) {
|
||||
return
|
||||
}
|
||||
|
||||
_, err = store.Open()
|
||||
require.NoError(t, err)
|
||||
store.Open()
|
||||
testVersion(store, version, t)
|
||||
})
|
||||
}
|
||||
@@ -261,17 +237,17 @@ func migrateDBTestHelper(t *testing.T, srcPath, wantPath string, overrideInstanc
|
||||
_, store := MustNewTestStore(t, true, false)
|
||||
|
||||
fmt.Println("store.path=", store.GetConnection().GetDatabaseFilePath())
|
||||
|
||||
err = store.connection.DeleteObject("version", []byte("VERSION"))
|
||||
require.NoError(t, err)
|
||||
store.connection.DeleteObject("version", []byte("VERSION"))
|
||||
|
||||
// defer teardown()
|
||||
if err := importJSON(t, bytes.NewReader(srcJSON), store); err != nil {
|
||||
err = importJSON(t, bytes.NewReader(srcJSON), store)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Run the actual migrations on our input database.
|
||||
if err := store.MigrateData(); err != nil {
|
||||
err = store.MigrateData()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -284,7 +260,8 @@ func migrateDBTestHelper(t *testing.T, srcPath, wantPath string, overrideInstanc
|
||||
}
|
||||
|
||||
v.InstanceID = "463d5c47-0ea5-4aca-85b1-405ceefee254"
|
||||
if err := store.VersionService.UpdateVersion(v); err != nil {
|
||||
err = store.VersionService.UpdateVersion(v)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@@ -293,10 +270,10 @@ func migrateDBTestHelper(t *testing.T, srcPath, wantPath string, overrideInstanc
|
||||
// exportJson rather than ExportRaw. The exportJson function allows us to
|
||||
// strip out the metadata which we don't want for our tests.
|
||||
// TODO: update connection interface in CE to allow us to use ExportRaw and pass meta false
|
||||
if err := store.connection.Close(); err != nil {
|
||||
err = store.connection.Close()
|
||||
if err != nil {
|
||||
t.Fatalf("err closing bolt connection: %v", err)
|
||||
}
|
||||
|
||||
con, ok := store.connection.(*boltdb.DbConnection)
|
||||
if !ok {
|
||||
t.Fatalf("backing database is not using boltdb, but the migrations test requires it")
|
||||
@@ -325,15 +302,11 @@ func migrateDBTestHelper(t *testing.T, srcPath, wantPath string, overrideInstanc
|
||||
// Compare the result we got with the one we wanted.
|
||||
if diff := cmp.Diff(wantJSON, gotJSON); diff != "" {
|
||||
gotPath := filepath.Join(os.TempDir(), "portainer-migrator-test-fail.json")
|
||||
err = os.WriteFile(
|
||||
os.WriteFile(
|
||||
gotPath,
|
||||
gotJSON,
|
||||
0o600,
|
||||
)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("failed writing migrated output to temp file")
|
||||
}
|
||||
|
||||
t.Errorf(
|
||||
"migrate data from %s to %s failed\nwrote migrated input to %s\nmismatch (-want +got):\n%s",
|
||||
srcPath,
|
||||
|
||||
@@ -6,9 +6,7 @@ import (
|
||||
portainer "github.com/portainer/portainer/api"
|
||||
"github.com/portainer/portainer/api/datastore/migrator"
|
||||
gittypes "github.com/portainer/portainer/api/git/types"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMigrateStackEntryPoint(t *testing.T) {
|
||||
@@ -30,25 +28,25 @@ func TestMigrateStackEntryPoint(t *testing.T) {
|
||||
|
||||
for _, s := range stacks {
|
||||
err := stackService.Create(s)
|
||||
require.NoError(t, err, "failed to create stack")
|
||||
assert.NoError(t, err, "failed to create stack")
|
||||
}
|
||||
|
||||
s, err := stackService.Read(1)
|
||||
require.NoError(t, err)
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, s.GitConfig, "first stack should not have git config")
|
||||
|
||||
s, err = stackService.Read(2)
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, s.GitConfig.ConfigFilePath, "not migrated yet migrated")
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "", s.GitConfig.ConfigFilePath, "not migrated yet migrated")
|
||||
|
||||
err = migrator.MigrateStackEntryPoint(stackService)
|
||||
require.NoError(t, err, "failed to migrate entry point to Git ConfigFilePath")
|
||||
assert.NoError(t, err, "failed to migrate entry point to Git ConfigFilePath")
|
||||
|
||||
s, err = stackService.Read(1)
|
||||
require.NoError(t, err)
|
||||
assert.NoError(t, err)
|
||||
assert.Nil(t, s.GitConfig, "first stack should not have git config")
|
||||
|
||||
s, err = stackService.Read(2)
|
||||
require.NoError(t, err)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "dir/sub/compose.yml", s.GitConfig.ConfigFilePath, "second stack should have config file path migrated")
|
||||
}
|
||||
|
||||
@@ -105,18 +105,12 @@ func (store *Store) getOrMigrateLegacyVersion() (*models.Version, error) {
|
||||
|
||||
// finishMigrateLegacyVersion writes the new version to the DB and removes the old version keys from the DB
|
||||
func (store *Store) finishMigrateLegacyVersion(versionToWrite *models.Version) error {
|
||||
if err := store.VersionService.UpdateVersion(versionToWrite); err != nil {
|
||||
return err
|
||||
}
|
||||
err := store.VersionService.UpdateVersion(versionToWrite)
|
||||
|
||||
// Remove legacy keys if present
|
||||
if err := store.connection.DeleteObject(bucketName, []byte(legacyDBVersionKey)); err != nil {
|
||||
return err
|
||||
}
|
||||
store.connection.DeleteObject(bucketName, []byte(legacyDBVersionKey))
|
||||
store.connection.DeleteObject(bucketName, []byte(legacyEditionKey))
|
||||
store.connection.DeleteObject(bucketName, []byte(legacyInstanceKey))
|
||||
|
||||
if err := store.connection.DeleteObject(bucketName, []byte(legacyEditionKey)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return store.connection.DeleteObject(bucketName, []byte(legacyInstanceKey))
|
||||
return err
|
||||
}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user