Checks to make sure repo is indexed before generating code. Don't generate tests for changes to tests. Remove unused code. Fix bootstrapping issue with langchaingo tables.
93 lines
2.0 KiB
Go
93 lines
2.0 KiB
Go
package llm
|
|
|
|
import (
|
|
"ai-code-assistant/pkg/database"
|
|
"context"
|
|
"github.com/tmc/langchaingo/callbacks"
|
|
"github.com/tmc/langchaingo/schema"
|
|
"github.com/tmc/langchaingo/vectorstores"
|
|
"github.com/tmc/langchaingo/vectorstores/pgvector"
|
|
"strconv"
|
|
)
|
|
|
|
type RelevantDocs struct {
|
|
CallbacksHandler callbacks.Handler
|
|
db *database.Database
|
|
llm *LLM
|
|
repoID string
|
|
size int
|
|
}
|
|
|
|
type FileChunkID struct {
|
|
Name string
|
|
ChunkID int
|
|
Start uint64
|
|
End uint64
|
|
Score float32
|
|
Doc *schema.Document
|
|
}
|
|
|
|
func NewGetRelevantDocs(db *database.Database, llm *LLM, repoID string, size int) *RelevantDocs {
|
|
return &RelevantDocs{
|
|
db: db,
|
|
llm: llm,
|
|
repoID: repoID,
|
|
size: size,
|
|
}
|
|
}
|
|
|
|
func (rd *RelevantDocs) GetRelevantFileChunks(ctx context.Context, query string) ([]*FileChunkID, error) {
|
|
conn, err := rd.db.DB(ctx)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer conn.Release()
|
|
|
|
vectorStore, err := pgvector.New(ctx,
|
|
pgvector.WithConn(conn),
|
|
pgvector.WithEmbedder(rd.llm.Embedder()),
|
|
pgvector.WithCollectionName("file_chunks"),
|
|
)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
retr := vectorstores.ToRetriever(vectorStore, rd.size, vectorstores.WithFilters(map[string]any{"type": "file_chunk", "repo_id": rd.repoID}))
|
|
retr.CallbacksHandler = rd.CallbacksHandler
|
|
|
|
docs, err := retr.GetRelevantDocuments(ctx, query)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var chunks []*FileChunkID
|
|
|
|
for _, doc := range docs {
|
|
chunk := &FileChunkID{
|
|
Score: doc.Score,
|
|
}
|
|
|
|
if filePath, ok := doc.Metadata["file_path"].(string); ok {
|
|
chunk.Name = filePath
|
|
}
|
|
|
|
if chunkID, ok := doc.Metadata["chunk_id"].(string); ok {
|
|
id, _ := strconv.ParseInt(chunkID, 10, 64)
|
|
chunk.ChunkID = int(id)
|
|
}
|
|
|
|
if start, ok := doc.Metadata["start"].(string); ok {
|
|
chunk.Start, _ = strconv.ParseUint(start, 10, 64)
|
|
}
|
|
|
|
if end, ok := doc.Metadata["end"].(string); ok {
|
|
chunk.End, _ = strconv.ParseUint(end, 10, 64)
|
|
}
|
|
|
|
chunk.Doc = &doc
|
|
chunks = append(chunks, chunk)
|
|
}
|
|
|
|
return chunks, nil
|
|
}
|