film_recom_sys/similarity_recommend.py
2025-01-31 21:36:31 +08:00

89 lines
3.6 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import pymysql
import pandas as pd
from sklearn.metrics.pairwise import cosine_similarity
# 建立数据库连接
connection = pymysql.connect(
host='localhost',
user='root',
password='root',
database='filmsystem',
charset='utf8mb4',
cursorclass=pymysql.cursors.DictCursor,
port=3316
)
# 从数据库中读取电影、评分、用户数据
def fetch_data():
with connection.cursor() as cursor:
# 读取电影数据
cursor.execute("SELECT id, genres, rating_count, average_rating, actor FROM movie")
movies = pd.DataFrame(cursor.fetchall())
# 读取评分数据
cursor.execute("SELECT user_id, item_id, rating FROM rating")
ratings = pd.DataFrame(cursor.fetchall())
# 读取用户数据
cursor.execute("SELECT user_id, age, gender, occupation FROM user_default")
users = pd.DataFrame(cursor.fetchall())
return movies, ratings, users
# 计算电影的用户特征分布
def calculate_user_preferences(ratings, users):
# 将评分数据和用户数据合并,获取每个电影的用户年龄、性别、职业分布
ratings_with_users = ratings.merge(users, on='user_id', how='left')
movie_user_features = ratings_with_users.groupby('item_id').agg({
'age': ['mean', 'std'],
'gender': lambda x: x.value_counts(normalize=True).get('M', 0),
'occupation': lambda x: x.mode()[0] if not x.mode().empty else -1 # 取职业众数
}).reset_index()
movie_user_features.columns = ['item_id', 'age_mean', 'age_std', 'gender_male_ratio', 'common_occupation']
return movie_user_features
# 计算电影相似度
def calculate_similarity(movies, movie_user_features):
# 处理genres字段为one-hot编码方便相似性计算
genres_set = set(g for genres in movies['genres'] for g in genres.split(','))
for genre in genres_set:
movies[f'genre_{genre}'] = movies['genres'].apply(lambda x: 1 if str(genre) in x.split(',') else 0)
# 合并电影本身特征和用户特征
movies = movies.merge(movie_user_features, left_on='id', right_on='item_id', how='left').fillna(0)
features = ['rating_count', 'average_rating', 'age_mean', 'age_std', 'gender_male_ratio'] + [f'genre_{g}' for g in genres_set]
movie_features = movies[features].values
# 计算相似度矩阵
similarity_matrix = cosine_similarity(movie_features)
return similarity_matrix, movies
# 计算并存储每个电影的前8个相似电影
def store_similar_movies(similarity_matrix, movies):
with connection.cursor() as cursor:
for idx, movie_id in enumerate(movies['id']):
# 获取相似度得分并排序选取前8个最相似的电影
similarity_scores = list(enumerate(similarity_matrix[idx]))
similarity_scores = sorted(similarity_scores, key=lambda x: x[1], reverse=True)
top_similar_ids = [movies.iloc[i[0]]['id'] for i in similarity_scores[1:9]] # 排除自身
# 将相似电影ID存入similar_movie字段
top_similar_ids_str = ','.join(map(str, top_similar_ids))
cursor.execute("UPDATE movie SET similar_movie = %s WHERE id = %s", (top_similar_ids_str, movie_id))
connection.commit()
print(f"Processed movie ID {movie_id}, similar movies: {top_similar_ids_str}")
# 主程序
def main():
movies, ratings, users = fetch_data()
movie_user_features = calculate_user_preferences(ratings, users)
similarity_matrix, movies_with_features = calculate_similarity(movies, movie_user_features)
store_similar_movies(similarity_matrix, movies_with_features)
print("All movies processed.")
main()
connection.close()