First Working Prototype

This application is a simple proof of concept demonstrating an agent capable of taking a prompt and generating a patch implementing code satisfying the prompt along with an accompanying unit test.
This commit is contained in:
2025-04-20 07:47:41 -04:00
commit 4b8b8132fd
15 changed files with 1797 additions and 0 deletions

96
pkg/llm/llm.go Normal file
View File

@@ -0,0 +1,96 @@
package llm
import (
"ai-code-assistant/pkg/config"
"bytes"
"context"
"embed"
"fmt"
"github.com/tmc/langchaingo/embeddings"
"github.com/tmc/langchaingo/llms"
"text/template"
)
type contextKey string
const contextKeyLLM contextKey = "llm"
//go:embed prompts
var prompts embed.FS
type LLM struct {
code llms.Model
chat llms.Model
embedder embeddings.Embedder
}
func FromConfig(cfg *config.Configuration) (*LLM, error) {
embedLLM, err := cfg.Embedding.GetEmbedding()
if err != nil {
return nil, fmt.Errorf("unable to get embedding model: %w", err)
}
codeLLM, err := cfg.Code.GetModel()
if err != nil {
return nil, fmt.Errorf("unable to get coder model: %w", err)
}
chatLLM, err := cfg.Chat.GetModel()
if err != nil {
return nil, fmt.Errorf("unable to get chat model: %w", err)
}
embedder, err := embeddings.NewEmbedder(embedLLM)
if err != nil {
return nil, err
}
return &LLM{
embedder: embedder,
code: codeLLM,
chat: chatLLM,
}, nil
}
func FromContext(ctx context.Context) *LLM {
return ctx.Value(contextKeyLLM).(*LLM)
}
func WrapContext(ctx context.Context, llmRef *LLM) context.Context {
return context.WithValue(ctx, contextKeyLLM, llmRef)
}
func (llm *LLM) GetEmbedding(ctx context.Context, texts ...string) ([][]float32, error) {
return llm.embedder.EmbedDocuments(ctx, texts)
}
func (llm *LLM) Embedder() embeddings.Embedder {
return llm.embedder
}
func (llm *LLM) CodePrompt(ctx context.Context, prompt string) (string, error) {
return llm.code.Call(ctx, prompt)
}
func (llm *LLM) ChatPrompt(ctx context.Context, prompt string) (string, error) {
return llm.chat.Call(ctx, prompt)
}
func GetPrompt(name string, data any) (string, error) {
tmplText, err := prompts.ReadFile("prompts/" + name + ".tmpl")
if err != nil {
return "", err
}
tmpl, err := template.New(name).Parse(string(tmplText))
if err != nil {
return "", err
}
var buf bytes.Buffer
if err := tmpl.Execute(&buf, data); err != nil {
return "", err
}
return buf.String(), nil
}

View File

@@ -0,0 +1,5 @@
given the following prompt create a descriptive commit message indicating the changes made
only give a single commit message, do not describe your thought process
files changed: {{range .Files}}{{.}} {{end}}
prompt: {{ .Prompt }}

View File

@@ -0,0 +1,11 @@
given the following code snippet in markdown format perform the following task:
{{ .Prompt }}
return the modified function as a code block formatted in markdown, make sure to retain as much of the preceding and
trailing context as possible
!IMPORTANT do not add any explanation
```go
{{ .Context }}
```

View File

@@ -0,0 +1,15 @@
given the following code block in markdown that performs the following task:
{{ .Prompt }}
{{- if .TestFileExists -}}
create a unit test for the following code as a code block formatted in markdown, only include the unit test itself do
not include any imports
!IMPORTANT do not add any explanation
{{- else -}}
create a unit test for the following code as a code block formatted in markdown, include any required imports
!IMPORTANT do not add any explanation
{{- end -}}
```go
{{ .Context }}
```"

151
pkg/llm/relevent_docs.go Normal file
View File

