First Working Prototype
This application is a simple proof of concept demonstrating an agent capable of taking a prompt and generating a patch implementing code satisfying the prompt along with an accompanying unit test.
This commit is contained in:
151
pkg/llm/relevent_docs.go
Normal file
151
pkg/llm/relevent_docs.go
Normal file
@@ -0,0 +1,151 @@
|
||||
package llm
|
||||
|
||||
import (
|
||||
"ai-code-assistant/pkg/database"
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/cenkalti/backoff/v5"
|
||||
"github.com/tmc/langchaingo/callbacks"
|
||||
"github.com/tmc/langchaingo/schema"
|
||||
"github.com/tmc/langchaingo/vectorstores"
|
||||
"github.com/tmc/langchaingo/vectorstores/pgvector"
|
||||
"log/slog"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type RelevantDocs struct {
|
||||
CallbacksHandler callbacks.Handler
|
||||
db *database.Database
|
||||
llm *LLM
|
||||
repoID string
|
||||
size int
|
||||
}
|
||||
|
||||
type FileChunkID struct {
|
||||
Name string
|
||||
ChunkID int
|
||||
Start uint64
|
||||
End uint64
|
||||
Score float32
|
||||
Doc *schema.Document
|
||||
}
|
||||
|
||||
func NewGetRelevantDocs(db *database.Database, llm *LLM, repoID string, size int) *RelevantDocs {
|
||||
return &RelevantDocs{
|
||||
db: db,
|
||||
llm: llm,
|
||||
repoID: repoID,
|
||||
size: size,
|
||||
}
|
||||
}
|
||||
|
||||
func (rd *RelevantDocs) GetRelevantFileChunks(ctx context.Context, query string) ([]*FileChunkID, error) {
|
||||
conn, err := rd.db.DB(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer conn.Release()
|
||||
|
||||
vectorStore, err := pgvector.New(ctx,
|
||||
pgvector.WithConn(conn),
|
||||
pgvector.WithEmbedder(rd.llm.Embedder()),
|
||||
pgvector.WithCollectionName("file_chunks"),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
retr := vectorstores.ToRetriever(vectorStore, rd.size, vectorstores.WithFilters(map[string]any{"type": "file_chunk", "repo_id": rd.repoID}))
|
||||
retr.CallbacksHandler = rd.CallbacksHandler
|
||||
|
||||
docs, err := retr.GetRelevantDocuments(ctx, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var chunks []*FileChunkID
|
||||
|
||||
for _, doc := range docs {
|
||||
chunk := &FileChunkID{
|
||||
Score: doc.Score,
|
||||
}
|
||||
|
||||
if filePath, ok := doc.Metadata["file_path"].(string); ok {
|
||||
chunk.Name = filePath
|
||||
}
|
||||
|
||||
if chunkID, ok := doc.Metadata["chunk_id"].(string); ok {
|
||||
id, _ := strconv.ParseInt(chunkID, 10, 64)
|
||||
chunk.ChunkID = int(id)
|
||||
}
|
||||
|
||||
if start, ok := doc.Metadata["start"].(string); ok {
|
||||
chunk.Start, _ = strconv.ParseUint(start, 10, 64)
|
||||
}
|
||||
|
||||
if end, ok := doc.Metadata["end"].(string); ok {
|
||||
chunk.End, _ = strconv.ParseUint(end, 10, 64)
|
||||
}
|
||||
|
||||
chunk.Doc = &doc
|
||||
chunks = append(chunks, chunk)
|
||||
}
|
||||
|
||||
return chunks, nil
|
||||
}
|
||||
|
||||
func (rd *RelevantDocs) RankChunks(ctx context.Context, query string, chunks []*FileChunkID) error {
|
||||
var didErr error
|
||||
|
||||
slices.SortFunc(chunks, func(a, b *FileChunkID) int {
|
||||
if didErr != nil {
|
||||
return 0
|
||||
}
|
||||
|
||||
retr, err := rd.CompareChunks(ctx, query, a, b)
|
||||
if err != nil {
|
||||
didErr = err
|
||||
}
|
||||
|
||||
return retr
|
||||
})
|
||||
|
||||
return didErr
|
||||
}
|
||||
|
||||
func (rd *RelevantDocs) CompareChunks(ctx context.Context, query string, chunk1, chunk2 *FileChunkID) (int, error) {
|
||||
slog.Info("comparing chunks", "chunk_1_name", chunk1.Name, "chunk_1_id", chunk1.ChunkID, "chunk_2_name", chunk2.Name, "chunk_2_id", chunk2.ChunkID)
|
||||
|
||||
prompt := `Given the following two pieces of code pick the most relevant chunk to the task described below. Reply as a json object in the format {"chunk_id": "<chunk>"}. Only reply in JSON. Do not include any explanation or code.`
|
||||
|
||||
prompt += "\n\n" + query + "\n\n"
|
||||
|
||||
// Now that we have candidates we need to compare them against each other to find the most appropriate place to
|
||||
// inject them.
|
||||
|
||||
prompt += "-- chunk_id: chunk_1 --\n"
|
||||
prompt += chunk1.Doc.PageContent
|
||||
|
||||
prompt += "-- chunk_id: chunk_2 --\n"
|
||||
prompt += chunk2.Doc.PageContent
|
||||
|
||||
op := func() (int, error) {
|
||||
rsp, err := rd.llm.CodePrompt(ctx, prompt)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if strings.Contains(rsp, "chunk_1") {
|
||||
return -1, nil
|
||||
} else if strings.Contains(rsp, "chunk_2") {
|
||||
return 1, nil
|
||||
}
|
||||
|
||||
return 0, fmt.Errorf("compare response didn't contain a chunk id: %s", rsp)
|
||||
}
|
||||
|
||||
return backoff.Retry(ctx, op, backoff.WithBackOff(backoff.NewConstantBackOff(10*time.Millisecond)), backoff.WithMaxTries(1))
|
||||
}
|
||||
Reference in New Issue
Block a user