diff --git a/.claude/settings.local.json b/.claude/settings.local.json new file mode 100644 index 0000000..e8f289d --- /dev/null +++ b/.claude/settings.local.json @@ -0,0 +1,3 @@ +{ + "enableAllProjectMcpServers": false +} \ No newline at end of file diff --git a/cmd/context.go b/cmd/context.go index 1eb0268..bed4d0c 100644 --- a/cmd/context.go +++ b/cmd/context.go @@ -155,16 +155,26 @@ func toMessages(inputs []string) []client.NewMessage { func addMessagesToContextCmd() *cobra.Command { var id string + var name string var messages []string var files []string cmd := &cobra.Command{ Use: "add", Short: "Add messages to a given context", Run: func(cmd *cobra.Command, args []string) { + if id == "" && name == "" { + exitIfError(errors.New("you should pass a context ID or a context name")) + } c, err := client.New() exitIfError(err) ctx := context.Background() messages := toMessages(messages) + if name != "" { + context, err := c.GetContextByName(ctx, name) + exitIfError(err) + id = context.ID + } + for _, input := range files { role, path, found := strings.Cut(input, ":") if !found { @@ -188,9 +198,8 @@ func addMessagesToContextCmd() *cobra.Command { printJson(*response) }, } - cmd.PersistentFlags().StringVar(&id, "id", "", "The context ID") - err := cmd.MarkPersistentFlagRequired("id") - exitIfError(err) + cmd.PersistentFlags().StringVar(&id, "context-id", "", "The context ID") + cmd.PersistentFlags().StringVar(&name, "context-name", "", "The context name") cmd.PersistentFlags().StringArrayVar(&messages, "message", []string{}, "Messages to add to this context. They should be prefixed by the role name (example: user:hello-world)") cmd.PersistentFlags().StringArrayVar(&files, "message-from-file", []string{}, "A list of files paths, the content will be added to the context. They should be prefixed by the role name (example: user:/my/file)") return cmd diff --git a/cmd/conversation.go b/cmd/conversation.go index 2815934..aba044e 100644 --- a/cmd/conversation.go +++ b/cmd/conversation.go @@ -33,6 +33,7 @@ func buildConversationCmd() *cobra.Command { var ragModel string var ragProvider string var ragLimit uint32 + var systemPromptID string cmd := &cobra.Command{ Use: "conversation", Short: `Send a message to an AI provider. @@ -96,6 +97,7 @@ If a context ID is provided, it will be used as input for the conversation. Else QueryOptions: options, NewContextOptions: contextOptions, Messages: msg, + SystemPromptID: systemPromptID, } if interactive { input.Stream = stream @@ -167,6 +169,7 @@ If a context ID is provided, it will be used as input for the conversation. Else exitIfError(err) cmd.PersistentFlags().StringVar(&system, "system", "", "System promt for the AI provider") + cmd.PersistentFlags().StringVar(&systemPromptID, "system-prompt-id", "", "ID of a system prompt to use (will be concatenated with --system if both are provided)") cmd.PersistentFlags().StringVar(&contextID, "context-id", "", "The ID of the context to reuse for this conversation") cmd.PersistentFlags().StringVar(&contextName, "context-name", "", "The name of the context to reuse for this conversation") cmd.PersistentFlags().StringVar(&newContextName, "new-context-name", "", "The name of the new context that will be created for this conversation if a context ID is not provided") diff --git a/cmd/root.go b/cmd/root.go index 6f84a43..5221202 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -82,6 +82,10 @@ func Run() error { Use: "embedding", Short: "Embedding commands", } + systemPromptCmd := &cobra.Command{ + Use: "system-prompt", + Short: "System prompt subcommands", + } serverCmd := buildServerCmd() embeddingCmd.AddCommand(embeddingMatchCmd()) documentCmd.AddCommand(documentListCmd()) @@ -103,6 +107,11 @@ func Run() error { contextMessageCmd.AddCommand(deleteContextMessagesCmd()) contextSourceCmd.AddCommand(contextSourceContextDeleteCmd()) contextSourceCmd.AddCommand(contextSourceContextAddCmd()) + systemPromptCmd.AddCommand(systemPromptListCmd()) + systemPromptCmd.AddCommand(systemPromptGetCmd()) + systemPromptCmd.AddCommand(systemPromptCreateCmd()) + systemPromptCmd.AddCommand(systemPromptUpdateCmd()) + systemPromptCmd.AddCommand(systemPromptDeleteCmd()) conversationCmd := buildConversationCmd() rootCmd.AddCommand(embeddingCmd) @@ -110,6 +119,7 @@ func Run() error { rootCmd.AddCommand(documentChunkCmd) rootCmd.AddCommand(conversationCmd) rootCmd.AddCommand(contextCmd) + rootCmd.AddCommand(systemPromptCmd) rootCmd.AddCommand(serverCmd) shutdown, err := initOpentelemetry() if err != nil { diff --git a/cmd/server.go b/cmd/server.go index 9f1e393..e633d3b 100644 --- a/cmd/server.go +++ b/cmd/server.go @@ -52,7 +52,7 @@ func RunServer() error { rag := rag.New(db, embeddingProviders) ai := assistant.New(clients, manager, rag) - handlersBuilder := handlers.NewBuilder(ai, manager, rag) + handlersBuilder := handlers.NewBuilder(ai, manager, rag, db) server, err := http.New(config.HTTP, registry, handlersBuilder) if err != nil { return err diff --git a/cmd/system.go b/cmd/system.go new file mode 100644 index 0000000..f7a5c38 --- /dev/null +++ b/cmd/system.go @@ -0,0 +1,160 @@ +package cmd + +import ( + "context" + "errors" + "os" + + "github.com/appclacks/maizai/internal/http/client" + "github.com/spf13/cobra" +) + +func systemPromptListCmd() *cobra.Command { + cmd := &cobra.Command{ + Use: "list", + Short: "List system prompts", + Run: func(cmd *cobra.Command, args []string) { + client, err := client.New() + exitIfError(err) + ctx := context.Background() + prompts, err := client.ListSystemPrompts(ctx) + exitIfError(err) + printJson(prompts) + }, + } + return cmd +} + +func systemPromptGetCmd() *cobra.Command { + var id string + cmd := &cobra.Command{ + Use: "get", + Short: "Get a system prompt by ID", + Run: func(cmd *cobra.Command, args []string) { + if id == "" { + exitIfError(errors.New("the command expects a system prompt id as input")) + } + client, err := client.New() + exitIfError(err) + ctx := context.Background() + prompt, err := client.GetSystemPrompt(ctx, id) + exitIfError(err) + printJson(*prompt) + }, + } + cmd.PersistentFlags().StringVar(&id, "id", "", "The ID of the system prompt to retrieve") + err := cmd.MarkPersistentFlagRequired("id") + exitIfError(err) + return cmd +} + +func systemPromptCreateCmd() *cobra.Command { + var name string + var description string + var content string + var contentFile string + cmd := &cobra.Command{ + Use: "create", + Short: "Create a new system prompt", + Run: func(cmd *cobra.Command, args []string) { + if content == "" && contentFile == "" { + exitIfError(errors.New("either --content or --content-from-file must be provided")) + } + if content != "" && contentFile != "" { + exitIfError(errors.New("cannot specify both --content and --content-from-file")) + } + + finalContent := content + if contentFile != "" { + fileContent, err := os.ReadFile(contentFile) + if err != nil { + exitIfError(err) + } + finalContent = string(fileContent) + } + + c, err := client.New() + exitIfError(err) + ctx := context.Background() + input := client.CreateSystemPromptInput{ + Name: name, + Description: description, + Content: finalContent, + } + response, err := c.CreateSystemPrompt(ctx, input) + exitIfError(err) + printJson(*response) + }, + } + cmd.PersistentFlags().StringVar(&name, "name", "", "The name of the new system prompt") + err := cmd.MarkPersistentFlagRequired("name") + exitIfError(err) + cmd.PersistentFlags().StringVar(&description, "description", "", "The description of the new system prompt") + cmd.PersistentFlags().StringVar(&content, "content", "", "The content of the system prompt") + cmd.PersistentFlags().StringVar(&contentFile, "content-from-file", "", "Path to file containing the system prompt content") + return cmd +} + +func systemPromptUpdateCmd() *cobra.Command { + var id string + var content string + var contentFile string + cmd := &cobra.Command{ + Use: "update", + Short: "Update a system prompt's content", + Run: func(cmd *cobra.Command, args []string) { + if content == "" && contentFile == "" { + exitIfError(errors.New("either --content or --content-from-file must be provided")) + } + if content != "" && contentFile != "" { + exitIfError(errors.New("cannot specify both --content and --content-from-file")) + } + + finalContent := content + if contentFile != "" { + fileContent, err := os.ReadFile(contentFile) + if err != nil { + exitIfError(err) + } + finalContent = string(fileContent) + } + + c, err := client.New() + exitIfError(err) + ctx := context.Background() + input := client.UpdateSystemPromptInput{ + ID: id, + Content: finalContent, + } + response, err := c.UpdateSystemPrompt(ctx, input) + exitIfError(err) + printJson(*response) + }, + } + cmd.PersistentFlags().StringVar(&id, "id", "", "The ID of the system prompt to update") + err := cmd.MarkPersistentFlagRequired("id") + exitIfError(err) + cmd.PersistentFlags().StringVar(&content, "content", "", "The new content for the system prompt") + cmd.PersistentFlags().StringVar(&contentFile, "content-from-file", "", "Path to file containing the new system prompt content") + return cmd +} + +func systemPromptDeleteCmd() *cobra.Command { + var id string + cmd := &cobra.Command{ + Use: "delete", + Short: "Delete a system prompt by ID", + Run: func(cmd *cobra.Command, args []string) { + client, err := client.New() + exitIfError(err) + ctx := context.Background() + response, err := client.DeleteSystemPrompt(ctx, id) + exitIfError(err) + printJson(*response) + }, + } + cmd.PersistentFlags().StringVar(&id, "id", "", "The ID of the system prompt to delete") + err := cmd.MarkPersistentFlagRequired("id") + exitIfError(err) + return cmd +} \ No newline at end of file diff --git a/doc/openapi.yaml b/doc/openapi.yaml index 5ea10d5..b32b79d 100644 --- a/doc/openapi.yaml +++ b/doc/openapi.yaml @@ -60,6 +60,21 @@ paths: $ref: '#/components/schemas/ClientContext' description: OK /api/v1/context/{id}/message: + delete: + description: Delete all messages for a given context + parameters: + - in: path + name: id + required: true + schema: + type: string + responses: + "200": + content: + application/json: + schema: + $ref: '#/components/schemas/ClientResponse' + description: OK post: description: Add new messages for a given context parameters: @@ -293,6 +308,81 @@ paths: schema: $ref: '#/components/schemas/ClientResponse' description: OK + /api/v1/system-prompt: + get: + description: List system prompts + responses: + "200": + content: + application/json: + schema: + $ref: '#/components/schemas/ClientListSystemPromptsOutput' + description: OK + post: + description: Create a new system prompt + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/ClientCreateSystemPromptInput' + responses: + "200": + content: + application/json: + schema: + $ref: '#/components/schemas/ClientResponse' + description: OK + /api/v1/system-prompt/{id}: + delete: + description: Delete a system prompt by ID + parameters: + - in: path + name: id + required: true + schema: + type: string + responses: + "200": + content: + application/json: + schema: + $ref: '#/components/schemas/ClientResponse' + description: OK + get: + description: Get a system prompt by ID + parameters: + - in: path + name: id + required: true + schema: + type: string + responses: + "200": + content: + application/json: + schema: + $ref: '#/components/schemas/ClientSystemPrompt' + description: OK + put: + description: Update a system prompt + parameters: + - in: path + name: id + required: true + schema: + type: string + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/ClientUpdateSystemPromptInput' + responses: + "200": + content: + application/json: + schema: + $ref: '#/components/schemas/ClientResponse' + description: OK components: schemas: ClientContext: @@ -384,7 +474,7 @@ components: messages: description: messages attached to this context items: - $ref: '#/components/schemas/ClientCreateContextMessage' + $ref: '#/components/schemas/ClientNewMessage' nullable: true type: array name: @@ -395,35 +485,29 @@ components: required: - name type: object - ClientCreateContextMessage: - properties: - content: - description: The message content - type: string - role: - description: The message role - type: string - required: - - role - - content - type: object ClientCreateConversationInput: properties: context-id: description: The ID of an existing context to use for this conversation type: string + messages: + description: The messages to provide the the AI provider + items: + $ref: '#/components/schemas/ClientNewMessage' + nullable: true + type: array new-context: $ref: '#/components/schemas/ClientContextOptions' - prompt: - description: The prompt that will be passed to the AI provider - type: string query-options: $ref: '#/components/schemas/ClientQueryOptions' stream: description: Streaming mode using SSE type: boolean + system-prompt-id: + description: The ID of a system prompt to use for this conversation + type: string required: - - prompt + - messages type: object ClientCreateDocumentInput: properties: @@ -434,6 +518,21 @@ components: required: - name type: object + ClientCreateSystemPromptInput: + properties: + content: + description: The system prompt content + type: string + description: + description: The system prompt description + type: string + name: + description: The system prompt name + type: string + required: + - name + - content + type: object ClientDocument: properties: created-at: @@ -506,6 +605,14 @@ components: nullable: true type: array type: object + ClientListSystemPromptsOutput: + properties: + system_prompts: + items: + $ref: '#/components/schemas/ClientSystemPrompt' + nullable: true + type: array + type: object ClientMessage: properties: content: @@ -522,6 +629,18 @@ components: description: The message role type: string type: object + ClientNewMessage: + properties: + content: + description: The message content + type: string + role: + description: The message role + type: string + required: + - role + - content + type: object ClientQueryOptions: properties: max-tokens: @@ -581,6 +700,25 @@ components: text: type: string type: object + ClientSystemPrompt: + properties: + content: + description: The system prompt content + type: string + created-at: + description: The system prompt creation date + format: date-time + type: string + description: + description: The system prompt description + type: string + id: + description: The system prompt ID + type: string + name: + description: The system prompt name + type: string + type: object ClientUpdateContextMessageInput: properties: content: @@ -593,6 +731,14 @@ components: - role - content type: object + ClientUpdateSystemPromptInput: + properties: + content: + description: The system prompt content + type: string + required: + - content + type: object SharedContextSources: properties: contexts: diff --git a/internal/contextstore/memory/memory.go b/internal/contextstore/memory/memory.go index 40249a3..229436f 100644 --- a/internal/contextstore/memory/memory.go +++ b/internal/contextstore/memory/memory.go @@ -11,7 +11,8 @@ import ( type MemoryContextStore struct { state map[string]*shared.Context - lock sync.RWMutex + //system map[string]*shared.SystemPrompt + lock sync.RWMutex } func New() *MemoryContextStore { @@ -20,6 +21,30 @@ func New() *MemoryContextStore { } } +func (m *MemoryContextStore) CreateSystemPrompt(ctx context.Context, prompt shared.SystemPrompt) error { + return nil +} + +func (m *MemoryContextStore) GetSystemPrompt(ctx context.Context, id string) (*shared.SystemPrompt, error) { + return nil, nil +} + +func (m *MemoryContextStore) GetSystemPromptByName(ctx context.Context, name string) (*shared.SystemPrompt, error) { + return nil, nil +} + +func (m *MemoryContextStore) ListSystemPrompts(ctx context.Context) ([]shared.SystemPrompt, error) { + return nil, nil +} + +func (m *MemoryContextStore) UpdateSystemPrompt(ctx context.Context, id string, content string) error { + return nil +} + +func (m *MemoryContextStore) DeleteSystemPrompt(ctx context.Context, id string) error { + return nil +} + func (m *MemoryContextStore) DeleteContext(ctx context.Context, id string) error { m.lock.Lock() defer m.lock.Unlock() diff --git a/internal/database/migrations/20250425_system.up.sql b/internal/database/migrations/20250425_system.up.sql new file mode 100644 index 0000000..dd9672d --- /dev/null +++ b/internal/database/migrations/20250425_system.up.sql @@ -0,0 +1,11 @@ +--;; +create table if not exists system_prompt ( + id uuid not null primary key, + name varchar(255) not null unique, + description text, + content text not null, + created_at timestamp not null +); +--;; +CREATE INDEX IF NOT EXISTS idx_system_prompt_name ON system_prompt(name); +--;; diff --git a/internal/database/queries/models.go b/internal/database/queries/models.go index 473c096..aa7593d 100644 --- a/internal/database/queries/models.go +++ b/internal/database/queries/models.go @@ -45,3 +45,11 @@ type DocumentChunk struct { Embedding pgvector.Vector CreatedAt pgtype.Timestamp } + +type SystemPrompt struct { + ID pgtype.UUID + Name string + Description pgtype.Text + Content string + CreatedAt pgtype.Timestamp +} diff --git a/internal/database/queries/system_prompt.sql.go b/internal/database/queries/system_prompt.sql.go new file mode 100644 index 0000000..b24ee5a --- /dev/null +++ b/internal/database/queries/system_prompt.sql.go @@ -0,0 +1,134 @@ +// Code generated by sqlc. DO NOT EDIT. +// versions: +// sqlc v1.28.0 +// source: system_prompt.sql + +package queries + +import ( + "context" + + "github.com/jackc/pgx/v5/pgtype" +) + +const createSystemPrompt = `-- name: CreateSystemPrompt :exec +INSERT INTO system_prompt ( + id, name, description, content, created_at +) VALUES ( + $1, $2, $3, $4, $5 +) +` + +type CreateSystemPromptParams struct { + ID pgtype.UUID + Name string + Description pgtype.Text + Content string + CreatedAt pgtype.Timestamp +} + +func (q *Queries) CreateSystemPrompt(ctx context.Context, arg CreateSystemPromptParams) error { + _, err := q.db.Exec(ctx, createSystemPrompt, + arg.ID, + arg.Name, + arg.Description, + arg.Content, + arg.CreatedAt, + ) + return err +} + +const deleteSystemPrompt = `-- name: DeleteSystemPrompt :exec +DELETE FROM system_prompt +WHERE id = $1 +` + +func (q *Queries) DeleteSystemPrompt(ctx context.Context, id pgtype.UUID) error { + _, err := q.db.Exec(ctx, deleteSystemPrompt, id) + return err +} + +const getSystemPrompt = `-- name: GetSystemPrompt :one +SELECT id, name, description, content, created_at +FROM system_prompt +WHERE id = $1 +` + +func (q *Queries) GetSystemPrompt(ctx context.Context, id pgtype.UUID) (SystemPrompt, error) { + row := q.db.QueryRow(ctx, getSystemPrompt, id) + var i SystemPrompt + err := row.Scan( + &i.ID, + &i.Name, + &i.Description, + &i.Content, + &i.CreatedAt, + ) + return i, err +} + +const getSystemPromptByName = `-- name: GetSystemPromptByName :one +SELECT id, name, description, content, created_at +FROM system_prompt +WHERE name = $1 +` + +func (q *Queries) GetSystemPromptByName(ctx context.Context, name string) (SystemPrompt, error) { + row := q.db.QueryRow(ctx, getSystemPromptByName, name) + var i SystemPrompt + err := row.Scan( + &i.ID, + &i.Name, + &i.Description, + &i.Content, + &i.CreatedAt, + ) + return i, err +} + +const listSystemPrompts = `-- name: ListSystemPrompts :many +SELECT id, name, description, content, created_at +FROM system_prompt +` + +func (q *Queries) ListSystemPrompts(ctx context.Context) ([]SystemPrompt, error) { + rows, err := q.db.Query(ctx, listSystemPrompts) + if err != nil { + return nil, err + } + defer rows.Close() + var items []SystemPrompt + for rows.Next() { + var i SystemPrompt + if err := rows.Scan( + &i.ID, + &i.Name, + &i.Description, + &i.Content, + &i.CreatedAt, + ); err != nil { + return nil, err + } + items = append(items, i) + } + if err := rows.Err(); err != nil { + return nil, err + } + return items, nil +} + +const updateSystemPrompt = `-- name: UpdateSystemPrompt :exec +UPDATE system_prompt +SET content = $2 +WHERE id = $1 +` + +type UpdateSystemPromptParams struct { + ID pgtype.UUID + Content string +} + +func (q *Queries) UpdateSystemPrompt(ctx context.Context, arg UpdateSystemPromptParams) error { + _, err := q.db.Exec(ctx, updateSystemPrompt, arg.ID, arg.Content) + return err +} diff --git a/internal/database/root.go b/internal/database/root.go index 4e88737..476b149 100644 --- a/internal/database/root.go +++ b/internal/database/root.go @@ -123,7 +123,7 @@ func pgxTime(t time.Time) pgtype.Timestamp { } func (c *Database) beginTx(ctx context.Context, options pgx.TxOptions) (pgx.Tx, *queries.Queries, func(), error) { - tx, err := c.conn.BeginTx(ctx, pgx.TxOptions{}) + tx, err := c.conn.BeginTx(ctx, options) if err != nil { return nil, nil, nil, err } diff --git a/internal/database/system.go b/internal/database/system.go new file mode 100644 index 0000000..a5656d2 --- /dev/null +++ b/internal/database/system.go @@ -0,0 +1,112 @@ +package database + +import ( + "context" + "fmt" + + "github.com/appclacks/maizai/internal/database/queries" + "github.com/appclacks/maizai/pkg/shared" + "github.com/jackc/pgx/v5" + er "github.com/mcorbin/corbierror" +) + +func (c *Database) CreateSystemPrompt(ctx context.Context, prompt shared.SystemPrompt) error { + err := c.queries.CreateSystemPrompt(ctx, queries.CreateSystemPromptParams{ + ID: pgxID(prompt.ID), + Name: prompt.Name, + Description: pgxText(prompt.Description), + Content: prompt.Content, + CreatedAt: pgxTime(prompt.CreatedAt), + }) + if err != nil { + return err + } + return nil +} + +func (c *Database) GetSystemPrompt(ctx context.Context, id string) (*shared.SystemPrompt, error) { + prompt, err := c.queries.GetSystemPrompt(ctx, pgxID(id)) + if err != nil { + if err == pgx.ErrNoRows { + return nil, er.Newf("system prompt %s doesn't exist", er.NotFound, true, id) + } + return nil, err + } + return &shared.SystemPrompt{ + ID: prompt.ID.String(), + Name: prompt.Name, + Description: prompt.Description.String, + Content: prompt.Content, + CreatedAt: prompt.CreatedAt.Time, + }, nil +} + +func (c *Database) GetSystemPromptByName(ctx context.Context, name string) (*shared.SystemPrompt, error) { + prompt, err := c.queries.GetSystemPromptByName(ctx, name) + if err != nil { + if err == pgx.ErrNoRows { + return nil, er.Newf("system prompt with name %s doesn't exist", er.NotFound, true, name) + } + return nil, err + } + return &shared.SystemPrompt{ + ID: prompt.ID.String(), + Name: prompt.Name, + Description: prompt.Description.String, + Content: prompt.Content, + CreatedAt: prompt.CreatedAt.Time, + }, nil +} + +func (c *Database) SystemPromptExistsByName(ctx context.Context, name string) (bool, error) { + _, err := c.queries.GetSystemPromptByName(ctx, name) + if err != nil { + if err == pgx.ErrNoRows { + return false, nil + } + return false, fmt.Errorf("fail to check system prompt %s: %w", name, err) + } + return true, nil +} + +func (c *Database) ListSystemPrompts(ctx context.Context) ([]shared.SystemPrompt, error) { + prompts, err := c.queries.ListSystemPrompts(ctx) + if err != nil { + return nil, err + } + result := make([]shared.SystemPrompt, 0, len(prompts)) + for _, p := range prompts { + result = append(result, shared.SystemPrompt{ + ID: p.ID.String(), + Name: p.Name, + Description: p.Description.String, + CreatedAt: p.CreatedAt.Time, + }) + } + return result, nil +} + +func (c *Database) UpdateSystemPrompt(ctx context.Context, id string, content string) error { + err := c.queries.UpdateSystemPrompt(ctx, queries.UpdateSystemPromptParams{ + ID: pgxID(id), + Content: content, + }) + if err != nil { + if err == pgx.ErrNoRows { + return er.Newf("system prompt %s doesn't exist", er.NotFound, true, id) + } + return err + } + return nil +} + +func (c *Database) DeleteSystemPrompt(ctx context.Context, id string) error { + err := c.queries.DeleteSystemPrompt(ctx, pgxID(id)) + if err != nil { + if err == pgx.ErrNoRows { + return er.Newf("system prompt %s doesn't exist", er.NotFound, true, id) + } + return err + } + return nil +} diff --git a/internal/database/system_test.go b/internal/database/system_test.go new file mode 100644 index 0000000..74ca45d --- /dev/null +++ b/internal/database/system_test.go @@ -0,0 +1,76 @@ +package database_test + +import ( + "context" + "testing" + "time" + + "github.com/appclacks/maizai/pkg/shared" + "github.com/google/uuid" + "github.com/stretchr/testify/assert" +) + +func TestSystemCRUD(t *testing.T) { + ctx := context.Background() + prompt := shared.SystemPrompt{ + ID: uuid.NewString(), + Name: "prompt1", + Content: "my prompt", + Description: "abc", + CreatedAt: time.Now().UTC(), + } + err := TestComponent.CreateSystemPrompt(ctx, prompt) + assert.NoError(t, err) + + result, err := TestComponent.GetSystemPrompt(ctx, prompt.ID) + assert.NoError(t, err) + assert.Equal(t, prompt.Name, result.Name) + assert.Equal(t, prompt.Content, result.Content) + assert.Equal(t, prompt.Description, result.Description) + + result, err = TestComponent.GetSystemPromptByName(ctx, prompt.Name) + assert.NoError(t, err) + assert.Equal(t, prompt.Name, result.Name) + assert.Equal(t, prompt.Content, result.Content) + assert.Equal(t, prompt.Description, result.Description) + + found, err := TestComponent.SystemPromptExistsByName(ctx, prompt.Name) + assert.NoError(t, err) + assert.True(t, found) + + found, err = TestComponent.SystemPromptExistsByName(ctx, "unknown") + assert.NoError(t, err) + assert.False(t, found) + + err = TestComponent.UpdateSystemPrompt(ctx, prompt.ID, "new content") + assert.NoError(t, err) + + result, err = TestComponent.GetSystemPrompt(ctx, prompt.ID) + assert.NoError(t, err) + assert.Equal(t, prompt.Name, result.Name) + assert.Equal(t, "new content", result.Content) + + newPrompt := shared.SystemPrompt{ + ID: uuid.NewString(), + Name: "prompt2", + Content: "my second prompt", + Description: "abc", + CreatedAt: time.Now().UTC(), + } + err = TestComponent.CreateSystemPrompt(ctx, newPrompt) + assert.NoError(t, err) + + prompts, err := TestComponent.ListSystemPrompts(ctx) + assert.NoError(t, err) + assert.Len(t, prompts, 2) + + err = TestComponent.DeleteSystemPrompt(ctx, prompt.ID) + assert.NoError(t, err) + + _, err = TestComponent.GetSystemPrompt(ctx, prompt.ID) + assert.ErrorContains(t, err, "doesn't exist") + + err = TestComponent.DeleteSystemPrompt(ctx, newPrompt.ID) + assert.NoError(t, err) + +} diff --git a/internal/http/client/conversation.go b/internal/http/client/conversation.go index 851efd3..53a5ce7 100644 --- a/internal/http/client/conversation.go +++ b/internal/http/client/conversation.go @@ -32,6 +32,7 @@ type CreateConversationInput struct { ContextID string `json:"context-id,omitempty" description:"The ID of an existing context to use for this conversation"` NewContextOptions ContextOptions `json:"new-context" description:"Options to create a new context"` Stream bool `json:"stream" description:"Streaming mode using SSE"` + SystemPromptID string `json:"system-prompt-id,omitempty" description:"The ID of a system prompt to use for this conversation"` } type Result struct { diff --git a/internal/http/client/system.go b/internal/http/client/system.go new file mode 100644 index 0000000..e42959c --- /dev/null +++ b/internal/http/client/system.go @@ -0,0 +1,84 @@ +package client + +import ( + "context" + "fmt" + "net/http" + "time" +) + +type SystemPrompt struct { + ID string `json:"id" description:"The system prompt ID"` + Name string `json:"name" description:"The system prompt name"` + Description string `json:"description,omitempty" description:"The system prompt description"` + Content string `json:"content" description:"The system prompt content"` + CreatedAt time.Time `json:"created-at" description:"The system prompt creation date"` +} + +type CreateSystemPromptInput struct { + Name string `json:"name" required:"true" description:"The system prompt name"` + Description string `json:"description" description:"The system prompt description"` + Content string `json:"content" required:"true" description:"The system prompt content"` +} + +type GetSystemPromptInput struct { + ID string `param:"id" path:"id"` +} + +type DeleteSystemPromptInput struct { + ID string `param:"id" path:"id"` +} + +type UpdateSystemPromptInput struct { + ID string `json:"-" param:"id" path:"id"` + Content string `json:"content" required:"true" description:"The system prompt content"` +} + +type ListSystemPromptsOutput struct { + SystemPrompts []SystemPrompt `json:"system_prompts"` +} + +func (c *Client) ListSystemPrompts(ctx context.Context) (*ListSystemPromptsOutput, error) { + var result ListSystemPromptsOutput + _, err := c.sendRequest(ctx, "/api/v1/system-prompt", http.MethodGet, nil, &result, nil) + if err != nil { + return nil, err + } + return &result, nil +} + +func (c *Client) GetSystemPrompt(ctx context.Context, id string) (*SystemPrompt, error) { + var result SystemPrompt + _, err := c.sendRequest(ctx, fmt.Sprintf("/api/v1/system-prompt/%s", id), http.MethodGet, nil, &result, nil) + if err != nil { + return nil, err + } + return &result, nil +} + +func (c *Client) CreateSystemPrompt(ctx context.Context, input CreateSystemPromptInput) (*Response, error) { + var result Response + _, err := c.sendRequest(ctx, "/api/v1/system-prompt", http.MethodPost, input, &result, nil) + if err != nil { + return nil, err + } + return &result, nil +} + +func (c *Client) UpdateSystemPrompt(ctx context.Context, input UpdateSystemPromptInput) (*Response, error) { + var result Response + _, err := c.sendRequest(ctx, fmt.Sprintf("/api/v1/system-prompt/%s", input.ID), http.MethodPut, input, &result, nil) + if err != nil { + return nil, err + } + return &result, nil +} + +func (c *Client) DeleteSystemPrompt(ctx context.Context, id string) (*Response, error) { + var result Response + _, err := c.sendRequest(ctx, fmt.Sprintf("/api/v1/system-prompt/%s", id), http.MethodDelete, nil, &result, nil) + if err != nil { + return nil, err + } + return &result, nil +} \ No newline at end of file diff --git a/internal/http/handlers/builder.go b/internal/http/handlers/builder.go index 19ae8ab..bfae89d 100644 --- a/internal/http/handlers/builder.go +++ b/internal/http/handlers/builder.go @@ -38,6 +38,14 @@ type Rag interface { ListDocumentChunksForDocument(ctx context.Context, id string) ([]rag.DocumentChunk, error) } +type SystemPromptManager interface { + CreateSystemPrompt(ctx context.Context, prompt shared.SystemPrompt) error + GetSystemPrompt(ctx context.Context, id string) (*shared.SystemPrompt, error) + ListSystemPrompts(ctx context.Context) ([]shared.SystemPrompt, error) + UpdateSystemPrompt(ctx context.Context, id string, content string) error + DeleteSystemPrompt(ctx context.Context, id string) error +} + func newResponse(messages ...string) client.Response { return client.Response{ Messages: messages, @@ -45,15 +53,17 @@ func newResponse(messages ...string) client.Response { } type Builder struct { - assistant Assistant - ctxManager ContextManager - ragManager Rag + assistant Assistant + ctxManager ContextManager + ragManager Rag + systemPromptManager SystemPromptManager } -func NewBuilder(assistant Assistant, ctxManager ContextManager, ragManager Rag) *Builder { +func NewBuilder(assistant Assistant, ctxManager ContextManager, ragManager Rag, systemPromptManager SystemPromptManager) *Builder { return &Builder{ - assistant: assistant, - ctxManager: ctxManager, - ragManager: ragManager, + assistant: assistant, + ctxManager: ctxManager, + ragManager: ragManager, + systemPromptManager: systemPromptManager, } } diff --git a/internal/http/handlers/conversation.go b/internal/http/handlers/conversation.go index 536dadb..b6bac30 100644 --- a/internal/http/handlers/conversation.go +++ b/internal/http/handlers/conversation.go @@ -28,9 +28,23 @@ func (b *Builder) Conversation(ec echo.Context) error { } ctx := ec.Request().Context() + + systemPrompt := payload.QueryOptions.System + if payload.SystemPromptID != "" { + prompt, err := b.systemPromptManager.GetSystemPrompt(ctx, payload.SystemPromptID) + if err != nil { + return err + } + if systemPrompt != "" { + systemPrompt = prompt.Content + "\n\n" + systemPrompt + } else { + systemPrompt = prompt.Content + } + } + queryOpts := aggregates.QueryOptions{ Model: payload.QueryOptions.Model, - System: payload.QueryOptions.System, + System: systemPrompt, Temperature: payload.QueryOptions.Temperature, MaxTokens: payload.QueryOptions.MaxTokens, Provider: payload.QueryOptions.Provider, diff --git a/internal/http/handlers/system.go b/internal/http/handlers/system.go new file mode 100644 index 0000000..271d1c1 --- /dev/null +++ b/internal/http/handlers/system.go @@ -0,0 +1,85 @@ +package handlers + +import ( + "net/http" + + "github.com/appclacks/maizai/internal/http/client" + "github.com/appclacks/maizai/pkg/shared" + "github.com/labstack/echo/v4" +) + +func toClientSystemPrompt(prompt shared.SystemPrompt) client.SystemPrompt { + return client.SystemPrompt{ + ID: prompt.ID, + Name: prompt.Name, + Description: prompt.Description, + Content: prompt.Content, + CreatedAt: prompt.CreatedAt, + } +} + +func (b *Builder) ListSystemPrompts(ec echo.Context) error { + prompts, err := b.systemPromptManager.ListSystemPrompts(ec.Request().Context()) + if err != nil { + return err + } + output := client.ListSystemPromptsOutput{ + SystemPrompts: []client.SystemPrompt{}, + } + for _, prompt := range prompts { + output.SystemPrompts = append(output.SystemPrompts, toClientSystemPrompt(prompt)) + } + return ec.JSON(http.StatusOK, output) +} + +func (b *Builder) GetSystemPrompt(ec echo.Context) error { + var payload client.GetSystemPromptInput + if err := ec.Bind(&payload); err != nil { + return err + } + prompt, err := b.systemPromptManager.GetSystemPrompt(ec.Request().Context(), payload.ID) + if err != nil { + return err + } + return ec.JSON(http.StatusOK, toClientSystemPrompt(*prompt)) +} + +func (b *Builder) CreateSystemPrompt(ec echo.Context) error { + var payload client.CreateSystemPromptInput + if err := ec.Bind(&payload); err != nil { + return err + } + prompt, err := shared.NewSystemPrompt(payload.Name, payload.Description, payload.Content) + if err != nil { + return err + } + err = b.systemPromptManager.CreateSystemPrompt(ec.Request().Context(), *prompt) + if err != nil { + return err + } + return ec.JSON(http.StatusOK, newResponse("system prompt created")) +} + +func (b *Builder) UpdateSystemPrompt(ec echo.Context) error { + var payload client.UpdateSystemPromptInput + if err := ec.Bind(&payload); err != nil { + return err + } + err := b.systemPromptManager.UpdateSystemPrompt(ec.Request().Context(), payload.ID, payload.Content) + if err != nil { + return err + } + return ec.JSON(http.StatusOK, newResponse("system prompt updated")) +} + +func (b *Builder) DeleteSystemPrompt(ec echo.Context) error { + var payload client.DeleteSystemPromptInput + if err := ec.Bind(&payload); err != nil { + return err + } + err := b.systemPromptManager.DeleteSystemPrompt(ec.Request().Context(), payload.ID) + if err != nil { + return err + } + return ec.JSON(http.StatusOK, newResponse("system prompt deleted")) +} diff --git a/internal/http/server.go b/internal/http/server.go index 7809275..02f7dbf 100644 --- a/internal/http/server.go +++ b/internal/http/server.go @@ -233,6 +233,46 @@ func New(config Configuration, registry *prometheus.Registry, builder *handlers. response: client.ListDocumentChunksOutput{}, description: "Return chunks matching the provided input", }, + { + path: "/system-prompt", + method: http.MethodGet, + handler: builder.ListSystemPrompts, + payload: nil, + response: client.ListSystemPromptsOutput{}, + description: "List system prompts", + }, + { + path: "/system-prompt/:id", + method: http.MethodGet, + handler: builder.GetSystemPrompt, + payload: client.GetSystemPromptInput{}, + response: client.SystemPrompt{}, + description: "Get a system prompt by ID", + }, + { + path: "/system-prompt", + method: http.MethodPost, + handler: builder.CreateSystemPrompt, + payload: client.CreateSystemPromptInput{}, + response: client.Response{}, + description: "Create a new system prompt", + }, + { + path: "/system-prompt/:id", + method: http.MethodPut, + handler: builder.UpdateSystemPrompt, + payload: client.UpdateSystemPromptInput{}, + response: client.Response{}, + description: "Update a system prompt", + }, + { + path: "/system-prompt/:id", + method: http.MethodDelete, + handler: builder.DeleteSystemPrompt, + payload: client.DeleteSystemPromptInput{}, + response: client.Response{}, + description: "Delete a system prompt by ID", + }, } err = openapiSpec(e, definitions) diff --git a/main_test.go b/main_test.go index dfdc397..922a73a 100644 --- a/main_test.go +++ b/main_test.go @@ -548,7 +548,7 @@ func TestIntegration(t *testing.T) { rag := rag.New(db, embeddingClients) ai := assistant.New(clients, manager, rag) - handlersBuilder := handlers.NewBuilder(ai, manager, rag) + handlersBuilder := handlers.NewBuilder(ai, manager, rag, db) server, err := mhttp.New(config.HTTP, registry, handlersBuilder) assert.NoError(t, err) diff --git a/pkg/context/manager.go b/pkg/context/manager.go index 9f4fb9d..8f3df92 100644 --- a/pkg/context/manager.go +++ b/pkg/context/manager.go @@ -24,6 +24,11 @@ type ContextStore interface { DeleteContextSourceContext(ctx context.Context, contextID string, sourceContextID string) error CreateContextSourceContext(ctx context.Context, contextID string, sourceContextID string) error DeleteContextMessages(ctx context.Context, contextID string) error + CreateSystemPrompt(ctx context.Context, prompt shared.SystemPrompt) error + GetSystemPrompt(ctx context.Context, id string) (*shared.SystemPrompt, error) + ListSystemPrompts(ctx context.Context) ([]shared.SystemPrompt, error) + DeleteSystemPrompt(ctx context.Context, id string) error + UpdateSystemPrompt(ctx context.Context, id string, content string) error } type ContextManager struct { diff --git a/pkg/context/system.go b/pkg/context/system.go new file mode 100644 index 0000000..1247953 --- /dev/null +++ b/pkg/context/system.go @@ -0,0 +1,45 @@ +package context + +import ( + "context" + "errors" + + "github.com/appclacks/maizai/internal/id" + "github.com/appclacks/maizai/pkg/shared" +) + +func (c *ContextManager) CreateSystemPrompt(ctx context.Context, prompt shared.SystemPrompt) error { + err := prompt.Validate() + if err != nil { + return err + } + return c.store.CreateSystemPrompt(ctx, prompt) +} + +func (c *ContextManager) GetSystemPrompt(ctx context.Context, promptID string) (*shared.SystemPrompt, error) { + if err := id.Validate(promptID, "Invalid system prompt ID"); err != nil { + return nil, err + } + return c.store.GetSystemPrompt(ctx, promptID) +} + +func (c *ContextManager) ListSystemPrompts(ctx context.Context) ([]shared.SystemPrompt, error) { + return c.store.ListSystemPrompts(ctx) +} + +func (c *ContextManager) DeleteSystemPrompt(ctx context.Context, promptID string) error { + if err := id.Validate(promptID, "Invalid system prompt ID"); err != nil { + return err + } + return c.store.DeleteSystemPrompt(ctx, promptID) +} + +func (c *ContextManager) UpdateSystemPrompt(ctx context.Context, promptID string, content string) error { + if err := id.Validate(promptID, "Invalid system prompt ID"); err != nil { + return err + } + if content == "" { + return errors.New("System prompt content is empty") + } + return c.store.UpdateSystemPrompt(ctx, promptID, content) +} diff --git a/pkg/shared/system.go b/pkg/shared/system.go new file mode 100644 index 0000000..4f26b61 --- /dev/null +++ b/pkg/shared/system.go @@ -0,0 +1,47 @@ +package shared + +import ( + "errors" + "time" + + "github.com/appclacks/maizai/internal/id" + "github.com/google/uuid" +) + +type SystemPrompt struct { + ID string `json:"id"` + Name string `json:"name"` + Content string `json:"content"` + Description string `json:"description"` + CreatedAt time.Time `json:"created-at"` +} + +func NewSystemPrompt(name, description, content string) (*SystemPrompt, error) { + id, err := uuid.NewV6() + if err != nil { + return nil, err + } + return &SystemPrompt{ + ID: id.String(), + Name: name, + Description: description, + Content: content, + CreatedAt: time.Now().UTC(), + }, nil +} + +func (s SystemPrompt) Validate() error { + if err := id.Validate(s.ID, "Invalid system prompt ID"); err != nil { + return err + } + if s.Name == "" { + return errors.New("A system prompt name is mandatory") + } + if s.Content == "" { + return errors.New("System prompt content is empty") + } + if s.CreatedAt.IsZero() { + return errors.New("A system prompt should have a creation date") + } + return nil +} diff --git a/poc/add-file-context.sh b/poc/add-file-context.sh index 430ab24..4a5eeda 100755 --- a/poc/add-file-context.sh +++ b/poc/add-file-context.sh @@ -5,5 +5,4 @@ set -e contextName=$1 filePath=$2 -contextID=$(maizai context get --name $contextName | jq -r .id) -maizai context message add --id $contextID --message-from-file "user:$filePath" +maizai context message add --context-name $contextName --message-from-file "user:$filePath" diff --git a/poc/complete-buffer.sh b/poc/complete-buffer.sh index e3344fc..3d0e114 100755 --- a/poc/complete-buffer.sh +++ b/poc/complete-buffer.sh @@ -17,12 +17,10 @@ args+=(--model $MODEL) echo "source context: '${SOURCE_CONTEXT}'" if [ ! -z "${SOURCE_CONTEXT}" ]; then - sourceContextID=$(maizai context get --name "$SOURCE_CONTEXT" | jq -r '.id') - args+=(--source-context $sourceContextID) + args+=(--source-context-name "$SOURCE_CONTEXT") fi - echo "pass file content: '${PASS_FILE_CONTENT}'" if [ "${PASS_FILE_CONTENT}" == "true" ]; then args+=(--message-from-file "assistant:$filePath") diff --git a/queries/system_prompt.sql b/queries/system_prompt.sql new file mode 100644 index 0000000..8c61b01 --- /dev/null +++ b/queries/system_prompt.sql @@ -0,0 +1,29 @@ +-- name: CreateSystemPrompt :exec +INSERT INTO system_prompt ( + id, name, description, content, created_at +) VALUES ( + $1, $2, $3, $4, $5 +); + +-- name: GetSystemPrompt :one +SELECT id, name, description, content, created_at +FROM system_prompt +WHERE id = $1; + +-- name: GetSystemPromptByName :one +SELECT id, name, description, content, created_at +FROM system_prompt +WHERE name = $1; + +-- name: ListSystemPrompts :many +SELECT id, name, description, content, created_at +FROM system_prompt; + +-- name: UpdateSystemPrompt :exec +UPDATE system_prompt +SET content = $2 +WHERE id = $1; + +-- name: DeleteSystemPrompt :exec +DELETE FROM system_prompt +WHERE id = $1;