First Working Prototype
This application is a simple proof of concept demonstrating an agent capable of taking a prompt and generating a patch implementing code satisfying the prompt along with an accompanying unit test.
This commit is contained in:
107
pkg/config/config.go
Normal file
107
pkg/config/config.go
Normal file
@@ -0,0 +1,107 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/tmc/langchaingo/embeddings"
|
||||
"github.com/tmc/langchaingo/llms"
|
||||
"github.com/tmc/langchaingo/llms/ollama"
|
||||
"github.com/tmc/langchaingo/llms/openai"
|
||||
)
|
||||
|
||||
type contextKey string
|
||||
|
||||
const (
|
||||
contextKeyConfig contextKey = "config"
|
||||
)
|
||||
|
||||
type Configuration struct {
|
||||
Database struct {
|
||||
ConnString string `yaml:"conn_string"`
|
||||
} `yaml:"database"`
|
||||
Logging struct {
|
||||
Level string `yaml:"level"`
|
||||
} `yaml:"logging"`
|
||||
Embedding LLMConfig `yaml:"embedding"`
|
||||
Code LLMConfig `yaml:"code"`
|
||||
Chat LLMConfig `yaml:"chat"`
|
||||
RelevantDocs int `yaml:"relevant_docs"`
|
||||
}
|
||||
|
||||
type LLMConfig struct {
|
||||
Type string `yaml:"type"`
|
||||
Model string `yaml:"model"`
|
||||
OLlama OLlamaConfig `yaml:"ollama"`
|
||||
OpenAI OpenAIConfig `yaml:"openai"`
|
||||
}
|
||||
|
||||
type OLlamaConfig struct {
|
||||
URL string `yaml:"url"`
|
||||
}
|
||||
|
||||
type OpenAIConfig struct {
|
||||
Key string `yaml:"key"`
|
||||
URL string `yaml:"url"`
|
||||
}
|
||||
|
||||
func (cfg LLMConfig) GetModel() (llms.Model, error) {
|
||||
switch cfg.Type {
|
||||
case "ollama":
|
||||
var opts []ollama.Option
|
||||
|
||||
if cfg.Model != "" {
|
||||
opts = append(opts, ollama.WithModel(cfg.Model))
|
||||
}
|
||||
|
||||
if cfg.OLlama.URL != "" {
|
||||
opts = append(opts, ollama.WithServerURL(cfg.OLlama.URL))
|
||||
}
|
||||
|
||||
return ollama.New(opts...)
|
||||
case "openai":
|
||||
var opts []openai.Option
|
||||
|
||||
if cfg.Model != "" {
|
||||
opts = append(opts, openai.WithModel(cfg.Model))
|
||||
}
|
||||
|
||||
if cfg.OpenAI.URL != "" {
|
||||
opts = append(opts, openai.WithBaseURL(cfg.OpenAI.URL))
|
||||
}
|
||||
|
||||
if cfg.OpenAI.Key != "" {
|
||||
opts = append(opts, openai.WithToken(cfg.OpenAI.Key))
|
||||
}
|
||||
|
||||
return openai.New(opts...)
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown model type: %s", cfg.Type)
|
||||
}
|
||||
}
|
||||
|
||||
func (cfg LLMConfig) GetEmbedding() (embeddings.EmbedderClient, error) {
|
||||
switch cfg.Type {
|
||||
case "ollama":
|
||||
var opts []ollama.Option
|
||||
|
||||
if cfg.Model != "" {
|
||||
opts = append(opts, ollama.WithModel(cfg.Model))
|
||||
}
|
||||
|
||||
if cfg.OLlama.URL != "" {
|
||||
opts = append(opts, ollama.WithServerURL(cfg.OLlama.URL))
|
||||
}
|
||||
|
||||
return ollama.New(opts...)
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown embedding type: %s", cfg.Type)
|
||||
}
|
||||
}
|
||||
|
||||
func FromContext(ctx context.Context) *Configuration {
|
||||
return ctx.Value(contextKeyConfig).(*Configuration)
|
||||
}
|
||||
|
||||
func WrapContext(ctx context.Context, cfg *Configuration) context.Context {
|
||||
return context.WithValue(ctx, contextKeyConfig, cfg)
|
||||
}
|
||||
139
pkg/database/database.go
Normal file
139
pkg/database/database.go
Normal file
@@ -0,0 +1,139 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"ai-code-assistant/pkg/config"
|
||||
"context"
|
||||
"embed"
|
||||
"errors"
|
||||
"fmt"
|
||||
"github.com/golang-migrate/migrate/v4"
|
||||
_ "github.com/golang-migrate/migrate/v4/database/postgres"
|
||||
"github.com/golang-migrate/migrate/v4/source/iofs"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
_ "github.com/lib/pq"
|
||||
"strconv"
|
||||
)
|
||||
|
||||
//go:embed migrations
|
||||
var migrations embed.FS
|
||||
|
||||
type contextKey string
|
||||
|
||||
const (
|
||||
contextKeyDB contextKey = "db"
|
||||
)
|
||||
|
||||
func preparedStatements() map[string]string {
|
||||
return map[string]string{
|
||||
"insert_repo": `INSERT INTO repos (repo_id, repo_hash, repo_path) VALUES ($1, $2, $3) ON CONFLICT DO NOTHING`,
|
||||
"repo_from_path": `SELECT repo_id FROM repos WHERE repo_path = $1`,
|
||||
"get_chunk": `SELECT document FROM langchain_pg_embedding WHERE JSON_EXTRACT_PATH_TEXT(cmetadata, 'chunk_id')=$1 AND JSON_EXTRACT_PATH_TEXT(cmetadata, 'file_path')=$2 AND JSON_EXTRACT_PATH_TEXT(cmetadata, 'repo_id')=$3`,
|
||||
}
|
||||
}
|
||||
|
||||
type Database struct {
|
||||
pool *pgxpool.Pool
|
||||
}
|
||||
|
||||
func (db *Database) DB(ctx context.Context) (*pgxpool.Conn, error) {
|
||||
conn, err := db.pool.Acquire(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for nm, query := range preparedStatements() {
|
||||
if _, err := conn.Conn().Prepare(ctx, nm, query); err != nil {
|
||||
return nil, fmt.Errorf("problem preparing statement %s: %w", nm, err)
|
||||
}
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func (db *Database) RepoIDFromPath(ctx context.Context, path string) (string, error) {
|
||||
conn, err := db.DB(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer conn.Release()
|
||||
|
||||
var repoID string
|
||||
|
||||
if err := conn.QueryRow(ctx, "repo_from_path", path).Scan(&repoID); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return repoID, nil
|
||||
}
|
||||
|
||||
func (db *Database) GetChunk(ctx context.Context, chunkID int, path, repoID string) (string, error) {
|
||||
conn, err := db.DB(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer conn.Release()
|
||||
|
||||
var chunk string
|
||||
|
||||
chunkIDStr := strconv.FormatInt(int64(chunkID), 10)
|
||||
|
||||
if err := conn.QueryRow(ctx, "get_chunk", chunkIDStr, path, repoID).Scan(&chunk); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return chunk, nil
|
||||
}
|
||||
|
||||
func (db *Database) GetChunkContext(ctx context.Context, chunkID, distance int, path, repoID string) (string, error) {
|
||||
minChunk := chunkID - distance
|
||||
if minChunk < 0 {
|
||||
minChunk = 0
|
||||
}
|
||||
|
||||
chunkContext := ""
|
||||
|
||||
for chunkID := minChunk; chunkID < chunkID+(distance*2); chunkID++ {
|
||||
chunk, err := db.GetChunk(ctx, chunkID, path, repoID)
|
||||
if err == nil {
|
||||
chunkContext += chunk
|
||||
} else if !errors.Is(err, pgx.ErrNoRows) {
|
||||
return "", err
|
||||
} else {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
return chunkContext, nil
|
||||
}
|
||||
|
||||
func FromConfig(ctx context.Context, cfg *config.Configuration) (*Database, error) {
|
||||
migFS, err := iofs.New(migrations, "migrations")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mig, err := migrate.NewWithSourceInstance("iofs", migFS, cfg.Database.ConnString)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := mig.Up(); err != nil && !errors.Is(err, migrate.ErrNoChange) {
|
||||
return nil, fmt.Errorf("problem performing database migrations: %w", err)
|
||||
}
|
||||
|
||||
pool, err := pgxpool.New(ctx, cfg.Database.ConnString)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Database{pool: pool}, nil
|
||||
}
|
||||
|
||||
func FromContext(ctx context.Context) *Database {
|
||||
return ctx.Value(contextKeyDB).(*Database)
|
||||
}
|
||||
|
||||
func WrapContext(ctx context.Context, cfg *Database) context.Context {
|
||||
return context.WithValue(ctx, contextKeyDB, cfg)
|
||||
}
|
||||
6
pkg/database/migrations/01_schema.up.sql
Normal file
6
pkg/database/migrations/01_schema.up.sql
Normal file
@@ -0,0 +1,6 @@
|
||||
CREATE TABLE IF NOT EXISTS repos (
|
||||
repo_id UUID PRIMARY KEY,
|
||||
repo_hash TEXT NOT NULL,
|
||||
repo_path TEXT NOT NULL,
|
||||
UNIQUE(repo_hash, repo_path)
|
||||
);
|
||||
96
pkg/llm/llm.go
Normal file
96
pkg/llm/llm.go
Normal file
@@ -0,0 +1,96 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"ai-code-assistant/pkg/config"
|
||||
"bytes"
|
||||
"context"
|
||||
"embed"
|
||||
"fmt"
|
||||
"github.com/tmc/langchaingo/embeddings"
|
||||
"github.com/tmc/langchaingo/llms"
|
||||
"text/template"
|
||||
)
|
||||
|
||||
type contextKey string
|
||||
|
||||
const contextKeyLLM contextKey = "llm"
|
||||
|
||||
//go:embed prompts
|
||||
var prompts embed.FS
|
||||
|
||||
type LLM struct {
|
||||
code llms.Model
|
||||
chat llms.Model
|
||||
embedder embeddings.Embedder
|
||||
}
|
||||
|
||||
func FromConfig(cfg *config.Configuration) (*LLM, error) {
|
||||
embedLLM, err := cfg.Embedding.GetEmbedding()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to get embedding model: %w", err)
|
||||
}
|
||||
|
||||
codeLLM, err := cfg.Code.GetModel()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to get coder model: %w", err)
|
||||
}
|
||||
|
||||
chatLLM, err := cfg.Chat.GetModel()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("unable to get chat model: %w", err)
|
||||
}
|
||||
|
||||
embedder, err := embeddings.NewEmbedder(embedLLM)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &LLM{
|
||||
embedder: embedder,
|
||||
code: codeLLM,
|
||||
chat: chatLLM,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func FromContext(ctx context.Context) *LLM {
|
||||
return ctx.Value(contextKeyLLM).(*LLM)
|
||||
}
|
||||
|
||||
func WrapContext(ctx context.Context, llmRef *LLM) context.Context {
|
||||
return context.WithValue(ctx, contextKeyLLM, llmRef)
|
||||
}
|
||||
|
||||
func (llm *LLM) GetEmbedding(ctx context.Context, texts ...string) ([][]float32, error) {
|
||||
return llm.embedder.EmbedDocuments(ctx, texts)
|
||||
}
|
||||
|
||||
func (llm *LLM) Embedder() embeddings.Embedder {
|
||||
return llm.embedder
|
||||
}
|
||||
|
||||
func (llm *LLM) CodePrompt(ctx context.Context, prompt string) (string, error) {
|
||||
return llm.code.Call(ctx, prompt)
|
||||
}
|
||||
|
||||
func (llm *LLM) ChatPrompt(ctx context.Context, prompt string) (string, error) {
|
||||
return llm.chat.Call(ctx, prompt)
|
||||
}
|
||||
|
||||
func GetPrompt(name string, data any) (string, error) {
|
||||
tmplText, err := prompts.ReadFile("prompts/" + name + ".tmpl")
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
tmpl, err := template.New(name).Parse(string(tmplText))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
if err := tmpl.Execute(&buf, data); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return buf.String(), nil
|
||||
}
|
||||
5
pkg/llm/prompts/generate_commitmsg.tmpl
Normal file
5
pkg/llm/prompts/generate_commitmsg.tmpl
Normal file
@@ -0,0 +1,5 @@
|
||||
given the following prompt create a descriptive commit message indicating the changes made
|
||||
only give a single commit message, do not describe your thought process
|
||||
|
||||
files changed: {{range .Files}}{{.}} {{end}}
|
||||
prompt: {{ .Prompt }}
|
||||
11
pkg/llm/prompts/generate_patch.tmpl
Normal file
11
pkg/llm/prompts/generate_patch.tmpl
Normal file
@@ -0,0 +1,11 @@
|
||||
given the following code snippet in markdown format perform the following task:
|
||||
{{ .Prompt }}
|
||||
|
||||
return the modified function as a code block formatted in markdown, make sure to retain as much of the preceding and
|
||||
trailing context as possible
|
||||
|
||||
!IMPORTANT do not add any explanation
|
||||
|
||||
```go
|
||||
{{ .Context }}
|
||||
```
|
||||
15
pkg/llm/prompts/generate_unittest.tmpl
Normal file
15
pkg/llm/prompts/generate_unittest.tmpl
Normal file
@@ -0,0 +1,15 @@
|
||||
given the following code block in markdown that performs the following task:
|
||||
{{ .Prompt }}
|
||||
|
||||
{{- if .TestFileExists -}}
|
||||
create a unit test for the following code as a code block formatted in markdown, only include the unit test itself do
|
||||
not include any imports
|
||||
!IMPORTANT do not add any explanation
|
||||
{{- else -}}
|
||||
create a unit test for the following code as a code block formatted in markdown, include any required imports
|
||||
!IMPORTANT do not add any explanation
|
||||
{{- end -}}
|
||||
|
||||
```go
|
||||
{{ .Context }}
|
||||
```"
|
||||
151
pkg/llm/relevent_docs.go
Normal file
151
pkg/llm/relevent_docs.go
Normal file
@@ -0,0 +1,151 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"ai-code-assistant/pkg/database"
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/cenkalti/backoff/v5"
|
||||
"github.com/tmc/langchaingo/callbacks"
|
||||
"github.com/tmc/langchaingo/schema"
|
||||
"github.com/tmc/langchaingo/vectorstores"
|
||||
"github.com/tmc/langchaingo/vectorstores/pgvector"
|
||||
"log/slog"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type RelevantDocs struct {
|
||||
CallbacksHandler callbacks.Handler
|
||||
db *database.Database
|
||||
llm *LLM
|
||||
repoID string
|
||||
size int
|
||||
}
|
||||
|
||||
type FileChunkID struct {
|
||||
Name string
|
||||
ChunkID int
|
||||
Start uint64
|
||||
End uint64
|
||||
Score float32
|
||||
Doc *schema.Document
|
||||
}
|
||||
|
||||
func NewGetRelevantDocs(db *database.Database, llm *LLM, repoID string, size int) *RelevantDocs {
|
||||
return &RelevantDocs{
|
||||
db: db,
|
||||
llm: llm,
|
||||
repoID: repoID,
|
||||
size: size,
|
||||
}
|
||||
}
|
||||
|
||||
func (rd *RelevantDocs) GetRelevantFileChunks(ctx context.Context, query string) ([]*FileChunkID, error) {
|
||||
conn, err := rd.db.DB(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer conn.Release()
|
||||
|
||||
vectorStore, err := pgvector.New(ctx,
|
||||
pgvector.WithConn(conn),
|
||||
pgvector.WithEmbedder(rd.llm.Embedder()),
|
||||
pgvector.WithCollectionName("file_chunks"),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
retr := vectorstores.ToRetriever(vectorStore, rd.size, vectorstores.WithFilters(map[string]any{"type": "file_chunk", "repo_id": rd.repoID}))
|
||||
retr.CallbacksHandler = rd.CallbacksHandler
|
||||
|
||||
docs, err := retr.GetRelevantDocuments(ctx, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var chunks []*FileChunkID
|
||||
|
||||
for _, doc := range docs {
|
||||
chunk := &FileChunkID{
|
||||
Score: doc.Score,
|
||||
}
|
||||
|
||||
if filePath, ok := doc.Metadata["file_path"].(string); ok {
|
||||
chunk.Name = filePath
|
||||
}
|
||||
|
||||
if chunkID, ok := doc.Metadata["chunk_id"].(string); ok {
|
||||
id, _ := strconv.ParseInt(chunkID, 10, 64)
|
||||
chunk.ChunkID = int(id)
|
||||
}
|
||||
|
||||
if start, ok := doc.Metadata["start"].(string); ok {
|
||||
chunk.Start, _ = strconv.ParseUint(start, 10, 64)
|
||||
}
|
||||
|
||||
if end, ok := doc.Metadata["end"].(string); ok {
|
||||
chunk.End, _ = strconv.ParseUint(end, 10, 64)
|
||||
}
|
||||
|
||||
chunk.Doc = &doc
|
||||
chunks = append(chunks, chunk)
|
||||
}
|
||||
|
||||
return chunks, nil
|
||||
}
|
||||
|
||||
func (rd *RelevantDocs) RankChunks(ctx context.Context, query string, chunks []*FileChunkID) error {
|
||||
var didErr error
|
||||
|
||||
slices.SortFunc(chunks, func(a, b *FileChunkID) int {
|
||||
if didErr != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
retr, err := rd.CompareChunks(ctx, query, a, b)
|
||||
if err != nil {
|
||||
didErr = err
|
||||
}
|
||||
|
||||
return retr
|
||||
})
|
||||
|
||||
return didErr
|
||||
}
|
||||
|
||||
func (rd *RelevantDocs) CompareChunks(ctx context.Context, query string, chunk1, chunk2 *FileChunkID) (int, error) {
|
||||
slog.Info("comparing chunks", "chunk_1_name", chunk1.Name, "chunk_1_id", chunk1.ChunkID, "chunk_2_name", chunk2.Name, "chunk_2_id", chunk2.ChunkID)
|
||||
|
||||
prompt := `Given the following two pieces of code pick the most relevant chunk to the task described below. Reply as a json object in the format {"chunk_id": "<chunk>"}. Only reply in JSON. Do not include any explanation or code.`
|
||||
|
||||
prompt += "\n\n" + query + "\n\n"
|
||||
|
||||
// Now that we have candidates we need to compare them against each other to find the most appropriate place to
|
||||
// inject them.
|
||||
|
||||
prompt += "-- chunk_id: chunk_1 --\n"
|
||||
prompt += chunk1.Doc.PageContent
|
||||
|
||||
prompt += "-- chunk_id: chunk_2 --\n"
|
||||
prompt += chunk2.Doc.PageContent
|
||||
|
||||
op := func() (int, error) {
|
||||
rsp, err := rd.llm.CodePrompt(ctx, prompt)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if strings.Contains(rsp, "chunk_1") {
|
||||
return -1, nil
|
||||
} else if strings.Contains(rsp, "chunk_2") {
|
||||
return 1, nil
|
||||
}
|
||||
|
||||
return 0, fmt.Errorf("compare response didn't contain a chunk id: %s", rsp)
|
||||
}
|
||||
|
||||
return backoff.Retry(ctx, op, backoff.WithBackOff(backoff.NewConstantBackOff(10*time.Millisecond)), backoff.WithMaxTries(1))
|
||||
}
|
||||
Reference in New Issue
Block a user