Compare commits

..

4 Commits

Author SHA1 Message Date
Matt Hook
d26ebf6f8e update jwt 2023-12-14 16:23:03 +13:00
Matt Hook
b569ea7ef3 update db version stuff 2023-12-14 16:06:09 +13:00
Matt Hook
f797d7f3dd update websocket hijack code 2023-12-14 15:58:05 +13:00
Matt Hook
c80b8ff440 enable staticcheck. Fix some easy ones 2023-12-14 15:36:36 +13:00
2626 changed files with 38988 additions and 61102 deletions

View File

@@ -10,7 +10,6 @@ globals:
extends:
- 'eslint:recommended'
- 'plugin:storybook/recommended'
- 'plugin:import/typescript'
- prettier
plugins:
@@ -24,13 +23,12 @@ parserOptions:
modules: true
rules:
no-console: error
no-console: warn
no-alert: error
no-control-regex: 'off'
no-empty: warn
no-empty-function: warn
no-useless-escape: 'off'
import/named: error
import/order:
[
'error',
@@ -45,12 +43,6 @@ rules:
pathGroupsExcludedImportTypes: ['internal'],
},
]
no-restricted-imports:
- error
- patterns:
- group:
- '@/react/test-utils/*'
message: 'These utils are just for test files'
settings:
'import/resolver':
@@ -59,8 +51,6 @@ settings:
- ['@@', './app/react/components']
- ['@', './app']
extensions: ['.js', '.ts', '.tsx']
typescript: true
node: true
overrides:
- files:
@@ -85,9 +75,7 @@ overrides:
settings:
react:
version: 'detect'
rules:
no-console: error
import/order:
[
'error',
@@ -120,12 +108,6 @@ overrides:
'no-await-in-loop': 'off'
'react/jsx-no-useless-fragment': ['error', { allowExpressions: true }]
'regex/invalid': ['error', [{ 'regex': '<Icon icon="(.*)"', 'message': 'Please directly import the `lucide-react` icon instead of using the string' }]]
'@typescript-eslint/no-restricted-imports':
- error
- patterns:
- group:
- '@/react/test-utils/*'
message: 'These utils are just for test files'
overrides: # allow props spreading for hoc files
- files:
- app/**/with*.ts{,x}
@@ -134,18 +116,13 @@ overrides:
- files:
- app/**/*.test.*
extends:
- 'plugin:vitest/recommended'
- 'plugin:jest/recommended'
- 'plugin:jest/style'
env:
'vitest/env': true
'jest/globals': true
rules:
'react/jsx-no-constructed-context-values': off
'@typescript-eslint/no-restricted-imports': off
no-restricted-imports: off
'react/jsx-props-no-spreading': off
- files:
- app/**/*.stories.*
rules:
'no-alert': off
'@typescript-eslint/no-restricted-imports': off
no-restricted-imports: off
'react/jsx-props-no-spreading': off

View File

@@ -93,16 +93,6 @@ body:
description: We only provide support for the most recent version of Portainer and the previous 3 versions. If you are on an older version of Portainer we recommend [upgrading first](https://docs.portainer.io/start/upgrade) in case your bug has already been fixed.
multiple: false
options:
- '2.22.0'
- '2.21.3'
- '2.21.2'
- '2.21.1'
- '2.21.0'
- '2.20.3'
- '2.20.2'
- '2.20.1'
- '2.20.0'
- '2.19.5'
- '2.19.4'
- '2.19.3'
- '2.19.2'

View File

@@ -5,7 +5,7 @@ on:
push:
branches:
- 'develop'
- 'release/*'
- '!release/*'
pull_request:
branches:
- 'develop'
@@ -13,15 +13,11 @@ on:
- 'feat/*'
- 'fix/*'
- 'refactor/*'
types:
- opened
- reopened
- synchronize
- ready_for_review
env:
DOCKER_HUB_REPO: portainerci/portainer-ce
EXTENSION_HUB_REPO: portainerci/portainer-docker-extension
DOCKER_HUB_REPO: portainerci/portainer
NODE_ENV: testing
GO_VERSION: 1.21.3
NODE_VERSION: 18.x
jobs:
@@ -29,71 +25,85 @@ jobs:
strategy:
matrix:
config:
- { platform: linux, arch: amd64, version: "" }
- { platform: linux, arch: arm64, version: "" }
- { platform: linux, arch: arm, version: "" }
- { platform: linux, arch: ppc64le, version: "" }
- { platform: linux, arch: amd64 }
- { platform: linux, arch: arm64 }
- { platform: windows, arch: amd64, version: 1809 }
- { platform: windows, arch: amd64, version: ltsc2022 }
runs-on: ubuntu-latest
if: github.event.pull_request.draft == false
runs-on: arc-runner-set
steps:
- name: '[preparation] checkout the current branch'
uses: actions/checkout@v4.1.1
uses: actions/checkout@v3.5.3
with:
ref: ${{ github.event.inputs.branch }}
- name: '[preparation] set up golang'
uses: actions/setup-go@v5.0.0
uses: actions/setup-go@v4.0.1
with:
go-version-file: go.mod
go-version: ${{ env.GO_VERSION }}
cache: false
- name: '[preparation] cache paths'
id: cache-dir-path
run: |
echo "yarn-cache-dir=$(yarn cache dir)" >> "$GITHUB_OUTPUT"
echo "go-build-dir=$(go env GOCACHE)" >> "$GITHUB_OUTPUT"
echo "go-mod-dir=$(go env GOMODCACHE)" >> "$GITHUB_OUTPUT"
- name: '[preparation] cache go'
uses: actions/cache@v3
with:
path: |
${{ steps.cache-dir-path.outputs.go-build-dir }}
${{ steps.cache-dir-path.outputs.go-mod-dir }}
key: ${{ matrix.config.platform }}-${{ matrix.config.arch }}-go-${{ hashFiles('**/go.sum') }}
restore-keys: |
${{ matrix.config.platform }}-${{ matrix.config.arch }}-go-
enableCrossOsArchive: true
- name: '[preparation] set up node.js'
uses: actions/setup-node@v4.0.1
uses: actions/setup-node@v3
with:
node-version: ${{ env.NODE_VERSION }}
cache: 'yarn'
cache: ''
- name: '[preparation] cache yarn'
uses: actions/cache@v3
with:
path: |
**/node_modules
${{ steps.cache-dir-path.outputs.yarn-cache-dir }}
key: ${{ matrix.config.platform }}-${{ matrix.config.arch }}-yarn-${{ hashFiles('**/yarn.lock') }}
restore-keys: |
${{ matrix.config.platform }}-${{ matrix.config.arch }}-yarn-
enableCrossOsArchive: true
- name: '[preparation] set up qemu'
uses: docker/setup-qemu-action@v3.0.0
uses: docker/setup-qemu-action@v2
- name: '[preparation] set up docker context for buildx'
run: docker context create builders
- name: '[preparation] set up docker buildx'
uses: docker/setup-buildx-action@v3.0.0
uses: docker/setup-buildx-action@v2
with:
endpoint: builders
- name: '[preparation] docker login'
uses: docker/login-action@v3.0.0
uses: docker/login-action@v2.2.0
with:
username: ${{ secrets.DOCKER_HUB_USERNAME }}
password: ${{ secrets.DOCKER_HUB_PASSWORD }}
- name: '[preparation] set the container image tag'
run: |
if [[ "${GITHUB_REF_NAME}" =~ ^release/.*$ ]]; then
# use the release branch name as the tag for release branches
# for instance, release/2.19 becomes 2.19
CONTAINER_IMAGE_TAG=$(echo $GITHUB_REF_NAME | cut -d "/" -f 2)
elif [ "${GITHUB_EVENT_NAME}" == "pull_request" ]; then
# use pr${{ github.event.number }} as the tag for pull requests
# for instance, pr123
if [ "${GITHUB_EVENT_NAME}" == "pull_request" ]; then
CONTAINER_IMAGE_TAG="pr${{ github.event.number }}"
else
# replace / with - in the branch name
# for instance, feature/1.0.0 -> feature-1.0.0
CONTAINER_IMAGE_TAG=$(echo $GITHUB_REF_NAME | sed 's/\//-/g')
fi
echo "CONTAINER_IMAGE_TAG=${CONTAINER_IMAGE_TAG}-${{ matrix.config.platform }}${{ matrix.config.version }}-${{ matrix.config.arch }}" >> $GITHUB_ENV
if [ "${{ matrix.config.platform }}" == "windows" ]; then
CONTAINER_IMAGE_TAG="${CONTAINER_IMAGE_TAG}-${{ matrix.config.platform }}${{ matrix.config.version }}-${{ matrix.config.arch }}"
else
CONTAINER_IMAGE_TAG="${CONTAINER_IMAGE_TAG}-${{ matrix.config.platform }}-${{ matrix.config.arch }}"
fi
echo "CONTAINER_IMAGE_TAG=${CONTAINER_IMAGE_TAG}" >> $GITHUB_ENV
- name: '[execution] build linux & windows portainer binaries'
run: |
export YARN_VERSION=$(yarn --version)
export WEBPACK_VERSION=$(yarn list webpack --depth=0 | grep webpack | awk -F@ '{print $2}')
export BUILDNUMBER=${GITHUB_RUN_NUMBER}
GIT_COMMIT_HASH_LONG=${{ github.sha }}
export GIT_COMMIT_HASH_SHORT={GIT_COMMIT_HASH_LONG:0:7}
NODE_ENV="testing"
if [[ "${GITHUB_REF_NAME}" =~ ^release/.*$ ]]; then
NODE_ENV="production"
fi
make build-all PLATFORM=${{ matrix.config.platform }} ARCH=${{ matrix.config.arch }} ENV=${NODE_ENV}
env:
CONTAINER_IMAGE_TAG: ${{ env.CONTAINER_IMAGE_TAG }}
@@ -101,66 +111,38 @@ jobs:
run: |
if [ "${{ matrix.config.platform }}" == "windows" ]; then
mv dist/portainer dist/portainer.exe
docker buildx build --output=type=registry --attest type=provenance,mode=max --attest type=sbom,disabled=false --platform ${{ matrix.config.platform }}/${{ matrix.config.arch }} --build-arg OSVERSION=${{ matrix.config.version }} -t "${DOCKER_HUB_REPO}:${CONTAINER_IMAGE_TAG}" -f build/${{ matrix.config.platform }}/Dockerfile .
docker buildx build --output=type=registry --platform ${{ matrix.config.platform }}/${{ matrix.config.arch }} --build-arg OSVERSION=${{ matrix.config.version }} -t "${DOCKER_HUB_REPO}:${CONTAINER_IMAGE_TAG}" -f build/${{ matrix.config.platform }}/Dockerfile .
else
docker buildx build --output=type=registry --attest type=provenance,mode=max --attest type=sbom,disabled=false --platform ${{ matrix.config.platform }}/${{ matrix.config.arch }} -t "${DOCKER_HUB_REPO}:${CONTAINER_IMAGE_TAG}" -f build/${{ matrix.config.platform }}/Dockerfile .
docker buildx build --output=type=registry --attest type=provenance,mode=max --attest type=sbom,disabled=false --platform ${{ matrix.config.platform }}/${{ matrix.config.arch }} -t "${DOCKER_HUB_REPO}:${CONTAINER_IMAGE_TAG}-alpine" -f build/${{ matrix.config.platform }}/alpine.Dockerfile .
if [[ "${GITHUB_REF_NAME}" =~ ^release/.*$ ]]; then
docker buildx build --output=type=registry --attest type=provenance,mode=max --attest type=sbom,disabled=false --platform ${{ matrix.config.platform }}/${{ matrix.config.arch }} -t "${EXTENSION_HUB_REPO}:${CONTAINER_IMAGE_TAG}" -f build/${{ matrix.config.platform }}/Dockerfile .
docker buildx build --output=type=registry --attest type=provenance,mode=max --attest type=sbom,disabled=false --platform ${{ matrix.config.platform }}/${{ matrix.config.arch }} -t "${EXTENSION_HUB_REPO}:${CONTAINER_IMAGE_TAG}-alpine" -f build/${{ matrix.config.platform }}/alpine.Dockerfile .
fi
docker buildx build --output=type=registry --platform ${{ matrix.config.platform }}/${{ matrix.config.arch }} -t "${DOCKER_HUB_REPO}:${CONTAINER_IMAGE_TAG}" -f build/${{ matrix.config.platform }}/Dockerfile .
docker buildx build --output=type=registry --platform ${{ matrix.config.platform }}/${{ matrix.config.arch }} -t "${DOCKER_HUB_REPO}:${CONTAINER_IMAGE_TAG}-alpine" -f build/${{ matrix.config.platform }}/alpine.Dockerfile .
fi
env:
CONTAINER_IMAGE_TAG: ${{ env.CONTAINER_IMAGE_TAG }}
build_manifests:
runs-on: ubuntu-latest
if: github.event.pull_request.draft == false
runs-on: arc-runner-set
needs: [build_images]
steps:
- name: '[preparation] docker login'
uses: docker/login-action@v3.0.0
uses: docker/login-action@v2.2.0
with:
username: ${{ secrets.DOCKER_HUB_USERNAME }}
password: ${{ secrets.DOCKER_HUB_PASSWORD }}
- name: '[preparation] set up docker context for buildx'
run: docker version && docker context create builders
- name: '[preparation] set up docker buildx'
uses: docker/setup-buildx-action@v3.0.0
uses: docker/setup-buildx-action@v2
with:
endpoint: builders
- name: '[execution] build and push manifests'
run: |
if [[ "${GITHUB_REF_NAME}" =~ ^release/.*$ ]]; then
# use the release branch name as the tag for release branches
# for instance, release/2.19 becomes 2.19
CONTAINER_IMAGE_TAG=$(echo $GITHUB_REF_NAME | cut -d "/" -f 2)
elif [ "${GITHUB_EVENT_NAME}" == "pull_request" ]; then
# use pr${{ github.event.number }} as the tag for pull requests
# for instance, pr123
if [ "${GITHUB_EVENT_NAME}" == "pull_request" ]; then
CONTAINER_IMAGE_TAG="pr${{ github.event.number }}"
else
# replace / with - in the branch name
# for instance, feature/1.0.0 -> feature-1.0.0
CONTAINER_IMAGE_TAG=$(echo $GITHUB_REF_NAME | sed 's/\//-/g')
fi
docker buildx imagetools create -t "${DOCKER_HUB_REPO}:${CONTAINER_IMAGE_TAG}" \
"${DOCKER_HUB_REPO}:${CONTAINER_IMAGE_TAG}-linux-amd64" \
"${DOCKER_HUB_REPO}:${CONTAINER_IMAGE_TAG}-linux-arm64" \
"${DOCKER_HUB_REPO}:${CONTAINER_IMAGE_TAG}-linux-arm" \
"${DOCKER_HUB_REPO}:${CONTAINER_IMAGE_TAG}-linux-ppc64le" \
"${DOCKER_HUB_REPO}:${CONTAINER_IMAGE_TAG}-windows1809-amd64" \
"${DOCKER_HUB_REPO}:${CONTAINER_IMAGE_TAG}-windowsltsc2022-amd64"
docker buildx imagetools create -t "${DOCKER_HUB_REPO}:${CONTAINER_IMAGE_TAG}-alpine" \
"${DOCKER_HUB_REPO}:${CONTAINER_IMAGE_TAG}-linux-amd64-alpine" \
"${DOCKER_HUB_REPO}:${CONTAINER_IMAGE_TAG}-linux-arm64-alpine" \
"${DOCKER_HUB_REPO}:${CONTAINER_IMAGE_TAG}-linux-arm-alpine" \
"${DOCKER_HUB_REPO}:${CONTAINER_IMAGE_TAG}-linux-ppc64le-alpine"
if [[ "${GITHUB_REF_NAME}" =~ ^release/.*$ ]]; then
docker buildx imagetools create -t "${EXTENSION_HUB_REPO}:${CONTAINER_IMAGE_TAG}" \
"${EXTENSION_HUB_REPO}:${CONTAINER_IMAGE_TAG}-linux-amd64" \
"${EXTENSION_HUB_REPO}:${CONTAINER_IMAGE_TAG}-linux-arm64"
fi

View File

@@ -11,27 +11,20 @@ on:
- master
- develop
- release/*
types:
- opened
- reopened
- synchronize
- ready_for_review
env:
GO_VERSION: 1.22.5
NODE_VERSION: 18.x
GO_VERSION: 1.21.3
jobs:
run-linters:
name: Run linters
runs-on: ubuntu-latest
if: github.event.pull_request.draft == false
steps:
- uses: actions/checkout@v2
- uses: actions/setup-node@v2
with:
node-version: ${{ env.NODE_VERSION }}
node-version: '18'
cache: 'yarn'
- uses: actions/setup-go@v4
with:
@@ -51,5 +44,6 @@ jobs:
- name: GolangCI-Lint
uses: golangci/golangci-lint-action@v3
with:
version: v1.59.1
version: v1.54.1
working-directory: api
args: --timeout=10m -c .golangci.yaml

View File

@@ -6,9 +6,7 @@ on:
workflow_dispatch:
env:
GO_VERSION: 1.22.5
DOCKER_HUB_REPO: portainerci/portainer-ce
DOCKER_HUB_IMAGE_TAG: develop
GO_VERSION: 1.21.3
jobs:
client-dependencies:
@@ -114,7 +112,7 @@ jobs:
uses: docker://docker.io/aquasec/trivy:latest
continue-on-error: true
with:
args: image --ignore-unfixed=true --vuln-type="os,library" --exit-code=1 --format="json" --output="image-trivy.json" --no-progress ${{ env.DOCKER_HUB_REPO }}:${{ env.DOCKER_HUB_IMAGE_TAG }}
args: image --ignore-unfixed=true --vuln-type="os,library" --exit-code=1 --format="json" --output="image-trivy.json" --no-progress portainerci/portainer:develop
- name: upload Trivy image security scan result as artifact
uses: actions/upload-artifact@v3
@@ -143,10 +141,10 @@ jobs:
continue-on-error: true
with:
command: cves
image: ${{ env.DOCKER_HUB_REPO }}:${{ env.DOCKER_HUB_IMAGE_TAG }}
image: portainerci/portainer:develop
sarif-file: image-docker-scout.json
dockerhub-user: ${{ secrets.DOCKER_HUB_USERNAME }}
dockerhub-password: ${{ secrets.DOCKER_HUB_PASSWORD }}
dockerhub-password: ${{ secrets.DOCKER_HUB_PASSWORD }}
- name: upload Docker Scout image security scan result as artifact
uses: actions/upload-artifact@v3
@@ -199,7 +197,7 @@ jobs:
matrix.js.status == 'failure' ||
matrix.go.status == 'failure' ||
matrix.image-trivy.status == 'failure' ||
matrix.image-docker-scout.status == 'failure'
matrix.image-docker-scout.status == 'failure'
uses: slackapi/slack-github-action@v1.23.0
with:
payload: |

View File

@@ -14,7 +14,7 @@ on:
- '.github/workflows/pr-security.yml'
env:
GO_VERSION: 1.22.5
GO_VERSION: 1.21.3
NODE_VERSION: 18.x
jobs:
@@ -23,8 +23,7 @@ jobs:
runs-on: ubuntu-latest
if: >-
github.event.pull_request &&
github.event.review.body == '/scan' &&
github.event.pull_request.draft == false
github.event.review.body == '/scan'
outputs:
jsdiff: ${{ steps.set-diff-matrix.outputs.js_diff_result }}
steps:
@@ -78,8 +77,7 @@ jobs:
runs-on: ubuntu-latest
if: >-
github.event.pull_request &&
github.event.review.body == '/scan' &&
github.event.pull_request.draft == false
github.event.review.body == '/scan'
outputs:
godiff: ${{ steps.set-diff-matrix.outputs.go_diff_result }}
steps:
@@ -141,8 +139,7 @@ jobs:
runs-on: ubuntu-latest
if: >-
github.event.pull_request &&
github.event.review.body == '/scan' &&
github.event.pull_request.draft == false
github.event.review.body == '/scan'
outputs:
imagediff-trivy: ${{ steps.set-diff-trivy-matrix.outputs.image_diff_trivy_result }}
imagediff-docker-scout: ${{ steps.set-diff-docker-scout-matrix.outputs.image_diff_docker_scout_result }}
@@ -271,8 +268,7 @@ jobs:
runs-on: ubuntu-latest
if: >-
github.event.pull_request &&
github.event.review.body == '/scan' &&
github.event.pull_request.draft == false
github.event.review.body == '/scan'
strategy:
matrix:
jsdiff: ${{fromJson(needs.client-dependencies.outputs.jsdiff)}}

View File

@@ -1,49 +1,25 @@
name: Test
env:
GO_VERSION: 1.22.5
NODE_VERSION: 18.x
on: push
on:
workflow_dispatch:
pull_request:
branches:
- master
- develop
- release/*
types:
- opened
- reopened
- synchronize
- ready_for_review
push:
branches:
- master
- develop
- release/*
env:
GO_VERSION: 1.21.3
NODE_VERSION: 18.x
jobs:
test-client:
runs-on: ubuntu-latest
if: github.event.pull_request.draft == false
steps:
- name: 'checkout the current branch'
uses: actions/checkout@v4.1.1
with:
ref: ${{ github.event.inputs.branch }}
- name: 'set up node.js'
uses: actions/setup-node@v4.0.1
- uses: actions/checkout@v2
- uses: actions/setup-node@v2
with:
node-version: ${{ env.NODE_VERSION }}
cache: 'yarn'
- run: yarn --frozen-lockfile
- name: Run tests
run: make test-client ARGS="--maxWorkers=2 --minWorkers=1"
run: make test-client ARGS="--maxWorkers=2"
test-server:
strategy:
matrix:
@@ -53,24 +29,10 @@ jobs:
- { platform: windows, arch: amd64, version: 1809 }
- { platform: windows, arch: amd64, version: ltsc2022 }
runs-on: ubuntu-latest
if: github.event.pull_request.draft == false
steps:
- name: 'checkout the current branch'
uses: actions/checkout@v4.1.1
with:
ref: ${{ github.event.inputs.branch }}
- name: 'set up golang'
uses: actions/setup-go@v5.0.0
- uses: actions/checkout@v3
- uses: actions/setup-go@v3
with:
go-version: ${{ env.GO_VERSION }}
- name: 'install dependencies'
run: make test-deps PLATFORM=linux ARCH=amd64
- name: 'update $PATH'
run: echo "$(pwd)/dist" >> $GITHUB_PATH
- name: 'run tests'
- name: Run tests
run: make test-server

View File

@@ -6,20 +6,14 @@ on:
- master
- develop
- 'release/*'
types:
- opened
- reopened
- synchronize
- ready_for_review
env:
GO_VERSION: 1.22.5
GO_VERSION: 1.21.3
NODE_VERSION: 18.x
jobs:
openapi-spec:
runs-on: ubuntu-latest
if: github.event.pull_request.draft == false
steps:
- uses: actions/checkout@v3

View File

@@ -3,7 +3,6 @@ import { StorybookConfig } from '@storybook/react-webpack5';
import TsconfigPathsPlugin from 'tsconfig-paths-webpack-plugin';
import { Configuration } from 'webpack';
import postcss from 'postcss';
const config: StorybookConfig = {
stories: ['../app/**/*.stories.@(ts|tsx)'],
addons: [
@@ -88,6 +87,9 @@ const config: StorybookConfig = {
name: '@storybook/react-webpack5',
options: {},
},
docs: {
autodocs: true,
},
};
export default config;

View File

@@ -1,25 +1,23 @@
import '../app/assets/css';
import React from 'react';
import { pushStateLocationPlugin, UIRouter } from '@uirouter/react';
import { initialize as initMSW, mswLoader } from 'msw-storybook-addon';
import { handlers } from '../app/setup-tests/server-handlers';
import { QueryClient, QueryClientProvider } from '@tanstack/react-query';
initMSW(
{
onUnhandledRequest: ({ method, url }) => {
if (url.startsWith('/api')) {
console.error(`Unhandled ${method} request to ${url}.
import { pushStateLocationPlugin, UIRouter } from '@uirouter/react';
import { initialize as initMSW, mswDecorator } from 'msw-storybook-addon';
import { handlers } from '@/setup-tests/server-handlers';
import { QueryClient, QueryClientProvider } from 'react-query';
// Initialize MSW
initMSW({
onUnhandledRequest: ({ method, url }) => {
if (url.pathname.startsWith('/api')) {
console.error(`Unhandled ${method} request to ${url}.
This exception has been only logged in the console, however, it's strongly recommended to resolve this error as you don't want unmocked data in Storybook stories.
If you wish to mock an error response, please refer to this guide: https://mswjs.io/docs/recipes/mocking-error-responses
`);
}
},
}
},
handlers
);
});
export const parameters = {
actions: { argTypesRegex: '^on[A-Z].*' },
@@ -46,6 +44,5 @@ export const decorators = [
</UIRouter>
</QueryClientProvider>
),
mswDecorator,
];
export const loaders = [mswLoader];

View File

@@ -2,22 +2,22 @@
/* tslint:disable */
/**
* Mock Service Worker (2.0.11).
* Mock Service Worker (0.36.3).
* @see https://github.com/mswjs/msw
* - Please do NOT modify this file.
* - Please do NOT serve this file on production.
*/
const INTEGRITY_CHECKSUM = 'c5f7f8e188b673ea4e677df7ea3c5a39';
const IS_MOCKED_RESPONSE = Symbol('isMockedResponse');
const INTEGRITY_CHECKSUM = '02f4ad4a2797f85668baf196e553d929';
const bypassHeaderName = 'x-msw-bypass';
const activeClientIds = new Set();
self.addEventListener('install', function () {
self.skipWaiting();
return self.skipWaiting();
});
self.addEventListener('activate', function (event) {
event.waitUntil(self.clients.claim());
self.addEventListener('activate', async function (event) {
return self.clients.claim();
});
self.addEventListener('message', async function (event) {
@@ -33,9 +33,7 @@ self.addEventListener('message', async function (event) {
return;
}
const allClients = await self.clients.matchAll({
type: 'window',
});
const allClients = await self.clients.matchAll();
switch (event.data) {
case 'KEEPALIVE_REQUEST': {
@@ -85,8 +83,165 @@ self.addEventListener('message', async function (event) {
}
});
// Resolve the "main" client for the given event.
// Client that issues a request doesn't necessarily equal the client
// that registered the worker. It's with the latter the worker should
// communicate with during the response resolving phase.
async function resolveMainClient(event) {
const client = await self.clients.get(event.clientId);
if (client.frameType === 'top-level') {
return client;
}
const allClients = await self.clients.matchAll();
return allClients
.filter((client) => {
// Get only those clients that are currently visible.
return client.visibilityState === 'visible';
})
.find((client) => {
// Find the client ID that's recorded in the
// set of clients that have registered the worker.
return activeClientIds.has(client.id);
});
}
async function handleRequest(event, requestId) {
const client = await resolveMainClient(event);
const response = await getResponse(event, client, requestId);
// Send back the response clone for the "response:*" life-cycle events.
// Ensure MSW is active and ready to handle the message, otherwise
// this message will pend indefinitely.
if (client && activeClientIds.has(client.id)) {
(async function () {
const clonedResponse = response.clone();
sendToClient(client, {
type: 'RESPONSE',
payload: {
requestId,
type: clonedResponse.type,
ok: clonedResponse.ok,
status: clonedResponse.status,
statusText: clonedResponse.statusText,
body: clonedResponse.body === null ? null : await clonedResponse.text(),
headers: serializeHeaders(clonedResponse.headers),
redirected: clonedResponse.redirected,
},
});
})();
}
return response;
}
async function getResponse(event, client, requestId) {
const { request } = event;
const requestClone = request.clone();
const getOriginalResponse = () => fetch(requestClone);
// Bypass mocking when the request client is not active.
if (!client) {
return getOriginalResponse();
}
// Bypass initial page load requests (i.e. static assets).
// The absence of the immediate/parent client in the map of the active clients
// means that MSW hasn't dispatched the "MOCK_ACTIVATE" event yet
// and is not ready to handle requests.
if (!activeClientIds.has(client.id)) {
return await getOriginalResponse();
}
// Bypass requests with the explicit bypass header
if (requestClone.headers.get(bypassHeaderName) === 'true') {
const cleanRequestHeaders = serializeHeaders(requestClone.headers);
// Remove the bypass header to comply with the CORS preflight check.
delete cleanRequestHeaders[bypassHeaderName];
const originalRequest = new Request(requestClone, {
headers: new Headers(cleanRequestHeaders),
});
return fetch(originalRequest);
}
// Send the request to the client-side MSW.
const reqHeaders = serializeHeaders(request.headers);
const body = await request.text();
const clientMessage = await sendToClient(client, {
type: 'REQUEST',
payload: {
id: requestId,
url: request.url,
method: request.method,
headers: reqHeaders,
cache: request.cache,
mode: request.mode,
credentials: request.credentials,
destination: request.destination,
integrity: request.integrity,
redirect: request.redirect,
referrer: request.referrer,
referrerPolicy: request.referrerPolicy,
body,
bodyUsed: request.bodyUsed,
keepalive: request.keepalive,
},
});
switch (clientMessage.type) {
case 'MOCK_SUCCESS': {
return delayPromise(() => respondWithMock(clientMessage), clientMessage.payload.delay);
}
case 'MOCK_NOT_FOUND': {
return getOriginalResponse();
}
case 'NETWORK_ERROR': {
const { name, message } = clientMessage.payload;
const networkError = new Error(message);
networkError.name = name;
// Rejecting a request Promise emulates a network error.
throw networkError;
}
case 'INTERNAL_ERROR': {
const parsedBody = JSON.parse(clientMessage.payload.body);
console.error(
`\
[MSW] Uncaught exception in the request handler for "%s %s":
${parsedBody.location}
This exception has been gracefully handled as a 500 response, however, it's strongly recommended to resolve this error, as it indicates a mistake in your code. If you wish to mock an error response, please see this guide: https://mswjs.io/docs/recipes/mocking-error-responses\
`,
request.method,
request.url
);
return respondWithMock(clientMessage);
}
}
return getOriginalResponse();
}
self.addEventListener('fetch', function (event) {
const { request } = event;
const accept = request.headers.get('accept') || '';
// Bypass server-sent events.
if (accept.includes('text/event-stream')) {
return;
}
// Bypass navigation requests.
if (request.mode === 'navigate') {
@@ -106,149 +261,36 @@ self.addEventListener('fetch', function (event) {
return;
}
// Generate unique request ID.
const requestId = crypto.randomUUID();
event.respondWith(handleRequest(event, requestId));
const requestId = uuidv4();
return event.respondWith(
handleRequest(event, requestId).catch((error) => {
if (error.name === 'NetworkError') {
console.warn('[MSW] Successfully emulated a network error for the "%s %s" request.', request.method, request.url);
return;
}
// At this point, any exception indicates an issue with the original request/response.
console.error(
`\
[MSW] Caught an exception from the "%s %s" request (%s). This is probably not a problem with Mock Service Worker. There is likely an additional logging output above.`,
request.method,
request.url,
`${error.name}: ${error.message}`
);
})
);
});
async function handleRequest(event, requestId) {
const client = await resolveMainClient(event);
const response = await getResponse(event, client, requestId);
// Send back the response clone for the "response:*" life-cycle events.
// Ensure MSW is active and ready to handle the message, otherwise
// this message will pend indefinitely.
if (client && activeClientIds.has(client.id)) {
(async function () {
const responseClone = response.clone();
sendToClient(
client,
{
type: 'RESPONSE',
payload: {
requestId,
isMockedResponse: IS_MOCKED_RESPONSE in response,
type: responseClone.type,
status: responseClone.status,
statusText: responseClone.statusText,
body: responseClone.body,
headers: Object.fromEntries(responseClone.headers.entries()),
},
},
[responseClone.body]
);
})();
}
return response;
}
// Resolve the main client for the given event.
// Client that issues a request doesn't necessarily equal the client
// that registered the worker. It's with the latter the worker should
// communicate with during the response resolving phase.
async function resolveMainClient(event) {
const client = await self.clients.get(event.clientId);
if (client?.frameType === 'top-level') {
return client;
}
const allClients = await self.clients.matchAll({
type: 'window',
function serializeHeaders(headers) {
const reqHeaders = {};
headers.forEach((value, name) => {
reqHeaders[name] = reqHeaders[name] ? [].concat(reqHeaders[name]).concat(value) : value;
});
return allClients
.filter((client) => {
// Get only those clients that are currently visible.
return client.visibilityState === 'visible';
})
.find((client) => {
// Find the client ID that's recorded in the
// set of clients that have registered the worker.
return activeClientIds.has(client.id);
});
return reqHeaders;
}
async function getResponse(event, client, requestId) {
const { request } = event;
// Clone the request because it might've been already used
// (i.e. its body has been read and sent to the client).
const requestClone = request.clone();
function passthrough() {
const headers = Object.fromEntries(requestClone.headers.entries());
// Remove internal MSW request header so the passthrough request
// complies with any potential CORS preflight checks on the server.
// Some servers forbid unknown request headers.
delete headers['x-msw-intention'];
return fetch(requestClone, { headers });
}
// Bypass mocking when the client is not active.
if (!client) {
return passthrough();
}
// Bypass initial page load requests (i.e. static assets).
// The absence of the immediate/parent client in the map of the active clients
// means that MSW hasn't dispatched the "MOCK_ACTIVATE" event yet
// and is not ready to handle requests.
if (!activeClientIds.has(client.id)) {
return passthrough();
}
// Bypass requests with the explicit bypass header.
// Such requests can be issued by "ctx.fetch()".
const mswIntention = request.headers.get('x-msw-intention');
if (['bypass', 'passthrough'].includes(mswIntention)) {
return passthrough();
}
// Notify the client that a request has been intercepted.
const requestBuffer = await request.arrayBuffer();
const clientMessage = await sendToClient(
client,
{
type: 'REQUEST',
payload: {
id: requestId,
url: request.url,
mode: request.mode,
method: request.method,
headers: Object.fromEntries(request.headers.entries()),
cache: request.cache,
credentials: request.credentials,
destination: request.destination,
integrity: request.integrity,
redirect: request.redirect,
referrer: request.referrer,
referrerPolicy: request.referrerPolicy,
body: requestBuffer,
keepalive: request.keepalive,
},
},
[requestBuffer]
);
switch (clientMessage.type) {
case 'MOCK_RESPONSE': {
return respondWithMock(clientMessage.data);
}
case 'MOCK_NOT_FOUND': {
return passthrough();
}
}
return passthrough();
}
function sendToClient(client, message, transferrables = []) {
function sendToClient(client, message) {
return new Promise((resolve, reject) => {
const channel = new MessageChannel();
@@ -260,25 +302,27 @@ function sendToClient(client, message, transferrables = []) {
resolve(event.data);
};
client.postMessage(message, [channel.port2].concat(transferrables.filter(Boolean)));
client.postMessage(JSON.stringify(message), [channel.port2]);
});
}
async function respondWithMock(response) {
// Setting response status code to 0 is a no-op.
// However, when responding with a "Response.error()", the produced Response
// instance will have status code set to 0. Since it's not possible to create
// a Response instance with status code 0, handle that use-case separately.
if (response.status === 0) {
return Response.error();
}
const mockedResponse = new Response(response.body, response);
Reflect.defineProperty(mockedResponse, IS_MOCKED_RESPONSE, {
value: true,
enumerable: true,
function delayPromise(cb, duration) {
return new Promise((resolve) => {
setTimeout(() => resolve(cb()), duration);
});
}
function respondWithMock(clientMessage) {
return new Response(clientMessage.payload.body, {
...clientMessage.payload,
headers: clientMessage.payload.headers,
});
}
function uuidv4() {
return 'xxxxxxxx-xxxx-4xxx-yxxx-xxxxxxxxxxxx'.replace(/[xy]/g, function (c) {
const r = (Math.random() * 16) | 0;
const v = c == 'x' ? r : (r & 0x3) | 0x8;
return v.toString(16);
});
return mockedResponse;
}

View File

@@ -7,9 +7,9 @@ ARCH=$(shell go env GOARCH)
# build target, can be one of "production", "testing", "development"
ENV=development
WEBPACK_CONFIG=webpack/webpack.$(ENV).js
TAG=local
TAG=latest
SWAG=go run github.com/swaggo/swag/cmd/swag@v1.16.2
SWAG=go run github.com/swaggo/swag/cmd/swag@v1.8.11
GOTESTSUM=go run gotest.tools/gotestsum@latest
# Don't change anything below this line unless you know what you're doing
@@ -30,7 +30,7 @@ build-server: init-dist ## Build the server binary
./build/build_binary.sh "$(PLATFORM)" "$(ARCH)"
build-image: build-all ## Build the Portainer image locally
docker buildx build --load -t portainerci/portainer-ce:$(TAG) -f build/linux/Dockerfile .
docker buildx build --load -t portainerci/portainer:$(TAG) -f build/linux/Dockerfile .
build-storybook: ## Build and serve the storybook files
yarn storybook:build
@@ -64,14 +64,11 @@ clean: ## Remove all build and download artifacts
.PHONY: test test-client test-server
test: test-server test-client ## Run all tests
test-deps: init-dist
./build/download_docker_compose_binary.sh $(PLATFORM) $(ARCH) $(shell jq -r '.dockerCompose' < "./binary-version.json")
test-client: ## Run client tests
yarn test $(ARGS)
test-server: ## Run server tests
$(GOTESTSUM) --format pkgname-and-test-fails --format-hide-empty-pkg --hide-summary skipped -- -cover ./...
cd api && $(GOTESTSUM) --format pkgname-and-test-fails --format-hide-empty-pkg --hide-summary skipped -- -cover ./...
##@ Dev
.PHONY: dev dev-client dev-server
@@ -85,8 +82,6 @@ dev-client: ## Run the client in development mode
dev-server: build-server ## Run the server in development mode
@./dev/run_container.sh
dev-server-podman: build-server ## Run the server in development mode
@./dev/run_container_podman.sh
##@ Format
.PHONY: format format-client format-server
@@ -97,7 +92,7 @@ format-client: ## Format client code
yarn format
format-server: ## Format server code
go fmt ./...
cd api && go fmt ./...
##@ Lint
.PHONY: lint lint-client lint-server
@@ -107,7 +102,7 @@ lint-client: ## Lint client code
yarn lint
lint-server: ## Lint server code
golangci-lint run --timeout=10m -c .golangci.yaml
cd api && go vet ./...
##@ Extension
@@ -119,7 +114,7 @@ dev-extension: build-server build-client ## Run the extension in development mod
##@ Docs
.PHONY: docs-build docs-validate docs-clean docs-validate-clean
docs-build: init-dist ## Build docs
cd api && $(SWAG) init -o "../dist/docs" -ot "yaml" -g ./http/handler/handler.go --parseDependency --parseInternal --parseDepth 2 -p pascalcase --markdownFiles ./
cd api && $(SWAG) init -o "../dist/docs" -ot "yaml" -g ./http/handler/handler.go --parseDependency --parseInternal --parseDepth 2 --markdownFiles ./
docs-validate: docs-build ## Validate docs
yarn swagger2openapi --warnOnly dist/docs/swagger.yaml -o dist/docs/openapi.yaml

View File

@@ -4,11 +4,11 @@ linters:
# Enable these for now
enable:
- unused
- depguard
- gosimple
- govet
- errorlint
- exportloopref
- staticcheck
linters-settings:
depguard:

View File

@@ -10,7 +10,7 @@ import (
"time"
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/url"
"github.com/portainer/portainer/api/internal/url"
)
// GetAgentVersionAndPlatform returns the agent version and platform

View File

@@ -6,11 +6,11 @@ import (
// APIKeyService represents a service for managing API keys.
type APIKeyService interface {
HashRaw(rawKey string) string
HashRaw(rawKey string) []byte
GenerateApiKey(user portainer.User, description string) (string, *portainer.APIKey, error)
GetAPIKey(apiKeyID portainer.APIKeyID) (*portainer.APIKey, error)
GetAPIKeys(userID portainer.UserID) ([]portainer.APIKey, error)
GetDigestUserAndKey(digest string) (portainer.User, portainer.APIKey, error)
GetDigestUserAndKey(digest []byte) (portainer.User, portainer.APIKey, error)
UpdateAPIKey(apiKey *portainer.APIKey) error
DeleteAPIKey(apiKeyID portainer.APIKeyID) error
InvalidateUserKeyCache(userId portainer.UserID) bool

View File

@@ -3,6 +3,7 @@ package apikey
import (
"testing"
"github.com/portainer/portainer/api/internal/securecookie"
"github.com/stretchr/testify/assert"
)
@@ -33,19 +34,17 @@ func Test_generateRandomKey(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := GenerateRandomKey(tt.wantLenth)
got := securecookie.GenerateRandomKey(tt.wantLenth)
is.Equal(tt.wantLenth, len(got))
})
}
t.Run("Generated keys are unique", func(t *testing.T) {
keys := make(map[string]bool)
for range 100 {
key := GenerateRandomKey(8)
for i := 0; i < 100; i++ {
key := securecookie.GenerateRandomKey(8)
_, ok := keys[string(key)]
is.False(ok)
keys[string(key)] = true
}
})

View File

@@ -1,79 +1,69 @@
package apikey
import (
portainer "github.com/portainer/portainer/api"
lru "github.com/hashicorp/golang-lru"
portainer "github.com/portainer/portainer/api"
)
const DefaultAPIKeyCacheSize = 1024
const defaultAPIKeyCacheSize = 1024
// entry is a tuple containing the user and API key associated to an API key digest
type entry[T any] struct {
user T
type entry struct {
user portainer.User
apiKey portainer.APIKey
}
type UserCompareFn[T any] func(T, portainer.UserID) bool
// ApiKeyCache is a concurrency-safe, in-memory cache which primarily exists for to reduce database roundtrips.
// apiKeyCache is a concurrency-safe, in-memory cache which primarily exists for to reduce database roundtrips.
// We store the api-key digest (keys) and the associated user and key-data (values) in the cache.
// This is required because HTTP requests will contain only the api-key digest in the x-api-key request header;
// digest value must be mapped to a portainer user (and respective key data) for validation.
// This cache is used to avoid multiple database queries to retrieve these user/key associated to the digest.
type ApiKeyCache[T any] struct {
type apiKeyCache struct {
// cache type [string]entry cache (key: string(digest), value: user/key entry)
// note: []byte keys are not supported by golang-lru Cache
cache *lru.Cache
userCmpFn UserCompareFn[T]
cache *lru.Cache
}
// NewAPIKeyCache creates a new cache for API keys
func NewAPIKeyCache[T any](cacheSize int, userCompareFn UserCompareFn[T]) *ApiKeyCache[T] {
func NewAPIKeyCache(cacheSize int) *apiKeyCache {
cache, _ := lru.New(cacheSize)
return &ApiKeyCache[T]{cache: cache, userCmpFn: userCompareFn}
return &apiKeyCache{cache: cache}
}
// Get returns the user/key associated to an api-key's digest
// This is required because HTTP requests will contain the digest of the API key in header,
// the digest value must be mapped to a portainer user.
func (c *ApiKeyCache[T]) Get(digest string) (T, portainer.APIKey, bool) {
val, ok := c.cache.Get(digest)
func (c *apiKeyCache) Get(digest []byte) (portainer.User, portainer.APIKey, bool) {
val, ok := c.cache.Get(string(digest))
if !ok {
var t T
return t, portainer.APIKey{}, false
return portainer.User{}, portainer.APIKey{}, false
}
tuple := val.(entry[T])
tuple := val.(entry)
return tuple.user, tuple.apiKey, true
}
// Set persists a user/key entry to the cache
func (c *ApiKeyCache[T]) Set(digest string, user T, apiKey portainer.APIKey) {
c.cache.Add(digest, entry[T]{
func (c *apiKeyCache) Set(digest []byte, user portainer.User, apiKey portainer.APIKey) {
c.cache.Add(string(digest), entry{
user: user,
apiKey: apiKey,
})
}
// Delete evicts a digest's user/key entry key from the cache
func (c *ApiKeyCache[T]) Delete(digest string) {
c.cache.Remove(digest)
func (c *apiKeyCache) Delete(digest []byte) {
c.cache.Remove(string(digest))
}
// InvalidateUserKeyCache loops through all the api-keys associated to a user and removes them from the cache
func (c *ApiKeyCache[T]) InvalidateUserKeyCache(userId portainer.UserID) bool {
func (c *apiKeyCache) InvalidateUserKeyCache(userId portainer.UserID) bool {
present := false
for _, k := range c.cache.Keys() {
user, _, _ := c.Get(k.(string))
if c.userCmpFn(user, userId) {
user, _, _ := c.Get([]byte(k.(string)))
if user.ID == userId {
present = c.cache.Remove(k)
}
}
return present
}

View File

@@ -10,32 +10,32 @@ import (
func Test_apiKeyCacheGet(t *testing.T) {
is := assert.New(t)
keyCache := NewAPIKeyCache(10, compareUser)
keyCache := NewAPIKeyCache(10)
// pre-populate cache
keyCache.cache.Add(string("foo"), entry[portainer.User]{user: portainer.User{}, apiKey: portainer.APIKey{}})
keyCache.cache.Add(string(""), entry[portainer.User]{user: portainer.User{}, apiKey: portainer.APIKey{}})
keyCache.cache.Add(string("foo"), entry{user: portainer.User{}, apiKey: portainer.APIKey{}})
keyCache.cache.Add(string(""), entry{user: portainer.User{}, apiKey: portainer.APIKey{}})
tests := []struct {
digest string
digest []byte
found bool
}{
{
digest: "foo",
digest: []byte("foo"),
found: true,
},
{
digest: "",
digest: []byte(""),
found: true,
},
{
digest: "bar",
digest: []byte("bar"),
found: false,
},
}
for _, test := range tests {
t.Run(test.digest, func(t *testing.T) {
t.Run(string(test.digest), func(t *testing.T) {
_, _, found := keyCache.Get(test.digest)
is.Equal(test.found, found)
})
@@ -45,43 +45,43 @@ func Test_apiKeyCacheGet(t *testing.T) {
func Test_apiKeyCacheSet(t *testing.T) {
is := assert.New(t)
keyCache := NewAPIKeyCache(10, compareUser)
keyCache := NewAPIKeyCache(10)
// pre-populate cache
keyCache.Set("bar", portainer.User{ID: 2}, portainer.APIKey{})
keyCache.Set("foo", portainer.User{ID: 1}, portainer.APIKey{})
keyCache.Set([]byte("bar"), portainer.User{ID: 2}, portainer.APIKey{})
keyCache.Set([]byte("foo"), portainer.User{ID: 1}, portainer.APIKey{})
// overwrite existing entry
keyCache.Set("foo", portainer.User{ID: 3}, portainer.APIKey{})
keyCache.Set([]byte("foo"), portainer.User{ID: 3}, portainer.APIKey{})
val, ok := keyCache.cache.Get(string("bar"))
is.True(ok)
tuple := val.(entry[portainer.User])
tuple := val.(entry)
is.Equal(portainer.User{ID: 2}, tuple.user)
val, ok = keyCache.cache.Get(string("foo"))
is.True(ok)
tuple = val.(entry[portainer.User])
tuple = val.(entry)
is.Equal(portainer.User{ID: 3}, tuple.user)
}
func Test_apiKeyCacheDelete(t *testing.T) {
is := assert.New(t)
keyCache := NewAPIKeyCache(10, compareUser)
keyCache := NewAPIKeyCache(10)
t.Run("Delete an existing entry", func(t *testing.T) {
keyCache.cache.Add(string("foo"), entry[portainer.User]{user: portainer.User{ID: 1}, apiKey: portainer.APIKey{}})
keyCache.Delete("foo")
keyCache.cache.Add(string("foo"), entry{user: portainer.User{ID: 1}, apiKey: portainer.APIKey{}})
keyCache.Delete([]byte("foo"))
_, ok := keyCache.cache.Get(string("foo"))
is.False(ok)
})
t.Run("Delete a non-existing entry", func(t *testing.T) {
nonPanicFunc := func() { keyCache.Delete("non-existent-key") }
nonPanicFunc := func() { keyCache.Delete([]byte("non-existent-key")) }
is.NotPanics(nonPanicFunc)
})
}
@@ -128,19 +128,19 @@ func Test_apiKeyCacheLRU(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
keyCache := NewAPIKeyCache(test.cacheLen, compareUser)
keyCache := NewAPIKeyCache(test.cacheLen)
for _, key := range test.key {
keyCache.Set(key, portainer.User{ID: 1}, portainer.APIKey{})
keyCache.Set([]byte(key), portainer.User{ID: 1}, portainer.APIKey{})
}
for _, key := range test.foundKeys {
_, _, found := keyCache.Get(key)
_, _, found := keyCache.Get([]byte(key))
is.True(found, "Key %s not found", key)
}
for _, key := range test.evictedKeys {
_, _, found := keyCache.Get(key)
_, _, found := keyCache.Get([]byte(key))
is.False(found, "key %s should have been evicted", key)
}
})
@@ -150,10 +150,10 @@ func Test_apiKeyCacheLRU(t *testing.T) {
func Test_apiKeyCacheInvalidateUserKeyCache(t *testing.T) {
is := assert.New(t)
keyCache := NewAPIKeyCache(10, compareUser)
keyCache := NewAPIKeyCache(10)
t.Run("Removes users keys from cache", func(t *testing.T) {
keyCache.cache.Add(string("foo"), entry[portainer.User]{user: portainer.User{ID: 1}, apiKey: portainer.APIKey{}})
keyCache.cache.Add(string("foo"), entry{user: portainer.User{ID: 1}, apiKey: portainer.APIKey{}})
ok := keyCache.InvalidateUserKeyCache(1)
is.True(ok)
@@ -163,8 +163,8 @@ func Test_apiKeyCacheInvalidateUserKeyCache(t *testing.T) {
})
t.Run("Does not affect other keys", func(t *testing.T) {
keyCache.cache.Add(string("foo"), entry[portainer.User]{user: portainer.User{ID: 1}, apiKey: portainer.APIKey{}})
keyCache.cache.Add(string("bar"), entry[portainer.User]{user: portainer.User{ID: 2}, apiKey: portainer.APIKey{}})
keyCache.cache.Add(string("foo"), entry{user: portainer.User{ID: 1}, apiKey: portainer.APIKey{}})
keyCache.cache.Add(string("bar"), entry{user: portainer.User{ID: 2}, apiKey: portainer.APIKey{}})
ok := keyCache.InvalidateUserKeyCache(1)
is.True(ok)

View File

@@ -1,15 +1,14 @@
package apikey
import (
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"fmt"
"io"
"time"
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/dataservices"
"github.com/portainer/portainer/api/internal/securecookie"
"github.com/pkg/errors"
)
@@ -21,45 +20,30 @@ var ErrInvalidAPIKey = errors.New("Invalid API key")
type apiKeyService struct {
apiKeyRepository dataservices.APIKeyRepository
userRepository dataservices.UserService
cache *ApiKeyCache[portainer.User]
}
// GenerateRandomKey generates a random key of specified length
// source: https://github.com/gorilla/securecookie/blob/master/securecookie.go#L515
func GenerateRandomKey(length int) []byte {
k := make([]byte, length)
if _, err := io.ReadFull(rand.Reader, k); err != nil {
return nil
}
return k
}
func compareUser(u portainer.User, id portainer.UserID) bool {
return u.ID == id
cache *apiKeyCache
}
func NewAPIKeyService(apiKeyRepository dataservices.APIKeyRepository, userRepository dataservices.UserService) *apiKeyService {
return &apiKeyService{
apiKeyRepository: apiKeyRepository,
userRepository: userRepository,
cache: NewAPIKeyCache(DefaultAPIKeyCacheSize, compareUser),
cache: NewAPIKeyCache(defaultAPIKeyCacheSize),
}
}
// HashRaw computes a hash digest of provided raw API key.
func (a *apiKeyService) HashRaw(rawKey string) string {
func (a *apiKeyService) HashRaw(rawKey string) []byte {
hashDigest := sha256.Sum256([]byte(rawKey))
return base64.StdEncoding.EncodeToString(hashDigest[:])
return hashDigest[:]
}
// GenerateApiKey generates a raw API key for a user (for one-time display).
// The generated API key is stored in the cache and database.
func (a *apiKeyService) GenerateApiKey(user portainer.User, description string) (string, *portainer.APIKey, error) {
randKey := GenerateRandomKey(32)
randKey := securecookie.GenerateRandomKey(32)
encodedRawAPIKey := base64.StdEncoding.EncodeToString(randKey)
prefixedAPIKey := portainerAPIKeyPrefix + encodedRawAPIKey
hashDigest := a.HashRaw(prefixedAPIKey)
apiKey := &portainer.APIKey{
@@ -70,7 +54,8 @@ func (a *apiKeyService) GenerateApiKey(user portainer.User, description string)
Digest: hashDigest,
}
if err := a.apiKeyRepository.Create(apiKey); err != nil {
err := a.apiKeyRepository.Create(apiKey)
if err != nil {
return "", nil, errors.Wrap(err, "Unable to create API key")
}
@@ -92,7 +77,8 @@ func (a *apiKeyService) GetAPIKeys(userID portainer.UserID) ([]portainer.APIKey,
// GetDigestUserAndKey returns the user and api-key associated to a specified hash digest.
// A cache lookup is performed first; if the user/api-key is not found in the cache, respective database lookups are performed.
func (a *apiKeyService) GetDigestUserAndKey(digest string) (portainer.User, portainer.APIKey, error) {
func (a *apiKeyService) GetDigestUserAndKey(digest []byte) (portainer.User, portainer.APIKey, error) {
// get api key from cache if possible
cachedUser, cachedKey, ok := a.cache.Get(digest)
if ok {
return cachedUser, cachedKey, nil
@@ -120,21 +106,20 @@ func (a *apiKeyService) UpdateAPIKey(apiKey *portainer.APIKey) error {
if err != nil {
return errors.Wrap(err, "Unable to retrieve API key")
}
a.cache.Set(apiKey.Digest, user, *apiKey)
return a.apiKeyRepository.Update(apiKey.ID, apiKey)
}
// DeleteAPIKey deletes an API key and removes the digest/api-key entry from the cache.
func (a *apiKeyService) DeleteAPIKey(apiKeyID portainer.APIKeyID) error {
// get api-key digest to remove from cache
apiKey, err := a.apiKeyRepository.Read(apiKeyID)
if err != nil {
return errors.Wrap(err, fmt.Sprintf("Unable to retrieve API key: %d", apiKeyID))
}
// delete the user/api-key from cache
a.cache.Delete(apiKey.Digest)
return a.apiKeyRepository.Delete(apiKeyID)
}

View File

@@ -2,7 +2,6 @@ package apikey
import (
"crypto/sha256"
"encoding/base64"
"fmt"
"strings"
"testing"
@@ -69,7 +68,7 @@ func Test_GenerateApiKey(t *testing.T) {
generatedDigest := sha256.Sum256([]byte(rawKey))
is.Equal(apiKey.Digest, base64.StdEncoding.EncodeToString(generatedDigest[:]))
is.Equal(apiKey.Digest, generatedDigest[:])
})
}

View File

@@ -48,6 +48,18 @@ func TarGzDir(absolutePath string) (string, error) {
}
func addToArchive(tarWriter *tar.Writer, pathInArchive string, path string, info os.FileInfo) error {
header, err := tar.FileInfoHeader(info, info.Name())
if err != nil {
return err
}
header.Name = pathInArchive // use relative paths in archive
err = tarWriter.WriteHeader(header)
if err != nil {
return err
}
if info.IsDir() {
return nil
}
@@ -56,26 +68,6 @@ func addToArchive(tarWriter *tar.Writer, pathInArchive string, path string, info
if err != nil {
return err
}
stat, err := file.Stat()
if err != nil {
return err
}
header, err := tar.FileInfoHeader(stat, stat.Name())
if err != nil {
return err
}
header.Name = pathInArchive // use relative paths in archive
err = tarWriter.WriteHeader(header)
if err != nil {
return err
}
if stat.IsDir() {
return nil
}
_, err = io.Copy(tarWriter, file)
return err
}
@@ -106,7 +98,7 @@ func ExtractTarGz(r io.Reader, outputDirPath string) error {
// skip, dir will be created with a file
case tar.TypeReg:
p := filepath.Clean(filepath.Join(outputDirPath, header.Name))
if err := os.MkdirAll(filepath.Dir(p), 0o744); err != nil {
if err := os.MkdirAll(filepath.Dir(p), 0744); err != nil {
return fmt.Errorf("Failed to extract dir %s", filepath.Dir(p))
}
outFile, err := os.Create(p)

View File

@@ -17,7 +17,7 @@ import (
"github.com/rs/zerolog/log"
)
const rwxr__r__ os.FileMode = 0o744
const rwxr__r__ os.FileMode = 0744
var filesToBackup = []string{
"certs",
@@ -82,9 +82,14 @@ func CreateBackupArchive(password string, gate *offlinegate.OfflineGate, datasto
}
func backupDb(backupDirPath string, datastore dataservices.DataStore) error {
dbFileName := datastore.Connection().GetDatabaseFileName()
_, err := datastore.Backup(filepath.Join(backupDirPath, dbFileName))
return err
backupWriter, err := os.Create(filepath.Join(backupDirPath, "portainer.db"))
if err != nil {
return err
}
if err = datastore.BackupTo(backupWriter); err != nil {
return err
}
return backupWriter.Close()
}
func encrypt(path string, passphrase string) (string, error) {

View File

@@ -26,7 +26,7 @@ func RestoreArchive(archive io.Reader, password string, filestorePath string, ga
if password != "" {
archive, err = decrypt(archive, password)
if err != nil {
return errors.Wrap(err, "failed to decrypt the archive. Please ensure the password is correct and try again")
return errors.Wrap(err, "failed to decrypt the archive")
}
}

View File

@@ -1,12 +1,9 @@
package build
import "runtime"
// Variables to be set during the build time
var BuildNumber string
var ImageTag string
var NodejsVersion string
var YarnVersion string
var WebpackVersion string
var GoVersion string = runtime.Version()
var GitCommit string
var GoVersion string

View File

@@ -5,17 +5,6 @@ import (
"github.com/portainer/portainer/api/internal/edge/cache"
)
// EdgeJobs retrieves the edge jobs for the given environment
func (service *Service) EdgeJobs(endpointID portainer.EndpointID) []portainer.EdgeJob {
service.mu.RLock()
defer service.mu.RUnlock()
return append(
make([]portainer.EdgeJob, 0, len(service.edgeJobs[endpointID])),
service.edgeJobs[endpointID]...,
)
}
// AddEdgeJob register an EdgeJob inside the tunnel details associated to an environment(endpoint).
func (service *Service) AddEdgeJob(endpoint *portainer.Endpoint, edgeJob *portainer.EdgeJob) {
if endpoint.Edge.AsyncMode {
@@ -23,10 +12,10 @@ func (service *Service) AddEdgeJob(endpoint *portainer.Endpoint, edgeJob *portai
}
service.mu.Lock()
defer service.mu.Unlock()
tunnel := service.getTunnelDetails(endpoint.ID)
existingJobIndex := -1
for idx, existingJob := range service.edgeJobs[endpoint.ID] {
for idx, existingJob := range tunnel.Jobs {
if existingJob.ID == edgeJob.ID {
existingJobIndex = idx
@@ -35,28 +24,30 @@ func (service *Service) AddEdgeJob(endpoint *portainer.Endpoint, edgeJob *portai
}
if existingJobIndex == -1 {
service.edgeJobs[endpoint.ID] = append(service.edgeJobs[endpoint.ID], *edgeJob)
tunnel.Jobs = append(tunnel.Jobs, *edgeJob)
} else {
service.edgeJobs[endpoint.ID][existingJobIndex] = *edgeJob
tunnel.Jobs[existingJobIndex] = *edgeJob
}
cache.Del(endpoint.ID)
service.mu.Unlock()
}
// RemoveEdgeJob will remove the specified Edge job from each tunnel it was registered with.
func (service *Service) RemoveEdgeJob(edgeJobID portainer.EdgeJobID) {
service.mu.Lock()
for endpointID := range service.edgeJobs {
for endpointID, tunnel := range service.tunnelDetailsMap {
n := 0
for _, edgeJob := range service.edgeJobs[endpointID] {
for _, edgeJob := range tunnel.Jobs {
if edgeJob.ID != edgeJobID {
service.edgeJobs[endpointID][n] = edgeJob
tunnel.Jobs[n] = edgeJob
n++
}
}
service.edgeJobs[endpointID] = service.edgeJobs[endpointID][:n]
tunnel.Jobs = tunnel.Jobs[:n]
cache.Del(endpointID)
}
@@ -66,17 +57,19 @@ func (service *Service) RemoveEdgeJob(edgeJobID portainer.EdgeJobID) {
func (service *Service) RemoveEdgeJobFromEndpoint(endpointID portainer.EndpointID, edgeJobID portainer.EdgeJobID) {
service.mu.Lock()
defer service.mu.Unlock()
tunnel := service.getTunnelDetails(endpointID)
n := 0
for _, edgeJob := range service.edgeJobs[endpointID] {
for _, edgeJob := range tunnel.Jobs {
if edgeJob.ID != edgeJobID {
service.edgeJobs[endpointID][n] = edgeJob
tunnel.Jobs[n] = edgeJob
n++
}
}
service.edgeJobs[endpointID] = service.edgeJobs[endpointID][:n]
tunnel.Jobs = tunnel.Jobs[:n]
cache.Del(endpointID)
service.mu.Unlock()
}

View File

@@ -19,127 +19,99 @@ import (
const (
tunnelCleanupInterval = 10 * time.Second
requiredTimeout = 15 * time.Second
activeTimeout = 4*time.Minute + 30*time.Second
pingTimeout = 3 * time.Second
)
// Service represents a service to manage the state of multiple reverse tunnels.
// It is used to start a reverse tunnel server and to manage the connection status of each tunnel
// connected to the tunnel server.
type Service struct {
serverFingerprint string
serverPort string
activeTunnels map[portainer.EndpointID]*portainer.TunnelDetails
edgeJobs map[portainer.EndpointID][]portainer.EdgeJob
dataStore dataservices.DataStore
snapshotService portainer.SnapshotService
chiselServer *chserver.Server
shutdownCtx context.Context
ProxyManager *proxy.Manager
mu sync.RWMutex
fileService portainer.FileService
defaultCheckinInterval int
serverFingerprint string
serverPort string
tunnelDetailsMap map[portainer.EndpointID]*portainer.TunnelDetails
dataStore dataservices.DataStore
snapshotService portainer.SnapshotService
chiselServer *chserver.Server
shutdownCtx context.Context
ProxyManager *proxy.Manager
mu sync.Mutex
fileService portainer.FileService
}
// NewService returns a pointer to a new instance of Service
func NewService(dataStore dataservices.DataStore, shutdownCtx context.Context, fileService portainer.FileService) *Service {
defaultCheckinInterval := portainer.DefaultEdgeAgentCheckinIntervalInSeconds
settings, err := dataStore.Settings().Settings()
if err == nil {
defaultCheckinInterval = settings.EdgeAgentCheckinInterval
} else {
log.Error().Err(err).Msg("unable to retrieve the settings from the database")
}
return &Service{
activeTunnels: make(map[portainer.EndpointID]*portainer.TunnelDetails),
edgeJobs: make(map[portainer.EndpointID][]portainer.EdgeJob),
dataStore: dataStore,
shutdownCtx: shutdownCtx,
fileService: fileService,
defaultCheckinInterval: defaultCheckinInterval,
tunnelDetailsMap: make(map[portainer.EndpointID]*portainer.TunnelDetails),
dataStore: dataStore,
shutdownCtx: shutdownCtx,
fileService: fileService,
}
}
// pingAgent ping the given agent so that the agent can keep the tunnel alive
func (service *Service) pingAgent(endpointID portainer.EndpointID) error {
endpoint, err := service.dataStore.Endpoint().Endpoint(endpointID)
if err != nil {
return err
}
tunnelAddr, err := service.TunnelAddr(endpoint)
if err != nil {
return err
}
requestURL := fmt.Sprintf("http://%s/ping", tunnelAddr)
tunnel := service.GetTunnelDetails(endpointID)
requestURL := fmt.Sprintf("http://127.0.0.1:%d/ping", tunnel.Port)
req, err := http.NewRequest(http.MethodHead, requestURL, nil)
if err != nil {
return err
}
httpClient := &http.Client{
Timeout: pingTimeout,
Timeout: 3 * time.Second,
}
resp, err := httpClient.Do(req)
if err != nil {
return err
}
io.Copy(io.Discard, resp.Body)
resp.Body.Close()
return nil
return err
}
// KeepTunnelAlive keeps the tunnel of the given environment for maxAlive duration, or until ctx is done
func (service *Service) KeepTunnelAlive(endpointID portainer.EndpointID, ctx context.Context, maxAlive time.Duration) {
go service.keepTunnelAlive(endpointID, ctx, maxAlive)
}
go func() {
log.Debug().
Int("endpoint_id", int(endpointID)).
Float64("max_alive_minutes", maxAlive.Minutes()).
Msg("KeepTunnelAlive: start")
func (service *Service) keepTunnelAlive(endpointID portainer.EndpointID, ctx context.Context, maxAlive time.Duration) {
log.Debug().
Int("endpoint_id", int(endpointID)).
Float64("max_alive_minutes", maxAlive.Minutes()).
Msg("KeepTunnelAlive: start")
maxAliveTicker := time.NewTicker(maxAlive)
defer maxAliveTicker.Stop()
maxAliveTicker := time.NewTicker(maxAlive)
defer maxAliveTicker.Stop()
pingTicker := time.NewTicker(tunnelCleanupInterval)
defer pingTicker.Stop()
pingTicker := time.NewTicker(tunnelCleanupInterval)
defer pingTicker.Stop()
for {
select {
case <-pingTicker.C:
service.SetTunnelStatusToActive(endpointID)
err := service.pingAgent(endpointID)
if err != nil {
log.Debug().
Int("endpoint_id", int(endpointID)).
Err(err).
Msg("KeepTunnelAlive: ping agent")
}
case <-maxAliveTicker.C:
log.Debug().
Int("endpoint_id", int(endpointID)).
Float64("timeout_minutes", maxAlive.Minutes()).
Msg("KeepTunnelAlive: tunnel keep alive timeout")
for {
select {
case <-pingTicker.C:
service.UpdateLastActivity(endpointID)
if err := service.pingAgent(endpointID); err != nil {
return
case <-ctx.Done():
err := ctx.Err()
log.Debug().
Int("endpoint_id", int(endpointID)).
Err(err).
Msg("KeepTunnelAlive: ping agent")
Msg("KeepTunnelAlive: tunnel stop")
return
}
case <-maxAliveTicker.C:
log.Debug().
Int("endpoint_id", int(endpointID)).
Float64("timeout_minutes", maxAlive.Minutes()).
Msg("KeepTunnelAlive: tunnel keep alive timeout")
return
case <-ctx.Done():
err := ctx.Err()
log.Debug().
Int("endpoint_id", int(endpointID)).
Err(err).
Msg("KeepTunnelAlive: tunnel stop")
return
}
}
}()
}
// StartTunnelServer starts a tunnel server on the specified addr and port.
@@ -149,6 +121,7 @@ func (service *Service) keepTunnelAlive(endpointID portainer.EndpointID, ctx con
// The snapshotter is used in the tunnel status verification process.
func (service *Service) StartTunnelServer(addr, port string, snapshotService portainer.SnapshotService) error {
privateKeyFile, err := service.retrievePrivateKeyFile()
if err != nil {
return err
}
@@ -166,21 +139,21 @@ func (service *Service) StartTunnelServer(addr, port string, snapshotService por
service.serverFingerprint = chiselServer.GetFingerprint()
service.serverPort = port
if err := chiselServer.Start(addr, port); err != nil {
err = chiselServer.Start(addr, port)
if err != nil {
return err
}
service.chiselServer = chiselServer
// TODO: work-around Chisel default behavior.
// By default, Chisel will allow anyone to connect if no user exists.
username, password := generateRandomCredentials()
if err = service.chiselServer.AddUser(username, password, "127.0.0.1"); err != nil {
err = service.chiselServer.AddUser(username, password, "127.0.0.1")
if err != nil {
return err
}
service.snapshotService = snapshotService
go service.startTunnelVerificationLoop()
return nil
@@ -194,39 +167,37 @@ func (service *Service) StopTunnelServer() error {
func (service *Service) retrievePrivateKeyFile() (string, error) {
privateKeyFile := service.fileService.GetDefaultChiselPrivateKeyPath()
if exists, _ := service.fileService.FileExists(privateKeyFile); exists {
exist, _ := service.fileService.FileExists(privateKeyFile)
if !exist {
log.Debug().
Str("private-key", privateKeyFile).
Msg("Chisel private key file does not exist")
privateKey, err := ccrypto.GenerateKey("")
if err != nil {
log.Error().
Err(err).
Msg("Failed to generate chisel private key")
return "", err
}
err = service.fileService.StoreChiselPrivateKey(privateKey)
if err != nil {
log.Error().
Err(err).
Msg("Failed to save Chisel private key to disk")
return "", err
} else {
log.Info().
Str("private-key", privateKeyFile).
Msg("Generated a new Chisel private key file")
}
} else {
log.Info().
Str("private-key", privateKeyFile).
Msg("found Chisel private key file on disk")
return privateKeyFile, nil
Msg("Found Chisel private key file on disk")
}
log.Debug().
Str("private-key", privateKeyFile).
Msg("chisel private key file does not exist")
privateKey, err := ccrypto.GenerateKey("")
if err != nil {
log.Error().
Err(err).
Msg("failed to generate chisel private key")
return "", err
}
if err = service.fileService.StoreChiselPrivateKey(privateKey); err != nil {
log.Error().
Err(err).
Msg("failed to save Chisel private key to disk")
return "", err
}
log.Info().
Str("private-key", privateKeyFile).
Msg("generated a new Chisel private key file")
return privateKeyFile, nil
}
@@ -254,45 +225,63 @@ func (service *Service) startTunnelVerificationLoop() {
}
}
// checkTunnels finds the first tunnel that has not had any activity recently
// and attempts to take a snapshot, then closes it and returns
func (service *Service) checkTunnels() {
service.mu.RLock()
tunnels := make(map[portainer.EndpointID]portainer.TunnelDetails)
for endpointID, tunnel := range service.activeTunnels {
elapsed := time.Since(tunnel.LastActivity)
log.Debug().
Int("endpoint_id", int(endpointID)).
Float64("last_activity_seconds", elapsed.Seconds()).
Msg("environment tunnel monitoring")
if tunnel.Status == portainer.EdgeAgentManagementRequired && elapsed < activeTimeout {
service.mu.Lock()
for key, tunnel := range service.tunnelDetailsMap {
if tunnel.LastActivity.IsZero() || tunnel.Status == portainer.EdgeAgentIdle {
continue
}
tunnelPort := tunnel.Port
service.mu.RUnlock()
log.Debug().
Int("endpoint_id", int(endpointID)).
Float64("last_activity_seconds", elapsed.Seconds()).
Float64("timeout_seconds", activeTimeout.Seconds()).
Msg("last activity timeout exceeded")
if err := service.snapshotEnvironment(endpointID, tunnelPort); err != nil {
log.Error().
Int("endpoint_id", int(endpointID)).
Err(err).
Msg("unable to snapshot Edge environment")
if tunnel.Status == portainer.EdgeAgentManagementRequired && time.Since(tunnel.LastActivity) < requiredTimeout {
continue
}
service.close(endpointID)
if tunnel.Status == portainer.EdgeAgentActive && time.Since(tunnel.LastActivity) < activeTimeout {
continue
}
return
tunnels[key] = *tunnel
}
service.mu.Unlock()
service.mu.RUnlock()
for endpointID, tunnel := range tunnels {
elapsed := time.Since(tunnel.LastActivity)
log.Debug().
Int("endpoint_id", int(endpointID)).
Str("status", tunnel.Status).
Float64("status_time_seconds", elapsed.Seconds()).
Msg("environment tunnel monitoring")
if tunnel.Status == portainer.EdgeAgentManagementRequired && elapsed > requiredTimeout {
log.Debug().
Int("endpoint_id", int(endpointID)).
Str("status", tunnel.Status).
Float64("status_time_seconds", elapsed.Seconds()).
Float64("timeout_seconds", requiredTimeout.Seconds()).
Msg("REQUIRED state timeout exceeded")
}
if tunnel.Status == portainer.EdgeAgentActive && elapsed > activeTimeout {
log.Debug().
Int("endpoint_id", int(endpointID)).
Str("status", tunnel.Status).
Float64("status_time_seconds", elapsed.Seconds()).
Float64("timeout_seconds", activeTimeout.Seconds()).
Msg("ACTIVE state timeout exceeded")
err := service.snapshotEnvironment(endpointID, tunnel.Port)
if err != nil {
log.Error().
Int("endpoint_id", int(endpointID)).
Err(err).
Msg("unable to snapshot Edge environment")
}
}
service.SetTunnelStatusToIdle(portainer.EndpointID(endpointID))
}
}
func (service *Service) snapshotEnvironment(endpointID portainer.EndpointID, tunnelPort int) error {

View File

@@ -1,54 +0,0 @@
package chisel
import (
"context"
"net"
"net/http"
"testing"
"time"
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/datastore"
"github.com/stretchr/testify/require"
)
func TestPingAgentPanic(t *testing.T) {
endpoint := &portainer.Endpoint{
ID: 1,
EdgeID: "test-edge-id",
Type: portainer.EdgeAgentOnDockerEnvironment,
UserTrusted: true,
}
_, store := datastore.MustNewTestStore(t, true, true)
s := NewService(store, nil, nil)
defer func() {
require.Nil(t, recover())
}()
mux := http.NewServeMux()
mux.HandleFunc("/ping", func(w http.ResponseWriter, r *http.Request) {
time.Sleep(pingTimeout + 1*time.Second)
})
ln, err := net.ListenTCP("tcp", &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: 0})
require.NoError(t, err)
srv := &http.Server{Handler: mux}
errCh := make(chan error)
go func() {
errCh <- srv.Serve(ln)
}()
err = s.Open(endpoint)
require.NoError(t, err)
s.activeTunnels[endpoint.ID].Port = ln.Addr().(*net.TCPAddr).Port
require.Error(t, s.pingAgent(endpoint.ID))
require.NoError(t, srv.Shutdown(context.Background()))
require.ErrorIs(t, <-errCh, http.ErrServerClosed)
}

View File

@@ -5,18 +5,14 @@ import (
"errors"
"fmt"
"math/rand"
"net"
"strings"
"time"
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/internal/edge"
"github.com/portainer/portainer/api/internal/edge/cache"
"github.com/portainer/portainer/api/internal/endpointutils"
"github.com/portainer/portainer/pkg/libcrypto"
"github.com/dchest/uniuri"
"github.com/rs/zerolog/log"
)
const (
@@ -24,191 +20,18 @@ const (
maxAvailablePort = 65535
)
var (
ErrNonEdgeEnv = errors.New("cannot open a tunnel for non-edge environments")
ErrAsyncEnv = errors.New("cannot open a tunnel for async edge environments")
ErrInvalidEnv = errors.New("cannot open a tunnel for an invalid environment")
)
// Open will mark the tunnel as REQUIRED so the agent opens it
func (s *Service) Open(endpoint *portainer.Endpoint) error {
if !endpointutils.IsEdgeEndpoint(endpoint) {
return ErrNonEdgeEnv
}
if endpoint.Edge.AsyncMode {
return ErrAsyncEnv
}
if endpoint.ID == 0 || endpoint.EdgeID == "" || !endpoint.UserTrusted {
return ErrInvalidEnv
}
s.mu.Lock()
defer s.mu.Unlock()
if _, ok := s.activeTunnels[endpoint.ID]; ok {
return nil
}
defer cache.Del(endpoint.ID)
tun := &portainer.TunnelDetails{
Status: portainer.EdgeAgentManagementRequired,
Port: s.getUnusedPort(),
LastActivity: time.Now(),
}
username, password := generateRandomCredentials()
if s.chiselServer != nil {
authorizedRemote := fmt.Sprintf("^R:0.0.0.0:%d$", tun.Port)
if err := s.chiselServer.AddUser(username, password, authorizedRemote); err != nil {
return err
}
}
credentials, err := encryptCredentials(username, password, endpoint.EdgeID)
if err != nil {
return err
}
tun.Credentials = credentials
s.activeTunnels[endpoint.ID] = tun
return nil
}
// close removes the tunnel from the map so the agent will close it
func (s *Service) close(endpointID portainer.EndpointID) {
s.mu.Lock()
defer s.mu.Unlock()
tun, ok := s.activeTunnels[endpointID]
if !ok {
return
}
if len(tun.Credentials) > 0 && s.chiselServer != nil {
user, _, _ := strings.Cut(tun.Credentials, ":")
s.chiselServer.DeleteUser(user)
}
if s.ProxyManager != nil {
s.ProxyManager.DeleteEndpointProxy(endpointID)
}
delete(s.activeTunnels, endpointID)
cache.Del(endpointID)
}
// Config returns the tunnel details needed for the agent to connect
func (s *Service) Config(endpointID portainer.EndpointID) portainer.TunnelDetails {
s.mu.RLock()
defer s.mu.RUnlock()
if tun, ok := s.activeTunnels[endpointID]; ok {
return *tun
}
return portainer.TunnelDetails{Status: portainer.EdgeAgentIdle}
}
// TunnelAddr returns the address of the local tunnel, including the port, it
// will block until the tunnel is ready
func (s *Service) TunnelAddr(endpoint *portainer.Endpoint) (string, error) {
if err := s.Open(endpoint); err != nil {
return "", err
}
tun := s.Config(endpoint.ID)
checkinInterval := time.Duration(s.tryEffectiveCheckinInterval(endpoint)) * time.Second
for t0 := time.Now(); ; {
if time.Since(t0) > 2*checkinInterval {
s.close(endpoint.ID)
return "", errors.New("unable to open the tunnel")
}
// Check if the tunnel is established
conn, err := net.DialTCP("tcp", nil, &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: tun.Port})
if err != nil {
time.Sleep(checkinInterval / 100)
continue
}
conn.Close()
break
}
s.UpdateLastActivity(endpoint.ID)
return fmt.Sprintf("127.0.0.1:%d", tun.Port), nil
}
// tryEffectiveCheckinInterval avoids a potential deadlock by returning a
// previous known value after a timeout
func (s *Service) tryEffectiveCheckinInterval(endpoint *portainer.Endpoint) int {
ch := make(chan int, 1)
go func() {
ch <- edge.EffectiveCheckinInterval(s.dataStore, endpoint)
}()
select {
case <-time.After(50 * time.Millisecond):
s.mu.RLock()
defer s.mu.RUnlock()
return s.defaultCheckinInterval
case i := <-ch:
s.mu.Lock()
s.defaultCheckinInterval = i
s.mu.Unlock()
return i
}
}
// UpdateLastActivity sets the current timestamp to avoid the tunnel timeout
func (s *Service) UpdateLastActivity(endpointID portainer.EndpointID) {
s.mu.Lock()
defer s.mu.Unlock()
if tun, ok := s.activeTunnels[endpointID]; ok {
tun.LastActivity = time.Now()
}
}
// NOTE: it needs to be called with the lock acquired
// getUnusedPort is used to generate an unused random port in the dynamic port range.
// Dynamic ports (also called private ports) are 49152 to 65535.
func (service *Service) getUnusedPort() int {
port := randomInt(minAvailablePort, maxAvailablePort)
for _, tunnel := range service.activeTunnels {
for _, tunnel := range service.tunnelDetailsMap {
if tunnel.Port == port {
return service.getUnusedPort()
}
}
conn, err := net.DialTCP("tcp", nil, &net.TCPAddr{IP: net.IPv4(127, 0, 0, 1), Port: port})
if err == nil {
conn.Close()
log.Debug().
Int("port", port).
Msg("selected port is in use, trying a different one")
return service.getUnusedPort()
}
return port
}
@@ -216,10 +39,152 @@ func randomInt(min, max int) int {
return min + rand.Intn(max-min)
}
// NOTE: it needs to be called with the lock acquired
func (service *Service) getTunnelDetails(endpointID portainer.EndpointID) *portainer.TunnelDetails {
if tunnel, ok := service.tunnelDetailsMap[endpointID]; ok {
return tunnel
}
tunnel := &portainer.TunnelDetails{
Status: portainer.EdgeAgentIdle,
}
service.tunnelDetailsMap[endpointID] = tunnel
cache.Del(endpointID)
return tunnel
}
// GetTunnelDetails returns information about the tunnel associated to an environment(endpoint).
func (service *Service) GetTunnelDetails(endpointID portainer.EndpointID) portainer.TunnelDetails {
service.mu.Lock()
defer service.mu.Unlock()
return *service.getTunnelDetails(endpointID)
}
// GetActiveTunnel retrieves an active tunnel which allows communicating with edge agent
func (service *Service) GetActiveTunnel(endpoint *portainer.Endpoint) (portainer.TunnelDetails, error) {
if endpoint.Edge.AsyncMode {
return portainer.TunnelDetails{}, errors.New("cannot open tunnel on async endpoint")
}
tunnel := service.GetTunnelDetails(endpoint.ID)
if tunnel.Status == portainer.EdgeAgentActive {
// update the LastActivity
service.SetTunnelStatusToActive(endpoint.ID)
}
if tunnel.Status == portainer.EdgeAgentIdle || tunnel.Status == portainer.EdgeAgentManagementRequired {
err := service.SetTunnelStatusToRequired(endpoint.ID)
if err != nil {
return portainer.TunnelDetails{}, fmt.Errorf("failed opening tunnel to endpoint: %w", err)
}
if endpoint.EdgeCheckinInterval == 0 {
settings, err := service.dataStore.Settings().Settings()
if err != nil {
return portainer.TunnelDetails{}, fmt.Errorf("failed fetching settings from db: %w", err)
}
endpoint.EdgeCheckinInterval = settings.EdgeAgentCheckinInterval
}
time.Sleep(2 * time.Duration(endpoint.EdgeCheckinInterval) * time.Second)
}
return service.GetTunnelDetails(endpoint.ID), nil
}
// SetTunnelStatusToActive update the status of the tunnel associated to the specified environment(endpoint).
// It sets the status to ACTIVE.
func (service *Service) SetTunnelStatusToActive(endpointID portainer.EndpointID) {
service.mu.Lock()
tunnel := service.getTunnelDetails(endpointID)
tunnel.Status = portainer.EdgeAgentActive
tunnel.Credentials = ""
tunnel.LastActivity = time.Now()
service.mu.Unlock()
cache.Del(endpointID)
}
// SetTunnelStatusToIdle update the status of the tunnel associated to the specified environment(endpoint).
// It sets the status to IDLE.
// It removes any existing credentials associated to the tunnel.
func (service *Service) SetTunnelStatusToIdle(endpointID portainer.EndpointID) {
service.mu.Lock()
tunnel := service.getTunnelDetails(endpointID)
tunnel.Status = portainer.EdgeAgentIdle
tunnel.Port = 0
tunnel.LastActivity = time.Now()
credentials := tunnel.Credentials
if credentials != "" {
tunnel.Credentials = ""
if service.chiselServer != nil {
service.chiselServer.DeleteUser(strings.Split(credentials, ":")[0])
}
}
service.ProxyManager.DeleteEndpointProxy(endpointID)
service.mu.Unlock()
cache.Del(endpointID)
}
// SetTunnelStatusToRequired update the status of the tunnel associated to the specified environment(endpoint).
// It sets the status to REQUIRED.
// If no port is currently associated to the tunnel, it will associate a random unused port to the tunnel
// and generate temporary credentials that can be used to establish a reverse tunnel on that port.
// Credentials are encrypted using the Edge ID associated to the environment(endpoint).
func (service *Service) SetTunnelStatusToRequired(endpointID portainer.EndpointID) error {
defer cache.Del(endpointID)
tunnel := service.getTunnelDetails(endpointID)
service.mu.Lock()
defer service.mu.Unlock()
if tunnel.Port == 0 {
endpoint, err := service.dataStore.Endpoint().Endpoint(endpointID)
if err != nil {
return err
}
tunnel.Status = portainer.EdgeAgentManagementRequired
tunnel.Port = service.getUnusedPort()
tunnel.LastActivity = time.Now()
username, password := generateRandomCredentials()
authorizedRemote := fmt.Sprintf("^R:0.0.0.0:%d$", tunnel.Port)
if service.chiselServer != nil {
err = service.chiselServer.AddUser(username, password, authorizedRemote)
if err != nil {
return err
}
}
credentials, err := encryptCredentials(username, password, endpoint.EdgeID)
if err != nil {
return err
}
tunnel.Credentials = credentials
}
return nil
}
func generateRandomCredentials() (string, string) {
username := uniuri.NewLen(8)
password := uniuri.NewLen(8)
return username, password
}

View File

@@ -17,20 +17,24 @@ import (
type Service struct{}
var (
ErrInvalidEndpointProtocol = errors.New("Invalid environment protocol: Portainer only supports unix://, npipe:// or tcp://")
ErrSocketOrNamedPipeNotFound = errors.New("Unable to locate Unix socket or named pipe")
ErrInvalidSnapshotInterval = errors.New("Invalid snapshot interval")
ErrAdminPassExcludeAdminPassFile = errors.New("Cannot use --admin-password with --admin-password-file")
errInvalidEndpointProtocol = errors.New("Invalid environment protocol: Portainer only supports unix://, npipe:// or tcp://")
errSocketOrNamedPipeNotFound = errors.New("Unable to locate Unix socket or named pipe")
errInvalidSnapshotInterval = errors.New("Invalid snapshot interval")
errAdminPassExcludeAdminPassFile = errors.New("Cannot use --admin-password with --admin-password-file")
)
func CLIFlags() *portainer.CLIFlags {
return &portainer.CLIFlags{
// ParseFlags parse the CLI flags and return a portainer.Flags struct
func (*Service) ParseFlags(version string) (*portainer.CLIFlags, error) {
kingpin.Version(version)
flags := &portainer.CLIFlags{
Addr: kingpin.Flag("bind", "Address and port to serve Portainer").Default(defaultBindAddress).Short('p').String(),
AddrHTTPS: kingpin.Flag("bind-https", "Address and port to serve Portainer via https").Default(defaultHTTPSBindAddress).String(),
TunnelAddr: kingpin.Flag("tunnel-addr", "Address to serve the tunnel server").Default(defaultTunnelServerAddress).String(),
TunnelPort: kingpin.Flag("tunnel-port", "Port to serve the tunnel server").Default(defaultTunnelServerPort).String(),
Assets: kingpin.Flag("assets", "Path to the assets").Default(defaultAssetsDirectory).Short('a').String(),
Data: kingpin.Flag("data", "Path to the folder where the data is stored").Default(defaultDataDirectory).Short('d').String(),
DemoEnvironment: kingpin.Flag("demo", "Demo environment").Bool(),
EndpointURL: kingpin.Flag("host", "Environment URL").Short('H').String(),
FeatureFlags: kingpin.Flag("feat", "List of feature flags").Strings(),
EnableEdgeComputeFeatures: kingpin.Flag("edge-compute", "Enable Edge Compute features").Bool(),
@@ -58,15 +62,8 @@ func CLIFlags() *portainer.CLIFlags {
MaxBatchDelay: kingpin.Flag("max-batch-delay", "Maximum delay before a batch starts").Duration(),
SecretKeyName: kingpin.Flag("secret-key-name", "Secret key name for encryption and will be used as /run/secrets/<secret-key-name>.").Default(defaultSecretKeyName).String(),
LogLevel: kingpin.Flag("log-level", "Set the minimum logging level to show").Default("INFO").Enum("DEBUG", "INFO", "WARN", "ERROR"),
LogMode: kingpin.Flag("log-mode", "Set the logging output mode").Default("PRETTY").Enum("NOCOLOR", "PRETTY", "JSON"),
LogMode: kingpin.Flag("log-mode", "Set the logging output mode").Default("PRETTY").Enum("PRETTY", "JSON"),
}
}
// ParseFlags parse the CLI flags and return a portainer.Flags struct
func (*Service) ParseFlags(version string) (*portainer.CLIFlags, error) {
kingpin.Version(version)
flags := CLIFlags()
kingpin.Parse()
@@ -86,16 +83,18 @@ func (*Service) ParseFlags(version string) (*portainer.CLIFlags, error) {
func (*Service) ValidateFlags(flags *portainer.CLIFlags) error {
displayDeprecationWarnings(flags)
if err := validateEndpointURL(*flags.EndpointURL); err != nil {
err := validateEndpointURL(*flags.EndpointURL)
if err != nil {
return err
}
if err := validateSnapshotInterval(*flags.SnapshotInterval); err != nil {
err = validateSnapshotInterval(*flags.SnapshotInterval)
if err != nil {
return err
}
if *flags.AdminPassword != "" && *flags.AdminPasswordFile != "" {
return ErrAdminPassExcludeAdminPassFile
return errAdminPassExcludeAdminPassFile
}
return nil
@@ -117,16 +116,15 @@ func validateEndpointURL(endpointURL string) error {
}
if !strings.HasPrefix(endpointURL, "unix://") && !strings.HasPrefix(endpointURL, "tcp://") && !strings.HasPrefix(endpointURL, "npipe://") {
return ErrInvalidEndpointProtocol
return errInvalidEndpointProtocol
}
if strings.HasPrefix(endpointURL, "unix://") || strings.HasPrefix(endpointURL, "npipe://") {
socketPath := strings.TrimPrefix(endpointURL, "unix://")
socketPath = strings.TrimPrefix(socketPath, "npipe://")
if _, err := os.Stat(socketPath); err != nil {
if os.IsNotExist(err) {
return ErrSocketOrNamedPipeNotFound
return errSocketOrNamedPipeNotFound
}
return err
@@ -141,8 +139,9 @@ func validateSnapshotInterval(snapshotInterval string) error {
return nil
}
if _, err := time.ParseDuration(snapshotInterval); err != nil {
return ErrInvalidSnapshotInterval
_, err := time.ParseDuration(snapshotInterval)
if err != nil {
return errInvalidSnapshotInterval
}
return nil

View File

@@ -42,19 +42,12 @@ func setLoggingMode(mode string) {
TimeFormat: "2006/01/02 03:04PM",
FormatMessage: formatMessage,
})
case "NOCOLOR":
log.Logger = log.Output(zerolog.ConsoleWriter{
Out: os.Stderr,
TimeFormat: "2006/01/02 03:04PM",
FormatMessage: formatMessage,
NoColor: true,
})
case "JSON":
log.Logger = log.Output(os.Stderr)
}
}
func formatMessage(i any) string {
func formatMessage(i interface{}) string {
if i == nil {
return ""
}

View File

@@ -1,7 +1,6 @@
package main
import (
"cmp"
"context"
"crypto/sha256"
"os"
@@ -20,7 +19,7 @@ import (
"github.com/portainer/portainer/api/dataservices"
"github.com/portainer/portainer/api/datastore"
"github.com/portainer/portainer/api/datastore/migrator"
"github.com/portainer/portainer/api/datastore/postinit"
"github.com/portainer/portainer/api/demo"
"github.com/portainer/portainer/api/docker"
dockerclient "github.com/portainer/portainer/api/docker/client"
"github.com/portainer/portainer/api/exec"
@@ -43,9 +42,6 @@ import (
"github.com/portainer/portainer/api/ldap"
"github.com/portainer/portainer/api/oauth"
"github.com/portainer/portainer/api/pendingactions"
"github.com/portainer/portainer/api/pendingactions/actions"
"github.com/portainer/portainer/api/pendingactions/handlers"
"github.com/portainer/portainer/api/platform"
"github.com/portainer/portainer/api/scheduler"
"github.com/portainer/portainer/api/stacks/deployments"
"github.com/portainer/portainer/pkg/featureflags"
@@ -58,14 +54,14 @@ import (
)
func initCLI() *portainer.CLIFlags {
cliService := &cli.Service{}
var cliService portainer.CLIService = &cli.Service{}
flags, err := cliService.ParseFlags(portainer.APIVersion)
if err != nil {
log.Fatal().Err(err).Msg("failed parsing flags")
}
if err := cliService.ValidateFlags(flags); err != nil {
err = cliService.ValidateFlags(flags)
if err != nil {
log.Fatal().Err(err).Msg("failed validating flags")
}
@@ -96,14 +92,14 @@ func initDataStore(flags *portainer.CLIFlags, secretKey []byte, fileService port
}
store := datastore.NewStore(*flags.Data, fileService, connection)
isNew, err := store.Open()
if err != nil {
log.Fatal().Err(err).Msg("failed opening store")
}
if *flags.Rollback {
if err := store.Rollback(false); err != nil {
err := store.Rollback(false)
if err != nil {
log.Fatal().Err(err).Msg("failed rolling back")
}
@@ -112,7 +108,8 @@ func initDataStore(flags *portainer.CLIFlags, secretKey []byte, fileService port
}
// Init sets some defaults - it's basically a migration
if err := store.Init(); err != nil {
err = store.Init()
if err != nil {
log.Fatal().Err(err).Msg("failed initializing data store")
}
@@ -134,23 +131,25 @@ func initDataStore(flags *portainer.CLIFlags, secretKey []byte, fileService port
}
store.VersionService.UpdateVersion(&v)
if err := updateSettingsFromFlags(store, flags); err != nil {
err = updateSettingsFromFlags(store, flags)
if err != nil {
log.Fatal().Err(err).Msg("failed updating settings from flags")
}
} else {
if err := store.MigrateData(); err != nil {
err = store.MigrateData()
if err != nil {
log.Fatal().Err(err).Msg("failed migration")
}
}
if err := updateSettingsFromFlags(store, flags); err != nil {
err = updateSettingsFromFlags(store, flags)
if err != nil {
log.Fatal().Err(err).Msg("failed updating settings from flags")
}
// this is for the db restore functionality - needs more tests.
go func() {
<-shutdownCtx.Done()
defer connection.Close()
}()
@@ -204,16 +203,36 @@ func initJWTService(userSessionTimeout string, dataStore dataservices.DataStore)
userSessionTimeout = portainer.DefaultUserSessionTimeout
}
return jwt.NewService(userSessionTimeout, dataStore)
jwtService, err := jwt.NewService(userSessionTimeout, dataStore)
if err != nil {
return nil, err
}
return jwtService, nil
}
func initDigitalSignatureService() portainer.DigitalSignatureService {
return crypto.NewECDSAService(os.Getenv("AGENT_SECRET"))
}
func initCryptoService() portainer.CryptoService {
return &crypto.Service{}
}
func initLDAPService() portainer.LDAPService {
return &ldap.Service{}
}
func initOAuthService() portainer.OAuthService {
return oauth.NewService()
}
func initGitService(ctx context.Context) portainer.GitService {
return git.NewService(ctx)
}
func initSSLService(addr, certPath, keyPath string, fileService portainer.FileService, dataStore dataservices.DataStore, shutdownTrigger context.CancelFunc) (*ssl.Service, error) {
slices := strings.Split(addr, ":")
host := slices[0]
if host == "" {
host = "0.0.0.0"
@@ -221,13 +240,22 @@ func initSSLService(addr, certPath, keyPath string, fileService portainer.FileSe
sslService := ssl.NewService(fileService, dataStore, shutdownTrigger)
if err := sslService.Init(host, certPath, keyPath); err != nil {
err := sslService.Init(host, certPath, keyPath)
if err != nil {
return nil, err
}
return sslService, nil
}
func initDockerClientFactory(signatureService portainer.DigitalSignatureService, reverseTunnelService portainer.ReverseTunnelService) *dockerclient.ClientFactory {
return dockerclient.NewClientFactory(signatureService, reverseTunnelService)
}
func initKubernetesClientFactory(signatureService portainer.DigitalSignatureService, reverseTunnelService portainer.ReverseTunnelService, dataStore dataservices.DataStore, instanceID, addrHTTPS, userSessionTimeout string) (*kubecli.ClientFactory, error) {
return kubecli.NewClientFactory(signatureService, reverseTunnelService, dataStore, instanceID, addrHTTPS, userSessionTimeout)
}
func initSnapshotService(
snapshotIntervalFromFlag string,
dataStore dataservices.DataStore,
@@ -260,21 +288,34 @@ func updateSettingsFromFlags(dataStore dataservices.DataStore, flags *portainer.
return err
}
settings.SnapshotInterval = *cmp.Or(flags.SnapshotInterval, &settings.SnapshotInterval)
settings.LogoURL = *cmp.Or(flags.Logo, &settings.LogoURL)
settings.EnableEdgeComputeFeatures = *cmp.Or(flags.EnableEdgeComputeFeatures, &settings.EnableEdgeComputeFeatures)
settings.TemplatesURL = *cmp.Or(flags.Templates, &settings.TemplatesURL)
if *flags.SnapshotInterval != "" {
settings.SnapshotInterval = *flags.SnapshotInterval
}
if *flags.Logo != "" {
settings.LogoURL = *flags.Logo
}
if *flags.EnableEdgeComputeFeatures {
settings.EnableEdgeComputeFeatures = *flags.EnableEdgeComputeFeatures
}
if *flags.Templates != "" {
settings.TemplatesURL = *flags.Templates
}
if *flags.Labels != nil {
settings.BlackListedLabels = *flags.Labels
}
settings.AgentSecret = ""
if agentKey, ok := os.LookupEnv("AGENT_SECRET"); ok {
settings.AgentSecret = agentKey
} else {
settings.AgentSecret = ""
}
if err := dataStore.Settings().UpdateSettings(settings); err != nil {
err = dataStore.Settings().UpdateSettings(settings)
if err != nil {
return err
}
@@ -297,7 +338,6 @@ func loadAndParseKeyPair(fileService portainer.FileService, signatureService por
if err != nil {
return err
}
return signatureService.ParseKeyPair(private, public)
}
@@ -306,9 +346,7 @@ func generateAndStoreKeyPair(fileService portainer.FileService, signatureService
if err != nil {
return err
}
privateHeader, publicHeader := signatureService.PEMHeaders()
return fileService.StoreKeyPair(private, public, privateHeader, publicHeader)
}
@@ -321,7 +359,6 @@ func initKeyPair(fileService portainer.FileService, signatureService portainer.D
if existingKeyPair {
return loadAndParseKeyPair(fileService, signatureService)
}
return generateAndStoreKeyPair(fileService, signatureService)
}
@@ -339,7 +376,6 @@ func loadEncryptionSecretKey(keyfilename string) []byte {
// return a 32 byte hash of the secret (required for AES)
hash := sha256.Sum256(content)
return hash[:]
}
@@ -384,17 +420,17 @@ func buildServer(flags *portainer.CLIFlags) portainer.Server {
log.Fatal().Err(err).Msg("failed initializing JWT service")
}
ldapService := &ldap.Service{}
ldapService := initLDAPService()
oauthService := oauth.NewService()
oauthService := initOAuthService()
gitService := git.NewService(shutdownCtx)
gitService := initGitService(shutdownCtx)
openAMTService := openamt.NewService()
cryptoService := &crypto.Service{}
cryptoService := initCryptoService()
signatureService := initDigitalSignatureService()
digitalSignatureService := initDigitalSignatureService()
edgeStacksService := edgestacks.NewService(dataStore)
@@ -408,27 +444,35 @@ func buildServer(flags *portainer.CLIFlags) portainer.Server {
log.Fatal().Err(err).Msg("failed to get SSL settings")
}
if err := initKeyPair(fileService, signatureService); err != nil {
err = initKeyPair(fileService, digitalSignatureService)
if err != nil {
log.Fatal().Err(err).Msg("failed initializing key pair")
}
reverseTunnelService := chisel.NewService(dataStore, shutdownCtx, fileService)
dockerClientFactory := dockerclient.NewClientFactory(signatureService, reverseTunnelService)
kubernetesClientFactory, err := kubecli.NewClientFactory(signatureService, reverseTunnelService, dataStore, instanceID, *flags.AddrHTTPS, settings.UserSessionTimeout)
dockerClientFactory := initDockerClientFactory(digitalSignatureService, reverseTunnelService)
kubernetesClientFactory, err := initKubernetesClientFactory(digitalSignatureService, reverseTunnelService, dataStore, instanceID, *flags.AddrHTTPS, settings.UserSessionTimeout)
if err != nil {
log.Fatal().Err(err).Msg("failed initializing Kubernetes Client Factory service")
log.Fatal().Err(err).Msg("failed initializing kubernetes client factory")
}
authorizationService := authorization.NewService(dataStore)
authorizationService.K8sClientFactory = kubernetesClientFactory
pendingActionsService := pendingactions.NewService(dataStore, kubernetesClientFactory, authorizationService, shutdownCtx)
snapshotService, err := initSnapshotService(*flags.SnapshotInterval, dataStore, dockerClientFactory, kubernetesClientFactory, shutdownCtx, pendingActionsService)
if err != nil {
log.Fatal().Err(err).Msg("failed initializing snapshot service")
}
snapshotService.Start()
kubernetesTokenCacheManager := kubeproxy.NewTokenCacheManager()
kubeClusterAccessService := kubernetes.NewKubeClusterAccessService(*flags.BaseURL, *flags.AddrHTTPS, sslSettings.CertPath)
proxyManager := proxy.NewManager(kubernetesClientFactory)
proxyManager := proxy.NewManager(dataStore, digitalSignatureService, reverseTunnelService, dockerClientFactory, kubernetesClientFactory, kubernetesTokenCacheManager, gitService)
reverseTunnelService.ProxyManager = proxyManager
@@ -441,45 +485,39 @@ func buildServer(flags *portainer.CLIFlags) portainer.Server {
composeStackManager := initComposeStackManager(composeDeployer, proxyManager)
swarmStackManager, err := initSwarmStackManager(*flags.Assets, dockerConfigPath, signatureService, fileService, reverseTunnelService, dataStore)
swarmStackManager, err := initSwarmStackManager(*flags.Assets, dockerConfigPath, digitalSignatureService, fileService, reverseTunnelService, dataStore)
if err != nil {
log.Fatal().Err(err).Msg("failed initializing swarm stack manager")
}
kubernetesDeployer := initKubernetesDeployer(kubernetesTokenCacheManager, kubernetesClientFactory, dataStore, reverseTunnelService, signatureService, proxyManager, *flags.Assets)
pendingActionsService := pendingactions.NewService(dataStore, kubernetesClientFactory)
pendingActionsService.RegisterHandler(actions.CleanNAPWithOverridePolicies, handlers.NewHandlerCleanNAPWithOverridePolicies(authorizationService, dataStore))
pendingActionsService.RegisterHandler(actions.DeletePortainerK8sRegistrySecrets, handlers.NewHandlerDeleteRegistrySecrets(authorizationService, dataStore, kubernetesClientFactory))
pendingActionsService.RegisterHandler(actions.PostInitMigrateEnvironment, handlers.NewHandlerPostInitMigrateEnvironment(authorizationService, dataStore, kubernetesClientFactory, dockerClientFactory, *flags.Assets, kubernetesDeployer))
snapshotService, err := initSnapshotService(*flags.SnapshotInterval, dataStore, dockerClientFactory, kubernetesClientFactory, shutdownCtx, pendingActionsService)
if err != nil {
log.Fatal().Err(err).Msg("failed initializing snapshot service")
}
snapshotService.Start()
proxyManager.NewProxyFactory(dataStore, signatureService, reverseTunnelService, dockerClientFactory, kubernetesClientFactory, kubernetesTokenCacheManager, gitService, snapshotService)
kubernetesDeployer := initKubernetesDeployer(kubernetesTokenCacheManager, kubernetesClientFactory, dataStore, reverseTunnelService, digitalSignatureService, proxyManager, *flags.Assets)
helmPackageManager, err := initHelmPackageManager(*flags.Assets)
if err != nil {
log.Fatal().Err(err).Msg("failed initializing helm package manager")
}
if err := edge.LoadEdgeJobs(dataStore, reverseTunnelService); err != nil {
err = edge.LoadEdgeJobs(dataStore, reverseTunnelService)
if err != nil {
log.Fatal().Err(err).Msg("failed loading edge jobs from database")
}
applicationStatus := initStatus(instanceID)
demoService := demo.NewService()
if *flags.DemoEnvironment {
err := demoService.Init(dataStore, cryptoService)
if err != nil {
log.Fatal().Err(err).Msg("failed initializing demo environment")
}
}
// channel to control when the admin user is created
adminCreationDone := make(chan struct{}, 1)
go endpointutils.InitEndpoint(shutdownCtx, adminCreationDone, flags, dataStore, snapshotService)
adminPasswordHash := ""
if *flags.AdminPasswordFile != "" {
content, err := fileService.GetFileContent(*flags.AdminPasswordFile, "")
if err != nil {
@@ -502,14 +540,14 @@ func buildServer(flags *portainer.CLIFlags) portainer.Server {
if len(users) == 0 {
log.Info().Msg("created admin user with the given password.")
user := &portainer.User{
Username: "admin",
Role: portainer.AdministratorRole,
Password: adminPasswordHash,
}
if err := dataStore.User().Create(user); err != nil {
err := dataStore.User().Create(user)
if err != nil {
log.Fatal().Err(err).Msg("failed creating admin user")
}
@@ -520,7 +558,8 @@ func buildServer(flags *portainer.CLIFlags) portainer.Server {
}
}
if err := reverseTunnelService.StartTunnelServer(*flags.TunnelAddr, *flags.TunnelPort, snapshotService); err != nil {
err = reverseTunnelService.StartTunnelServer(*flags.TunnelAddr, *flags.TunnelPort, snapshotService)
if err != nil {
log.Fatal().Err(err).Msg("failed starting tunnel server")
}
@@ -533,20 +572,7 @@ func buildServer(flags *portainer.CLIFlags) portainer.Server {
log.Fatal().Msg("failed to fetch SSL settings from DB")
}
platformService, err := platform.NewService(dataStore)
if err != nil {
log.Fatal().Err(err).Msg("failed initializing platform service")
}
upgradeService, err := upgrade.NewService(
*flags.Assets,
kubernetesClientFactory,
dockerClientFactory,
composeStackManager,
dataStore,
fileService,
stackDeployer,
)
upgradeService, err := upgrade.NewService(*flags.Assets, composeDeployer, kubernetesClientFactory)
if err != nil {
log.Fatal().Err(err).Msg("failed initializing upgrade service")
}
@@ -555,12 +581,10 @@ func buildServer(flags *portainer.CLIFlags) portainer.Server {
// but some more complex migrations require access to a kubernetes or docker
// client. Therefore we run a separate migration process just before
// starting the server.
postInitMigrator := postinit.NewPostInitMigrator(
postInitMigrator := datastore.NewPostInitMigrator(
kubernetesClientFactory,
dockerClientFactory,
dataStore,
*flags.Assets,
kubernetesDeployer,
)
if err := postInitMigrator.PostInitMigrate(); err != nil {
log.Fatal().Err(err).Msg("failure during post init migrations")
@@ -591,7 +615,7 @@ func buildServer(flags *portainer.CLIFlags) portainer.Server {
ProxyManager: proxyManager,
KubernetesTokenCacheManager: kubernetesTokenCacheManager,
KubeClusterAccessService: kubeClusterAccessService,
SignatureService: signatureService,
SignatureService: digitalSignatureService,
SnapshotService: snapshotService,
SSLService: sslService,
DockerClientFactory: dockerClientFactory,
@@ -600,10 +624,10 @@ func buildServer(flags *portainer.CLIFlags) portainer.Server {
ShutdownCtx: shutdownCtx,
ShutdownTrigger: shutdownTrigger,
StackDeployer: stackDeployer,
DemoService: demoService,
UpgradeService: upgradeService,
AdminCreationDone: adminCreationDone,
PendingActionsService: pendingActionsService,
PlatformService: platformService,
}
}
@@ -618,7 +642,6 @@ func main() {
for {
server := buildServer(flags)
log.Info().
Str("version", portainer.APIVersion).
Str("build_number", build.BuildNumber).
@@ -630,7 +653,6 @@ func main() {
Msg("starting Portainer")
err := server.Start()
log.Info().Err(err).Msg("HTTP server exited")
}
}

View File

@@ -1,148 +0,0 @@
// Package concurrent provides utilities for running multiple functions concurrently in Go.
// For example, many kubernetes calls can take a while to fulfill. Oftentimes in Portainer
// we need to get a list of objects from multiple kubernetes REST APIs. We can often call these
// apis concurrently to speed up the response time.
// This package provides a clean way to do just that.
//
// Examples:
// The ConfigMaps and Secrets function converted using concurrent.Run.
/*
// GetConfigMapsAndSecrets gets all the ConfigMaps AND all the Secrets for a
// given namespace in a k8s endpoint. The result is a list of both config maps
// and secrets. The IsSecret boolean property indicates if a given struct is a
// secret or configmap.
func (kcl *KubeClient) GetConfigMapsAndSecrets(namespace string) ([]models.K8sConfigMapOrSecret, error) {
// use closures to capture the current kube client and namespace by declaring wrapper functions
// that match the interface signature for concurrent.Func
listConfigMaps := func(ctx context.Context) (any, error) {
return kcl.cli.CoreV1().ConfigMaps(namespace).List(context.Background(), meta.ListOptions{})
}
listSecrets := func(ctx context.Context) (any, error) {
return kcl.cli.CoreV1().Secrets(namespace).List(context.Background(), meta.ListOptions{})
}
// run the functions concurrently and wait for results. We can also pass in a context to cancel.
// e.g. Deadline timer.
results, err := concurrent.Run(context.TODO(), listConfigMaps, listSecrets)
if err != nil {
return nil, err
}
var configMapList *core.ConfigMapList
var secretList *core.SecretList
for _, r := range results {
switch v := r.Result.(type) {
case *core.ConfigMapList:
configMapList = v
case *core.SecretList:
secretList = v
}
}
// TODO: Applications
var combined []models.K8sConfigMapOrSecret
for _, m := range configMapList.Items {
var cm models.K8sConfigMapOrSecret
cm.UID = string(m.UID)
cm.Name = m.Name
cm.Namespace = m.Namespace
cm.Annotations = m.Annotations
cm.Data = m.Data
cm.CreationDate = m.CreationTimestamp.Time.UTC().Format(time.RFC3339)
combined = append(combined, cm)
}
for _, s := range secretList.Items {
var secret models.K8sConfigMapOrSecret
secret.UID = string(s.UID)
secret.Name = s.Name
secret.Namespace = s.Namespace
secret.Annotations = s.Annotations
secret.Data = msbToMss(s.Data)
secret.CreationDate = s.CreationTimestamp.Time.UTC().Format(time.RFC3339)
secret.IsSecret = true
secret.SecretType = string(s.Type)
combined = append(combined, secret)
}
return combined, nil
}
*/
package concurrent
import (
"context"
"sync"
)
// Result contains the result and any error returned from running a client task function
type Result struct {
Result any // the result of running the task function
Err error // any error that occurred while running the task function
}
// Func is a function returns a result or error
type Func func(ctx context.Context) (any, error)
// Run runs a list of functions returns the results
func Run(ctx context.Context, maxConcurrency int, tasks ...Func) ([]Result, error) {
var wg sync.WaitGroup
resultsChan := make(chan Result, len(tasks))
taskChan := make(chan Func, len(tasks))
localCtx, cancelCtx := context.WithCancel(ctx)
defer cancelCtx()
runTask := func() {
defer wg.Done()
for fn := range taskChan {
result, err := fn(localCtx)
resultsChan <- Result{Result: result, Err: err}
}
}
// Set maxConcurrency to the number of tasks if zero or negative
if maxConcurrency <= 0 {
maxConcurrency = len(tasks)
}
// Start worker goroutines
for range maxConcurrency {
wg.Add(1)
go runTask()
}
// Add tasks to the task channel
for _, fn := range tasks {
taskChan <- fn
}
// Close the task channel to signal workers to stop when all tasks are done
close(taskChan)
// Wait for all workers to complete
wg.Wait()
close(resultsChan)
// Collect the results and cancel on error
results := make([]Result, 0, len(tasks))
for r := range resultsChan {
if r.Err != nil {
cancelCtx()
return nil, r.Err
}
results = append(results, r)
}
return results, nil
}

View File

@@ -5,21 +5,22 @@ import (
)
type ReadTransaction interface {
GetObject(bucketName string, key []byte, object any) error
GetAll(bucketName string, obj any, append func(o any) (any, error)) error
GetAllWithKeyPrefix(bucketName string, keyPrefix []byte, obj any, append func(o any) (any, error)) error
GetObject(bucketName string, key []byte, object interface{}) error
GetAll(bucketName string, obj interface{}, append func(o interface{}) (interface{}, error)) error
GetAllWithJsoniter(bucketName string, obj interface{}, append func(o interface{}) (interface{}, error)) error
GetAllWithKeyPrefix(bucketName string, keyPrefix []byte, obj interface{}, append func(o interface{}) (interface{}, error)) error
}
type Transaction interface {
ReadTransaction
SetServiceName(bucketName string) error
UpdateObject(bucketName string, key []byte, object any) error
UpdateObject(bucketName string, key []byte, object interface{}) error
DeleteObject(bucketName string, key []byte) error
CreateObject(bucketName string, fn func(uint64) (int, any)) error
CreateObjectWithId(bucketName string, id int, obj any) error
CreateObjectWithStringId(bucketName string, id []byte, obj any) error
DeleteAllObjects(bucketName string, obj any, matching func(o any) (id int, ok bool)) error
CreateObject(bucketName string, fn func(uint64) (int, interface{})) error
CreateObjectWithId(bucketName string, id int, obj interface{}) error
CreateObjectWithStringId(bucketName string, id []byte, obj interface{}) error
DeleteAllObjects(bucketName string, obj interface{}, matching func(o interface{}) (id int, ok bool)) error
GetNextIdentifier(bucketName string) int
}
@@ -45,8 +46,8 @@ type Connection interface {
NeedsEncryptionMigration() (bool, error)
SetEncrypted(encrypted bool)
BackupMetadata() (map[string]any, error)
RestoreMetadata(s map[string]any) error
BackupMetadata() (map[string]interface{}, error)
RestoreMetadata(s map[string]interface{}) error
UpdateObjectFunc(bucketName string, key []byte, object any, updateFn func()) error
ConvertToKey(v int) []byte

View File

@@ -1,216 +1,52 @@
package crypto
import (
"bufio"
"bytes"
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"errors"
"fmt"
"io"
"golang.org/x/crypto/argon2"
"golang.org/x/crypto/scrypt"
)
const (
// AES GCM settings
aesGcmHeader = "AES256-GCM" // The encrypted file header
aesGcmBlockSize = 1024 * 1024 // 1MB block for aes gcm
// NOTE: has to go with what is considered to be a simplistic in that it omits any
// authentication of the encrypted data.
// Person with better knowledge is welcomed to improve it.
// sourced from https://golang.org/src/crypto/cipher/example_test.go
// Argon2 settings
// Recommded settings lower memory hardware according to current OWASP recommendations
// Considering some people run portainer on a NAS I think it's prudent not to assume we're on server grade hardware
// https://cheatsheetseries.owasp.org/cheatsheets/Password_Storage_Cheat_Sheet.html#argon2id
argon2MemoryCost = 12 * 1024
argon2TimeCost = 3
argon2Threads = 1
argon2KeyLength = 32
)
var emptySalt []byte = make([]byte, 0)
// AesEncrypt reads from input, encrypts with AES-256 and writes to output. passphrase is used to generate an encryption key
func AesEncrypt(input io.Reader, output io.Writer, passphrase []byte) error {
err := aesEncryptGCM(input, output, passphrase)
if err != nil {
return fmt.Errorf("error encrypting file: %w", err)
}
return nil
}
// AesDecrypt reads from input, decrypts with AES-256 and returns the reader to read the decrypted content from
func AesDecrypt(input io.Reader, passphrase []byte) (io.Reader, error) {
// Read file header to determine how it was encrypted
inputReader := bufio.NewReader(input)
header, err := inputReader.Peek(len(aesGcmHeader))
if err != nil {
return nil, fmt.Errorf("error reading encrypted backup file header: %w", err)
}
if string(header) == aesGcmHeader {
reader, err := aesDecryptGCM(inputReader, passphrase)
if err != nil {
return nil, fmt.Errorf("error decrypting file: %w", err)
}
return reader, nil
}
// Use the previous decryption routine which has no header (to support older archives)
reader, err := aesDecryptOFB(inputReader, passphrase)
if err != nil {
return nil, fmt.Errorf("error decrypting legacy file backup: %w", err)
}
return reader, nil
}
// aesEncryptGCM reads from input, encrypts with AES-256 and writes to output. passphrase is used to generate an encryption key.
func aesEncryptGCM(input io.Reader, output io.Writer, passphrase []byte) error {
// Derive key using argon2 with a random salt
salt := make([]byte, 16) // 16 bytes salt
if _, err := io.ReadFull(rand.Reader, salt); err != nil {
return err
}
key := argon2.IDKey(passphrase, salt, argon2TimeCost, argon2MemoryCost, argon2Threads, 32)
block, err := aes.NewCipher(key)
if err != nil {
return err
}
aesgcm, err := cipher.NewGCM(block)
if err != nil {
return err
}
// Generate nonce
nonce, err := NewRandomNonce(aesgcm.NonceSize())
if err != nil {
return err
}
// write the header
if _, err := output.Write([]byte(aesGcmHeader)); err != nil {
return err
}
// Write nonce and salt to the output file
if _, err := output.Write(salt); err != nil {
return err
}
if _, err := output.Write(nonce.Value()); err != nil {
return err
}
// Buffer for reading plaintext blocks
buf := make([]byte, aesGcmBlockSize) // Adjust buffer size as needed
ciphertext := make([]byte, len(buf)+aesgcm.Overhead())
// Encrypt plaintext in blocks
for {
n, err := io.ReadFull(input, buf)
if n == 0 {
break // end of plaintext input
}
if err != nil && !(errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF)) {
return err
}
// Seal encrypts the plaintext using the nonce returning the updated slice.
ciphertext = aesgcm.Seal(ciphertext[:0], nonce.Value(), buf[:n], nil)
_, err = output.Write(ciphertext)
if err != nil {
return err
}
nonce.Increment()
}
return nil
}
// aesDecryptGCM reads from input, decrypts with AES-256 and returns the reader to read the decrypted content from.
func aesDecryptGCM(input io.Reader, passphrase []byte) (io.Reader, error) {
// Reader & verify header
header := make([]byte, len(aesGcmHeader))
if _, err := io.ReadFull(input, header); err != nil {
return nil, err
}
if string(header) != aesGcmHeader {
return nil, fmt.Errorf("invalid header")
}
// Read salt
salt := make([]byte, 16) // Salt size
if _, err := io.ReadFull(input, salt); err != nil {
return nil, err
}
key := argon2.IDKey(passphrase, salt, argon2TimeCost, argon2MemoryCost, argon2Threads, 32)
// Initialize AES cipher block
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
// Create GCM mode with the cipher block
aesgcm, err := cipher.NewGCM(block)
if err != nil {
return nil, err
}
// Read nonce from the input reader
nonce := NewNonce(aesgcm.NonceSize())
if err := nonce.Read(input); err != nil {
return nil, err
}
// Initialize a buffer to store decrypted data
buf := bytes.Buffer{}
plaintext := make([]byte, aesGcmBlockSize)
// Decrypt the ciphertext in blocks
for {
// Read a block of ciphertext from the input reader
ciphertextBlock := make([]byte, aesGcmBlockSize+aesgcm.Overhead()) // Adjust block size as needed
n, err := io.ReadFull(input, ciphertextBlock)
if n == 0 {
break // end of ciphertext
}
if err != nil && !(errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF)) {
return nil, err
}
// Decrypt the block of ciphertext
plaintext, err = aesgcm.Open(plaintext[:0], nonce.Value(), ciphertextBlock[:n], nil)
if err != nil {
return nil, err
}
_, err = buf.Write(plaintext)
if err != nil {
return nil, err
}
nonce.Increment()
}
return &buf, nil
}
// aesDecryptOFB reads from input, decrypts with AES-256 and returns the reader to a read decrypted content from.
// AesEncrypt reads from input, encrypts with AES-256 and writes to the output.
// passphrase is used to generate an encryption key.
// note: This function used to decrypt files that were encrypted without a header i.e. old archives
func aesDecryptOFB(input io.Reader, passphrase []byte) (io.Reader, error) {
var emptySalt []byte = make([]byte, 0)
func AesEncrypt(input io.Reader, output io.Writer, passphrase []byte) error {
// making a 32 bytes key that would correspond to AES-256
// don't necessarily need a salt, so just kept in empty
key, err := scrypt.Key(passphrase, emptySalt, 32768, 8, 1, 32)
if err != nil {
return err
}
block, err := aes.NewCipher(key)
if err != nil {
return err
}
// If the key is unique for each ciphertext, then it's ok to use a zero
// IV.
var iv [aes.BlockSize]byte
stream := cipher.NewOFB(block, iv[:])
writer := &cipher.StreamWriter{S: stream, W: output}
// Copy the input to the output, encrypting as we go.
if _, err := io.Copy(writer, input); err != nil {
return err
}
return nil
}
// AesDecrypt reads from input, decrypts with AES-256 and returns the reader to a read decrypted content from.
// passphrase is used to generate an encryption key.
func AesDecrypt(input io.Reader, passphrase []byte) (io.Reader, error) {
// making a 32 bytes key that would correspond to AES-256
// don't necessarily need a salt, so just kept in empty
key, err := scrypt.Key(passphrase, emptySalt, 32768, 8, 1, 32)
@@ -223,9 +59,11 @@ func aesDecryptOFB(input io.Reader, passphrase []byte) (io.Reader, error) {
return nil, err
}
// If the key is unique for each ciphertext, then it's ok to use a zero IV.
// If the key is unique for each ciphertext, then it's ok to use a zero
// IV.
var iv [aes.BlockSize]byte
stream := cipher.NewOFB(block, iv[:])
reader := &cipher.StreamReader{S: stream, R: input}
return reader, nil

View File

@@ -2,7 +2,6 @@ package crypto
import (
"io"
"math/rand"
"os"
"path/filepath"
"testing"
@@ -10,19 +9,7 @@ import (
"github.com/stretchr/testify/assert"
)
const letterBytes = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789"
func randBytes(n int) []byte {
b := make([]byte, n)
for i := range b {
b[i] = letterBytes[rand.Intn(len(letterBytes))]
}
return b
}
func Test_encryptAndDecrypt_withTheSamePassword(t *testing.T) {
const passphrase = "passphrase"
tmpdir := t.TempDir()
var (
@@ -31,99 +18,17 @@ func Test_encryptAndDecrypt_withTheSamePassword(t *testing.T) {
decryptedFilePath = filepath.Join(tmpdir, "decrypted")
)
content := randBytes(1024*1024*100 + 523)
os.WriteFile(originFilePath, content, 0600)
originFile, _ := os.Open(originFilePath)
defer originFile.Close()
encryptedFileWriter, _ := os.Create(encryptedFilePath)
err := AesEncrypt(originFile, encryptedFileWriter, []byte(passphrase))
assert.Nil(t, err, "Failed to encrypt a file")
encryptedFileWriter.Close()
encryptedContent, err := os.ReadFile(encryptedFilePath)
assert.Nil(t, err, "Couldn't read encrypted file")
assert.NotEqual(t, encryptedContent, content, "Content wasn't encrypted")
encryptedFileReader, _ := os.Open(encryptedFilePath)
defer encryptedFileReader.Close()
decryptedFileWriter, _ := os.Create(decryptedFilePath)
defer decryptedFileWriter.Close()
decryptedReader, err := AesDecrypt(encryptedFileReader, []byte(passphrase))
assert.Nil(t, err, "Failed to decrypt file")
io.Copy(decryptedFileWriter, decryptedReader)
decryptedContent, _ := os.ReadFile(decryptedFilePath)
assert.Equal(t, content, decryptedContent, "Original and decrypted content should match")
}
func Test_encryptAndDecrypt_withStrongPassphrase(t *testing.T) {
const passphrase = "A strong passphrase with special characters: !@#$%^&*()_+"
tmpdir := t.TempDir()
var (
originFilePath = filepath.Join(tmpdir, "origin2")
encryptedFilePath = filepath.Join(tmpdir, "encrypted2")
decryptedFilePath = filepath.Join(tmpdir, "decrypted2")
)
content := randBytes(500)
os.WriteFile(originFilePath, content, 0600)
originFile, _ := os.Open(originFilePath)
defer originFile.Close()
encryptedFileWriter, _ := os.Create(encryptedFilePath)
err := AesEncrypt(originFile, encryptedFileWriter, []byte(passphrase))
assert.Nil(t, err, "Failed to encrypt a file")
encryptedFileWriter.Close()
encryptedContent, err := os.ReadFile(encryptedFilePath)
assert.Nil(t, err, "Couldn't read encrypted file")
assert.NotEqual(t, encryptedContent, content, "Content wasn't encrypted")
encryptedFileReader, _ := os.Open(encryptedFilePath)
defer encryptedFileReader.Close()
decryptedFileWriter, _ := os.Create(decryptedFilePath)
defer decryptedFileWriter.Close()
decryptedReader, err := AesDecrypt(encryptedFileReader, []byte(passphrase))
assert.Nil(t, err, "Failed to decrypt file")
io.Copy(decryptedFileWriter, decryptedReader)
decryptedContent, _ := os.ReadFile(decryptedFilePath)
assert.Equal(t, content, decryptedContent, "Original and decrypted content should match")
}
func Test_encryptAndDecrypt_withTheSamePasswordSmallFile(t *testing.T) {
tmpdir := t.TempDir()
var (
originFilePath = filepath.Join(tmpdir, "origin2")
encryptedFilePath = filepath.Join(tmpdir, "encrypted2")
decryptedFilePath = filepath.Join(tmpdir, "decrypted2")
)
content := randBytes(500)
content := []byte("content")
os.WriteFile(originFilePath, content, 0600)
originFile, _ := os.Open(originFilePath)
defer originFile.Close()
encryptedFileWriter, _ := os.Create(encryptedFilePath)
defer encryptedFileWriter.Close()
err := AesEncrypt(originFile, encryptedFileWriter, []byte("passphrase"))
assert.Nil(t, err, "Failed to encrypt a file")
encryptedFileWriter.Close()
encryptedContent, err := os.ReadFile(encryptedFilePath)
assert.Nil(t, err, "Couldn't read encrypted file")
assert.NotEqual(t, encryptedContent, content, "Content wasn't encrypted")
@@ -152,7 +57,7 @@ func Test_encryptAndDecrypt_withEmptyPassword(t *testing.T) {
decryptedFilePath = filepath.Join(tmpdir, "decrypted")
)
content := randBytes(1024 * 50)
content := []byte("content")
os.WriteFile(originFilePath, content, 0600)
originFile, _ := os.Open(originFilePath)
@@ -191,7 +96,7 @@ func Test_decryptWithDifferentPassphrase_shouldProduceWrongResult(t *testing.T)
decryptedFilePath = filepath.Join(tmpdir, "decrypted")
)
content := randBytes(1034)
content := []byte("content")
os.WriteFile(originFilePath, content, 0600)
originFile, _ := os.Open(originFilePath)
@@ -212,6 +117,11 @@ func Test_decryptWithDifferentPassphrase_shouldProduceWrongResult(t *testing.T)
decryptedFileWriter, _ := os.Create(decryptedFilePath)
defer decryptedFileWriter.Close()
_, err = AesDecrypt(encryptedFileReader, []byte("garbage"))
assert.NotNil(t, err, "Should not allow decrypt with wrong passphrase")
decryptedReader, err := AesDecrypt(encryptedFileReader, []byte("garbage"))
assert.Nil(t, err, "Should allow to decrypt with wrong passphrase")
io.Copy(decryptedFileWriter, decryptedReader)
decryptedContent, _ := os.ReadFile(decryptedFilePath)
assert.NotEqual(t, content, decryptedContent, "Original and decrypted content should NOT match")
}

View File

@@ -1,61 +0,0 @@
package crypto
import (
"crypto/rand"
"errors"
"io"
)
type Nonce struct {
val []byte
}
func NewNonce(size int) *Nonce {
return &Nonce{val: make([]byte, size)}
}
// NewRandomNonce generates a new initial nonce with the lower byte set to a random value
// This ensures there are plenty of nonce values availble before rolling over
// Based on ideas from the Secure Programming Cookbook for C and C++ by John Viega, Matt Messier
// https://www.oreilly.com/library/view/secure-programming-cookbook/0596003943/ch04s09.html
func NewRandomNonce(size int) (*Nonce, error) {
randomBytes := 1
if size <= randomBytes {
return nil, errors.New("nonce size must be greater than the number of random bytes")
}
randomPart := make([]byte, randomBytes)
if _, err := rand.Read(randomPart); err != nil {
return nil, err
}
zeroPart := make([]byte, size-randomBytes)
nonceVal := append(randomPart, zeroPart...)
return &Nonce{val: nonceVal}, nil
}
func (n *Nonce) Read(stream io.Reader) error {
_, err := io.ReadFull(stream, n.val)
return err
}
func (n *Nonce) Value() []byte {
return n.val
}
func (n *Nonce) Increment() error {
// Start incrementing from the least significant byte
for i := len(n.val) - 1; i >= 0; i-- {
// Increment the current byte
n.val[i]++
// Check for overflow
if n.val[i] != 0 {
// No overflow, nonce is successfully incremented
return nil
}
}
// If we reach here, it means the nonce has overflowed
return errors.New("nonce overflow")
}

View File

@@ -22,12 +22,6 @@ func CreateTLSConfiguration() *tls.Config {
tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
tls.TLS_RSA_WITH_AES_128_GCM_SHA256,
tls.TLS_RSA_WITH_AES_256_GCM_SHA384,
tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256,
tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256,
},
}
}

View File

@@ -8,7 +8,6 @@ import (
"math"
"os"
"path"
"strconv"
"time"
portainer "github.com/portainer/portainer/api"
@@ -74,6 +73,7 @@ func (connection *DbConnection) IsEncryptedStore() bool {
// NeedsEncryptionMigration returns true if database encryption is enabled and
// we have an un-encrypted DB that requires migration to an encrypted DB
func (connection *DbConnection) NeedsEncryptionMigration() (bool, error) {
// Cases: Note, we need to check both portainer.db and portainer.edb
// to determine if it's a new store. We only need to differentiate between cases 2,3 and 5
@@ -121,11 +121,11 @@ func (connection *DbConnection) NeedsEncryptionMigration() (bool, error) {
// Open opens and initializes the BoltDB database.
func (connection *DbConnection) Open() error {
log.Info().Str("filename", connection.GetDatabaseFileName()).Msg("loading PortainerDB")
// Now we open the db
databasePath := connection.GetDatabaseFilePath()
db, err := bolt.Open(databasePath, 0600, &bolt.Options{
Timeout: 1 * time.Second,
InitialMmapSize: connection.InitialMmapSize,
@@ -178,7 +178,6 @@ func (connection *DbConnection) ViewTx(fn func(portainer.Transaction) error) err
func (connection *DbConnection) BackupTo(w io.Writer) error {
return connection.View(func(tx *bolt.Tx) error {
_, err := tx.WriteTo(w)
return err
})
}
@@ -193,7 +192,6 @@ func (connection *DbConnection) ExportRaw(filename string) error {
if err != nil {
return err
}
return os.WriteFile(filename, b, 0600)
}
@@ -203,7 +201,6 @@ func (connection *DbConnection) ExportRaw(filename string) error {
func (connection *DbConnection) ConvertToKey(v int) []byte {
b := make([]byte, 8)
binary.BigEndian.PutUint64(b, uint64(v))
return b
}
@@ -215,7 +212,7 @@ func keyToString(b []byte) string {
v := binary.BigEndian.Uint64(b)
if v <= math.MaxInt32 {
return strconv.FormatUint(v, 10)
return fmt.Sprintf("%d", v)
}
return string(b)
@@ -229,7 +226,7 @@ func (connection *DbConnection) SetServiceName(bucketName string) error {
}
// GetObject is a generic function used to retrieve an unmarshalled object from a database.
func (connection *DbConnection) GetObject(bucketName string, key []byte, object any) error {
func (connection *DbConnection) GetObject(bucketName string, key []byte, object interface{}) error {
return connection.ViewTx(func(tx portainer.Transaction) error {
return tx.GetObject(bucketName, key, object)
})
@@ -244,7 +241,7 @@ func (connection *DbConnection) getEncryptionKey() []byte {
}
// UpdateObject is a generic function used to update an object inside a database.
func (connection *DbConnection) UpdateObject(bucketName string, key []byte, object any) error {
func (connection *DbConnection) UpdateObject(bucketName string, key []byte, object interface{}) error {
return connection.UpdateTx(func(tx portainer.Transaction) error {
return tx.UpdateObject(bucketName, key, object)
})
@@ -285,7 +282,7 @@ func (connection *DbConnection) DeleteObject(bucketName string, key []byte) erro
// DeleteAllObjects delete all objects where matching() returns (id, ok).
// TODO: think about how to return the error inside (maybe change ok to type err, and use "notfound"?
func (connection *DbConnection) DeleteAllObjects(bucketName string, obj any, matching func(o any) (id int, ok bool)) error {
func (connection *DbConnection) DeleteAllObjects(bucketName string, obj interface{}, matching func(o interface{}) (id int, ok bool)) error {
return connection.UpdateTx(func(tx portainer.Transaction) error {
return tx.DeleteAllObjects(bucketName, obj, matching)
})
@@ -304,64 +301,71 @@ func (connection *DbConnection) GetNextIdentifier(bucketName string) int {
}
// CreateObject creates a new object in the bucket, using the next bucket sequence id
func (connection *DbConnection) CreateObject(bucketName string, fn func(uint64) (int, any)) error {
func (connection *DbConnection) CreateObject(bucketName string, fn func(uint64) (int, interface{})) error {
return connection.UpdateTx(func(tx portainer.Transaction) error {
return tx.CreateObject(bucketName, fn)
})
}
// CreateObjectWithId creates a new object in the bucket, using the specified id
func (connection *DbConnection) CreateObjectWithId(bucketName string, id int, obj any) error {
func (connection *DbConnection) CreateObjectWithId(bucketName string, id int, obj interface{}) error {
return connection.UpdateTx(func(tx portainer.Transaction) error {
return tx.CreateObjectWithId(bucketName, id, obj)
})
}
// CreateObjectWithStringId creates a new object in the bucket, using the specified id
func (connection *DbConnection) CreateObjectWithStringId(bucketName string, id []byte, obj any) error {
func (connection *DbConnection) CreateObjectWithStringId(bucketName string, id []byte, obj interface{}) error {
return connection.UpdateTx(func(tx portainer.Transaction) error {
return tx.CreateObjectWithStringId(bucketName, id, obj)
})
}
func (connection *DbConnection) GetAll(bucketName string, obj any, appendFn func(o any) (any, error)) error {
func (connection *DbConnection) GetAll(bucketName string, obj interface{}, append func(o interface{}) (interface{}, error)) error {
return connection.ViewTx(func(tx portainer.Transaction) error {
return tx.GetAll(bucketName, obj, appendFn)
return tx.GetAll(bucketName, obj, append)
})
}
func (connection *DbConnection) GetAllWithKeyPrefix(bucketName string, keyPrefix []byte, obj any, appendFn func(o any) (any, error)) error {
// TODO: decide which Unmarshal to use, and use one...
func (connection *DbConnection) GetAllWithJsoniter(bucketName string, obj interface{}, append func(o interface{}) (interface{}, error)) error {
return connection.ViewTx(func(tx portainer.Transaction) error {
return tx.GetAllWithKeyPrefix(bucketName, keyPrefix, obj, appendFn)
return tx.GetAllWithJsoniter(bucketName, obj, append)
})
}
func (connection *DbConnection) GetAllWithKeyPrefix(bucketName string, keyPrefix []byte, obj interface{}, append func(o interface{}) (interface{}, error)) error {
return connection.ViewTx(func(tx portainer.Transaction) error {
return tx.GetAllWithKeyPrefix(bucketName, keyPrefix, obj, append)
})
}
// BackupMetadata will return a copy of the boltdb sequence numbers for all buckets.
func (connection *DbConnection) BackupMetadata() (map[string]any, error) {
buckets := map[string]any{}
func (connection *DbConnection) BackupMetadata() (map[string]interface{}, error) {
buckets := map[string]interface{}{}
err := connection.View(func(tx *bolt.Tx) error {
return tx.ForEach(func(name []byte, bucket *bolt.Bucket) error {
err := tx.ForEach(func(name []byte, bucket *bolt.Bucket) error {
bucketName := string(name)
seqId := bucket.Sequence()
buckets[bucketName] = int(seqId)
return nil
})
return err
})
return buckets, err
}
// RestoreMetadata will restore the boltdb sequence numbers for all buckets.
func (connection *DbConnection) RestoreMetadata(s map[string]any) error {
func (connection *DbConnection) RestoreMetadata(s map[string]interface{}) error {
var err error
for bucketName, v := range s {
id, ok := v.(float64) // JSON ints are unmarshalled to interface as float64. See: https://pkg.go.dev/encoding/json#Decoder.Decode
if !ok {
log.Error().Str("bucket", bucketName).Msg("failed to restore metadata to bucket, skipped")
continue
}

View File

@@ -87,7 +87,10 @@ func Test_NeedsEncryptionMigration(t *testing.T) {
}
for _, tc := range cases {
tc := tc
t.Run(tc.name, func(t *testing.T) {
connection := DbConnection{Path: dir}
if tc.dbname == "both" {

View File

@@ -8,8 +8,8 @@ import (
bolt "go.etcd.io/bbolt"
)
func backupMetadata(connection *bolt.DB) (map[string]any, error) {
buckets := map[string]any{}
func backupMetadata(connection *bolt.DB) (map[string]interface{}, error) {
buckets := map[string]interface{}{}
err := connection.View(func(tx *bolt.Tx) error {
err := tx.ForEach(func(name []byte, bucket *bolt.Bucket) error {
@@ -39,7 +39,7 @@ func (c *DbConnection) ExportJSON(databasePath string, metadata bool) ([]byte, e
}
defer connection.Close()
backup := make(map[string]any)
backup := make(map[string]interface{})
if metadata {
meta, err := backupMetadata(connection)
if err != nil {
@@ -52,7 +52,7 @@ func (c *DbConnection) ExportJSON(databasePath string, metadata bool) ([]byte, e
err = connection.View(func(tx *bolt.Tx) error {
err = tx.ForEach(func(name []byte, bucket *bolt.Bucket) error {
bucketName := string(name)
var list []any
var list []interface{}
version := make(map[string]string)
cursor := bucket.Cursor()
for k, v := cursor.First(); k != nil; k, v = cursor.Next() {
@@ -60,7 +60,7 @@ func (c *DbConnection) ExportJSON(databasePath string, metadata bool) ([]byte, e
continue
}
var obj any
var obj interface{}
err := c.UnmarshalObject(v, &obj)
if err != nil {
log.Error().

View File

@@ -5,16 +5,17 @@ import (
"crypto/aes"
"crypto/cipher"
"crypto/rand"
"fmt"
"io"
"github.com/pkg/errors"
"github.com/segmentio/encoding/json"
)
var errEncryptedStringTooShort = errors.New("encrypted string too short")
var errEncryptedStringTooShort = fmt.Errorf("encrypted string too short")
// MarshalObject encodes an object to binary format
func (connection *DbConnection) MarshalObject(object any) ([]byte, error) {
func (connection *DbConnection) MarshalObject(object interface{}) ([]byte, error) {
buf := &bytes.Buffer{}
// Special case for the VERSION bucket. Here we're not using json
@@ -38,7 +39,7 @@ func (connection *DbConnection) MarshalObject(object any) ([]byte, error) {
}
// UnmarshalObject decodes an object from binary data
func (connection *DbConnection) UnmarshalObject(data []byte, object any) error {
func (connection *DbConnection) UnmarshalObject(data []byte, object interface{}) error {
var err error
if connection.getEncryptionKey() != nil {
data, err = decrypt(data, connection.getEncryptionKey())
@@ -46,8 +47,8 @@ func (connection *DbConnection) UnmarshalObject(data []byte, object any) error {
return errors.Wrap(err, "Failed decrypting object")
}
}
if e := json.Unmarshal(data, object); e != nil {
e := json.Unmarshal(data, object)
if e != nil {
// Special case for the VERSION bucket. Here we're not using json
// So we need to return it as a string
s, ok := object.(*string)
@@ -57,7 +58,6 @@ func (connection *DbConnection) UnmarshalObject(data []byte, object any) error {
*s = string(data)
}
return err
}
@@ -70,20 +70,22 @@ func encrypt(plaintext []byte, passphrase []byte) (encrypted []byte, err error)
if err != nil {
return encrypted, err
}
nonce := make([]byte, gcm.NonceSize())
if _, err := io.ReadFull(rand.Reader, nonce); err != nil {
if _, err = io.ReadFull(rand.Reader, nonce); err != nil {
return encrypted, err
}
return gcm.Seal(nonce, nonce, plaintext, nil), nil
ciphertextByte := gcm.Seal(
nonce,
nonce,
plaintext,
nil)
return ciphertextByte, nil
}
func decrypt(encrypted []byte, passphrase []byte) (plaintextByte []byte, err error) {
if string(encrypted) == "false" {
return []byte("false"), nil
}
block, err := aes.NewCipher(passphrase)
if err != nil {
return encrypted, errors.Wrap(err, "Error creating cypher block")
@@ -100,8 +102,11 @@ func decrypt(encrypted []byte, passphrase []byte) (plaintextByte []byte, err err
}
nonce, ciphertextByteClean := encrypted[:nonceSize], encrypted[nonceSize:]
plaintextByte, err = gcm.Open(nil, nonce, ciphertextByteClean, nil)
plaintextByte, err = gcm.Open(
nil,
nonce,
ciphertextByteClean,
nil)
if err != nil {
return encrypted, errors.Wrap(err, "Error decrypting text")
}

View File

@@ -25,7 +25,7 @@ func Test_MarshalObjectUnencrypted(t *testing.T) {
uuid := uuid.Must(uuid.NewV4())
tests := []struct {
object any
object interface{}
expected string
}{
{
@@ -57,7 +57,7 @@ func Test_MarshalObjectUnencrypted(t *testing.T) {
expected: uuid.String(),
},
{
object: map[string]any{"key": "value"},
object: map[string]interface{}{"key": "value"},
expected: `{"key":"value"}`,
},
{
@@ -73,11 +73,11 @@ func Test_MarshalObjectUnencrypted(t *testing.T) {
expected: `["1","2","3"]`,
},
{
object: []map[string]any{{"key1": "value1"}, {"key2": "value2"}},
object: []map[string]interface{}{{"key1": "value1"}, {"key2": "value2"}},
expected: `[{"key1":"value1"},{"key2":"value2"}]`,
},
{
object: []any{1, "2", false, map[string]any{"key1": "value1"}},
object: []interface{}{1, "2", false, map[string]interface{}{"key1": "value1"}},
expected: `[1,"2",false,{"key1":"value1"}]`,
},
}

View File

@@ -20,7 +20,7 @@ func (tx *DbTransaction) SetServiceName(bucketName string) error {
return err
}
func (tx *DbTransaction) GetObject(bucketName string, key []byte, object any) error {
func (tx *DbTransaction) GetObject(bucketName string, key []byte, object interface{}) error {
bucket := tx.tx.Bucket([]byte(bucketName))
value := bucket.Get(key)
@@ -31,7 +31,7 @@ func (tx *DbTransaction) GetObject(bucketName string, key []byte, object any) er
return tx.conn.UnmarshalObject(value, object)
}
func (tx *DbTransaction) UpdateObject(bucketName string, key []byte, object any) error {
func (tx *DbTransaction) UpdateObject(bucketName string, key []byte, object interface{}) error {
data, err := tx.conn.MarshalObject(object)
if err != nil {
return err
@@ -46,7 +46,7 @@ func (tx *DbTransaction) DeleteObject(bucketName string, key []byte) error {
return bucket.Delete(key)
}
func (tx *DbTransaction) DeleteAllObjects(bucketName string, obj any, matchingFn func(o any) (id int, ok bool)) error {
func (tx *DbTransaction) DeleteAllObjects(bucketName string, obj interface{}, matchingFn func(o interface{}) (id int, ok bool)) error {
var ids []int
bucket := tx.tx.Bucket([]byte(bucketName))
@@ -74,18 +74,16 @@ func (tx *DbTransaction) DeleteAllObjects(bucketName string, obj any, matchingFn
func (tx *DbTransaction) GetNextIdentifier(bucketName string) int {
bucket := tx.tx.Bucket([]byte(bucketName))
id, err := bucket.NextSequence()
if err != nil {
log.Error().Err(err).Str("bucket", bucketName).Msg("failed to get the next identifier")
log.Error().Err(err).Str("bucket", bucketName).Msg("failed to get the next identifer")
return 0
}
return int(id)
}
func (tx *DbTransaction) CreateObject(bucketName string, fn func(uint64) (int, any)) error {
func (tx *DbTransaction) CreateObject(bucketName string, fn func(uint64) (int, interface{})) error {
bucket := tx.tx.Bucket([]byte(bucketName))
seqId, _ := bucket.NextSequence()
@@ -99,7 +97,7 @@ func (tx *DbTransaction) CreateObject(bucketName string, fn func(uint64) (int, a
return bucket.Put(tx.conn.ConvertToKey(id), data)
}
func (tx *DbTransaction) CreateObjectWithId(bucketName string, id int, obj any) error {
func (tx *DbTransaction) CreateObjectWithId(bucketName string, id int, obj interface{}) error {
bucket := tx.tx.Bucket([]byte(bucketName))
data, err := tx.conn.MarshalObject(obj)
if err != nil {
@@ -109,7 +107,7 @@ func (tx *DbTransaction) CreateObjectWithId(bucketName string, id int, obj any)
return bucket.Put(tx.conn.ConvertToKey(id), data)
}
func (tx *DbTransaction) CreateObjectWithStringId(bucketName string, id []byte, obj any) error {
func (tx *DbTransaction) CreateObjectWithStringId(bucketName string, id []byte, obj interface{}) error {
bucket := tx.tx.Bucket([]byte(bucketName))
data, err := tx.conn.MarshalObject(obj)
if err != nil {
@@ -119,7 +117,7 @@ func (tx *DbTransaction) CreateObjectWithStringId(bucketName string, id []byte,
return bucket.Put(id, data)
}
func (tx *DbTransaction) GetAll(bucketName string, obj any, appendFn func(o any) (any, error)) error {
func (tx *DbTransaction) GetAll(bucketName string, obj interface{}, appendFn func(o interface{}) (interface{}, error)) error {
bucket := tx.tx.Bucket([]byte(bucketName))
return bucket.ForEach(func(k []byte, v []byte) error {
@@ -132,7 +130,20 @@ func (tx *DbTransaction) GetAll(bucketName string, obj any, appendFn func(o any)
})
}
func (tx *DbTransaction) GetAllWithKeyPrefix(bucketName string, keyPrefix []byte, obj any, appendFn func(o any) (any, error)) error {
func (tx *DbTransaction) GetAllWithJsoniter(bucketName string, obj interface{}, appendFn func(o interface{}) (interface{}, error)) error {
bucket := tx.tx.Bucket([]byte(bucketName))
return bucket.ForEach(func(k []byte, v []byte) error {
err := tx.conn.UnmarshalObject(v, obj)
if err == nil {
obj, err = appendFn(obj)
}
return err
})
}
func (tx *DbTransaction) GetAllWithKeyPrefix(bucketName string, keyPrefix []byte, obj interface{}, appendFn func(o interface{}) (interface{}, error)) error {
cursor := tx.tx.Bucket([]byte(bucketName)).Cursor()
for k, v := cursor.Seek(keyPrefix); k != nil && bytes.HasPrefix(k, keyPrefix); k, v = cursor.Next() {

View File

@@ -1,6 +1,7 @@
package apikeyrepository
import (
"bytes"
"errors"
"fmt"
@@ -36,12 +37,12 @@ func NewService(connection portainer.Connection) (*Service, error) {
// GetAPIKeysByUserID returns a slice containing all the APIKeys a user has access to.
func (service *Service) GetAPIKeysByUserID(userID portainer.UserID) ([]portainer.APIKey, error) {
result := make([]portainer.APIKey, 0)
var result = make([]portainer.APIKey, 0)
err := service.Connection.GetAll(
BucketName,
&portainer.APIKey{},
func(obj any) (any, error) {
func(obj interface{}) (interface{}, error) {
record, ok := obj.(*portainer.APIKey)
if !ok {
log.Debug().Str("obj", fmt.Sprintf("%#v", obj)).Msg("failed to convert to APIKey object")
@@ -60,19 +61,19 @@ func (service *Service) GetAPIKeysByUserID(userID portainer.UserID) ([]portainer
// GetAPIKeyByDigest returns the API key for the associated digest.
// Note: there is a 1-to-1 mapping of api-key and digest
func (service *Service) GetAPIKeyByDigest(digest string) (*portainer.APIKey, error) {
func (service *Service) GetAPIKeyByDigest(digest []byte) (*portainer.APIKey, error) {
var k *portainer.APIKey
stop := fmt.Errorf("ok")
err := service.Connection.GetAll(
BucketName,
&portainer.APIKey{},
func(obj any) (any, error) {
func(obj interface{}) (interface{}, error) {
key, ok := obj.(*portainer.APIKey)
if !ok {
log.Debug().Str("obj", fmt.Sprintf("%#v", obj)).Msg("failed to convert to APIKey object")
return nil, fmt.Errorf("failed to convert to APIKey object: %s", obj)
}
if key.Digest == digest {
if bytes.Equal(key.Digest, digest) {
k = key
return nil, stop
}
@@ -95,7 +96,7 @@ func (service *Service) GetAPIKeyByDigest(digest string) (*portainer.APIKey, err
func (service *Service) Create(record *portainer.APIKey) error {
return service.Connection.CreateObject(
BucketName,
func(id uint64) (int, any) {
func(id uint64) (int, interface{}) {
record.ID = portainer.APIKeyID(id)
return int(record.ID), record

View File

@@ -31,7 +31,7 @@ func (service BaseDataServiceTx[T, I]) Read(ID I) (*T, error) {
func (service BaseDataServiceTx[T, I]) ReadAll() ([]T, error) {
var collection = make([]T, 0)
return collection, service.Tx.GetAll(
return collection, service.Tx.GetAllWithJsoniter(
service.Bucket,
new(T),
AppendFn(&collection),

View File

@@ -19,7 +19,7 @@ func (service ServiceTx) UpdateEdgeGroupFunc(ID portainer.EdgeGroupID, updateFun
func (service ServiceTx) Create(group *portainer.EdgeGroup) error {
return service.Tx.CreateObject(
BucketName,
func(id uint64) (int, any) {
func(id uint64) (int, interface{}) {
group.ID = portainer.EdgeGroupID(id)
return int(group.ID), group
},

View File

@@ -24,7 +24,7 @@ func (service ServiceTx) EdgeStacks() ([]portainer.EdgeStack, error) {
err := service.tx.GetAll(
BucketName,
&portainer.EdgeStack{},
func(obj any) (any, error) {
func(obj interface{}) (interface{}, error) {
stack, ok := obj.(*portainer.EdgeStack)
if !ok {
log.Debug().Str("obj", fmt.Sprintf("%#v", obj)).Msg("failed to convert to EdgeStack object")

View File

@@ -6,8 +6,6 @@ import (
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/dataservices"
"github.com/rs/zerolog/log"
)
// BucketName represents the name of the bucket where this service stores data.
@@ -159,7 +157,6 @@ func (service *Service) EndpointsByTeamID(teamID portainer.TeamID) ([]portainer.
return true
}
}
return false
}),
)
@@ -169,13 +166,11 @@ func (service *Service) EndpointsByTeamID(teamID portainer.TeamID) ([]portainer.
func (service *Service) GetNextIdentifier() int {
var identifier int
if err := service.connection.UpdateTx(func(tx portainer.Transaction) error {
service.connection.UpdateTx(func(tx portainer.Transaction) error {
identifier = service.Tx(tx).GetNextIdentifier()
return nil
}); err != nil {
log.Error().Err(err).Str("bucket", BucketName).Msg("could not get the next identifier")
}
})
return identifier
}

View File

@@ -20,10 +20,10 @@ func (service ServiceTx) BucketName() string {
// Endpoint returns an environment(endpoint) by ID.
func (service ServiceTx) Endpoint(ID portainer.EndpointID) (*portainer.Endpoint, error) {
var endpoint portainer.Endpoint
identifier := service.service.connection.ConvertToKey(int(ID))
if err := service.tx.GetObject(BucketName, identifier, &endpoint); err != nil {
err := service.tx.GetObject(BucketName, identifier, &endpoint)
if err != nil {
return nil, err
}
@@ -36,7 +36,8 @@ func (service ServiceTx) Endpoint(ID portainer.EndpointID) (*portainer.Endpoint,
func (service ServiceTx) UpdateEndpoint(ID portainer.EndpointID, endpoint *portainer.Endpoint) error {
identifier := service.service.connection.ConvertToKey(int(ID))
if err := service.tx.UpdateObject(BucketName, identifier, endpoint); err != nil {
err := service.tx.UpdateObject(BucketName, identifier, endpoint)
if err != nil {
return err
}
@@ -44,7 +45,6 @@ func (service ServiceTx) UpdateEndpoint(ID portainer.EndpointID, endpoint *porta
if len(endpoint.EdgeID) > 0 {
service.service.idxEdgeID[endpoint.EdgeID] = ID
}
service.service.heartbeats.Store(ID, endpoint.LastCheckInDate)
service.service.mu.Unlock()
@@ -57,7 +57,8 @@ func (service ServiceTx) UpdateEndpoint(ID portainer.EndpointID, endpoint *porta
func (service ServiceTx) DeleteEndpoint(ID portainer.EndpointID) error {
identifier := service.service.connection.ConvertToKey(int(ID))
if err := service.tx.DeleteObject(BucketName, identifier); err != nil {
err := service.tx.DeleteObject(BucketName, identifier)
if err != nil {
return err
}
@@ -69,7 +70,6 @@ func (service ServiceTx) DeleteEndpoint(ID portainer.EndpointID) error {
break
}
}
service.service.heartbeats.Delete(ID)
service.service.mu.Unlock()
@@ -82,7 +82,7 @@ func (service ServiceTx) DeleteEndpoint(ID portainer.EndpointID) error {
func (service ServiceTx) Endpoints() ([]portainer.Endpoint, error) {
var endpoints = make([]portainer.Endpoint, 0)
return endpoints, service.tx.GetAll(
return endpoints, service.tx.GetAllWithJsoniter(
BucketName,
&portainer.Endpoint{},
dataservices.AppendFn(&endpoints),
@@ -107,7 +107,8 @@ func (service ServiceTx) UpdateHeartbeat(endpointID portainer.EndpointID) {
// CreateEndpoint assign an ID to a new environment(endpoint) and saves it.
func (service ServiceTx) Create(endpoint *portainer.Endpoint) error {
if err := service.tx.CreateObjectWithId(BucketName, int(endpoint.ID), endpoint); err != nil {
err := service.tx.CreateObjectWithId(BucketName, int(endpoint.ID), endpoint)
if err != nil {
return err
}
@@ -115,7 +116,6 @@ func (service ServiceTx) Create(endpoint *portainer.Endpoint) error {
if len(endpoint.EdgeID) > 0 {
service.service.idxEdgeID[endpoint.EdgeID] = endpoint.ID
}
service.service.heartbeats.Store(endpoint.ID, endpoint.LastCheckInDate)
service.service.mu.Unlock()
@@ -134,7 +134,6 @@ func (service ServiceTx) EndpointsByTeamID(teamID portainer.TeamID) ([]portainer
return true
}
}
return false
}),
)

View File

@@ -41,7 +41,7 @@ func (service *Service) Tx(tx portainer.Transaction) ServiceTx {
func (service *Service) Create(endpointGroup *portainer.EndpointGroup) error {
return service.Connection.CreateObject(
BucketName,
func(id uint64) (int, any) {
func(id uint64) (int, interface{}) {
endpointGroup.ID = portainer.EndpointGroupID(id)
return int(endpointGroup.ID), endpointGroup
},

View File

@@ -13,7 +13,7 @@ type ServiceTx struct {
func (service ServiceTx) Create(endpointGroup *portainer.EndpointGroup) error {
return service.Tx.CreateObject(
BucketName,
func(id uint64) (int, any) {
func(id uint64) (int, interface{}) {
endpointGroup.ID = portainer.EndpointGroupID(id)
return int(endpointGroup.ID), endpointGroup
},

View File

@@ -32,7 +32,8 @@ func (service *Service) RegisterUpdateStackFunction(
// NewService creates a new instance of a service.
func NewService(connection portainer.Connection) (*Service, error) {
if err := connection.SetServiceName(BucketName); err != nil {
err := connection.SetServiceName(BucketName)
if err != nil {
return nil, err
}
@@ -64,7 +65,8 @@ func (service *Service) EndpointRelation(endpointID portainer.EndpointID) (*port
var endpointRelation portainer.EndpointRelation
identifier := service.connection.ConvertToKey(int(endpointID))
if err := service.connection.GetObject(BucketName, identifier, &endpointRelation); err != nil {
err := service.connection.GetObject(BucketName, identifier, &endpointRelation)
if err != nil {
return nil, err
}
@@ -159,24 +161,19 @@ func (service *Service) updateEdgeStacksAfterRelationChange(previousRelationStat
// list how many time this stack is referenced in all relations
// in order to update the stack deployments count
for refStackId, refStackEnabled := range stacksToUpdate {
if !refStackEnabled {
continue
}
numDeployments := 0
for _, r := range relations {
for sId, enabled := range r.EdgeStacks {
if enabled && sId == refStackId {
numDeployments += 1
if refStackEnabled {
numDeployments := 0
for _, r := range relations {
for sId, enabled := range r.EdgeStacks {
if enabled && sId == refStackId {
numDeployments += 1
}
}
}
}
if err := service.updateStackFn(refStackId, func(edgeStack *portainer.EdgeStack) {
edgeStack.NumDeployments = numDeployments
}); err != nil {
log.Error().Err(err).Msg("could not update the number of deployments")
service.updateStackFn(refStackId, func(edgeStack *portainer.EdgeStack) {
edgeStack.NumDeployments = numDeployments
})
}
}
}

View File

@@ -33,7 +33,8 @@ func (service ServiceTx) EndpointRelation(endpointID portainer.EndpointID) (*por
var endpointRelation portainer.EndpointRelation
identifier := service.service.connection.ConvertToKey(int(endpointID))
if err := service.tx.GetObject(BucketName, identifier, &endpointRelation); err != nil {
err := service.tx.GetObject(BucketName, identifier, &endpointRelation)
if err != nil {
return nil, err
}
@@ -128,23 +129,19 @@ func (service ServiceTx) updateEdgeStacksAfterRelationChange(previousRelationSta
// list how many time this stack is referenced in all relations
// in order to update the stack deployments count
for refStackId, refStackEnabled := range stacksToUpdate {
if !refStackEnabled {
continue
}
numDeployments := 0
for _, r := range relations {
for sId, enabled := range r.EdgeStacks {
if enabled && sId == refStackId {
numDeployments += 1
if refStackEnabled {
numDeployments := 0
for _, r := range relations {
for sId, enabled := range r.EdgeStacks {
if enabled && sId == refStackId {
numDeployments += 1
}
}
}
}
if err := service.service.updateStackFnTx(service.tx, refStackId, func(edgeStack *portainer.EdgeStack) {
edgeStack.NumDeployments = numDeployments
}); err != nil {
log.Error().Err(err).Msg("could not update the number of deployments")
service.service.updateStackFnTx(service.tx, refStackId, func(edgeStack *portainer.EdgeStack) {
edgeStack.NumDeployments = numDeployments
})
}
}
}

View File

@@ -0,0 +1,43 @@
package fdoprofile
import (
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/dataservices"
)
// BucketName represents the name of the bucket where this service stores data.
const BucketName = "fdo_profiles"
// Service represents a service for managingFDO Profiles data.
type Service struct {
dataservices.BaseDataService[portainer.FDOProfile, portainer.FDOProfileID]
}
// NewService creates a new instance of a service.
func NewService(connection portainer.Connection) (*Service, error) {
err := connection.SetServiceName(BucketName)
if err != nil {
return nil, err
}
return &Service{
BaseDataService: dataservices.BaseDataService[portainer.FDOProfile, portainer.FDOProfileID]{
Bucket: BucketName,
Connection: connection,
},
}, nil
}
// Create assign an ID to a new FDO Profile and saves it.
func (service *Service) Create(FDOProfile *portainer.FDOProfile) error {
return service.Connection.CreateObjectWithId(
BucketName,
int(FDOProfile.ID),
FDOProfile,
)
}
// GetNextIdentifier returns the next identifier for a FDO Profile.
func (service *Service) GetNextIdentifier() int {
return service.Connection.GetNextIdentifier(BucketName)
}

View File

@@ -45,7 +45,7 @@ func (service *Service) HelmUserRepositoryByUserID(userID portainer.UserID) ([]p
func (service *Service) Create(record *portainer.HelmUserRepository) error {
return service.Connection.CreateObject(
BucketName,
func(id uint64) (int, any) {
func(id uint64) (int, interface{}) {
record.ID = portainer.HelmUserRepositoryID(id)
return int(record.ID), record
},

View File

@@ -17,8 +17,8 @@ func IsErrObjectNotFound(e error) bool {
}
// AppendFn appends elements to the given collection slice
func AppendFn[T any](collection *[]T) func(obj any) (any, error) {
return func(obj any) (any, error) {
func AppendFn[T any](collection *[]T) func(obj interface{}) (interface{}, error) {
return func(obj interface{}) (interface{}, error) {
element, ok := obj.(*T)
if !ok {
log.Debug().Str("obj", fmt.Sprintf("%#v", obj)).Msg("type assertion failed")
@@ -32,8 +32,8 @@ func AppendFn[T any](collection *[]T) func(obj any) (any, error) {
}
// FilterFn appends elements to the given collection when the predicate is true
func FilterFn[T any](collection *[]T, predicate func(T) bool) func(obj any) (any, error) {
return func(obj any) (any, error) {
func FilterFn[T any](collection *[]T, predicate func(T) bool) func(obj interface{}) (interface{}, error) {
return func(obj interface{}) (interface{}, error) {
element, ok := obj.(*T)
if !ok {
log.Debug().Str("obj", fmt.Sprintf("%#v", obj)).Msg("type assertion failed")
@@ -50,8 +50,8 @@ func FilterFn[T any](collection *[]T, predicate func(T) bool) func(obj any) (any
// FirstFn sets the element to the first one that satisfies the predicate and stops the computation, returns ErrStop on
// success
func FirstFn[T any](element *T, predicate func(T) bool) func(obj any) (any, error) {
return func(obj any) (any, error) {
func FirstFn[T any](element *T, predicate func(T) bool) func(obj interface{}) (interface{}, error) {
return func(obj interface{}) (interface{}, error) {
e, ok := obj.(*T)
if !ok {
log.Debug().Str("obj", fmt.Sprintf("%#v", obj)).Msg("type assertion failed")

View File

@@ -1,6 +1,8 @@
package dataservices
import (
"io"
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/database/models"
)
@@ -15,6 +17,7 @@ type (
Endpoint() EndpointService
EndpointGroup() EndpointGroupService
EndpointRelation() EndpointRelationService
FDOProfile() FDOProfileService
HelmUserRepository() HelmUserRepositoryService
Registry() RegistryService
ResourceControl() ResourceControlService
@@ -35,7 +38,6 @@ type (
}
DataStore interface {
Connection() portainer.Connection
Open() (newStore bool, err error)
Init() error
Close() error
@@ -44,7 +46,7 @@ type (
MigrateData() error
Rollback(force bool) error
CheckCurrentEdition() error
Backup(path string) (string, error)
BackupTo(w io.Writer) error
Export(filename string) (err error)
DataStoreTx
@@ -71,9 +73,8 @@ type (
}
PendingActionsService interface {
BaseCRUD[portainer.PendingAction, portainer.PendingActionID]
BaseCRUD[portainer.PendingActions, portainer.PendingActionsID]
GetNextIdentifier() int
DeleteByEndpointID(ID portainer.EndpointID) error
}
// EdgeStackService represents a service to manage Edge stacks
@@ -119,6 +120,12 @@ type (
BucketName() string
}
// FDOProfileService represents a service to manage FDO Profiles
FDOProfileService interface {
BaseCRUD[portainer.FDOProfile, portainer.FDOProfileID]
GetNextIdentifier() int
}
// HelmUserRepositoryService represents a service to manage HelmUserRepositories
HelmUserRepositoryService interface {
BaseCRUD[portainer.HelmUserRepository, portainer.HelmUserRepositoryID]
@@ -145,7 +152,7 @@ type (
APIKeyRepository interface {
BaseCRUD[portainer.APIKey, portainer.APIKeyID]
GetAPIKeysByUserID(userID portainer.UserID) ([]portainer.APIKey, error)
GetAPIKeyByDigest(digest string) (*portainer.APIKey, error)
GetAPIKeyByDigest(digest []byte) (*portainer.APIKey, error)
}
// SettingsService represents a service for managing application settings

View File

@@ -1,12 +1,10 @@
package pendingactions
import (
"fmt"
"time"
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/dataservices"
"github.com/rs/zerolog/log"
)
const (
@@ -14,11 +12,11 @@ const (
)
type Service struct {
dataservices.BaseDataService[portainer.PendingAction, portainer.PendingActionID]
dataservices.BaseDataService[portainer.PendingActions, portainer.PendingActionsID]
}
type ServiceTx struct {
dataservices.BaseDataServiceTx[portainer.PendingAction, portainer.PendingActionID]
dataservices.BaseDataServiceTx[portainer.PendingActions, portainer.PendingActionsID]
}
func NewService(connection portainer.Connection) (*Service, error) {
@@ -28,34 +26,28 @@ func NewService(connection portainer.Connection) (*Service, error) {
}
return &Service{
BaseDataService: dataservices.BaseDataService[portainer.PendingAction, portainer.PendingActionID]{
BaseDataService: dataservices.BaseDataService[portainer.PendingActions, portainer.PendingActionsID]{
Bucket: BucketName,
Connection: connection,
},
}, nil
}
func (s Service) Create(config *portainer.PendingAction) error {
func (s Service) Create(config *portainer.PendingActions) error {
return s.Connection.UpdateTx(func(tx portainer.Transaction) error {
return s.Tx(tx).Create(config)
})
}
func (s Service) Update(ID portainer.PendingActionID, config *portainer.PendingAction) error {
func (s Service) Update(ID portainer.PendingActionsID, config *portainer.PendingActions) error {
return s.Connection.UpdateTx(func(tx portainer.Transaction) error {
return s.Tx(tx).Update(ID, config)
})
}
func (s Service) DeleteByEndpointID(ID portainer.EndpointID) error {
return s.Connection.UpdateTx(func(tx portainer.Transaction) error {
return s.Tx(tx).DeleteByEndpointID(ID)
})
}
func (service *Service) Tx(tx portainer.Transaction) ServiceTx {
return ServiceTx{
BaseDataServiceTx: dataservices.BaseDataServiceTx[portainer.PendingAction, portainer.PendingActionID]{
BaseDataServiceTx: dataservices.BaseDataServiceTx[portainer.PendingActions, portainer.PendingActionsID]{
Bucket: BucketName,
Connection: service.Connection,
Tx: tx,
@@ -63,42 +55,19 @@ func (service *Service) Tx(tx portainer.Transaction) ServiceTx {
}
}
func (s ServiceTx) Create(config *portainer.PendingAction) error {
return s.Tx.CreateObject(BucketName, func(id uint64) (int, any) {
config.ID = portainer.PendingActionID(id)
func (s ServiceTx) Create(config *portainer.PendingActions) error {
return s.Tx.CreateObject(BucketName, func(id uint64) (int, interface{}) {
config.ID = portainer.PendingActionsID(id)
config.CreatedAt = time.Now().Unix()
return int(config.ID), config
})
}
func (s ServiceTx) Update(ID portainer.PendingActionID, config *portainer.PendingAction) error {
func (s ServiceTx) Update(ID portainer.PendingActionsID, config *portainer.PendingActions) error {
return s.BaseDataServiceTx.Update(ID, config)
}
func (s ServiceTx) DeleteByEndpointID(ID portainer.EndpointID) error {
log.Debug().Int("endpointId", int(ID)).Msg("deleting pending actions for endpoint")
pendingActions, err := s.BaseDataServiceTx.ReadAll()
if err != nil {
return fmt.Errorf("failed to retrieve pending-actions for endpoint (%d): %w", ID, err)
}
for _, pendingAction := range pendingActions {
if pendingAction.EndpointID == ID {
err := s.BaseDataServiceTx.Delete(pendingAction.ID)
if err != nil {
log.Debug().Int("endpointId", int(ID)).Msgf("failed to delete pending action: %v", err)
}
}
}
return nil
}
// GetNextIdentifier returns the next identifier for a custom template.
func (service ServiceTx) GetNextIdentifier() int {
return service.Tx.GetNextIdentifier(BucketName)
}
// GetNextIdentifier returns the next identifier for a custom template.
func (service *Service) GetNextIdentifier() int {
return service.Connection.GetNextIdentifier(BucketName)

View File

@@ -42,7 +42,7 @@ func (service *Service) Tx(tx portainer.Transaction) ServiceTx {
func (service *Service) Create(registry *portainer.Registry) error {
return service.Connection.CreateObject(
BucketName,
func(id uint64) (int, any) {
func(id uint64) (int, interface{}) {
registry.ID = portainer.RegistryID(id)
return int(registry.ID), registry
},

View File

@@ -13,7 +13,7 @@ type ServiceTx struct {
func (service ServiceTx) Create(registry *portainer.Registry) error {
return service.Tx.CreateObject(
BucketName,
func(id uint64) (int, any) {
func(id uint64) (int, interface{}) {
registry.ID = portainer.RegistryID(id)
return int(registry.ID), registry
},

View File

@@ -52,7 +52,7 @@ func (service *Service) ResourceControlByResourceIDAndType(resourceID string, re
err := service.Connection.GetAll(
BucketName,
&portainer.ResourceControl{},
func(obj any) (any, error) {
func(obj interface{}) (interface{}, error) {
rc, ok := obj.(*portainer.ResourceControl)
if !ok {
log.Debug().Str("obj", fmt.Sprintf("%#v", obj)).Msg("failed to convert to ResourceControl object")
@@ -84,7 +84,7 @@ func (service *Service) ResourceControlByResourceIDAndType(resourceID string, re
func (service *Service) Create(resourceControl *portainer.ResourceControl) error {
return service.Connection.CreateObject(
BucketName,
func(id uint64) (int, any) {
func(id uint64) (int, interface{}) {
resourceControl.ID = portainer.ResourceControlID(id)
return int(resourceControl.ID), resourceControl
},

View File

@@ -23,7 +23,7 @@ func (service ServiceTx) ResourceControlByResourceIDAndType(resourceID string, r
err := service.Tx.GetAll(
BucketName,
&portainer.ResourceControl{},
func(obj any) (any, error) {
func(obj interface{}) (interface{}, error) {
rc, ok := obj.(*portainer.ResourceControl)
if !ok {
log.Debug().Str("obj", fmt.Sprintf("%#v", obj)).Msg("failed to convert to ResourceControl object")
@@ -55,7 +55,7 @@ func (service ServiceTx) ResourceControlByResourceIDAndType(resourceID string, r
func (service ServiceTx) Create(resourceControl *portainer.ResourceControl) error {
return service.Tx.CreateObject(
BucketName,
func(id uint64) (int, any) {
func(id uint64) (int, interface{}) {
resourceControl.ID = portainer.ResourceControlID(id)
return int(resourceControl.ID), resourceControl
},

View File

@@ -42,7 +42,7 @@ func (service *Service) Tx(tx portainer.Transaction) ServiceTx {
func (service *Service) Create(role *portainer.Role) error {
return service.Connection.CreateObject(
BucketName,
func(id uint64) (int, any) {
func(id uint64) (int, interface{}) {
role.ID = portainer.RoleID(id)
return int(role.ID), role
},

View File

@@ -13,7 +13,7 @@ type ServiceTx struct {
func (service ServiceTx) Create(role *portainer.Role) error {
return service.Tx.CreateObject(
BucketName,
func(id uint64) (int, any) {
func(id uint64) (int, interface{}) {
role.ID = portainer.RoleID(id)
return int(role.ID), role
},

View File

@@ -33,7 +33,7 @@ func TestService_StackByWebhookID(t *testing.T) {
b := stackBuilder{t: t, store: store}
b.createNewStack(newGuidString(t))
for range 10 {
for i := 0; i < 10; i++ {
b.createNewStack("")
}
webhookID := newGuidString(t)

View File

@@ -42,7 +42,7 @@ func (service *Service) Tx(tx portainer.Transaction) ServiceTx {
func (service *Service) Create(tag *portainer.Tag) error {
return service.Connection.CreateObject(
BucketName,
func(id uint64) (int, any) {
func(id uint64) (int, interface{}) {
tag.ID = portainer.TagID(id)
return int(tag.ID), tag
},

View File

@@ -15,7 +15,7 @@ type ServiceTx struct {
func (service ServiceTx) Create(tag *portainer.Tag) error {
return service.Tx.CreateObject(
BucketName,
func(id uint64) (int, any) {
func(id uint64) (int, interface{}) {
tag.ID = portainer.TagID(id)
return int(tag.ID), tag
},

View File

@@ -19,7 +19,8 @@ type Service struct {
// NewService creates a new instance of a service.
func NewService(connection portainer.Connection) (*Service, error) {
if err := connection.SetServiceName(BucketName); err != nil {
err := connection.SetServiceName(BucketName)
if err != nil {
return nil, err
}
@@ -31,16 +32,6 @@ func NewService(connection portainer.Connection) (*Service, error) {
}, nil
}
func (service *Service) Tx(tx portainer.Transaction) ServiceTx {
return ServiceTx{
BaseDataServiceTx: dataservices.BaseDataServiceTx[portainer.Team, portainer.TeamID]{
Bucket: BucketName,
Connection: service.Connection,
Tx: tx,
},
}
}
// TeamByName returns a team by name.
func (service *Service) TeamByName(name string) (*portainer.Team, error) {
var t portainer.Team
@@ -68,7 +59,7 @@ func (service *Service) TeamByName(name string) (*portainer.Team, error) {
func (service *Service) Create(team *portainer.Team) error {
return service.Connection.CreateObject(
BucketName,
func(id uint64) (int, any) {
func(id uint64) (int, interface{}) {
team.ID = portainer.TeamID(id)
return int(team.ID), team
},

View File

@@ -1,48 +0,0 @@
package team
import (
"errors"
"strings"
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/dataservices"
dserrors "github.com/portainer/portainer/api/dataservices/errors"
)
type ServiceTx struct {
dataservices.BaseDataServiceTx[portainer.Team, portainer.TeamID]
}
// TeamByName returns a team by name.
func (service ServiceTx) TeamByName(name string) (*portainer.Team, error) {
var t portainer.Team
err := service.Tx.GetAll(
BucketName,
&portainer.Team{},
dataservices.FirstFn(&t, func(e portainer.Team) bool {
return strings.EqualFold(e.Name, name)
}),
)
if errors.Is(err, dataservices.ErrStop) {
return &t, nil
}
if err == nil {
return nil, dserrors.ErrObjectNotFound
}
return nil, err
}
// CreateTeam creates a new Team.
func (service ServiceTx) Create(team *portainer.Team) error {
return service.Tx.CreateObject(
BucketName,
func(id uint64) (int, any) {
team.ID = portainer.TeamID(id)
return int(team.ID), team
},
)
}

View File

@@ -72,7 +72,7 @@ func (service *Service) TeamMembershipsByTeamID(teamID portainer.TeamID) ([]port
func (service *Service) Create(membership *portainer.TeamMembership) error {
return service.Connection.CreateObject(
BucketName,
func(id uint64) (int, any) {
func(id uint64) (int, interface{}) {
membership.ID = portainer.TeamMembershipID(id)
return int(membership.ID), membership
},
@@ -84,8 +84,8 @@ func (service *Service) DeleteTeamMembershipByUserID(userID portainer.UserID) er
return service.Connection.DeleteAllObjects(
BucketName,
&portainer.TeamMembership{},
func(obj any) (id int, ok bool) {
membership, ok := obj.(*portainer.TeamMembership)
func(obj interface{}) (id int, ok bool) {
membership, ok := obj.(portainer.TeamMembership)
if !ok {
log.Debug().Str("obj", fmt.Sprintf("%#v", obj)).Msg("failed to convert to TeamMembership object")
//return fmt.Errorf("Failed to convert to TeamMembership object: %s", obj)
@@ -105,8 +105,8 @@ func (service *Service) DeleteTeamMembershipByTeamID(teamID portainer.TeamID) er
return service.Connection.DeleteAllObjects(
BucketName,
&portainer.TeamMembership{},
func(obj any) (id int, ok bool) {
membership, ok := obj.(*portainer.TeamMembership)
func(obj interface{}) (id int, ok bool) {
membership, ok := obj.(portainer.TeamMembership)
if !ok {
log.Debug().Str("obj", fmt.Sprintf("%#v", obj)).Msg("failed to convert to TeamMembership object")
//return fmt.Errorf("Failed to convert to TeamMembership object: %s", obj)
@@ -125,8 +125,8 @@ func (service *Service) DeleteTeamMembershipByTeamIDAndUserID(teamID portainer.T
return service.Connection.DeleteAllObjects(
BucketName,
&portainer.TeamMembership{},
func(obj any) (id int, ok bool) {
membership, ok := obj.(*portainer.TeamMembership)
func(obj interface{}) (id int, ok bool) {
membership, ok := obj.(portainer.TeamMembership)
if !ok {
log.Debug().Str("obj", fmt.Sprintf("%#v", obj)).Msg("failed to convert to TeamMembership object")
//return fmt.Errorf("Failed to convert to TeamMembership object: %s", obj)

View File

@@ -43,7 +43,7 @@ func (service ServiceTx) TeamMembershipsByTeamID(teamID portainer.TeamID) ([]por
func (service ServiceTx) Create(membership *portainer.TeamMembership) error {
return service.Tx.CreateObject(
BucketName,
func(id uint64) (int, any) {
func(id uint64) (int, interface{}) {
membership.ID = portainer.TeamMembershipID(id)
return int(membership.ID), membership
},
@@ -55,7 +55,7 @@ func (service ServiceTx) DeleteTeamMembershipByUserID(userID portainer.UserID) e
return service.Tx.DeleteAllObjects(
BucketName,
&portainer.TeamMembership{},
func(obj any) (id int, ok bool) {
func(obj interface{}) (id int, ok bool) {
membership, ok := obj.(portainer.TeamMembership)
if !ok {
log.Debug().Str("obj", fmt.Sprintf("%#v", obj)).Msg("failed to convert to TeamMembership object")
@@ -76,7 +76,7 @@ func (service ServiceTx) DeleteTeamMembershipByTeamID(teamID portainer.TeamID) e
return service.Tx.DeleteAllObjects(
BucketName,
&portainer.TeamMembership{},
func(obj any) (id int, ok bool) {
func(obj interface{}) (id int, ok bool) {
membership, ok := obj.(portainer.TeamMembership)
if !ok {
log.Debug().Str("obj", fmt.Sprintf("%#v", obj)).Msg("failed to convert to TeamMembership object")
@@ -96,7 +96,7 @@ func (service ServiceTx) DeleteTeamMembershipByTeamIDAndUserID(teamID portainer.
return service.Tx.DeleteAllObjects(
BucketName,
&portainer.TeamMembership{},
func(obj any) (id int, ok bool) {
func(obj interface{}) (id int, ok bool) {
membership, ok := obj.(portainer.TeamMembership)
if !ok {
log.Debug().Str("obj", fmt.Sprintf("%#v", obj)).Msg("failed to convert to TeamMembership object")

View File

@@ -53,7 +53,7 @@ func (service ServiceTx) UsersByRole(role portainer.UserRole) ([]portainer.User,
func (service ServiceTx) Create(user *portainer.User) error {
return service.Tx.CreateObject(
BucketName,
func(id uint64) (int, any) {
func(id uint64) (int, interface{}) {
user.ID = portainer.UserID(id)
user.Username = strings.ToLower(user.Username)

View File

@@ -82,7 +82,7 @@ func (service *Service) UsersByRole(role portainer.UserRole) ([]portainer.User,
func (service *Service) Create(user *portainer.User) error {
return service.Connection.CreateObject(
BucketName,
func(id uint64) (int, any) {
func(id uint64) (int, interface{}) {
user.ID = portainer.UserID(id)
user.Username = strings.ToLower(user.Username)

View File

@@ -81,7 +81,7 @@ func (service *Service) WebhookByToken(token string) (*portainer.Webhook, error)
func (service *Service) Create(webhook *portainer.Webhook) error {
return service.Connection.CreateObject(
BucketName,
func(id uint64) (int, any) {
func(id uint64) (int, interface{}) {
webhook.ID = portainer.WebhookID(id)
return int(webhook.ID), webhook
},

View File

@@ -9,19 +9,12 @@ import (
"github.com/rs/zerolog/log"
)
// Backup takes an optional output path and creates a backup of the database.
// The database connection is stopped before running the backup to avoid any
// corruption and if a path is not given a default is used.
// The path or an error are returned.
func (store *Store) Backup(path string) (string, error) {
func (store *Store) Backup() (string, error) {
if err := store.createBackupPath(); err != nil {
return "", err
}
backupFilename := store.backupFilename()
if path != "" {
backupFilename = path
}
log.Info().Str("from", store.connection.GetDatabaseFilePath()).Str("to", backupFilename).Msgf("Backing up database")
// Close the store before backing up
@@ -76,7 +69,7 @@ func (store *Store) RestoreFromFile(backupFilename string) error {
func (store *Store) createBackupPath() error {
backupDir := path.Join(store.connection.GetStorePath(), "backups")
if exists, _ := store.fileService.FileExists(backupDir); !exists {
if err := os.MkdirAll(backupDir, 0o700); err != nil {
if err := os.MkdirAll(backupDir, 0700); err != nil {
return fmt.Errorf("unable to create backup folder: %w", err)
}
}

View File

@@ -39,7 +39,7 @@ func TestBackup(t *testing.T) {
SchemaVersion: portainer.APIVersion,
}
store.VersionService.UpdateVersion(&v)
store.Backup("")
store.Backup()
if !isFileExist(backupFileName) {
t.Errorf("Expect backup file to be created %s", backupFileName)
@@ -50,12 +50,12 @@ func TestBackup(t *testing.T) {
func TestRestore(t *testing.T) {
_, store := MustNewTestStore(t, true, false)
t.Run("Basic Restore", func(t *testing.T) {
t.Run(fmt.Sprintf("Basic Restore"), func(t *testing.T) {
// override and set initial db version and edition
updateEdition(store, portainer.PortainerCE)
updateVersion(store, "2.4")
store.Backup("")
store.Backup()
updateVersion(store, "2.16")
testVersion(store, "2.16", t)
store.Restore()
@@ -64,11 +64,11 @@ func TestRestore(t *testing.T) {
testVersion(store, "2.4", t)
})
t.Run("Basic Restore After Multiple Backups", func(t *testing.T) {
t.Run(fmt.Sprintf("Basic Restore After Multiple Backups"), func(t *testing.T) {
// override and set initial db version and edition
updateEdition(store, portainer.PortainerCE)
updateVersion(store, "2.4")
store.Backup("")
store.Backup()
updateVersion(store, "2.14")
updateVersion(store, "2.16")
testVersion(store, "2.16", t)

View File

@@ -31,7 +31,7 @@ func (store *Store) Open() (newStore bool, err error) {
}
if encryptionReq {
backupFilename, err := store.Backup("")
backupFilename, err := store.Backup()
if err != nil {
return false, fmt.Errorf("failed to backup database prior to encrypting: %w", err)
}

View File

@@ -56,3 +56,13 @@ func testVersion(store *Store, versionWant string, t *testing.T) {
t.Errorf("Expect store version to be %s but was %s", versionWant, v.SchemaVersion)
}
}
func testEdition(store *Store, editionWant portainer.SoftwareEdition, t *testing.T) {
v, err := store.VersionService.Version()
if err != nil {
log.Fatal().Err(err).Msg("")
}
if portainer.SoftwareEdition(v.Edition) != editionWant {
t.Errorf("Expect store edition to be %s but was %s", editionWant.GetEditionLabel(), portainer.SoftwareEdition(v.Edition).GetEditionLabel())
}
}

View File

@@ -40,7 +40,7 @@ func (store *Store) MigrateData() error {
}
// before we alter anything in the DB, create a backup
_, err = store.Backup("")
_, err = store.Backup()
if err != nil {
return errors.Wrap(err, "while backing up database")
}
@@ -86,7 +86,6 @@ func (store *Store) newMigratorParameters(version *models.Version) *migrator.Mig
EdgeStackService: store.EdgeStackService,
EdgeJobService: store.EdgeJobService,
TunnelServerService: store.TunnelServerService,
PendingActionsService: store.PendingActionsService,
}
}
@@ -132,6 +131,7 @@ func (store *Store) FailSafeMigrate(migrator *migrator.Migrator, version *models
// Rollback to a pre-upgrade backup copy/snapshot of portainer.db
func (store *Store) connectionRollback(force bool) error {
if !force {
confirmed, err := cli.Confirm("Are you sure you want to rollback your database to the previous backup?")
if err != nil || !confirmed {

View File

@@ -165,7 +165,7 @@ func TestRollback(t *testing.T) {
_, store := MustNewTestStore(t, false, false)
store.VersionService.UpdateVersion(&v)
_, err := store.Backup("")
_, err := store.Backup()
if err != nil {
log.Fatal().Err(err).Msg("")
}
@@ -199,7 +199,7 @@ func TestRollback(t *testing.T) {
_, store := MustNewTestStore(t, true, false)
store.VersionService.UpdateVersion(&v)
_, err := store.Backup("")
_, err := store.Backup()
if err != nil {
log.Fatal().Err(err).Msg("")
}
@@ -305,7 +305,7 @@ func migrateDBTestHelper(t *testing.T, srcPath, wantPath string, overrideInstanc
os.WriteFile(
gotPath,
gotJSON,
0o600,
0600,
)
t.Errorf(
"migrate data from %s to %s failed\nwrote migrated input to %s\nmismatch (-want +got):\n%s",
@@ -321,7 +321,7 @@ func migrateDBTestHelper(t *testing.T, srcPath, wantPath string, overrideInstanc
// importJSON reads input JSON and commits it to a portainer datastore.Store.
// Errors are logged with the testing package.
func importJSON(t *testing.T, r io.Reader, store *Store) error {
objects := make(map[string]any)
objects := make(map[string]interface{})
// Parse json into map of objects.
d := json.NewDecoder(r)
@@ -337,9 +337,9 @@ func importJSON(t *testing.T, r io.Reader, store *Store) error {
for k, v := range objects {
switch k {
case "version":
versions, ok := v.(map[string]any)
versions, ok := v.(map[string]interface{})
if !ok {
t.Logf("failed casting %s to map[string]any", k)
t.Logf("failed casting %s to map[string]interface{}", k)
}
// New format db
@@ -404,9 +404,9 @@ func importJSON(t *testing.T, r io.Reader, store *Store) error {
}
case "dockerhub":
obj, ok := v.([]any)
obj, ok := v.([]interface{})
if !ok {
t.Logf("failed to cast %s to []any", k)
t.Logf("failed to cast %s to []interface{}", k)
}
err := con.CreateObjectWithStringId(
k,
@@ -418,9 +418,9 @@ func importJSON(t *testing.T, r io.Reader, store *Store) error {
}
case "ssl":
obj, ok := v.(map[string]any)
obj, ok := v.(map[string]interface{})
if !ok {
t.Logf("failed to case %s to map[string]any", k)
t.Logf("failed to case %s to map[string]interface{}", k)
}
err := con.CreateObjectWithStringId(
k,
@@ -432,9 +432,9 @@ func importJSON(t *testing.T, r io.Reader, store *Store) error {
}
case "settings":
obj, ok := v.(map[string]any)
obj, ok := v.(map[string]interface{})
if !ok {
t.Logf("failed to case %s to map[string]any", k)
t.Logf("failed to case %s to map[string]interface{}", k)
}
err := con.CreateObjectWithStringId(
k,
@@ -446,9 +446,9 @@ func importJSON(t *testing.T, r io.Reader, store *Store) error {
}
case "tunnel_server":
obj, ok := v.(map[string]any)
obj, ok := v.(map[string]interface{})
if !ok {
t.Logf("failed to case %s to map[string]any", k)
t.Logf("failed to case %s to map[string]interface{}", k)
}
err := con.CreateObjectWithStringId(
k,
@@ -462,18 +462,18 @@ func importJSON(t *testing.T, r io.Reader, store *Store) error {
continue
default:
objlist, ok := v.([]any)
objlist, ok := v.([]interface{})
if !ok {
t.Logf("failed to cast %s to []any", k)
t.Logf("failed to cast %s to []interface{}", k)
}
for _, obj := range objlist {
value, ok := obj.(map[string]any)
value, ok := obj.(map[string]interface{})
if !ok {
t.Logf("failed to cast %v to map[string]any", obj)
t.Logf("failed to cast %v to map[string]interface{}", obj)
} else {
var ok bool
var id any
var id interface{}
switch k {
case "endpoint_relations":
// TODO: need to make into an int, then do that weird

View File

@@ -12,13 +12,13 @@ const dummyLogoURL = "example.com"
// initTestingDBConn creates a settings service with raw database DB connection
// for unit testing usage only since using NewStore will cause cycle import inside migrator pkg
func initTestingSettingsService(dbConn portainer.Connection, preSetObj map[string]any) error {
func initTestingSettingsService(dbConn portainer.Connection, preSetObj map[string]interface{}) error {
//insert a obj
return dbConn.UpdateObject("settings", []byte("SETTINGS"), preSetObj)
}
func setup(store *Store) error {
dummySettingsObj := map[string]any{
dummySettingsObj := map[string]interface{}{
"LogoURL": dummyLogoURL,
}

View File

@@ -99,7 +99,7 @@ func (store *Store) getOrMigrateLegacyVersion() (*models.Version, error) {
return &models.Version{
SchemaVersion: dbVersionToSemanticVersion(dbVersion),
Edition: edition,
InstanceID: instanceId,
InstanceID: string(instanceId),
}, nil
}
@@ -111,6 +111,5 @@ func (store *Store) finishMigrateLegacyVersion(versionToWrite *models.Version) e
store.connection.DeleteObject(bucketName, []byte(legacyDBVersionKey))
store.connection.DeleteObject(bucketName, []byte(legacyEditionKey))
store.connection.DeleteObject(bucketName, []byte(legacyInstanceKey))
return err
}

View File

@@ -0,0 +1,117 @@
package datastore
import (
"context"
"github.com/docker/docker/api/types"
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/dataservices"
dockerclient "github.com/portainer/portainer/api/docker/client"
"github.com/portainer/portainer/api/kubernetes/cli"
"github.com/rs/zerolog/log"
)
type PostInitMigrator struct {
kubeFactory *cli.ClientFactory
dockerFactory *dockerclient.ClientFactory
dataStore dataservices.DataStore
}
func NewPostInitMigrator(kubeFactory *cli.ClientFactory, dockerFactory *dockerclient.ClientFactory, dataStore dataservices.DataStore) *PostInitMigrator {
return &PostInitMigrator{
kubeFactory: kubeFactory,
dockerFactory: dockerFactory,
dataStore: dataStore,
}
}
func (migrator *PostInitMigrator) PostInitMigrate() error {
if err := migrator.PostInitMigrateIngresses(); err != nil {
return err
}
migrator.PostInitMigrateGPUs()
return nil
}
func (migrator *PostInitMigrator) PostInitMigrateIngresses() error {
endpoints, err := migrator.dataStore.Endpoint().Endpoints()
if err != nil {
return err
}
for i := range endpoints {
// Early exit if we do not need to migrate!
if !endpoints[i].PostInitMigrations.MigrateIngresses {
return nil
}
err := migrator.kubeFactory.MigrateEndpointIngresses(&endpoints[i])
if err != nil {
log.Debug().Err(err).Msg("failure migrating endpoint ingresses")
}
}
return nil
}
// PostInitMigrateGPUs will check all docker endpoints for containers with GPUs and set EnableGPUManagement to true if any are found
// If there's an error getting the containers, we'll log it and move on
func (migrator *PostInitMigrator) PostInitMigrateGPUs() {
environments, err := migrator.dataStore.Endpoint().Endpoints()
if err != nil {
log.Err(err).Msg("failure getting endpoints")
return
}
for i := range environments {
if environments[i].Type == portainer.DockerEnvironment {
// // Early exit if we do not need to migrate!
if !environments[i].PostInitMigrations.MigrateGPUs {
return
}
// set the MigrateGPUs flag to false so we don't run this again
environments[i].PostInitMigrations.MigrateGPUs = false
migrator.dataStore.Endpoint().UpdateEndpoint(environments[i].ID, &environments[i])
// create a docker client
dockerClient, err := migrator.dockerFactory.CreateClient(&environments[i], "", nil)
if err != nil {
log.Err(err).Msg("failure creating docker client for environment: " + environments[i].Name)
return
}
defer dockerClient.Close()
// get all containers
containers, err := dockerClient.ContainerList(context.Background(), types.ContainerListOptions{All: true})
if err != nil {
log.Err(err).Msg("failed to list containers")
return
}
// check for a gpu on each container. If even one GPU is found, set EnableGPUManagement to true for the whole endpoint
containersLoop:
for _, container := range containers {
// https://www.sobyte.net/post/2022-10/go-docker/ has nice documentation on the docker client with GPUs
containerDetails, err := dockerClient.ContainerInspect(context.Background(), container.ID)
if err != nil {
log.Err(err).Msg("failed to inspect container")
return
}
deviceRequests := containerDetails.HostConfig.Resources.DeviceRequests
for _, deviceRequest := range deviceRequests {
if deviceRequest.Driver == "nvidia" {
environments[i].EnableGPUManagement = true
migrator.dataStore.Endpoint().UpdateEndpoint(environments[i].ID, &environments[i])
break containersLoop
}
}
}
}
}
}

View File

@@ -15,7 +15,7 @@ func migrationError(err error, context string) error {
return errors.Wrap(err, "failed in "+context)
}
func GetFunctionName(i any) string {
func GetFunctionName(i interface{}) string {
return runtime.FuncForPC(reflect.ValueOf(i).Pointer()).Name()
}
@@ -39,19 +39,20 @@ func (m *Migrator) Migrate() error {
latestMigrations := m.LatestMigrations()
if latestMigrations.Version.Equal(schemaVersion) &&
version.MigratorCount != len(latestMigrations.MigrationFuncs) {
if err := runMigrations(latestMigrations.MigrationFuncs); err != nil {
err := runMigrations(latestMigrations.MigrationFuncs)
if err != nil {
return err
}
newMigratorCount = len(latestMigrations.MigrationFuncs)
}
} else {
// regular path when major/minor/patch versions differ
for _, migration := range m.migrations {
if schemaVersion.LessThan(migration.Version) {
log.Info().Msgf("migrating data to %s", migration.Version.String())
if err := runMigrations(migration.MigrationFuncs); err != nil {
log.Info().Msgf("migrating data to %s", migration.Version.String())
err := runMigrations(migration.MigrationFuncs)
if err != nil {
return err
}
}
@@ -62,14 +63,16 @@ func (m *Migrator) Migrate() error {
}
}
if err := m.Always(); err != nil {
err = m.Always()
if err != nil {
return migrationError(err, "Always migrations returned error")
}
version.SchemaVersion = portainer.APIVersion
version.MigratorCount = newMigratorCount
if err := m.versionService.UpdateVersion(version); err != nil {
err = m.versionService.UpdateVersion(version)
if err != nil {
return migrationError(err, "StoreDBVersion")
}
@@ -96,7 +99,6 @@ func (m *Migrator) NeedsMigration() bool {
// In this particular instance we should log a fatal error
if m.CurrentDBEdition() != portainer.PortainerCE {
log.Fatal().Msg("the Portainer database is set for Portainer Business Edition, please follow the instructions in our documentation to downgrade it: https://documentation.portainer.io/v2.0-be/downgrade/be-to-ce/")
return false
}

View File

@@ -7,7 +7,6 @@ import (
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/chisel/crypto"
"github.com/portainer/portainer/api/dataservices"
"github.com/rs/zerolog/log"
)
@@ -38,11 +37,9 @@ func (m *Migrator) convertSeedToPrivateKeyForDB100() error {
log.Info().Msg("ServerInfo object not found")
return nil
}
log.Error().
Err(err).
Msg("Failed to read ServerInfo from DB")
return err
}
@@ -52,15 +49,14 @@ func (m *Migrator) convertSeedToPrivateKeyForDB100() error {
log.Error().
Err(err).
Msg("Failed to read ServerInfo from DB")
return err
}
if err := m.fileService.StoreChiselPrivateKey(key); err != nil {
err = m.fileService.StoreChiselPrivateKey(key)
if err != nil {
log.Error().
Err(err).
Msg("Failed to save Chisel private key to disk")
return err
}
} else {
@@ -68,14 +64,14 @@ func (m *Migrator) convertSeedToPrivateKeyForDB100() error {
}
serverInfo.PrivateKeySeed = ""
if err := m.TunnelServerService.UpdateInfo(serverInfo); err != nil {
err = m.TunnelServerService.UpdateInfo(serverInfo)
if err != nil {
log.Error().
Err(err).
Msg("Failed to clean private key seed in DB")
} else {
log.Info().Msg("Success to migrate private key seed to private key file")
}
return err
}
@@ -88,8 +84,9 @@ func (m *Migrator) updateEdgeStackStatusForDB100() error {
}
for _, edgeStack := range edgeStacks {
for environmentID, environmentStatus := range edgeStack.Status {
// Skip if status is already updated
// skip if status is already updated
if len(environmentStatus.Status) > 0 {
continue
}
@@ -149,7 +146,10 @@ func (m *Migrator) updateEdgeStackStatusForDB100() error {
edgeStack.Status[environmentID] = environmentStatus
}
if err := m.edgeStackService.UpdateEdgeStack(edgeStack.ID, &edgeStack); err != nil {
err = m.edgeStackService.UpdateEdgeStackFunc(edgeStack.ID, func(stack *portainer.EdgeStack) {
stack.Status = edgeStack.Status
})
if err != nil {
return err
}
}

View File

@@ -24,26 +24,18 @@ func (migrator *Migrator) updateAppTemplatesVersionForDB110() error {
return migrator.settingsService.UpdateSettings(settings)
}
// In PortainerCE the resource overcommit option should always be true across all endpoints
func (migrator *Migrator) updateResourceOverCommitToDB110() error {
log.Info().Msg("updating resource overcommit setting to true")
endpoints, err := migrator.endpointService.Endpoints()
// setUseCacheForDB110 sets the user cache to true for all users
func (migrator *Migrator) setUserCacheForDB110() error {
users, err := migrator.userService.ReadAll()
if err != nil {
return err
}
for _, endpoint := range endpoints {
if endpoint.Type == portainer.KubernetesLocalEnvironment ||
endpoint.Type == portainer.AgentOnKubernetesEnvironment ||
endpoint.Type == portainer.EdgeAgentOnKubernetesEnvironment {
endpoint.Kubernetes.Configuration.EnableResourceOverCommit = true
err = migrator.endpointService.UpdateEndpoint(endpoint.ID, &endpoint)
if err != nil {
return err
}
for i := range users {
user := &users[i]
user.UseCache = true
if err := migrator.userService.Update(user.ID, user); err != nil {
return err
}
}

View File

@@ -1,32 +0,0 @@
package migrator
import (
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/dataservices"
"github.com/rs/zerolog/log"
)
func (migrator *Migrator) cleanPendingActionsForDeletedEndpointsForDB111() error {
log.Info().Msg("cleaning up pending actions for deleted endpoints")
pendingActions, err := migrator.pendingActionsService.ReadAll()
if err != nil {
return err
}
endpoints := make(map[portainer.EndpointID]struct{})
for _, action := range pendingActions {
endpoints[action.EndpointID] = struct{}{}
}
for endpointId := range endpoints {
_, err := migrator.endpointService.Endpoint(endpointId)
if dataservices.IsErrObjectNotFound(err) {
err := migrator.pendingActionsService.DeleteByEndpointID(endpointId)
if err != nil {
return err
}
}
}
return nil
}

View File

@@ -1,33 +0,0 @@
package migrator
import (
"github.com/segmentio/encoding/json"
"github.com/rs/zerolog/log"
)
func (migrator *Migrator) migratePendingActionsDataForDB130() error {
log.Info().Msg("Migrating pending actions data")
pendingActions, err := migrator.pendingActionsService.ReadAll()
if err != nil {
return err
}
for _, pa := range pendingActions {
actionData, err := json.Marshal(pa.ActionData)
if err != nil {
return err
}
pa.ActionData = string(actionData)
// Update the pending action
err = migrator.pendingActionsService.Update(pa.ID, &pa)
if err != nil {
return err
}
}
return nil
}

View File

@@ -56,6 +56,9 @@ func (m *Migrator) updateUsersAndRolesToDBVersion22() error {
endpointAdministratorRole.Authorizations = authorization.DefaultEndpointAuthorizationsForEndpointAdministratorRole()
err = m.roleService.Update(endpointAdministratorRole.ID, endpointAdministratorRole)
if err != nil {
return err
}
helpDeskRole, err := m.roleService.Read(portainer.RoleID(2))
if err != nil {
@@ -65,6 +68,9 @@ func (m *Migrator) updateUsersAndRolesToDBVersion22() error {
helpDeskRole.Authorizations = authorization.DefaultEndpointAuthorizationsForHelpDeskRole(settings.AllowVolumeBrowserForRegularUsers)
err = m.roleService.Update(helpDeskRole.ID, helpDeskRole)
if err != nil {
return err
}
standardUserRole, err := m.roleService.Read(portainer.RoleID(3))
if err != nil {
@@ -74,6 +80,9 @@ func (m *Migrator) updateUsersAndRolesToDBVersion22() error {
standardUserRole.Authorizations = authorization.DefaultEndpointAuthorizationsForStandardUserRole(settings.AllowVolumeBrowserForRegularUsers)
err = m.roleService.Update(standardUserRole.ID, standardUserRole)
if err != nil {
return err
}
readOnlyUserRole, err := m.roleService.Read(portainer.RoleID(4))
if err != nil {

View File

@@ -32,8 +32,8 @@ func (m *Migrator) updateStacksToDB24() error {
for idx := range stacks {
stack := &stacks[idx]
stack.Status = portainer.StackStatusActive
if err := m.stackService.Update(stack.ID, stack); err != nil {
err := m.stackService.Update(stack.ID, stack)
if err != nil {
return err
}
}

View File

@@ -123,7 +123,7 @@ func (m *Migrator) updateDockerhubToDB32() error {
migrated = true
} else {
// delete subsequent duplicates
m.registryService.Delete(r.ID)
m.registryService.Delete(portainer.RegistryID(r.ID))
}
}
}

View File

@@ -93,7 +93,9 @@ func (m *Migrator) updateEdgeStackStatusForDB80() error {
edgeStack.Status[endpointId] = status
}
err = m.edgeStackService.UpdateEdgeStack(edgeStack.ID, &edgeStack)
err = m.edgeStackService.UpdateEdgeStackFunc(edgeStack.ID, func(stack *portainer.EdgeStack) {
stack.Status = edgeStack.Status
})
if err != nil {
return err
}

View File

@@ -13,7 +13,7 @@ import (
"github.com/portainer/portainer/api/dataservices/endpointgroup"
"github.com/portainer/portainer/api/dataservices/endpointrelation"
"github.com/portainer/portainer/api/dataservices/extension"
"github.com/portainer/portainer/api/dataservices/pendingactions"
"github.com/portainer/portainer/api/dataservices/fdoprofile"
"github.com/portainer/portainer/api/dataservices/registry"
"github.com/portainer/portainer/api/dataservices/resourcecontrol"
"github.com/portainer/portainer/api/dataservices/role"
@@ -40,6 +40,7 @@ type (
endpointService *endpoint.Service
endpointRelationService *endpointrelation.Service
extensionService *extension.Service
fdoProfilesService *fdoprofile.Service
registryService *registry.Service
resourceControlService *resourcecontrol.Service
roleService *role.Service
@@ -57,7 +58,6 @@ type (
edgeStackService *edgestack.Service
edgeJobService *edgejob.Service
TunnelServerService *tunnelserver.Service
pendingActionsService *pendingactions.Service
}
// MigratorParameters represents the required parameters to create a new Migrator instance.
@@ -67,6 +67,7 @@ type (
EndpointService *endpoint.Service
EndpointRelationService *endpointrelation.Service
ExtensionService *extension.Service
FDOProfilesService *fdoprofile.Service
RegistryService *registry.Service
ResourceControlService *resourcecontrol.Service
RoleService *role.Service
@@ -84,7 +85,6 @@ type (
EdgeStackService *edgestack.Service
EdgeJobService *edgejob.Service
TunnelServerService *tunnelserver.Service
PendingActionsService *pendingactions.Service
}
)
@@ -96,6 +96,7 @@ func NewMigrator(parameters *MigratorParameters) *Migrator {
endpointService: parameters.EndpointService,
endpointRelationService: parameters.EndpointRelationService,
extensionService: parameters.ExtensionService,
fdoProfilesService: parameters.FDOProfilesService,
registryService: parameters.RegistryService,
resourceControlService: parameters.ResourceControlService,
roleService: parameters.RoleService,
@@ -113,7 +114,6 @@ func NewMigrator(parameters *MigratorParameters) *Migrator {
edgeStackService: parameters.EdgeStackService,
edgeJobService: parameters.EdgeJobService,
TunnelServerService: parameters.TunnelServerService,
pendingActionsService: parameters.PendingActionsService,
}
migrator.initMigrations()
@@ -230,16 +230,10 @@ func (m *Migrator) initMigrations() {
)
m.addMigrations("2.20",
m.updateAppTemplatesVersionForDB110,
m.updateResourceOverCommitToDB110,
)
m.addMigrations("2.20.2",
m.cleanPendingActionsForDeletedEndpointsForDB111,
)
m.addMigrations("2.22.0",
m.migratePendingActionsDataForDB130,
m.setUserCacheForDB110,
)
// Add new migrations above...
// Add new migrations below...
// One function per migration, each versions migration funcs in the same file.
}

View File

@@ -1,98 +0,0 @@
package datastore
import (
"testing"
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/pendingactions/actions"
"github.com/portainer/portainer/api/pendingactions/handlers"
)
type cleanNAPWithOverridePolicies struct {
EndpointGroupID portainer.EndpointGroupID
}
func Test_ConvertCleanNAPWithOverridePoliciesPayload(t *testing.T) {
t.Run("test ConvertCleanNAPWithOverridePoliciesPayload", func(t *testing.T) {
_, store := MustNewTestStore(t, true, false)
defer store.Close()
gid := portainer.EndpointGroupID(1)
testData := []struct {
Name string
PendingAction portainer.PendingAction
Expected any
Err bool
}{
{
Name: "test actiondata with EndpointGroupID 1",
PendingAction: handlers.NewCleanNAPWithOverridePolicies(
1,
&gid,
),
Expected: portainer.EndpointGroupID(1),
},
{
Name: "test actionData nil",
PendingAction: handlers.NewCleanNAPWithOverridePolicies(
2,
nil,
),
Expected: nil,
},
{
Name: "test actionData empty and expected error",
PendingAction: portainer.PendingAction{
EndpointID: 2,
Action: actions.CleanNAPWithOverridePolicies,
ActionData: "",
},
Expected: nil,
Err: true,
},
}
for _, d := range testData {
err := store.PendingActions().Create(&d.PendingAction)
if err != nil {
t.Error(err)
return
}
pendingActions, err := store.PendingActions().ReadAll()
if err != nil {
t.Error(err)
return
}
for _, endpointPendingAction := range pendingActions {
t.Run(d.Name, func(t *testing.T) {
if endpointPendingAction.Action == actions.CleanNAPWithOverridePolicies {
var payload cleanNAPWithOverridePolicies
err := endpointPendingAction.UnmarshallActionData(&payload)
if d.Err && err == nil {
t.Error(err)
}
if d.Expected == nil && payload.EndpointGroupID != 0 {
t.Errorf("expected nil, got %d", payload.EndpointGroupID)
}
if d.Expected != nil {
expected := d.Expected.(portainer.EndpointGroupID)
if d.Expected != nil && expected != payload.EndpointGroupID {
t.Errorf("expected EndpointGroupID %d, got %d", expected, payload.EndpointGroupID)
}
}
}
})
}
store.PendingActions().Delete(d.PendingAction.ID)
}
})
}

View File

@@ -1,184 +0,0 @@
package postinit
import (
"context"
"github.com/docker/docker/api/types/container"
"github.com/docker/docker/client"
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/dataservices"
dockerClient "github.com/portainer/portainer/api/docker/client"
"github.com/portainer/portainer/api/internal/endpointutils"
"github.com/portainer/portainer/api/kubernetes/cli"
"github.com/portainer/portainer/api/pendingactions/actions"
"github.com/rs/zerolog/log"
)
type PostInitMigrator struct {
kubeFactory *cli.ClientFactory
dockerFactory *dockerClient.ClientFactory
dataStore dataservices.DataStore
assetsPath string
kubernetesDeployer portainer.KubernetesDeployer
}
func NewPostInitMigrator(
kubeFactory *cli.ClientFactory,
dockerFactory *dockerClient.ClientFactory,
dataStore dataservices.DataStore,
assetsPath string,
kubernetesDeployer portainer.KubernetesDeployer,
) *PostInitMigrator {
return &PostInitMigrator{
kubeFactory: kubeFactory,
dockerFactory: dockerFactory,
dataStore: dataStore,
assetsPath: assetsPath,
kubernetesDeployer: kubernetesDeployer,
}
}
// PostInitMigrate will run all post-init migrations, which require docker/kube clients for all edge or non-edge environments
func (postInitMigrator *PostInitMigrator) PostInitMigrate() error {
environments, err := postInitMigrator.dataStore.Endpoint().Endpoints()
if err != nil {
log.Error().Err(err).Msg("Error getting environments")
return err
}
for _, environment := range environments {
// edge environments will run after the server starts, in pending actions
if endpointutils.IsEdgeEndpoint(&environment) {
log.Info().Msgf("Adding pending action 'PostInitMigrateEnvironment' for environment %d", environment.ID)
err = postInitMigrator.createPostInitMigrationPendingAction(environment.ID)
if err != nil {
log.Error().Err(err).Msgf("Error creating pending action for environment %d", environment.ID)
}
} else {
// non-edge environments will run before the server starts.
err = postInitMigrator.MigrateEnvironment(&environment)
if err != nil {
log.Error().Err(err).Msgf("Error running post-init migrations for non-edge environment %d", environment.ID)
}
}
}
return nil
}
// try to create a post init migration pending action. If it already exists, do nothing
// this function exists for readability, not reusability
// TODO: This should be moved into pending actions as part of the pending action migration
func (postInitMigrator *PostInitMigrator) createPostInitMigrationPendingAction(environmentID portainer.EndpointID) error {
// If there are no pending actions for the given endpoint, create one
err := postInitMigrator.dataStore.PendingActions().Create(&portainer.PendingAction{
EndpointID: environmentID,
Action: actions.PostInitMigrateEnvironment,
})
if err != nil {
log.Error().Err(err).Msgf("Error creating pending action for environment %d", environmentID)
}
return nil
}
// MigrateEnvironment runs migrations on a single environment
func (migrator *PostInitMigrator) MigrateEnvironment(environment *portainer.Endpoint) error {
log.Info().Msgf("Executing post init migration for environment %d", environment.ID)
switch {
case endpointutils.IsKubernetesEndpoint(environment):
// get the kubeclient for the environment, and skip all kube migrations if there's an error
kubeclient, err := migrator.kubeFactory.GetPrivilegedKubeClient(environment)
if err != nil {
log.Error().Err(err).Msgf("Error creating kubeclient for environment: %d", environment.ID)
return err
}
// if one environment fails, it is logged and the next migration runs. The error is returned at the end and handled by pending actions
err = migrator.MigrateIngresses(*environment, kubeclient)
if err != nil {
return err
}
return nil
case endpointutils.IsDockerEndpoint(environment):
// get the docker client for the environment, and skip all docker migrations if there's an error
dockerClient, err := migrator.dockerFactory.CreateClient(environment, "", nil)
if err != nil {
log.Error().Err(err).Msgf("Error creating docker client for environment: %d", environment.ID)
return err
}
defer dockerClient.Close()
migrator.MigrateGPUs(*environment, dockerClient)
}
return nil
}
func (migrator *PostInitMigrator) MigrateIngresses(environment portainer.Endpoint, kubeclient *cli.KubeClient) error {
// Early exit if we do not need to migrate!
if !environment.PostInitMigrations.MigrateIngresses {
return nil
}
log.Debug().Msgf("Migrating ingresses for environment %d", environment.ID)
err := migrator.kubeFactory.MigrateEndpointIngresses(&environment, migrator.dataStore, kubeclient)
if err != nil {
log.Error().Err(err).Msgf("Error migrating ingresses for environment %d", environment.ID)
return err
}
return nil
}
// MigrateGPUs will check all docker endpoints for containers with GPUs and set EnableGPUManagement to true if any are found
// If there's an error getting the containers, we'll log it and move on
func (migrator *PostInitMigrator) MigrateGPUs(e portainer.Endpoint, dockerClient *client.Client) error {
return migrator.dataStore.UpdateTx(func(tx dataservices.DataStoreTx) error {
environment, err := tx.Endpoint().Endpoint(e.ID)
if err != nil {
log.Error().Err(err).Msgf("Error getting environment %d", environment.ID)
return err
}
// Early exit if we do not need to migrate!
if !environment.PostInitMigrations.MigrateGPUs {
return nil
}
log.Debug().Msgf("Migrating GPUs for environment %d", e.ID)
// get all containers
containers, err := dockerClient.ContainerList(context.Background(), container.ListOptions{All: true})
if err != nil {
log.Error().Err(err).Msgf("failed to list containers for environment %d", environment.ID)
return err
}
// check for a gpu on each container. If even one GPU is found, set EnableGPUManagement to true for the whole environment
containersLoop:
for _, container := range containers {
// https://www.sobyte.net/post/2022-10/go-docker/ has nice documentation on the docker client with GPUs
containerDetails, err := dockerClient.ContainerInspect(context.Background(), container.ID)
if err != nil {
log.Error().Err(err).Msg("failed to inspect container")
continue
}
deviceRequests := containerDetails.HostConfig.Resources.DeviceRequests
for _, deviceRequest := range deviceRequests {
if deviceRequest.Driver == "nvidia" {
environment.EnableGPUManagement = true
break containersLoop
}
}
}
// set the MigrateGPUs flag to false so we don't run this again
environment.PostInitMigrations.MigrateGPUs = false
err = tx.Endpoint().UpdateEndpoint(environment.ID, environment)
if err != nil {
log.Error().Err(err).Msgf("Error updating EnableGPUManagement flag for environment %d", environment.ID)
return err
}
return nil
})
}

View File

@@ -17,6 +17,7 @@ import (
"github.com/portainer/portainer/api/dataservices/endpointgroup"
"github.com/portainer/portainer/api/dataservices/endpointrelation"
"github.com/portainer/portainer/api/dataservices/extension"
"github.com/portainer/portainer/api/dataservices/fdoprofile"
"github.com/portainer/portainer/api/dataservices/helmuserrepository"
"github.com/portainer/portainer/api/dataservices/pendingactions"
"github.com/portainer/portainer/api/dataservices/registry"
@@ -54,6 +55,7 @@ type Store struct {
EndpointService *endpoint.Service
EndpointRelationService *endpointrelation.Service
ExtensionService *extension.Service
FDOProfilesService *fdoprofile.Service
HelmUserRepositoryService *helmuserrepository.Service
RegistryService *registry.Service
ResourceControlService *resourcecontrol.Service
@@ -136,6 +138,12 @@ func (store *Store) initServices() error {
}
store.ExtensionService = extensionService
fdoProfilesService, err := fdoprofile.NewService(store.connection)
if err != nil {
return err
}
store.FDOProfilesService = fdoProfilesService
helmUserRepositoryService, err := helmuserrepository.NewService(store.connection)
if err != nil {
return err
@@ -281,6 +289,11 @@ func (store *Store) EndpointRelation() dataservices.EndpointRelationService {
return store.EndpointRelationService
}
// FDOProfile gives access to the FDOProfile data management layer
func (store *Store) FDOProfile() dataservices.FDOProfileService {
return store.FDOProfilesService
}
// HelmUserRepository access the helm user repository settings
func (store *Store) HelmUserRepository() dataservices.HelmUserRepositoryService {
return store.HelmUserRepositoryService
@@ -385,10 +398,11 @@ type storeExport struct {
User []portainer.User `json:"users,omitempty"`
Version models.Version `json:"version,omitempty"`
Webhook []portainer.Webhook `json:"webhooks,omitempty"`
Metadata map[string]any `json:"metadata,omitempty"`
Metadata map[string]interface{} `json:"metadata,omitempty"`
}
func (store *Store) Export(filename string) (err error) {
backup := storeExport{}
if c, err := store.CustomTemplate().ReadAll(); err != nil {
@@ -592,7 +606,6 @@ func (store *Store) Export(filename string) (err error) {
if err != nil {
return err
}
return os.WriteFile(filename, b, 0600)
}
@@ -603,7 +616,7 @@ func (store *Store) Import(filename string) (err error) {
if err != nil {
return err
}
err = json.Unmarshal(s, &backup)
err = json.Unmarshal([]byte(s), &backup)
if err != nil {
return err
}

View File

@@ -16,9 +16,7 @@ func (tx *StoreTx) IsErrObjectNotFound(err error) bool {
func (tx *StoreTx) CustomTemplate() dataservices.CustomTemplateService { return nil }
func (tx *StoreTx) PendingActions() dataservices.PendingActionsService {
return tx.store.PendingActionsService.Tx(tx.tx)
}
func (tx *StoreTx) PendingActions() dataservices.PendingActionsService { return nil }
func (tx *StoreTx) EdgeGroup() dataservices.EdgeGroupService {
return tx.store.EdgeGroupService.Tx(tx.tx)
@@ -44,6 +42,7 @@ func (tx *StoreTx) EndpointRelation() dataservices.EndpointRelationService {
return tx.store.EndpointRelationService.Tx(tx.tx)
}
func (tx *StoreTx) FDOProfile() dataservices.FDOProfileService { return nil }
func (tx *StoreTx) HelmUserRepository() dataservices.HelmUserRepositoryService { return nil }
func (tx *StoreTx) Registry() dataservices.RegistryService {
@@ -69,10 +68,7 @@ func (tx *StoreTx) Snapshot() dataservices.SnapshotService {
}
func (tx *StoreTx) SSLSettings() dataservices.SSLSettingsService { return nil }
func (tx *StoreTx) Stack() dataservices.StackService {
return tx.store.StackService.Tx(tx.tx)
}
func (tx *StoreTx) Stack() dataservices.StackService { return nil }
func (tx *StoreTx) Tag() dataservices.TagService {
return tx.store.TagService.Tx(tx.tx)
@@ -82,10 +78,7 @@ func (tx *StoreTx) TeamMembership() dataservices.TeamMembershipService {
return tx.store.TeamMembershipService.Tx(tx.tx)
}
func (tx *StoreTx) Team() dataservices.TeamService {
return tx.store.TeamService.Tx(tx.tx)
}
func (tx *StoreTx) Team() dataservices.TeamService { return nil }
func (tx *StoreTx) TunnelServer() dataservices.TunnelServerService { return nil }
func (tx *StoreTx) User() dataservices.UserService {

View File

@@ -38,7 +38,6 @@
"TenantID": ""
},
"ComposeSyntaxMaxVersion": "",
"ContainerEngine": "",
"Edge": {
"AsyncMode": false,
"CommandInterval": 0,
@@ -584,6 +583,7 @@
"AuthenticationMethod": 1,
"BlackListedLabels": [],
"Edge": {
"AsyncMode": false,
"CommandInterval": 0,
"PingInterval": 0,
"SnapshotInterval": 0
@@ -602,7 +602,7 @@
"RequiredPasswordLength": 12
},
"KubeconfigExpiry": "0",
"KubectlShellImage": "portainer/kubectl-shell:2.22.0",
"KubectlShellImage": "portainer/kubectl-shell",
"LDAPSettings": {
"AnonymousMode": true,
"AutoCreateUsers": true,
@@ -631,7 +631,6 @@
"LogoURL": "",
"OAuthSettings": {
"AccessTokenURI": "",
"AuthStyle": 0,
"AuthorizationURI": "",
"ClientID": "",
"DefaultTeamID": 0,
@@ -644,10 +643,17 @@
"Scopes": "",
"UserIdentifier": ""
},
"ShowKomposeBuildOption": false,
"SnapshotInterval": "5m",
"TemplatesURL": "",
"TrustOnFirstConnect": false,
"UserSessionTimeout": "8h",
"fdoConfiguration": {
"enabled": false,
"ownerPassword": "",
"ownerURL": "",
"ownerUsername": ""
},
"openAMTConfiguration": {
"certFileContent": "",
"certFileName": "",
@@ -663,7 +669,6 @@
"snapshots": [
{
"Docker": {
"ContainerCount": 0,
"DockerSnapshotRaw": {
"Containers": null,
"Images": null,
@@ -671,7 +676,6 @@
"Architecture": "",
"BridgeNfIp6tables": false,
"BridgeNfIptables": false,
"CDISpecDirs": null,
"CPUSet": false,
"CPUShares": false,
"CgroupDriver": "",
@@ -769,7 +773,6 @@
"GpuUseList": null,
"HealthyContainerCount": 0,
"ImageCount": 9,
"IsPodman": false,
"NodeCount": 0,
"RunningContainerCount": 5,
"ServiceCount": 0,
@@ -804,6 +807,7 @@
"FromAppTemplate": false,
"GitConfig": null,
"Id": 2,
"IsComposeFormat": false,
"Name": "alpine",
"Namespace": "",
"Option": null,
@@ -826,6 +830,7 @@
"FromAppTemplate": false,
"GitConfig": null,
"Id": 5,
"IsComposeFormat": false,
"Name": "redis",
"Namespace": "",
"Option": null,
@@ -848,6 +853,7 @@
"FromAppTemplate": false,
"GitConfig": null,
"Id": 6,
"IsComposeFormat": false,
"Name": "nginx",
"Namespace": "",
"Option": null,
@@ -897,7 +903,7 @@
"color": ""
},
"TokenIssueAt": 0,
"UseCache": false,
"UseCache": true,
"Username": "admin"
},
{
@@ -927,11 +933,11 @@
"color": ""
},
"TokenIssueAt": 0,
"UseCache": false,
"UseCache": true,
"Username": "prabhat"
}
],
"version": {
"VERSION": "{\"SchemaVersion\":\"2.22.0\",\"MigratorCount\":1,\"Edition\":1,\"InstanceID\":\"463d5c47-0ea5-4aca-85b1-405ceefee254\"}"
"VERSION": "{\"SchemaVersion\":\"2.20.0\",\"MigratorCount\":2,\"Edition\":1,\"InstanceID\":\"463d5c47-0ea5-4aca-85b1-405ceefee254\"}"
}
}

View File

@@ -52,24 +52,27 @@ func NewTestStore(t testing.TB, init, secure bool) (bool, *Store, func(), error)
}
if init {
if err := store.Init(); err != nil {
err = store.Init()
if err != nil {
return newStore, nil, nil, err
}
}
if newStore {
// From MigrateData
// from MigrateData
v := models.Version{
SchemaVersion: portainer.APIVersion,
Edition: int(portainer.PortainerCE),
}
if err := store.VersionService.UpdateVersion(&v); err != nil {
err = store.VersionService.UpdateVersion(&v)
if err != nil {
return newStore, nil, nil, err
}
}
teardown := func() {
if err := store.Close(); err != nil {
err := store.Close()
if err != nil {
log.Fatal().Err(err).Msg("")
}
}

118
api/demo/demo.go Normal file
View File

@@ -0,0 +1,118 @@
package demo
import (
portainer "github.com/portainer/portainer/api"
"github.com/portainer/portainer/api/dataservices"
"github.com/pkg/errors"
"github.com/rs/zerolog/log"
)
type EnvironmentDetails struct {
Enabled bool `json:"enabled"`
Users []portainer.UserID `json:"users"`
Environments []portainer.EndpointID `json:"environments"`
}
type Service struct {
details EnvironmentDetails
}
func NewService() *Service {
return &Service{}
}
func (service *Service) Details() EnvironmentDetails {
return service.details
}
func (service *Service) Init(store dataservices.DataStore, cryptoService portainer.CryptoService) error {
log.Info().Msg("starting demo environment")
isClean, err := isCleanStore(store)
if err != nil {
return errors.WithMessage(err, "failed checking if store is clean")
}
if !isClean {
return errors.New(" Demo environment can only be initialized on a clean database")
}
id, err := initDemoUser(store, cryptoService)
if err != nil {
return errors.WithMessage(err, "failed creating demo user")
}
endpointIds, err := initDemoEndpoints(store)
if err != nil {
return errors.WithMessage(err, "failed creating demo endpoint")
}
err = initDemoSettings(store)
if err != nil {
return errors.WithMessage(err, "failed updating demo settings")
}
service.details = EnvironmentDetails{
Enabled: true,
Users: []portainer.UserID{id},
// endpoints 2,3 are created after deployment of portainer
Environments: endpointIds,
}
return nil
}
func isCleanStore(store dataservices.DataStore) (bool, error) {
endpoints, err := store.Endpoint().Endpoints()
if err != nil {
return false, err
}
if len(endpoints) > 0 {
return false, nil
}
users, err := store.User().ReadAll()
if err != nil {
return false, err
}
if len(users) > 0 {
return false, nil
}
return true, nil
}
func (service *Service) IsDemo() bool {
return service.details.Enabled
}
func (service *Service) IsDemoEnvironment(environmentID portainer.EndpointID) bool {
if !service.IsDemo() {
return false
}
for _, demoEndpointID := range service.details.Environments {
if environmentID == demoEndpointID {
return true
}
}
return false
}
func (service *Service) IsDemoUser(userID portainer.UserID) bool {
if !service.IsDemo() {
return false
}
for _, demoUserID := range service.details.Users {
if userID == demoUserID {
return true
}
}
return false
}

Some files were not shown because too many files have changed in this diff Show More