diff --git a/README.md b/README.md index 25300abf36..1214ec3d33 100644 --- a/README.md +++ b/README.md @@ -439,6 +439,17 @@ Use this command to validate the contents of a package using the package specifi The command ensures that the package is aligned with the package spec and the README file is up-to-date with its template (if present). +### `elastic-package modify` + +_Context: package_ + +Use this command to apply modifications to a package. + +These modifications can range from applying best practices, generating ingest pipeline tags, and more. Run this command without any arguments to see a list of modifiers. + +Use --modifiers to specify which modifiers to run, separated by commas. + + ### `elastic-package profiles` _Context: global_ diff --git a/cmd/modify.go b/cmd/modify.go new file mode 100644 index 0000000000..84e0ea83fc --- /dev/null +++ b/cmd/modify.go @@ -0,0 +1,112 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License; +// you may not use this file except in compliance with the Elastic License. + +package cmd + +import ( + "fmt" + "io" + "sort" + "text/tabwriter" + + "github.com/spf13/cobra" + "github.com/spf13/pflag" + + "github.com/elastic/elastic-package/internal/cobraext" + "github.com/elastic/elastic-package/internal/fleetpkg" + "github.com/elastic/elastic-package/internal/modify" + "github.com/elastic/elastic-package/internal/modify/pipelinetag" + "github.com/elastic/elastic-package/internal/packages" +) + +const modifyLongDescription = `Use this command to apply modifications to a package. + +These modifications can range from applying best practices, generating ingest pipeline tags, and more. Run this command without any arguments to see a list of modifiers. + +Use --modifiers to specify which modifiers to run, separated by commas. +` + +func setupModifyCommand() *cobraext.Command { + modifiers := []*modify.Modifier{ + pipelinetag.Modifier, + } + sort.Slice(modifiers, func(i, j int) bool { + return modifiers[i].Name < modifiers[j].Name + }) + + validModifier := func(name string) bool { + for _, modifier := range modifiers { + if modifier.Name == name { + return true + } + } + + return false + } + + listModifiers := func(w io.Writer) { + tw := tabwriter.NewWriter(w, 0, 2, 3, ' ', 0) + for _, a := range modifiers { + _, _ = fmt.Fprintf(tw, "%s\t%s\n", a.Name, a.Doc) + } + _ = tw.Flush() + _, _ = fmt.Fprintln(w, "") + } + + cmd := &cobra.Command{ + Use: "modify", + Short: "Modify package assets", + Long: modifyLongDescription, + RunE: func(cmd *cobra.Command, args []string) error { + cmd.Println("Modify package assets") + + selectedModifiers, err := cmd.Flags().GetStringSlice("modifiers") + if err != nil { + return cobraext.FlagParsingError(err, "modifiers") + } + if len(selectedModifiers) == 0 { + _, _ = fmt.Fprint(cmd.OutOrStderr(), "Please provide at least one modifier:\n\n") + listModifiers(cmd.OutOrStderr()) + return nil + } + for _, selected := range selectedModifiers { + if !validModifier(selected) { + _, _ = fmt.Fprint(cmd.OutOrStderr(), "Please provide at a valid modifier:\n\n") + listModifiers(cmd.OutOrStderr()) + return cobraext.FlagParsingError(fmt.Errorf("invalid modifier: %q", selected), "modifiers") + } + } + + pkgRootPath, err := packages.FindPackageRoot() + if err != nil { + return fmt.Errorf("locating package root failed: %w", err) + } + + for _, modifier := range modifiers { + pkg, err := fleetpkg.Load(pkgRootPath) + if err != nil { + return fmt.Errorf("failed to load package from %q: %w", pkgRootPath, err) + } + if err = modifier.Run(pkg); err != nil { + return fmt.Errorf("failed to apply modifier %q: %w", modifier.Name, err) + } + } + + return nil + }, + } + + cmd.PersistentFlags().StringSliceP("modifiers", "m", nil, "List of modifiers to run, separated by commas") + + for _, m := range modifiers { + prefix := m.Name + "." + + m.Flags.VisitAll(func(f *pflag.Flag) { + name := prefix + f.Name + cmd.Flags().Var(f.Value, name, f.Usage) + }) + } + + return cobraext.NewCommand(cmd, cobraext.ContextPackage) +} diff --git a/cmd/root.go b/cmd/root.go index 9176f1ea21..c5f29569bc 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -32,6 +32,7 @@ var commands = []*cobraext.Command{ setupInstallCommand(), setupLinksCommand(), setupLintCommand(), + setupModifyCommand(), setupProfilesCommand(), setupReportsCommand(), setupServiceCommand(), diff --git a/go.mod b/go.mod index eca71c4e64..38e14229a7 100644 --- a/go.mod +++ b/go.mod @@ -22,6 +22,7 @@ require ( github.com/fatih/color v1.18.0 github.com/go-viper/mapstructure/v2 v2.5.0 github.com/gobwas/glob v0.2.3 + github.com/goccy/go-yaml v1.18.0 github.com/google/go-cmp v0.7.0 github.com/google/go-github/v32 v32.1.0 github.com/google/go-querystring v1.2.0 @@ -35,6 +36,7 @@ require ( github.com/rogpeppe/go-internal v1.14.1 github.com/shirou/gopsutil/v3 v3.24.5 github.com/spf13/cobra v1.10.2 + github.com/spf13/pflag v1.0.10 github.com/stretchr/testify v1.11.1 go.yaml.in/yaml/v2 v2.4.3 go4.org/netipx v0.0.0-20231129151722-fdeea329fbba @@ -144,7 +146,6 @@ require ( github.com/shopspring/decimal v1.4.0 // indirect github.com/spf13/afero v1.15.0 // indirect github.com/spf13/cast v1.7.0 // indirect - github.com/spf13/pflag v1.0.10 // indirect github.com/stretchr/objx v0.5.2 // indirect github.com/tklauser/go-sysconf v0.3.12 // indirect github.com/tklauser/numcpus v0.6.1 // indirect diff --git a/go.sum b/go.sum index 33dc3d3778..a0d367f41c 100644 --- a/go.sum +++ b/go.sum @@ -144,6 +144,8 @@ github.com/go-viper/mapstructure/v2 v2.5.0 h1:vM5IJoUAy3d7zRSVtIwQgBj7BiWtMPfmPE github.com/go-viper/mapstructure/v2 v2.5.0/go.mod h1:oJDH3BJKyqBA2TXFhDsKDGDTlndYOZ6rGS0BRZIxGhM= github.com/gobwas/glob v0.2.3 h1:A4xDbljILXROh+kObIiy5kIaPYD8e96x1tgBhUI5J+Y= github.com/gobwas/glob v0.2.3/go.mod h1:d3Ez4x06l9bZtSvzIay5+Yzi0fmZzPgnTbPcKjJAkT8= +github.com/goccy/go-yaml v1.18.0 h1:8W7wMFS12Pcas7KU+VVkaiCng+kG8QiFeFwzFb+rwuw= +github.com/goccy/go-yaml v1.18.0/go.mod h1:XBurs7gK8ATbW4ZPGKgcbrY1Br56PdM69F7LkFRi1kA= github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= github.com/golang/snappy v0.0.1/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/google/btree v1.1.3 h1:CVpQJjYgC4VbzxeGVHfvZrv1ctoYCAI8vbl07Fcxlyg= diff --git a/internal/fleetpkg/fleetpkg.go b/internal/fleetpkg/fleetpkg.go new file mode 100644 index 0000000000..e9bfa4f11c --- /dev/null +++ b/internal/fleetpkg/fleetpkg.go @@ -0,0 +1,207 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License; +// you may not use this file except in compliance with the Elastic License. + +package fleetpkg + +import ( + "encoding/json" + + "github.com/goccy/go-yaml" + "github.com/goccy/go-yaml/ast" + + "github.com/elastic/elastic-package/internal/yamledit" +) + +// Package is a fleet package. +type Package struct { + Manifest Manifest + Input *DataStream + DataStreams map[string]*DataStream + + sourceDir string +} + +// Path is the path to the root of the package. +func (i *Package) Path() string { + return i.sourceDir +} + +// Manifest is the package manifest. +type Manifest struct { + Name string `yaml:"name"` + Title string `yaml:"title"` + Version string `yaml:"version"` + Description string `yaml:"description"` + Type string `yaml:"type"` + FormatVersion string `yaml:"format_version"` + Owner struct { + Github string `yaml:"github"` + Type string `yaml:"type"` + } `yaml:"owner"` + + Doc *yamledit.Document `yaml:"-"` +} + +// Path is the path to the manifest file. +func (m *Manifest) Path() string { + return m.Doc.Filename() +} + +// DataStreamManifest is the data stream manifest file. +type DataStreamManifest struct { + Title string `yaml:"title"` + Type string `yaml:"type"` + + Doc *yamledit.Document `yaml:"-"` +} + +// Path is the path to the manifest file. +func (m *DataStreamManifest) Path() string { + return m.Doc.Filename() +} + +// DataStream is a data stream within the package. +type DataStream struct { + Manifest DataStreamManifest + Pipelines map[string]*Pipeline + + sourceDir string +} + +// Path is the path to the data stream. +func (d *DataStream) Path() string { + return d.sourceDir +} + +// Pipeline is an ingest pipeline. +type Pipeline struct { + Description string `yaml:"description"` + Processors []*Processor `yaml:"processors,omitempty"` + OnFailure []*Processor `yaml:"on_failure,omitempty"` + + Doc *yamledit.Document `yaml:"-"` +} + +// Path is the path to the pipeline. +func (p *Pipeline) Path() string { + return p.Doc.Filename() +} + +// Processor is an ingest pipeline processor. +type Processor struct { + Type string + Attributes map[string]any + OnFailure []*Processor + + Node ast.Node +} + +// GetAttribute gets an attribute of the processor. +func (p *Processor) GetAttribute(key string) (any, bool) { + v, ok := p.Attributes[key] + if !ok { + return nil, false + } + + return v, true +} + +// GetAttributeString gets a string attribute of the processor. +func (p *Processor) GetAttributeString(key string) (string, bool) { + v, ok := p.Attributes[key].(string) + if !ok { + return "", false + } + + return v, true +} + +// GetAttributeFloat gets a float attribute of the processor. +func (p *Processor) GetAttributeFloat(key string) (float64, bool) { + v, ok := p.Attributes[key].(float64) + if !ok { + return 0, false + } + + return v, true +} + +// GetAttributeInt gets an int attribute of the processor. +func (p *Processor) GetAttributeInt(key string) (int, bool) { + v, ok := p.Attributes[key].(int) + if !ok { + return 0, false + } + + return v, true +} + +// GetAttributeBool gets a bool attribute of the processor. +func (p *Processor) GetAttributeBool(key string) (bool, bool) { + v, ok := p.Attributes[key].(bool) + if !ok { + return false, false + } + + return v, true +} + +// UnmarshalYAML implements a YAML unmarshaler. +func (p *Processor) UnmarshalYAML(node ast.Node) error { + var procMap map[string]struct { + Attributes map[string]any `yaml:",inline"` + OnFailure []*Processor `yaml:"on_failure"` + } + if err := yaml.NodeToValue(node, &procMap); err != nil { + return err + } + + // The struct representation used here is much more convenient + // to work with than the original map of map format. + for k, v := range procMap { + p.Type = k + p.Attributes = v.Attributes + p.OnFailure = v.OnFailure + + delete(p.Attributes, "on_failure") + + break + } + + p.Node = node + + return nil +} + +// MarshalJSON implements a JSON marshaler. +func (p *Processor) MarshalJSON() ([]byte, error) { + properties := make(map[string]any, len(p.Attributes)+1) + for k, v := range p.Attributes { + properties[k] = v + } + if len(p.OnFailure) > 0 { + properties["on_failure"] = p.OnFailure + } + return json.Marshal(map[string]any{ + p.Type: properties, + }) +} + +// Validation is the validation.yml file of a package. +type Validation struct { + Errors struct { + ExcludeChecks []string `yaml:"exclude_checks,omitempty"` + } `yaml:"errors,omitempty"` + + DocsStructureEnforced struct { + Enabled bool `yaml:"enabled"` + Version int `yaml:"version"` + Skip []struct { + Title string `yaml:"title"` + Reason string `yaml:"reason"` + } `yaml:"skip,omitempty"` + } `yaml:"docs_structure_enforced"` + + Doc *yamledit.Document `yaml:"-"` +} diff --git a/internal/fleetpkg/fleetpkg_test.go b/internal/fleetpkg/fleetpkg_test.go new file mode 100644 index 0000000000..39c2ecee6b --- /dev/null +++ b/internal/fleetpkg/fleetpkg_test.go @@ -0,0 +1,300 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License; +// you may not use this file except in compliance with the Elastic License. + +package fleetpkg + +import ( + "testing" + + "github.com/goccy/go-yaml" + "github.com/goccy/go-yaml/parser" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestProcessor_UnmarshalYAML(t *testing.T) { + src := []byte(` +set: + field: key + value: some_value + on_failure: + - set: + field: event.kind + value: pipeline_error +`) + + want := Processor{ + Type: "set", + Attributes: map[string]interface{}{ + "field": "key", + "value": "some_value", + }, + OnFailure: []*Processor{ + { + Type: "set", + Attributes: map[string]interface{}{ + "field": "event.kind", + "value": "pipeline_error", + }, + }, + }, + } + + f, err := parser.ParseBytes(src, parser.ParseComments) + require.NoError(t, err) + + var got Processor + gotErr := yaml.NodeToValue(f.Docs[0].Body, &got) + require.NoError(t, gotErr) + + assert.Equal(t, want.Type, got.Type) + assert.Equal(t, want.Attributes, got.Attributes) + assert.Len(t, got.OnFailure, 1) + assert.Equal(t, want.OnFailure[0].Type, got.OnFailure[0].Type) + assert.Equal(t, want.OnFailure[0].Attributes, got.OnFailure[0].Attributes) + assert.Empty(t, want.OnFailure[0].OnFailure) + + assert.Equal(t, f.Docs[0].Body, got.Node) + + onFailurePath, err := yaml.PathString("$.set.on_failure[0]") + require.NoError(t, err) + onFailureNode, err := onFailurePath.FilterFile(f) + require.NoError(t, err) + assert.Equal(t, onFailureNode, got.OnFailure[0].Node) +} + +func TestProcessor_GetAttribute(t *testing.T) { + p := Processor{ + Attributes: map[string]any{ + "string": "test", + "int": 1, + "bool": true, + "float": 23.4, + }, + } + + testCases := []struct { + name string + key string + want any + wantFound bool + }{ + { + name: "ok", + key: "string", + want: "test", + wantFound: true, + }, + { + name: "not-found", + key: "missing", + wantFound: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + got, gotFound := p.GetAttribute(tc.key) + + if tc.wantFound { + assert.True(t, gotFound) + assert.Equal(t, tc.want, got) + } else { + assert.False(t, gotFound) + } + }) + } +} + +func TestProcessor_GetAttributeString(t *testing.T) { + p := Processor{ + Attributes: map[string]any{ + "string": "test", + "int": 1, + "bool": true, + "float": 23.4, + }, + } + + testCases := []struct { + name string + key string + want any + wantFound bool + }{ + { + name: "ok", + key: "string", + want: "test", + wantFound: true, + }, + { + name: "wrong-type", + key: "int", + wantFound: false, + }, + { + name: "not-found", + key: "missing", + wantFound: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + got, gotFound := p.GetAttributeString(tc.key) + + if tc.wantFound { + assert.True(t, gotFound) + assert.Equal(t, tc.want, got) + } else { + assert.False(t, gotFound) + } + }) + } +} + +func TestProcessor_GetAttributeFloat(t *testing.T) { + p := Processor{ + Attributes: map[string]any{ + "string": "test", + "int": 1, + "bool": true, + "float": 23.4, + }, + } + + testCases := []struct { + name string + key string + want any + wantFound bool + }{ + { + name: "ok", + key: "float", + want: 23.4, + wantFound: true, + }, + { + name: "wrong-type", + key: "string", + wantFound: false, + }, + { + name: "not-found", + key: "missing", + wantFound: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + got, gotFound := p.GetAttributeFloat(tc.key) + + if tc.wantFound { + assert.True(t, gotFound) + assert.Equal(t, tc.want, got) + } else { + assert.False(t, gotFound) + } + }) + } +} + +func TestProcessor_GetAttributeInt(t *testing.T) { + p := Processor{ + Attributes: map[string]any{ + "string": "test", + "int": 1, + "bool": true, + "float": 23.4, + }, + } + + testCases := []struct { + name string + key string + want any + wantFound bool + }{ + { + name: "ok", + key: "int", + want: 1, + wantFound: true, + }, + { + name: "wrong-type", + key: "string", + wantFound: false, + }, + { + name: "not-found", + key: "missing", + wantFound: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + got, gotFound := p.GetAttributeInt(tc.key) + + if tc.wantFound { + assert.True(t, gotFound) + assert.Equal(t, tc.want, got) + } else { + assert.False(t, gotFound) + } + }) + } +} + +func TestProcessor_GetAttributeBool(t *testing.T) { + p := Processor{ + Attributes: map[string]any{ + "string": "test", + "int": 1, + "bool": true, + "float": 23.4, + }, + } + + testCases := []struct { + name string + key string + want any + wantFound bool + }{ + { + name: "ok", + key: "bool", + want: true, + wantFound: true, + }, + { + name: "wrong-type", + key: "string", + wantFound: false, + }, + { + name: "not-found", + key: "missing", + wantFound: false, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + got, gotFound := p.GetAttributeBool(tc.key) + + if tc.wantFound { + assert.True(t, gotFound) + assert.Equal(t, tc.want, got) + } else { + assert.False(t, gotFound) + } + }) + } +} diff --git a/internal/fleetpkg/load.go b/internal/fleetpkg/load.go new file mode 100644 index 0000000000..888391aa6c --- /dev/null +++ b/internal/fleetpkg/load.go @@ -0,0 +1,74 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License; +// you may not use this file except in compliance with the Elastic License. + +package fleetpkg + +import ( + "path/filepath" + + "github.com/elastic/elastic-package/internal/yamledit" +) + +// Load will load a package from the given directory. +func Load(dir string) (*Package, error) { + pkg := Package{ + sourceDir: dir, + } + + // ------------------------------------------------------------------------- + // manifest.yml + + if _, err := yamledit.ParseDocumentFile(filepath.Join(dir, "manifest.yml"), &pkg.Manifest); err != nil { + return nil, err + } + + // ------------------------------------------------------------------------- + // Data Streams + + var dataStreamManifests []string + if pkg.Manifest.Type == "input" { + dataStreamManifests = []string{filepath.Join(dir, "manifest.yml")} + } else { + var err error + pkg.DataStreams = map[string]*DataStream{} + if dataStreamManifests, err = filepath.Glob(filepath.Join(dir, "data_stream/*/manifest.yml")); err != nil { + return nil, err + } + } + for _, manifestPath := range dataStreamManifests { + ds := &DataStream{ + sourceDir: filepath.Dir(manifestPath), + } + if pkg.Manifest.Type == "input" { + pkg.Input = ds + } else { + pkg.DataStreams[filepath.Base(ds.sourceDir)] = ds + + if _, err := yamledit.ParseDocumentFile(manifestPath, &ds.Manifest); err != nil { + return nil, err + } + + // ----------------------------------------------------------------- + // Pipelines + + pipelines, err := filepath.Glob(filepath.Join(ds.sourceDir, "elasticsearch/ingest_pipeline/*.yml")) + if err != nil { + return nil, err + } + + if len(pipelines) > 0 { + ds.Pipelines = map[string]*Pipeline{} + } + for _, pipelinePath := range pipelines { + var pipeline Pipeline + if _, err = yamledit.ParseDocumentFile(pipelinePath, &pipeline); err != nil { + return nil, err + } + ds.Pipelines[filepath.Base(pipelinePath)] = &pipeline + } + } + } + + return &pkg, nil +} diff --git a/internal/modify/modify.go b/internal/modify/modify.go new file mode 100644 index 0000000000..e390d0672f --- /dev/null +++ b/internal/modify/modify.go @@ -0,0 +1,18 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License; +// you may not use this file except in compliance with the Elastic License. + +package modify + +import ( + "github.com/spf13/pflag" + + "github.com/elastic/elastic-package/internal/fleetpkg" +) + +type Modifier struct { + Name string + Doc string + Flags pflag.FlagSet + Run func(pkg *fleetpkg.Package) error +} diff --git a/internal/modify/pipelinetag/pipelinetag.go b/internal/modify/pipelinetag/pipelinetag.go new file mode 100644 index 0000000000..75e4bdba2f --- /dev/null +++ b/internal/modify/pipelinetag/pipelinetag.go @@ -0,0 +1,158 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License; +// you may not use this file except in compliance with the Elastic License. + +package pipelinetag + +import ( + "encoding/json" + "fmt" + "hash/fnv" + "strings" + + "github.com/elastic/elastic-package/internal/fleetpkg" + "github.com/elastic/elastic-package/internal/modify" + "github.com/elastic/elastic-package/internal/yamledit" +) + +const Name = "pipeline-tag" + +var pathCleaner = strings.NewReplacer(".", "_", " ", "_", "@", "") + +var Modifier = &modify.Modifier{ + Name: Name, + Doc: "Generate tags for ingest pipeline processors", + Run: run, +} + +type processorNode struct { + Processor *fleetpkg.Processor + Parent *processorNode +} + +func (p *processorNode) ParentProcessor() *fleetpkg.Processor { + if p.Parent != nil { + return p.Parent.Processor + } + + return nil +} + +func run(pkg *fleetpkg.Package) error { + fmt.Println("Generating pipeline tags") + + for _, ds := range pkg.DataStreams { + for _, pipeline := range ds.Pipelines { + if err := processPipeline(pipeline); err != nil { + return err + } + + if pipeline.Doc.Modified() { + if err := pipeline.Doc.WriteFile(); err != nil { + return err + } + } + } + } + + return nil +} + +func processPipeline(pipeline *fleetpkg.Pipeline) error { + seen := map[string]*processorNode{} + + for _, proc := range pipeline.Processors { + node := processorNode{ + Processor: proc, + } + + if err := processTag(pipeline, &node, seen); err != nil { + return err + } + } + + return nil +} + +func processTag(pipeline *fleetpkg.Pipeline, node *processorNode, seen map[string]*processorNode) error { + var invalid bool + var err error + + tag, ok := node.Processor.Attributes["tag"].(string) + if ok { + if tag == "" { + invalid = true + } else if _, dup := seen[tag]; dup { + invalid = true + + } + } else { + invalid = true + } + + if invalid { + if tag, err = generateTag(node.Processor, node.ParentProcessor()); err != nil { + return err + } + if _, err = pipeline.Doc.SetKeyValue(fmt.Sprintf("%s.%s", node.Processor.Node.GetPath(), node.Processor.Type), "tag", tag, yamledit.IndexPrepend); err != nil { + return err + } + } + + seen[tag] = node + + for _, onFailProc := range node.Processor.OnFailure { + onFailProcNode := &processorNode{ + Processor: onFailProc, + Parent: node, + } + + if err = processTag(pipeline, onFailProcNode, seen); err != nil { + return err + } + } + + return nil +} + +func generateTag(proc, parent *fleetpkg.Processor) (string, error) { + hash, err := generateProcessorHash(proc, parent) + if err != nil { + return "", err + } + + field, ok := proc.Attributes["field"].(string) + if !ok || field == "" { + return proc.Type + "_" + hash, nil + } + field = pathCleaner.Replace(field) + + targetField, ok := proc.Attributes["target_field"].(string) + if !ok || targetField == "" { + return fmt.Sprintf("%s_%s_%s", proc.Type, field, hash), nil + } + targetField = pathCleaner.Replace(targetField) + + return fmt.Sprintf("%s_%s_to_%s_%s", proc.Type, field, targetField, hash), nil +} + +func generateProcessorHash(proc, parent *fleetpkg.Processor) (string, error) { + b, err := json.Marshal(proc) + if err != nil { + return "", fmt.Errorf("failed to marshal processor for hashing: %w", err) + } + + h := fnv.New32a() + _, _ = h.Write(b) + + if parent != nil { + b, err = json.Marshal(parent) + if err != nil { + return "", fmt.Errorf("failed to marshal parent processor for hashing: %w", err) + } + + _, _ = h.Write(b) + } + + return fmt.Sprintf("%08x", h.Sum32()), nil +} diff --git a/internal/modify/pipelinetag/pipelinetag_test.go b/internal/modify/pipelinetag/pipelinetag_test.go new file mode 100644 index 0000000000..3141d87be3 --- /dev/null +++ b/internal/modify/pipelinetag/pipelinetag_test.go @@ -0,0 +1,58 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License; +// you may not use this file except in compliance with the Elastic License. + +package pipelinetag + +import ( + "errors" + "testing" + + "github.com/goccy/go-yaml" + "github.com/goccy/go-yaml/ast" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/elastic/elastic-package/internal/fleetpkg" + "github.com/elastic/elastic-package/internal/yamledit" +) + +func getNodeString(f *ast.File, path string) (string, error) { + p, err := yaml.PathString(path) + if err != nil { + return "", err + } + + n, err := p.FilterFile(f) + if err != nil { + return "", err + } + + sn, ok := n.(*ast.StringNode) + if !ok { + return "", errors.New("expected string node") + } + + return sn.Value, nil +} + +func assertProcessorTag(t *testing.T, pipeline *fleetpkg.Pipeline, path string, want string) { + got, err := getNodeString(pipeline.Doc.AST(), path) + assert.NoError(t, err) + assert.Equal(t, want, got) +} + +func Test_GenerateTags(t *testing.T) { + var pipeline fleetpkg.Pipeline + _, err := yamledit.ParseDocumentFile("testdata/default.yml", &pipeline) + require.NoError(t, err) + + err = processPipeline(&pipeline) + require.NoError(t, err) + + assert.True(t, pipeline.Doc.Modified()) + + assertProcessorTag(t, &pipeline, "$.processors[0].set.tag", "set_sample_field_71a88542") + assertProcessorTag(t, &pipeline, "$.processors[1].set.tag", "valid_tag") + assertProcessorTag(t, &pipeline, "$.processors[1].set.on_failure[0].set.tag", "set_sample_field_a7a6e0d7") +} diff --git a/internal/modify/pipelinetag/testdata/default.yml b/internal/modify/pipelinetag/testdata/default.yml new file mode 100644 index 0000000000..2b1576fefd --- /dev/null +++ b/internal/modify/pipelinetag/testdata/default.yml @@ -0,0 +1,19 @@ +--- +description: Pipeline for processing sample logs +processors: + - set: + field: sample_field + value: "1" + - set: + tag: valid_tag + field: sample_field + value: "1" + on_failure: + - set: + field: sample_field + value: "1" + +on_failure: + - set: + field: error.message + value: '{{ _ingest.on_failure_message }}' diff --git a/internal/yamledit/testdata/invalid.yml b/internal/yamledit/testdata/invalid.yml new file mode 100644 index 0000000000..397db75f0d --- /dev/null +++ b/internal/yamledit/testdata/invalid.yml @@ -0,0 +1 @@ +: diff --git a/internal/yamledit/testdata/valid.yml b/internal/yamledit/testdata/valid.yml new file mode 100644 index 0000000000..b25ea302db --- /dev/null +++ b/internal/yamledit/testdata/valid.yml @@ -0,0 +1,24 @@ +--- +string: value +int: 1 +bool: true +list: + - one + - two + - three +map: + string: value + int: 1 + bool: true + list: + - one + - two + - three + map: + string: value + int: 1 + bool: true + list: + - one + - two + - three diff --git a/internal/yamledit/yamledit.go b/internal/yamledit/yamledit.go new file mode 100644 index 0000000000..a470a8082e --- /dev/null +++ b/internal/yamledit/yamledit.go @@ -0,0 +1,427 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License; +// you may not use this file except in compliance with the Elastic License. + +package yamledit + +import ( + "errors" + "fmt" + "hash/fnv" + "io" + "os" + "reflect" + "slices" + "strconv" + "strings" + + "github.com/goccy/go-yaml" + "github.com/goccy/go-yaml/ast" + "github.com/goccy/go-yaml/parser" + "github.com/goccy/go-yaml/printer" +) + +const ( + // IndexPrepend is a shorthand for an index that prepends to a list. + IndexPrepend = 0 + // IndexAppend is a shorthand for an index that appends to a list. + IndexAppend = -1 +) + +// ErrInvalidNodeType indicates a node was not of an expected type. +var ErrInvalidNodeType = errors.New("invalid node type") + +// Document defines a YAML document. +type Document struct { + f *ast.File + h uint64 +} + +// AST returns the underlying AST of the document. +func (d *Document) AST() *ast.File { + return d.f +} + +// Modified returns true if the document has been modified. +func (d *Document) Modified() bool { + return d.Hash() != d.h +} + +// Hash returns a FNV1a 64-bit hash of the document. +func (d *Document) Hash() uint64 { + h := fnv.New64a() + _, _ = h.Write([]byte(d.f.String())) + + return h.Sum64() +} + +// Filename returns the filename of the document (if applicable). +func (d *Document) Filename() string { + return d.f.Name +} + +// WriteFile writes the document to the original file. +func (d *Document) WriteFile() error { + if d.f.Name == "" { + return errors.New("failed to write document: empty filename") + } + + return d.WriteFileAs(d.Filename()) +} + +// WriteFileAs writes the document to the given file. +func (d *Document) WriteFileAs(filename string) error { + p := printer.Printer{} + data := p.PrintNode(d.f.Docs[0]) + + if err := os.WriteFile(filename, data, 0o644); err != nil { + return fmt.Errorf("failed to write document to file %q: %w", filename, err) + } + + return nil +} + +// Write writes the document to writer. +func (d *Document) Write(w io.Writer) (int, error) { + p := printer.Printer{} + data := p.PrintNode(d.f.Docs[0]) + + return w.Write(data) +} + +// Parse will attempt to parse the document into v. +func (d *Document) Parse(v any) error { + if err := yaml.NodeToValue(d.f.Docs[0].Body, v); err != nil { + return err + } + + // Set the Document field on v, if v is a pointer to a struct and the field + // on the struct is exported. + if rv := reflect.ValueOf(v); rv.Kind() == reflect.Ptr { + if structValue := rv.Elem(); structValue.Kind() == reflect.Struct { + for i := 0; i < structValue.NumField(); i++ { + if field := structValue.Field(i); field.CanAddr() && field.CanSet() { + if _, ok := field.Interface().(*Document); ok { + field.Set(reflect.ValueOf(d)) + break + } + } + } + } + } + + return nil +} + +// GetNode gets the node the given path. +func (d *Document) GetNode(path string) (ast.Node, error) { + p, err := yaml.PathString(path) + if err != nil { + return nil, err + } + + return p.FilterFile(d.f) +} + +// GetMappingNode gets the mapping node at the given path. +func (d *Document) GetMappingNode(path string) (*ast.MappingNode, error) { + n, err := d.GetNode(path) + if err != nil { + return nil, err + } + + mn, ok := n.(*ast.MappingNode) + if !ok { + return nil, fmt.Errorf("%w: expected a MappingNode, got %T", ErrInvalidNodeType, n) + } + + return mn, nil +} + +// GetSequenceNode gets the sequence node at the given path. +func (d *Document) GetSequenceNode(path string) (*ast.SequenceNode, error) { + n, err := d.GetNode(path) + if err != nil { + return nil, err + } + + mn, ok := n.(*ast.SequenceNode) + if !ok { + return nil, fmt.Errorf("%w: expected a SequenceNode, got %T", ErrInvalidNodeType, n) + } + + return mn, nil +} + +// GetParentNode gets the parent node for the node at the given path. If the +// node is the root node of the document, nil is returned. If the leaf node of +// the path does not exist in the document, but the parent node exists, the +// parent node will still be returned. +func (d *Document) GetParentNode(path string) (ast.Node, error) { + if path == "$" { + return nil, nil + } + + lastDot := strings.LastIndex(path, ".") + lastSeq := strings.LastIndex(path, "[") + + parentStr := path[:max(lastDot, lastSeq)] + parentPath, _ := yaml.PathString(parentStr) + + return parentPath.FilterFile(d.f) +} + +// DeleteNode deletes the node from the document at the given path. Only supported +// in cases where the parent node is mapping or sequence node. +func (d *Document) DeleteNode(path string) (bool, error) { + p, err := yaml.PathString(path) + if err != nil { + return false, err + } + n, err := p.FilterFile(d.f) + if err != nil { + return false, err + } + + parentNode, err := d.GetParentNode(path) + if err != nil { + return false, err + } + if parentNode == nil { + return false, errors.New("cannot delete root node") + } + + switch parentNode.Type() { + case ast.MappingType: + mn := parentNode.(*ast.MappingNode) + for i, kv := range mn.Values { + if kv.Value == n { + mn.Values = append(mn.Values[:i], mn.Values[i+1:]...) + + return true, nil + } + } + case ast.SequenceType: + sn := parentNode.(*ast.SequenceNode) + index := getPathIndex(path) + + sn.Values = append(sn.Values[:index], sn.Values[index+1:]...) + sn.ValueHeadComments = append(sn.ValueHeadComments[:index], sn.ValueHeadComments[index+1:]...) + + return true, nil + default: + return false, fmt.Errorf("unable to delete node with parent type %s at %q: %w", parentNode.Type(), p, err) + } + + return false, nil +} + +// AddValue is like AddNode, but takes a raw value rather than a YAML node. +func (d *Document) AddValue(path string, v any, index int, replace bool) (bool, error) { + n, err := yaml.ValueToNode(v) + if err != nil { + return false, err + } + + return d.AddNode(path, n, index, replace) +} + +// PrependValue prepends the value to the sequence at the given path. +func (d *Document) PrependValue(path string, v any) (bool, error) { + return d.AddValue(path, v, IndexPrepend, false) +} + +// AppendValue appends the value to the sequence at the given path. +func (d *Document) AppendValue(path string, v any) (bool, error) { + return d.AddValue(path, v, IndexAppend, false) +} + +// AddNode adds a node to the sequence node at the given path. The insertion +// point and behavior at insertion can be controlled by the index and replace +// arguments, respectively. +func (d *Document) AddNode(path string, n ast.Node, index int, replace bool) (bool, error) { + sn, err := d.GetSequenceNode(path) + if err != nil { + return false, err + } + + doReplace := replace && index > 0 && index < len(sn.Values) + if doReplace { + if !nodeEqual(sn.Values[index], n) { + sn.Values[index] = n + sn.ValueHeadComments[index] = n.GetComment() + + return true, nil + } + return false, nil + } + + if index < 0 || index >= len(sn.Values) { + sn.Values = append(sn.Values, n) + sn.ValueHeadComments = append(sn.ValueHeadComments, n.GetComment()) + } else { + sn.Values = slices.Insert(sn.Values, index, n) + sn.ValueHeadComments = slices.Insert(sn.ValueHeadComments, index, n.GetComment()) + } + + return true, nil +} + +// SetKeyValue is like SetKeyNode, but takes a raw value rather than a YAML node. +func (d *Document) SetKeyValue(path, key string, v any, index int) (bool, error) { + n, err := yaml.ValueToNode(v) + if err != nil { + return false, err + } + + return d.SetKeyNode(path, key, n, index) +} + +// SetKeyNode sets the node for a key in a mapping node at the given path. +func (d *Document) SetKeyNode(path, key string, n ast.Node, index int) (bool, error) { + mn, err := d.GetMappingNode(path) + if err != nil { + return false, err + } + + for _, kv := range mn.Values { + if kv.Key.String() != key { + continue + } + + p, err := yaml.PathString(path + "." + key) + if err != nil { + return false, err + } + + err = p.ReplaceWithNode(d.f, n) + return err == nil, err + } + + newNode, err := yaml.ValueToNode(map[string]any{ + key: n, + }) + if err != nil { + return false, err + } + + newValue := newNode.(*ast.MappingNode).Values[0] + newValue.AddColumn(mn.GetToken().Position.IndentNum) + + if index >= 0 && index < len(mn.Values) { + mn.Values = slices.Insert(mn.Values, index, newValue) + } else { + mn.Values = append(mn.Values, newValue) + } + + return true, nil +} + +// ParseDocumentFile is like NewDocumentFile, but will unmarshal the +// document into value given by v. +func ParseDocumentFile(filename string, v any) (*Document, error) { + d, err := NewDocumentFile(filename) + if err != nil { + return nil, err + } + + if err = d.Parse(v); err != nil { + return d, err + } + + return d, nil +} + +// ParseDocumentBytes is like NewDocumentBytes, but will unmarshal the +// document into value given by v. +func ParseDocumentBytes(data []byte, v any) (*Document, error) { + d, err := NewDocumentBytes(data) + if err != nil { + return nil, err + } + + if err = d.Parse(v); err != nil { + return nil, err + } + + return d, nil +} + +// NewDocumentFile creates a new document from the given yaml file. +func NewDocumentFile(filename string) (*Document, error) { + data, err := os.ReadFile(filename) + if err != nil { + return nil, fmt.Errorf("unable to read file %q: %w", filename, err) + } + + d, err := NewDocumentBytes(data) + if err != nil { + return nil, fmt.Errorf("unable to parse file %q: %w", filename, err) + } + d.f.Name = filename + + return d, nil +} + +// NewDocumentBytes creates a new document from the given yaml bytes. +func NewDocumentBytes(data []byte) (*Document, error) { + var d Document + var err error + + d.f, err = parser.ParseBytes(data, parser.ParseComments) + if err != nil { + return nil, err + } + d.h = d.Hash() + + return &d, nil +} + +// getPathIndex gets the index referenced at the end of the path. The last +// element of the path must be a sequence, otherwise -1 will be returned. +func getPathIndex(path string) int { + lastDot := strings.LastIndex(path, ".") + lastSeq := strings.LastIndex(path, "[") + + if lastSeq == -1 { + return -1 + } + if lastDot > lastSeq { + return -1 + } + + closeSeq := strings.LastIndex(path[lastSeq:], "]") + if closeSeq == -1 { + return -1 + } + + i, err := strconv.Atoi(path[lastSeq+1 : lastSeq+closeSeq]) + if err != nil { + return -1 + } + + return i +} + +// cutPath splits the last element of the path, returning the parent path and +// the last element of the path, or an error if the path is not valid. +func cutPath(path string) (string, string, error) { + idx := strings.LastIndex(path, ".") + if idx < 0 { + return "", "", fmt.Errorf("unable to get parent path of %q", path) + } + + before := path[:idx] + after := path[idx+1:] + + return before, after, nil +} + +// nodeEqual returns true if the two nodes are equal. +func nodeEqual(a, b ast.Node) bool { + var x, y any + _ = yaml.NodeToValue(a, &x) + _ = yaml.NodeToValue(b, &y) + + return reflect.DeepEqual(x, y) +} diff --git a/internal/yamledit/yamledit_test.go b/internal/yamledit/yamledit_test.go new file mode 100644 index 0000000000..d2ea8d0c10 --- /dev/null +++ b/internal/yamledit/yamledit_test.go @@ -0,0 +1,910 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License; +// you may not use this file except in compliance with the Elastic License. + +package yamledit + +import ( + "testing" + + "github.com/goccy/go-yaml" + "github.com/goccy/go-yaml/ast" + "github.com/goccy/go-yaml/parser" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func mustNewDocumentFile(filename string) *Document { + d, err := NewDocumentFile(filename) + if err != nil { + panic(err) + } + + return d +} + +func mustYAMLPathString(path string) *yaml.Path { + p, err := yaml.PathString(path) + if err != nil { + panic(err) + } + + return p +} + +func mustNodeFromString(s string) ast.Node { + f, err := parser.ParseBytes([]byte(s), parser.ParseComments) + if err != nil { + panic(err) + } + + return f.Docs[0].Body +} + +func TestDocument_GetNode(t *testing.T) { + testCases := []struct { + name string + doc *Document + path string + wantErr bool + }{ + { + name: "ok", + doc: mustNewDocumentFile("testdata/valid.yml"), + path: "$.string", + }, + { + name: "not-found", + doc: mustNewDocumentFile("testdata/valid.yml"), + path: "$.missing", + wantErr: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + got, err := tc.doc.GetNode(tc.path) + + if tc.wantErr { + assert.Error(t, err) + } else { + want, lookupErr := mustYAMLPathString(tc.path).FilterFile(tc.doc.f) + require.NoError(t, lookupErr) + + assert.NoError(t, err) + assert.Equal(t, want, got) + } + }) + } +} + +func TestDocument_GetMappingNode(t *testing.T) { + testCases := []struct { + name string + doc *Document + path string + wantErr bool + }{ + { + name: "ok", + doc: mustNewDocumentFile("testdata/valid.yml"), + path: "$.map", + }, + { + name: "wrong-type", + doc: mustNewDocumentFile("testdata/valid.yml"), + path: "$.list", + wantErr: true, + }, + { + name: "not-found", + doc: mustNewDocumentFile("testdata/valid.yml"), + path: "$.missing", + wantErr: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + got, err := tc.doc.GetMappingNode(tc.path) + + if tc.wantErr { + assert.Error(t, err) + } else { + wantRaw, lookupErr := mustYAMLPathString(tc.path).FilterFile(tc.doc.f) + require.NoError(t, lookupErr) + want, ok := wantRaw.(*ast.MappingNode) + require.True(t, ok) + require.NotNil(t, want) + + assert.NoError(t, err) + assert.Equal(t, want, got) + } + }) + } +} + +func TestDocument_GetSequenceNode(t *testing.T) { + testCases := []struct { + name string + doc *Document + path string + wantErr bool + }{ + { + name: "ok", + doc: mustNewDocumentFile("testdata/valid.yml"), + path: "$.list", + }, + { + name: "wrong-type", + doc: mustNewDocumentFile("testdata/valid.yml"), + path: "$.map", + wantErr: true, + }, + { + name: "not-found", + doc: mustNewDocumentFile("testdata/valid.yml"), + path: "$.missing", + wantErr: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + got, err := tc.doc.GetSequenceNode(tc.path) + + if tc.wantErr { + assert.Error(t, err) + } else { + wantRaw, lookupErr := mustYAMLPathString(tc.path).FilterFile(tc.doc.f) + require.NoError(t, lookupErr) + want, ok := wantRaw.(*ast.SequenceNode) + require.True(t, ok) + require.NotNil(t, want) + + assert.NoError(t, err) + assert.Equal(t, want, got) + } + }) + } +} + +func TestDocument_GetParentNode(t *testing.T) { + testCases := []struct { + name string + doc *Document + path string + wantNode *yaml.Path + wantErr bool + }{ + { + name: "ok-mapping", + doc: mustNewDocumentFile("testdata/valid.yml"), + path: "$.string", + wantNode: mustYAMLPathString("$"), + }, + { + name: "ok-sequence", + doc: mustNewDocumentFile("testdata/valid.yml"), + path: "$.list[1]", + wantNode: mustYAMLPathString("$.list"), + }, + { + name: "ok-root", + doc: mustNewDocumentFile("testdata/valid.yml"), + path: "$", + }, + { + name: "ok-not-found-parent-exists", + doc: mustNewDocumentFile("testdata/valid.yml"), + path: "$.missing", + wantNode: mustYAMLPathString("$"), + }, + { + name: "bad-not-found", + doc: mustNewDocumentFile("testdata/valid.yml"), + path: "$.missing.gone", + wantErr: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + got, err := tc.doc.GetParentNode(tc.path) + + if tc.wantErr { + assert.Error(t, err) + } else { + var want ast.Node + if tc.wantNode != nil { + want, err = tc.wantNode.FilterFile(tc.doc.f) + require.NoError(t, err) + require.NotNil(t, want) + } + + assert.NoError(t, err) + assert.Equal(t, want, got) + } + }) + } +} + +func TestDocument_AddValue(t *testing.T) { + testCases := []struct { + name string + doc *Document + path string + value any + index int + replace bool + want ast.Node + wantMod bool + wantErr bool + }{ + { + name: "ok-append", + doc: mustNewDocumentFile("testdata/valid.yml"), + path: "$.list", + index: IndexAppend, + value: "four", + want: mustNodeFromString("[one, two, three, four]"), + wantMod: true, + }, + { + name: "ok-prepend", + doc: mustNewDocumentFile("testdata/valid.yml"), + path: "$.list", + index: IndexPrepend, + value: "four", + want: mustNodeFromString("[four, one, two, three]"), + wantMod: true, + }, + { + name: "ok-replace", + doc: mustNewDocumentFile("testdata/valid.yml"), + path: "$.list", + index: 1, + replace: true, + value: "four", + want: mustNodeFromString("[one, four, three]"), + wantMod: true, + }, + { + name: "ok-replace-equal", + doc: mustNewDocumentFile("testdata/valid.yml"), + path: "$.list", + index: 1, + replace: true, + value: "two", + want: mustNodeFromString("[one, two, three]"), + }, + { + name: "bad-not-a-sequence", + doc: mustNewDocumentFile("testdata/valid.yml"), + path: "$.map", + index: 1, + replace: true, + value: "two", + wantErr: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + gotMod, gotErr := tc.doc.AddValue(tc.path, tc.value, tc.index, tc.replace) + + if tc.wantErr { + assert.Error(t, gotErr) + assert.False(t, gotMod) + assert.False(t, tc.doc.Modified()) + } else { + assert.NoError(t, gotErr) + assert.Equal(t, tc.wantMod, gotMod) + assert.Equal(t, tc.wantMod, tc.doc.Modified()) + + gotNode, lookupErr := mustYAMLPathString(tc.path).FilterFile(tc.doc.f) + require.NoError(t, lookupErr) + require.NotNil(t, gotNode) + + var gotV any + require.NoError(t, yaml.NodeToValue(gotNode, &gotV)) + var wantV any + require.NoError(t, yaml.NodeToValue(tc.want, &wantV)) + + assert.Equal(t, wantV, gotV) + } + }) + } +} + +func TestDocument_PrependValue(t *testing.T) { + testCases := []struct { + name string + doc *Document + path string + value any + want ast.Node + wantMod bool + wantErr bool + }{ + { + name: "ok", + doc: mustNewDocumentFile("testdata/valid.yml"), + path: "$.list", + value: "four", + want: mustNodeFromString("[four, one, two, three]"), + wantMod: true, + }, + { + name: "bad-not-a-sequence", + doc: mustNewDocumentFile("testdata/valid.yml"), + path: "$.map", + value: "two", + wantErr: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + gotMod, gotErr := tc.doc.PrependValue(tc.path, tc.value) + + if tc.wantErr { + assert.Error(t, gotErr) + assert.False(t, gotMod) + assert.False(t, tc.doc.Modified()) + } else { + assert.NoError(t, gotErr) + assert.Equal(t, tc.wantMod, gotMod) + assert.Equal(t, tc.wantMod, tc.doc.Modified()) + + gotNode, lookupErr := mustYAMLPathString(tc.path).FilterFile(tc.doc.f) + require.NoError(t, lookupErr) + require.NotNil(t, gotNode) + + var gotV any + require.NoError(t, yaml.NodeToValue(gotNode, &gotV)) + var wantV any + require.NoError(t, yaml.NodeToValue(tc.want, &wantV)) + + assert.Equal(t, wantV, gotV) + } + }) + } +} + +func TestDocument_AppendValue(t *testing.T) { + testCases := []struct { + name string + doc *Document + path string + value any + want ast.Node + wantMod bool + wantErr bool + }{ + { + name: "ok", + doc: mustNewDocumentFile("testdata/valid.yml"), + path: "$.list", + value: "four", + want: mustNodeFromString("[one, two, three, four]"), + wantMod: true, + }, + { + name: "bad-not-a-sequence", + doc: mustNewDocumentFile("testdata/valid.yml"), + path: "$.map", + value: "two", + wantErr: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + gotMod, gotErr := tc.doc.AppendValue(tc.path, tc.value) + + if tc.wantErr { + assert.Error(t, gotErr) + assert.False(t, gotMod) + assert.False(t, tc.doc.Modified()) + } else { + assert.NoError(t, gotErr) + assert.Equal(t, tc.wantMod, gotMod) + assert.Equal(t, tc.wantMod, tc.doc.Modified()) + + gotNode, lookupErr := mustYAMLPathString(tc.path).FilterFile(tc.doc.f) + require.NoError(t, lookupErr) + require.NotNil(t, gotNode) + + var gotV any + require.NoError(t, yaml.NodeToValue(gotNode, &gotV)) + var wantV any + require.NoError(t, yaml.NodeToValue(tc.want, &wantV)) + + assert.Equal(t, wantV, gotV) + } + }) + } +} + +func TestDocument_SetKeyValue(t *testing.T) { + testCases := []struct { + name string + doc *Document + path string + key string + value any + index int + wantMod bool + wantErr bool + }{ + { + name: "ok-append-new", + doc: mustNewDocumentFile("testdata/valid.yml"), + path: "$.map", + key: "new_item", + index: IndexAppend, + value: "foobar", + wantMod: true, + }, + { + name: "ok-prepend-new", + doc: mustNewDocumentFile("testdata/valid.yml"), + path: "$.map", + key: "new_item", + index: IndexPrepend, + value: "foobar", + wantMod: true, + }, + { + name: "ok-prepend-new", + doc: mustNewDocumentFile("testdata/valid.yml"), + path: "$.map", + key: "string", + index: IndexAppend, + value: "foobar", + wantMod: true, + }, + { + name: "bad-not-a-mapping", + doc: mustNewDocumentFile("testdata/valid.yml"), + path: "$.list", + index: IndexAppend, + value: "foobar", + wantErr: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + gotMod, gotErr := tc.doc.SetKeyValue(tc.path, tc.key, tc.value, tc.index) + + if tc.wantErr { + assert.Error(t, gotErr) + assert.False(t, gotMod) + assert.False(t, tc.doc.Modified()) + } else { + assert.NoError(t, gotErr) + assert.Equal(t, tc.wantMod, gotMod) + assert.Equal(t, tc.wantMod, tc.doc.Modified()) + + gotNode, lookupErr := mustYAMLPathString(tc.path + "." + tc.key).FilterFile(tc.doc.f) + require.NoError(t, lookupErr) + require.NotNil(t, gotNode) + + var gotV any + require.NoError(t, yaml.NodeToValue(gotNode, &gotV)) + + assert.Equal(t, tc.value, gotV) + } + }) + } +} + +func TestDocument_DeleteNode(t *testing.T) { + testCases := []struct { + name string + doc *Document + path string + wantMod bool + wantErr bool + }{ + { + name: "ok-sequence", + doc: mustNewDocumentFile("testdata/valid.yml"), + path: "$.list[1]", + wantMod: true, + }, + { + name: "ok-mapping", + doc: mustNewDocumentFile("testdata/valid.yml"), + path: "$.map", + wantMod: true, + }, + { + name: "bad-missing-sequence", + doc: mustNewDocumentFile("testdata/valid.yml"), + path: "$.list[4]", + wantErr: true, + }, + { + name: "bad-missing-mapping", + doc: mustNewDocumentFile("testdata/valid.yml"), + path: "$.map.missing", + wantErr: true, + }, + { + name: "bad-invalid-path", + doc: mustNewDocumentFile("testdata/valid.yml"), + path: ".map.missing", + wantErr: true, + }, + { + name: "bad-root", + doc: mustNewDocumentFile("testdata/valid.yml"), + path: "$", + wantErr: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + gotMod, gotErr := tc.doc.DeleteNode(tc.path) + + if tc.wantErr { + assert.Error(t, gotErr) + assert.False(t, gotMod) + assert.False(t, tc.doc.Modified()) + } else { + assert.NoError(t, gotErr) + assert.Equal(t, tc.wantMod, gotMod) + assert.Equal(t, tc.wantMod, tc.doc.Modified()) + } + }) + } +} + +func TestNewDocumentFile(t *testing.T) { + testCases := []struct { + name string + filename string + wantErr bool + }{ + { + name: "ok", + filename: "testdata/valid.yml", + }, + { + name: "bad", + filename: "testdata/invalid.yml", + wantErr: true, + }, + { + name: "missing", + filename: "testdata/missing.yml", + wantErr: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + got, err := NewDocumentFile(tc.filename) + + if tc.wantErr { + assert.Error(t, err) + assert.Nil(t, got) + } else { + assert.NoError(t, err) + assert.NotNil(t, got) + assert.Equal(t, tc.filename, got.Filename()) + } + }) + } +} + +func TestParseDocumentFile(t *testing.T) { + t.Run("ok", func(t *testing.T) { + type TestStruct struct { + String string `yaml:"string"` + Int int `yaml:"int"` + Bool bool `yaml:"bool"` + List []string `yaml:"list"` + Map map[string]any `yaml:"map"` + } + + want := TestStruct{ + String: "value", + Int: 1, + Bool: true, + List: []string{"one", "two", "three"}, + Map: map[string]any{ + "string": "value", + "int": uint64(1), + "bool": true, + "list": []any{"one", "two", "three"}, + "map": map[string]any{ + "string": "value", + "int": uint64(1), + "bool": true, + "list": []any{"one", "two", "three"}, + }, + }, + } + + var v TestStruct + d, err := ParseDocumentFile("testdata/valid.yml", &v) + + require.NoError(t, err) + require.NotNil(t, d) + + assert.Equal(t, want, v) + }) + + t.Run("bad-file", func(t *testing.T) { + type TestStruct struct { + String int `yaml:"string"` + } + + var v TestStruct + _, err := ParseDocumentFile("testdata/bad.yml", &v) + + require.Error(t, err) + }) + + t.Run("bad-parse", func(t *testing.T) { + type TestStruct struct { + String int `yaml:"string"` + } + + var v TestStruct + _, err := ParseDocumentFile("testdata/valid.yml", &v) + + require.Error(t, err) + }) +} + +func TestNewDocumentBytes(t *testing.T) { + testCases := []struct { + name string + in []byte + wantErr bool + }{ + { + name: "ok", + in: []byte(`--- +string: value +int: 1 +bool: true +list: + - one + - two + - three +map: + string: value + int: 1 + bool: true + list: + - one + - two + - three + map: + string: value + int: 1 + bool: true + list: + - one + - two + - three +`), + }, + { + name: "bad", + in: []byte(`:`), + wantErr: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + got, err := NewDocumentBytes(tc.in) + + if tc.wantErr { + assert.Error(t, err) + assert.Nil(t, got) + } else { + assert.NoError(t, err) + assert.NotNil(t, got) + } + }) + } +} + +func TestParseDocumentBytes(t *testing.T) { + t.Run("ok", func(t *testing.T) { + type TestStruct struct { + String string `yaml:"string"` + Int int `yaml:"int"` + Bool bool `yaml:"bool"` + List []string `yaml:"list"` + Map map[string]any `yaml:"map"` + } + + want := TestStruct{ + String: "value", + Int: 1, + Bool: true, + List: []string{"one", "two", "three"}, + Map: map[string]any{ + "string": "value", + "int": uint64(1), + "bool": true, + "list": []any{"one", "two", "three"}, + "map": map[string]any{ + "string": "value", + "int": uint64(1), + "bool": true, + "list": []any{"one", "two", "three"}, + }, + }, + } + + var v TestStruct + d, err := ParseDocumentBytes([]byte(`--- +string: value +int: 1 +bool: true +list: + - one + - two + - three +map: + string: value + int: 1 + bool: true + list: + - one + - two + - three + map: + string: value + int: 1 + bool: true + list: + - one + - two + - three +`), &v) + + require.NoError(t, err) + require.NotNil(t, d) + + assert.Equal(t, want, v) + }) + + t.Run("bad-file", func(t *testing.T) { + type TestStruct struct { + String int `yaml:"string"` + } + + var v TestStruct + _, err := ParseDocumentBytes([]byte(`:`), &v) + + require.Error(t, err) + }) + + t.Run("bad-parse", func(t *testing.T) { + type TestStruct struct { + String int `yaml:"string"` + } + + var v TestStruct + _, err := ParseDocumentBytes([]byte(`--- +string: value +`), &v) + + require.Error(t, err) + }) +} + +func Test_getPathIndex(t *testing.T) { + testCases := []struct { + name string + path string + want int + }{ + { + name: "single-index", + path: "$.test[1]", + want: 1, + }, + { + name: "multiple-indices", + path: "$.test[1].attributes[3]", + want: 3, + }, + { + name: "bad-malformed-index", + path: "$.test[1", + want: -1, + }, + { + name: "bad-no-index", + path: "$.test", + want: -1, + }, + { + name: "bad-index-attribute", + path: "$.test[1].attribute", + want: -1, + }, + { + name: "bad-index-all", + path: "$.test[*]", + want: -1, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + got := getPathIndex(tc.path) + + assert.Equal(t, tc.want, got) + }) + } +} + +func Test_cutPath(t *testing.T) { + testCases := []struct { + name string + path string + wantBefore string + wantAfter string + wantErr bool + }{ + { + name: "ok", + path: "$.test.attribute", + wantBefore: "$.test", + wantAfter: "attribute", + }, + { + name: "bad-split-root", + path: "$", + wantErr: true, + }, + } + + for _, tc := range testCases { + t.Run(tc.name, func(t *testing.T) { + gotBefore, gotAfter, err := cutPath(tc.path) + + if tc.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.wantBefore, gotBefore) + assert.Equal(t, tc.wantAfter, gotAfter) + } + }) + } +} + +func Test_nodeEqual(t *testing.T) { + t.Run("equal", func(t *testing.T) { + a, err := parser.ParseBytes([]byte(`string: "foobar"`), parser.ParseComments) + require.NoError(t, err) + b, err := parser.ParseBytes([]byte(`string: "foobar"`), parser.ParseComments) + require.NoError(t, err) + + assert.True(t, nodeEqual(a.Docs[0].Body, b.Docs[0].Body)) + }) + t.Run("not-equal", func(t *testing.T) { + a, err := parser.ParseBytes([]byte(`string: "foo"`), parser.ParseComments) + require.NoError(t, err) + b, err := parser.ParseBytes([]byte(`string: "bar"`), parser.ParseComments) + require.NoError(t, err) + + assert.False(t, nodeEqual(a.Docs[0].Body, b.Docs[0].Body)) + }) +}