89 lines
3.6 KiB
Python
89 lines
3.6 KiB
Python
import pymysql
|
||
import pandas as pd
|
||
from sklearn.metrics.pairwise import cosine_similarity
|
||
|
||
# 建立数据库连接
|
||
connection = pymysql.connect(
|
||
host='localhost',
|
||
user='root',
|
||
password='root',
|
||
database='filmsystem',
|
||
charset='utf8mb4',
|
||
cursorclass=pymysql.cursors.DictCursor,
|
||
port=3316
|
||
)
|
||
|
||
# 从数据库中读取电影、评分、用户数据
|
||
def fetch_data():
|
||
with connection.cursor() as cursor:
|
||
# 读取电影数据
|
||
cursor.execute("SELECT id, genres, rating_count, average_rating, actor FROM movie")
|
||
movies = pd.DataFrame(cursor.fetchall())
|
||
|
||
# 读取评分数据
|
||
cursor.execute("SELECT user_id, item_id, rating FROM rating")
|
||
ratings = pd.DataFrame(cursor.fetchall())
|
||
|
||
# 读取用户数据
|
||
cursor.execute("SELECT user_id, age, gender, occupation FROM user_default")
|
||
users = pd.DataFrame(cursor.fetchall())
|
||
|
||
return movies, ratings, users
|
||
|
||
# 计算电影的用户特征分布
|
||
def calculate_user_preferences(ratings, users):
|
||
# 将评分数据和用户数据合并,获取每个电影的用户年龄、性别、职业分布
|
||
ratings_with_users = ratings.merge(users, on='user_id', how='left')
|
||
movie_user_features = ratings_with_users.groupby('item_id').agg({
|
||
'age': ['mean', 'std'],
|
||
'gender': lambda x: x.value_counts(normalize=True).get('M', 0),
|
||
'occupation': lambda x: x.mode()[0] if not x.mode().empty else -1 # 取职业众数
|
||
}).reset_index()
|
||
|
||
movie_user_features.columns = ['item_id', 'age_mean', 'age_std', 'gender_male_ratio', 'common_occupation']
|
||
return movie_user_features
|
||
|
||
# 计算电影相似度
|
||
def calculate_similarity(movies, movie_user_features):
|
||
# 处理genres字段为one-hot编码,方便相似性计算
|
||
genres_set = set(g for genres in movies['genres'] for g in genres.split(','))
|
||
for genre in genres_set:
|
||
movies[f'genre_{genre}'] = movies['genres'].apply(lambda x: 1 if str(genre) in x.split(',') else 0)
|
||
|
||
# 合并电影本身特征和用户特征
|
||
movies = movies.merge(movie_user_features, left_on='id', right_on='item_id', how='left').fillna(0)
|
||
features = ['rating_count', 'average_rating', 'age_mean', 'age_std', 'gender_male_ratio'] + [f'genre_{g}' for g in genres_set]
|
||
movie_features = movies[features].values
|
||
|
||
# 计算相似度矩阵
|
||
similarity_matrix = cosine_similarity(movie_features)
|
||
return similarity_matrix, movies
|
||
|
||
# 计算并存储每个电影的前8个相似电影
|
||
def store_similar_movies(similarity_matrix, movies):
|
||
with connection.cursor() as cursor:
|
||
for idx, movie_id in enumerate(movies['id']):
|
||
# 获取相似度得分并排序,选取前8个最相似的电影
|
||
similarity_scores = list(enumerate(similarity_matrix[idx]))
|
||
similarity_scores = sorted(similarity_scores, key=lambda x: x[1], reverse=True)
|
||
top_similar_ids = [movies.iloc[i[0]]['id'] for i in similarity_scores[1:9]] # 排除自身
|
||
|
||
# 将相似电影ID存入similar_movie字段
|
||
top_similar_ids_str = ','.join(map(str, top_similar_ids))
|
||
cursor.execute("UPDATE movie SET similar_movie = %s WHERE id = %s", (top_similar_ids_str, movie_id))
|
||
connection.commit()
|
||
|
||
print(f"Processed movie ID {movie_id}, similar movies: {top_similar_ids_str}")
|
||
|
||
# 主程序
|
||
def main():
|
||
movies, ratings, users = fetch_data()
|
||
movie_user_features = calculate_user_preferences(ratings, users)
|
||
similarity_matrix, movies_with_features = calculate_similarity(movies, movie_user_features)
|
||
store_similar_movies(similarity_matrix, movies_with_features)
|
||
print("All movies processed.")
|
||
|
||
main()
|
||
|
||
connection.close()
|