@@ -0,0 +1,151 @@
package llm
import (
"ai-code-assistant/pkg/database"
"context"
"fmt"
"github.com/cenkalti/backoff/v5"
"github.com/tmc/langchaingo/callbacks"
"github.com/tmc/langchaingo/schema"
"github.com/tmc/langchaingo/vectorstores"
"github.com/tmc/langchaingo/vectorstores/pgvector"
"log/slog"
"slices"
"strconv"
"strings"
"time"
)
type RelevantDocs struct {
CallbacksHandler callbacks.Handler
db *database.Database
llm *LLM
repoID string
size int
}
type FileChunkID struct {
Name string
ChunkID int
Start uint64
End uint64
Score float32
Doc *schema.Document
}
func NewGetRelevantDocs(db *database.Database, llm *LLM, repoID string, size int) *RelevantDocs {
return &RelevantDocs{
db: db,
llm: llm,
repoID: repoID,
size: size,
}
}
func (rd *RelevantDocs) GetRelevantFileChunks(ctx context.Context, query string) ([]*FileChunkID, error) {
conn, err := rd.db.DB(ctx)
if err != nil {
return nil, err
}
defer conn.Release()
vectorStore, err := pgvector.New(ctx,
pgvector.WithConn(conn),
pgvector.WithEmbedder(rd.llm.Embedder()),
pgvector.WithCollectionName("file_chunks"),
)
if err != nil {
return nil, err
}
retr := vectorstores.ToRetriever(vectorStore, rd.size, vectorstores.WithFilters(map[string]any{"type": "file_chunk", "repo_id": rd.repoID}))
retr.CallbacksHandler = rd.CallbacksHandler
docs, err := retr.GetRelevantDocuments(ctx, query)
if err != nil {
return nil, err
}
var chunks []*FileChunkID
for _, doc := range docs {
chunk := &FileChunkID{
Score: doc.Score,
}
if filePath, ok := doc.Metadata["file_path"].(string); ok {
chunk.Name = filePath
}
if chunkID, ok := doc.Metadata["chunk_id"].(string); ok {
id, _ := strconv.ParseInt(chunkID, 10, 64)
chunk.ChunkID = int(id)
}
if start, ok := doc.Metadata["start"].(string); ok {
chunk.Start, _ = strconv.ParseUint(start, 10, 64)
}
if end, ok := doc.Metadata["end"].(string); ok {
chunk.End, _ = strconv.ParseUint(end, 10, 64)
}
chunk.Doc = &doc
chunks = append(chunks, chunk)
}
return chunks, nil
}
func (rd *RelevantDocs) RankChunks(ctx context.Context, query string, chunks []*FileChunkID) error {
var didErr error
slices.SortFunc(chunks, func(a, b *FileChunkID) int {
if didErr != nil {
return 0
}
retr, err := rd.CompareChunks(ctx, query, a, b)
if err != nil {
didErr = err
}
return retr
})
return didErr
}
func (rd *RelevantDocs) CompareChunks(ctx context.Context, query string, chunk1, chunk2 *FileChunkID) (int, error) {
slog.Info("comparing chunks", "chunk_1_name", chunk1.Name, "chunk_1_id", chunk1.ChunkID, "chunk_2_name", chunk2.Name, "chunk_2_id", chunk2.ChunkID)
prompt := `Given the following two pieces of code pick the most relevant chunk to the task described below. Reply as a json object in the format {"chunk_id": "<chunk>"}. Only reply in JSON. Do not include any explanation or code.`
prompt += "\n\n" + query + "\n\n"
// Now that we have candidates we need to compare them against each other to find the most appropriate place to
// inject them.
prompt += "-- chunk_id: chunk_1 --\n"
prompt += chunk1.Doc.PageContent
prompt += "-- chunk_id: chunk_2 --\n"
prompt += chunk2.Doc.PageContent
op := func() (int, error) {
rsp, err := rd.llm.CodePrompt(ctx, prompt)
if err != nil {
return 0, err
}
if strings.Contains(rsp, "chunk_1") {
return -1, nil
} else if strings.Contains(rsp, "chunk_2") {
return 1, nil
}
return 0, fmt.Errorf("compare response didn't contain a chunk id: %s", rsp)
}
return backoff.Retry(ctx, op, backoff.WithBackOff(backoff.NewConstantBackOff(10*time.Millisecond)), backoff.WithMaxTries(1))
}