87 lines
2.1 KiB
Go
87 lines
2.1 KiB
Go
package llm
|
|
|
|
import (
|
|
"ai-code-assistant/pkg/database"
|
|
"context"
|
|
"github.com/tmc/langchaingo/callbacks"
|
|
"github.com/tmc/langchaingo/schema"
|
|
"github.com/tmc/langchaingo/vectorstores"
|
|
"strconv"
|
|
)
|
|
|
|
// RelevantDocs attempts to find the most relevant file chunks based on context from the prompt.
|
|
type RelevantDocs struct {
|
|
CallbacksHandler callbacks.Handler
|
|
db *database.Database
|
|
llm *LLM
|
|
repoID string
|
|
size int
|
|
}
|
|
|
|
// FileChunkID is a pointer to a repository file chunk that has been indexed.
|
|
type FileChunkID struct {
|
|
Name string
|
|
ChunkID int
|
|
Start uint64
|
|
End uint64
|
|
Score float32
|
|
Doc *schema.Document
|
|
}
|
|
|
|
// NewGetRelevantDocs creates a new RelevantDocs scanner.
|
|
func NewGetRelevantDocs(db *database.Database, llm *LLM, repoID string, size int) *RelevantDocs {
|
|
return &RelevantDocs{
|
|
db: db,
|
|
llm: llm,
|
|
repoID: repoID,
|
|
size: size,
|
|
}
|
|
}
|
|
|
|
// GetRelevantFileChunks will scan for relevant documents based on a prompt.
|
|
func (rd *RelevantDocs) GetRelevantFileChunks(ctx context.Context, query string) ([]*FileChunkID, error) {
|
|
vectorStore, closeFunc, err := rd.db.GetVectorStore(ctx, rd.llm.Embedder())
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer closeFunc()
|
|
|
|
retr := vectorstores.ToRetriever(vectorStore, rd.size, vectorstores.WithFilters(map[string]any{"type": "file_chunk", "repo_id": rd.repoID}))
|
|
retr.CallbacksHandler = rd.CallbacksHandler
|
|
|
|
docs, err := retr.GetRelevantDocuments(ctx, query)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var chunks []*FileChunkID
|
|
|
|
for _, doc := range docs {
|
|
chunk := &FileChunkID{
|
|
Score: doc.Score,
|
|
}
|
|
|
|
if filePath, ok := doc.Metadata["file_path"].(string); ok {
|
|
chunk.Name = filePath
|
|
}
|
|
|
|
if chunkID, ok := doc.Metadata["chunk_id"].(string); ok {
|
|
id, _ := strconv.ParseInt(chunkID, 10, 64)
|
|
chunk.ChunkID = int(id)
|
|
}
|
|
|
|
if start, ok := doc.Metadata["start"].(string); ok {
|
|
chunk.Start, _ = strconv.ParseUint(start, 10, 64)
|
|
}
|
|
|
|
if end, ok := doc.Metadata["end"].(string); ok {
|
|
chunk.End, _ = strconv.ParseUint(end, 10, 64)
|
|
}
|
|
|
|
chunk.Doc = &doc
|
|
chunks = append(chunks, chunk)
|
|
}
|
|
|
|
return chunks, nil
|
|
}
|