import {DateRange} from 'rsuite/DateRangePicker'
import {OptionalAIModelIdMap, OptionalAIModelMap, TOKEN_BASED_MODEL_IDS} from '../types/AiModel'
import {AiPricing, OtherPricingTier} from '../types/AiPricing'
import {AuditInfoStored, AuditWarning, GroupAuditInfo, ModelAuditMap, UserAuditMap} from '../types/AuditInfo'
import {UserAuditLog, UserPromptAudit} from '../types/UserAuditLog'
import {UserInfoMap} from '../types/UserInfo'
import {transformArrayToEntryMapByProp} from './genericUtils'
import {getDateRangeEndSeconds, getDateRangeStartSeconds} from './dateUtils'

export const isAuditInfoInDateRange = (dateRange: DateRange) => (auditInfo: AuditInfoStored): boolean => {
	return auditInfo.logCreationDate >= getDateRangeStartSeconds(dateRange) && auditInfo.logCreationDate <= getDateRangeEndSeconds(dateRange)
}

export const groupLogsByConversation = (auditLogs: AuditInfoStored[]) => {
    return auditLogs.reduce((groupedLogs, log) => {
        const chatId = 'chatId' in log ? log.chatId : undefined
        const conversationId = `${log.promptId || log.tempPromptId || chatId}-${log.modelId}`
        if (log.promptId || log.tempPromptId || chatId) {
            if (!groupedLogs[conversationId]) groupedLogs[conversationId] = []
            groupedLogs[conversationId].push(log)
        }
        return groupedLogs
    }, {} as GroupAuditInfo)
}

export const getAuditWarnings = (auditLogs: AuditInfoStored[]): AuditWarning[] =>
    auditLogs
        .map(log => 'warnings' in log ? log.warnings : undefined)
        .filter((warnings): warnings is AuditWarning[] => warnings !== undefined)
        .flat()

export const parseAuditLogs = (aiModelMap: OptionalAIModelMap, userInfoMap: UserInfoMap, rawLogs: AuditInfoStored[]): UserAuditLog[] => {
    const groupedLogs = groupLogsByConversation(rawLogs)

	return Object.values(groupedLogs).map((logsGroup) => {
		const [firstLogGroup] = logsGroup
		const {executorUserId, modelId, promptId, tempPromptId, logCreationDate, hashKey} = firstLogGroup
		const chatId = 'chatId' in firstLogGroup ? firstLogGroup.chatId : undefined
		const userFullName = userInfoMap[executorUserId]?.userFullName
		const aiModelName = aiModelMap[modelId]?.name ?? ''
		const warnings = getAuditWarnings(logsGroup)
		const isViewed = logsGroup.every(log => log.isViewed)

		const messages = logsGroup.map(logGroup =>
			({userPrompt: logGroup.userPrompt, date: logGroup.logCreationDate, warnings: logGroup.warnings ?? []})) as UserPromptAudit[]

        const userLog: UserAuditLog = {
			promptId: promptId || chatId || tempPromptId,
			conversationDate: logCreationDate * 1000,
			isOutputSaved: !!promptId || !!chatId,
			userFullName,
			messagesSentToModel: logsGroup.length,
			aiModelId: modelId,
			aiModelName,
			warnings,
			messages,
			isViewed,
			hashKey,
			auditLogIds: logsGroup.map(log => log.hashKey)
		}

        return userLog
	})
}

export const getUserAuditMap = (auditLogs: AuditInfoStored[]): UserAuditMap =>
	transformArrayToEntryMapByProp(auditLogs, 'executorUserId')

export const getModelAuditMap = (auditLogs: AuditInfoStored[]): ModelAuditMap =>
	transformArrayToEntryMapByProp(auditLogs, 'modelId')

export const getModelAuditMapCost = (modelAuditMap: ModelAuditMap, aiPricings: AiPricing[]): { total: number, modelCostMap: OptionalAIModelIdMap<number>} => {
	let total = 0
	let modelCostMap: OptionalAIModelIdMap<number> = {}

	TOKEN_BASED_MODEL_IDS.forEach(modelId => {
		const logs = modelAuditMap[modelId]
		const modelCost = logs?.reduce<number>((acc, log) => {
			const logPricing = aiPricings.find(aiPricing => aiPricing.modelId === modelId && aiPricing.startDate <= log.logCreationDate && (aiPricing.endDate === undefined || log.logCreationDate <= aiPricing.endDate))

			const otherPricingTier = getOtherPricingTier(logPricing, log)
			if (otherPricingTier){
				return acc + (log.consumedOutputTokens ?? 0) * otherPricingTier.priceOutputToken + log.consumedInputTokens! * otherPricingTier.priceInputToken
			}
			return acc + (log.consumedOutputTokens ?? 0) * (logPricing?.priceOutputToken ?? 0) + (log.consumedInputTokens ?? 0) * (logPricing?.priceInputToken ?? 0)
		}, 0) ?? 0

		total += modelCost
		modelCostMap[modelId] = modelCost
	})

	return {
		total,
		modelCostMap
	}
}

export const getOtherPricingTier = (aiPricing: AiPricing | undefined, auditLog: AuditInfoStored) => {
	return aiPricing?.otherPricingTiers
		?.sort((pricingTier1: OtherPricingTier, pricingTier2: OtherPricingTier) => pricingTier2.tokenLimit - pricingTier1.tokenLimit)
		.find(otherPricingTier => auditLog.consumedInputTokens && auditLog.consumedInputTokens >= otherPricingTier.tokenLimit)

}