diff --git a/.junie/guidelines.md b/.junie/guidelines.md index 59194cd..74217a7 100644 --- a/.junie/guidelines.md +++ b/.junie/guidelines.md @@ -1,10 +1,13 @@ +## Global rules +- You must always ask before creating mocks -## SQLC queries +## SQLite queries -When writing SQL queries ensure you annotate your queries - -following are examples of correct annotations +- When writing sqlite queries use the tool [sqlc](https://sqlc.dev/) +- When writing SQL queries ensure you annotate your queries +- Following are examples of correct annotations +- after you finished writing the queries, use the command `task gen` to generate the boilerplate go code ```sql -- name: GetAuthor :one @@ -80,10 +83,12 @@ review `taskfile.yaml` for list of commands available in repo. ### Testing: - Write **unit tests** using use [odize](https://github.com/code-gorilla-au/odize) as the test framework and parallel execution. +- Do not mock out the database, we're using sqlite and embedded db for tests. +- Think about edge cases, within reason. - **Mock external interfaces** cleanly using generated ([Moq](https://github.com/matryer/moq)) or handwritten mocks. - Separate **fast unit tests** from slower integration and E2E tests. - Ensure **test coverage** for every exported function, with behavioural checks. -- Test coverage command is: `task go-cover`. +- Test command with coverage is: `task go-cover`. #### Example odize framework @@ -114,8 +119,7 @@ func TestQueries(t *testing.T) { "owner", "login", })) - - // Interpolated parameters + odize.AssertTrue(t, containsAll(q, []string{ "owner:" + owner, "topic:" + topic, diff --git a/cmd/local/main.go b/cmd/local/main.go new file mode 100644 index 0000000..64a3547 --- /dev/null +++ b/cmd/local/main.go @@ -0,0 +1,34 @@ +package main + +import ( + "context" + "os" + "watchtower/internal/database" + "watchtower/internal/logging" + "watchtower/internal/watchtower" +) + +func main() { + + appConfig := watchtower.LoadConfig() + logger := logging.FromContext(context.Background()).With("service", "local") + + _, db, err := database.NewDBFromProvider(appConfig.AppDir) + if err != nil { + logger.Error("Error creating database", "error", err) + os.Exit(1) + } + + defer func() { + if err = db.Close(); err != nil { + logger.Error("Error closing database", "error", err) + } + }() + + migrator := database.NewMigrator(db) + + if err = migrator.Init(); err != nil { + logger.Error("Error running migrations", "error", err) + os.Exit(1) + } +} diff --git a/frontend/src/lib/wailsjs/go/watchtower/Service.d.ts b/frontend/src/lib/wailsjs/go/watchtower/Service.d.ts index 12b3b65..80a2c1e 100755 --- a/frontend/src/lib/wailsjs/go/watchtower/Service.d.ts +++ b/frontend/src/lib/wailsjs/go/watchtower/Service.d.ts @@ -9,6 +9,10 @@ export function CreateOrganisation(arg1:string,arg2:string,arg3:string,arg4:stri export function CreateProduct(arg1:string,arg2:string,arg3:Array,arg4:number):Promise; +export function CreateUnreadPRNotification():Promise; + +export function CreateUnreadSecurityNotification():Promise; + export function DeleteAllOrgs():Promise; export function DeleteOldNotifications():Promise; @@ -37,7 +41,7 @@ export function GetSecurityByOrganisation(arg1:number):Promise>; -export function GetUnreadNotifications(arg1:number):Promise>; +export function GetUnreadNotifications():Promise>; export function MarkNotificationAsRead(arg1:number):Promise; @@ -53,4 +57,4 @@ export function SyncProduct(arg1:number):Promise; export function UpdateOrganisation(arg1:organisations.UpdateOrgParams):Promise; -export function UpdateProduct(arg1:number,arg2:string,arg3:Array):Promise; +export function UpdateProduct(arg1:number,arg2:string,arg3:string,arg4:Array):Promise; diff --git a/frontend/src/lib/wailsjs/go/watchtower/Service.js b/frontend/src/lib/wailsjs/go/watchtower/Service.js index 632a38d..f57af68 100755 --- a/frontend/src/lib/wailsjs/go/watchtower/Service.js +++ b/frontend/src/lib/wailsjs/go/watchtower/Service.js @@ -10,6 +10,14 @@ export function CreateProduct(arg1, arg2, arg3, arg4) { return window['go']['watchtower']['Service']['CreateProduct'](arg1, arg2, arg3, arg4); } +export function CreateUnreadPRNotification() { + return window['go']['watchtower']['Service']['CreateUnreadPRNotification'](); +} + +export function CreateUnreadSecurityNotification() { + return window['go']['watchtower']['Service']['CreateUnreadSecurityNotification'](); +} + export function DeleteAllOrgs() { return window['go']['watchtower']['Service']['DeleteAllOrgs'](); } @@ -66,8 +74,8 @@ export function GetSecurityByProductID(arg1) { return window['go']['watchtower']['Service']['GetSecurityByProductID'](arg1); } -export function GetUnreadNotifications(arg1) { - return window['go']['watchtower']['Service']['GetUnreadNotifications'](arg1); +export function GetUnreadNotifications() { + return window['go']['watchtower']['Service']['GetUnreadNotifications'](); } export function MarkNotificationAsRead(arg1) { @@ -98,6 +106,6 @@ export function UpdateOrganisation(arg1) { return window['go']['watchtower']['Service']['UpdateOrganisation'](arg1); } -export function UpdateProduct(arg1, arg2, arg3) { - return window['go']['watchtower']['Service']['UpdateProduct'](arg1, arg2, arg3); +export function UpdateProduct(arg1, arg2, arg3, arg4) { + return window['go']['watchtower']['Service']['UpdateProduct'](arg1, arg2, arg3, arg4); } diff --git a/frontend/src/lib/watchtower/products.svelte.ts b/frontend/src/lib/watchtower/products.svelte.ts index 4534d87..98dc90f 100644 --- a/frontend/src/lib/watchtower/products.svelte.ts +++ b/frontend/src/lib/watchtower/products.svelte.ts @@ -48,8 +48,8 @@ export class ProductsService { return product; } - async update(id: number, name: string, tags: string[]) { - const product = await UpdateProduct(id, name, tags); + async update(id: number, name: string, description: string, tags: string[]) { + const product = await UpdateProduct(id, name, description, tags); this.internalUpdateProduct(product); return product; diff --git a/frontend/src/routes/(orgs)/products/[product_id]/edit/+page.svelte b/frontend/src/routes/(orgs)/products/[product_id]/edit/+page.svelte index 674e955..4423efc 100644 --- a/frontend/src/routes/(orgs)/products/[product_id]/edit/+page.svelte +++ b/frontend/src/routes/(orgs)/products/[product_id]/edit/+page.svelte @@ -25,7 +25,12 @@ try { pageState.loading = true; pageState.error = undefined; - await productSvc.update(product.id, formData.name, formData.tags.split(",")); + await productSvc.update( + product.id, + formData.name, + formData.description, + formData.tags.split(",") + ); } catch (e) { const err = e as Error; pageState.error = err.message; diff --git a/go.mod b/go.mod index 79f23f2..f8ac31c 100644 --- a/go.mod +++ b/go.mod @@ -7,6 +7,7 @@ require ( github.com/code-gorilla-au/fetch v1.1.0 github.com/code-gorilla-au/odize v1.3.4 github.com/go-co-op/gocron/v2 v2.19.0 + github.com/google/uuid v1.6.0 github.com/wailsapp/wails/v2 v2.11.0 modernc.org/sqlite v1.42.2 ) @@ -16,7 +17,6 @@ require ( github.com/dustin/go-humanize v1.0.1 // indirect github.com/go-ole/go-ole v1.3.0 // indirect github.com/godbus/dbus/v5 v5.1.0 // indirect - github.com/google/uuid v1.6.0 // indirect github.com/gorilla/websocket v1.5.3 // indirect github.com/jchv/go-winloader v0.0.0-20210711035445-715c2860da7e // indirect github.com/joho/godotenv v1.5.1 // indirect diff --git a/internal/database/database.go b/internal/database/database.go index 958177b..854cfb6 100644 --- a/internal/database/database.go +++ b/internal/database/database.go @@ -4,6 +4,7 @@ import ( "database/sql" _ "embed" "path" + "strings" _ "modernc.org/sqlite" ) @@ -47,3 +48,7 @@ func resolveDBPath(filePath string) string { return path.Join(filePath, dbName, "?_busy_timeout=5000") } + +func IsErrUniqueConstraint(err error) bool { + return err != nil && strings.Contains(err.Error(), "constraint failed: UNIQUE constraint failed") +} diff --git a/internal/database/models.gen.go b/internal/database/models.gen.go index 840f684..801bc93 100644 --- a/internal/database/models.gen.go +++ b/internal/database/models.gen.go @@ -22,6 +22,7 @@ type Organisation struct { type OrganisationNotification struct { ID int64 OrganisationID sql.NullInt64 + ExternalID string Type string Content string Status string diff --git a/internal/database/notifications.sql.gen.go b/internal/database/notifications.sql.gen.go index 3278ec4..7b0540e 100644 --- a/internal/database/notifications.sql.gen.go +++ b/internal/database/notifications.sql.gen.go @@ -12,30 +12,39 @@ import ( const createOrgNotification = `-- name: CreateOrgNotification :one INSERT INTO organisation_notifications (organisation_id, + external_id, type, content, created_at, updated_at) VALUES (?, + ?, ?, ?, CAST(strftime('%s', 'now') AS INTEGER), CAST(strftime('%s', 'now') AS INTEGER)) -RETURNING id, organisation_id, type, content, status, created_at, updated_at +RETURNING id, organisation_id, external_id, type, content, status, created_at, updated_at ` type CreateOrgNotificationParams struct { OrganisationID sql.NullInt64 + ExternalID string Type string Content string } func (q *Queries) CreateOrgNotification(ctx context.Context, arg CreateOrgNotificationParams) (OrganisationNotification, error) { - row := q.db.QueryRowContext(ctx, createOrgNotification, arg.OrganisationID, arg.Type, arg.Content) + row := q.db.QueryRowContext(ctx, createOrgNotification, + arg.OrganisationID, + arg.ExternalID, + arg.Type, + arg.Content, + ) var i OrganisationNotification err := row.Scan( &i.ID, &i.OrganisationID, + &i.ExternalID, &i.Type, &i.Content, &i.Status, @@ -46,7 +55,9 @@ func (q *Queries) CreateOrgNotification(ctx context.Context, arg CreateOrgNotifi } const deleteOrgNotificationByDate = `-- name: DeleteOrgNotificationByDate :exec -DELETE FROM organisation_notifications WHERE created_at < ? +DELETE +FROM organisation_notifications +WHERE created_at < ? ` func (q *Queries) DeleteOrgNotificationByDate(ctx context.Context, createdAt int64) error { @@ -54,9 +65,33 @@ func (q *Queries) DeleteOrgNotificationByDate(ctx context.Context, createdAt int return err } +const getNotificationByExternalID = `-- name: GetNotificationByExternalID :one +SELECT id, organisation_id, external_id, type, content, status, created_at, updated_at +FROM organisation_notifications +WHERE external_id = ? +` + +func (q *Queries) GetNotificationByExternalID(ctx context.Context, externalID string) (OrganisationNotification, error) { + row := q.db.QueryRowContext(ctx, getNotificationByExternalID, externalID) + var i OrganisationNotification + err := row.Scan( + &i.ID, + &i.OrganisationID, + &i.ExternalID, + &i.Type, + &i.Content, + &i.Status, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + const getUnreadNotificationsByOrgID = `-- name: GetUnreadNotificationsByOrgID :many -SELECT id, organisation_id, type, content, status, created_at, updated_at FROM organisation_notifications WHERE organisation_id = ? -AND status = 'unread' +SELECT id, organisation_id, external_id, type, content, status, created_at, updated_at +FROM organisation_notifications +WHERE organisation_id = ? + AND status = 'unread' ` func (q *Queries) GetUnreadNotificationsByOrgID(ctx context.Context, organisationID sql.NullInt64) ([]OrganisationNotification, error) { @@ -71,6 +106,7 @@ func (q *Queries) GetUnreadNotificationsByOrgID(ctx context.Context, organisatio if err := rows.Scan( &i.ID, &i.OrganisationID, + &i.ExternalID, &i.Type, &i.Content, &i.Status, @@ -97,7 +133,7 @@ SET type = ?, status = ?, updated_at = CAST(strftime('%s', 'now') AS INTEGER) WHERE id = ? -RETURNING id, organisation_id, type, content, status, created_at, updated_at +RETURNING id, organisation_id, external_id, type, content, status, created_at, updated_at ` type UpdateOrgNotificationByIDParams struct { @@ -118,6 +154,7 @@ func (q *Queries) UpdateOrgNotificationByID(ctx context.Context, arg UpdateOrgNo err := row.Scan( &i.ID, &i.OrganisationID, + &i.ExternalID, &i.Type, &i.Content, &i.Status, @@ -132,7 +169,7 @@ UPDATE organisation_notifications SET status = ?, updated_at = CAST(strftime('%s', 'now') AS INTEGER) WHERE id = ? -RETURNING id, organisation_id, type, content, status, created_at, updated_at +RETURNING id, organisation_id, external_id, type, content, status, created_at, updated_at ` type UpdateOrgNotificationStatusByIDParams struct { @@ -146,6 +183,7 @@ func (q *Queries) UpdateOrgNotificationStatusByID(ctx context.Context, arg Updat err := row.Scan( &i.ID, &i.OrganisationID, + &i.ExternalID, &i.Type, &i.Content, &i.Status, diff --git a/internal/database/products.sql.gen.go b/internal/database/products.sql.gen.go index 0d64eba..beed269 100644 --- a/internal/database/products.sql.gen.go +++ b/internal/database/products.sql.gen.go @@ -80,13 +80,6 @@ VALUES (?, ?, ?, CAST(strftime('%s', 'now') AS INTEGER)) -ON CONFLICT (external_id) DO UPDATE SET title = excluded.title, - repository_name = excluded.repository_name, - url = excluded.url, - state = excluded.state, - author = excluded.author, - merged_at = excluded.merged_at, - updated_at = CAST(strftime('%s', 'now') AS INTEGER) RETURNING id, external_id, title, repository_name, url, state, author, merged_at, created_at, updated_at ` @@ -189,13 +182,6 @@ VALUES (?, ?, ?, CAST(strftime('%s', 'now') AS INTEGER)) -ON CONFLICT (external_id) DO UPDATE SET repository_name = excluded.repository_name, - package_name = excluded.package_name, - state = excluded.state, - severity = excluded.severity, - patched_version = excluded.patched_version, - fixed_at = excluded.fixed_at, - updated_at = CAST(strftime('%s', 'now') AS INTEGER) RETURNING id, external_id, repository_name, package_name, state, severity, patched_version, fixed_at, created_at, updated_at ` @@ -354,6 +340,31 @@ func (q *Queries) GetProductByID(ctx context.Context, id int64) (Product, error) return i, err } +const getPullRequestByExternalID = `-- name: GetPullRequestByExternalID :one +SELECT id, external_id, title, repository_name, url, state, author, merged_at, created_at, updated_at +FROM pull_requests +WHERE external_id = ? +LIMIT 1 +` + +func (q *Queries) GetPullRequestByExternalID(ctx context.Context, externalID string) (PullRequest, error) { + row := q.db.QueryRowContext(ctx, getPullRequestByExternalID, externalID) + var i PullRequest + err := row.Scan( + &i.ID, + &i.ExternalID, + &i.Title, + &i.RepositoryName, + &i.Url, + &i.State, + &i.Author, + &i.MergedAt, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + const getPullRequestByProductIDAndState = `-- name: GetPullRequestByProductIDAndState :many SELECT pr.id, pr.external_id, pr.title, pr.repository_name, pr.url, pr.state, pr.author, pr.merged_at, pr.created_at, pr.updated_at, r.topic as tag, p.name as product_name FROM pull_requests pr @@ -494,6 +505,68 @@ func (q *Queries) GetPullRequestsByOrganisationAndState(ctx context.Context, arg return items, nil } +const getRecentPullRequests = `-- name: GetRecentPullRequests :many +SELECT external_id +FROM pull_requests +WHERE created_at >= unixepoch() - 300 +AND state = 'OPEN' +ORDER BY created_at DESC +` + +func (q *Queries) GetRecentPullRequests(ctx context.Context) ([]string, error) { + rows, err := q.db.QueryContext(ctx, getRecentPullRequests) + if err != nil { + return nil, err + } + defer rows.Close() + var items []string + for rows.Next() { + var external_id string + if err := rows.Scan(&external_id); err != nil { + return nil, err + } + items = append(items, external_id) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const getRecentSecurity = `-- name: GetRecentSecurity :many +SELECT external_id +FROM securities +WHERE created_at >= unixepoch() - 300 +and state = 'OPEN' +ORDER BY created_at DESC +` + +func (q *Queries) GetRecentSecurity(ctx context.Context) ([]string, error) { + rows, err := q.db.QueryContext(ctx, getRecentSecurity) + if err != nil { + return nil, err + } + defer rows.Close() + var items []string + for rows.Next() { + var external_id string + if err := rows.Scan(&external_id); err != nil { + return nil, err + } + items = append(items, external_id) + } + if err := rows.Close(); err != nil { + return nil, err + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + const getRepoByName = `-- name: GetRepoByName :one SELECT id, name, url, topic, owner, created_at, updated_at FROM repositories @@ -569,6 +642,31 @@ func (q *Queries) GetReposByProductID(ctx context.Context, id int64) ([]GetRepos return items, nil } +const getSecurityByExternalID = `-- name: GetSecurityByExternalID :one +SELECT id, external_id, repository_name, package_name, state, severity, patched_version, fixed_at, created_at, updated_at +FROM securities +WHERE external_id = ? +LIMIT 1 +` + +func (q *Queries) GetSecurityByExternalID(ctx context.Context, externalID string) (Security, error) { + row := q.db.QueryRowContext(ctx, getSecurityByExternalID, externalID) + var i Security + err := row.Scan( + &i.ID, + &i.ExternalID, + &i.RepositoryName, + &i.PackageName, + &i.State, + &i.Severity, + &i.PatchedVersion, + &i.FixedAt, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + const getSecurityByOrganisationAndState = `-- name: GetSecurityByOrganisationAndState :many SELECT s.id, s.external_id, s.repository_name, s.package_name, s.state, s.severity, s.patched_version, s.fixed_at, s.created_at, s.updated_at, r.topic as tag, p.name as product_name FROM securities s @@ -794,6 +892,55 @@ func (q *Queries) UpdateProductSync(ctx context.Context, id int64) error { return err } +const updatePullRequest = `-- name: UpdatePullRequest :one +UPDATE pull_requests +SET title = ?, + repository_name = ?, + url = ?, + state = ?, + author = ?, + merged_at = ?, + updated_at = CAST(strftime('%s', 'now') AS INTEGER) +WHERE id = ? +RETURNING id, external_id, title, repository_name, url, state, author, merged_at, created_at, updated_at +` + +type UpdatePullRequestParams struct { + Title string + RepositoryName string + Url string + State string + Author string + MergedAt sql.NullInt64 + ID int64 +} + +func (q *Queries) UpdatePullRequest(ctx context.Context, arg UpdatePullRequestParams) (PullRequest, error) { + row := q.db.QueryRowContext(ctx, updatePullRequest, + arg.Title, + arg.RepositoryName, + arg.Url, + arg.State, + arg.Author, + arg.MergedAt, + arg.ID, + ) + var i PullRequest + err := row.Scan( + &i.ID, + &i.ExternalID, + &i.Title, + &i.RepositoryName, + &i.Url, + &i.State, + &i.Author, + &i.MergedAt, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} + const updateRepo = `-- name: UpdateRepo :one UPDATE repositories SET name = ?, @@ -833,3 +980,52 @@ func (q *Queries) UpdateRepo(ctx context.Context, arg UpdateRepoParams) (Reposit ) return i, err } + +const updateSecurity = `-- name: UpdateSecurity :one +UPDATE securities +SET repository_name = ?, + package_name = ?, + state = ?, + severity = ?, + patched_version = ?, + fixed_at = ?, + updated_at = CAST(strftime('%s', 'now') AS INTEGER) +WHERE external_id = ? +RETURNING id, external_id, repository_name, package_name, state, severity, patched_version, fixed_at, created_at, updated_at +` + +type UpdateSecurityParams struct { + RepositoryName string + PackageName string + State string + Severity string + PatchedVersion string + FixedAt sql.NullInt64 + ExternalID string +} + +func (q *Queries) UpdateSecurity(ctx context.Context, arg UpdateSecurityParams) (Security, error) { + row := q.db.QueryRowContext(ctx, updateSecurity, + arg.RepositoryName, + arg.PackageName, + arg.State, + arg.Severity, + arg.PatchedVersion, + arg.FixedAt, + arg.ExternalID, + ) + var i Security + err := row.Scan( + &i.ID, + &i.ExternalID, + &i.RepositoryName, + &i.PackageName, + &i.State, + &i.Severity, + &i.PatchedVersion, + &i.FixedAt, + &i.CreatedAt, + &i.UpdatedAt, + ) + return i, err +} diff --git a/internal/database/queries/notifications.sql b/internal/database/queries/notifications.sql index 94375a5..0695934 100644 --- a/internal/database/queries/notifications.sql +++ b/internal/database/queries/notifications.sql @@ -1,16 +1,23 @@ -- name: CreateOrgNotification :one INSERT INTO organisation_notifications (organisation_id, + external_id, type, content, created_at, updated_at) VALUES (?, + ?, ?, ?, CAST(strftime('%s', 'now') AS INTEGER), CAST(strftime('%s', 'now') AS INTEGER)) RETURNING *; +-- name: GetNotificationByExternalID :one +SELECT * +FROM organisation_notifications +WHERE external_id = ?; + -- name: UpdateOrgNotificationByID :one UPDATE organisation_notifications SET type = ?, @@ -28,8 +35,12 @@ WHERE id = ? RETURNING *; -- name: GetUnreadNotificationsByOrgID :many -SELECT * FROM organisation_notifications WHERE organisation_id = ? -AND status = 'unread'; +SELECT * +FROM organisation_notifications +WHERE organisation_id = ? + AND status = 'unread'; -- name: DeleteOrgNotificationByDate :exec -DELETE FROM organisation_notifications WHERE created_at < ?; +DELETE +FROM organisation_notifications +WHERE created_at < ?; diff --git a/internal/database/queries/products.sql b/internal/database/queries/products.sql index 849c7c1..01dc8e6 100644 --- a/internal/database/queries/products.sql +++ b/internal/database/queries/products.sql @@ -132,15 +132,26 @@ VALUES (?, ?, ?, CAST(strftime('%s', 'now') AS INTEGER)) -ON CONFLICT (external_id) DO UPDATE SET title = excluded.title, - repository_name = excluded.repository_name, - url = excluded.url, - state = excluded.state, - author = excluded.author, - merged_at = excluded.merged_at, - updated_at = CAST(strftime('%s', 'now') AS INTEGER) RETURNING *; +-- name: UpdatePullRequest :one +UPDATE pull_requests +SET title = ?, + repository_name = ?, + url = ?, + state = ?, + author = ?, + merged_at = ?, + updated_at = CAST(strftime('%s', 'now') AS INTEGER) +WHERE id = ? +RETURNING *; + +-- name: GetPullRequestByExternalID :one +SELECT * +FROM pull_requests +WHERE external_id = ? +LIMIT 1; + -- name: GetPullRequestByProductIDAndState :many SELECT pr.*, r.topic as tag, p.name as product_name FROM pull_requests pr @@ -179,6 +190,22 @@ WHERE external_id IN (SELECT pr.external_id FROM JSON_EACH(p.tags) WHERE JSON_EACH.value = r.topic)); + +-- name: GetRecentPullRequests :many +SELECT external_id +FROM pull_requests +WHERE created_at >= unixepoch() - 300 +AND state = 'OPEN' +ORDER BY created_at DESC; + + +-- name: GetRecentSecurity :many +SELECT external_id +FROM securities +WHERE created_at >= unixepoch() - 300 +and state = 'OPEN' +ORDER BY created_at DESC; + -- name: CreateSecurity :one INSERT INTO securities (external_id, repository_name, @@ -197,15 +224,26 @@ VALUES (?, ?, ?, CAST(strftime('%s', 'now') AS INTEGER)) -ON CONFLICT (external_id) DO UPDATE SET repository_name = excluded.repository_name, - package_name = excluded.package_name, - state = excluded.state, - severity = excluded.severity, - patched_version = excluded.patched_version, - fixed_at = excluded.fixed_at, - updated_at = CAST(strftime('%s', 'now') AS INTEGER) RETURNING *; +-- name: UpdateSecurity :one +UPDATE securities +SET repository_name = ?, + package_name = ?, + state = ?, + severity = ?, + patched_version = ?, + fixed_at = ?, + updated_at = CAST(strftime('%s', 'now') AS INTEGER) +WHERE external_id = ? +RETURNING *; + +-- name: GetSecurityByExternalID :one +SELECT * +FROM securities +WHERE external_id = ? +LIMIT 1; + -- name: GetSecurityByProductIDAndState :many SELECT s.*, r.topic as tag, p.name as product_name diff --git a/internal/database/schemas/schema.sql b/internal/database/schemas/schema.sql index acde7b8..eddbb27 100644 --- a/internal/database/schemas/schema.sql +++ b/internal/database/schemas/schema.sql @@ -70,6 +70,7 @@ CREATE TABLE IF NOT EXISTS organisation_notifications ( id INTEGER PRIMARY KEY AUTOINCREMENT, organisation_id INTEGER REFERENCES organisations (id), + external_id TEXT UNIQUE NOT NULL, type TEXT NOT NULL, content TEXT NOT NULL, status TEXT NOT NULL DEFAULT 'unread', diff --git a/internal/notifications/interfaces.go b/internal/notifications/interfaces.go index 831b387..0451957 100644 --- a/internal/notifications/interfaces.go +++ b/internal/notifications/interfaces.go @@ -11,6 +11,7 @@ type Store interface { CreateOrgNotification(ctx context.Context, arg database.CreateOrgNotificationParams) (database.OrganisationNotification, error) UpdateOrgNotificationByID(ctx context.Context, arg database.UpdateOrgNotificationByIDParams) (database.OrganisationNotification, error) UpdateOrgNotificationStatusByID(ctx context.Context, arg database.UpdateOrgNotificationStatusByIDParams) (database.OrganisationNotification, error) + GetNotificationByExternalID(ctx context.Context, externalID string) (database.OrganisationNotification, error) GetUnreadNotificationsByOrgID(ctx context.Context, organisationID sql.NullInt64) ([]database.OrganisationNotification, error) DeleteOrgNotificationByDate(ctx context.Context, createdAt int64) error } diff --git a/internal/notifications/service.go b/internal/notifications/service.go index 3ae874a..396ff11 100644 --- a/internal/notifications/service.go +++ b/internal/notifications/service.go @@ -16,34 +16,42 @@ func New(db Store, txnDB *sql.DB, txnFunc func(tx *sql.Tx) Store) *Service { } } +type CreateNotificationParams struct { + OrgID int64 + ExternalID string + NotificationType string + Content string +} + // CreateNotification creates a new notification for a specific organisation with the given type and content. -func (s *Service) CreateNotification(ctx context.Context, orgID int64, notificationType string, content string) (Notification, error) { - logger := logging.FromContext(ctx).With("orgID", orgID, "service", "notifications") +func (s *Service) CreateNotification(ctx context.Context, params CreateNotificationParams) error { + logger := logging.FromContext(ctx).With("orgID", params.OrgID, "service", "notifications") logger.Debug("Creating notification for org") - model, err := s.store.CreateOrgNotification(ctx, database.CreateOrgNotificationParams{ + _, err := s.store.CreateOrgNotification(ctx, database.CreateOrgNotificationParams{ + ExternalID: params.ExternalID, OrganisationID: sql.NullInt64{ - Int64: orgID, + Int64: params.OrgID, Valid: true, }, - Type: notificationType, - Content: content, + Type: params.NotificationType, + Content: params.Content, }) - if err != nil { - logger.Error("Error creating notification", "error", err) - return Notification{}, err + if err == nil || database.IsErrUniqueConstraint(err) { + return nil } - return fromNotificationModel(model), nil + logger.Error("Error creating notification", "error", err) + return err } // GetUnreadNotifications fetches all unread notifications for the specified organisation ID. Returns a list of notifications or an error. -func (s *Service) GetUnreadNotifications(ctx context.Context, orgID int64) ([]Notification, error) { - logger := logging.FromContext(ctx).With("orgID", orgID, "service", "notifications") +func (s *Service) GetUnreadNotifications(ctx context.Context) ([]Notification, error) { + logger := logging.FromContext(ctx).With("service", "notifications") logger.Debug("Fetching unread notifications") models, err := s.store.GetUnreadNotificationsByOrgID(ctx, sql.NullInt64{ - Int64: orgID, + Int64: 0, Valid: true, }) if err != nil { @@ -54,6 +62,19 @@ func (s *Service) GetUnreadNotifications(ctx context.Context, orgID int64) ([]No return fromNotificationModels(models), nil } +func (s *Service) GetNotificationByExternalID(ctx context.Context, externalID string) (Notification, error) { + logger := logging.FromContext(ctx).With("externalID", externalID, "service", "notifications") + logger.Debug("Fetching notification by external ID") + + model, err := s.store.GetNotificationByExternalID(ctx, externalID) + if err != nil { + logger.Error("Error fetching notification by external ID", "error", err) + return Notification{}, err + } + + return fromNotificationModel(model), nil +} + // MarkNotificationAsRead updates the status of a notification to "read" based on the provided notification ID. func (s *Service) MarkNotificationAsRead(ctx context.Context, notificationID int64) error { logger := logging.FromContext(ctx).With("notificationID", notificationID, "service", "notifications") diff --git a/internal/notifications/service_test.go b/internal/notifications/service_test.go index 957c3ca..224b305 100644 --- a/internal/notifications/service_test.go +++ b/internal/notifications/service_test.go @@ -27,53 +27,54 @@ func TestService(t *testing.T) { notifType := "test-type" content := "test-content" - notif, err := s.CreateNotification(ctx, orgID, notifType, content) + err := s.CreateNotification(ctx, CreateNotificationParams{ + OrgID: orgID, + NotificationType: notifType, + Content: content, + ExternalID: "test-external-id", + }) odize.AssertNoError(t, err) - odize.AssertEqual(t, orgID, notif.OrganisationID) - odize.AssertEqual(t, notifType, notif.Type) - odize.AssertEqual(t, content, notif.Content) - odize.AssertEqual(t, StatusUnread, notif.Status) - odize.AssertTrue(t, notif.ID > 0) - odize.AssertFalse(t, notif.CreatedAt.IsZero()) - odize.AssertFalse(t, notif.UpdatedAt.IsZero()) - }). - Test("GetUnreadNotifications should return only unread notifications for an org", func(t *testing.T) { - orgID := int64(2) - - _, err := s.CreateNotification(ctx, orgID, "type1", "content1") - odize.AssertNoError(t, err) - - notif2, err := s.CreateNotification(ctx, orgID, "type2", "content2") - odize.AssertNoError(t, err) - - err = s.MarkNotificationAsRead(ctx, notif2.ID) + unread, err := s.GetNotificationByExternalID(ctx, "test-external-id") odize.AssertNoError(t, err) - unread, err := s.GetUnreadNotifications(ctx, orgID) - odize.AssertNoError(t, err) - - odize.AssertEqual(t, 1, len(unread)) - odize.AssertEqual(t, "content1", unread[0].Content) - odize.AssertEqual(t, StatusUnread, unread[0].Status) + odize.AssertEqual(t, orgID, unread.OrganisationID) + odize.AssertEqual(t, notifType, unread.Type) + odize.AssertEqual(t, content, unread.Content) + odize.AssertEqual(t, StatusUnread, unread.Status) + odize.AssertTrue(t, unread.ID > 0) + odize.AssertFalse(t, unread.CreatedAt.IsZero()) + odize.AssertFalse(t, unread.UpdatedAt.IsZero()) }). Test("MarkNotificationAsRead should update notification status", func(t *testing.T) { orgID := int64(3) - notif, err := s.CreateNotification(ctx, orgID, "type", "content") + err := s.CreateNotification(ctx, CreateNotificationParams{ + OrgID: orgID, + NotificationType: "type", + Content: "content", + ExternalID: "ext3", + }) odize.AssertNoError(t, err) - odize.AssertEqual(t, StatusUnread, notif.Status) - err = s.MarkNotificationAsRead(ctx, notif.ID) + notif, err := s.GetNotificationByExternalID(ctx, "ext3") odize.AssertNoError(t, err) - unread, err := s.GetUnreadNotifications(ctx, orgID) + err = s.MarkNotificationAsRead(ctx, notif.ID) odize.AssertNoError(t, err) - odize.AssertEqual(t, 0, len(unread)) + + updatedNotif, updateErr := s.GetNotificationByExternalID(ctx, "ext3") + odize.AssertNoError(t, updateErr) + odize.AssertEqual(t, StatusRead, updatedNotif.Status) }). Test("DeleteNotificationsByDate should delete old notifications", func(t *testing.T) { orgID := int64(4) - _, err := s.CreateNotification(ctx, orgID, "type", "content") + err := s.CreateNotification(ctx, CreateNotificationParams{ + OrgID: orgID, + NotificationType: "type", + Content: "content", + ExternalID: "ext4", + }) odize.AssertNoError(t, err) cutoff := time.Now().Add(1 * time.Minute) @@ -81,7 +82,7 @@ func TestService(t *testing.T) { err = s.DeleteNotificationsByDate(ctx, cutoff) odize.AssertNoError(t, err) - unread, err := s.GetUnreadNotifications(ctx, orgID) + unread, err := s.GetUnreadNotifications(ctx) odize.AssertNoError(t, err) odize.AssertEqual(t, 0, len(unread)) }). diff --git a/internal/products/interfaces.go b/internal/products/interfaces.go index 5e8aea1..55d4986 100644 --- a/internal/products/interfaces.go +++ b/internal/products/interfaces.go @@ -7,24 +7,47 @@ import ( ) type ProductStore interface { + ProductBaseStore + RepoStore + PullRequestStore + SecurityStore +} + +var _ ProductStore = (*database.Queries)(nil) + +type ProductBaseStore interface { CreateProduct(ctx context.Context, arg database.CreateProductParams) (database.Product, error) GetProductByID(ctx context.Context, id int64) (database.Product, error) ListProductsByOrganisation(ctx context.Context, organisationID sql.NullInt64) ([]database.Product, error) UpdateProduct(ctx context.Context, arg database.UpdateProductParams) (database.Product, error) UpdateProductSync(ctx context.Context, id int64) error - DeleteSecurityByProductID(ctx context.Context, id int64) error - DeletePullRequestsByProductID(ctx context.Context, id int64) error - DeleteReposByProductID(ctx context.Context, id int64) error DeleteProduct(ctx context.Context, id int64) error +} +type RepoStore interface { CreateRepo(ctx context.Context, arg database.CreateRepoParams) (database.Repository, error) UpdateRepo(ctx context.Context, arg database.UpdateRepoParams) (database.Repository, error) GetRepoByName(ctx context.Context, name string) (database.Repository, error) GetReposByProductID(ctx context.Context, id int64) ([]database.GetReposByProductIDRow, error) + DeleteReposByProductID(ctx context.Context, id int64) error +} + +type PullRequestStore interface { GetPullRequestByProductIDAndState(ctx context.Context, arg database.GetPullRequestByProductIDAndStateParams) ([]database.GetPullRequestByProductIDAndStateRow, error) GetPullRequestsByOrganisationAndState(ctx context.Context, arg database.GetPullRequestsByOrganisationAndStateParams) ([]database.GetPullRequestsByOrganisationAndStateRow, error) CreatePullRequest(ctx context.Context, arg database.CreatePullRequestParams) (database.PullRequest, error) + UpdatePullRequest(ctx context.Context, arg database.UpdatePullRequestParams) (database.PullRequest, error) + GetPullRequestByExternalID(ctx context.Context, externalID string) (database.PullRequest, error) + GetRecentPullRequests(ctx context.Context) ([]string, error) + DeletePullRequestsByProductID(ctx context.Context, id int64) error +} + +type SecurityStore interface { GetSecurityByProductIDAndState(ctx context.Context, arg database.GetSecurityByProductIDAndStateParams) ([]database.GetSecurityByProductIDAndStateRow, error) GetSecurityByOrganisationAndState(ctx context.Context, arg database.GetSecurityByOrganisationAndStateParams) ([]database.GetSecurityByOrganisationAndStateRow, error) CreateSecurity(ctx context.Context, arg database.CreateSecurityParams) (database.Security, error) + GetRecentSecurity(ctx context.Context) ([]string, error) + UpdateSecurity(ctx context.Context, arg database.UpdateSecurityParams) (database.Security, error) + GetSecurityByExternalID(ctx context.Context, externalID string) (database.Security, error) + DeleteSecurityByProductID(ctx context.Context, id int64) error } diff --git a/internal/products/service.go b/internal/products/service.go index 146d574..905e5d3 100644 --- a/internal/products/service.go +++ b/internal/products/service.go @@ -281,34 +281,119 @@ func (s *Service) BulkUpsertRepos(ctx context.Context, paramsList []CreateRepoPa return nil } +// CreatePullRequest creates a new pull request entry in the database using the provided parameters. +func (s *Service) CreatePullRequest(ctx context.Context, params CreatePRParams) error { + logger := logging.FromContext(ctx).With("service", "products") + + logger.Debug("Creating pull request") + + var mergedAt sql.NullInt64 + if params.MergedAt != nil { + mergedAt.Valid = true + mergedAt.Int64 = params.MergedAt.Unix() + } + + _, err := s.store.CreatePullRequest(ctx, database.CreatePullRequestParams{ + ExternalID: params.ExternalID, + Title: params.Title, + RepositoryName: params.RepositoryName, + Url: params.Url, + State: params.State, + Author: params.Author, + MergedAt: mergedAt, + CreatedAt: params.CreatedAt.Unix(), + }) + + if err != nil { + logger.Error("Error creating pull request", "error", err) + return err + } + + return nil +} + +func (s *Service) UpdatePullRequest(ctx context.Context, params UpdatePRParams) error { + logger := logging.FromContext(ctx).With("service", "products") + + logger.Debug("Updating pull request") + + var mergedAt sql.NullInt64 + if params.MergedAt != nil { + mergedAt.Valid = true + mergedAt.Int64 = params.MergedAt.Unix() + } + + _, err := s.store.UpdatePullRequest(ctx, database.UpdatePullRequestParams{ + Title: params.Title, + RepositoryName: params.RepositoryName, + Url: params.Url, + State: params.State, + Author: params.Author, + MergedAt: mergedAt, + ID: params.ID, + }) + + if err != nil { + logger.Error("Error updating pull request", "error", err) + return err + } + + return nil +} + +func (s *Service) UpsertPullRequest(ctx context.Context, params CreatePRParams) error { + logger := logging.FromContext(ctx).With("service", "products") + + createErr := s.CreatePullRequest(ctx, params) + if createErr == nil { + return nil + } + + if !database.IsErrUniqueConstraint(createErr) { + logger.Error("Error creating pull request", "error", createErr) + return createErr + } + + pr, getErr := s.store.GetPullRequestByExternalID(ctx, params.ExternalID) + if getErr != nil { + logger.Error("Error fetching pull request", "error", getErr) + return getErr + } + + return s.UpdatePullRequest(ctx, UpdatePRParams{ + ID: pr.ID, + ExternalID: pr.ExternalID, + Title: params.Title, + RepositoryName: params.RepositoryName, + Url: params.Url, + State: params.State, + Author: params.Author, + MergedAt: params.MergedAt, + }) +} + +func (s *Service) GetRecentPullRequests(ctx context.Context) ([]string, error) { + logger := logging.FromContext(ctx).With("service", "products") + logger.Debug("Getting recent pull requests") + externalIDs, err := s.store.GetRecentPullRequests(ctx) + if err != nil { + logger.Error("Error fetching recent pull requests", "error", err) + return nil, err + } + + return externalIDs, nil +} + func (s *Service) BulkCreatePullRequest(ctx context.Context, paramsList []CreatePRParams) error { logger := logging.FromContext(ctx) for _, params := range paramsList { - mergedAt := sql.NullInt64{ - Valid: false, - Int64: 0, - } - if params.MergedAt != nil { - mergedAt.Int64 = params.MergedAt.Unix() - mergedAt.Valid = true - } - _, err := s.store.CreatePullRequest(ctx, database.CreatePullRequestParams{ - ExternalID: params.ExternalID, - Title: params.Title, - RepositoryName: params.RepositoryName, - Url: params.Url, - State: params.State, - Author: params.Author, - MergedAt: mergedAt, - CreatedAt: params.CreatedAt.Unix(), - }) - if err != nil { + if err := s.UpsertPullRequest(ctx, params); err != nil { logger.Error("Error creating pull request", "error", err) - return err } + } return nil @@ -350,26 +435,24 @@ func (s *Service) GetSecurityByOrg(ctx context.Context, orgID int64) ([]Security return orgToSecurityDTOs(model), nil } +func (s *Service) GetRecentSecurity(ctx context.Context) ([]string, error) { + logger := logging.FromContext(ctx).With("service", "products") + logger.Debug("Getting recent security") + + externalIDs, err := s.store.GetRecentSecurity(ctx) + if err != nil { + logger.Error("Error fetching recent security", "error", err) + return nil, err + } + + return externalIDs, nil +} + func (s *Service) BulkCreateSecurity(ctx context.Context, paramsList []CreateSecurityParams) error { logger := logging.FromContext(ctx) for _, params := range paramsList { - fixedAt := sql.NullInt64{} - if params.FixedAt != nil { - fixedAt.Int64 = params.FixedAt.Unix() - fixedAt.Valid = true - } - - _, err := s.store.CreateSecurity(ctx, database.CreateSecurityParams{ - ExternalID: params.ExternalID, - RepositoryName: params.RepositoryName, - PackageName: params.PackageName, - State: params.State, - Severity: params.Severity, - PatchedVersion: params.PatchedVersion, - FixedAt: fixedAt, - }) - if err != nil { + if err := s.UpsertSecurity(ctx, params); err != nil { logger.Error("Error creating security", "error", err) return err } @@ -378,6 +461,75 @@ func (s *Service) BulkCreateSecurity(ctx context.Context, paramsList []CreateSec return nil } +func (s *Service) UpdateSecurity(ctx context.Context, params UpdateSecurityParams) error { + logger := logging.FromContext(ctx).With("service", "products") + + logger.Debug("Updating security") + + var fixedAt sql.NullInt64 + if params.FixedAt != nil { + fixedAt.Valid = true + fixedAt.Int64 = params.FixedAt.Unix() + } + + _, err := s.store.UpdateSecurity(ctx, database.UpdateSecurityParams{ + RepositoryName: params.RepositoryName, + PackageName: params.PackageName, + State: params.State, + Severity: params.Severity, + PatchedVersion: params.PatchedVersion, + FixedAt: fixedAt, + ExternalID: params.ExternalID, + }) + + if err != nil { + logger.Error("Error updating security", "error", err) + return err + } + + return nil +} + +func (s *Service) UpsertSecurity(ctx context.Context, params CreateSecurityParams) error { + logger := logging.FromContext(ctx).With("service", "products") + + var fixedAt sql.NullInt64 + if params.FixedAt != nil { + fixedAt.Int64 = params.FixedAt.Unix() + fixedAt.Valid = true + } + + _, createErr := s.store.CreateSecurity(ctx, database.CreateSecurityParams{ + ExternalID: params.ExternalID, + RepositoryName: params.RepositoryName, + PackageName: params.PackageName, + State: params.State, + Severity: params.Severity, + PatchedVersion: params.PatchedVersion, + FixedAt: fixedAt, + CreatedAt: params.CreatedAt.Unix(), + }) + + if createErr == nil { + return nil + } + + if !strings.Contains(createErr.Error(), "constraint failed: UNIQUE constraint failed") { + logger.Error("Error creating security", "error", createErr) + return createErr + } + + return s.UpdateSecurity(ctx, UpdateSecurityParams{ + ExternalID: params.ExternalID, + RepositoryName: params.RepositoryName, + PackageName: params.PackageName, + State: params.State, + Severity: params.Severity, + PatchedVersion: params.PatchedVersion, + FixedAt: params.FixedAt, + }) +} + func (s *Service) BulkInsertRepos(ctx context.Context, reposList []github.Node[github.Repository], tag string) error { params := toCreateRepoFromGithub(reposList, tag) diff --git a/internal/products/service_test.go b/internal/products/service_test.go index 5e5e9a3..4820315 100644 --- a/internal/products/service_test.go +++ b/internal/products/service_test.go @@ -163,6 +163,88 @@ func TestService(t *testing.T) { odize.AssertEqual(t, "url2", model.Url) odize.AssertEqual(t, "topic2", model.Topic) }). + Test("CreatePullRequest should create a new pull request", func(t *testing.T) { + params := CreatePRParams{ + ExternalID: "new-pr", + Title: "New PR", + RepositoryName: "repo1", + Url: "url1", + State: "OPEN", + Author: "author1", + CreatedAt: time.Now(), + } + + err := s.CreatePullRequest(ctx, params) + odize.AssertNoError(t, err) + + pr, err := _testDB.GetPullRequestByExternalID(ctx, params.ExternalID) + odize.AssertNoError(t, err) + odize.AssertEqual(t, params.Title, pr.Title) + odize.AssertEqual(t, params.Author, pr.Author) + }). + Test("UpdatePullRequest should update an existing pull request", func(t *testing.T) { + params := CreatePRParams{ + ExternalID: "update-pr", + Title: "Original Title", + RepositoryName: "repo1", + Url: "url1", + State: "OPEN", + Author: "author1", + CreatedAt: time.Now(), + } + + err := s.CreatePullRequest(ctx, params) + odize.AssertNoError(t, err) + + pr, _ := _testDB.GetPullRequestByExternalID(ctx, params.ExternalID) + + updateParams := UpdatePRParams{ + ID: pr.ID, + Title: "Updated Title", + RepositoryName: "repo1", + Url: "url-updated", + State: "CLOSED", + Author: "author-updated", + } + + err = s.UpdatePullRequest(ctx, updateParams) + odize.AssertNoError(t, err) + + updatedPr, err := _testDB.GetPullRequestByExternalID(ctx, params.ExternalID) + odize.AssertNoError(t, err) + odize.AssertEqual(t, updateParams.Title, updatedPr.Title) + odize.AssertEqual(t, updateParams.State, updatedPr.State) + odize.AssertEqual(t, updateParams.Author, updatedPr.Author) + }). + Test("UpsertPullRequest should create when not exists and update when exists", func(t *testing.T) { + params := CreatePRParams{ + ExternalID: "upsert-pr", + Title: "Upsert Title", + RepositoryName: "repo1", + Url: "url1", + State: "OPEN", + Author: "author1", + CreatedAt: time.Now(), + } + + // Create + err := s.UpsertPullRequest(ctx, params) + odize.AssertNoError(t, err) + + pr, err := _testDB.GetPullRequestByExternalID(ctx, params.ExternalID) + odize.AssertNoError(t, err) + odize.AssertEqual(t, params.Title, pr.Title) + + // Update + params.Title = "Upsert Updated Title" + err = s.UpsertPullRequest(ctx, params) + odize.AssertNoError(t, err) + + updatedPr, err := _testDB.GetPullRequestByExternalID(ctx, params.ExternalID) + odize.AssertNoError(t, err) + odize.AssertEqual(t, "Upsert Updated Title", updatedPr.Title) + odize.AssertEqual(t, pr.ID, updatedPr.ID) + }). Test("GetPullRequests and GetPullRequestByOrg", func(t *testing.T) { tag := fmt.Sprintf("pr-tag-%d", time.Now().UnixNano()) prod, _ := s.Create(ctx, CreateProductParams{Name: "PR Product", Tags: []string{tag}}) @@ -197,6 +279,33 @@ func TestService(t *testing.T) { odize.AssertNoError(t, err) odize.AssertTrue(t, len(orgPrs) > 0) }). + Test("GetRecentPullRequests should return external IDs of recent PRs", func(t *testing.T) { + params := CreatePRParams{ + ExternalID: "recent-pr-1", + Title: "Recent PR", + RepositoryName: "repo1", + Url: "url1", + State: "OPEN", + Author: "author1", + CreatedAt: time.Now(), + } + + err := s.CreatePullRequest(ctx, params) + odize.AssertNoError(t, err) + + recent, err := s.GetRecentPullRequests(ctx) + odize.AssertNoError(t, err) + odize.AssertTrue(t, len(recent) > 0) + + found := false + for _, id := range recent { + if id == params.ExternalID { + found = true + break + } + } + odize.AssertTrue(t, found) + }). Test("GetSecurity and GetSecurityByOrg", func(t *testing.T) { tag := fmt.Sprintf("sec-tag-%d", time.Now().UnixNano()) prod, _ := s.Create(ctx, CreateProductParams{Name: "Sec Product", Tags: []string{tag}}) @@ -229,6 +338,32 @@ func TestService(t *testing.T) { odize.AssertNoError(t, err) odize.AssertTrue(t, len(orgSecs) > 0) }). + Test("GetRecentSecurity should return external IDs of recent security alerts", func(t *testing.T) { + params := CreateSecurityParams{ + ExternalID: "recent-sec-1", + RepositoryName: "repo1", + PackageName: "pkg1", + State: "OPEN", + Severity: "HIGH", + CreatedAt: time.Now(), + } + + err := s.UpsertSecurity(ctx, params) + odize.AssertNoError(t, err) + + recent, err := s.GetRecentSecurity(ctx) + odize.AssertNoError(t, err) + odize.AssertTrue(t, len(recent) > 0) + + found := false + for _, id := range recent { + if id == params.ExternalID { + found = true + break + } + } + odize.AssertTrue(t, found) + }). Test("BulkInsertRepos", func(t *testing.T) { tag := "bulk-repo-tag" repos := []github.Node[github.Repository]{ @@ -289,6 +424,68 @@ func TestService(t *testing.T) { err := s.BulkInsertRepoDetails(ctx, repoDetails) odize.AssertNoError(t, err) }). + Test("UpdateSecurity should update an existing security alert", func(t *testing.T) { + params := CreateSecurityParams{ + ExternalID: "update-sec", + RepositoryName: "repo1", + PackageName: "pkg1", + State: "OPEN", + Severity: "HIGH", + CreatedAt: time.Now(), + } + + err := s.UpsertSecurity(ctx, params) + odize.AssertNoError(t, err) + + updateParams := UpdateSecurityParams{ + ExternalID: params.ExternalID, + RepositoryName: "repo1", + PackageName: "pkg1-updated", + State: "FIXED", + Severity: "CRITICAL", + PatchedVersion: "1.2.3", + } + + err = s.UpdateSecurity(ctx, updateParams) + odize.AssertNoError(t, err) + + updated, err := _testDB.GetSecurityByExternalID(ctx, params.ExternalID) + odize.AssertNoError(t, err) + odize.AssertEqual(t, updateParams.PackageName, updated.PackageName) + odize.AssertEqual(t, updateParams.State, updated.State) + odize.AssertEqual(t, updateParams.Severity, updated.Severity) + odize.AssertEqual(t, updateParams.PatchedVersion, updated.PatchedVersion) + }). + Test("UpsertSecurity should create when not exists and update when exists", func(t *testing.T) { + params := CreateSecurityParams{ + ExternalID: "upsert-sec", + RepositoryName: "repo1", + PackageName: "pkg-upsert", + State: "OPEN", + Severity: "LOW", + CreatedAt: time.Now(), + } + + // Create + err := s.UpsertSecurity(ctx, params) + odize.AssertNoError(t, err) + + sec, err := _testDB.GetSecurityByExternalID(ctx, params.ExternalID) + odize.AssertNoError(t, err) + odize.AssertEqual(t, params.PackageName, sec.PackageName) + + // Update + params.PackageName = "pkg-upsert-updated" + params.Severity = "MEDIUM" + err = s.UpsertSecurity(ctx, params) + odize.AssertNoError(t, err) + + updated, err := _testDB.GetSecurityByExternalID(ctx, params.ExternalID) + odize.AssertNoError(t, err) + odize.AssertEqual(t, "pkg-upsert-updated", updated.PackageName) + odize.AssertEqual(t, "MEDIUM", updated.Severity) + odize.AssertEqual(t, sec.ID, updated.ID) + }). Run() odize.AssertNoError(t, err) diff --git a/internal/products/types.go b/internal/products/types.go index b13745c..0aa9387 100644 --- a/internal/products/types.go +++ b/internal/products/types.go @@ -97,6 +97,17 @@ type CreatePRParams struct { CreatedAt time.Time } +type UpdatePRParams struct { + ID int64 + ExternalID string + Title string + RepositoryName string + Url string + State string + Author string + MergedAt *time.Time +} + type CreateSecurityParams struct { ExternalID string RepositoryName string @@ -107,3 +118,14 @@ type CreateSecurityParams struct { FixedAt *time.Time CreatedAt time.Time } + +type UpdateSecurityParams struct { + ID int64 + ExternalID string + RepositoryName string + PackageName string + State string + Severity string + PatchedVersion string + FixedAt *time.Time +} diff --git a/config.go b/internal/watchtower/config.go similarity index 81% rename from config.go rename to internal/watchtower/config.go index e03b50e..31af487 100644 --- a/config.go +++ b/internal/watchtower/config.go @@ -1,7 +1,7 @@ -package main +package watchtower import ( - "fmt" + "log" "log/slog" "os" "path" @@ -9,12 +9,6 @@ import ( "github.com/code-gorilla-au/env" ) -type Config struct { - Env string - AppDir string - LogLevel slog.Level -} - const appDirPath = "watchtower" func LoadConfig() Config { @@ -31,13 +25,13 @@ func LoadConfig() Config { appDir := path.Join(homeDir, appDirPath) if environment == "local" { - fmt.Print("LOCAL MODE") + log.Println("LOCAL MODE") current, _ := os.Getwd() appDir = path.Join(current, localDevDBPath) } else { // folder can already exist - _ = os.Mkdir(appDir, 0755) + _ = os.Mkdir(appDir, 0750) } return Config{ diff --git a/internal/watchtower/notifications.go b/internal/watchtower/notifications.go index 8f9f458..34a2338 100644 --- a/internal/watchtower/notifications.go +++ b/internal/watchtower/notifications.go @@ -6,8 +6,8 @@ import ( ) // GetUnreadNotifications retrieves a list of unread notifications for the specified organization ID. -func (s *Service) GetUnreadNotifications(orgID int64) ([]notifications.Notification, error) { - return s.notificationSvc.GetUnreadNotifications(s.ctx, orgID) +func (s *Service) GetUnreadNotifications() ([]notifications.Notification, error) { + return s.notificationSvc.GetUnreadNotifications(s.ctx) } // MarkNotificationAsRead marks a notification as read based on the provided notification ID. diff --git a/internal/watchtower/notifications_test.go b/internal/watchtower/notifications_test.go index b8ce873..b38646f 100644 --- a/internal/watchtower/notifications_test.go +++ b/internal/watchtower/notifications_test.go @@ -7,6 +7,7 @@ import ( "watchtower/internal/notifications" "github.com/code-gorilla-au/odize" + "github.com/google/uuid" ) func TestService_Notifications(t *testing.T) { @@ -21,36 +22,51 @@ func TestService_Notifications(t *testing.T) { err := group. Test("GetUnreadNotifications should return unread notifications", func(t *testing.T) { - orgID := int64(1001) - - // Seed a notification using the internal notification service - notif, err := s.notificationSvc.CreateNotification(ctx, orgID, "test-type", "test-content") + orgID := int64(0) + + err := s.notificationSvc.CreateNotification(ctx, notifications.CreateNotificationParams{ + OrgID: orgID, + NotificationType: "test-type", + Content: "test-content", + ExternalID: uuid.New().String(), + }) odize.AssertNoError(t, err) - // Fetch unread notifications - unread, err := s.GetUnreadNotifications(orgID) + unread, err := s.GetUnreadNotifications() odize.AssertNoError(t, err) odize.AssertEqual(t, 1, len(unread)) - odize.AssertEqual(t, notif.ID, unread[0].ID) odize.AssertEqual(t, "test-content", unread[0].Content) odize.AssertEqual(t, notifications.StatusUnread, unread[0].Status) }). Test("MarkNotificationAsRead should mark a notification as read", func(t *testing.T) { orgID := int64(1002) - notif, err := s.notificationSvc.CreateNotification(ctx, orgID, "type", "content") + err := s.notificationSvc.CreateNotification(ctx, notifications.CreateNotificationParams{ + OrgID: orgID, + NotificationType: "type", + Content: "content", + ExternalID: "test-external-id-2", + }) odize.AssertNoError(t, err) - err = s.MarkNotificationAsRead(notif.ID) + unread, err := s.GetUnreadNotifications() odize.AssertNoError(t, err) - unread, err := s.GetUnreadNotifications(orgID) + err = s.MarkNotificationAsRead(unread[0].ID) odize.AssertNoError(t, err) - odize.AssertEqual(t, 0, len(unread)) + + verifyUnread, err := s.GetUnreadNotifications() + odize.AssertNoError(t, err) + odize.AssertEqual(t, 0, len(verifyUnread)) }). Test("DeleteOldNotifications should delete notifications", func(t *testing.T) { orgID := int64(1003) - _, err := s.notificationSvc.CreateNotification(ctx, orgID, "type", "content") + err := s.notificationSvc.CreateNotification(ctx, notifications.CreateNotificationParams{ + OrgID: orgID, + NotificationType: "type", + Content: "content", + ExternalID: "test-external-id-3", + }) odize.AssertNoError(t, err) // DeleteOldNotifications uses time.Now() and the query uses created_at < ?. @@ -61,7 +77,7 @@ func TestService_Notifications(t *testing.T) { err = s.DeleteOldNotifications() odize.AssertNoError(t, err) - unread, err := s.GetUnreadNotifications(orgID) + unread, err := s.GetUnreadNotifications() odize.AssertNoError(t, err) odize.AssertEqual(t, 0, len(unread)) diff --git a/internal/watchtower/sync.go b/internal/watchtower/sync.go index 6e2d759..888b933 100644 --- a/internal/watchtower/sync.go +++ b/internal/watchtower/sync.go @@ -32,6 +32,61 @@ func (s *Service) Startup(ctx context.Context) { s.ctx = ctx } +// CreateUnreadPRNotification generates unread notifications for recent pull requests by fetching their IDs and creating notifications. +func (s *Service) CreateUnreadPRNotification() error { + logger := logging.FromContext(s.ctx) + + prIDs, err := s.productSvc.GetRecentPullRequests(s.ctx) + if err != nil { + logging.FromContext(s.ctx).Error("Error fetching recent pull requests", "error", err) + return err + } + + logger.Debug("Creating unread notifications for pull requests", "count", len(prIDs)) + + for _, id := range prIDs { + if notifyErr := s.notificationSvc.CreateNotification(s.ctx, notifications.CreateNotificationParams{ + OrgID: 0, + ExternalID: id, + NotificationType: "OPEN_PULL_REQUEST", + Content: "New pull request", + }); notifyErr != nil { + logger.Error("Error creating notification", "error", err) + return err + } + } + + return nil +} + +// CreateUnreadSecurityNotification generates unread security notifications for recent security alerts. +// It retrieves recent security-related IDs and creates notifications for each using the notification service. +// Returns an error if fetching security IDs or creating notifications fails. +func (s *Service) CreateUnreadSecurityNotification() error { + logger := logging.FromContext(s.ctx) + externalIDs, err := s.productSvc.GetRecentSecurity(s.ctx) + if err != nil { + logger.Error("Error fetching recent security", "error", err) + return err + } + + logger.Debug("creating unread notifications for security alerts", "count", len(externalIDs)) + + for _, id := range externalIDs { + if notifyErr := s.notificationSvc.CreateNotification(s.ctx, notifications.CreateNotificationParams{ + OrgID: 0, + ExternalID: id, + NotificationType: "SECURITY_ALERT", + Content: "New security alert", + }); notifyErr != nil { + logger.Error("Error creating notification", "error", err) + return err + } + } + + return nil +} + // SyncOrgs synchronizes stale organisations by retrieving them and invoking the sync process for each. func (s *Service) SyncOrgs() error { logger := logging.FromContext(s.ctx) diff --git a/internal/watchtower/types.go b/internal/watchtower/types.go index cb937b4..ca5f741 100644 --- a/internal/watchtower/types.go +++ b/internal/watchtower/types.go @@ -2,6 +2,7 @@ package watchtower import ( "context" + "log/slog" "watchtower/internal/notifications" "watchtower/internal/organisations" "watchtower/internal/products" @@ -14,3 +15,9 @@ type Service struct { notificationSvc *notifications.Service ghClient ghClient } + +type Config struct { + Env string + AppDir string + LogLevel slog.Level +} diff --git a/internal/watchtower/worker.go b/internal/watchtower/worker.go index 42fddf4..66c0b75 100644 --- a/internal/watchtower/worker.go +++ b/internal/watchtower/worker.go @@ -2,19 +2,23 @@ package watchtower import ( "context" + "log/slog" "time" "watchtower/internal/logging" "github.com/go-co-op/gocron/v2" + "github.com/google/uuid" ) type Workers struct { watchTower *Service cron gocron.Scheduler + logger *slog.Logger } func NewWorkers(wt *Service) (*Workers, error) { + logger := logging.FromContext(context.Background()).With("service", "workers") s, err := gocron.NewScheduler() if err != nil { return nil, err @@ -23,29 +27,24 @@ func NewWorkers(wt *Service) (*Workers, error) { return &Workers{ watchTower: wt, cron: s, + logger: logger, }, nil } func (w *Workers) AddJobs() error { - logger := logging.FromContext(context.Background()).With("service", "workers") - - if _, err := w.cron.NewJob(gocron.DurationJob(time.Minute*2), gocron.NewTask(func() { - logger.Debug("Running syncing orgs worker") - if err := w.watchTower.SyncOrgs(); err != nil { - logger.Error("Error syncing orgs", "error", err) - } - })); err != nil { + if _, err := w.cron.NewJob( + gocron.DurationJob(time.Minute*2), + gocron.NewTask(w.jobSyncOrgs), + gocron.WithEventListeners(gocron.AfterJobRuns(w.afterOrgSync)), + ); err != nil { return err } - if _, err := w.cron.NewJob(gocron.DurationJob(time.Minute*10), gocron.NewTask(func() { - logger.Debug("Running remove old notifications worker") - - if err := w.watchTower.DeleteOldNotifications(); err != nil { - logger.Error("Error syncing orgs", "error", err) - } - })); err != nil { + if _, err := w.cron.NewJob( + gocron.DurationJob(time.Minute*10), + gocron.NewTask(w.jobDeleteOldNotifications), + ); err != nil { return err } @@ -53,13 +52,47 @@ func (w *Workers) AddJobs() error { } func (w *Workers) Start(ctx context.Context) { - logger := logging.FromContext(ctx) - logger.Debug("Starting workers") + w.logger.Debug("Starting workers") + w.cron.Start() } func (w *Workers) Stop() { + w.logger.Debug("Stopping workers") + if err := w.cron.StopJobs(); err != nil { - logging.FromContext(context.Background()).Error("Error stopping org sync worker", "error", err) + w.logger.Error("Error stopping worker", "error", err) + } + + if err := w.cron.Shutdown(); err != nil { + w.logger.Error("Error shutting down worker", "error", err) + } +} + +func (w *Workers) jobSyncOrgs() { + w.logger.Debug("Running syncing orgs worker") + + if err := w.watchTower.SyncOrgs(); err != nil { + w.logger.Error("Error syncing orgs", "error", err) + } +} + +func (w *Workers) jobDeleteOldNotifications() { + w.logger.Debug("Running remove old notifications worker") + + if err := w.watchTower.DeleteOldNotifications(); err != nil { + w.logger.Error("Error syncing orgs", "error", err) + } +} + +func (w *Workers) afterOrgSync(jobID uuid.UUID, jobName string) { + w.logger.Debug("Running notification worker") + + if err := w.watchTower.CreateUnreadPRNotification(); err != nil { + w.logger.Error("Error creating unread PR notification", "error", err) + } + + if err := w.watchTower.CreateUnreadSecurityNotification(); err != nil { + w.logger.Error("Error creating unread security notification", "error", err) } } diff --git a/main.go b/main.go index 05885e6..07e3dcb 100644 --- a/main.go +++ b/main.go @@ -23,7 +23,7 @@ var assets embed.FS func main() { ctx := context.Background() - appConfig := LoadConfig() + appConfig := watchtower.LoadConfig() logger := logging.New(appConfig.LogLevel) logger.Debug("Starting watchtower", "config", appConfig) diff --git a/taskfile.yaml b/taskfile.yaml index 276eb80..46d3b9d 100644 --- a/taskfile.yaml +++ b/taskfile.yaml @@ -19,11 +19,17 @@ tasks: cmds: - wails dev - clean-db: + db-clean: desc: Clean the database. cmds: - rm -f $LOCAL_DEV_DIR/watchtower.db + db-migrate: + desc: Run database migrations. + deps: [gen] + cmds: + - go run ./cmd/local + gen: desc: Generate go code. cmds: @@ -141,7 +147,7 @@ tasks: reset: desc: Reset the application. - deps: [clean-db] + deps: [db-clean] cmds: - task: dev